10.17 A complete from-scratch training loop
We end with a complete PyTorch training loop that puts the chapter together: mini-batch SGD, Adam optimiser, cosine schedule with warmup, gradient clipping, mixed precision, checkpointing, and a sketch of distributed training.
import math
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, DistributedSampler
import torch.distributed as dist
def cosine_with_warmup(step, total_steps, warmup_steps, lr_max, lr_min):
"""Linear warmup, then cosine decay."""
if step < warmup_steps:
return lr_max * step / warmup_steps
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * progress))
def train(
model,
train_dataset,
val_dataset,
*,
total_steps=100_000,
warmup_steps=1_000,
batch_size=256,
lr_max=3e-4,
lr_min=3e-5,
weight_decay=1e-2,
grad_clip=1.0,
log_every=50,
eval_every=1_000,
ckpt_path="checkpoint.pt",
use_ddp=False,
rank=0,
world_size=1,
):
device = torch.device(f"cuda:{rank}")
model.to(device)
if use_ddp:
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
else:
sampler = None
train_loader = DataLoader(
train_dataset,
batch_size=batch_size // world_size,
sampler=sampler,
shuffle=(sampler is None),
num_workers=4,
pin_memory=True,
drop_last=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
)
# AdamW with decoupled weight decay
no_decay = {"bias", "LayerNorm.weight"}
params_decay = [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)]
params_nodecay = [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)]
optimizer = optim.AdamW(
[
{"params": params_decay, "weight_decay": weight_decay},
{"params": params_nodecay, "weight_decay": 0.0},
],
lr=lr_max,
betas=(0.9, 0.95),
eps=1e-8,
)
scaler = GradScaler()
step = 0
epoch = 0
best_val = float("inf")
# Resume if checkpoint exists
if os.path.exists(ckpt_path):
ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
scaler.load_state_dict(ckpt["scaler"])
step = ckpt["step"]
epoch = ckpt["epoch"]
best_val = ckpt["best_val"]
while step < total_steps:
if sampler is not None:
sampler.set_epoch(epoch)
model.train()
for x, y in train_loader:
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
# Update LR
lr = cosine_with_warmup(step, total_steps, warmup_steps, lr_max, lr_min)
for pg in optimizer.param_groups:
pg["lr"] = lr
optimizer.zero_grad(set_to_none=True)
# Forward + loss in autocast
with autocast(dtype=torch.bfloat16):
logits = model(x)
loss = nn.functional.cross_entropy(logits, y)
# Backward with optional FP16 scaling (BF16 doesn't need it)
scaler.scale(loss).backward()
# Gradient clipping
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
# Optimiser step
scaler.step(optimizer)
scaler.update()
if step % log_every == 0 and rank == 0:
print(f"step {step:6d} lr {lr:.2e} loss {loss.item():.4f} |g| {grad_norm:.3f}")
if step % eval_every == 0 and step > 0:
val_loss = evaluate(model, val_loader, device)
if rank == 0:
print(f" val loss {val_loss:.4f}")
if val_loss < best_val and rank == 0:
best_val = val_loss
torch.save({
"model": (model.module if use_ddp else model).state_dict(),
"optimizer": optimizer.state_dict(),
"scaler": scaler.state_dict(),
"step": step,
"epoch": epoch,
"best_val": best_val,
}, ckpt_path)
step += 1
if step >= total_steps:
break
epoch += 1
@torch.no_grad()
def evaluate(model, loader, device):
model.eval()
total_loss, total_count = 0.0, 0
for x, y in loader:
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
with autocast(dtype=torch.bfloat16):
logits = model(x)
loss = nn.functional.cross_entropy(logits, y, reduction="sum")
total_loss += loss.item()
total_count += x.size(0)
return total_loss / total_count
Launching distributed training
# torchrun --nproc_per_node=8 train.py
def main():
rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
model = build_model()
train_dataset, val_dataset = build_datasets()
train(
model, train_dataset, val_dataset,
use_ddp=True, rank=rank, world_size=world_size,
batch_size=4096, # global batch
lr_max=3e-4 * (4096 / 256), # linear scaling rule
warmup_steps=2_000,
)
dist.destroy_process_group()
if __name__ == "__main__":
main()
This loop incorporates almost every idea in the chapter: warmup-then-cosine schedule, AdamW with parameter groups (no weight decay on biases or LayerNorm), bfloat16 mixed precision, gradient clipping, periodic evaluation and checkpointing, distributed sampler with set_epoch for reproducibility, and global batch size scaled with the linear rule for distributed training. With minor adjustments (FSDP wrapper for ZeRO-3, gradient accumulation for effective larger batches than memory allows) it scales from a single GPU pilot run up to multi-node frontier training.