KV Cache: Accelerating LLM Inference
The Key-Value (KV) cache is a critical optimization that makes autoregressive text generation practical. Without it, generating each new token would require recomputing attention for the entire sequence - making long-form generation prohibitively expensive.
Interactive KV Cache Demonstration
See how caching dramatically reduces computation during text generation:
The Recomputation Problem
Without Caching
During autoregressive generation, each token generation step requires:
Total computation for generating n tokens:
Where:
- L = number of layers
- H = number of attention heads
- d = head dimension
- n = sequence length
With KV Caching
Each step only computes for the new token:
Total computation becomes linear:
How KV Cache Works
The Attention Mechanism
Standard attention computation:
def attention(Q, K, V): # Q: [batch, seq_len, d_model] # K, V: [batch, seq_len, d_model] scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) weights = torch.softmax(scores, dim=-1) output = torch.matmul(weights, V) return output
Caching Strategy
During generation with caching:
class KVCache: def __init__(self, max_length, n_layers, n_heads, head_dim): self.cache_k = torch.zeros(n_layers, max_length, n_heads, head_dim) self.cache_v = torch.zeros(n_layers, max_length, n_heads, head_dim) self.seq_len = 0 def update(self, layer_idx, new_k, new_v): # new_k, new_v: [1, n_heads, head_dim] for single new token self.cache_k[layer_idx, self.seq_len] = new_k self.cache_v[layer_idx, self.seq_len] = new_v self.seq_len += 1 def get(self, layer_idx): return (self.cache_k[layer_idx, :self.seq_len], self.cache_v[layer_idx, :self.seq_len])
Generation Loop
def generate_with_cache(prompt_tokens, max_new_tokens): kv_cache = KVCache(...) # Process prompt (can be parallelized) hidden_states = embed(prompt_tokens) for layer in model.layers: k, v = layer.compute_kv(hidden_states) kv_cache.update(layer.idx, k, v) hidden_states = layer(hidden_states) # Generate new tokens (sequential) for _ in range(max_new_tokens): new_token = sample(hidden_states[-1]) new_hidden = embed(new_token) for layer in model.layers: # Only compute KV for new token new_k, new_v = layer.compute_kv(new_hidden) kv_cache.update(layer.idx, new_k, new_v) # Reuse cached KV for attention cached_k, cached_v = kv_cache.get(layer.idx) new_hidden = layer.attention(new_hidden, cached_k, cached_v) tokens.append(new_token) return tokens
Memory Requirements
Cache Size Formula
Where:
- 2 = Keys + Values
- B = batch size
- L = number of layers
- S = sequence length
- H = number of heads
- D = head dimension
Real-World Examples
| Model | Context | Layers | Heads | Dim | Cache Size (FP16) |
|---|---|---|---|---|---|
| GPT-3 175B | 2K | 96 | 96 | 128 | 4.7 GB |
| LLaMA-7B | 4K | 32 | 32 | 128 | 2.0 GB |
| LLaMA-70B | 4K | 80 | 64 | 128 | 10.0 GB |
| GPT-4* | 32K | 120 | 120 | 128 | 94.0 GB |
*Estimated configuration
Optimization Techniques
1. Multi-Query Attention (MQA)
Share keys and values across heads:
# Standard Multi-Head Attention K, V: [batch, seq_len, n_heads, head_dim] # Multi-Query Attention K, V: [batch, seq_len, 1, head_dim] # Shared across heads
Memory reduction: H × where H is number of heads
2. Grouped-Query Attention (GQA)
Balance between MHA and MQA by sharing key-value heads across groups of query heads, significantly reducing KV cache memory. For more on how Grouped-Query Attention achieves this balance, see the dedicated concept page.
# n_kv_heads < n_heads K, V: [batch, seq_len, n_kv_heads, head_dim] # Each KV head serves n_heads/n_kv_heads query heads
Used in: LLaMA-2 70B, Mistral
3. Sliding Window Cache
Only cache recent tokens:
class SlidingKVCache: def __init__(self, window_size): self.window_size = window_size self.cache = deque(maxlen=window_size) def update(self, k, v): self.cache.append((k, v)) # Automatically drops oldest when full
4. PagedAttention
Virtual memory for KV cache:
- Non-contiguous memory allocation
- Dynamic growth
- Memory sharing across sequences
- Used in vLLM
Cache Management Strategies
1. Static Allocation
Pre-allocate maximum size:
cache = torch.zeros(max_batch, max_length, ...)
- ✅ Simple, no fragmentation
- ❌ Wastes memory for short sequences
2. Dynamic Growth
Grow cache as needed:
if seq_len > cache_size: cache = torch.cat([cache, new_allocation], dim=1)
- ✅ Memory efficient
- ❌ Reallocation overhead
3. Block-wise Allocation
Allocate in fixed-size blocks:
blocks = [] while need_more_space: blocks.append(allocate_block())
- ✅ Balance of efficiency and flexibility
- ❌ More complex implementation
Advanced Techniques
Quantized KV Cache
Reduce precision to save memory:
def quantize_cache(cache, bits=8): scale = cache.abs().max() / (2**(bits-1) - 1) quantized = torch.round(cache / scale).to(torch.int8) return quantized, scale def dequantize_cache(quantized, scale): return quantized.to(torch.float16) * scale
Memory savings: 50% (FP16 → INT8) or 75% (FP16 → INT4)
Hierarchical Caching
Cache at multiple granularities:
- Token-level: Full resolution
- Segment-level: Compressed representations
- Document-level: Summary embeddings
Speculative Decoding Cache
Separate caches for draft and target models:
draft_cache = KVCache(small_model_config) target_cache = KVCache(large_model_config) # Generate with draft model draft_tokens = draft_model.generate(prompt, cache=draft_cache) # Verify with target model verified_tokens = target_model.verify(draft_tokens, cache=target_cache)
Performance Impact
Generation Speed
Without cache:
- Time per token: O(n) where n is current sequence length
- Total time: O(n²)
- Example: 1000 tokens = 500,000 attention computations
With cache:
- Time per token: O(1)
- Total time: O(n)
- Example: 1000 tokens = 1000 attention computations
Throughput Comparison
| Sequence Length | Without Cache | With Cache | Speedup |
|---|---|---|---|
| 128 tokens | 1.2 sec | 0.13 sec | 9.2× |
| 512 tokens | 19.5 sec | 0.51 sec | 38.2× |
| 2048 tokens | 312 sec | 2.05 sec | 152× |
| 8192 tokens | ~5000 sec | 8.2 sec | 610× |
Common Issues and Solutions
1. Memory Overflow
Problem: Cache exceeds available memory Solutions:
- Use sliding window cache
- Implement cache eviction
- Quantize cache values
- Use CPU offloading
2. Cache Invalidation
Problem: Prompt changes require cache reset Solutions:
- Incremental cache updates
- Prefix caching for common prompts
- Cache versioning
3. Batch Processing
Problem: Different sequences have different lengths Solutions:
- Padding and masking
- Dynamic batching
- Continuous batching (vLLM)
Implementation Best Practices
1. Memory Pool Management
class CachePool: def __init__(self, total_memory): self.pool = [] self.allocated = {} def allocate(self, request_id, size): if size <= self.available(): cache = self._get_from_pool(size) self.allocated[request_id] = cache return cache return None def free(self, request_id): cache = self.allocated.pop(request_id) self._return_to_pool(cache)
2. Cache Warming
Pre-compute common prefixes:
common_prefixes = ["You are a helpful", "Please analyze", ...] for prefix in common_prefixes: cache = compute_kv_cache(prefix) cache_store[hash(prefix)] = cache
3. Monitoring
Track cache metrics:
- Hit rate
- Memory usage
- Eviction rate
- Recomputation frequency
Related Concepts
- Context Windows - Maximum cache size limits
- Flash Attention - Memory-efficient attention
- Tokenization - What gets cached
- Attention Mechanisms - What we're caching
Conclusion
KV caching transforms LLM inference from quadratic to linear complexity, making real-time text generation feasible. Understanding cache dynamics is essential for optimizing inference performance and managing memory constraints in production deployments.
