Masked and Causal Attention: Preserving Causality in Generation
Masked attention is the key mechanism that allows transformers to generate sequences one token at a time, ensuring models only attend to past tokens and maintaining the autoregressive property essential for generation tasks.
Interactive Masked Attention Visualization
Explore how masking patterns control information flow in attention:
Why Masked Attention?
The Information Leakage Problem
In standard self-attention, every position can attend to every other position:
- During training: the model can "cheat" by looking at future tokens
- During inference: future tokens don't exist yet
Solution: apply masks to prevent attending to future positions.
Types of Masking
- Causal Mask: for autoregressive generation (GPT-style)
- Padding Mask: for variable-length sequences
- Custom Masks: for specific attention patterns
- Combined Masks: multiple masks applied together
How Causal Masking Works
The Causal Mask
For a sequence of length n, the causal mask is a lower-triangular matrix:
This ensures position i can only attend to positions 0 through i.
Applying the Mask
- Compute attention scores: QK^\top / √(dk).
- Add the mask: forbidden positions become -∞.
- Apply softmax: -∞ becomes 0, so no attention flows to those positions.
Key insight: softmax(−∞) = 0 blocks attention to masked positions entirely. In practice use a large negative value (−1e9, or the dtype minimum) rather than a literal -∞ — softmax needs a very negative input to drive a weight to ~0, but actual infinities can produce NaNs.
Modern attention kernels fuse this masking into the computation — FlashAttention with is_causal=True applies the causal mask without ever materializing it.
Training vs Inference
Training: Parallel Processing
With the causal mask applied, the whole sequence is processed at once: logits for every position are computed in parallel and the loss is taken across all positions simultaneously (teacher forcing). The GPU parallelizes across the sequence dimension, which is why training is far faster than sequential generation.
Inference: Sequential Generation
- Start with the prompt tokens.
- Run a forward pass and read the last position's logits.
- Sample the next token (greedy, top-k, or nucleus sampling).
- Append it and repeat, stopping at a max length or an end-of-sequence token.
In practice each step reuses a KV cache so only the new token's K, V are computed instead of recomputing the whole sequence.
Types of Attention Masks
The causal mask is one pattern; the same "set forbidden scores to −∞ before softmax" trick gives several others.
Causal Mask
Lower-triangular — each position attends to itself and earlier positions. Used for language modeling and text generation (GPT-style).
Padding Mask
Masks positions beyond a sequence's real length so different-length inputs can be batched together efficiently.
Prefix-LM Mask
Bidirectional (full) attention within a given prefix, then causal for the generated continuation — used by T5 and UL2.
Block-Sparse Mask
Local attention within blocks plus global attention to a few special tokens, used by long-context models (Longformer, BigBird) to reach near-linear complexity instead of O(n²). See also sparse attention and sliding-window attention.
Common Applications
Language Modeling (GPT)
A stack of decoder blocks, each applying masked self-attention with a causal mask, pre-norm layer normalization, and residual connections around the attention and feedforward sub-layers.
Decoder in Seq2Seq
The decoder pairs masked (causal) self-attention on its own inputs with cross-attention to the encoder outputs — and the cross-attention is unmasked, since the decoder may see the entire source. Used in translation and summarization models like T5 and BART.
Related concepts
Learn ALiBi, the position encoding method that adds linear biases to attention scores for exceptional length extrapolation in transformers.
How Flash Attention, Multi-Head Attention (MHA), Grouped-Query Attention (GQA), and Multi-Query Attention (MQA) compare — algorithm vs architecture, KV-cache memory, quality trade-offs, and how to choose for production transformer inference.
Learn about attention sinks, where LLMs concentrate attention on initial tokens, and how preserving them enables streaming inference.
Understand cross-attention, the mechanism that enables transformers to align and fuse information from different sources, sequences, or modalities.
Learn how Grouped-Query Attention (GQA) balances Multi-Head quality with Multi-Query efficiency for faster LLM inference.
Explore linear complexity attention mechanisms including Performer, Linformer, and other efficient transformers that scale to very long sequences.
