10.12 Distributed training
Frontier-scale models (10B+ parameters, trillions of training tokens) cannot fit on a single GPU and would take centuries to train serially. Distributed training spreads the work across many GPUs, possibly across many nodes connected by high-speed networking.
There are three orthogonal axes of parallelism: data, tensor, and pipeline. Modern training combines all three. We start with the simplest: data parallelism.
Data parallelism (DDP)
Replicate the model on each GPU. Each GPU processes a different mini-batch slice. After the backward pass, gradients are averaged across GPUs (using all-reduce) and all replicas perform the same optimiser step.
Mathematically: with $K$ GPUs and per-GPU batch $B$, the effective batch size is $KB$. The gradient computed by all-reduce is
$$\hat g = \frac{1}{KB} \sum_{k=1}^K \sum_{i \in \mathcal{B}_k} \nabla \ell_i(\theta).$$
This is identical to a single-GPU gradient on a $KB$ batch. DDP gives perfect scaling so long as the all-reduce communication time is masked by the backward computation. For modern interconnects (NVLink, InfiniBand) this holds up to a few hundred GPUs.
PyTorch's DistributedDataParallel is the canonical implementation. It overlaps gradient communication with backward computation by bucketing parameters and launching all-reduce operations as soon as each bucket's gradients are ready.
ZeRO: redundant memory elimination
ZeRO (Zero Redundancy Optimizer; Rajbhandari et al. 2020) eliminates the memory duplication inherent in DDP. With $K$ replicas, each GPU stores a full copy of the model parameters $\theta$, gradients $g$, and optimiser state $(m, v)$. Total memory is $K \times (4P)$ for an Adam-style optimiser (parameter + gradient + first moment + second moment), all in FP32, where $P$ is the parameter count.
ZeRO partitions these across GPUs:
- ZeRO-1: shard optimiser state. Each GPU holds $1/K$ of $(m, v)$. After all-reduce of gradients, each GPU updates its assigned shard, then all-gathers the updated parameters. Memory savings: $4\times$ for Adam.
- ZeRO-2: shard optimiser state + gradients. Backward pass uses reduce-scatter instead of all-reduce. Memory savings: $8\times$.
- ZeRO-3: shard optimiser state + gradients + parameters. Each GPU only holds $1/K$ of $\theta$. Forward and backward passes all-gather the parameters of the current layer just in time, then release them. Memory savings: $K\times$.
ZeRO-3 is the basis of FSDP (Fully Sharded Data Parallel) in PyTorch. For training a 70B model on 1024 GPUs with ZeRO-3, each GPU stores only $\sim 70$M parameters' worth of state, modest enough to fit comfortably in 80 GB of HBM with room for activations.
The trade-off is communication: ZeRO-3 requires roughly $2\times$ the inter-GPU traffic of plain DDP (an all-gather in forward, an all-gather in backward, a reduce-scatter for gradients). On NVLink-connected nodes the overhead is small; across InfiniBand it is more significant.
Pipeline parallelism
Split the model vertically by layer. GPU 1 holds layers 1–10, GPU 2 holds layers 11–20, etc. The forward pass flows from GPU 1 to GPU $K$; the backward pass flows back. This enables training models too large to fit on any single GPU's memory.
Naive pipeline has poor utilisation: GPU 2 is idle while GPU 1 computes the first batch's forward; GPU 3 is idle while GPU 1 and 2 compute. To fix this, GPipe (Huang et al. 2019) splits each batch into $M$ micro-batches and pipelines them. The schedule is:
GPU 1 | f1 f2 f3 f4 ... b4 b3 b2 b1
GPU 2 | f1 f2 f3 f4 ... b4 b3 b2 b1
GPU 3 | f1 f2 f3 f4 ... b4 b3 b2 b1
GPU 4 | f1 f2 f3 f4 ... b4 b3 b2 b1
The empty cells at the start and end are the pipeline bubble. With $K$ stages and $M$ micro-batches, the bubble is a fraction $(K - 1)/(K + M - 1)$ of total time. Larger $M$ reduces the bubble but increases activation memory (you must store activations for all $M$ in-flight micro-batches until the backward pass). Typical choices: $M = 4K$ to $8K$, bubble $\approx 10\%$.
PipeDream (Narayanan et al. 2019) and 1F1B (one-forward-one-backward) schedules further reduce bubble overhead by interleaving forward and backward passes.
Tensor parallelism
Split each weight matrix across GPUs. For a Transformer linear layer $Y = XW$ where $W \in \mathbb{R}^{d_{\mathrm{in}} \times d_{\mathrm{out}}}$, partition $W$ into column blocks $W = [W_1 | W_2 | \ldots | W_K]$ across $K$ GPUs. Each GPU computes its local $Y_k = XW_k$. To get the full $Y$, all-gather the $Y_k$.
Megatron-LM (Shoeybi et al. 2019) showed how to combine tensor parallelism efficiently across multiple Transformer components: column-parallel for the QKV projection, row-parallel for the output projection, composed so that only one all-reduce is needed per attention block.
Tensor parallelism is bandwidth-intensive: every layer requires an all-reduce or all-gather. It is therefore typically restricted to within a single high-bandwidth domain (NVLink-connected GPUs in one node, $\le 8$ GPUs).
Expert parallelism
For Mixture-of-Experts (MoE) models, different experts can live on different GPUs. The router sends each token to its top-$k$ experts; an all-to-all communication step routes activations to the right GPUs. Expert parallelism scales the parameter count without scaling the per-token compute, but the all-to-all is bandwidth-hungry and load-balancing the experts is delicate.
3D parallelism
Frontier models combine all three: tensor parallelism within a node, pipeline parallelism across nodes within a "stage group", and data parallelism across stage groups. A 1024-GPU GPT-3-scale training run might use TP=8, PP=16, DP=8 for $8 \times 16 \times 8 = 1024$ GPUs. Picking the partitioning (which dimension gets which factor) is its own optimisation problem; tools like Alpa automate it.