Skip to main content

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

How FlashAttention computes exact attention with linear memory by tiling Q, K, and V into SRAM-resident blocks and fusing the softmax, avoiding the quadratic HBM cost of materializing the full attention matrix.

TL;DR

  • FlashAttention computes exact attention (no approximation) while cutting memory from O(N²) to O(N) by never writing the full N×N score matrix to HBM.
  • It tiles Q, K, and V into blocks small enough to live in fast on-chip SRAM, computes attention block-by-block, and fuses the softmax into the same kernel.
  • An online softmax — a running max and running sum rescaled as new blocks arrive — makes block-wise accumulation numerically exact.
  • The result is a 2–4× wall-clock speedup and dramatically longer feasible context lengths, because attention becomes memory-bound-friendly rather than memory-bound-limited.

The memory wall of attention

Standard attention forms the full N×N score matrix, applies softmax over each row, and multiplies by V. That intermediate matrix is quadratic in sequence length and lives in slow high-bandwidth memory (HBM). At long context the attention operator is memory-bound: the GPU spends most of its time moving the score matrix to and from HBM, not computing.

Tiling: keep the work in SRAM

FlashAttention splits Q, K, and V into blocks and walks the grid of (Q-block, K-block) tiles, loading only the few blocks it needs into SRAM. The full score matrix is never assembled — each tile is consumed and discarded once its contribution is accumulated.

Online softmax makes it exact

Processing keys block-by-block would normally break the row-wise softmax normalization. FlashAttention keeps a running max m and running sum ; when a new block raises the max, it rescales the accumulated output by exp(m_old − m_new). The final result is identical to a full-row softmax — exact, not approximate.

Why it mattered

FlashAttention reframed attention as an IO problem. By optimizing for bytes moved rather than FLOPs, it delivered exact attention that is both faster and far more memory-efficient, making long-context training and inference practical. Its tiling-and-fusion approach is now the default attention kernel in PyTorch, and its successors (FlashAttention-2/3, FlashDecoding) extend the same idea to inference and newer hardware.

If you found this paper review helpful, consider sharing it with others.

Mastodon