Skip to main content

Flash Attention: IO-Aware Exact Attention

Interactive Flash Attention visualization - the IO-aware algorithm achieving memory-efficient exact attention through tiling and kernel fusion.

Best viewed on desktop for optimal interactive experience

Flash Attention: Revolutionizing Transformer Efficiency

Flash Attention is an IO-aware attention algorithm that achieves 2-9× speedup over standard attention while using orders of magnitude less memory - all while computing exact attention, not an approximation. For a formal analysis of why data movement dominates transformer costs, see the data movement analysis in transformers paper.

Interactive Flash Attention Explorer

Visualize how tiling and memory hierarchy optimization transform attention computation:

Detailed Algorithm Walkthrough

Explore the complete FlashAttention algorithm step-by-step with interactive visualizations:

The Memory Bandwidth Bottleneck

The Hidden Problem

Modern GPUs have massive compute power but limited memory bandwidth:

ResourceA100 GPURatio
Compute312 TFLOPS-
Memory Bandwidth1.5 TB/s208:1
SRAM Bandwidth19 TB/s16:1

Most attention implementations are memory-bound, not compute-bound!

Standard Attention's Inefficiency

Attention(Q,K,V) = softmax(QKT√(d))V

Standard implementation:

  1. Compute S = QKT → Store N×N matrix in HBM
  2. Compute P = softmax(S) → Read/write N×N matrix
  3. Compute O = PV → Read N×N matrix

Total HBM accesses: O(N2)

Flash Attention's Innovation

Core Idea: Stay in SRAM

Instead of materializing the full attention matrix:

  1. Tile the computation into blocks that fit in SRAM
  2. Fuse operations to minimize memory transfers
  3. Recompute instead of storing intermediate results

The Tiling Strategy

def flash_attention(Q, K, V, block_size): N = Q.shape[0] num_blocks = ceil(N / block_size) O = zeros(N, d) for i in range(num_blocks): # Load Q block to SRAM Q_block = Q[i*block_size:(i+1)*block_size] # Initialize block outputs O_block = zeros(block_size, d) l_block = zeros(block_size) m_block = full(block_size, -inf) for j in range(num_blocks): # Load K, V blocks to SRAM K_block = K[j*block_size:(j+1)*block_size] V_block = V[j*block_size:(j+1)*block_size] # Compute attention in SRAM S_block = Q_block @ K_block.T / sqrt(d) # Online softmax update m_new = maximum(m_block, max(S_block)) P_block = exp(S_block - m_new) l_new = exp(m_block - m_new) * l_block + sum(P_block) # Update output O_block = exp(m_block - m_new) * O_block + P_block @ V_block m_block = m_new l_block = l_new # Write back to HBM O[i*block_size:(i+1)*block_size] = O_block / l_block return O

Mathematical Foundation

Online Softmax

Key insight: Compute softmax without materializing the full matrix using the online softmax algorithm:

softmax(x)i = exi - mΣj exj - m

Where m = max(x). This can be computed in a single pass!

Safe Softmax Update

When processing blocks incrementally:

lnew = em^{old - mnew} · lold + lcurrent
Onew = em^{old - mnew} · lold · Oold + lcurrent · Ocurrentlnew

Memory Complexity

Standard attention:

Mstandard = O(N2 + Nd)

Flash Attention:

Mflash = O(N√(MSRAM) + Nd)

For typical values: 100× memory reduction!

IO Complexity Analysis

Standard Attention IO

IOstandard = O(Nd + N2)

Breakdown:

  • Load Q, K, V: 3Nd
  • Store/load S: 2N2
  • Store/load P: 2N2
  • Store O: Nd

Flash Attention IO

IOflash = O(N2dMSRAM)

Breakdown:

  • Load each block of K, V: O(Nd) total
  • Load Q once: O(Nd)
  • Write O once: O(Nd)

Reduction factor: MSRAMd (typically 16-64×)

Implementation Optimizations

1. Kernel Fusion

Fuse multiple operations into a single CUDA kernel (see kernel fusion for a deeper look at this technique):

__global__ void fused_attention_kernel( float* Q, float* K, float* V, float* O, int N, int d, int block_size ) { // Shared memory for tiles __shared__ float Q_smem[BLOCK_SIZE][HEAD_DIM]; __shared__ float K_smem[BLOCK_SIZE][HEAD_DIM]; __shared__ float V_smem[BLOCK_SIZE][HEAD_DIM]; // Compute attention with all ops fused // No intermediate writes to global memory }

2. Warp-Level Primitives

Utilize GPU warp-level operations:

  • __shfl_sync() for reductions
  • Tensor cores for matrix multiplies
  • Warp-wide reductions for softmax

3. Work Partitioning (Flash-2)

Better parallelization:

  • Split across sequence length dimension
  • Each warp handles different queries
  • Reduced synchronization overhead

Variants and Extensions

Flash Attention 2

Improvements over Flash Attention:

  1. Better parallelization: 2× speedup
  2. Reduced non-matmul FLOPs:
  3. Support for different head dimensions
  4. Multi-query/Grouped-query attention

Flash Decoding

Optimized for inference:

  • Parallel decoding for batch size 1
  • Split K,V across threadblocks
  • Efficient for long context generation

Block-Sparse Flash Attention

Combine with sparsity patterns:

def sparse_flash_attention(Q, K, V, sparsity_mask): # Only compute blocks where mask is non-zero for i, j in sparse_blocks(sparsity_mask): compute_block_attention(Q[i], K[j], V[j])

Performance Characteristics

Speedup by Sequence Length

Sequence LengthSpeedupMemory Savings
5122.4×10×
10243.0×15×
20483.9×20×
40965.1×40×
81927.6×60×
163849.5×100×

Hardware Scaling

Performance on different GPUs:

  • V100: 2-3× speedup (limited SRAM)
  • A100: 3-5× speedup (more SRAM)
  • H100: 5-9× speedup (faster HBM)

Backward Pass

Flash Attention also optimizes the backward pass:

  1. Recomputation: Don't store attention matrix
  2. Gradient accumulation: In SRAM
  3. Fused operations: Single kernel for backward
def flash_attention_backward(dO, Q, K, V, O): # Recompute attention blocks on-the-fly # Accumulate gradients in SRAM # Single pass through sequence dQ, dK, dV = compute_gradients_tiled(dO, Q, K, V, O) return dQ, dK, dV

Practical Considerations

When to Use Flash Attention

Use when:

  • Long sequences (>512 tokens)
  • Memory constrained
  • Training large models
  • Need exact attention

Don't use when:

  • Very short sequences (less than 128 tokens)
  • Custom attention patterns needed
  • Hardware doesn't support (old GPUs)

Integration

Most frameworks now include Flash Attention:

# PyTorch from torch.nn.functional import scaled_dot_product_attention out = scaled_dot_product_attention( Q, K, V, attn_mask=mask, dropout_p=0.1, is_causal=True ) # Uses Flash Attention automatically # Transformers library from transformers import AutoModel model = AutoModel.from_pretrained( "model-name", attn_implementation="flash_attention_2" )

Future Directions

Flash Attention 3

Potential improvements:

  • Persistent kernels: Keep data in SRAM across layers
  • Cross-attention optimization: For encoder-decoder
  • Dynamic sparsity: Adaptive attention patterns

Hardware Co-design

Future hardware optimizations:

  • Larger SRAM (>256KB)
  • Higher SRAM bandwidth
  • Hardware attention units
  • Near-memory computing

Common Misconceptions

"Flash Attention is approximate"

False: Flash Attention computes exact attention, numerically identical to standard attention (within floating-point precision).

"Flash Attention only helps with memory"

False: Primary benefit is speed (2-9×), memory savings are secondary.

"Flash Attention requires special hardware"

False: Works on any GPU with CUDA capability ≥ 7.5 (Turing and newer).

Conclusion

Flash Attention demonstrates that algorithmic innovation can overcome hardware limitations. By reformulating attention as an IO-aware problem rather than a pure compute problem, it enables training and inference of models with much longer contexts than previously possible.

If you found this explanation helpful, consider sharing it with others.

Mastodon