10.12 Distributed training

Frontier-scale models (10B+ parameters, trillions of training tokens) cannot fit on a single GPU and would take centuries to train serially. Distributed training spreads the work across many GPUs, possibly across many nodes connected by high-speed networking.

There are three orthogonal axes of parallelism: data, tensor, and pipeline. Modern training combines all three. We start with the simplest: data parallelism.

Data parallelism (DDP)

Replicate the model on each GPU. Each GPU processes a different mini-batch slice. After the backward pass, gradients are averaged across GPUs (using all-reduce) and all replicas perform the same optimiser step.

Mathematically: with $K$ GPUs and per-GPU batch $B$, the effective batch size is $KB$. The gradient computed by all-reduce is

$$\hat g = \frac{1}{KB} \sum_{k=1}^K \sum_{i \in \mathcal{B}_k} \nabla \ell_i(\theta).$$

This is identical to a single-GPU gradient on a $KB$ batch. DDP gives perfect scaling so long as the all-reduce communication time is masked by the backward computation. For modern interconnects (NVLink, InfiniBand) this holds up to a few hundred GPUs.

PyTorch's DistributedDataParallel is the canonical implementation. It overlaps gradient communication with backward computation by bucketing parameters and launching all-reduce operations as soon as each bucket's gradients are ready.

ZeRO: redundant memory elimination

ZeRO (Zero Redundancy Optimizer; Rajbhandari et al. 2020) eliminates the memory duplication inherent in DDP. With $K$ replicas, each GPU stores a full copy of the model parameters $\theta$, gradients $g$, and optimiser state $(m, v)$. Total memory is $K \times (4P)$ for an Adam-style optimiser (parameter + gradient + first moment + second moment), all in FP32, where $P$ is the parameter count.

ZeRO partitions these across GPUs:

  • ZeRO-1: shard optimiser state. Each GPU holds $1/K$ of $(m, v)$. After all-reduce of gradients, each GPU updates its assigned shard, then all-gathers the updated parameters. Memory savings: $4\times$ for Adam.
  • ZeRO-2: shard optimiser state + gradients. Backward pass uses reduce-scatter instead of all-reduce. Memory savings: $8\times$.
  • ZeRO-3: shard optimiser state + gradients + parameters. Each GPU only holds $1/K$ of $\theta$. Forward and backward passes all-gather the parameters of the current layer just in time, then release them. Memory savings: $K\times$.

ZeRO-3 is the basis of FSDP (Fully Sharded Data Parallel) in PyTorch. For training a 70B model on 1024 GPUs with ZeRO-3, each GPU stores only $\sim 70$M parameters' worth of state, modest enough to fit comfortably in 80 GB of HBM with room for activations.

The trade-off is communication: ZeRO-3 requires roughly $2\times$ the inter-GPU traffic of plain DDP (an all-gather in forward, an all-gather in backward, a reduce-scatter for gradients). On NVLink-connected nodes the overhead is small; across InfiniBand it is more significant.

Pipeline parallelism

Split the model vertically by layer. GPU 1 holds layers 1–10, GPU 2 holds layers 11–20, etc. The forward pass flows from GPU 1 to GPU $K$; the backward pass flows back. This enables training models too large to fit on any single GPU's memory.

Naive pipeline has poor utilisation: GPU 2 is idle while GPU 1 computes the first batch's forward; GPU 3 is idle while GPU 1 and 2 compute. To fix this, GPipe (Huang et al. 2019) splits each batch into $M$ micro-batches and pipelines them. The schedule is:

GPU 1 | f1 f2 f3 f4 ... b4 b3 b2 b1
GPU 2 |    f1 f2 f3 f4 ... b4 b3 b2 b1
GPU 3 |       f1 f2 f3 f4 ... b4 b3 b2 b1
GPU 4 |          f1 f2 f3 f4 ... b4 b3 b2 b1

The empty cells at the start and end are the pipeline bubble. With $K$ stages and $M$ micro-batches, the bubble is a fraction $(K - 1)/(K + M - 1)$ of total time. Larger $M$ reduces the bubble but increases activation memory (you must store activations for all $M$ in-flight micro-batches until the backward pass). Typical choices: $M = 4K$ to $8K$, bubble $\approx 10\%$.

PipeDream (Narayanan et al. 2019) and 1F1B (one-forward-one-backward) schedules further reduce bubble overhead by interleaving forward and backward passes.

Tensor parallelism

Split each weight matrix across GPUs. For a Transformer linear layer $Y = XW$ where $W \in \mathbb{R}^{d_{\mathrm{in}} \times d_{\mathrm{out}}}$, partition $W$ into column blocks $W = [W_1 | W_2 | \ldots | W_K]$ across $K$ GPUs. Each GPU computes its local $Y_k = XW_k$. To get the full $Y$, all-gather the $Y_k$.

Megatron-LM (Shoeybi et al. 2019) showed how to combine tensor parallelism efficiently across multiple Transformer components: column-parallel for the QKV projection, row-parallel for the output projection, composed so that only one all-reduce is needed per attention block.

Tensor parallelism is bandwidth-intensive: every layer requires an all-reduce or all-gather. It is therefore typically restricted to within a single high-bandwidth domain (NVLink-connected GPUs in one node, $\le 8$ GPUs).

Expert parallelism

For Mixture-of-Experts (MoE) models, different experts can live on different GPUs. The router sends each token to its top-$k$ experts; an all-to-all communication step routes activations to the right GPUs. Expert parallelism scales the parameter count without scaling the per-token compute, but the all-to-all is bandwidth-hungry and load-balancing the experts is delicate.

3D parallelism

Frontier models combine all three: tensor parallelism within a node, pipeline parallelism across nodes within a "stage group", and data parallelism across stage groups. A 1024-GPU GPT-3-scale training run might use TP=8, PP=16, DP=8 for $8 \times 16 \times 8 = 1024$ GPUs. Picking the partitioning (which dimension gets which factor) is its own optimisation problem; tools like Alpa automate it.

This site is currently in Beta. Contact: Chris Paton

Textbook of Usability · Textbook of Digital Health

Auckland Maths and Science Tutoring

AI tools used: Claude (research, coding, text), ChatGPT (diagrams, images), Grammarly (editing).