12.15 Building a character-level LSTM in PyTorch

We now build a small but functional character-level language model from scratch in PyTorch, in the spirit of Karpathy's 2015 "min-char-rnn" but using LSTM instead of vanilla RNN. The model reads a corpus character by character and learns to predict the next character.

import torch
import torch.nn as nn
import torch.nn.functional as F

# ---- 1. Data ----
with open('shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(set(text))
vocab_size = len(chars)
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for i, c in enumerate(chars)}
data = torch.tensor([stoi[c] for c in text], dtype=torch.long)

# ---- 2. Model ----
class CharLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, n_layers=2, dropout=0.2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, n_layers,
                            dropout=dropout, batch_first=True)
        self.head = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        # x: (B, T) integer indices
        emb = self.embed(x)                      # (B, T, E)
        out, hidden = self.lstm(emb, hidden)     # (B, T, H)
        logits = self.head(out)                  # (B, T, V)
        return logits, hidden

# ---- 3. Training loop ----
def get_batch(data, batch_size, seq_len):
    n = data.size(0)
    starts = torch.randint(0, n - seq_len - 1, (batch_size,))
    x = torch.stack([data[s : s + seq_len] for s in starts])
    y = torch.stack([data[s + 1 : s + seq_len + 1] for s in starts])
    return x, y

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = CharLSTM(vocab_size).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=3e-4)

batch_size, seq_len = 64, 128
for step in range(5000):
    x, y = get_batch(data, batch_size, seq_len)
    x, y = x.to(device), y.to(device)
    logits, _ = model(x)
    loss = F.cross_entropy(logits.reshape(-1, vocab_size), y.reshape(-1))
    opt.zero_grad()
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    opt.step()
    if step % 200 == 0:
        print(f'step {step}  loss {loss.item():.3f}  ppl {loss.exp().item():.1f}')

# ---- 4. Sampling ----
@torch.no_grad()
def sample(model, prefix, length=500, temperature=0.8, top_p=0.9):
    model.eval()
    idx = torch.tensor([[stoi[c] for c in prefix]], device=device)
    hidden = None
    out = list(prefix)
    for _ in range(length):
        logits, hidden = model(idx, hidden)
        logits = logits[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1).squeeze()
        # top-p truncation
        sorted_p, sorted_i = probs.sort(descending=True)
        cum = sorted_p.cumsum(0)
        cutoff = int((cum < top_p).sum().item()) + 1
        kept_p = sorted_p[:cutoff]
        kept_i = sorted_i[:cutoff]
        kept_p = kept_p / kept_p.sum()
        next_id = kept_i[torch.multinomial(kept_p, 1)]
        out.append(itos[int(next_id)])
        idx = next_id.view(1, 1)
    model.train()
    return ''.join(out)

print(sample(model, 'ROMEO:', length=400))

Several pedagogical observations follow.

Truncated BPTT in disguise. The code reads random length-seq_len chunks of the corpus. The hidden state is reset to zero at the start of each chunk, which means BPTT is truncated at seq_len characters. A more thorough implementation reads consecutive chunks and carries the hidden state forward (detached from the graph between chunks).

Cross-entropy and perplexity. The final printed perplexity is the exponential of the loss. A well-trained character-level LSTM on Shakespeare reaches perplexity in the 3 to 5 range; a vanilla RNN with the same hyperparameters tends to plateau around 5 to 8.

Embedding the inputs. We embed integer character indices through nn.Embedding, which is the lookup table from §12.4. PyTorch's LSTM layer expects (batch, time, feature) when batch_first=True.

Sampling temperature and top-$p$. The sampler implements both knobs. Lowering temperature towards 0 produces more confident, repetitive samples; raising it injects more diversity and eventually noise. Top-$p$ around 0.9 is a sensible default.

Generated samples from a 2-layer 256-unit LSTM trained for a few thousand steps on Shakespeare look strikingly like Shakespeare at the surface level (correctly capitalised speaker labels, plausible diction, balanced quotes) but degenerate on close reading. This is the central pedagogical point of Karpathy's 2015 blog post and the reason recurrent character models were a viral demonstration of deep learning's power: at the local statistics they imitate, they imitate well. The deeper coherence (themes, character consistency, plot) they cannot reach. Modern Transformer language models can reach those further levels, and the next chapter is about how.

This site is currently in Beta. Contact: Chris Paton

Textbook of Usability · Textbook of Digital Health

Auckland Maths and Science Tutoring

AI tools used: Claude (research, coding, text), ChatGPT (diagrams, images), Grammarly (editing).