10.11 Mixed precision training

Modern GPUs (A100, H100, H200, B200, GB200) and TPUs have hardware acceleration for half-precision arithmetic. FP16 (or BF16) operations run $2$–$8\times$ faster than FP32 on tensor cores. Mixed-precision training runs forward and backward passes in FP16/BF16 while keeping a master copy of weights in FP32, capturing most of the speedup with minimal accuracy loss.

Numerical formats

  • FP32: 1 sign + 8 exponent + 23 mantissa bits. Range $\sim 10^{-38}$ to $10^{38}$, $\sim 7$ decimal digits.
  • FP16: 1 + 5 + 10. Range $\sim 6 \times 10^{-5}$ to $6 \times 10^4$. Narrow dynamic range, gradients can underflow.
  • BF16: 1 + 8 + 7. Same exponent range as FP32, only $3$ decimal digits of precision. Wider dynamic range than FP16; mantissa is poor but adequate for training.

BF16 has largely replaced FP16 for training because its dynamic range matches FP32, gradients rarely underflow, and loss scaling is unnecessary.

Loss scaling (FP16 only)

Many gradient values are between $10^{-7}$ and $10^{-4}$. In FP16 these underflow to zero. Loss scaling multiplies the loss by a large factor $S$ (typically $S = 2^{16}$ or higher) before backpropagation. Gradients scale up by $S$, moving them into FP16's representable range. Before the optimiser update, gradients are unscaled by dividing by $S$.

Dynamic loss scaling (NVIDIA Apex, PyTorch AMP) automates this: increase $S$ when no overflow occurs for $N$ steps; decrease $S$ when an overflow is detected (skip the update for that step). The scale converges to a value just below the overflow threshold.

FP32 master weights and accumulation

Crucial detail: even with FP16/BF16 forward and backward, weights and optimiser state are kept in FP32. The Adam first and second moments (which decay slowly) and the parameters themselves (which receive small updates) need FP32 precision to accumulate correctly over millions of steps. Operations that can lose precision (softmax, layer norm, loss reduction) are also performed in FP32 by default in modern AMP libraries.

Memory savings

Mixed precision halves activation memory (since activations dominate memory in deep models). It does not halve total memory, because the FP32 master weights and optimiser state remain. For Adam, the per-parameter memory is roughly $4 + 4 + 4 + 4 = 16$ bytes (master FP32 weight, FP16 forward weight, FP32 first moment, FP32 second moment), only modestly less than the $8$ bytes for FP32-only training.

This site is currently in Beta. Contact: Chris Paton

Textbook of Usability · Textbook of Digital Health

Auckland Maths and Science Tutoring

AI tools used: Claude (research, coding, text), ChatGPT (diagrams, images), Grammarly (editing).