10.16 Debugging training
You launch a long training run. The loss curve looks wrong. Something is broken. What do you do?
A practitioner's guide to interpreting training curves and gradient/weight statistics.
Loss curves
The shape of the loss curve diagnoses many problems. Plot per-step training loss (smoothed with a rolling average) and per-epoch validation loss on the same axes.
- Loss diverges (goes to infinity). Learning rate too high, or no warmup, or dataset has out-of-distribution example with huge gradient, or numerical overflow in mixed precision. First check: lower LR by 10x, see if it stabilises.
- Loss plateaus at chance level. Bug in the data pipeline (labels detached from inputs, all the same class, etc.) or in the loss (NaN-poisoning, wrong reduction). Sanity check: train on a tiny subset (32 examples) and ensure loss reaches near zero, if not, you have a bug.
- Loss decreases then plateaus far above expected. Model capacity insufficient, or learning rate too low, or schedule decayed too aggressively, or vanishing gradients in early layers.
- Training loss low, validation loss high. Classical overfitting. Add regularisation, augmentation, or reduce capacity.
- Training loss high, validation loss higher. Underfitting. Increase capacity, train longer, raise learning rate (carefully).
- Sudden loss spikes. Common in Transformer training. Usually a single bad batch (extreme outlier example) or numerical instability in attention. Solutions: stronger gradient clipping, per-layer LR, skip catastrophic batches.
Gradient norms
Plot $\|g\|_2$ (the global gradient norm before clipping) over time. Good training shows $\|g\|$ decreasing smoothly with the loss. Pathologies:
- $\|g\|$ exploding indicates instability. Apply gradient clipping (start with $c = 1$ for Transformers).
- $\|g\|$ vanishing in deep networks indicates initialisation or architecture problems. Check residual connections, normalisation layers, activation choice (ReLU dies for negative pre-activations; LeakyReLU/GELU help).
- $\|g\|$ stuck at clipping threshold indicates the threshold is too low. Raise it.
Weight norms
Track per-layer $\|\theta\|$ over time. Should grow modestly (from a small initialisation) and stabilise. Pathologies:
- $\|\theta\|$ growing without bound: weight decay too low, or a runaway feedback loop (e.g. softmax saturation pushing logits to infinity).
- $\|\theta\|$ shrinking aggressively: weight decay too high, model collapsing to zero.
Activation statistics
Sample activations from a few layers each step. Monitor mean, standard deviation, and fraction saturated. For ReLU, a high "dead neuron" fraction (always zero) is bad, consider Leaky ReLU, GELU, or different initialisation.
For Transformers, attention entropy is informative. If attention is highly peaked (low entropy, near-Dirac), the head is dominated by a single token, sometimes correct (induction heads), sometimes a sign of a degenerate attention pattern.
Learning rate schedule
Plot the actual learning rate at each step. Bug: schedule misaligned with total steps (e.g. cosine decays to zero before training is done). Bug: warmup phase too short or too long.
Sanity checks before launching a long run
- Overfit a tiny subset. Take 32 examples; train without regularisation; loss should reach near zero. If not, the model can't fit the data, bug somewhere.
- Test forward and backward separately. Check that gradients flow to all parameters (no
requires_grad=Falseaccidentally). - Check loss at initialisation. For a $K$-class classifier with cross-entropy loss, initial loss should be $\log K$. For a $K = 1000$ ImageNet classifier, $\log 1000 \approx 6.9$. If different, something is off in the data normalisation or final layer.
- Match a known recipe first. If you are training a ResNet-50, reproduce the $76\%$ ImageNet-1k accuracy before innovating. Most "clever ideas" turn out to be re-discoveries of bugs.
Common pathologies
| Symptom | Likely cause |
|---|---|
| Loss = NaN after a few steps | Mixed-precision overflow; need loss scaling or BF16 |
| Loss exactly constant | Detached gradient; loss.detach() or model not in train mode |
| Validation loss much higher than training | Forgot to switch BatchNorm to running stats at eval, or dropout still on |
| Loss jumps up at exactly epoch boundaries | Data shuffling bug; one ordering is much harder |
| Single GPU works, DDP doesn't | Forgot to set_epoch() on the sampler; all GPUs see the same data |
| Adam diverges with high LR despite warmup | Try BF16; FP16 may be underflowing $v_t$ |
| Lower LR helps for first 1k steps but plateaus | Need a warmup-then-decay schedule, not constant low LR |