Skip to main content

Grouped-Query Attention (GQA)

Learn how Grouped-Query Attention (GQA) balances Multi-Head quality with Multi-Query efficiency for faster LLM inference.

Best viewed on desktop for optimal interactive experience

Grouped-Query Attention: The Best of Both Worlds

Grouped-Query Attention (GQA) is an attention mechanism that strikes an optimal balance between the quality of Multi-Head Attention (MHA) and the efficiency of Multi-Query Attention (MQA), making it the preferred choice for modern large language models.

Interactive GQA Visualization

Explore how queries are grouped to share keys and values:

The Evolution: MHA → MQA → GQA

Multi-Head Attention (MHA)

  • Every head has its own Q, K, V
  • Best quality but highest memory usage
  • KV cache size: 2 × L × H × D

Multi-Query Attention (MQA)

  • All heads share single K, V
  • Most efficient but quality degradation
  • KV cache size: 2 × L × D

Grouped-Query Attention (GQA)

  • Groups of heads share K, V
  • Balanced quality and efficiency
  • KV cache size: 2 × L × G × D

Where L = sequence length, H = num heads, G = num groups, D = head dimension

How GQA Works

The Grouping Mechanism

Instead of H separate KV pairs (MHA) or 1 shared KV pair (MQA), GQA uses G groups:

Configuration Example (32 heads, 8 groups):

  • Total attention heads: 32
  • Number of KV groups: 8
  • Group size: 4 heads per group
  • Memory savings: 75% reduction compared to MHA

Mathematical Formulation

For head h in group g:

Attentionh = softmax(Qh KgT√(dk))Vg

Where:

  • Qh is the query for head h
  • Kg, Vg are shared keys/values for group g
  • Group assignment: g = \lfloor h × G / H \rfloor

Implementation

Key Architecture Components

Projection Layers:

  • Query projections: Separate for each head (num_heads × d_model)
  • Key projections: Shared across groups (num_kv_heads × d_model)
  • Value projections: Shared across groups (num_kv_heads × d_model)
  • Output projection: Standard linear layer combining all heads

Forward Pass Steps:

  1. Project input to multi-head queries Q
  2. Project input to grouped K and V (fewer projections than queries)
  3. Repeat/expand K, V to match query head count using efficient views
  4. Compute scaled dot-product attention for each head
  5. Concatenate outputs and apply final projection

Key Constraint: Number of query heads must be divisible by number of KV heads to ensure even grouping

Efficient KV Cache Management

Cache Structure:

  • Pre-allocated tensors for keys and values
  • Shape: [num_layers, batch_size, num_kv_heads, max_seq_len, head_dim]
  • Significantly smaller than MHA cache (uses num_kv_heads instead of num_heads)
  • Incremental updates during autoregressive generation

Cache Operations:

  • Store only num_kv_heads KV pairs per layer (not num_heads)
  • Concatenate with previous cache for each new token
  • Return expanded cache to match query head count during attention
  • Enables efficient long-context inference

Memory and Compute Analysis

Memory Comparison

MethodKV Cache SizeRelative SizeQuality
MHA2 × L × H × D100%Best
GQA-82 × L × 8 × D25% (H=32)Near-MHA
GQA-42 × L × 4 × D12.5% (H=32)Good
MQA2 × L × 1 × D3.1% (H=32)Degraded

Compute Overhead

GQA adds minimal compute overhead:

  • Projection: Same as MHA (different weight shapes)
  • Repeat operation: O(1) using efficient tensor views (no data copying)
  • Attention computation: Identical to MHA

Efficient K,V Expansion:

  • Use unsqueeze and expand operations instead of copying
  • Creates view over existing memory
  • No additional memory allocation during forward pass

GQA in Production Models

Llama 2 Configuration

Llama 2 70B Architecture:

  • Query heads: 64
  • KV heads: 8 (GQA-8)
  • Group size: 8 heads per group
  • Context length: 4096 tokens
  • Head dimension: 128

Memory Impact:

  • MHA cache: ~8.4 GB per sequence
  • GQA cache: ~1.0 GB per sequence
  • 88% memory reduction

Mistral Configuration

Mistral 7B Architecture:

  • Query heads: 32
  • KV heads: 8 (GQA-8)
  • Sliding window: 4096 tokens
  • Combined with sliding window attention for double efficiency
  • Enables efficient long-context processing

Training Considerations

Converting MHA to GQA

Uptraining Strategy - Convert pre-trained MHA models to GQA:

Conversion Steps:

  1. Keep query projections unchanged (all heads)
  2. Group K,V projections and average weights within each group
  3. Reshape averaged weights to match num_kv_heads dimension
  4. Fine-tune with lower learning rate to recover quality
  5. Typically converges within 5-10% of original training steps

Benefits:

  • Leverage existing pre-trained weights
  • Faster than training GQA from scratch
  • Often achieves comparable final quality to original MHA

Training from Scratch

GQA can be trained directly:

  • Similar convergence to MHA
  • Slightly faster training (less memory movement)
  • Regularization effect from parameter sharing

Choosing the Right Configuration

Decision Guide

Research and Experimentation:

  • Use MHA when quality is paramount
  • Acceptable when memory is abundant

Cloud Serving (Large Models greater than 30B):

  • Use GQA-8 for memory-critical scenarios
  • Balances quality with serving efficiency

Cloud Serving (Smaller Models less than 30B):

  • Use GQA-16 or higher
  • Can afford more groups for better quality

Edge Devices:

  • Use MQA or GQA-4 for maximum efficiency
  • Prioritize memory savings over quality

Batch Serving:

  • Use GQA-8 to GQA-16
  • Balance based on batch size and throughput requirements
Model SizeBatch SizeRecommendedGroupsRationale
Less than 7B1MHA or GQA-1616-32Quality focus
7B-13B1-8GQA-88Balanced
13B-70B1-4GQA-88Memory critical
More than 70B1-2GQA-4 or MQA4-1Extreme efficiency

Performance Tips

1. Hardware Considerations

GPU-Specific Optimization:

  • A100 GPUs: Use num_kv_heads as multiples of 8 (tensor core optimization)
  • V100 GPUs: Prefer fewer groups (4-8) due to limited memory
  • Align group sizes with hardware warp/thread block sizes

2. Dynamic Group Size

Adaptive Strategy:

  • Use fewer groups for long sequences (>2048 tokens) to save memory
  • Use more groups for short sequences to maintain quality
  • Switch dynamically based on input length
  • Balance memory usage with quality requirements

3. Optimize Memory Layout

Cache Organization:

  • Use contiguous memory layout: [layer, batch, num_kv_heads, seq_len, head_dim]
  • Optimize for sequential access patterns during autoregressive generation
  • Pre-allocate cache to avoid dynamic memory allocation
  • Keep cache aligned to cache line boundaries

Common Pitfalls

Pitfall 1: Uneven Group Sizes

Problem: Number of query heads not divisible by number of KV heads

  • Example: 32 query heads with 7 KV heads (32 % 7 ≠ 0)
  • Solution: Ensure num_heads is divisible by num_kv_heads
  • Recommended: Use powers of 2 for both (8, 16, 32, 64)

Pitfall 2: Inefficient Repeat

Problem: Using operations that copy data instead of creating views

  • Copying data wastes memory and bandwidth
  • Solution: Use unsqueeze + expand + reshape for zero-copy expansion
  • Creates view over existing memory without duplication

Pitfall 3: Wrong Cache Shape

Problem: Allocating cache with num_heads instead of num_kv_heads

  • Wastes memory by storing unnecessary duplicates
  • Solution: Cache shape should use num_kv_heads dimension
  • Correct: [batch, seq_len, num_kv_heads, head_dim]
  • Wrong: [batch, seq_len, num_heads, head_dim]

If you found this explanation helpful, consider sharing it with others.

Mastodon