Glossary

FlashAttention

FlashAttention, introduced by Tri Dao, Daniel Fu, Stefano Ermon, Atri Rudra and Christopher Ré in 2022, is an algorithm for computing exact attention with substantially reduced memory bandwidth requirements. The algorithm is tile-based: it loads small blocks of the query, key and value matrices into fast SRAM, computes the partial attention output for that block, and writes the result back to slow HBM (high-bandwidth memory). The crucial insight is to compute softmax in a streaming fashion, using the online softmax algorithm, so that the full attention matrix never needs to be materialised.

FlashAttention achieves 2-4× speedups on common attention shapes versus naive implementations, with no change to the computed result (it is exact attention, not an approximation). The algorithm is now standard in nearly every Transformer training and inference framework, PyTorch's scaled_dot_product_attention, JAX's Flash Attention library, the Triton implementation shipped with major training stacks, and the optimised attention layers in vLLM, TensorRT-LLM, SGLang and other inference servers.

Successor algorithms, FlashAttention-2 (Dao 2023) further optimises the original; FlashAttention-3 (Shah et al. 2024) targets Hopper GPUs specifically with FP8 support, have continued to push performance. The technique has also been extended to other operations: FlashConv for convolutions, Flash Linear Attention for linear-attention variants, and FlashGEMM for matrix multiplications more generally.

Video

Related terms: tri-dao, Attention Mechanism, Transformer

Discussed in:

This site is currently in Beta. Contact: Chris Paton

Textbook of Usability · Textbook of Digital Health

Auckland Maths and Science Tutoring

AI tools used: Claude (research, coding, text), ChatGPT (diagrams, images), Grammarly (editing).