Grouped-Query Attention: The Best of Both Worlds
Grouped-Query Attention (GQA) is an attention mechanism that strikes an optimal balance between the quality of Multi-Head Attention (MHA) and the efficiency of Multi-Query Attention (MQA), making it the preferred choice for modern large language models.
Interactive GQA Visualization
Explore how queries are grouped to share keys and values:
The Evolution: MHA → MQA → GQA
Multi-Head Attention (MHA)
- Every head has its own Q, K, V
- Best quality but highest memory usage
- KV cache size: 2 × L × H × D
Multi-Query Attention (MQA)
- All heads share single K, V
- Most efficient but quality degradation
- KV cache size: 2 × L × D
Grouped-Query Attention (GQA)
- Groups of heads share K, V
- Balanced quality and efficiency
- KV cache size: 2 × L × G × D
Where L = sequence length, H = num heads, G = num groups, D = head dimension
How GQA Works
The Grouping Mechanism
Instead of H separate KV pairs (MHA) or 1 shared KV pair (MQA), GQA uses G groups:
Configuration Example (32 heads, 8 groups):
- Total attention heads: 32
- Number of KV groups: 8
- Group size: 4 heads per group
- Memory savings: 75% reduction compared to MHA
Mathematical Formulation
For head h in group g:
Where:
- Qh is the query for head h
- Kg, Vg are shared keys/values for group g
- Group assignment: g = \lfloor h × G / H \rfloor
Implementation
Key Architecture Components
Projection Layers:
- Query projections: Separate for each head (num_heads × d_model)
- Key projections: Shared across groups (num_kv_heads × d_model)
- Value projections: Shared across groups (num_kv_heads × d_model)
- Output projection: Standard linear layer combining all heads
Forward Pass Steps:
- Project input to multi-head queries Q
- Project input to grouped K and V (fewer projections than queries)
- Repeat/expand K, V to match query head count using efficient views
- Compute scaled dot-product attention for each head
- Concatenate outputs and apply final projection
Key Constraint: Number of query heads must be divisible by number of KV heads to ensure even grouping
Efficient KV Cache Management
Cache Structure:
- Pre-allocated tensors for keys and values
- Shape: [num_layers, batch_size, num_kv_heads, max_seq_len, head_dim]
- Significantly smaller than MHA cache (uses num_kv_heads instead of num_heads)
- Incremental updates during autoregressive generation
Cache Operations:
- Store only num_kv_heads KV pairs per layer (not num_heads)
- Concatenate with previous cache for each new token
- Return expanded cache to match query head count during attention
- Enables efficient long-context inference
Memory and Compute Analysis
Memory Comparison
| Method | KV Cache Size | Relative Size | Quality |
|---|---|---|---|
| MHA | 2 × L × H × D | 100% | Best |
| GQA-8 | 2 × L × 8 × D | 25% (H=32) | Near-MHA |
| GQA-4 | 2 × L × 4 × D | 12.5% (H=32) | Good |
| MQA | 2 × L × 1 × D | 3.1% (H=32) | Degraded |
Compute Overhead
GQA adds minimal compute overhead:
- Projection: Same as MHA (different weight shapes)
- Repeat operation: O(1) using efficient tensor views (no data copying)
- Attention computation: Identical to MHA
Efficient K,V Expansion:
- Use unsqueeze and expand operations instead of copying
- Creates view over existing memory
- No additional memory allocation during forward pass
GQA in Production Models
Llama 2 Configuration
Llama 2 70B Architecture:
- Query heads: 64
- KV heads: 8 (GQA-8)
- Group size: 8 heads per group
- Context length: 4096 tokens
- Head dimension: 128
Memory Impact:
- MHA cache: ~8.4 GB per sequence
- GQA cache: ~1.0 GB per sequence
- 88% memory reduction
Mistral Configuration
Mistral 7B Architecture:
- Query heads: 32
- KV heads: 8 (GQA-8)
- Sliding window: 4096 tokens
- Combined with sliding window attention for double efficiency
- Enables efficient long-context processing
Training Considerations
Converting MHA to GQA
Uptraining Strategy - Convert pre-trained MHA models to GQA:
Conversion Steps:
- Keep query projections unchanged (all heads)
- Group K,V projections and average weights within each group
- Reshape averaged weights to match num_kv_heads dimension
- Fine-tune with lower learning rate to recover quality
- Typically converges within 5-10% of original training steps
Benefits:
- Leverage existing pre-trained weights
- Faster than training GQA from scratch
- Often achieves comparable final quality to original MHA
Training from Scratch
GQA can be trained directly:
- Similar convergence to MHA
- Slightly faster training (less memory movement)
- Regularization effect from parameter sharing
Choosing the Right Configuration
Decision Guide
Research and Experimentation:
- Use MHA when quality is paramount
- Acceptable when memory is abundant
Cloud Serving (Large Models greater than 30B):
- Use GQA-8 for memory-critical scenarios
- Balances quality with serving efficiency
Cloud Serving (Smaller Models less than 30B):
- Use GQA-16 or higher
- Can afford more groups for better quality
Edge Devices:
- Use MQA or GQA-4 for maximum efficiency
- Prioritize memory savings over quality
Batch Serving:
- Use GQA-8 to GQA-16
- Balance based on batch size and throughput requirements
Recommended Configurations
| Model Size | Batch Size | Recommended | Groups | Rationale |
|---|---|---|---|---|
| Less than 7B | 1 | MHA or GQA-16 | 16-32 | Quality focus |
| 7B-13B | 1-8 | GQA-8 | 8 | Balanced |
| 13B-70B | 1-4 | GQA-8 | 8 | Memory critical |
| More than 70B | 1-2 | GQA-4 or MQA | 4-1 | Extreme efficiency |
Performance Tips
1. Hardware Considerations
GPU-Specific Optimization:
- A100 GPUs: Use num_kv_heads as multiples of 8 (tensor core optimization)
- V100 GPUs: Prefer fewer groups (4-8) due to limited memory
- Align group sizes with hardware warp/thread block sizes
2. Dynamic Group Size
Adaptive Strategy:
- Use fewer groups for long sequences (>2048 tokens) to save memory
- Use more groups for short sequences to maintain quality
- Switch dynamically based on input length
- Balance memory usage with quality requirements
3. Optimize Memory Layout
Cache Organization:
- Use contiguous memory layout: [layer, batch, num_kv_heads, seq_len, head_dim]
- Optimize for sequential access patterns during autoregressive generation
- Pre-allocate cache to avoid dynamic memory allocation
- Keep cache aligned to cache line boundaries
Common Pitfalls
Pitfall 1: Uneven Group Sizes
Problem: Number of query heads not divisible by number of KV heads
- Example: 32 query heads with 7 KV heads (32 % 7 ≠ 0)
- Solution: Ensure num_heads is divisible by num_kv_heads
- Recommended: Use powers of 2 for both (8, 16, 32, 64)
Pitfall 2: Inefficient Repeat
Problem: Using operations that copy data instead of creating views
- Copying data wastes memory and bandwidth
- Solution: Use unsqueeze + expand + reshape for zero-copy expansion
- Creates view over existing memory without duplication
Pitfall 3: Wrong Cache Shape
Problem: Allocating cache with num_heads instead of num_kv_heads
- Wastes memory by storing unnecessary duplicates
- Solution: Cache shape should use num_kv_heads dimension
- Correct: [batch, seq_len, num_kv_heads, head_dim]
- Wrong: [batch, seq_len, num_heads, head_dim]
