Skip to main content

Flash Attention: IO-Aware Exact Attention

Summary
Interactive Flash Attention visualization - the IO-aware algorithm achieving memory-efficient exact attention through tiling and kernel fusion.

Flash Attention: Revolutionizing Transformer Efficiency

Flash Attention is an IO-aware attention algorithm that achieves 2-9× speedup over standard attention while using orders of magnitude less memory - all while computing exact attention, not an approximation. For a formal analysis of why data movement dominates transformer costs, see the data movement analysis in transformers paper.

Interactive Flash Attention Explorer

Visualize how tiling and memory hierarchy optimization transform attention computation:

The Memory Bandwidth Bottleneck

The Hidden Problem

Modern GPUs have massive compute power but limited memory bandwidth:

ResourceA100 GPURatio
Compute312 TFLOPS-
Memory Bandwidth1.5 TB/s208:1
SRAM Bandwidth19 TB/s16:1

Most attention implementations are memory-bound, not compute-bound!

Standard Attention's Inefficiency

Attention(Q,K,V) = softmax(QKT√(d))V

Standard implementation:

  1. Compute S = QKT → Store N×N matrix in HBM
  2. Compute P = softmax(S) → Read/write N×N matrix
  3. Compute O = PV → Read N×N matrix

Total HBM accesses: O(N2)

Flash Attention's Innovation

Core Idea: Stay in SRAM

Instead of materializing the full attention matrix:

  1. Tile the computation into blocks that fit in SRAM
  2. Fuse operations to minimize memory transfers
  3. Recompute instead of storing intermediate results

The Tiling Strategy

def flash_attention(Q, K, V, block_size): N = Q.shape[0] num_blocks = ceil(N / block_size) O = zeros(N, d) for i in range(num_blocks): # Load Q block to SRAM Q_block = Q[i*block_size:(i+1)*block_size] # Initialize block outputs O_block = zeros(block_size, d) l_block = zeros(block_size) m_block = full(block_size, -inf) for j in range(num_blocks): # Load K, V blocks to SRAM K_block = K[j*block_size:(j+1)*block_size] V_block = V[j*block_size:(j+1)*block_size] # Compute attention in SRAM S_block = Q_block @ K_block.T / sqrt(d) # Online softmax update m_new = maximum(m_block, max(S_block)) P_block = exp(S_block - m_new) l_new = exp(m_block - m_new) * l_block + sum(P_block) # Update output O_block = exp(m_block - m_new) * O_block + P_block @ V_block m_block = m_new l_block = l_new # Write back to HBM O[i*block_size:(i+1)*block_size] = O_block / l_block return O

Real implementations fuse these steps into a single CUDA kernel and use warp-level reductions for the softmax — see kernel fusion for that technique.

Mathematical Foundation

Online Softmax

Key insight: Compute softmax without materializing the full matrix using the online softmax algorithm:

softmax(x)i = exi - mΣj exj - m

Where m = max(x). This can be computed in a single pass!

Safe Softmax Update

When processing blocks incrementally:

lnew = em^{old - mnew} · lold + lcurrent
Onew = em^{old - mnew} · lold · Oold + lcurrent · Ocurrentlnew

Memory Complexity

Standard attention:

Mstandard = O(N2 + Nd)

Flash Attention:

Mflash = O(N√(MSRAM) + Nd)

For typical values: 100× memory reduction!

IO Complexity Analysis

Standard Attention IO

IOstandard = O(Nd + N2)

Breakdown:

  • Load Q, K, V: 3Nd
  • Store/load S: 2N2
  • Store/load P: 2N2
  • Store O: Nd

Flash Attention IO

IOflash = O(N2dMSRAM)

Breakdown:

  • Load each block of K, V: O(Nd) total
  • Load Q once: O(Nd)
  • Write O once: O(Nd)

Reduction factor: MSRAMd (typically 16-64×)

Variants

Follow-ups refine the same idea: FlashAttention-2 parallelises across the sequence dimension for a further ~2× speedup, Flash-Decoding splits K/V across thread blocks to accelerate long-context inference, and block-sparse variants skip blocks the attention mask zeroes out.

When to use Flash Attention (and when standard attention is fine)

Flash Attention is the right call almost any time you control the kernel and the sequence length is long enough that HBM bandwidth is the bottleneck. It is the wrong call when the model is already tiny, when you cannot guarantee the hardware/dtype combination it supports, or when a single tweaked kernel call replaces a path you have not budgeted to debug.

Use Flash Attention when:

  • Sequence length is ≥ 1024 tokens and you train or serve at scale — the memory and speed wins compound with context length.
  • You are running on Ampere (A100/A10/A40), Ada (RTX 40-series), or Hopper (H100) with FP16 or BF16 — these are the paths the kernel is optimised for.
  • You are already using HuggingFace Transformers, PyTorch ≥ 2.0 SDPA, or vLLM — you get Flash Attention by setting attn_implementation="flash_attention_2" or relying on torch.nn.functional.scaled_dot_product_attention to pick the FA backend automatically.
  • You need longer context windows than vanilla attention can fit in HBM at your batch size — Flash Attention's O(N) memory is the only way to push past that wall without sharding.

Use standard attention (or a different optimisation) when:

  • Sequence length is < 512 — kernel-launch and bookkeeping overhead can erase the bandwidth win on short sequences.
  • You need fully custom attention biases, masks, or score modifications that the FA kernel does not support — try torch.nn.functional.scaled_dot_product_attention first; it falls back gracefully.
  • You are on FP32 or a GPU older than Turing — FA2 paths assume FP16/BF16 + SM ≥ 8.0. On consumer hardware that does not qualify, the standard PyTorch kernel is often faster than a forced FP32 FA fallback.
  • The bottleneck is decode-time KV-cache reads, not the attention compute — reach for paged KV cache or sliding-window attention instead.

When in doubt, switch to PyTorch's scaled_dot_product_attention with the default backend selector. It will pick the FA path when it is faster and fall back when it is not — that is almost always the right default.

If you found this explanation helpful, consider sharing it with others.

Mastodon