Linear Attention: From O(n²) to O(n)
Linear attention approximations break the quadratic complexity barrier of self-attention, enabling transformers to process sequences with millions of tokens while maintaining reasonable quality.
Interactive Linear Attention Explorer
Compare different linear attention methods and their trade-offs:
Visualizing the n² Bottleneck
Watch how the attention matrix grows quadratically
Standard Attention: (Q × K^T) × V
- • Must compute 8×8 = 64 attention scores
- • Must store 8×8 = 64 values in memory
- • For n=16,384 tokens: 268M values = 1GB+ memory just for attention!
The Kernel Trick: Reverse the Multiplication Order
Compute Q × (K^T × V) instead of (Q × K^T) × V
Linear Attention: φ(Q) × (φ(K)^T × V)
- • Creates 8×8 matrix
- • Memory: O(n²)
- • Computation: O(n²d)
- • Creates 4×4 matrix
- • Memory: O(d²)
- • Computation: O(nd²)
Different Ways to Achieve Linearity
See how each method transforms the matrices
The Linearization Problem
Standard attention has quadratic complexity:
The bottleneck is the n × n attention matrix. Linear methods approximate or avoid computing it explicitly.
Major Linear Attention Methods
1. Performer (FAVOR+)
Key Idea: Approximate softmax kernel using random Fourier features
Architecture:
- Use random projection matrix for feature mapping
- Orthogonal or Gaussian random features
- FAVOR+ algorithm for positive random features
Core Steps:
- Create random projection matrix (orthogonal preferred)
- Project Q and K to random feature space: φ(Q), φ(K)
- Apply feature map (ReLU + small constant for positivity)
- Compute using associative property: φ(Q) @ (φ(K)^T @ V) instead of (φ(Q) @ φ(K)^T) @ V
- Normalize with sum of K features
Complexity: O(nkd) where k = number of random features (typically 256)
2. Linformer
Key Idea: Attention is approximately low-rank, so project K,V to smaller dimension
Architecture:
- Learnable projection matrices E and F
- Project sequence length dimension from n → k
- Apply standard attention in lower-dimensional space
Core Steps:
- Compute Q, K, V normally
- Project K using E: K' = E × K (reduces seq_len from n to k)
- Project V using F: V' = F × V (reduces seq_len from n to k)
- Compute attention: softmax(Q × K'^T / √d_k) × V'
- Output has original sequence length n
Complexity: O(nkd) where k = projection dimension (typically 256-512)
3. Linear Transformer
Key Idea: Simple kernel trick with φ(x) = elu(x) + 1 feature map
Architecture:
- Minimal changes to standard attention
- Use ELU activation + 1 as feature map
- Leverages associative property of matrix multiplication
Core Steps:
- Compute Q, K, V normally
- Apply feature map: φ(Q) = elu(Q) + 1, φ(K) = elu(K) + 1
- Non-causal: Compute KV = φ(K)^T @ V, then output = φ(Q) @ KV
- Causal: Use cumulative sum trick for O(1) per-token generation
- Normalize by sum of K features
Complexity: O(nd²) where d = head_dim
Special Feature: Causal masking with constant-time decoding via cumulative sums
4. Cosformer
Key Idea: Decompose attention using cosine basis functions
Architecture:
- Apply cosine-based reweighting to Q and K
- Use position-dependent weights
- Combines benefits of positional encoding with linear complexity
Core Steps:
- Pre-compute cosine weights: cos(πi/2n) for position i
- Apply weights to Q and K
- Use linear attention mechanism on weighted features
- Normalize output
Complexity: O(nd log n) due to cosine computation
Benefit: Better captures positional information than simple feature maps
Comparison of Methods
| Method | Kernel φ(x) | Time | Memory | Quality | Key Idea |
|---|---|---|---|---|---|
| Performer | Random Fourier | O(nkd) | O(nk) | 95% | FAVOR+ algorithm |
| Linformer | Low-rank projection | O(nkd) | O(kd) | 92% | Attention is low-rank |
| Linear | elu(x) + 1 | O(nd²) | O(d²) | 90% | Simple feature map |
| Cosformer | Cosine reweighting | O(nd log n) | O(nd) | 93% | Decomposition |
| Flash* | Exact softmax | O(n²d) | O(n) | 100% | IO-aware |
*Flash is not linear but achieves linear memory through tiling.
Advanced Techniques
Hybrid Approaches
Combine Local and Global Attention:
- Split heads between local windowed attention and linear global attention
- Local attention: High quality for nearby tokens
- Global attention: Linear complexity for long-range dependencies
- Combine outputs additively or with learned gating
Benefits:
- Best of both worlds: quality + efficiency
- Tuneable quality-speed trade-off
- Used in Longformer, BigBird
Learned Feature Maps
Adaptive Kernel Learning:
- Replace fixed feature maps (elu, ReLU) with learned neural networks
- Small MLP projects queries/keys to feature space
- Use softplus activation to ensure positivity
- Can adapt to data distribution during training
Trade-off: More parameters but potentially better approximation
Implementation Tips
1. Numerical Stability
Add Small Epsilon to Normalizer:
- Prevent division by zero
- Use eps=1e-6 typically
- Apply before division: out / (normalizer + eps)
Stable Computation Order:
- Compute KV matrix first: einsum('...nd,...ne->...de', k, v)
- Then apply queries: einsum('...nd,...de->...ne', q, kv)
- Normalize with sum of k features
2. Memory-Efficient Training
Chunked Processing:
- Process queries in chunks of 64-128 tokens
- Compute full KV matrix once
- Apply each query chunk separately
- Concatenate results
Benefits:
- Reduces peak memory usage
- Allows training with limited GPU memory
- Minimal overhead
3. Causal Masking
Cumulative Sum Trick:
- Maintain running sums: kv_cumsum and k_cumsum
- For each position i, update sums with current K,V
- Compute output using only cumulative values
- O(1) complexity per token for generation
Implementation:
- Initialize sums to zero
- Loop through sequence positions
- Update sums incrementally
- Apply query to cumulative sums
Production Configurations
Performer
Typical Configuration:
- d_model: 1024
- n_heads: 16
- nb_features: 256 (random feature dimension)
- use_orthogonal: True (orthogonal random features)
- redraw_features: True (periodically redraw for better approximation)
- redraw_interval: 1000 steps
Linformer
Typical Configuration:
- d_model: 768
- n_heads: 12
- seq_len: 4096 (fixed maximum sequence length)
- k: 256 (projection dimension, ~6% of seq_len)
- share_projections: True (share E, F across layers to save parameters)
Linear Transformer
Typical Configuration:
- d_model: 512
- n_heads: 8
- feature_map: "elu" (can use "relu", "gelu")
- eps: 1e-6 (numerical stability)
- use_rotary_embeddings: True (better positional encoding)
Best Practices
Choosing the Right Method
Decision Guide:
Quality > 98% required:
- Use FlashAttention (exact attention with linear memory)
Very long sequences (>10K tokens):
- Use Performer (best scaling for ultra-long contexts)
Quality acceptable at 90-92%:
- Use Linear Transformer (simplest, fastest)
Tight memory budget:
- Use Linformer (most memory efficient, low-rank projections)
Balanced requirements:
- Use Cosformer (good quality-efficiency trade-off)
Combining with Other Techniques
Layerwise Strategy:
- Early layers: Local/sliding window attention (capture fine-grained patterns)
- Middle layers: Linear attention (efficient long-range modeling)
- Final layers: Full or sparse attention (high-quality output)
Benefits:
- Leverages strengths of each method
- Progressive refinement of representations
- Optimal quality-speed trade-off
