Glossary

FlashAttention Internals

Standard attention computes $\mathrm{Attn}(Q,K,V) = \mathrm{softmax}(QK^\top / \sqrt{d}) V$ in three passes: form $S = QK^\top$ ($N^2$ scores), apply softmax, multiply by $V$. The intermediate $N \times N$ matrix $S$ has size $4 N^2$ bytes in BF16 and is written to and re-read from HBM between steps. For $N = 8192$ this is 256 MB of HBM traffic per head per layer, far more than the 30 MB of on-chip SRAM.

FlashAttention (Tri Dao et al., 2022) reorganises the computation so that $S$ is never materialised in HBM. The trick is online softmax with tiling.

Tile-based forward pass: split $Q$ into row blocks $Q_i$ of size $B_r \times d$ (typically $B_r = 64$, $d = 128$, so 16 KB in BF16) and $K, V$ into column blocks $K_j$, $V_j$ of size $B_c \times d$. For each $Q_i$ load it into SRAM, then iterate over all $K_j, V_j$ blocks, accumulating the partial output. The total SRAM footprint per tile is $\approx 3 \times B_r \times d \times 2 = 48$ KB, well within H100's 228 KB per SM.

Online softmax: naive softmax needs the global maximum and sum across the entire row before normalising. The streaming algorithm tracks running statistics $(m, \ell)$, the max so far and the sum of exponentials so far, and rescales when a new tile reveals a larger max: $$m_{\mathrm{new}} = \max(m, \tilde m), \quad \ell_{\mathrm{new}} = e^{m - m_{\mathrm{new}}} \ell + e^{\tilde m - m_{\mathrm{new}}} \tilde\ell$$ The output is rescaled by the corresponding exponential factors. This is mathematically exact, not an approximation.

Backward pass: rather than store $S$ from the forward, FlashAttention recomputes it tile by tile during the backward. The cost is one extra forward FLOP-pass; the saving is $O(N^2)$ HBM reads. Net: backward is faster despite recomputation because HBM bandwidth, not FLOPs, was the bottleneck.

Hopper-specific FlashAttention-3 (Shah et al., 2024) exploits H100 features:

  • Warp specialisation: producer warps issue async TMA loads from HBM to shared memory; consumer warps run wgmma on tensor cores. Producer and consumer overlap, hiding memory latency.
  • wgmma asynchrony: the matmul instruction completes asynchronously, letting subsequent instructions issue immediately.
  • FP8 path with E4M3 weights/activations, scale tracking via the Transformer Engine. Doubles throughput from BF16's 989 TFLOP/s to ~1.97 PFLOP/s.

Measured speedups: FlashAttention-2 on A100 hits 50–73 % of peak BF16; FlashAttention-3 on H100 hits 75 % of peak BF16 and 1.2 PFLOP/s in FP8. End-to-end transformer training speedups of 1.5–2× over PyTorch's reference attention are routine at long context.

Video

Related terms: FlashAttention, GPU Memory Hierarchy, Tensor Cores, Attention Mechanism

Discussed in:

This site is currently in Beta. Contact: Chris Paton

Textbook of Usability · Textbook of Digital Health

Auckland Maths and Science Tutoring

AI tools used: Claude (research, coding, text), ChatGPT (diagrams, images), Grammarly (editing).