A look at how PyTorch implements the Transformer and Multi-Head Attention to understand both how the core components of the Transformer are written as well as what engineering decisions were made or required for the implementation.
In the PyTorch source code as of v1.12.0 we find a tensor library - aten, the experimental JAX-like functorch and the primary source of PyTorch itself under the subdirectory torch/
.
Structure of pytorch/torch/nn/modules/
Starting by first looking around inside torch/
we see familiar subpackages:
- cuda for providing CUDA support on GPUs
- onnx for compiling to the Open Neural Network Exchange, open-source standard for representing models in an interoperable way
- fft for Fast Fourier Transforms and other related jazz
- optim for optimizers
- a profiler
- support for sparse operations and operators
- support for quantization
- a linear algebra library linalg
- nn for neural network model components
If we go into [pytorch/torch/nn/modules/](<https://github.com/pytorch/pytorch/tree/master/torch/nn/modules>)
we see a list of modules, including but not limited to1 activation, batchnorm, conv, distance, dropout, linear, loss, module, normalization, padding, pooling, rnn, sparse, and transformer.
The init.py follows the usual from .some_module import object, other_object
syntax then defines the “public” classes and functions via an __all__
list.
# the Module abstract base class
from .module import Module
# convolution operations (including transposed convolutions)
from .conv import Conv1d, Conv2d, Conv3d, \\\\
ConvTranspose1d, ConvTranspose2d, ConvTranspose3d # and more...
# activation functions
from .activation import Threshold, ReLU, Hardtanh, ReLU6, Sigmoid, Tanh, \\\\
Softmax, Softmax2d, LogSoftmax, ELU, SELU, CELU, GELU, Hardshrink, LeakyReLU, # and more...
# losses
from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLoss, NLLLoss2d, \\\\
CosineEmbeddingLoss, CTCLoss, HingeEmbeddingLoss, MarginRankingLoss, \\\\
MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, SmoothL1Loss, HuberLoss, \\\\
SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginWithDistanceLoss # and more...
# transformer
from .transformer import TransformerEncoder, TransformerDecoder, \\\\
TransformerEncoderLayer, TransformerDecoderLayer, Transformer
# then export all the public entities with __all__
__all__ = [
'Module', 'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d'
...
'Hardsigmoid', 'Hardswish', 'SiLU', 'Mish', 'TripletMarginWithDistanceLoss', 'ChannelShuffle'
]
The Transformer Module
If we take a look inside the transformer module therein - and filter on code definitions with GitHub’s handy “jump to” feature - we see that there are five classes defined:
Transformer
- A transformer model. User is able to modify the attributes as neededTransformerEncoder
- TransformerEncoder is a stack of N encoder layersTransformerDecoder
- TransformerDecoder is a stack of N decoder layersTransformerEncoderLayer
- TransformerEncoderLayer is made up of self-attn and feedforward networkTransformerDecoderLayer
- TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network
Where all of these classes are implemented with standard vanilla architectures based on the paper Attention Is All You Need.
Let’s look at the Transformer
implementation
PyTorch’s Transformer - torch.nn.Transformer
If we look at the implementation of the [Transformer
in PyTorch](https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html#torch.nn.Transformer), we see that we initialize the model with the following parameters (arguments to init) which I comment / describe inline:
def __init__(
self,
d_model: int = 512, # input embedding dimension
nhead: int = 8, # number of attention heads
num_encoder_layers: int = 6, # number of encoder layers
num_decoder_layers: int = 6, # number of decoder layers
dim_feedforward: int = 2048, # linear layer dimensionality
dropout: float = 0.1, # dropout probability
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, # activation function
# You can pass custom encoder or decoders
custom_encoder: Optional[Any] = None,
custom_decoder: Optional[Any] = None,
# engineering arguments
layer_norm_eps: float = 1e-5, # epsilon added to layer norm denominator - numerical stability
batch_first: bool = False, # whether to put the batch (sample-wise) dimension first or second
# whether to apply the layer norm before or after the multi-head attention block
norm_first: bool = False,
# "factory" key-word arguments the device the parameter tensors are on and their data types
device=None,
dtype=None,
) -> None:
First the housekeeping arguments…
- The
layer_norm_eps
argument is an additional term added to the denominator of the LayerNorm operation for numerical stability in cases where the weights’ standard deviation is close to zero. For details see the Note on Normalization including Layernorm and Batchnorm. - You can specify whether to put the batch (sample-wise) dimension first, as comes out of
DataLoader
s unless you tweak the collate function, or second - Interestingly, we can apply the Layer Normalization layer before or after Multi-Headed Attention as mentioned in the paper On Layer Normalization in the Transformer Architecture by Ruibin Xiong and colleagues from 2020.
- You can pass a custom encoder and / or decoder via
custom_encoder
andcustom_decoder
- There are also a couple of “factory” key-word arguments that specify the device the parameter tensors are on and their data types
The more interesting arguments…
d_model
(integer, defaulting to 512) - the input embedding dimensionnhead
(integer, defaulting to 8) - the number of attention headsnum_encoder_layers
(integer, defaulting to 6) - the number of encoder layersnum_decoder_layers
(integer, defaulting to 6) - the number of decoder layersdim_feedforward
(integer, defaulting to 2048) - the linear layer dimensionality- dropout (floating point, defaulting to 0.1) - the dropout probability
- activation (either a string or a callable, defaulting to
F.relu
- the activation function
The forward
pass through the Transformer
model is not particularly enlightening in itself because sensibly it delegates all work to modular components, but it is here that we begin our overview of the Transformer before zooming in. The forward pass is as follows:
# arguments to Transformer.forward
# src: the sequence to the encoder (required).
# tgt: the sequence to the decoder (required).
# src_mask: the additive mask for the src sequence (optional).
# tgt_mask: the additive mask for the tgt sequence (optional).
# memory_mask: the additive mask for the encoder output (optional).
# src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
# tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
# memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
# ...some batching and dimension-related checks are performed before the main work is done...
# the operative section of the forward method
memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
output = self.decoder(
tgt,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
)
So clearly what’s going on is that we pass the input to the encoder, which by default is the TransformerEncoder
, a stack of TransformerEncoderLayer
layers, and then pass the output of that to the TransformerDecoder
, a stack of TransformerDecoderLayer
layers. A reminder that under the default case.
So let’s look now at the TransformerEncoder
.
The Transformer Encoder - TransformerEncoder
Luckily the implementation of “ is equally simple in its modularity, since all it does is make a few checks to see if nested tensors can be used, and then does the following
for mod in self.layers:
if convert_to_nested:
output = mod(output, src_mask=mask)
else:
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
if convert_to_nested:
output = output.to_padded_tensor(0.0)
if self.norm is not None:
output = self.norm(output)
We can throw out the code related to use of nested tensor and effectively think of this as:
- Loop over the layers in the encoder, applying each in turn
- Normalize if normalization is used
for mod in self.layers:
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
if self.norm is not None:
output = self.norm(output)
Similarly, we can see from the TransformerEncoder
’s __init__
method that the layers are literally just clones2, the number of layers is saved, whether to apply normalization is saved and whether to enable_nested_tensor
use is saved, defaulting to False
.
def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=False):
super(TransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.enable_nested_tensor = enable_nested_tensor
With that we can move fairly quickly on to look at the first meaningful or substantive piece of implementation that makes use of fundamental neural network building blocks like Linear
layers or Dropout
: TransformerEncoderLayer
.
The Transformer Encoder Layer - TransformerEncoderLayer
Let’s begin by reminding ourselves that the TransformerEncoderLayer
is made up of
- A self-attention block, that in this case consists of Multi-Head Attention, and…
- A feedforward (densely connected) network, which consists of two linear layers each followed by dropout layers, and the activation function immediately after the first linear layer (i.e. before the dropout layer)
The Encoder Self-Attention Block
# self-attention block
def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
x = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0]
return self.dropout1(x)
The Encoder Feed Forward Block
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
If we turn out eyes upward to the __init__
method, we see that all of these components are fairly straightforward, with the exception of the definition of self_attn
which is assigned the value MultiheadAttention
with some parametrization using d_model
, nhead
, and dropout
.
# in the imports we have
from .activation import MultiheadAttention
# an extract from the __init__ of TransformerEncoderLayer
super(TransformerEncoderLayer, self).__init__()
# define the self-attention block
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs)
# Implementation of Feedforward model
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
self.norm_first = norm_first
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
The dropout layers are in fact identical and could have been applied via the function torch.nn.functional.dropout
as I explain in the Note on PyTorch’s Dropout layer and Object-Oriented Interfaces - Dropout
. (Not so for the LayerNorm layers, norm1
and norm2
, which have trainable weights and biases.)
So all that remains to understand the encoder of the Transformer is to study the Multi-Head Attention module provided by PyTorch in pytorch/torch/nn/modules/activation.py
.
Multi-Head Attention in PyTorch
To be continued…
Forthcoming Sections
Upcoming sections on
- the use of stub files (see also) in the PyTorch codebase, for example
[pytorch/torch/_C/_functions.pyi](<https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/_C/_functions.pyi>)
amongst other[.pyi](<https://fileinfo.com/extension/pyi>)
files, and interfaces - the use of GNU Autoconf with
[.in](<https://fileinfo.com/extension/in>)
and[.ac](<https://fileinfo.com/extension/ac>)
files to build application code, for example[pytorch/torch/_C/_VariableFunctions.pyi.in](<https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/_C/_VariableFunctions.pyi.in>)
Note on PyTorch’s Dropout layer and Object-Oriented Interfaces - Dropout
A quick aside or note to say that whilst PyTorch provides a “stateful”, object-oriented interface to work with operations like dropout or even activation functions like the ReLU, these class-based interfaces are mostly thin wrappers around the torch.nn.functional
interface, as mentioned by Ed Yang in his excellent PyTorch Developer Podcast.
Let’s look at the head of the script defining Torch’s class-based Dropout
layer for such an example. I’ve abridged the code and added my comments inline:
# off the bat we import torch.nn.functional
from .. import functional as F
# we have an abstract base class that the specific dropout classes inherit from
class _DropoutNd(Module):
# some type hints are provided for convenience / readability
__constants__ = ['p', 'inplace']
p: float
inplace: bool
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
super(_DropoutNd, self).__init__()
# basically the init method is just a setter that does value validation on p
if p < 0 or p > 1:
raise ValueError("dropout probability has to be between 0 and 1, "
"but got {}".format(p))
self.p = p
self.inplace = inplace
class Dropout(_DropoutNd):
r"""During training, randomly zeroes some of the elements of the input
tensor with probability p using samples from a Bernoulli
distribution. Each channel will be zeroed out independently on every forward
call."""
# no __init__ is defined for Dropout -> it uses the inherited _DropoutNd.__init__
def forward(self, input: Tensor) -> Tensor:
return F.dropout(input, self.p, self.training, self.inplace)
We can see that this is mostly boilerplate and that the only necessary line is the following, which could be run directly passing the dropout probability value and flag determining whether the operation should be performed in-place or not:
F.dropout(input, self.p, self.training, self.inplace)
If you’re reading this and you know why such object-oriented interfaces are provided, beyond the need or desire for consistency from Torch users when defining their networks, please reach out to let me know.
Note on Normalization including LayerNorm and BatchNorm
We mentioned there was an added to the denominator of the LayerNorm layer in the Transformer. What does this mean and why is it needed?
A quick reminder that normalization, the re-scaling and centering of weights in neural networks, is used or required to improve optimization, although the mechanism for the techniques’ effectiveness is not well described. See for example How Does Batch Normalization Help Optimization? from Shibani Santurkar et al. (2018) which puts forward that batch normalization makes the optimization landscape significantly smoother, which induces a more predictive and stable behavior of the gradients, allowing for faster training.
Originally, batch normalization was proposed by Sergey Ioffe and Christian Szegedy in 2015 in Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. Its benefit is that it allows the use of much higher learning rates and makes convergence less sensitive to weights’ initialization. BatchNorm also acts as a regularizer and in some cases eliminates the need for dropout. In BatchNorm you compute the first and second centered moments over a minibatch, , since doing this over the training set is intractable:
Where and are learnable parameters.
Note that we have the in the denominator. Sometimes under conditions where the weights are almost identical, the standard deviation of the weights’ values can become very small (close to zero) and therefore the normalized output value, becomes undefined at .
This is the exact same as what is going on in Layer Normalization, proposed by Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton in 2016. Here we have the same equation (watch out: different notation) but this time we compute the summary statistics (mean and standard deviation) not over a minibatch but rather over a single layer.
This was originally proposed in the context of recurrent neural networks where the means of applying BatchNorm is not obvious. The authors changed up batch normalization by instead computing the mean and variance used for normalization from all of the summed inputs to the neurons in a layer on a single training case, and kept a learnable adaptive bias and gain for each neuron which are applied after the normalization but before the non-linearity (the activation function).
Unlike batch normalization, layer normalization performs exactly the same computation at training and test times. (You can’t apply BatchNorm consistently at inference time or in general since batch sizes are variable and at inference or prediction time you might want to compute the model’s output for a single sample, where obviously these numbers are meaningless.)
Further Reading
One of the best ways to understand a concept is to approach it from several angles, for example to read about a concept from several authors or, in the case of deep learning engineering, to understand several different implementations of the same model. There is a Transformer implemented in the much faster and more modern, Jax, for WMT which contains all the same components.
- The Annotated Transformer (revised edition)
- The Illustrated Transformer by Jay Alammar
- Transformer posts by Ketan Doshi:
- Overview of functionality - How Transformers are used, and why they are better than RNNs. Components of the architecture, and behavior during Training and Inference
- How it works - Internal operation end-to-end. How data flows and what computations are performed, including matrix representations
- Transformers Explained Visually - Multi-head Attention, deep dive
- Why Attention Boosts Performance - Not just what Attention does but why it works so well. How does Attention capture the relationships between words in a sentence
- PyTorch Transformer docs i.e. for
torch.nn.Transformer
Footnotes
-
The full list of Python modules available at the time of writing (commit 664058fa83) comprised: activation, adaptive, batchnorm, channelshuffle, container, conv, distance, dropout, flatten, fold, instancenorm, lazy, linear, loss, module, normalization, padding, pixelshuffle, pooling, rnn, sparse, transformer, upsampling, utils, _functions, and an
__init__.py
. ↩ -
Clones, not copies, since then the same parameters would be referenced in memory because of Python’s copy by reference semantics ↩