9.14 Practical training tips
Training a neural network is, in the end, an empirical art built on top of mathematical foundations. The mathematics tells you what stochastic gradient descent does, why backpropagation computes the right derivatives, and how a residual connection or a layer-norm block reshapes the loss surface. None of this tells you what learning rate to set on a Tuesday morning when your loss is stuck at 6.8 and you have one GPU and three hours before the seminar. That second kind of knowledge, the kind that lives in lab notebooks, GitHub issues, conference workshop talks, and the heads of senior research engineers, is the subject of this section. It is the closest thing this textbook contains to a practical manual: how to actually train a network so that it converges, generalises, and does not waste a week of compute on a silent bug.
Sections 9.10 to 9.13 each fixed one piece of the puzzle. Section 9.10 explained how to initialise weights so that signals neither explode nor vanish on the first forward pass. Section 9.11 described the family of optimisers, SGD, momentum, Adam, that turn gradient information into parameter updates. Section 9.12 covered regularisation: weight decay, dropout, augmentation, label smoothing. Section 9.13 introduced batch norm, layer norm, and RMSNorm to keep activations well behaved at every depth. This section assembles those parts into a workflow. It assumes you already have a model that builds and runs end-to-end, and asks: how do I drive it to a useful answer in finite time, and how do I tell the difference between converging slowly and broken?
Where there is theory (the linear-scaling rule for batch size, the variance arguments behind warmup), we will say so. Where there is only folklore (the standard transformer betas, the canonical clip threshold of 1.0), we will be honest that these are empirical defaults that have hardened over many papers rather than results derived from first principles. Both kinds of knowledge are worth having. Folklore that has survived a decade of replication is usually pointing at something real, even if no one has yet written down the theorem.
Pick a learning rate first
The learning rate is the single most consequential hyperparameter in deep learning. Other choices, depth, width, optimiser, schedule, produce a few percentage points of difference at the margin. The learning rate, set wrong, produces no model at all. Spend your first afternoon on it.
The two failure modes are easy to recognise once you have seen them. If $\eta$ is too high, the loss either diverges outright (jumping to NaN within a few hundred steps) or oscillates wildly without trending downward, the optimiser is bouncing across the loss landscape rather than descending into it. If $\eta$ is too low, the loss decreases, but at a glacial rate that is indistinguishable from staying still on a short timescale. You have not broken anything; you simply will not finish in this lifetime. The difference between these two pathologies is often a factor of one hundred or one thousand in $\eta$, which is why the right way to search is on a logarithmic scale.
The simplest method is a log-spaced sweep. Pick four or five candidates evenly spaced in log scale: a typical menu is $\eta \in \{10^{-5}, 10^{-4}, 10^{-3}, 10^{-2}\}$. Run each for one epoch, or for whatever short fraction of training you can afford, and plot the loss curves on the same axes. The curve that drops fastest without diverging tells you the right order of magnitude. Sweeping a finer grid around the winner, say multiplying by 0.3 and by 3, usually picks up the last factor of two. You rarely need more than half a day of compute to land within a factor of two of optimal, and that is enough for almost any project.
A sharper variant is the learning-rate range test introduced by Smith (2017). Start with a tiny $\eta$ (perhaps $10^{-7}$) and exponentially increase it every step over a few hundred steps, all in a single short run. Plot the loss against $\eta$ on log axes. You will see a flat region where $\eta$ is too small to do anything, a downward-sloping region where $\eta$ is doing useful work, and a sharp upward spike where $\eta$ is now too big and the optimiser has started to diverge. A reasonable production learning rate sits at roughly one-tenth of the value at which the loss begins to spike upward; this gives a margin of safety against the noisier statistics of full training. The whole test takes minutes and gives you a single curve that contains all the information of a coarse sweep.
Some empirically reasonable starting values, for orientation rather than as gospel: $\eta = 10^{-3}$ for Adam or AdamW on most tasks; $\eta = 5 \times 10^{-4}$ to $10^{-3}$ as a peak learning rate for transformers (with warmup, see below); $\eta = 10^{-1}$ for SGD with Nesterov momentum on ImageNet-style CNNs with batch norm; $\eta = 3 \times 10^{-4}$ as a small-transformer default that has the curious property of working surprisingly often. None of these numbers are derived from any theorem; they are the values that survived selection across thousands of published runs. Treat them as the centre of your sweep, not the answer.
A subtler point: the right learning rate depends on every other choice you have made, initialisation scale, normalisation, batch size, number of layers, optimiser, even the precision of your arithmetic. Change any of these and you should re-sweep $\eta$. The most common cause of a "broken" reimplementation is not a bug in the code but a learning rate that was correct for the original paper's setup and now is not.
Pick a batch size
Batch size is a hardware decision wearing the costume of a hyperparameter. There are two real pressures and they pull in opposite directions.
Larger batches average over more examples per step, which makes the gradient estimate smoother and closer to the true gradient over the data distribution. Smoother gradients let you use a higher learning rate without diverging, which means each step makes more progress. Larger batches also use modern GPUs more efficiently, because matrix multiplications hit higher arithmetic intensity when the batch dimension is large. The wall-clock time to traverse the dataset goes down.
Smaller batches produce noisier gradients. That noise is sometimes a feature rather than a bug: it acts as an implicit regulariser, helping the optimiser escape sharp minima and find flatter ones that generalise better. Smaller batches also use less memory, which matters because GPU memory is usually the binding constraint. A model that does not fit at $B = 64$ may fit at $B = 8$, and an out-of-memory crash is one hundred percent slower than any batch size that runs.
The dominant practical heuristic is therefore: use the largest batch that fits in memory, then tune $\eta$ around it. If validation quality drops at the larger batch, scale the learning rate accordingly. The linear-scaling rule due to Goyal et al. (2017) says that when you multiply the batch by $k$, you should multiply $\eta$ by $k$ as well, and add a brief warmup so the first few steps with the new larger learning rate do not destabilise the network. This rule has good empirical support up to a critical batch size beyond which the returns flatten: doubling the batch above critical no longer halves the time to a fixed loss, because the gradient is already nearly noise-free and adding more samples just averages over redundant information. McCandlish et al. (2018) characterised this critical batch size empirically and showed it grows with the difficulty of the task and the size of the model, from a few thousand examples for ImageNet to millions for large-scale language model pre-training.
For a single-GPU project, none of this matters: pick the largest $B$ that fits, set $\eta$ as in the previous subsection, and move on. For multi-GPU training, the linear scaling rule plus warmup is the right starting point and is usually within reach of optimal. For training runs at the frontier of compute, picking batch size is a research project in its own right.
Choose an optimiser
The short answer to which optimiser should I use? is Adam, unless you have a specific reason otherwise. The longer answer recognises a small number of well-established exceptions.
AdamW is the modern default for most deep learning, especially for transformers and language models. The hyperparameters that survive across nearly every published recipe are $\beta_1 = 0.9$ (the decay rate of the first-moment running average, i.e. the momentum on the gradient), $\beta_2 = 0.999$ for vision and most general-purpose tasks, and $\beta_2 = 0.95$ for transformer language models. The lower $\beta_2$ in language models exists because language gradients are heavy-tailed: the occasional rare token produces gradients many standard deviations larger than the typical one, and the running estimate of the second moment recovers from those spikes faster when $\beta_2$ is smaller. The remaining Adam constants are $\epsilon = 10^{-8}$ (a small floor in the denominator that prevents division by zero) and a peak $\eta$ in the $10^{-4}$ to $10^{-3}$ range. AdamW differs from plain Adam only in how it applies weight decay: AdamW decouples the decay from the gradient update, so the regularisation does not get rescaled by the second-moment estimate. This is the version you want, Loshchilov and Hutter (2019) showed that the original Adam-with-L2 can systematically under-decay parameters with large running second moments, which is the opposite of the desired behaviour.
SGD with Nesterov momentum remains the right choice for one specific niche: classical image classification with batch norm. Hyperparameters: $\eta = 10^{-1}$ as a peak, momentum $0.9$, weight decay $10^{-4}$, with step decay (drop $\eta$ by a factor of 10 at preset epoch boundaries). On ImageNet-style ResNets and their relatives, well-tuned SGD slightly beats Adam on top-1 accuracy. The catch is in the words "well-tuned": SGD has a narrower window of working learning rates and less tolerant interaction with weight decay, so it takes more careful sweeping to find that small advantage.
LARS (You et al., 2017) and LAMB (You et al., 2020) extend SGD and Adam respectively with per-layer adaptive learning rates that scale with the parameter norm. They were designed for very-large-batch training (tens of thousands of examples per step), where naive linear scaling of the learning rate breaks down. Outside that regime they offer no advantage.
Adafactor (Shazeer and Stern, 2018) is a memory-frugal alternative to Adam that factorises the second-moment estimate row-wise and column-wise instead of storing it per-parameter. It is the default optimiser in the T5 family and is worth knowing about when you cannot afford to keep two extra copies of every parameter in GPU memory.
Lion (Chen et al., 2023) takes only the sign of the momentum-smoothed gradient as the update direction, with a separate weight-decay term. It uses half the memory of AdamW (no second-moment buffer) and has been competitive or better on a range of vision and language benchmarks. It remains a sensible second choice rather than a default. Shampoo and its successor SOAP approximate second-order preconditioning and have appeared in some frontier training runs; they are research tools rather than production defaults. SOAP (Vyas et al. 2024) and Muon (Jordan et al. 2024) have moved into some 2025 frontier training runs.
The pragmatic policy: start with AdamW at $\eta = 3 \times 10^{-4}$, $\beta_1 = 0.9$, $\beta_2 = 0.95$ for transformers (or $0.999$ for vision), weight decay $0.1$, gradient clipping at $1.0$. If something has been working for the last decade with these numbers, it will probably keep working with them.
Pick a learning rate schedule
A learning rate schedule is a rule for how $\eta$ changes during training. The simplest schedule is no schedule, a constant $\eta$ throughout. This is rarely optimal, because the right step size early in training (when the loss surface is curved and parameters are moving fast) is not the right step size late in training (when you are fine-tuning the last digits of a near-converged solution).
The four schedules that account for almost every modern recipe are: step decay, cosine decay, linear decay, and warmup-plus-decay.
Step decay drops $\eta$ by a factor of 10 (or 5) at preset epoch boundaries, say, at epochs 30, 60, and 90 of a 100-epoch ImageNet run. It is simple, easy to reproduce, and was the standard recipe for image classification in the ResNet era.
Cosine decay (Loshchilov and Hutter, 2017) is smoother:
$$\eta_t = \eta_{\min} + \tfrac{1}{2}(\eta_{\max} - \eta_{\min})\bigl(1 + \cos(\pi t / T)\bigr) ,$$
where $T$ is the total number of training steps. The learning rate starts at $\eta_{\max}$ and decays smoothly to $\eta_{\min}$ (typically zero or a small floor). It removes the discrete jumps of step decay and tends to outperform it slightly in practice.
Linear warmup followed by cosine decay is the standard transformer recipe. The warmup phase, usually the first 1 to 10 percent of total training, ramps $\eta$ from zero up to $\eta_{\max}$ linearly. The remaining steps follow cosine decay back down to a small fraction of $\eta_{\max}$ (often one-tenth or zero). The reason for warmup is that adaptive optimisers like Adam are unreliable in their first few hundred steps: the running estimates of the first and second moments have not stabilised yet, so the effective per-parameter learning rate is poorly calibrated. Starting at $\eta = 0$ and ramping up gives those estimates time to settle before the optimiser is allowed to take large steps.
A worked example. Suppose you are training a transformer for $T = 100{,}000$ steps with $\eta_{\max} = 10^{-3}$ and $T_w = 5{,}000$ warmup steps.
- At step 1{,}000, you are still in warmup, and $\eta = 10^{-3} \cdot 1000 / 5000 = 2 \times 10^{-4}$.
- At step 5{,}000 you have just reached the peak, $\eta = 10^{-3}$.
- At step 50{,}000, you are halfway through the cosine decay phase: $\eta = 0 + \tfrac{1}{2}(10^{-3} - 0)(1 + \cos(\pi \cdot 45000/95000))$, which works out to roughly $5 \times 10^{-4}$.
- At step 100{,}000, the cosine has reached $\cos(\pi) = -1$ and $\eta = 0$.
One-cycle (Smith, 2018) is a triangular schedule that ramps from a low rate up to a high peak and then back down to an even lower final value, often combined with an inverse cycle on momentum. On many vision tasks it converges in a fraction of the steps required by step decay. It is not the standard for transformers, but is worth knowing when you are trying to squeeze training time on a fixed-architecture image task.
A defensible default for any new project: linear warmup over the first 1 to 5 percent of training, followed by cosine decay to one-tenth of the peak. This has worked on enough tasks across enough labs that it is now the industry standard.
Use gradient clipping defensively
Gradient clipping is a cheap insurance policy that should be on by default. Compute the global gradient norm, the L2 norm of all gradients concatenated into a single vector, and if it exceeds a threshold $\sigma$, scale every gradient down by $\sigma / \|\nabla\|$ so that the new norm is exactly $\sigma$. Formally, $\mathbf{g} \leftarrow \mathbf{g} \cdot \min(1, \sigma / \|\mathbf{g}\|)$.
The point is not to prevent learning. It is to prevent a single freak batch, perhaps a minibatch with an unusually long sequence, or a numerical hiccup in mixed precision, or a rare label that produces a giant cross-entropy, from sending parameters off into nonsense regions of weight space. One bad step at a high learning rate can permanently break a run. Gradient clipping is the seat belt that prevents that.
Typical thresholds: $\sigma = 1.0$ for transformers and most modern architectures; $\sigma = 0.25$ to $0.5$ for recurrent networks, where gradient explosion through time is a chronic problem. The exact number rarely matters much; what matters is that clipping is enabled. Two implementation notes. First, the clip must use the global norm, computed across all parameters, not per-tensor, per-tensor clipping changes the direction of the update, not just its magnitude, which is rarely what you want. Second, clip before the optimiser step, after the gradients have been computed but before they are consumed.
The cost of gradient clipping is one extra norm computation per step, which is trivial. The benefit is the elimination of a whole class of training failures. Turn it on.
Verify your gradients (when implementing from scratch)
If you are using PyTorch's autograd or any production framework, your gradients are correct by construction and you can skip this subsection. If you are implementing a layer or operator from scratch, writing a custom CUDA kernel, building a tiny autograd engine for teaching, or porting a model to a framework without automatic differentiation, you must verify your gradients before you trust any training run.
The technique is numerical gradient checking. For a single scalar parameter $w$ and a scalar loss $\mathcal{L}$, the symmetric finite-difference approximation to the derivative is
$$\nabla_w \mathcal{L} \approx \frac{\mathcal{L}(w + \epsilon) - \mathcal{L}(w - \epsilon)}{2\epsilon} ,$$
with $\epsilon$ small but not too small. The standard choice is $\epsilon = 10^{-7}$ in float64; in float32 you should use a larger $\epsilon$ around $10^{-4}$ because of round-off. Compute this numerical gradient for a handful of randomly chosen parameters and compare it to the analytic gradient your code produces. The diagnostic is the relative error,
$$\text{rel\_err} = \frac{|\nabla^{\text{analytic}} - \nabla^{\text{numeric}}|}{\max(|\nabla^{\text{analytic}}|, |\nabla^{\text{numeric}}|, 10^{-8})} ,$$
with a tiny floor in the denominator to keep things finite when both gradients are zero. A relative error below $10^{-7}$ means the analytic gradient is essentially correct. Errors between $10^{-5}$ and $10^{-7}$ are usually fine, especially with non-smooth activations such as ReLU at zero. Errors above $10^{-5}$ indicate a bug, almost always a missing minus sign, a forgotten transpose, an indexing mistake, or a non-differentiable operation introduced unintentionally.
Two practical points. First, perform gradient checks on a small network, perhaps two layers of width four, with a small batch. Numerical gradient checking is $O(P)$ where $P$ is the number of parameters, because each parameter requires two extra forward passes. On a real network this would be hopeless, but on a toy network it takes seconds. Second, fix the random seeds and disable any stochastic layers (dropout, batch norm in training mode) before checking, since gradient correctness is a property of the deterministic function you implemented, not its noisy training-time variant.
Debug a training run that isn't working
Most training failures fall into one of a small number of categories. Working through this checklist in order will catch the great majority of them and is far faster than randomly tweaking hyperparameters in the hope that the loss starts to fall.
- Try to overfit a single mini-batch. Take one batch of, say, eight examples, and train on it repeatedly until the loss is essentially zero. If you cannot drive the loss to zero on eight examples, the bug is in the model or the loss, not the data or the optimiser. A model with millions of parameters should be able to memorise eight examples in a few hundred steps.
- Print the loss after the first step. If it is NaN, the gradients have already exploded; lower $\eta$ by a factor of ten, enable gradient clipping, and check the loss for numerical pathologies (a $\log(0)$ or $\sqrt{\text{negative}}$).
- Check the order of magnitude of the initial loss. For cross-entropy with $K$ uniformly likely classes the expected loss at random initialisation is $\log K$. If you see $0.1$ when you expected $\log(1000) \approx 6.9$, your loss is normalised wrong (perhaps divided by the number of classes, or averaged when it should be summed).
- Print per-layer gradient norms. A healthy network has gradient norms within an order of magnitude across layers. If layer 1 has gradient norm $10^{-9}$ while layer $L$ has gradient norm $1$, you have a vanishing-gradient problem; check initialisation, normalisation, and whether you have accidentally used sigmoid or tanh in deep layers.
- Print activation distributions. Dead ReLUs, units that always output zero on every input in the batch, indicate that the upstream weights have collapsed. Check the fraction of zero outputs in each layer; if it is above 50 percent, your initialisation is too small or your learning rate has knocked the weights into a bad region.
- Inspect the data loader. Print a handful of input-label pairs by hand. Verify labels match inputs. Verify normalisation (you have applied the same mean and standard deviation as in training). Verify shapes and dtypes. A surprising fraction of "broken" training runs are mis-ordered or mis-aligned data.
- Disable the optimiser and run one forward pass with random weights. The loss should be roughly $\log K$ for $K$-class classification, $0.5$ for binary tasks with sigmoid output, or whatever the trivial value is for your loss. If it is not, the loss or the model output head is wrong, and no amount of tuning $\eta$ will fix that.
- Plot the learning curves on log axes. Plot training loss, validation loss, gradient norm, and learning rate against step. The curves should be informative: a flat training loss means nothing is happening; a falling training loss with rising validation loss means overfitting; a sudden NaN means a numerical event you can localise to a step number.
Most catastrophic-looking training runs are debugged by step 1 or step 6. The discipline of going through the checklist in order, before reaching for hyperparameter changes, will save days of compute over the course of a project.
Save and reload checkpoints
A checkpoint is a snapshot of everything you would need to resume training from where you left off. That is more than just the model weights. To resume reliably you need: the model state dictionary, the optimiser state (Adam's first- and second-moment running averages especially), the learning-rate scheduler state, the random number generator states (for both the framework and any data-loader workers), and the current step or epoch counter. Save all of them.
In PyTorch the canonical pattern is:
torch.save({
'step': step,
'model': model.state_dict(),
'optim': optim.state_dict(),
'sched': scheduler.state_dict(),
'rng': torch.get_rng_state(),
}, path)
Reloading reverses the operation. Best practice is to keep two kinds of checkpoint: a rolling window of the last $K$ checkpoints (so you can roll back if a run goes bad), and a separate best-validation checkpoint that you only overwrite when validation loss improves. The latter is the one you ultimately deploy. Save often enough that an unexpected machine reboot costs you no more than an hour or so, for most training runs, every few thousand steps or every epoch is reasonable.
Use deterministic settings when debugging
Reproducibility helps diagnosis. If you cannot reproduce a bug, you cannot fix it. When you are debugging, investigating a NaN, comparing two implementations, bisecting a regression, set every available seed before the first random number is drawn:
torch.manual_seed(seed)
numpy.random.seed(seed)
random.seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
Two warnings. First, deterministic CUDA kernels are slower than their non-deterministic counterparts, sometimes by tens of percent. Once you have finished debugging, turn determinism off for production training. Second, even with all seeds set, distributed training across multiple GPUs and multi-worker data loading introduce non-determinism that is difficult to remove fully. The point is not bit-exact reproducibility, it is enough determinism that two runs of the same configuration produce learning curves close enough to compare meaningfully.
Monitor more than just loss
A loss curve alone is a thin diagnostic. Modern training infrastructure logs many quantities, and the cost of doing so is essentially zero. Tools that automate this include TensorBoard (built into PyTorch and TensorFlow), Weights and Biases (a hosted experiment tracker), and MLflow (an open-source alternative). Whichever you use, log at minimum:
- training loss per step;
- validation loss per evaluation point;
- learning rate over time (so you can see your schedule actually doing what you specified);
- global gradient norm per step;
- per-layer weight norms, every few hundred steps;
- the fraction of dead ReLUs (or the activation histogram for one or two representative layers);
- top-1 and top-5 accuracy, or whatever task-specific metric matters;
- throughput (examples per second, tokens per second) so that you notice when something has slowed down by 10 percent without warning.
Set up early-warning thresholds where the framework supports them. Alert if validation loss rises for $K$ consecutive evaluations (early sign of overfitting or instability). Alert if any gradient or weight goes NaN. Alert if throughput drops sharply, which often means a hardware issue or a memory leak. The goal is for a long training run to tell you when something has gone wrong, rather than waiting for you to notice on a Monday morning that yesterday's run quietly produced garbage from step 12{,}000 onwards.
A second principle: keep all of these logs, indexed by run identifier and configuration hash. Six months from now you will want to compare against an older experiment, and the historical metrics will be the difference between a confident answer and a guess.
What you should take away
- The learning rate is the hyperparameter that matters most. Sweep it on a log scale, or use a learning-rate range test, before tuning anything else. Default to $\eta = 3 \times 10^{-4}$ for AdamW and $10^{-1}$ for SGD-with-momentum on ImageNet.
- Use AdamW with $\beta_1 = 0.9$, $\beta_2 = 0.95$ for transformers (or $0.999$ for vision), $\epsilon = 10^{-8}$, weight decay $0.1$, gradient clip $1.0$. These are not derived defaults but they have hardened across enough recipes to be the right starting point.
- Use linear warmup followed by cosine decay for almost any modern training run, with $T_w$ between 1 and 10 percent of total steps and a final learning rate around $\eta_{\max} / 10$.
- Turn gradient clipping on by default at threshold $\sigma = 1.0$ (transformers) or $0.25$–$0.5$ (RNNs), using the global gradient norm. The cost is nothing; the protection is real.
- When a training run misbehaves, work through the diagnostic checklist in order. Overfit one batch first, print loss and gradient norms, inspect the data loader, plot learning curves on log axes, almost every failure is caught by step 6, well before any hyperparameter retuning is justified.
- Log more than the loss, save more than the model, and reproduce when debugging. Gradient norms, weight norms, dead-ReLU fractions, throughput; full optimiser and scheduler state in checkpoints; deterministic seeds when investigating bugs but not in production. The discipline is cheap; the payoff is the difference between research that compounds and research that has to be redone.