Skip to main content

Multi-Query Attention (MQA)

Learn Multi-Query Attention (MQA), the optimization that shares keys and values across attention heads for massive memory savings.

Best viewed on desktop for optimal interactive experience

Multi-Query Attention: Maximum Efficiency Through Sharing

Multi-Query Attention (MQA) is a radical simplification of multi-head attention that shares a single set of keys and values across all query heads, achieving dramatic memory savings with acceptable quality trade-offs.

Interactive MQA Visualization

See how all query heads share the same keys and values:

The Core Insight

Traditional Multi-Head Attention (MHA) maintains separate K, V projections for each head:

  • Memory: O(n · h · d)
  • Redundancy: Similar patterns learned across heads

MQA's breakthrough: One K, V pair serves all heads

  • Memory: O(n · d)
  • Efficiency: Up to 32× KV cache reduction

How MQA Works

The Architecture

MQA(X) = Concat(head1, ..., headh)WO

Where each head computes:

headi = Attention(Qi, Kshared, Vshared)

Key differences from MHA:

  • Queries: Still head-specific (Q1, Q2, ..., Qh)
  • Keys/Values: Shared across all heads (Kshared, Vshared)

Implementation Details

Key Architecture Components

Projection Layers:

  • Queries: Separate projection for each head (n_heads × d_model)
  • Keys: Single shared projection (1 × head_dim)
  • Values: Single shared projection (1 × head_dim)
  • Output: Standard projection combining all heads

Forward Pass Steps:

  1. Project input X to multi-head queries Q
  2. Project input X to single K and V (shared across heads)
  3. Expand K, V to match all query heads via broadcasting
  4. Compute scaled dot-product attention for each head
  5. Concatenate head outputs and apply final projection

KV Cache Management:

  • Cache stores only one K,V pair per layer (not per head)
  • Cache shape: [n_layers, 1, max_seq_len, head_dim]
  • Dramatically reduced memory: ~87-96% smaller than MHA
  • Simple concatenation for incremental decoding

Memory Savings Analysis

KV Cache Comparison

For a model with 32 heads, 40 layers, sequence length 2048, head dimension 128:

MethodCache Size per TokenTotal for 2K ContextReduction
MHA2 × 40 × 32 × 128 = 327,680 floats640 MB0%
MQA2 × 40 × 1 × 128 = 10,240 floats20 MB96.9%

Batch Serving Benefits

Calculating Maximum Batch Size:

Given GPU memory and model size, the KV cache determines maximum batch capacity:

Example: 80GB A100 GPU with 30GB Model

  • Available memory for cache: 50GB
  • MHA cache per sequence: 640 MB → Max batch: ~78 sequences
  • MQA cache per sequence: 20 MB → Max batch: ~2,500 sequences

Impact: MQA enables 32× larger batches for the same hardware, dramatically improving serving throughput.

Quality Considerations

The Trade-off

MQA trades expressiveness for efficiency:

Parameter Count Comparison:

  • MHA: n_heads × d_model × d_head × 3 (separate Q, K, V per head)
  • MQA: n_heads × d_model × d_head + 2 × d_model × d_head (Q per head, shared K,V)
  • Reduction: Approximately 66% fewer parameters in attention layers

Empirical Results

From the original paper (Shazeer, 2019):

ModelAttention TypePerplexitySpeed
BaseMHA10.21.0×
BaseMQA10.41.8×
LargeMHA8.11.0×
LargeMQA8.32.4×

Key findings:

  • Small quality loss (~2% perplexity increase)
  • Significant speed gains (1.8-2.4×)
  • Benefits scale with model size

Training Strategies

Training from Scratch

Configuration Adjustments:

  • Increase model dimension by ~10% to compensate for reduced parameters
  • Keep same number of query heads
  • Use single K,V projection (n_kv_heads = 1)
  • Reduce learning rate by ~10% for stability
  • Increase training steps by ~10% to converge

Fine-tuning from MHA

Conversion Strategy:

  1. Copy query projections directly from MHA
  2. Average K,V projections across all heads to create shared K,V
  3. Keep output projection unchanged
  4. Fine-tune with lower learning rate (1e-5)
  5. Usually converges within 5-10% of original training

Benefits of Uptraining:

  • Start from strong pretrained weights
  • Faster than training from scratch
  • Often achieves better final quality

Optimization Techniques

1. Fused Kernels

Efficient Broadcasting Strategy:

  • Expand K,V in-place using unsqueeze(1) for memory efficiency
  • Compute attention scores with broadcasting instead of explicit expansion
  • Single fused kernel reduces memory allocations
  • TorchScript compilation for additional speedup

Benefits:

  • Reduced memory bandwidth usage
  • Fewer kernel launches
  • Better cache utilization
  • 20-30% faster than naive implementation

2. Memory Layout Optimization

Optimized Weight Storage:

  • Pack query weights for better cache locality
  • Store K,V weights contiguously
  • Use einsum for efficient multi-head query computation
  • Minimize memory fragmentation

Performance Impact:

  • Improved CPU-to-GPU transfer efficiency
  • Better CUDA kernel occupancy
  • Reduced register pressure

3. Dynamic Batching

Adaptive Batch Processing:

  • Group requests by similar sequence lengths
  • Bucket lengths to reduce padding waste (e.g., round to nearest 128)
  • Pad sequences within each bucket
  • Process each bucket in single forward pass
  • Unpad results before returning

Throughput Gains:

  • Minimize wasted computation on padding
  • Better GPU utilization
  • 2-3× higher throughput for variable-length batches

When to Use MQA

Ideal Use Cases

Large-scale serving

  • Maximize throughput
  • Minimize latency
  • Large batch sizes

Memory-constrained environments

  • Edge devices
  • Mobile deployment
  • Limited GPU memory

Long context applications

  • Document processing
  • Multi-turn dialogue
  • Code generation

When to Avoid

Research/Experimentation

  • Need best quality
  • Small-scale deployment
  • Abundant resources

Latency-insensitive training

  • Pre-training from scratch
  • Have sufficient memory
  • Quality is paramount

Comparison with Alternatives

FeatureMHAGQA-8MQA
KV Parameters100%25%3.1%
Cache Size100%25%3.1%
QualityBestNear-bestGood
Inference Speed1.5×
ImplementationComplexModerateSimple

Production Examples

PaLM (Google)

Configuration:

  • d_model: 18,432
  • n_heads: 48
  • n_kv_heads: 1 (MQA)
  • layers: 118
  • context: 2,048 tokens

Memory Impact:

  • MHA cache: ~35.9 GB per sequence
  • MQA cache: ~0.75 GB per sequence
  • Savings: 97.9%

Falcon (TII)

Falcon-40B Configuration:

  • d_model: 8,192
  • n_heads: 64
  • attention_type: multi_query
  • n_kv_heads: 1

Performance: Consistently tops open LLM benchmarks while maintaining efficient serving

Common Pitfalls and Solutions

Pitfall 1: Incorrect Broadcasting

Problem: Using repeat() creates unnecessary copies, consuming extra memory

Solution: Use unsqueeze() with broadcasting for memory-efficient expansion

Impact: Reduces memory usage by 8× (for 8 heads) during forward pass

Pitfall 2: Cache Shape Mismatch

Problem: Caching K,V with head dimension wastes memory

Correct approach: Cache shape should be (batch, 1, seq_len, head_dim) not (batch, n_heads, seq_len, head_dim)

Impact: Ensures cache stays small and shared across heads

Pitfall 3: Learning Rate

Problem: Using same learning rate as MHA can cause training instability

Solution: Reduce learning rate by 40-50% for MQA training (e.g., 5e-4 instead of 1e-3)

Reason: Shared K,V parameters update from all heads simultaneously, requiring more careful optimization

Future Directions

Hybrid Approaches

  • First/last layers with MHA
  • Critical layers with GQA
  • Others with MQA

Learned Sharing

  • Dynamic K,V sharing
  • Attention-based routing
  • Adaptive compression

If you found this explanation helpful, consider sharing it with others.

Mastodon