13.10 Implementing a Transformer from scratch

Now we build a complete decoder-only Transformer in PyTorch, the GPT-style architecture, from primitive operations. The implementation has no dependencies beyond torch. Everything below is real, working code: you can paste it into a file, run it, and watch a 200K-parameter Transformer learn to predict the next character of Shakespeare.

Imports and configuration

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

@dataclass
class GPTConfig:
    vocab_size: int = 65          # tiny char-level vocabulary
    block_size: int = 128         # max context length
    n_layer: int = 4              # number of Transformer blocks
    n_head: int = 4               # number of attention heads
    n_embd: int = 128             # model dimension
    dropout: float = 0.0

Multi-head self-attention

We compute Q, K, V with one big linear layer, reshape into heads, do scaled dot-product attention with a causal mask, and project back.

class CausalSelfAttention(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()
        assert cfg.n_embd % cfg.n_head == 0
        self.n_head = cfg.n_head
        self.head_dim = cfg.n_embd // cfg.n_head
        self.qkv = nn.Linear(cfg.n_embd, 3 * cfg.n_embd, bias=False)
        self.proj = nn.Linear(cfg.n_embd, cfg.n_embd, bias=False)
        self.attn_drop = nn.Dropout(cfg.dropout)
        self.resid_drop = nn.Dropout(cfg.dropout)
        # causal mask: lower-triangular ones, registered as a buffer
        mask = torch.tril(torch.ones(cfg.block_size, cfg.block_size))
        self.register_buffer("mask", mask.view(1, 1, cfg.block_size, cfg.block_size))

    def forward(self, x):
        B, T, C = x.shape  # batch, time, channels
        # one matrix multiply for Q, K, V
        qkv = self.qkv(x)  # (B, T, 3C)
        q, k, v = qkv.split(C, dim=2)
        # reshape into heads: (B, T, n_head, head_dim) -> (B, n_head, T, head_dim)
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        # scaled dot-product attention
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        # apply causal mask
        att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v  # (B, n_head, T, head_dim)
        # re-stitch heads
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.resid_drop(self.proj(y))

A few things worth noticing. First, the mask is registered as a buffer, not a parameter, it has no gradients and is saved with the model but not updated. Second, masked_fill is the standard way to apply $-\infty$ where the mask is zero. Third, the modern PyTorch alternative is F.scaled_dot_product_attention(q, k, v, is_causal=True), which calls FlashAttention under the hood; we wrote it the long way for clarity.

Feed-forward network

class MLP(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.fc = nn.Linear(cfg.n_embd, 4 * cfg.n_embd, bias=False)
        self.proj = nn.Linear(4 * cfg.n_embd, cfg.n_embd, bias=False)
        self.drop = nn.Dropout(cfg.dropout)

    def forward(self, x):
        x = self.fc(x)
        x = F.gelu(x)
        x = self.proj(x)
        return self.drop(x)

The classic 4× expansion with GELU. For SwiGLU you would replace this with three matrices and the gated formula above; we keep it simple.

The Transformer block (pre-norm)

class Block(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.ln1 = nn.LayerNorm(cfg.n_embd)
        self.attn = CausalSelfAttention(cfg)
        self.ln2 = nn.LayerNorm(cfg.n_embd)
        self.mlp = MLP(cfg)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

The pre-norm structure: LayerNorm → attention → add to residual; LayerNorm → FFN → add to residual.

Sinusoidal positional encoding

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, n_embd: int, block_size: int):
        super().__init__()
        pe = torch.zeros(block_size, n_embd)
        position = torch.arange(0, block_size).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, n_embd, 2).float()
                             * -(math.log(10000.0) / n_embd))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))  # (1, block_size, n_embd)

    def forward(self, x):
        T = x.size(1)
        return x + self.pe[:, :T, :]

The full model

class GPT(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.cfg = cfg
        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.n_embd)
        self.pos_enc = SinusoidalPositionalEncoding(cfg.n_embd, cfg.block_size)
        self.drop = nn.Dropout(cfg.dropout)
        self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)])
        self.ln_f = nn.LayerNorm(cfg.n_embd)
        self.head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False)
        # weight tying
        self.head.weight = self.tok_emb.weight
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
        elif isinstance(m, nn.Embedding):
            torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        assert T <= self.cfg.block_size
        x = self.tok_emb(idx)            # (B, T, C)
        x = self.pos_enc(x)
        x = self.drop(x)
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        logits = self.head(x)            # (B, T, vocab_size)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
                                   targets.view(-1), ignore_index=-1)
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens: int, temperature: float = 1.0,
                 top_k: int | None = None):
        for _ in range(max_new_tokens):
            # crop context if too long
            idx_cond = idx if idx.size(1) <= self.cfg.block_size \
                       else idx[:, -self.cfg.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('inf')
            probs = F.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_id), dim=1)
        return idx

That is a complete decoder-only Transformer in roughly 90 lines. The remaining ingredients are the data pipeline and the training loop.

Training on tiny Shakespeare

# Load data: a single text file of ~1MB Shakespeare
with open('input.txt', 'r') as f:
    text = f.read()
chars = sorted(set(text))
vocab_size = len(chars)
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for i, c in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join(itos[i] for i in l)

data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data, val_data = data[:n], data[n:]

def get_batch(split, block_size, batch_size, device):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+1+block_size] for i in ix])
    return x.to(device), y.to(device)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
cfg = GPTConfig(vocab_size=vocab_size, block_size=128,
                n_layer=4, n_head=4, n_embd=128)
model = GPT(cfg).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=3e-4)

for step in range(5000):
    xb, yb = get_batch('train', cfg.block_size, batch_size=64, device=device)
    _, loss = model(xb, yb)
    opt.zero_grad(); loss.backward(); opt.step()
    if step % 500 == 0:
        with torch.no_grad():
            xv, yv = get_batch('val', cfg.block_size, 64, device)
            _, val_loss = model(xv, yv)
        print(f'step {step}: train {loss.item():.3f}, val {val_loss.item():.3f}')

# Generate
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=200)[0].tolist()))

On a single GPU, this trains in a few minutes and the loss drops from $\sim 4.2$ (uniform over 65 chars: $\log 65 = 4.17$) to about $1.5$. Generation produces faux-Shakespearean text with recognisable word boundaries, plausible character names, and the right rhythm of dialogue and stage directions. It is not literature, but the architecture is exactly the same as GPT-3, scaled down by a factor of a million.

That is the entire core mechanism of modern AI in 200 lines.

What this tiny model tells us

It is worth pausing on what we have just built. A 4-layer, 4-head, 128-dimensional Transformer has roughly 200,000 parameters. GPT-3 has 175 billion. The architectural differences between our toy and GPT-3 are all quantitative: more layers (96 vs 4), bigger model dimension (12288 vs 128), more heads (96 vs 4), longer context (2048 vs 128 in the original GPT-3, far longer in successors), bigger vocabulary (50257 vs 65), and BPE tokens instead of characters. The structural code (the Block class, the CausalSelfAttention class, the residual additions, the LayerNorms, the loss) is essentially unchanged.

The same code that learns to babble Shakespeare at $10^5$ parameters also learns to write essays, debug code, and pass professional examinations at $10^{11}$ parameters. The architecture scales. Scale itself is the differentiating factor between a curiosity and a production AI system. That is the practical content of "the scaling hypothesis": the architecture is not the bottleneck; the data, compute, and optimisation are.

Modern refinements: a production-style block

To bring our toy code closer to a 2026-vintage production block, the changes are local. Replace nn.LayerNorm with RMSNorm. Replace the FFN with SwiGLU. Add RoPE inside the attention. Use grouped-query attention. Drop biases. The result, in pseudocode:

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
        return self.weight * x / rms

class SwiGLU(nn.Module):
    def __init__(self, dim, hidden):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden, bias=False)
        self.w2 = nn.Linear(dim, hidden, bias=False)
        self.w3 = nn.Linear(hidden, dim, bias=False)
    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

def apply_rope(x, freqs):
    # x: (B, n_heads, T, d_head); freqs: (T, d_head/2) of complex exponentials.
    # Pair adjacent feature dims (GPT-NeoX / Llama RoPE layout, not the
    # original Su et al. half-split) and rotate.
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    x_rotated = x_complex * freqs.unsqueeze(0).unsqueeze(0)  # → (1, 1, T, d/2)
    return torch.view_as_real(x_rotated).flatten(-2)

These are 30 lines that capture the difference between a 2017-style Transformer and a 2026-style Transformer. Conceptually, the architecture has not changed. Engineering refinements at this level give 5–20% improvements that compound over many generations of models, but the core idea (attention plus FFN plus residual plus normalisation, stacked deep) is what it has been since 2017.

This site is currently in Beta. Contact: Chris Paton

Textbook of Usability · Textbook of Digital Health

Auckland Maths and Science Tutoring

AI tools used: Claude (research, coding, text), ChatGPT (diagrams, images), Grammarly (editing).