Cross-Attention in Transformer Architecture
Excerpt
Merge two embedding sequences regardless of modality, e.g., image with text in Stable Diffusion U-Net.
Cross attention is:
- an attention mechanism in Transformer architecture that mixes two different embedding sequences
- the two sequences must have the same dimension
- the two sequences can be of different modalities (e.g. text, image, sound)
- one of the sequences defines the output length as it plays a role of a query input
- the other sequence then produces key and value input
Cross-attention Applications
- image-text classification with Perceiver
- machine translation: cross-attention helps decoder predict next token of the translated text
Cross-attention vs Self-attention
Except for inputs, cross-attention calculation is the same as self-attention. Cross-attention combines asymmetrically two separate embedding sequences of same dimension, in contrast self-attention input is a single embedding sequence. One of the sequences serves as a query input, while the other as a key and value inputs. Alternative cross-attention in SelfDoc, uses query and value from one sequence, and key from the other.
The feed forward layer is related to cross-attention, except the feed forward layer does use softmax and one of the input sequences is static. Augmenting Self-attention with Persistent Memory paper shows that Feed Forward layer calculation made the same as self-attention.
Cross-attention Algorithm
- Let us have embeddings (token) sequences S1 and S2
- Calculate Key and Value from sequence S1
- Calculate Queries from sequence S2
- Calculate attention matrix from Keys and Queries
- Apply queries to the attention matrix
- Output sequence has dimension and length of sequence S2
In an equation:
Cross-attention Alternatives
Feature-wise Linear Modulation Layer is simpler alternative, which does not require the input to be a sequence and is linear complexity to to calculate.
Cross-attention Implementation
Have a look at CrossAttention implementation in Diffusers library, which can generate images with Stable Diffusion. In this case the cross-attention is used to condition transformers inside a UNet layer with a text prompt for image generation. The constructor shows, how we can also have different dimensions and if you step through with a debugger, you will also see the different sequence length between the two modalities .
class CrossAttention(nn.Module):
r"""
A cross attention layer.
Parameters:
query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
"""
In particular at this part, where you can see how query, key, and value interact. This is encoder-decoder architecture, so query is created from encoder hidden states.
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
Cross-Attention in Popular Architectures
Cross-attention is widely used in encoder-decoder or multi-modality use cases.
Cross-Attention in Transformer Decoder
Cross-attention was described in the Transformer paper, but it was not given this name yet. Transformer decoding starts with full input sequence, but empty decoding sequence. Cross-attention introduces information from the input sequence to the layers of the decoder, such that it can predict the next output sequence token. The decoder then adds the token to the output sequence, and repeats this autoregressive process until the EOS token is generated.
Cross-Attention in Stable Diffusion
Stable diffusion uses cross-attention for image generation to condition transformers with a text prompt inside the denoising U-Net layer.
Cross-Attention in Perceiver IO
Perceiver IO is a general-purpose multi-modal architecture that can handle wide variety of inputs as well as outputs. Perceiver can be applied to for example image-text classification. Perceiver IO uses cross-attention for merging:
- multimodal input sequences (e.g. image, text, audio) into a low dimensional latent sequence
- âoutput queryâ or âcommandâ to decode the output value e.g. predict this masked word
Advantage of the Perceiver architecture is that in general you can work with very large inputs. Architecture Hierarchical Perceiver has ability to process even longer input sequences by splitting into subsequences and then merging them. Hierarchical Perceiver also learns the positional encodings with a separate training step with a reconstruction loss.
Cross-Attention in SelfDoc
In Selfdoc, cross-attention is integrated in a special way. First step of their Cross-Modality Encoder, instead uses value and query from sequence A and then key from the sequence B.
Other Cross-Attention Examples
- DeepMindâs RETRO Transformer uses cross-attention to incorporate the database retrived sequences
- Code example: HuggingFace BERT (key, value are from the encoder, while query is from the decoder)
- CrossVit - here only simplified cross-attention is used
- On the Strengths of Cross-Attention in Pretrained Transformers for Machine Translation