Glossary

Train GPT-2 Recipe

Recipe to reproduce a small GPT-2 (125M params, 12 layers, 12 heads, $d_{\text{model}}=768$, context $T=1024$) from scratch. Targets the next-token prediction loss

$$\mathcal{L} = -\frac{1}{N}\sum_{i=1}^{N} \log p_\theta(x_i \mid x_{\lt i})$$

over a tokenised corpus of roughly 10B tokens (OpenWebText-scale). Reproduces Karpathy's nanoGPT defaults.

Data pipeline.

  1. Download corpus (OpenWebText, ~40GB raw text) and deduplicate by URL + MinHash.
  2. Train a byte-level BPE tokeniser with 50,257 merges (matches GPT-2 vocab).
  3. Tokenise the entire corpus into a single contiguous uint16 memmap. For 10B tokens this is ~20GB on disk; loading is then a free random-offset view.
  4. Hold out a 0.5% validation slice.
# Tokenise once, store as memmap
ids = tokenizer.encode_ordinary(open("corpus.txt").read())
arr = np.memmap("train.bin", dtype=np.uint16, mode="w+", shape=(len(ids),))
arr[:] = np.array(ids, dtype=np.uint16)

Model architecture. Decoder-only Transformer with causal self-attention, pre-LayerNorm, and learned positional embeddings.

class Block(nn.Module):
    def __init__(self, d, n_head):
        self.ln1 = nn.LayerNorm(d)
        self.attn = CausalSelfAttention(d, n_head)  # uses FlashAttention
        self.ln2 = nn.LayerNorm(d)
        self.mlp = nn.Sequential(
            nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d), nn.Dropout(0.0)
        )
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class GPT(nn.Module):
    def __init__(self, vocab=50257, T=1024, d=768, n_layer=12, n_head=12):
        self.tok = nn.Embedding(vocab, d)
        self.pos = nn.Embedding(T, d)
        self.blocks = nn.ModuleList([Block(d, n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(d)
        self.head = nn.Linear(d, vocab, bias=False)
        self.head.weight = self.tok.weight    # weight tying

Use FlashAttention-2 for the attention kernel, saves ~3x memory and runs ~1.5x faster than vanilla softmax(QK^T/sqrt(d))V at $T=1024$.

Optimiser and schedule.

  • AdamW, $\beta_1=0.9$, $\beta_2=0.95$, $\epsilon=10^{-8}$, weight decay $0.1$ (applied only to 2D weights, not biases or LayerNorm).
  • Peak LR $6\times 10^{-4}$, linear warmup for 2000 steps, then cosine decay to $6\times 10^{-5}$ over 600,000 steps.
  • Gradient clipping at norm 1.0.
  • Effective batch size 0.5M tokens (e.g. micro-batch 12 × seq 1024 × 40 grad-accum steps on a single A100, or scaled out across 8 GPUs).

Training loop.

def get_batch(split):
    data = train_mm if split == "train" else val_mm
    ix = torch.randint(len(data) - T - 1, (B,))
    x = torch.stack([torch.from_numpy(data[i:i+T].astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy(data[i+1:i+1+T].astype(np.int64)) for i in ix])
    return x.to("cuda"), y.to("cuda")

model = GPT().to("cuda")
model = torch.compile(model)
opt = model.configure_optimizers(weight_decay=0.1, lr=6e-4, betas=(0.9, 0.95))
scaler = torch.cuda.amp.GradScaler()

for step in range(max_steps):
    lr = cosine_lr(step, warmup=2000, max_steps=600_000, peak=6e-4, min=6e-5)
    for g in opt.param_groups: g["lr"] = lr

    opt.zero_grad(set_to_none=True)
    for micro in range(grad_accum_steps):
        x, y = get_batch("train")
        with torch.amp.autocast("cuda", dtype=torch.bfloat16):
            logits = model(x)
            loss = F.cross_entropy(logits.view(-1, vocab), y.view(-1))
            loss = loss / grad_accum_steps
        loss.backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    opt.step()

    if step % 1000 == 0:
        val_loss = estimate_loss(model, "val")
        print(f"step {step}  train {loss.item():.4f}  val {val_loss:.4f}")

Compute estimate. Total FLOPs $\approx 6 \cdot N \cdot D = 6 \cdot 1.25!\times!10^8 \cdot 10^{10} \approx 7.5!\times!10^{18}$ FLOPs. On a single A100 (~150 TFLOPs sustained bf16), that is ~14 GPU-days. On 8×A100 with proper data parallel, ~1.7 wall-clock days. Target validation loss ≈ 2.85 (perplexity ~17) on OpenWebText.

Pitfalls. Loss spikes early in training usually mean the LR is too high, drop peak LR to $3\times 10^{-4}$. NaNs in bf16 mean an unstable LayerNorm, keep LayerNorm in fp32. Forgetting weight tying costs ~1% perplexity. Not shuffling validation offsets gives a noisy curve.

Related terms: Transformer, Attention Mechanism, Adam, Mixed Precision Training, Gradient Descent

Discussed in:

This site is currently in Beta. Contact: Chris Paton

Textbook of Usability · Textbook of Digital Health

Auckland Maths and Science Tutoring

AI tools used: Claude (research, coding, text), ChatGPT (diagrams, images), Grammarly (editing).