Skip to main content

Cross-Attention: Bridging Different Modalities

Summary
Understand cross-attention, the mechanism that enables transformers to align and fuse information from different sources, sequences, or modalities.

Cross-Attention: Connecting Different Information Sources

Cross-attention is the bridge that allows transformers to align and combine information from different sequences, making it fundamental for tasks like translation, image captioning, and multimodal understanding.

Interactive Cross-Attention Visualization

Explore how queries from one sequence attend to keys and values from another:

What is Cross-Attention?

Unlike self-attention where Q, K, and V come from the same sequence, cross-attention uses:

  • Queries (Q) from one sequence (e.g., decoder)
  • Keys (K) and Values (V) from another sequence (e.g., encoder)
CrossAttention(Qtarget, Ksource, Vsource) = softmax(QtargetKsourceT√(dk))Vsource

Why Cross-Attention?

The Connection Problem

Many tasks require relating two different sequences:

  • Translation: Source language → Target language
  • Image Captioning: Image features → Text description
  • VQA: Question + Image → Answer
  • Speech Recognition: Audio → Text

Cross-attention provides the mechanism to:

  • Align elements between sequences
  • Transfer information from source to target
  • Learn relationships across modalities

How Cross-Attention Works

Step-by-Step Process

1. Extract Representations from Both Sequences

The first step involves obtaining hidden representations from two different sources:

Source Sequence Processing:

  • The source (e.g., English sentence in translation) passes through an encoder
  • Produces contextualized representations: shape [batch_size, source_length, model_dimension]
  • Each position contains information about the token and its context
  • These become the "knowledge base" that the decoder will query

Target Sequence Processing:

  • The target (e.g., French sentence being generated) processes through self-attention first
  • Creates target hidden states: shape [batch_size, target_length, model_dimension]
  • Each position knows about previous target tokens (via causal masking)
  • These become the "queries" that ask questions of the source

2. Generate Queries, Keys, and Values

This is where cross-attention differs fundamentally from self-attention:

Queries from Target:

  • Apply learned weight matrix Wq to target hidden states
  • Transforms target representations into "questions": What information do I need?
  • Dimensionality: [batch_size, target_length, key_dimension]
  • Each target position creates its own query

Keys from Source:

  • Apply learned weight matrix Wk to source hidden states
  • Transforms source into "indices": What information is available?
  • Dimensionality: [batch_size, source_length, key_dimension]
  • Must match query dimension for dot product compatibility

Values from Source:

  • Apply learned weight matrix Wv to source hidden states
  • Transforms source into "content": The actual information to retrieve
  • Dimensionality: [batch_size, source_length, value_dimension]
  • Contains the information that will be aggregated

3. Compute Cross-Attention

The attention mechanism connects target queries to source keys/values:

Scoring Phase:

  • Compute dot product between each query and all keys
  • Formula: Q · KT / √(dk)
  • Results in attention scores: [batch_size, target_length, source_length]
  • Each target position gets a score for every source position
  • It is the same scaled dot-product attention as self-attention — scaling by √(dk) prevents gradient issues

Attention Weights:

  • Apply softmax across source dimension
  • Converts scores to probability distribution
  • Each target position's weights sum to 1
  • High weights indicate strong alignment between target and source positions

Information Aggregation:

  • Multiply attention weights by value vectors
  • Weighted sum: each target position gets a mix of source values
  • Output shape: [batch_size, target_length, value_dimension]
  • Result is a context vector informed by relevant source positions

Cross-Attention in Transformers

Encoder-Decoder Architecture

The transformer decoder integrates cross-attention as a critical middle layer, positioned strategically between self-attention and feed-forward networks.

Decoder Layer Structure:

Layer 1: Masked Self-Attention

  • Processes the target sequence independently
  • Uses causal masking (can only see previous positions)
  • Purpose: Build contextual understanding of target sequence
  • Q, K, V all come from decoder input
  • Enables each position to integrate information from earlier target tokens

Layer 2: Cross-Attention (The Bridge)

  • Connects decoder to encoder representations
  • Query source: Output from masked self-attention
    • Represents "what the decoder currently understands"
    • Asks: "What source information do I need?"
  • Key/Value source: Encoder output
    • Represents "what information is available from source"
    • Provides: The actual content to retrieve
  • Purpose: Pull relevant information from source sequence
  • No masking on source (can attend to all source positions)

Layer 3: Feed-Forward Network

  • Processes the fused representation
  • Applies non-linear transformations
  • Integrates cross-attention context with target understanding

Why This Order Matters:

  1. Self-attention first ensures the decoder understands the target sequence structure
  2. Cross-attention second allows informed querying based on target context
  3. FFN last integrates both streams of information

Information Flow Breakdown:

Source Path: Input Text → Tokenization → Encoder Stack → Encoder Output (K, V) Stored for all decoder steps Target Path: Previous Outputs → Embedding → Masked Self-Attention → Updated Target Representation (Q) Cross-Attention Layer (Q meets K, V from encoder) Context-Aware Representation Feed-Forward Network Final Decoder Output

Key Insight: The encoder output is computed once and reused for all decoding steps. Cross-attention allows each target position to dynamically select which source positions are relevant, creating adaptive alignment.

Types of Cross-Attention

1. Encoder-Decoder Attention

Classic transformer architecture:

  • Used in: Machine translation, summarization
  • Q: Decoder states
  • K, V: Encoder outputs

2. Multi-Modal Cross-Attention

Between different modalities:

  • Vision-Language: CLIP, DALL-E
  • Audio-Text: Whisper, Wav2Vec
  • Video-Text: Video understanding models

3. Memory Attention

Attending to external memory:

  • Retrieval-Augmented: RAG models
  • Memory Networks: Neural Turing Machines
  • Q: Current state
  • K, V: Memory bank

4. Cross-Attention in Diffusion

Conditioning image generation:

  • Q: Image features at timestep t
  • K, V: Text embeddings
  • Guides generation based on text

If you found this explanation helpful, consider sharing it with others.

Mastodon