Glossary

Train a Diffusion Model

DDPM (Ho et al. 2020) defines a forward process that gradually adds Gaussian noise to an image $x_0 \sim q(x_0)$ over $T$ steps:

$$q(x_t \mid x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t}\,x_{t-1}, \beta_t I),$$

with closed form $x_t = \sqrt{\bar\alpha_t}\,x_0 + \sqrt{1-\bar\alpha_t}\,\epsilon$, where $\bar\alpha_t = \prod_{s=1}^t (1-\beta_s)$ and $\epsilon \sim \mathcal{N}(0, I)$. The reverse process is parametrised by a UNet $\epsilon_\theta(x_t, t)$ that predicts the noise, trained with the simplified loss

$$\mathcal{L} = \mathbb{E}_{x_0,\,\epsilon,\,t}\big[\|\epsilon - \epsilon_\theta(x_t, t)\|^2\big].$$

Noise schedule. Linear $\beta_t$ from $10^{-4}$ to $0.02$ over $T=1000$ steps works for 32×32 (CIFAR-10). For 64×64+ use cosine $\bar\alpha_t = \cos^2((t/T + s)/(1+s) \cdot \pi/2)$, $s=0.008$, gives much better samples at high res.

UNet architecture. Encoder–decoder with skip connections, conditioned on timestep $t$ via sinusoidal embedding → MLP → injected into every residual block (FiLM-style). For 64×64 RGB:

in:  3×64×64
down: 64 → 128 → 256 → 512   (each stage: 2 ResBlocks + attention at 16×16, 8×8)
mid:  512 with self-attention
up:   512 → 256 → 128 → 64   (skip connections from down path)
out:  3×64×64                 (predicted ε)
class ResBlock(nn.Module):
    def __init__(self, c_in, c_out, t_dim):
        self.norm1 = nn.GroupNorm(32, c_in); self.conv1 = nn.Conv2d(c_in, c_out, 3, 1, 1)
        self.t_proj = nn.Linear(t_dim, c_out)
        self.norm2 = nn.GroupNorm(32, c_out); self.conv2 = nn.Conv2d(c_out, c_out, 3, 1, 1)
        self.skip  = nn.Conv2d(c_in, c_out, 1) if c_in != c_out else nn.Identity()
    def forward(self, x, t_emb):
        h = self.conv1(F.silu(self.norm1(x)))
        h = h + self.t_proj(F.silu(t_emb))[:, :, None, None]
        h = self.conv2(F.silu(self.norm2(h)))
        return h + self.skip(x)

Training loop.

betas = cosine_schedule(T=1000)
alphas = 1 - betas
alpha_bar = torch.cumprod(alphas, dim=0)         # ᾱ_t
sqrt_ab = torch.sqrt(alpha_bar)
sqrt_1mab = torch.sqrt(1 - alpha_bar)

model = UNet().cuda()
ema = copy.deepcopy(model); ema.requires_grad_(False)
opt = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=0.0)

for step in range(800_000):
    x0, y = next(iter(loader))                   # x0 ∈ [-1, 1]
    x0 = x0.cuda(); y = y.cuda()
    t  = torch.randint(0, T, (x0.size(0),), device="cuda")
    eps = torch.randn_like(x0)
    xt = sqrt_ab[t][:,None,None,None]*x0 + sqrt_1mab[t][:,None,None,None]*eps

    # Classifier-free guidance: drop labels 10% of the time
    mask = (torch.rand(y.size(0)) < 0.1).cuda()
    y_in = torch.where(mask, torch.full_like(y, NULL_CLASS), y)

    eps_pred = model(xt, t, y_in)
    loss = F.mse_loss(eps_pred, eps)
    opt.zero_grad(); loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()

    # EMA update
    with torch.no_grad():
        for p, q in zip(model.parameters(), ema.parameters()):
            q.mul_(0.9999).add_(p.data, alpha=0.0001)

Hyperparameters. AdamW, lr $2\times 10^{-4}$ constant (no schedule needed), batch 128, 800k steps, EMA decay 0.9999. Train at fp32 or bf16, fp16 is unstable in the UNet's GroupNorm.

Sampling (DDPM, $T=1000$ steps).

@torch.no_grad()
def sample(ema, y, w=3.0):
    x = torch.randn(B, 3, 64, 64).cuda()
    for t in reversed(range(T)):
        tt = torch.full((B,), t, device="cuda")
        eps_c = ema(x, tt, y)
        eps_u = ema(x, tt, NULL_CLASS)
        eps   = (1 + w) * eps_c - w * eps_u                      # CFG
        mu = (x - betas[t]/sqrt_1mab[t] * eps) / torch.sqrt(alphas[t])
        if t > 0:
            x = mu + torch.sqrt(betas[t]) * torch.randn_like(x)
        else:
            x = mu
    return x

Classifier-free guidance (Ho & Salimans 2022) trades diversity for sample quality at $w \in [1, 5]$. Faster samplers (DDIM, DPM-Solver) reduce the 1000-step sampling loop to 25-50 steps.

Compute. 800k steps × batch 128 × ~5 GFLOPs/sample at 64×64 ≈ $5\!\times\!10^{20}$ FLOPs. On 8×A100, ~3-5 days for CIFAR/ImageNet-64. Latent diffusion (Stable Diffusion) runs the diffusion in a 4×64×64 VAE latent space, cutting compute ~10x.

Pitfalls. Forgetting the EMA copy gives ~3 FID worse, always sample from EMA. Predicting $x_0$ instead of $\epsilon$ at high $t$ blows up. Linear schedule at high resolution destroys signal too early; switch to cosine. Sampling at fp16 introduces visible banding.

Related terms: Diffusion Model, Convolutional Neural Network, Attention Mechanism, Adam

Discussed in:

This site is currently in Beta. Contact: Chris Paton

Textbook of Usability · Textbook of Digital Health

Auckland Maths and Science Tutoring

AI tools used: Claude (research, coding, text), ChatGPT (diagrams, images), Grammarly (editing).