A look at how PyTorch implements the Transformer and Multi-Head Attention to understand both how the core components of the Transformer are written as well as what engineering decisions were made or required for the implementation.

In the PyTorch source code as of v1.12.0 we find a tensor library - aten, the experimental JAX-like functorch and the primary source of PyTorch itself under the subdirectory torch/.

Structure of pytorch/torch/nn/modules/

Starting by first looking around inside torch/ we see familiar subpackages:

  • cuda for providing CUDA support on GPUs
  • onnx for compiling to the Open Neural Network Exchange, open-source standard for representing models in an interoperable way
  • fft for Fast Fourier Transforms and other related jazz
  • optim for optimizers
  • a profiler
  • support for sparse operations and operators
  • support for quantization
  • a linear algebra library linalg
  • nn for neural network model components

If we go into [pytorch/torch/nn/modules/](<https://github.com/pytorch/pytorch/tree/master/torch/nn/modules>) we see a list of modules, including but not limited to1 activation, batchnorm, conv, distance, dropout, linear, loss, module, normalization, padding, pooling, rnn, sparse, and transformer.

The init.py follows the usual from .some_module import object, other_object syntax then defines the “public” classes and functions via an __all__ list.

# the Module abstract base class
from .module import Module
# convolution operations (including transposed convolutions)
from .conv import Conv1d, Conv2d, Conv3d, \\\\
    ConvTranspose1d, ConvTranspose2d, ConvTranspose3d # and more...
# activation functions
from .activation import Threshold, ReLU, Hardtanh, ReLU6, Sigmoid, Tanh, \\\\
    Softmax, Softmax2d, LogSoftmax, ELU, SELU, CELU, GELU, Hardshrink, LeakyReLU, # and more...
# losses
from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLoss, NLLLoss2d, \\\\
    CosineEmbeddingLoss, CTCLoss, HingeEmbeddingLoss, MarginRankingLoss, \\\\
    MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, SmoothL1Loss, HuberLoss, \\\\
    SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginWithDistanceLoss # and more...
# transformer
from .transformer import TransformerEncoder, TransformerDecoder, \\\\
    TransformerEncoderLayer, TransformerDecoderLayer, Transformer

# then export all the public entities with __all__
__all__ = [
    'Module', 'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d'
    ...
    'Hardsigmoid', 'Hardswish', 'SiLU', 'Mish', 'TripletMarginWithDistanceLoss', 'ChannelShuffle'
]

The Transformer Module

If we take a look inside the transformer module therein - and filter on code definitions with GitHub’s handy “jump to” feature - we see that there are five classes defined:

  1. Transformer - A transformer model. User is able to modify the attributes as needed
  2. TransformerEncoder - TransformerEncoder is a stack of N encoder layers
  3. TransformerDecoder - TransformerDecoder is a stack of N decoder layers
  4. TransformerEncoderLayer - TransformerEncoderLayer is made up of self-attn and feedforward network
  5. TransformerDecoderLayer - TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network

Where all of these classes are implemented with standard vanilla architectures based on the paper Attention Is All You Need.

Let’s look at the Transformer implementation

PyTorch’s Transformer - torch.nn.Transformer

If we look at the implementation of the [Transformer in PyTorch](https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html#torch.nn.Transformer), we see that we initialize the model with the following parameters (arguments to init) which I comment / describe inline:

def __init__(
    self,
    d_model: int = 512,                                          # input embedding dimension
    nhead: int = 8,                                              # number of attention heads
    num_encoder_layers: int = 6,                                 # number of encoder layers
    num_decoder_layers: int = 6,                                 # number of decoder layers
    dim_feedforward: int = 2048,                                 # linear layer dimensionality
    dropout: float = 0.1,                                        # dropout probability
    activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, # activation function
    # You can pass custom encoder or decoders
    custom_encoder: Optional[Any] = None,
    custom_decoder: Optional[Any] = None,
    # engineering arguments
    layer_norm_eps: float = 1e-5, # epsilon added to layer norm denominator - numerical stability
    batch_first: bool = False,    # whether to put the batch (sample-wise) dimension first or second
    # whether to apply the layer norm before or after the multi-head attention block
    norm_first: bool = False,
    # "factory" key-word arguments the device the parameter tensors are on and their data types
    device=None,
    dtype=None,
) -> None:

First the housekeeping arguments…

  • The layer_norm_eps argument is an additional term added to the denominator of the LayerNorm operation for numerical stability in cases where the weights’ standard deviation is close to zero. For details see the Note on Normalization including Layernorm and Batchnorm.
  • You can specify whether to put the batch (sample-wise) dimension first, as comes out of DataLoaders unless you tweak the collate function, or second
  • Interestingly, we can apply the Layer Normalization layer before or after Multi-Headed Attention as mentioned in the paper On Layer Normalization in the Transformer Architecture by Ruibin Xiong and colleagues from 2020.
  • You can pass a custom encoder and / or decoder via custom_encoder and custom_decoder
  • There are also a couple of “factory” key-word arguments that specify the device the parameter tensors are on and their data types

The more interesting arguments…

  1. d_model (integer, defaulting to 512) - the input embedding dimension
  2. nhead (integer, defaulting to 8) - the number of attention heads
  3. num_encoder_layers (integer, defaulting to 6) - the number of encoder layers
  4. num_decoder_layers (integer, defaulting to 6) - the number of decoder layers
  5. dim_feedforward (integer, defaulting to 2048) - the linear layer dimensionality
  6. dropout (floating point, defaulting to 0.1) - the dropout probability
  7. activation (either a string or a callable, defaulting to F.relu - the activation function

The forward pass through the Transformer model is not particularly enlightening in itself because sensibly it delegates all work to modular components, but it is here that we begin our overview of the Transformer before zooming in. The forward pass is as follows:

# arguments to Transformer.forward

# src: the sequence to the encoder (required).
# tgt: the sequence to the decoder (required).
# src_mask: the additive mask for the src sequence (optional).
# tgt_mask: the additive mask for the tgt sequence (optional).
# memory_mask: the additive mask for the encoder output (optional).
# src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
# tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
# memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).

# ...some batching and dimension-related checks are performed before the main work is done...

# the operative section of the forward method

memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
output = self.decoder(
    tgt,
    memory,
    tgt_mask=tgt_mask,
    memory_mask=memory_mask,
    tgt_key_padding_mask=tgt_key_padding_mask,
    memory_key_padding_mask=memory_key_padding_mask,
)

So clearly what’s going on is that we pass the input to the encoder, which by default is the TransformerEncoder, a stack of TransformerEncoderLayer layers, and then pass the output of that to the TransformerDecoder, a stack of TransformerDecoderLayer layers. A reminder that under the default case.

So let’s look now at the TransformerEncoder.

The Transformer Encoder - TransformerEncoder

Luckily the implementation of “ is equally simple in its modularity, since all it does is make a few checks to see if nested tensors can be used, and then does the following

for mod in self.layers:
    if convert_to_nested:
        output = mod(output, src_mask=mask)
    else:
        output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)

if convert_to_nested:
    output = output.to_padded_tensor(0.0)

if self.norm is not None:
    output = self.norm(output)

We can throw out the code related to use of nested tensor and effectively think of this as:

  1. Loop over the layers in the encoder, applying each in turn
  2. Normalize if normalization is used
for mod in self.layers:
    output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)

if self.norm is not None:
    output = self.norm(output)

Similarly, we can see from the TransformerEncoder’s __init__ method that the layers are literally just clones2, the number of layers is saved, whether to apply normalization is saved and whether to enable_nested_tensor use is saved, defaulting to False.

def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=False):
    super(TransformerEncoder, self).__init__()
    self.layers = _get_clones(encoder_layer, num_layers)
    self.num_layers = num_layers
    self.norm = norm
    self.enable_nested_tensor = enable_nested_tensor

With that we can move fairly quickly on to look at the first meaningful or substantive piece of implementation that makes use of fundamental neural network building blocks like Linear layers or Dropout: TransformerEncoderLayer.

The Transformer Encoder Layer - TransformerEncoderLayer

Let’s begin by reminding ourselves that the TransformerEncoderLayer is made up of

  1. A self-attention block, that in this case consists of Multi-Head Attention, and…
  2. A feedforward (densely connected) network, which consists of two linear layers each followed by dropout layers, and the activation function immediately after the first linear layer (i.e. before the dropout layer)

The Encoder Self-Attention Block

# self-attention block
def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
    x = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0]
    return self.dropout1(x)

The Encoder Feed Forward Block

# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
    x = self.linear2(self.dropout(self.activation(self.linear1(x))))
    return self.dropout2(x)

If we turn out eyes upward to the __init__ method, we see that all of these components are fairly straightforward, with the exception of the definition of self_attn which is assigned the value MultiheadAttention with some parametrization using d_model, nhead, and dropout.

# in the imports we have
from .activation import MultiheadAttention

# an extract from the __init__ of TransformerEncoderLayer
super(TransformerEncoderLayer, self).__init__()

# define the self-attention block
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs)

# Implementation of Feedforward model
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)

self.norm_first = norm_first
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)

The dropout layers are in fact identical and could have been applied via the function torch.nn.functional.dropout as I explain in the Note on PyTorch’s Dropout layer and Object-Oriented Interfaces - Dropout. (Not so for the LayerNorm layers, norm1 and norm2, which have trainable weights and biases.)

So all that remains to understand the encoder of the Transformer is to study the Multi-Head Attention module provided by PyTorch in pytorch/torch/nn/modules/activation.py.

Multi-Head Attention in PyTorch

To be continued…

Forthcoming Sections

Upcoming sections on

  • the use of stub files (see also) in the PyTorch codebase, for example [pytorch/torch/_C/_functions.pyi](<https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/_C/_functions.pyi>) amongst other [.pyi](<https://fileinfo.com/extension/pyi>) files, and interfaces
  • the use of GNU Autoconf with [.in](<https://fileinfo.com/extension/in>) and [.ac](<https://fileinfo.com/extension/ac>) files to build application code, for example [pytorch/torch/_C/_VariableFunctions.pyi.in](<https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/_C/_VariableFunctions.pyi.in>)
    • Side note on [.am](<https://fileinfo.com/extension/am>) files for Automake, a GNU tool for automatically generating Makefile.in files, which requires use of GNU Autoconf. An example is in the [Makefile.am](<https://github.com/espeak-ng/espeak-ng/blob/master/Makefile.am>) in espeak-ng

Note on PyTorch’s Dropout layer and Object-Oriented Interfaces - Dropout

A quick aside or note to say that whilst PyTorch provides a “stateful”, object-oriented interface to work with operations like dropout or even activation functions like the ReLU, these class-based interfaces are mostly thin wrappers around the torch.nn.functional interface, as mentioned by Ed Yang in his excellent PyTorch Developer Podcast.

Let’s look at the head of the script defining Torch’s class-based Dropout layer for such an example. I’ve abridged the code and added my comments inline:

# off the bat we import torch.nn.functional
from .. import functional as F

# we have an abstract base class that the specific dropout classes inherit from
class _DropoutNd(Module):
    # some type hints are provided for convenience / readability
    __constants__ = ['p', 'inplace']
    p: float
    inplace: bool

    def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
        super(_DropoutNd, self).__init__()
        # basically the init method is just a setter that does value validation on p
        if p < 0 or p > 1:
            raise ValueError("dropout probability has to be between 0 and 1, "
                             "but got {}".format(p))
        self.p = p
        self.inplace = inplace

class Dropout(_DropoutNd):
    r"""During training, randomly zeroes some of the elements of the input
    tensor with probability p using samples from a Bernoulli
    distribution. Each channel will be zeroed out independently on every forward
    call."""

    # no __init__ is defined for Dropout -> it uses the inherited _DropoutNd.__init__

    def forward(self, input: Tensor) -> Tensor:
        return F.dropout(input, self.p, self.training, self.inplace)

We can see that this is mostly boilerplate and that the only necessary line is the following, which could be run directly passing the dropout probability value and flag determining whether the operation should be performed in-place or not:

F.dropout(input, self.p, self.training, self.inplace)

If you’re reading this and you know why such object-oriented interfaces are provided, beyond the need or desire for consistency from Torch users when defining their networks, please reach out to let me know.

Note on Normalization including LayerNorm and BatchNorm

We mentioned there was an added to the denominator of the LayerNorm layer in the Transformer. What does this mean and why is it needed?

A quick reminder that normalization, the re-scaling and centering of weights in neural networks, is used or required to improve optimization, although the mechanism for the techniques’ effectiveness is not well described. See for example How Does Batch Normalization Help Optimization? from Shibani Santurkar et al. (2018) which puts forward that batch normalization makes the optimization landscape significantly smoother, which induces a more predictive and stable behavior of the gradients, allowing for faster training.

Originally, batch normalization was proposed by Sergey Ioffe and Christian Szegedy in 2015 in Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. Its benefit is that it allows the use of much higher learning rates and makes convergence less sensitive to weights’ initialization. BatchNorm also acts as a regularizer and in some cases eliminates the need for dropout. In BatchNorm you compute the first and second centered moments over a minibatch, , since doing this over the training set is intractable:

Where and are learnable parameters.

Note that we have the in the denominator. Sometimes under conditions where the weights are almost identical, the standard deviation of the weights’ values can become very small (close to zero) and therefore the normalized output value, becomes undefined at .

This is the exact same as what is going on in Layer Normalization, proposed by Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton in 2016. Here we have the same equation (watch out: different notation) but this time we compute the summary statistics (mean and standard deviation) not over a minibatch but rather over a single layer.

This was originally proposed in the context of recurrent neural networks where the means of applying BatchNorm is not obvious. The authors changed up batch normalization by instead computing the mean and variance used for normalization from all of the summed inputs to the neurons in a layer on a single training case, and kept a learnable adaptive bias and gain for each neuron which are applied after the normalization but before the non-linearity (the activation function).

Unlike batch normalization, layer normalization performs exactly the same computation at training and test times. (You can’t apply BatchNorm consistently at inference time or in general since batch sizes are variable and at inference or prediction time you might want to compute the model’s output for a single sample, where obviously these numbers are meaningless.)

Further Reading

One of the best ways to understand a concept is to approach it from several angles, for example to read about a concept from several authors or, in the case of deep learning engineering, to understand several different implementations of the same model. There is a Transformer implemented in the much faster and more modern, Jax, for WMT which contains all the same components.

Footnotes

  1. The full list of Python modules available at the time of writing (commit 664058fa83) comprised: activation, adaptive, batchnorm, channelshuffle, container, conv, distance, dropout, flatten, fold, instancenorm, lazy, linear, loss, module, normalization, padding, pixelshuffle, pooling, rnn, sparse, transformer, upsampling, utils, _functions, and an __init__.py.

  2. Clones, not copies, since then the same parameters would be referenced in memory because of Python’s copy by reference semantics