14.16 From scratch: a working DDPM
We now build a complete diffusion model, training on MNIST. The two components are a U-Net noise predictor and the training/sampling loops. This implementation is faithful to Ho et al. (2020) at the level of equations; we simplify the architecture for pedagogical clarity.
Sinusoidal time embeddings
The U-Net needs to know which timestep $t$ it is denoising. Following the Transformer (Vaswani et al., 2017) we use sinusoidal embeddings:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class SinusoidalTimeEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
# t is shape (batch,), output is (batch, dim)
half = self.dim // 2
freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device) / (half - 1))
args = t[:, None].float() * freqs[None]
return torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
A small U-Net
class ResBlock(nn.Module):
def __init__(self, ch_in, ch_out, t_dim):
super().__init__()
self.conv1 = nn.Conv2d(ch_in, ch_out, 3, padding=1)
self.conv2 = nn.Conv2d(ch_out, ch_out, 3, padding=1)
self.norm1 = nn.GroupNorm(8, ch_out)
self.norm2 = nn.GroupNorm(8, ch_out)
self.t_proj = nn.Linear(t_dim, ch_out)
self.skip = nn.Conv2d(ch_in, ch_out, 1) if ch_in != ch_out else nn.Identity()
def forward(self, x, t_emb):
h = F.silu(self.norm1(self.conv1(x)))
h = h + self.t_proj(F.silu(t_emb))[:, :, None, None]
h = F.silu(self.norm2(self.conv2(h)))
return h + self.skip(x)
class UNet(nn.Module):
def __init__(self, ch=64, t_dim=128):
super().__init__()
self.t_emb = nn.Sequential(
SinusoidalTimeEmbedding(t_dim),
nn.Linear(t_dim, t_dim), nn.SiLU(),
nn.Linear(t_dim, t_dim),
)
self.in_conv = nn.Conv2d(1, ch, 3, padding=1)
# Encoder
self.down1 = ResBlock(ch, ch, t_dim)
self.down2 = ResBlock(ch, 2*ch, t_dim)
self.down3 = ResBlock(2*ch, 4*ch, t_dim)
# Bottleneck
self.mid = ResBlock(4*ch, 4*ch, t_dim)
# Decoder (with skip connections)
self.up3 = ResBlock(8*ch, 2*ch, t_dim)
self.up2 = ResBlock(4*ch, ch, t_dim)
self.up1 = ResBlock(2*ch, ch, t_dim)
self.out_conv = nn.Conv2d(ch, 1, 3, padding=1)
self.pool = nn.AvgPool2d(2)
self.up = nn.Upsample(scale_factor=2, mode='nearest')
def forward(self, x, t):
te = self.t_emb(t)
h0 = self.in_conv(x)
h1 = self.down1(h0, te)
h2 = self.down2(self.pool(h1), te)
h3 = self.down3(self.pool(h2), te)
m = self.mid(h3, te)
u3 = self.up3(torch.cat([m, h3], dim=1), te)
u2 = self.up2(torch.cat([self.up(u3), h2], dim=1), te)
u1 = self.up1(torch.cat([self.up(u2), h1], dim=1), te)
return self.out_conv(u1)
This is roughly the Ho et al. architecture with three resolution levels. Channel widths are doubled each downsampling. The skip connections are essential: without them the network has to reconstruct high-frequency detail at each upsampling step from the bottleneck, which is hopeless.
Noise schedule and pre-computed quantities
def make_schedule(T=1000, beta_min=1e-4, beta_max=0.02, device='cuda'):
betas = torch.linspace(beta_min, beta_max, T, device=device)
alphas = 1.0 - betas
alpha_bars = torch.cumprod(alphas, dim=0)
return {
'betas': betas,
'alphas': alphas,
'alpha_bars': alpha_bars,
'sqrt_ab': torch.sqrt(alpha_bars),
'sqrt_one_minus_ab': torch.sqrt(1 - alpha_bars),
}
A linear schedule from $\beta_1 = 10^{-4}$ to $\beta_T = 0.02$ over $T = 1000$ steps is the original DDPM choice. Cosine schedules (Nichol & Dhariwal, 2021) work better for higher resolutions.
Training step
def q_sample(x0, t, sched):
eps = torch.randn_like(x0)
sqrt_ab = sched['sqrt_ab'][t][:, None, None, None]
sqrt_om = sched['sqrt_one_minus_ab'][t][:, None, None, None]
return sqrt_ab * x0 + sqrt_om * eps, eps
def ddpm_loss(model, x0, sched, T):
bsz = x0.size(0)
t = torch.randint(0, T, (bsz,), device=x0.device)
xt, eps = q_sample(x0, t, sched)
eps_pred = model(xt, t)
return F.mse_loss(eps_pred, eps)
The whole training loop:
def train_ddpm(epochs=20, batch_size=128, lr=2e-4, T=1000, device='cuda'):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)), # to [-1, 1]
])
train_set = datasets.MNIST('.', train=True, download=True, transform=transform)
loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
model = UNet().to(device)
sched = make_schedule(T=T, device=device)
opt = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
for x, _ in loader:
x = x.to(device)
loss = ddpm_loss(model, x, sched, T)
opt.zero_grad(); loss.backward(); opt.step()
print(f"epoch {epoch}: loss {loss.item():.4f}")
return model, sched
Sampling
@torch.no_grad()
def sample(model, sched, n=16, T=1000, device='cuda'):
x = torch.randn(n, 1, 28, 28, device=device)
for t in reversed(range(T)):
t_batch = torch.full((n,), t, device=device, dtype=torch.long)
eps = model(x, t_batch)
beta = sched['betas'][t]
ab = sched['alpha_bars'][t]
a = sched['alphas'][t]
coef = (1 - a) / torch.sqrt(1 - ab)
mean = (x - coef * eps) / torch.sqrt(a)
if t > 0:
noise = torch.randn_like(x)
x = mean + torch.sqrt(beta) * noise
else:
x = mean
return x.clamp(-1, 1)
After 20 epochs on MNIST you should see clearly recognisable digits. The 1000-step DDPM sampler is slow, about 20 seconds for a batch of 16 on a single GPU. Switching to DDIM with 50 steps reduces this to a second or two with negligible quality loss.
What you will see
Early in training the U-Net outputs near-zero noise predictions, so sampling produces near-zero $x_0$, uniform grey. By epoch 5 you should see digit-like blobs with the correct statistics (white strokes on black). By epoch 20 the digits are well-formed; the diversity is high (every digit class appears roughly equally) and the mode coverage is essentially complete.
This is a faithful, working DDPM in roughly 100 lines of code. Industrial-strength implementations differ in scale (more channels, more attention layers, EMA weights, mixed-precision training) but not in fundamental structure.