Skip to main content

Linear Attention Approximations

Explore linear complexity attention mechanisms including Performer, Linformer, and other efficient transformers that scale to very long sequences.

Linear Attention: From O(n²) to O(n)

Linear attention approximations break the quadratic complexity barrier of self-attention, enabling transformers to process sequences with millions of tokens while maintaining reasonable quality.

Interactive Linear Attention Explorer

Compare different linear attention methods and their trade-offs:

The Linearization Problem

Standard attention has quadratic complexity:

Attention(Q,K,V) = softmax(QKT√(d))V

The bottleneck is the n × n attention matrix. Linear methods approximate or avoid computing it explicitly.

Major Linear Attention Methods

1. Performer (FAVOR+)

Key Idea: Approximate softmax kernel using random Fourier features

Architecture:

  • Use random projection matrix for feature mapping
  • Orthogonal or Gaussian random features
  • FAVOR+ algorithm for positive random features

Core Steps:

  1. Create random projection matrix (orthogonal preferred)
  2. Project Q and K to random feature space: φ(Q), φ(K)
  3. Apply feature map (ReLU + small constant for positivity)
  4. Compute using associative property: φ(Q) @ (φ(K)^T @ V) instead of (φ(Q) @ φ(K)^T) @ V
  5. Normalize with sum of K features

Complexity: O(nkd) where k = number of random features (typically 256)

2. Linformer

Key Idea: Attention is approximately low-rank, so project K,V to smaller dimension

Architecture:

  • Learnable projection matrices E and F
  • Project sequence length dimension from n → k
  • Apply standard attention in lower-dimensional space

Core Steps:

  1. Compute Q, K, V normally
  2. Project K using E: K' = E × K (reduces seq_len from n to k)
  3. Project V using F: V' = F × V (reduces seq_len from n to k)
  4. Compute attention: softmax(Q × K'^T / √d_k) × V'
  5. Output has original sequence length n

Complexity: O(nkd) where k = projection dimension (typically 256-512)

3. Linear Transformer

Key Idea: Simple kernel trick with φ(x) = elu(x) + 1 feature map

Architecture:

  • Minimal changes to standard attention
  • Use ELU activation + 1 as feature map
  • Leverages associative property of matrix multiplication

Core Steps:

  1. Compute Q, K, V normally
  2. Apply feature map: φ(Q) = elu(Q) + 1, φ(K) = elu(K) + 1
  3. Non-causal: Compute KV = φ(K)^T @ V, then output = φ(Q) @ KV
  4. Causal: Use cumulative sum trick for O(1) per-token generation
  5. Normalize by sum of K features

Complexity: O(nd²) where d = head_dim

Special Feature: Causal masking with constant-time decoding via cumulative sums

4. Cosformer

Key Idea: Decompose attention using cosine basis functions

Architecture:

  • Apply cosine-based reweighting to Q and K
  • Use position-dependent weights
  • Combines benefits of positional encoding with linear complexity

Core Steps:

  1. Pre-compute cosine weights: cos(πi/2n) for position i
  2. Apply weights to Q and K
  3. Use linear attention mechanism on weighted features
  4. Normalize output

Complexity: O(nd log n) due to cosine computation

Benefit: Better captures positional information than simple feature maps

Comparison of Methods

MethodKernel φ(x)TimeMemoryQualityKey Idea
PerformerRandom FourierO(nkd)O(nk)95%FAVOR+ algorithm
LinformerLow-rank projectionO(nkd)O(kd)92%Attention is low-rank
Linearelu(x) + 1O(nd²)O(d²)90%Simple feature map
CosformerCosine reweightingO(nd log n)O(nd)93%Decomposition
Flash*Exact softmaxO(n²d)O(n)100%IO-aware

*Flash is not linear but achieves linear memory through tiling.

Advanced Techniques

Hybrid Approaches

Combine Local and Global Attention:

  • Split heads between local windowed attention and linear global attention
  • Local attention: High quality for nearby tokens
  • Global attention: Linear complexity for long-range dependencies
  • Combine outputs additively or with learned gating

Benefits:

  • Best of both worlds: quality + efficiency
  • Tuneable quality-speed trade-off
  • Used in Longformer, BigBird

Learned Feature Maps

Adaptive Kernel Learning:

  • Replace fixed feature maps (elu, ReLU) with learned neural networks
  • Small MLP projects queries/keys to feature space
  • Use softplus activation to ensure positivity
  • Can adapt to data distribution during training

Trade-off: More parameters but potentially better approximation

Implementation Tips

1. Numerical Stability

Add Small Epsilon to Normalizer:

  • Prevent division by zero
  • Use eps=1e-6 typically
  • Apply before division: out / (normalizer + eps)

Stable Computation Order:

  • Compute KV matrix first: einsum('...nd,...ne->...de', k, v)
  • Then apply queries: einsum('...nd,...de->...ne', q, kv)
  • Normalize with sum of k features

2. Memory-Efficient Training

Chunked Processing:

  • Process queries in chunks of 64-128 tokens
  • Compute full KV matrix once
  • Apply each query chunk separately
  • Concatenate results

Benefits:

  • Reduces peak memory usage
  • Allows training with limited GPU memory
  • Minimal overhead

3. Causal Masking

Cumulative Sum Trick:

  • Maintain running sums: kv_cumsum and k_cumsum
  • For each position i, update sums with current K,V
  • Compute output using only cumulative values
  • O(1) complexity per token for generation

Implementation:

  • Initialize sums to zero
  • Loop through sequence positions
  • Update sums incrementally
  • Apply query to cumulative sums

Production Configurations

Performer

Typical Configuration:

  • d_model: 1024
  • n_heads: 16
  • nb_features: 256 (random feature dimension)
  • use_orthogonal: True (orthogonal random features)
  • redraw_features: True (periodically redraw for better approximation)
  • redraw_interval: 1000 steps

Linformer

Typical Configuration:

  • d_model: 768
  • n_heads: 12
  • seq_len: 4096 (fixed maximum sequence length)
  • k: 256 (projection dimension, ~6% of seq_len)
  • share_projections: True (share E, F across layers to save parameters)

Linear Transformer

Typical Configuration:

  • d_model: 512
  • n_heads: 8
  • feature_map: "elu" (can use "relu", "gelu")
  • eps: 1e-6 (numerical stability)
  • use_rotary_embeddings: True (better positional encoding)

Best Practices

Choosing the Right Method

Decision Guide:

Quality > 98% required:

  • Use FlashAttention (exact attention with linear memory)

Very long sequences (>10K tokens):

  • Use Performer (best scaling for ultra-long contexts)

Quality acceptable at 90-92%:

  • Use Linear Transformer (simplest, fastest)

Tight memory budget:

  • Use Linformer (most memory efficient, low-rank projections)

Balanced requirements:

  • Use Cosformer (good quality-efficiency trade-off)

Combining with Other Techniques

Layerwise Strategy:

  • Early layers: Local/sliding window attention (capture fine-grained patterns)
  • Middle layers: Linear attention (efficient long-range modeling)
  • Final layers: Full or sparse attention (high-quality output)

Benefits:

  • Leverages strengths of each method
  • Progressive refinement of representations
  • Optimal quality-speed trade-off

If you found this explanation helpful, consider sharing it with others.

Mastodon