3.13 Numerical considerations

For most of this chapter we have written gradients on paper, manipulated them as if they were exact mathematical objects, and assumed that whatever the calculus said the computer would faithfully execute. That is a useful fiction, and for understanding the algorithms it is the right one. The moment we run those algorithms on real silicon, however, the fiction starts to leak. Every number is stored in finite precision, every arithmetic operation rounds, and the rounding errors do not stay small. They drift, they accumulate, and over the millions of operations involved in a single training step they can compound into surprising and sometimes catastrophic ways. A network can train smoothly for hours and then suddenly print NaN. A loss can plateau because gradients have silently underflowed to zero. A handwritten layer can give plausible-looking outputs while computing entirely the wrong derivative. The mathematics is clean; the implementation is delicate.

This section catalogues the most common numerical pitfalls that beset gradient computation, and the standard responses to them. The closely related material on floating-point representation, conditioning and numerical stability lives in §2.11; readers who are unsure what fp32 and machine epsilon mean should glance back there before continuing.

Symbols Used Here
$\varepsilon_{\text{mach}}$machine epsilon (smallest representable relative gap)
$\nabla$gradient
$|x|$absolute value
$\sigma$gradient-clipping threshold
$\mathbf{z}$vector of pre-softmax logits

Catastrophic cancellation in gradient checking

When you write a custom layer with a hand-derived backward pass, the standard sanity check is to compare the analytical gradient against a finite-difference approximation. The naive form is

$$ \frac{\partial f}{\partial x} \;\approx\; \frac{f(x + h) - f(x)}{h}. $$

Algebraically this is correct in the limit $h \to 0$. Numerically it is a trap. As $h$ shrinks, $f(x+h)$ and $f(x)$ become almost equal, and the floating-point subtraction in the numerator cancels their leading digits. If $f(x) = 1.234567890$ and $f(x+h) = 1.234567891$, then $f(x+h) - f(x)$ has only one significant digit, even though both inputs were stored to full precision. This is catastrophic cancellation, and it means that decreasing $h$ does not always improve the estimate; below some threshold it makes things much worse.

The remedy has two parts. First, use a centred difference:

$$ \frac{\partial f}{\partial x} \;\approx\; \frac{f(x + h) - f(x - h)}{2 h}, $$

which is symmetric about $x$ and has truncation error $O(h^2)$ instead of $O(h)$. Second, choose $h$ to balance truncation against round-off. A useful rule of thumb is $h \approx \sqrt[3]{\varepsilon_{\text{mach}}}$: about $10^{-5}$ in single precision (fp32), about $10^{-7}$ for general-purpose checks at that precision, and about $10^{-12}$ in double precision (fp64). PyTorch's torch.autograd.gradcheck uses double precision and a centred difference for exactly this reason.

The acceptance criterion is a relative error of around $10^{-7}$ for an fp64 implementation, or about $10^{-3}$ for fp32. Absolute thresholds will mislead, because a layer with very large outputs can have very large gradients, and an absolute error of $10^{-3}$ in that setting is excellent rather than alarming. Always normalise.

Gradient explosion and clipping

Recurrent networks and very deep feed-forward networks have a structural tendency toward exploding gradients. The chain rule multiplies Jacobian factors layer by layer; if each factor has spectral norm a little above one, the product grows geometrically with depth, and a single batch can produce a gradient norm of $10^{6}$ or more. The optimiser, given such a gradient, takes a giant step into a region of parameter space that the loss has never seen before, and the next forward pass overflows to NaN.

The standard response is gradient clipping. Before applying the optimiser update, compute the global norm of the gradient,

$$ \|\nabla\|_2 \;=\; \sqrt{\sum_i \nabla_i^2}, $$

and if it exceeds a threshold $\sigma$, scale the entire gradient down so that its norm is exactly $\sigma$. PyTorch exposes this as torch.nn.utils.clip_grad_norm_(parameters, max_norm=1.0), called between loss.backward() and optimizer.step(). Typical thresholds are $\sigma = 1.0$ for transformers and $\sigma = 0.25$ for recurrent networks; the right value is essentially folklore for each architecture family. Clipping by norm preserves the direction of the gradient and only rescales its magnitude, which is what we want. (There is also a per-parameter clip_grad_value_, which is cruder and rarely the better choice.) When you see a training loss that occasionally spikes by orders of magnitude before recovering, gradient clipping is what is keeping the run alive.

Numerically stable softmax cross-entropy

The single most common stability issue in practice arises from softmax followed by cross-entropy. The naive recipe is to compute probabilities $p_j = e^{z_j} / \sum_k e^{z_k}$, then take the log of the probability for the correct class. Both steps are numerically dangerous: the exponentials overflow when any logit is large, and the logarithm loses precision when the resulting probability is near zero or one.

The cure is to fuse the two operations into a single expression and rearrange:

$$ \mathcal{L} \;=\; -\log \frac{e^{z_y}}{\sum_j e^{z_j}} \;=\; \log \sum_j e^{z_j} \;-\; z_y \;=\; \mathrm{logsumexp}(\mathbf{z}) - z_y, $$

where the stable log-sum-exp factors out the maximum logit:

$$ \mathrm{logsumexp}(\mathbf{z}) \;=\; z_{\max} + \log \sum_j e^{z_j - z_{\max}}. $$

Subtracting $z_{\max}$ inside the exponent guarantees that the largest term is $e^0 = 1$, so the sum is bounded between $1$ and $n$ and never overflows. The leading $z_{\max}$ outside compensates exactly. The same trick stabilises the binary case via $\log(1 + e^z) = \max(z, 0) + \log(1 + e^{-|z|})$, which is what PyTorch's softplus and binary_cross_entropy_with_logits use internally.

The practical advice is short and absolute: never compute log(softmax(x)) by hand. Always call the framework's fused entry point, F.cross_entropy(logits, targets) for multiclass classification, F.binary_cross_entropy_with_logits(logits, targets) for binary or multilabel, and feed it raw logits, not probabilities. The framework will use the stable form, your code will be shorter, and you will not spend an afternoon hunting a NaN that appeared in epoch 47.

Mixed-precision training

Modern accelerators are dramatically faster at 16-bit arithmetic than at 32-bit. Mixed-precision training exploits this by storing the master copy of weights and the optimiser state in fp32, while running the matrix multiplications and most activations in either fp16 or bf16. The forward and backward passes get the speed and memory savings; the parameter update keeps the precision needed for stable optimisation.

The two 16-bit formats behave very differently. fp16 has high precision but a small dynamic range, so gradients routinely underflow to zero. The standard fix is loss scaling: multiply the loss by a large constant such as $2^{15}$ before calling backward(), which shifts every gradient up into the representable range; then unscale by dividing by the same constant before the optimiser step. PyTorch's torch.cuda.amp.GradScaler automates this and, importantly, dynamically adjusts the scale upward when training is calm and halves it whenever an inf or NaN is detected. bf16, by contrast, has the same exponent range as fp32 but fewer mantissa bits; gradients essentially never underflow, and no loss scaling is needed. On hardware that supports it (recent NVIDIA GPUs, TPUs, Apple silicon), bf16 is now the preferred default. Wrap the forward pass in with torch.cuda.amp.autocast(): and you are done.

Gradient noise and stale statistics

When training runs on multiple devices, gradients are computed in parallel on different mini-batches and then aggregated. The default and safest scheme is synchronous all-reduce: each worker waits at the end of the backward pass, the gradients are summed across workers, and every worker takes the same optimiser step. This behaves identically to a single large-batch run, up to numerical reordering. Asynchronous variants let workers update parameters with stale gradients computed from an older version of the weights; throughput is higher but stale gradients cause divergence at high learning rates and are mostly out of fashion for deep learning.

Batch normalisation is a particular trap in distributed training. Its running mean and variance are computed from the local mini-batch, which on each worker may be small enough to be noisy. Synchronised batch normalisation (SyncBN) gathers statistics across all workers before normalising, recovering the behaviour of a single large batch. For dense vision models with small per-device batch sizes, switching from local BN to SyncBN can change accuracy by several percentage points.

Detecting and debugging numerical issues

When a training run misbehaves, the diagnostics that pay for themselves repeatedly are simple. Log the per-layer gradient norm, the per-layer weight norm, and the loss curve every few hundred steps. A NaN is detected with torch.isnan(loss).any(); a small wrapper that asserts every batch will catch the moment things go wrong, rather than letting the corruption propagate silently for an hour. If gradient norms are healthy in most layers but enormous in one, the parameters or activations of that layer are misscaled, and you should look at its initialisation, its inputs, or any custom operation it contains. If the loss decreases on the training set but diverges on validation, you do not have a numerical problem at all, you have a data leak between splits, or a validation pipeline that uses different preprocessing. The numerical pathology and the methodological pathology look identical from a single loss curve, and disentangling them quickly is one of the marks of a fluent practitioner.

Pair these prints with the framework's anomaly mode (torch.autograd.set_detect_anomaly(True)) when you need a stack trace pointing at the operation that produced the first NaN. It is slow, so leave it off by default, but it can save hours when something is genuinely wrong.

What you should take away

  1. Finite differences for gradient checking must use a centred formula and a sensibly chosen $h$, or catastrophic cancellation will hide bugs rather than reveal them.
  2. Recurrent and very deep networks need gradient-norm clipping ($\sigma \approx 1.0$ for transformers, $\sigma \approx 0.25$ for RNNs) to survive the occasional explosive batch.
  3. Always use the framework's fused cross_entropy and binary_cross_entropy_with_logits; behind them sits the stable $\mathrm{logsumexp}(\mathbf{z}) = z_{\max} + \log \sum_j e^{z_j - z_{\max}}$ identity.
  4. Mixed precision is now standard: prefer bf16 where the hardware supports it, and use fp16 with GradScaler otherwise.
  5. Log gradient and weight norms, assert against NaN, and remember that diverging validation with falling training loss is almost always a methodology bug, not a numerical one.

This site is currently in Beta. Contact: Chris Paton

Textbook of Usability · Textbook of Digital Health

Auckland Maths and Science Tutoring

AI tools used: Claude (research, coding, text), ChatGPT (diagrams, images), Grammarly (editing).