Fully Sharded Data Parallel (FSDP) is PyTorch's native implementation of parameter sharding, the same idea pioneered by DeepSpeed's zero Stage 3. Where ddp replicates the entire model on every device, FSDP partitions each parameter tensor into $N$ flat shards , one per worker, so each device permanently holds only $P/N$ of the parameters, gradients, and optimiser state. For a 70B-parameter model trained on 8 GPUs in BF16 with Adam (16 bytes per parameter end-to-end), the memory per GPU drops from ~1.1 TB (DDP, infeasible) to ~140 GB (FSDP, plausible with offloading).
The execution pattern uses two collectives. Before each module's forward pass, FSDP all-gathers the shards back into the full parameter tensor on every device:
$$W = \mathrm{AllGather}(W_1, W_2, \dots, W_N).$$
The forward computation then proceeds locally, and FSDP immediately frees the gathered tensor (or keeps it for the backward pass, depending on the policy). After the backward pass produces the full local gradient, a reduce-scatter averages and shards it in one fused collective:
$$\bar{g}_i = \frac{1}{N} \sum_{j=1}^N g_j^{(i)},$$
so worker $i$ ends up holding only its own shard of the averaged gradient $\bar{g}_i$. Each worker then applies the optimiser update to its own shard locally, with no further communication.
The total per-step communication is one all-gather plus one reduce-scatter, which together move the same volume as one all-reduce, so FSDP's communication cost is comparable to DDP's despite the much smaller memory footprint. The scheduling is more delicate, however: gathering parameters layer-by-layer overlaps the all-gather of layer $L+1$ with the compute of layer $L$, but bursty memory peaks during the gather can still spill if the wrap policy is wrong.
Wrap policy determines the granularity of sharding. Wrapping the entire model as one unit gives maximum memory savings but no overlap; wrapping every linear layer creates too many tiny collectives. The pragmatic choice for transformers is to wrap each transformer block as a single FSDP unit, so one all-gather brings in attention plus FFN parameters together and one reduce-scatter sends the gradients back.
FSDP supports several mixed-precision modes orthogonally: parameters, gradients, and reductions can each independently be FP32, BF16, or FP16. The standard recipe keeps a master FP32 shard for the optimiser, gathers BF16 parameters for compute, and reduce-scatters BF16 gradients , matching the mixed-precision playbook with full sharding underneath.
CPU offloading is the last memory lever: FSDP can keep parameter and optimiser shards on CPU and stream them to GPU just-in-time. This sacrifices throughput (PCIe bandwidth is much lower than HBM) but enables training models that exceed total GPU memory. Combined with activation checkpointing and offload, single-node FSDP can fine-tune models up to roughly the host's RAM size.
FSDP and tensor-parallelism compose: in HSDP (Hybrid Sharded Data Parallel) the global mesh is split into intra-node tensor-parallel groups and inter-node FSDP groups, so the all-gather travels only over fast NVLink while gradient sharding spans the slower inter-node fabric.
Related terms: Distributed Data Parallel, ZeRO, Tensor Parallelism, Pipeline Parallelism, Mixed Precision Training
Discussed in:
- Chapter 15: Modern AI, Engineering at Scale