9.13 Normalisation

Imagine you are trying to teach a deep neural network to recognise cats. The first layer receives pixel intensities, numbers between 0 and 255, perhaps, or between 0 and 1 if you have rescaled them. The second layer receives the outputs of the first, which are weighted sums passed through an activation function. The third layer receives the outputs of the second. And so on, possibly through dozens or hundreds of layers.

Here is the awkward part. Even if you carefully scale the input pixels, the activations deeper in the network can drift to enormous or tiny magnitudes as training progresses. Each weight update changes what every subsequent layer receives. A layer that was happily processing values near zero on Monday might be drowning in values of magnitude one thousand on Tuesday. The network spends much of its training time chasing these moving targets rather than learning anything useful.

Normalisation layers are a useful fix. They sit between the regular weight layers and actively rescale the activations as they flow through the network, holding them at a roughly fixed mean and variance regardless of what the weights happen to be doing. The activations stay in a comfortable range. Training becomes faster, more stable, and tolerant of sloppier hyperparameter choices.

Normalisation was introduced by Ioffe and Szegedy in 2015 with batch normalisation, and the basic idea has since been refined into a small family of variants, layer norm, instance norm, group norm, RMSNorm, each suited to a particular kind of architecture. Today, normalisation appears in nearly every modern deep network, from convolutional image classifiers to the largest language models. It is one of the small handful of ideas that turned deep learning from "sometimes it works" into "it usually works".

This section connects to the surrounding chapter as follows. Section 9.10 (initialisation) sets the activation scale at the start of training: He, Xavier and similar schemes pick weight magnitudes so the variance of activations is roughly preserved across the very first forward pass. Section 9.11 explains why this scale matters, gradients vanish or explode when activations drift to extremes. Section 9.13, the present section, is the maintenance crew: it keeps the scale fixed throughout training, not just at the start. Normalisation is also intimately tied to optimisation (Chapter 10): networks that include normalisation tolerate much higher learning rates and converge in fewer steps.

Symbols Used Here
$\mathbf{x}$input tensor of shape $(B, C, H, W)$ for images, $(B, T, D)$ for sequences
$B$batch size (number of independent examples processed together)
$C$number of channels (image) or features
$T$sequence length (number of tokens or time steps)
$D$feature dimension (size of the embedding for one token)
$H, W$height and width of a feature map
$\mu$$\sigma^2$, mean and variance computed over a chosen normalisation set
$\gamma$$\beta$, learnable scale and shift parameters per channel or per feature
$\epsilon$small constant ($\sim 10^{-5}$) added to variance for numerical stability
$G$number of groups in group normalisation

The general normalisation recipe

Every normalisation method in this section is a variation on the same three-step procedure. Once you understand the recipe, the differences between methods come down to a single question: which set of activations do we use to compute the mean and variance?

Step 1. Pick a set $S$ of activations. Compute their mean and variance: $$\mu = \frac{1}{|S|} \sum_{i \in S} x_i, \qquad \sigma^2 = \frac{1}{|S|} \sum_{i \in S} (x_i - \mu)^2.$$ Here $|S|$ is the number of activations in the set.

Step 2. Normalise each activation in the set so it has mean zero and variance one: $$\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}.$$ The constant $\epsilon$ is a tiny number, typically $10^{-5}$, which prevents division by zero on the rare occasion when the variance happens to be zero (for example, if all activations in the set are identical).

Step 3. Apply a learnable affine transform: $$y_i = \gamma_i \hat{x}_i + \beta_i.$$ The parameters $\gamma_i$ and $\beta_i$ are learned by gradient descent, just like ordinary weights and biases. They give the network the flexibility to un-normalise if that turns out to be useful: with $\gamma = \sigma$ and $\beta = \mu$, the layer reproduces the original distribution exactly. So the normalisation never strictly limits what the network can express; it merely changes the parameterisation so optimisation is easier.

The only thing that varies between methods is step 1, specifically, which axes of the tensor we average over to compute $\mu$ and $\sigma^2$.

For an image tensor of shape $(B, C, H, W)$:

  • Batch norm averages over the batch and the spatial dimensions, separately for each channel. The set $S$ contains $B \cdot H \cdot W$ values; we get one $\mu$ and one $\sigma^2$ per channel.
  • Layer norm averages over the channel and spatial dimensions, separately for each example. $S$ contains $C \cdot H \cdot W$ values; we get one $\mu$ and one $\sigma^2$ per example.
  • Instance norm averages over the spatial dimensions only, separately for each (example, channel) pair. $S$ contains $H \cdot W$ values; one $\mu$ and one $\sigma^2$ per (example, channel).
  • Group norm splits the $C$ channels into $G$ groups, then averages over each group's channels and spatial positions, separately for each example. $S$ contains $(C/G) \cdot H \cdot W$ values per group.
  • RMSNorm is a layer-norm variant that omits step 1's mean subtraction (and step 3's $\beta$).

There is one further wrinkle. Some methods (batch norm) compute statistics from the current minibatch, which means the layer's behaviour depends on which other examples happen to be in the batch. Others (layer norm, instance norm, group norm, RMSNorm) compute statistics from a single example, so the layer's output for a given input is the same whatever else is in the batch. This distinction matters enormously at inference time, when batches may be irregular or unavailable, as we shall see.

Batch normalisation

Batch normalisation, introduced by Ioffe and Szegedy in 2015, was the first member of the family and is still the standard choice for convolutional networks. The idea is direct: for each channel, compute the mean and variance of that channel's activations across the entire current minibatch, and use those to normalise.

For a convolutional input tensor of shape $(B, C, H, W)$, that is, $B$ images, each with $C$ channels and a spatial map of size $H \times W$, batch norm computes one $\mu_c$ and one $\sigma^2_c$ for each channel $c$, by averaging over all $B \cdot H \cdot W$ activations in that channel: $$\mu_c = \frac{1}{B \cdot H \cdot W} \sum_{b=1}^{B} \sum_{h=1}^{H} \sum_{w=1}^{W} x_{b,c,h,w}.$$ The variance is computed analogously. Then each channel is normalised, and a per-channel learnable scale $\gamma_c$ and shift $\beta_c$ are applied.

Worked example

Take a tiny case: a batch of $B = 4$ images, $C = 2$ channels, with $H = W = 1$ so each channel has just one number per image. The pre-normalisation activations are:

Image Channel 0 Channel 1
1 $1.0$ $4.0$
2 $3.0$ $8.0$
3 $5.0$ $12.0$
4 $7.0$ $16.0$

Compute statistics for channel 0 across the four images: $$\mu_0 = \frac{1 + 3 + 5 + 7}{4} = \frac{16}{4} = 4.0.$$ $$\sigma^2_0 = \frac{(1 - 4)^2 + (3 - 4)^2 + (5 - 4)^2 + (7 - 4)^2}{4} = \frac{9 + 1 + 1 + 9}{4} = \frac{20}{4} = 5.0.$$ $$\sigma_0 = \sqrt{5.0} \approx 2.2361.$$ Normalising (with $\epsilon$ negligible), the channel-0 values become $$\frac{1 - 4}{2.2361}, \frac{3 - 4}{2.2361}, \frac{5 - 4}{2.2361}, \frac{7 - 4}{2.2361} \approx -1.342, -0.447, 0.447, 1.342.$$ A quick sanity check: these four numbers do indeed have mean zero and variance one (their sum is zero; their sum of squares is approximately $1.801 + 0.200 + 0.200 + 1.801 = 4.002$, divided by $4$ gives variance $\approx 1.0005$, off by rounding error only).

For channel 1, the values are $(4, 8, 12, 16)$, with $\mu_1 = 10$, $\sigma^2_1 = (36 + 4 + 4 + 36)/4 = 20$, $\sigma_1 \approx 4.4721$. The normalised values are $$\frac{4 - 10}{4.4721}, \frac{8 - 10}{4.4721}, \frac{12 - 10}{4.4721}, \frac{16 - 10}{4.4721} \approx -1.342, -0.447, 0.447, 1.342.$$ Notice that channel 1 ends up with the same normalised values as channel 0, even though the raw scales differed by a factor of four. That is precisely the point: batch norm strips out the per-channel scale and leaves the relative pattern.

Finally, the affine transform $y = \gamma_c \hat{x} + \beta_c$ is applied with the layer's current learned $\gamma_c$ and $\beta_c$. At initialisation $\gamma_c = 1$, $\beta_c = 0$, so the output equals the normalised values; during training the network learns whatever scale and shift work best.

Inference time and running averages

At training time, batch norm uses the statistics of the current minibatch. At test time, this is awkward: an inference request might be for a single example, and you cannot compute a meaningful mean and variance over a batch of one. Worse, even if test-time batches are large, you usually want the network's prediction for a given input to be a deterministic function of that input, not to depend on which other examples happen to be queued up alongside it.

The standard fix is to maintain a running average of $\mu_c$ and $\sigma^2_c$ during training, updated each minibatch with an exponential moving average: $$\mu_c^{\text{run}} \leftarrow m \cdot \mu_c^{\text{run}} + (1 - m) \cdot \mu_c^{\text{batch}},$$ where $m$ is a momentum hyperparameter, typically $0.9$ or $0.99$. At inference, the layer switches over and uses these fixed running estimates instead of any batch statistics. This means batch norm has different behaviour at training and test time, a frequent source of subtle bugs, particularly when forgetting to switch the framework's "eval mode" on at evaluation.

Pros, cons, and gotchas

The advantages are genuinely large. Batch norm regularly enables learning rates ten times higher than would otherwise be stable, training-curve speed-ups of two to ten times on convolutional networks, and a mild regularisation effect (the noise from minibatch statistics acts as a stochastic perturbation, similar in spirit to dropout).

The disadvantages all stem from the dependence on batch statistics. With small batch sizes, say $B = 2$ or $B = 4$, common when training on high-resolution images with limited GPU memory, the per-channel mean and variance are estimated from very few values and become unreliably noisy. With distributed training across multiple GPUs, each worker computes statistics over its local subset of the batch; if you want them to agree, you must synchronise across workers (so-called SyncBN), which costs communication. With sequence data of variable length, as in NLP, the very notion of "batch dimension" becomes ambiguous, since different examples have different numbers of tokens. And the train-test discrepancy from running averages means subtle distribution shift between training and evaluation can degrade performance.

Despite these issues, batch norm remains the default for image-classification CNNs trained with reasonable batch sizes, where its dramatic optimisation benefits outweigh the bookkeeping complexity.

Layer normalisation

Layer normalisation, introduced by Ba, Kiros and Hinton in 2016, takes the obvious alternative: instead of averaging across the batch dimension to get a per-channel mean, average across the feature dimension to get a per-example mean. Each example is normalised entirely independently, with no reference to the rest of the batch.

For sequence data of shape $(B, T, D)$, a batch of $B$ sequences, each with $T$ tokens, each token represented by a $D$-dimensional feature vector, layer norm computes one $\mu$ and one $\sigma^2$ per token by averaging over the $D$ features: $$\mu_{b,t} = \frac{1}{D} \sum_{i=1}^{D} x_{b,t,i}, \qquad \sigma^2_{b,t} = \frac{1}{D} \sum_{i=1}^{D} (x_{b,t,i} - \mu_{b,t})^2.$$ The learnable $\gamma$ and $\beta$ have shape $D$ (one parameter per feature dimension, shared across batch and sequence axes). Each token is normalised using only its own $D$ feature values; nothing depends on the rest of the batch or the rest of the sequence.

Worked example

Take a single token with $D = 4$ features and pre-normalisation values $$x = (2.0,\ 4.0,\ -1.0,\ 3.0).$$ Compute the mean across the four features: $$\mu = \frac{2.0 + 4.0 + (-1.0) + 3.0}{4} = \frac{8.0}{4} = 2.0.$$ Compute the variance: $$\sigma^2 = \frac{(2.0 - 2.0)^2 + (4.0 - 2.0)^2 + (-1.0 - 2.0)^2 + (3.0 - 2.0)^2}{4} = \frac{0 + 4 + 9 + 1}{4} = \frac{14}{4} = 3.5.$$ $$\sigma = \sqrt{3.5} \approx 1.8708.$$ Normalising: $$\hat{x} = \left( \frac{0}{1.8708},\ \frac{2}{1.8708},\ \frac{-3}{1.8708},\ \frac{1}{1.8708} \right) \approx (0.000,\ 1.069,\ -1.604,\ 0.535).$$ Sanity check: the four normalised values sum to $0.000 + 1.069 - 1.604 + 0.535 = 0$ exactly (by construction); their sum of squares is approximately $0 + 1.143 + 2.572 + 0.286 = 4.001$, divided by $4$ gives variance $\approx 1.0003$. Mean zero, variance one, as it should be. The affine $y_i = \gamma_i \hat{x}_i + \beta_i$ then applies whatever scale and shift the network has learned.

The crucial thing is that we computed both $\mu$ and $\sigma^2$ using only the $D = 4$ features of this one token. There was no batch involved at all.

Why layer norm is the right choice for transformers

Layer norm has two huge practical advantages over batch norm.

First, training and inference are identical. There are no running averages to maintain, no momentum hyperparameter, no train-test discrepancy. The output of a layer-norm layer for a given input is a deterministic function of that input alone.

Second, sequence length is irrelevant. Whether the input has 5 tokens or 5,000, each token is normalised the same way using its own $D$ features. Padding a batch to a uniform length, or processing variable-length sequences without padding, both work transparently. This is exactly what NLP needs.

For these reasons, layer norm is the standard choice in transformers. Every encoder, decoder, attention sub-layer, and feed-forward sub-layer of the original transformer (Vaswani et al. 2017) and its descendants uses layer norm. The classic transformer block reads: "LayerNorm $\to$ multi-head attention $\to$ residual add $\to$ LayerNorm $\to$ feed-forward $\to$ residual add" in the pre-norm configuration used by modern systems like LLaMA and the GPT-4 era; or with the LayerNorms after the residual additions in the original post-norm configuration, which we discuss further below.

The downside of layer norm is that it usually performs slightly worse than batch norm on convolutional image classifiers, where batch statistics genuinely capture useful information about what kinds of features tend to fire together across the dataset. But for transformer architectures, where features are token embeddings that should be treated on their own terms, layer norm is unambiguously the better choice.

Instance and group normalisation

These two methods sit between batch norm and layer norm, useful in specific niches.

Instance normalisation (Ulyanov, Vedaldi and Lempitsky, 2016) takes the spatial mean and variance separately for each (example, channel) pair. For an image tensor of shape $(B, C, H, W)$, the set $S$ used to compute statistics for channel $c$ of image $b$ contains $H \cdot W$ values, just the spatial positions of that one channel of that one image. Each instance and channel is independently normalised.

Why would you want this? Style transfer is the classic application. The intuition is that the per-channel statistics (mean and variance) of intermediate features in a CNN encode style information, colour palette, brightness, texture intensity. Stripping those statistics out of each image individually removes the input's style while preserving its content, which is exactly what a style-transfer network needs. Instance norm is also used in some generative models such as the original GAN architectures applied to images.

Group normalisation (Wu and He, 2018) is a compromise designed to recover batch-norm-like benefits when batch sizes must be small. Split the $C$ channels into $G$ groups of $C/G$ channels each. For each group, compute statistics over the $(C/G) \cdot H \cdot W$ activations in that group, separately for each example in the batch. The default is $G = 32$ groups; the choice is often robust over a range from $G = 8$ to $G = 32$.

Group norm, like layer norm, is independent of batch size; its statistics involve only one example at a time. But unlike layer norm, it does not pool statistics across all $C$ channels at once, which would conflate channels capturing very different kinds of feature. Empirically, group norm matches batch norm's accuracy on ImageNet-style image classifiers and substantially exceeds it when batch sizes drop below about $B = 8$. The most common application is high-resolution image tasks (object detection, semantic segmentation, medical imaging) where memory constraints force batch sizes of one or two per GPU.

Choosing between them in CNN settings is largely about batch size: large batches favour batch norm; tiny batches favour group norm; and instance norm is reserved for the special cases where you specifically want per-image style statistics removed.

RMSNorm

RMSNorm, introduced by Zhang and Sennrich in 2019, is a stripped-down version of layer norm. The observation is that layer norm subtracts the mean before dividing by the standard deviation, but the mean subtraction may not actually be doing anything important. RMSNorm tests this hypothesis by simply skipping it.

The formula is $$y_i = \frac{x_i}{\sqrt{\frac{1}{D} \sum_{j=1}^{D} x_j^2 + \epsilon}} \cdot \gamma_i.$$ There is no $\mu$: we never compute the mean. There is no $\beta$: there is no learnable shift. The denominator is the root mean square of the activations (hence the name), which is the same as the standard deviation only if the mean is already zero. The numerator keeps the raw input rather than the mean-subtracted input.

Why this matters at scale

Mean computation has a real cost. For a hidden state of dimension $D = 4096$ (small by 2026 standards), each forward pass through a layer-norm layer requires summing $D$ values to compute the mean, then summing $D$ squared deviations to compute the variance, two passes over the data. RMSNorm needs only one pass: sum the squares directly. The backward pass is similarly simplified, with one fewer term to differentiate through.

The empirical claim, corroborated by ablation studies in the original paper and many follow-ups, is that this simplification costs essentially nothing in model quality. RMSNorm reaches the same loss as layer norm in the same number of training steps, while being roughly 10 to 30 percent faster per step depending on the hardware. At scale (a trillion-parameter model trained on tens of thousands of GPUs for months), that compute saving is substantial.

For these reasons RMSNorm has become the default in most large language models trained from 2020 onwards, including LLaMA (all generations), T5, PaLM, and most of the major open-source LLMs of the 2024–2026 period. The choice between layer norm and RMSNorm is now largely a matter of convention; both work well, but RMSNorm wins on throughput.

Where to put the normalisation

Choosing the type of normalisation is one decision; choosing where in the residual block it goes is a separate one, and the right choice has shifted over time.

The original transformer (Vaswani et al. 2017) used the post-norm configuration: $$y = \operatorname{LayerNorm}(x + \operatorname{Sublayer}(x)).$$ That is: apply the sub-layer (attention or feed-forward) to $x$, add the residual connection, then normalise the sum.

Modern transformers from GPT-2 onwards prefer pre-norm: $$y = x + \operatorname{Sublayer}(\operatorname{LayerNorm}(x)).$$ That is: normalise first, pass through the sub-layer, then add back the un-normalised residual.

The difference looks small but has dramatic consequences for stability. In post-norm, the residual stream is repeatedly normalised, which can interfere with the gradient flowing back through the residual connection, the very mechanism that lets ResNets and transformers train deep networks. The standard post-norm transformer requires a careful learning-rate warmup schedule (linear ramp from zero over several thousand steps) to train stably; without warmup it diverges.

In pre-norm, the residual stream itself is never normalised, only the input to each sub-layer is. The sub-layer's output is added back to the un-normalised residual, which preserves the clean gradient highway. Pre-norm transformers train reliably at depths of hundreds or even a thousand layers, often without warmup.

The trade-off is that post-norm sometimes reaches slightly higher peak performance when it does train successfully, the repeated re-normalisation of the residual stream may regularise the model usefully. But the operational reliability of pre-norm has won out: nearly every modern transformer (LLaMA, GPT-3 onwards, PaLM, Gemini) uses pre-norm. The original post-norm style is now mostly of historical interest.

Why normalisation works (no clean theory)

When Ioffe and Szegedy introduced batch normalisation in 2015, they explained its effectiveness in terms of internal covariate shift: as the parameters of earlier layers update during training, the distribution of activations seen by later layers shifts, forcing those later layers to perpetually re-adapt. Batch norm, the argument went, fixes the activation distributions and removes this moving target.

The story was intuitive and convincing, and, as it turned out, largely wrong. Santurkar et al. (2018) constructed a careful set of ablation experiments. They artificially injected noise after the batch-norm layer to deliberately re-introduce distribution shift, while preserving the normalisation step itself. If the original story were correct, this should destroy batch norm's benefit. But it did not; networks with this artificial noise still trained nearly as well as standard batch-normed ones, and dramatically better than networks without batch norm.

The current best understanding is that batch norm primarily helps by smoothing the loss landscape. Without normalisation, the loss as a function of the parameters has sharp ridges and steep gradient cliffs; with normalisation, the same loss has a much gentler topology that gradient descent can navigate with larger steps. This is consistent with the empirical observation that networks with normalisation tolerate much higher learning rates: the optimiser is no longer afraid of overshooting into a region where the loss explodes.

But "smooths the loss landscape" is more an empirical description than a rigorous theory. There is no clean derivation that predicts, from architectural properties alone, how much batch norm will help, or which variant is best for which architecture. The full theoretical story remains an open research question. In practice we use normalisation because it works, because every careful ablation has confirmed that it works, even though we cannot tell you exactly why.

Practical guidance

A short decision tree gets you most of the way:

  • Convolutional network with batch size $B \geq 32$: use batch normalisation. This is what ResNet, the original ImageNet winners, and most production image classifiers use.
  • Convolutional network with batch size $B < 16$ (high-resolution images, detection, segmentation, limited GPU memory): use group normalisation, with $G = 32$ as a reasonable default. Performance is competitive with batch norm and independent of batch size.
  • Transformer or large language model: use layer normalisation if following the canonical recipe, or RMSNorm if you care about throughput. Apply in pre-norm configuration unless you have a specific reason to do otherwise.
  • Style transfer, certain image-generation models: use instance normalisation.
  • Recurrent neural network: use layer normalisation along the feature dimension. Batch norm interacts badly with the sequential nature of RNNs.

A few additional rules of thumb. Always use the framework's built-in implementation rather than rolling your own, the numerical care taken in the backward pass, the handling of the running averages, and the fused kernel implementations all matter. When you see an unstable training run with diverging loss, suspect normalisation: a wrongly-set momentum, a missing eval-mode switch, or a precision issue with mixed-precision training are all common culprits. And resist the urge to skip normalisation in the name of simplicity: the optimisation benefit is real and the cost is small.

What you should take away

  1. Normalisation actively rescales activations during the forward pass to keep their distributions roughly constant in mean and variance, regardless of how the weights of earlier layers happen to change during training.

  2. All normalisation methods follow the same recipe: pick a set of activations, compute their mean and variance, normalise to mean zero variance one, apply a learnable per-feature scale $\gamma$ and shift $\beta$. The methods differ only in which set of activations is averaged.

  3. Batch norm averages across the batch and spatial dimensions per channel and is the default for CNNs with reasonable batch sizes; it depends on running averages at inference, which complicates train-test consistency. Layer norm averages across the feature dimension per example and is the default for transformers; it is identical at training and inference.

  4. RMSNorm drops the mean-subtraction step and is faster than layer norm at no measurable accuracy cost; it is now standard in large language models. Instance norm and group norm are batch-size-independent variants useful in style transfer and small-batch CNN settings respectively.

  5. Pre-norm placement (LayerNorm $\to$ sublayer $\to$ residual add) is more stable than post-norm for very deep transformers and is the modern default. The reason normalisation works at all is best understood as smoothing the loss landscape rather than the original "internal covariate shift" story, but a complete theory remains elusive.

This site is currently in Beta. Contact: Chris Paton

Textbook of Usability · Textbook of Digital Health

Auckland Maths and Science Tutoring

AI tools used: Claude (research, coding, text), ChatGPT (diagrams, images), Grammarly (editing).