10.17 A complete from-scratch training loop

We end with a complete PyTorch training loop that puts the chapter together: mini-batch SGD, Adam optimiser, cosine schedule with warmup, gradient clipping, mixed precision, checkpointing, and a sketch of distributed training.

import math
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, DistributedSampler
import torch.distributed as dist


def cosine_with_warmup(step, total_steps, warmup_steps, lr_max, lr_min):
    """Linear warmup, then cosine decay."""
    if step < warmup_steps:
        return lr_max * step / warmup_steps
    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * progress))


def train(
    model,
    train_dataset,
    val_dataset,
    *,
    total_steps=100_000,
    warmup_steps=1_000,
    batch_size=256,
    lr_max=3e-4,
    lr_min=3e-5,
    weight_decay=1e-2,
    grad_clip=1.0,
    log_every=50,
    eval_every=1_000,
    ckpt_path="checkpoint.pt",
    use_ddp=False,
    rank=0,
    world_size=1,
):
    device = torch.device(f"cuda:{rank}")
    model.to(device)

    if use_ddp:
        model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
        sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
    else:
        sampler = None

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size // world_size,
        sampler=sampler,
        shuffle=(sampler is None),
        num_workers=4,
        pin_memory=True,
        drop_last=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
    )

    # AdamW with decoupled weight decay
    no_decay = {"bias", "LayerNorm.weight"}
    params_decay = [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)]
    params_nodecay = [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)]
    optimizer = optim.AdamW(
        [
            {"params": params_decay, "weight_decay": weight_decay},
            {"params": params_nodecay, "weight_decay": 0.0},
        ],
        lr=lr_max,
        betas=(0.9, 0.95),
        eps=1e-8,
    )

    scaler = GradScaler()
    step = 0
    epoch = 0
    best_val = float("inf")

    # Resume if checkpoint exists
    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optimizer"])
        scaler.load_state_dict(ckpt["scaler"])
        step = ckpt["step"]
        epoch = ckpt["epoch"]
        best_val = ckpt["best_val"]

    while step < total_steps:
        if sampler is not None:
            sampler.set_epoch(epoch)
        model.train()

        for x, y in train_loader:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            # Update LR
            lr = cosine_with_warmup(step, total_steps, warmup_steps, lr_max, lr_min)
            for pg in optimizer.param_groups:
                pg["lr"] = lr

            optimizer.zero_grad(set_to_none=True)

            # Forward + loss in autocast
            with autocast(dtype=torch.bfloat16):
                logits = model(x)
                loss = nn.functional.cross_entropy(logits, y)

            # Backward with optional FP16 scaling (BF16 doesn't need it)
            scaler.scale(loss).backward()

            # Gradient clipping
            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            # Optimiser step
            scaler.step(optimizer)
            scaler.update()

            if step % log_every == 0 and rank == 0:
                print(f"step {step:6d}  lr {lr:.2e}  loss {loss.item():.4f}  |g| {grad_norm:.3f}")

            if step % eval_every == 0 and step > 0:
                val_loss = evaluate(model, val_loader, device)
                if rank == 0:
                    print(f"  val loss {val_loss:.4f}")
                if val_loss < best_val and rank == 0:
                    best_val = val_loss
                    torch.save({
                        "model": (model.module if use_ddp else model).state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "scaler": scaler.state_dict(),
                        "step": step,
                        "epoch": epoch,
                        "best_val": best_val,
                    }, ckpt_path)

            step += 1
            if step >= total_steps:
                break

        epoch += 1


@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total_loss, total_count = 0.0, 0
    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        with autocast(dtype=torch.bfloat16):
            logits = model(x)
            loss = nn.functional.cross_entropy(logits, y, reduction="sum")
        total_loss += loss.item()
        total_count += x.size(0)
    return total_loss / total_count

Launching distributed training

# torchrun --nproc_per_node=8 train.py
def main():
    rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    model = build_model()
    train_dataset, val_dataset = build_datasets()

    train(
        model, train_dataset, val_dataset,
        use_ddp=True, rank=rank, world_size=world_size,
        batch_size=4096,  # global batch
        lr_max=3e-4 * (4096 / 256),  # linear scaling rule
        warmup_steps=2_000,
    )

    dist.destroy_process_group()


if __name__ == "__main__":
    main()

This loop incorporates almost every idea in the chapter: warmup-then-cosine schedule, AdamW with parameter groups (no weight decay on biases or LayerNorm), bfloat16 mixed precision, gradient clipping, periodic evaluation and checkpointing, distributed sampler with set_epoch for reproducibility, and global batch size scaled with the linear rule for distributed training. With minor adjustments (FSDP wrapper for ZeRO-3, gradient accumulation for effective larger batches than memory allows) it scales from a single GPU pilot run up to multi-node frontier training.

This site is currently in Beta. Contact: Chris Paton

Textbook of Usability · Textbook of Digital Health

Auckland Maths and Science Tutoring

AI tools used: Claude (research, coding, text), ChatGPT (diagrams, images), Grammarly (editing).