13.14 FlashAttention: IO-aware exact attention

The previous section showed that the cost of self-attention grows quadratically with sequence length, and that the bottleneck is not the multiplications themselves but the $n \times n$ attention matrix that has to be written down, normalised, and read back. FlashAttention, introduced by Tri Dao and colleagues in 2022, attacks exactly this problem. It does not change the mathematics of attention. The output is the same softmax-attention you have already met, computed to machine precision. What changes is how the computation is laid out on the GPU. By fusing the entire attention operation into a single kernel and never materialising the full matrix in slow memory, FlashAttention achieves two-to-fourfold wall-clock speed-ups for training and even larger gains for inference, while supporting much longer contexts. It is now the default behind torch.nn.functional.scaled_dot_product_attention in PyTorch 2 and ships with every major transformer library. If you write the textbook formula softmax(QK^T / sqrt(d_k)) V in your code today, the chances are very high that FlashAttention is what actually runs.

The quadratic wall has two faces: memory and compute. Linear-attention methods (§13.15) attack the compute wall by changing the algorithm. FlashAttention attacks the memory wall by leaving the algorithm alone and changing the data movement. FlashAttention is by far the easier sell because it costs nothing in accuracy.

Symbols Used Here
$n$sequence length
$M$fast-memory size (SRAM per streaming multiprocessor)

Memory bandwidth as the bottleneck

A modern GPU is not a single homogeneous device. It has a deep memory hierarchy. At the bottom is high-bandwidth memory, or HBM: tens of gigabytes, with bandwidth measured in terabytes per second. On an A100 80GB this is around 2.0 TB/s; on an H100 SXM, 3.35 TB/s; on H200, 4.8 TB/s; on B200, around 7.7 TB/s on HBM3e. Above that sits a much smaller but much faster on-chip cache: a few megabytes of static random-access memory, or SRAM, per streaming multiprocessor, with bandwidth approaching twenty terabytes per second. The compute units themselves can sustain roughly one petaflop per second of half-precision matrix multiplication on an H100. The arithmetic intensity required to keep them busy is therefore very high indeed: every byte fetched from HBM must support hundreds of floating-point operations, or the tensor cores will sit idle waiting for data.

This is what engineers mean by a memory-bandwidth bottleneck. The chip is fast enough; the pipe feeding it is not. Naive attention is a textbook case. To compute attention on a single head, the standard recipe writes the score matrix $\mathbf{S} = \mathbf{Q}\mathbf{K}^\top$ to HBM, reads it back to apply the row-wise softmax, writes the normalised attention matrix $\mathbf{P}$ back to HBM, and then reads it once more to multiply by $\mathbf{V}$. The intermediate matrices $\mathbf{S}$ and $\mathbf{P}$ are both of size $n \times n$. For a sequence of length 8000, that is sixty-four million entries each, in fp16 around 128 MB per matrix per head, repeated for every head and every layer. The total HBM traffic comes to roughly $4n^2$ reads and writes per head. At long sequence length, the matrix multiplications complete in a flash and the GPU spends almost all its time waiting on HBM. Profiling a long-context transformer on stock attention will show tensor-core utilisation in the single-digit percentages, with the rest of the time burnt on memory copies.

The conclusion is uncomfortable. Throwing more compute at the problem will not help; the cores are already idle. What needs to shrink is the number of bytes moved. FlashAttention is the answer to that question.

Tile decomposition

The central idea is to compute attention block by block, keeping every intermediate result inside SRAM, and only writing the final per-row output back to HBM. Concretely, partition the queries into row-blocks $\mathbf{Q}_i$ of size $B_r \times d_k$ and the keys and values into column-blocks $\mathbf{K}_j$, $\mathbf{V}_j$ of size $B_c \times d_k$, with the block sizes chosen so that all the working tensors for one outer-loop iteration fit comfortably inside the on-chip SRAM available to a streaming multiprocessor. Typical numbers on an A100 are $B_r = B_c = 64$ or 128.

The kernel then runs two nested loops. The outer loop iterates over query blocks $i$. For each query block, an inner loop sweeps over key-and-value blocks $j$. Inside the inner loop, the kernel loads $\mathbf{Q}_i$, $\mathbf{K}_j$ and $\mathbf{V}_j$ from HBM into SRAM, computes the small score tile $\mathbf{S}_{ij} = \mathbf{Q}_i \mathbf{K}_j^\top$ entirely on chip, updates the running softmax statistics for this query block, and accumulates a partial output $\mathbf{O}_i$ on chip. When the inner loop has consumed every key-and-value block, the accumulated $\mathbf{O}_i$, now equal to the exact attention output for the queries in block $i$, is written once to HBM. Nothing of size $n \times n$ is ever stored in slow memory; the only large tensor that touches HBM is the output, which has the same shape as the input.

The HBM traffic for a single head drops from $\Theta(n^2)$ to $\Theta(n^2 d_k / M)$ where $M$ is the SRAM size, and in practice this is several times smaller than the naive cost. More importantly, fewer bytes mean fewer round-trips, so the cores stay fed and tensor-core utilisation climbs from a sad single-digit number to something approaching its theoretical ceiling. The fact that all of this is done in a single fused kernel matters too. There are no intermediate launches, no synchronisation barriers between kernels, no memory allocations for temporary buffers. The whole attention layer becomes one launch.

Backpropagation uses the same trick, but in reverse. Rather than storing the attention matrix from the forward pass for use in the backward pass, which would defeat the entire purpose, FlashAttention recomputes the relevant tiles on the fly during the backward pass. Recomputation is cheap once you are inside SRAM, and the saving in HBM traffic more than pays for it. This is a clean instance of the activation-recomputation trade-off you met in §9.13.

Online softmax

The technical wrinkle in tiling attention is the softmax. Softmax along a row of $\mathbf{S}$ requires the maximum and the sum-of-exponentials of that whole row, but the row is split across the inner loop. You cannot normalise a tile until you have seen every other tile in the same row. The classical fix is a two-pass algorithm: one pass to find the maximum, a second to compute the exponentials and sum. That doubles the HBM reads, which is exactly what we are trying to avoid.

The online softmax algorithm, which dates back to a 2018 NVIDIA paper by Maxim Milakov and Natalia Gimelshein, eliminates this. The trick is a one-pass running update. For each query row, maintain two running statistics as the inner loop progresses: a running maximum $m_i$ and a running denominator $\ell_i$. When a new score tile arrives, compute its local maximum $\tilde m$ and update the running maximum to $m_i^{\text{new}} = \max(m_i, \tilde m)$. The previous denominator is rescaled by $e^{m_i - m_i^{\text{new}}}$ to account for the change of pivot, the new tile's exponentials are summed using $m_i^{\text{new}}$ as the stable subtractive offset, and the partial output $\mathbf{O}_i$ is rescaled by the same factor. After the last tile, dividing by $\ell_i$ gives the exact softmax. Subtracting the running maximum before exponentiating is the same numerical-stability trick you use for any softmax: it ensures no exponent ever overflows.

This algorithm is exact. It produces the same bits, modulo summation order, as the two-pass version. Combined with tiling, it gives a single-pass, single-kernel attention that never writes the full attention matrix and never compromises on numerical correctness. The mathematics is unchanged; only the data flow has been redrawn.

Speedups

The numbers depend on context length, head dimension, hardware and precision, but the headline figures are robust. On training workloads with sequence lengths from 1k to 4k, FlashAttention typically delivers a two- to fourfold wall-clock speed-up over a well-tuned PyTorch baseline, with the larger gains at longer sequences where the memory wall hurts most. On inference, where batch sizes are smaller and bandwidth dominates more strongly, the gains rise to five to tenfold. Memory consumption drops from $O(n^2)$ to $O(n)$, which is what allows context windows to grow from a few thousand tokens to tens or hundreds of thousands without exhausting HBM.

Because FlashAttention is exact, it produces the same outputs and the same gradients as standard attention, up to floating-point rounding. There is no quality loss to worry about, no accuracy regression to monitor, no separate hyper-parameter to tune. You drop it in and your model trains faster and fits more tokens. Of all the systems-level changes that rewarded the deep-learning era, this one is unusually free of trade-offs.

Where it's used

FlashAttention is now the default attention backend in nearly every production transformer stack. PyTorch 2 wires it into torch.nn.functional.scaled_dot_product_attention, and HuggingFace Transformers, vLLM, TGI and DeepSpeed all dispatch to it when the inputs are compatible. Subsequent versions have continued to extend it. FlashAttention-2, released by Dao in 2023, improves work-partitioning across streaming multiprocessors and parallelises the outer loop more aggressively, yielding roughly another twofold improvement on long sequences. FlashAttention-3, released in 2024 for Hopper-class GPUs, exploits the asynchronous tensor cores and TMA hardware on H100 to overlap data movement with compute. FlashAttention-4 (announced Hot Chips August 2025; paper March 2026) reaches 1,605 TFLOPS on Blackwell B200, roughly 3.6× FA2 forward at sequence length 32k, written in CuTeDSL. The pattern is the same across all four: deeper integration with the hardware, with no change to the mathematics.

What you should take away

  1. Naive attention is bandwidth-bound on modern GPUs because it materialises an $n \times n$ matrix in HBM and reads it back through a pipe that is much narrower than the compute units demand.
  2. FlashAttention tiles the computation, keeping every intermediate inside the streaming multiprocessor's SRAM, and writes only the final output to HBM.
  3. The algorithmic enabler is the online softmax, a one-pass running-max and running-denominator update that gives the exact normalised result without ever seeing a whole row at once.
  4. The result is an exact, drop-in replacement for standard attention with two- to tenfold wall-clock gains and $O(n)$ memory, enabling much longer context windows.
  5. You almost certainly use it already: it is the default behind scaled_dot_product_attention in PyTorch 2 and the standard backend in HuggingFace and vLLM.

This site is currently in Beta. Contact: Chris Paton

Textbook of Usability · Textbook of Digital Health

Auckland Maths and Science Tutoring

AI tools used: Claude (research, coding, text), ChatGPT (diagrams, images), Grammarly (editing).