Skip to main content

Masked and Causal Attention

Summary
Learn how masked attention enables autoregressive generation and prevents information leakage in transformers and language models.

Masked and Causal Attention: Preserving Causality in Generation

Masked attention is the key mechanism that allows transformers to generate sequences one token at a time, ensuring models only attend to past tokens and maintaining the autoregressive property essential for generation tasks.

Interactive Masked Attention Visualization

Explore how masking patterns control information flow in attention:

Why Masked Attention?

The Information Leakage Problem

In standard self-attention, every position can attend to every other position:

  • During training: the model can "cheat" by looking at future tokens
  • During inference: future tokens don't exist yet

Solution: apply masks to prevent attending to future positions.

Types of Masking

  1. Causal Mask: for autoregressive generation (GPT-style)
  2. Padding Mask: for variable-length sequences
  3. Custom Masks: for specific attention patterns
  4. Combined Masks: multiple masks applied together

How Causal Masking Works

The Causal Mask

For a sequence of length n, the causal mask is a lower-triangular matrix:

Mij = \begin{cases} 0 & \text{if } i ≥ j \ -∞ & \text{if } i < j \end{cases}

This ensures position i can only attend to positions 0 through i.

Applying the Mask

  1. Compute attention scores: QK^\top / √(dk).
  2. Add the mask: forbidden positions become -∞.
  3. Apply softmax: -∞ becomes 0, so no attention flows to those positions.

Key insight: softmax(−∞) = 0 blocks attention to masked positions entirely. In practice use a large negative value (−1e9, or the dtype minimum) rather than a literal -∞ — softmax needs a very negative input to drive a weight to ~0, but actual infinities can produce NaNs.

Modern attention kernels fuse this masking into the computation — FlashAttention with is_causal=True applies the causal mask without ever materializing it.

Training vs Inference

Training: Parallel Processing

With the causal mask applied, the whole sequence is processed at once: logits for every position are computed in parallel and the loss is taken across all positions simultaneously (teacher forcing). The GPU parallelizes across the sequence dimension, which is why training is far faster than sequential generation.

Inference: Sequential Generation

  1. Start with the prompt tokens.
  2. Run a forward pass and read the last position's logits.
  3. Sample the next token (greedy, top-k, or nucleus sampling).
  4. Append it and repeat, stopping at a max length or an end-of-sequence token.

In practice each step reuses a KV cache so only the new token's K, V are computed instead of recomputing the whole sequence.

Types of Attention Masks

The causal mask is one pattern; the same "set forbidden scores to −∞ before softmax" trick gives several others.

Causal Mask

Lower-triangular — each position attends to itself and earlier positions. Used for language modeling and text generation (GPT-style).

Padding Mask

Masks positions beyond a sequence's real length so different-length inputs can be batched together efficiently.

Prefix-LM Mask

Bidirectional (full) attention within a given prefix, then causal for the generated continuation — used by T5 and UL2.

Block-Sparse Mask

Local attention within blocks plus global attention to a few special tokens, used by long-context models (Longformer, BigBird) to reach near-linear complexity instead of O(n²). See also sparse attention and sliding-window attention.

Common Applications

Language Modeling (GPT)

A stack of decoder blocks, each applying masked self-attention with a causal mask, pre-norm layer normalization, and residual connections around the attention and feedforward sub-layers.

Decoder in Seq2Seq

The decoder pairs masked (causal) self-attention on its own inputs with cross-attention to the encoder outputs — and the cross-attention is unmasked, since the decoder may see the entire source. Used in translation and summarization models like T5 and BART.

If you found this explanation helpful, consider sharing it with others.

Mastodon