Cross-Attention: Connecting Different Information Sources
Cross-attention is the bridge that allows transformers to align and combine information from different sequences, making it fundamental for tasks like translation, image captioning, and multimodal understanding.
Interactive Cross-Attention Visualization
Explore how queries from one sequence attend to keys and values from another:
What is Cross-Attention?
Unlike self-attention where Q, K, and V come from the same sequence, cross-attention uses:
- Queries (Q) from one sequence (e.g., decoder)
- Keys (K) and Values (V) from another sequence (e.g., encoder)
Why Cross-Attention?
The Connection Problem
Many tasks require relating two different sequences:
- Translation: Source language → Target language
- Image Captioning: Image features → Text description
- VQA: Question + Image → Answer
- Speech Recognition: Audio → Text
Cross-attention provides the mechanism to:
- Align elements between sequences
- Transfer information from source to target
- Learn relationships across modalities
How Cross-Attention Works
Step-by-Step Process
1. Extract Representations from Both Sequences
The first step involves obtaining hidden representations from two different sources:
Source Sequence Processing:
- The source (e.g., English sentence in translation) passes through an encoder
- Produces contextualized representations: shape [batch_size, source_length, model_dimension]
- Each position contains information about the token and its context
- These become the "knowledge base" that the decoder will query
Target Sequence Processing:
- The target (e.g., French sentence being generated) processes through self-attention first
- Creates target hidden states: shape [batch_size, target_length, model_dimension]
- Each position knows about previous target tokens (via causal masking)
- These become the "queries" that ask questions of the source
2. Generate Queries, Keys, and Values
This is where cross-attention differs fundamentally from self-attention:
Queries from Target:
- Apply learned weight matrix Wq to target hidden states
- Transforms target representations into "questions": What information do I need?
- Dimensionality: [batch_size, target_length, key_dimension]
- Each target position creates its own query
Keys from Source:
- Apply learned weight matrix Wk to source hidden states
- Transforms source into "indices": What information is available?
- Dimensionality: [batch_size, source_length, key_dimension]
- Must match query dimension for dot product compatibility
Values from Source:
- Apply learned weight matrix Wv to source hidden states
- Transforms source into "content": The actual information to retrieve
- Dimensionality: [batch_size, source_length, value_dimension]
- Contains the information that will be aggregated
3. Compute Cross-Attention
The attention mechanism connects target queries to source keys/values:
Scoring Phase:
- Compute dot product between each query and all keys
- Formula: Q · KT / √(dk)
- Results in attention scores: [batch_size, target_length, source_length]
- Each target position gets a score for every source position
- It is the same scaled dot-product attention as self-attention — scaling by √(dk) prevents gradient issues
Attention Weights:
- Apply softmax across source dimension
- Converts scores to probability distribution
- Each target position's weights sum to 1
- High weights indicate strong alignment between target and source positions
Information Aggregation:
- Multiply attention weights by value vectors
- Weighted sum: each target position gets a mix of source values
- Output shape: [batch_size, target_length, value_dimension]
- Result is a context vector informed by relevant source positions
Cross-Attention in Transformers
Encoder-Decoder Architecture
The transformer decoder integrates cross-attention as a critical middle layer, positioned strategically between self-attention and feed-forward networks.
Decoder Layer Structure:
Layer 1: Masked Self-Attention
- Processes the target sequence independently
- Uses causal masking (can only see previous positions)
- Purpose: Build contextual understanding of target sequence
- Q, K, V all come from decoder input
- Enables each position to integrate information from earlier target tokens
Layer 2: Cross-Attention (The Bridge)
- Connects decoder to encoder representations
- Query source: Output from masked self-attention
- Represents "what the decoder currently understands"
- Asks: "What source information do I need?"
- Key/Value source: Encoder output
- Represents "what information is available from source"
- Provides: The actual content to retrieve
- Purpose: Pull relevant information from source sequence
- No masking on source (can attend to all source positions)
Layer 3: Feed-Forward Network
- Processes the fused representation
- Applies non-linear transformations
- Integrates cross-attention context with target understanding
Why This Order Matters:
- Self-attention first ensures the decoder understands the target sequence structure
- Cross-attention second allows informed querying based on target context
- FFN last integrates both streams of information
Information Flow Breakdown:
Source Path: Input Text → Tokenization → Encoder Stack → Encoder Output (K, V) ↓ Stored for all decoder steps Target Path: Previous Outputs → Embedding → Masked Self-Attention → Updated Target Representation (Q) ↓ Cross-Attention Layer (Q meets K, V from encoder) ↓ Context-Aware Representation ↓ Feed-Forward Network ↓ Final Decoder Output
Key Insight: The encoder output is computed once and reused for all decoding steps. Cross-attention allows each target position to dynamically select which source positions are relevant, creating adaptive alignment.
Types of Cross-Attention
1. Encoder-Decoder Attention
Classic transformer architecture:
- Used in: Machine translation, summarization
- Q: Decoder states
- K, V: Encoder outputs
2. Multi-Modal Cross-Attention
Between different modalities:
- Vision-Language: CLIP, DALL-E
- Audio-Text: Whisper, Wav2Vec
- Video-Text: Video understanding models
3. Memory Attention
Attending to external memory:
- Retrieval-Augmented: RAG models
- Memory Networks: Neural Turing Machines
- Q: Current state
- K, V: Memory bank
4. Cross-Attention in Diffusion
Conditioning image generation:
- Q: Image features at timestep t
- K, V: Text embeddings
- Guides generation based on text
Related concepts
Learn ALiBi, the position encoding method that adds linear biases to attention scores for exceptional length extrapolation in transformers.
How Flash Attention, Multi-Head Attention (MHA), Grouped-Query Attention (GQA), and Multi-Query Attention (MQA) compare — algorithm vs architecture, KV-cache memory, quality trade-offs, and how to choose for production transformer inference.
Learn about attention sinks, where LLMs concentrate attention on initial tokens, and how preserving them enables streaming inference.
Learn how Grouped-Query Attention (GQA) balances Multi-Head quality with Multi-Query efficiency for faster LLM inference.
Explore linear complexity attention mechanisms including Performer, Linformer, and other efficient transformers that scale to very long sequences.
Learn how masked attention enables autoregressive generation and prevents information leakage in transformers and language models.
