Skip to main content

KV Cache: The Secret to Fast LLM Inference

Summary
Interactive KV cache visualization - how key-value caching in LLM transformers enables fast text generation without quadratic recomputation.

KV Cache: Accelerating LLM Inference

The Key-Value (KV) cache is a critical optimization that makes autoregressive text generation practical. Without it, generating each new token would require recomputing attention for the entire sequence - making long-form generation prohibitively expensive.

Interactive KV Cache Demonstration

See how caching dramatically reduces computation during text generation:

The Recomputation Problem

Without Caching

During autoregressive generation, each token generation step requires:

Computestep\i = i × L × H × d2

Total computation for generating n tokens:

Total = Σi=1n i × L × H × d2 = O(n2)

Where:

  • L = number of layers
  • H = number of attention heads
  • d = head dimension
  • n = sequence length

With KV Caching

Each step only computes for the new token:

Computestep = L × H × d2 = O(1)

Total computation becomes linear:

Totalcached = n × L × H × d2 = O(n)

How KV Cache Works

The Attention Mechanism

Each layer computes scaled dot-product attention over the Query of the current token and the Keys/Values of all tokens so far. The Keys and Values of earlier tokens do not change as generation proceeds — which is exactly what makes them cacheable.

Caching Strategy

During generation with caching:

class KVCache: def __init__(self, max_length, n_layers, n_heads, head_dim): self.cache_k = torch.zeros(n_layers, max_length, n_heads, head_dim) self.cache_v = torch.zeros(n_layers, max_length, n_heads, head_dim) self.seq_len = 0 def update(self, layer_idx, new_k, new_v): # new_k, new_v: [1, n_heads, head_dim] for single new token self.cache_k[layer_idx, self.seq_len] = new_k self.cache_v[layer_idx, self.seq_len] = new_v self.seq_len += 1 def get(self, layer_idx): return (self.cache_k[layer_idx, :self.seq_len], self.cache_v[layer_idx, :self.seq_len])

Generation Loop

def generate_with_cache(prompt_tokens, max_new_tokens): kv_cache = KVCache(...) # Process prompt (can be parallelized) hidden_states = embed(prompt_tokens) for layer in model.layers: k, v = layer.compute_kv(hidden_states) kv_cache.update(layer.idx, k, v) hidden_states = layer(hidden_states) # Generate new tokens (sequential) for _ in range(max_new_tokens): new_token = sample(hidden_states[-1]) new_hidden = embed(new_token) for layer in model.layers: # Only compute KV for new token new_k, new_v = layer.compute_kv(new_hidden) kv_cache.update(layer.idx, new_k, new_v) # Reuse cached KV for attention cached_k, cached_v = kv_cache.get(layer.idx) new_hidden = layer.attention(new_hidden, cached_k, cached_v) tokens.append(new_token) return tokens

Memory Requirements

Cache Size Formula

MemoryKV = 2 × B × L × S × H × D × sizeof(dtype)

Where:

  • 2 = Keys + Values
  • B = batch size
  • L = number of layers
  • S = sequence length
  • H = number of heads
  • D = head dimension

Real-World Examples

ModelContextLayersHeadsDimCache Size (FP16)
GPT-3 175B2K96961284.7 GB
LLaMA-7B4K32321282.0 GB
LLaMA-70B4K806412810.0 GB
GPT-4*32K12012012894.0 GB

*Estimated configuration

Shrinking the cache

The cache is the bottleneck for long-context and high-throughput inference, so most systems shrink it:

  • Multi-query and grouped-query attention share one (or a few) K/V heads across all query heads — cutting cache size by the head-sharing ratio.
  • Sliding-window attention keeps only the most recent w tokens' K/V.
  • Quantizing the cache to INT8/INT4 halves or quarters its footprint.
  • PagedAttention (vLLM) stores the cache in non-contiguous blocks like OS virtual memory — no fragmentation, and prefixes can be shared across requests.

If you found this explanation helpful, consider sharing it with others.

Mastodon