14.16 From scratch: a working DDPM

We now build a complete diffusion model, training on MNIST. The two components are a U-Net noise predictor and the training/sampling loops. This implementation is faithful to Ho et al. (2020) at the level of equations; we simplify the architecture for pedagogical clarity.

Sinusoidal time embeddings

The U-Net needs to know which timestep $t$ it is denoising. Following the Transformer (Vaswani et al., 2017) we use sinusoidal embeddings:

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        # t is shape (batch,), output is (batch, dim)
        half = self.dim // 2
        freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device) / (half - 1))
        args = t[:, None].float() * freqs[None]
        return torch.cat([torch.sin(args), torch.cos(args)], dim=-1)

A small U-Net

class ResBlock(nn.Module):
    def __init__(self, ch_in, ch_out, t_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(ch_in, ch_out, 3, padding=1)
        self.conv2 = nn.Conv2d(ch_out, ch_out, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, ch_out)
        self.norm2 = nn.GroupNorm(8, ch_out)
        self.t_proj = nn.Linear(t_dim, ch_out)
        self.skip = nn.Conv2d(ch_in, ch_out, 1) if ch_in != ch_out else nn.Identity()

    def forward(self, x, t_emb):
        h = F.silu(self.norm1(self.conv1(x)))
        h = h + self.t_proj(F.silu(t_emb))[:, :, None, None]
        h = F.silu(self.norm2(self.conv2(h)))
        return h + self.skip(x)

class UNet(nn.Module):
    def __init__(self, ch=64, t_dim=128):
        super().__init__()
        self.t_emb = nn.Sequential(
            SinusoidalTimeEmbedding(t_dim),
            nn.Linear(t_dim, t_dim), nn.SiLU(),
            nn.Linear(t_dim, t_dim),
        )
        self.in_conv = nn.Conv2d(1, ch, 3, padding=1)
        # Encoder
        self.down1 = ResBlock(ch,   ch,   t_dim)
        self.down2 = ResBlock(ch,   2*ch, t_dim)
        self.down3 = ResBlock(2*ch, 4*ch, t_dim)
        # Bottleneck
        self.mid   = ResBlock(4*ch, 4*ch, t_dim)
        # Decoder (with skip connections)
        self.up3   = ResBlock(8*ch, 2*ch, t_dim)
        self.up2   = ResBlock(4*ch, ch,   t_dim)
        self.up1   = ResBlock(2*ch, ch,   t_dim)
        self.out_conv = nn.Conv2d(ch, 1, 3, padding=1)
        self.pool = nn.AvgPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x, t):
        te = self.t_emb(t)
        h0 = self.in_conv(x)
        h1 = self.down1(h0, te)
        h2 = self.down2(self.pool(h1), te)
        h3 = self.down3(self.pool(h2), te)
        m  = self.mid(h3, te)
        u3 = self.up3(torch.cat([m, h3], dim=1), te)
        u2 = self.up2(torch.cat([self.up(u3), h2], dim=1), te)
        u1 = self.up1(torch.cat([self.up(u2), h1], dim=1), te)
        return self.out_conv(u1)

This is roughly the Ho et al. architecture with three resolution levels. Channel widths are doubled each downsampling. The skip connections are essential: without them the network has to reconstruct high-frequency detail at each upsampling step from the bottleneck, which is hopeless.

Noise schedule and pre-computed quantities

def make_schedule(T=1000, beta_min=1e-4, beta_max=0.02, device='cuda'):
    betas = torch.linspace(beta_min, beta_max, T, device=device)
    alphas = 1.0 - betas
    alpha_bars = torch.cumprod(alphas, dim=0)
    return {
        'betas': betas,
        'alphas': alphas,
        'alpha_bars': alpha_bars,
        'sqrt_ab': torch.sqrt(alpha_bars),
        'sqrt_one_minus_ab': torch.sqrt(1 - alpha_bars),
    }

A linear schedule from $\beta_1 = 10^{-4}$ to $\beta_T = 0.02$ over $T = 1000$ steps is the original DDPM choice. Cosine schedules (Nichol & Dhariwal, 2021) work better for higher resolutions.

Training step

def q_sample(x0, t, sched):
    eps = torch.randn_like(x0)
    sqrt_ab = sched['sqrt_ab'][t][:, None, None, None]
    sqrt_om = sched['sqrt_one_minus_ab'][t][:, None, None, None]
    return sqrt_ab * x0 + sqrt_om * eps, eps

def ddpm_loss(model, x0, sched, T):
    bsz = x0.size(0)
    t = torch.randint(0, T, (bsz,), device=x0.device)
    xt, eps = q_sample(x0, t, sched)
    eps_pred = model(xt, t)
    return F.mse_loss(eps_pred, eps)

The whole training loop:

def train_ddpm(epochs=20, batch_size=128, lr=2e-4, T=1000, device='cuda'):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),  # to [-1, 1]
    ])
    train_set = datasets.MNIST('.', train=True, download=True, transform=transform)
    loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)

    model = UNet().to(device)
    sched = make_schedule(T=T, device=device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        for x, _ in loader:
            x = x.to(device)
            loss = ddpm_loss(model, x, sched, T)
            opt.zero_grad(); loss.backward(); opt.step()
        print(f"epoch {epoch}: loss {loss.item():.4f}")
    return model, sched

Sampling

@torch.no_grad()
def sample(model, sched, n=16, T=1000, device='cuda'):
    x = torch.randn(n, 1, 28, 28, device=device)
    for t in reversed(range(T)):
        t_batch = torch.full((n,), t, device=device, dtype=torch.long)
        eps = model(x, t_batch)
        beta = sched['betas'][t]
        ab   = sched['alpha_bars'][t]
        a    = sched['alphas'][t]
        coef = (1 - a) / torch.sqrt(1 - ab)
        mean = (x - coef * eps) / torch.sqrt(a)
        if t > 0:
            noise = torch.randn_like(x)
            x = mean + torch.sqrt(beta) * noise
        else:
            x = mean
    return x.clamp(-1, 1)

After 20 epochs on MNIST you should see clearly recognisable digits. The 1000-step DDPM sampler is slow, about 20 seconds for a batch of 16 on a single GPU. Switching to DDIM with 50 steps reduces this to a second or two with negligible quality loss.

What you will see

Early in training the U-Net outputs near-zero noise predictions, so sampling produces near-zero $x_0$, uniform grey. By epoch 5 you should see digit-like blobs with the correct statistics (white strokes on black). By epoch 20 the digits are well-formed; the diversity is high (every digit class appears roughly equally) and the mode coverage is essentially complete.

This is a faithful, working DDPM in roughly 100 lines of code. Industrial-strength implementations differ in scale (more channels, more attention layers, EMA weights, mixed-precision training) but not in fundamental structure.

This site is currently in Beta. Contact: Chris Paton

Textbook of Usability · Textbook of Digital Health

Auckland Maths and Science Tutoring

AI tools used: Claude (research, coding, text), ChatGPT (diagrams, images), Grammarly (editing).