Skip to main content

Multi-Head Attention

Summary
How multi-head attention runs scaled dot-product attention in parallel across several representation subspaces to build context-aware token embeddings.

Deconstructing Multi-Head Self-Attention

The Multi-Head Self-Attention layer is arguably the most critical component driving the success of Transformer models in domains like Natural Language Processing (NLP). It allows the model to dynamically weigh the importance of different tokens in a sequence when updating the representation of a specific token, thereby creating highly context-aware embeddings — the single-head form of this is covered in self-attention.

The "Multi-Head" aspect enables the model to perform this attention process multiple times in parallel, each time potentially focusing on different types of relationships or information subspaces.

This page provides an interactive, step-by-step walkthrough of this mechanism. Use the visualization below to follow the calculations and build your intuition.

Setting the Scene: Input Embeddings & Query Token

  • Input: The attention layer receives a sequence of input embeddings (vectors). These typically represent tokens (like words or subwords) and often already include positional encoding information from previous steps.
  • Query Token: We focus the calculation from the perspective of one token at a time, referred to as the Query token. The goal is to compute an updated, contextualized embedding for this specific token.
  • Interaction: In the component below, select one of the example text sequences (e.g., "Simple", "Question"). The initial input embeddings are displayed. You can click on a token's embedding in the "Input Embeddings" step to select it as the Query token (its border will highlight).

The Multi-Head Attention Calculation: Step-by-Step Exploration

Now, let's walk through the process. Use the step indicator (1, 2, 3...) or the 'Next'/'Prev' buttons in the component below to advance through each stage.

  1. Input Embeddings: The starting point. Each token has an associated input vector. (Observe the initial vectors in the component).

  2. Project Q, K, V: This is where the "Multi-Head" aspect begins. For each attention head, the input embeddings are linearly projected using separate learned weight matrices (Wq, Wk, Wv) to create Query (Q), Key (K), and Value (V) vectors specific to that head. This allows each head to potentially focus on different aspects of the input. (See the component generate Q, K, V vectors for each head – notice the different colors representing different heads).

  3. Head-Specific Attention (repeated for each head): Within each head, the chosen token's Query is compared against every token's Key, the scores are scaled by √d_k and passed through softmax, and the resulting weights combine the Value vectors into that head's output vector Z. This inner computation is exactly scaled dot-product attention — run independently per head, so each head can specialize. (Step through the per-head scores, scaling, softmax, and weighted sum in the component.)

  4. Concatenate Heads: After each head has independently computed its output vector (Z1, Z2,... Zh), these vectors are concatenated together into one larger vector. This combines the different perspectives learned by each head. (See the Z vectors from all heads being combined).

  5. Final Projection: The large concatenated vector is passed through one final linear projection layer (using weight matrix Wo). This mixes the information from all heads and projects it back down to the original embedding dimension, producing the final output embedding for the Query token. (Observe the final projection step).

  6. Output Embedding: This final vector is the contextualized embedding for the original Query token. It now incorporates information gathered from other relevant tokens in the sequence, weighted according to the attention mechanism across multiple heads. Compare this visually to the original input embedding.

Key Concepts & Further Exploration

  • Scaled Dot-Product Attention: The core calculation involving Q, K, V, scaling, and softmax.
  • Multi-Head: The strategy of running the attention mechanism multiple times in parallel with different projections (Wq, Wk, Wv) to capture diverse relationships. Concatenation and final projection integrate these diverse perspectives.
  • Contextualization: The key outcome is an output embedding that reflects not just the token itself, but also its context within the sequence, as determined by the attention weights.
  • Efficiency: Multi-query attention shares one K/V projection across all heads to shrink the KV cache at inference.

For concise definitions of these concepts, expand the 'Multi-Head Attention Concepts' section within the interactive tool.

Conclusion: The Engine of Context

Multi-Head Self-Attention is the sophisticated engine that allows Transformers to understand context. By projecting inputs into multiple Query, Key, and Value spaces (heads) and calculating weighted sums based on relevance, it produces rich, contextualized representations for each element in a sequence. This ability to dynamically weigh information across the entire input is fundamental to the power of Transformer models in various domains. Exploring the step-by-step calculation interactively helps demystify this complex but crucial mechanism.

If you found this explanation helpful, consider sharing it with others.

Mastodon