Recipe to reproduce a small GPT-2 (125M params, 12 layers, 12 heads, $d_{\text{model}}=768$, context $T=1024$) from scratch. Targets the next-token prediction loss
$$\mathcal{L} = -\frac{1}{N}\sum_{i=1}^{N} \log p_\theta(x_i \mid x_{\lt i})$$
over a tokenised corpus of roughly 10B tokens (OpenWebText-scale). Reproduces
Karpathy's nanoGPT defaults.
Data pipeline.
- Download corpus (OpenWebText, ~40GB raw text) and deduplicate by URL + MinHash.
- Train a byte-level BPE tokeniser with 50,257 merges (matches GPT-2 vocab).
- Tokenise the entire corpus into a single contiguous
uint16memmap. For 10B tokens this is ~20GB on disk; loading is then a free random-offset view. - Hold out a 0.5% validation slice.
# Tokenise once, store as memmap
ids = tokenizer.encode_ordinary(open("corpus.txt").read())
arr = np.memmap("train.bin", dtype=np.uint16, mode="w+", shape=(len(ids),))
arr[:] = np.array(ids, dtype=np.uint16)
Model architecture. Decoder-only Transformer with causal self-attention, pre-LayerNorm, and learned positional embeddings.
class Block(nn.Module):
def __init__(self, d, n_head):
self.ln1 = nn.LayerNorm(d)
self.attn = CausalSelfAttention(d, n_head) # uses FlashAttention
self.ln2 = nn.LayerNorm(d)
self.mlp = nn.Sequential(
nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d), nn.Dropout(0.0)
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, vocab=50257, T=1024, d=768, n_layer=12, n_head=12):
self.tok = nn.Embedding(vocab, d)
self.pos = nn.Embedding(T, d)
self.blocks = nn.ModuleList([Block(d, n_head) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(d)
self.head = nn.Linear(d, vocab, bias=False)
self.head.weight = self.tok.weight # weight tying
Use FlashAttention-2 for the attention kernel, saves ~3x memory and runs ~1.5x
faster than vanilla softmax(QK^T/sqrt(d))V at $T=1024$.
Optimiser and schedule.
- AdamW, $\beta_1=0.9$, $\beta_2=0.95$, $\epsilon=10^{-8}$, weight decay $0.1$ (applied only to 2D weights, not biases or LayerNorm).
- Peak LR $6\times 10^{-4}$, linear warmup for 2000 steps, then cosine decay to $6\times 10^{-5}$ over 600,000 steps.
- Gradient clipping at norm 1.0.
- Effective batch size 0.5M tokens (e.g. micro-batch 12 × seq 1024 × 40 grad-accum steps on a single A100, or scaled out across 8 GPUs).
Training loop.
def get_batch(split):
data = train_mm if split == "train" else val_mm
ix = torch.randint(len(data) - T - 1, (B,))
x = torch.stack([torch.from_numpy(data[i:i+T].astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy(data[i+1:i+1+T].astype(np.int64)) for i in ix])
return x.to("cuda"), y.to("cuda")
model = GPT().to("cuda")
model = torch.compile(model)
opt = model.configure_optimizers(weight_decay=0.1, lr=6e-4, betas=(0.9, 0.95))
scaler = torch.cuda.amp.GradScaler()
for step in range(max_steps):
lr = cosine_lr(step, warmup=2000, max_steps=600_000, peak=6e-4, min=6e-5)
for g in opt.param_groups: g["lr"] = lr
opt.zero_grad(set_to_none=True)
for micro in range(grad_accum_steps):
x, y = get_batch("train")
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
logits = model(x)
loss = F.cross_entropy(logits.view(-1, vocab), y.view(-1))
loss = loss / grad_accum_steps
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
if step % 1000 == 0:
val_loss = estimate_loss(model, "val")
print(f"step {step} train {loss.item():.4f} val {val_loss:.4f}")
Compute estimate. Total FLOPs $\approx 6 \cdot N \cdot D = 6 \cdot 1.25!\times!10^8 \cdot 10^{10} \approx 7.5!\times!10^{18}$ FLOPs. On a single A100 (~150 TFLOPs sustained bf16), that is ~14 GPU-days. On 8×A100 with proper data parallel, ~1.7 wall-clock days. Target validation loss ≈ 2.85 (perplexity ~17) on OpenWebText.
Pitfalls. Loss spikes early in training usually mean the LR is too high, drop peak LR to $3\times 10^{-4}$. NaNs in bf16 mean an unstable LayerNorm, keep LayerNorm in fp32. Forgetting weight tying costs ~1% perplexity. Not shuffling validation offsets gives a noisy curve.
Related terms: Transformer, Attention Mechanism, Adam, Mixed Precision Training, Gradient Descent
Discussed in:
- Chapter 10: Training & Optimisation, Training Optimisation
- Chapter 9: Neural Networks, Transformer Architecture