9.17 The same MLP in PyTorch
In Section 9.16 we wrote everything by hand. We initialised the weights ourselves, we wrote the forward pass as a chain of matrix multiplications, we wrote the backward pass by applying the chain rule layer by layer, we wrote the cross-entropy loss, we wrote the SGD update, and we wrote the minibatch loop. The whole program was several hundred lines of NumPy and ran on the CPU. It worked, and it taught us what every part does. But almost nobody in industry or research writes neural networks that way. They use a deep learning framework, and the framework that has won over the last decade is PyTorch.
PyTorch automates almost everything we did by hand. It builds a computational graph as you compute the forward pass, and then differentiates that graph automatically when you call .backward(). It ships dozens of pre-built layer types, optimisers, loss functions, learning-rate schedules, and data utilities. It moves tensors and parameters to the GPU with a single .to(device) call. It provides a DataLoader that handles shuffling, batching, and parallel data loading. The same network we built in Section 9.16 is roughly thirty lines of PyTorch and trains ten times faster on a modest GPU.
The point of this section is not to teach you a new model. It is the same MLP. The point is to show, line by line, what PyTorch is doing for you, so that when you open any modern deep learning codebase you can read it. Every research paper from 2017 onwards releases its code in PyTorch (or, less commonly, JAX). The vocabulary you learn here transfers directly to GPT, BERT, ResNet, diffusion models, and everything else in the rest of the book.
Setting up
Before we can build a network we need to import PyTorch, choose a device (CPU or GPU), load the MNIST data, and wrap it in a DataLoader. Here is the working code:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import fetch_openml
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using {device}')
X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False)
X = torch.tensor(X.astype('float32') / 255.0)
y = torch.tensor(y.astype('int64'))
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]
train_ds = TensorDataset(X_train, y_train)
test_ds = TensorDataset(X_test, y_test)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=512)
The first three imports give us the three things we need most: the core torch library (tensors and autograd), the nn module (pre-built layers), and the functional module F (stateless functions like ReLU and softmax). The fourth import is the DataLoader, which is the tool we use for minibatching, shuffling, and (optionally) parallel loading. We also import fetch_openml to grab MNIST in one line; this is a small convenience over the manual download in Section 9.16.
The line device = torch.device(...) is one of the most useful idioms in PyTorch. It checks whether a CUDA-capable GPU is visible, and if so picks 'cuda'; otherwise it falls back to 'cpu'. We will pass this device object to .to(device) calls later. The same script then runs on a laptop without modification or on a workstation with a GPU and gets a ten-fold speedup automatically. On Apple Silicon the equivalent string is 'mps'; many modern scripts probe for that as well.
We then load the data. MNIST is 70,000 28-by-28 grey-scale digit images flattened to vectors of length 784. We normalise pixel values from $[0, 255]$ to $[0, 1]$ by dividing by 255. We convert to PyTorch tensors with torch.tensor. The training labels must be int64 (PyTorch's CrossEntropyLoss requires this dtype). We then split into the first 60,000 training examples and the last 10,000 test examples, exactly as in Section 9.16.
TensorDataset is a thin wrapper that pairs an input tensor with a label tensor and exposes them as a dataset that supports indexing. DataLoader then takes a dataset and gives us an iterable that yields minibatches. With batch_size=64 and shuffle=True, every epoch the loader walks through the data in a fresh random order, packing 64 examples per batch. For evaluation we use a larger batch size of 512 and no shuffling: bigger batches are faster, and shuffling does not affect accuracy. The hand-rolled minibatch loop from Section 9.16 disappears.
Defining the network
In NumPy we represented the network as a list of weight matrices and bias vectors. In PyTorch we describe it as a Python class that inherits from nn.Module. Here is the same architecture, three fully connected layers with ReLU activations:
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x) # logits, softmax inside the loss
return x
model = MLP().to(device)
nn.Module is the base class for every neural network component in PyTorch, from a single layer up to a 70-billion-parameter language model. By inheriting from it we get three benefits for free. First, every nn.Linear (or other layer) we assign as an attribute is automatically registered as a submodule and its weights and biases become discoverable through model.parameters(). Second, calling the model as model(x) automatically calls model.forward(x) while also running PyTorch's hook system. Third, model.to(device), model.train(), model.eval(), and model.state_dict() all walk the module tree recursively, so we never have to track parameters by hand.
nn.Linear(784, 128) creates a layer with a $784 \times 128$ weight matrix and a 128-dimensional bias vector. PyTorch initialises the weights using Kaiming uniform initialisation by default (a variant of He initialisation, see Section 9.10), which is appropriate for ReLU networks. In Section 9.16 we wrote that initialisation by hand; here it is the silent default.
The forward method describes how data flows through the network. Inside forward we use the functional form F.relu because ReLU has no learnable parameters; for layers that do have parameters (like nn.Linear) we use the module form so PyTorch can track them. The activation pattern is exactly the chain we wrote out in Section 9.6: linear, ReLU, linear, ReLU, linear. The third layer outputs logits, ten unnormalised real numbers per example. We deliberately do not apply softmax here. The cross-entropy loss in PyTorch combines softmax and negative log-likelihood in a single numerically stable operation, so the convention is to feed it raw logits.
A subtle but important point: PyTorch builds the computational graph dynamically, on the fly, every time forward runs. Each tensor produced by self.fc1(x), F.relu(...), and so on remembers the operation that created it and the input tensors. This dynamic graph is what makes loss.backward() possible later, and it is the difference between PyTorch and the older static-graph frameworks (TensorFlow 1.x, Theano). You can put if statements, loops, and Python control flow inside forward and PyTorch will correctly differentiate through them.
Finally, model = MLP().to(device) constructs the model on the CPU and then moves all of its parameters and buffers to the chosen device. From now on every input tensor passed to the model must live on the same device, otherwise PyTorch raises an error.
Optimiser and loss
We need two more ingredients before we can train: a loss function and an optimiser. PyTorch ships both in one-liners.
optimiser = torch.optim.SGD(model.parameters(), lr=0.5)
loss_fn = nn.CrossEntropyLoss()
torch.optim.SGD implements plain stochastic gradient descent with the update $\theta \leftarrow \theta - \eta \nabla_\theta \mathcal{L}$, exactly the rule we coded by hand in Section 9.16. It accepts an iterable of parameters; model.parameters() returns a generator that walks the entire module tree and yields every learnable tensor. Optionally we can pass momentum=0.9 to reproduce the momentum variant from the end of Section 9.16, or weight_decay=1e-4 to add L2 regularisation (Section 9.12). To switch to Adam we change one line: optimiser = torch.optim.Adam(model.parameters(), lr=1e-3). The rest of the training code does not change. This drop-in interchangeability is one of the strongest reasons to use a framework: comparing optimisers is trivial.
nn.CrossEntropyLoss() is the multi-class classification loss we derived in Section 9.9. PyTorch's implementation does three things at once. First, it applies the log-softmax to the logits in a numerically stable way (by subtracting the maximum logit before exponentiating, which prevents overflow). Second, it indexes into the log-probabilities with the integer target labels. Third, it negates and averages over the batch. Combining these steps is how PyTorch avoids the log(0) problems that a naive softmax-then-log would suffer. The function expects logits of shape (B, C) and integer targets of shape (B,); passing probabilities or one-hot encodings is a common mistake.
Training loop
Now we put it together. We separate training and evaluation into two functions, then run for twenty epochs.
def train_epoch(model, loader, optimiser, loss_fn):
model.train()
for X_batch, y_batch in loader:
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
optimiser.zero_grad()
logits = model(X_batch)
loss = loss_fn(logits, y_batch)
loss.backward()
optimiser.step()
@torch.no_grad()
def evaluate(model, loader):
model.eval()
correct = total = 0
for X_batch, y_batch in loader:
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
logits = model(X_batch)
pred = logits.argmax(dim=1)
correct += (pred == y_batch).sum().item()
total += y_batch.size(0)
return correct / total
for epoch in range(20):
train_epoch(model, train_loader, optimiser, loss_fn)
train_acc = evaluate(model, train_loader)
test_acc = evaluate(model, test_loader)
print(f'Epoch {epoch+1}: train acc {train_acc:.4f}, test acc {test_acc:.4f}')
The training function is six lines of real work. We start with model.train(), which puts the model into training mode. For an MLP without dropout or batch normalisation this is a no-op, but we always include the call because forgetting it once you add dropout is a classic PyTorch bug.
Inside the loop, the first thing we do is move the batch to the device. The data lives on the CPU because DataLoader produces CPU tensors; the model lives on the GPU because we called .to(device) earlier. Without X_batch.to(device) PyTorch would complain about a device mismatch.
optimiser.zero_grad() resets every parameter's .grad attribute to zero. PyTorch accumulates gradients into .grad by default. This sounds like a bug but it is actually a feature: it lets you split a large effective batch across multiple smaller forward passes (gradient accumulation), and it lets multi-task setups sum gradients from different losses. The cost of the convenience is that you must remember to clear gradients at the start of every step; forgetting is one of the most common PyTorch errors.
logits = model(X_batch) runs the forward pass. As it runs, PyTorch builds the computational graph. loss = loss_fn(logits, y_batch) adds the loss node on top.
loss.backward() is the line that earns PyTorch its keep. It walks the computational graph backwards from the loss tensor, applying the chain rule at every node, and accumulates the gradient with respect to each leaf tensor (every parameter) into that tensor's .grad. The thirty lines of by-hand backpropagation we wrote in Section 9.16 collapse to this one line, and PyTorch never makes a chain-rule mistake.
optimiser.step() applies the update rule. For SGD that is $\theta \leftarrow \theta - \eta \, \theta.\mathrm{grad}$ for every parameter the optimiser holds. For Adam it would maintain running first and second moments. The optimiser code is generic and tested.
The evaluation function is wrapped in @torch.no_grad(). This decorator disables graph construction inside the function, which roughly halves memory use and modestly speeds things up; we do not need gradients during evaluation. We also call model.eval(), which switches dropout to identity and batch normalisation to its running statistics. For our plain MLP both calls are no-ops, but again we include them as a habit.
We compute predictions with logits.argmax(dim=1) (which returns the index of the largest logit in each row), compare to the labels, and accumulate correct counts. The .item() calls convert zero-dimensional tensors to Python numbers; this is the safe way to extract scalar values from PyTorch and detaches them from the graph.
A typical run prints something like Epoch 20: train acc 0.9988, test acc 0.9785, the same ballpark as the NumPy version.
What changed from the NumPy version
It is worth tabulating the mapping concept by concept. Each row is a piece of the NumPy implementation from Section 9.16 and the PyTorch idiom that replaces it.
| NumPy version | PyTorch idiom |
|---|---|
init_params([784, 128, 64, 10]) |
nn.Linear(784, 128), etc. |
Hand-rolled forward() |
model(x) (calls forward automatically) |
Hand-rolled backward() |
loss.backward() |
params[i] = (W - lr*dW, b - lr*db) |
optimiser.step() |
softmax(z) then cross_entropy(Y_hat, Y) |
nn.CrossEntropyLoss()(logits, y) |
np.maximum(0, z) |
F.relu(x) |
| Loop over manual minibatches | DataLoader(...) |
| CPU-only NumPy | model.to(device) enables GPU |
PyTorch automates initialisation (Kaiming by default for nn.Linear), the entire backward pass via reverse-mode autograd, the optimiser book-keeping (gradient state for Adam, momentum buffers for SGD with momentum), GPU placement (one method call moves the whole model), and minibatching with shuffling and parallel loading (DataLoader with num_workers > 0 spawns worker processes that prefetch batches in the background).
What you still write by hand is the architecture of the network (which layers, what sizes, what activations), the choice of loss and optimiser, the high-level training loop, the evaluation logic, and any logging or checkpointing you want. These are the interesting design choices. Everything mechanical is gone. This is why a 70-billion-parameter language model can be expressed in two or three hundred lines of PyTorch: the framework eliminates the boilerplate so the only code you read is the code that encodes the modelling decisions.
A second observation is that PyTorch's correctness is extremely well tested. The autograd engine has been used to differentiate billions of model evaluations by millions of users; the chance that your hand-rolled backward pass is correct on the first try is much lower. This means that bugs in PyTorch programs are almost always in your forward pass, your data pipeline, or your training-loop logic, not in the gradients. That focuses debugging effort.
Performance comparison
On a typical modern GPU the PyTorch version trains the same network in roughly 5 to 10 seconds per epoch. The NumPy version from Section 9.16 takes 30 to 60 seconds per epoch on the same machine's CPU. On CPU-only PyTorch the speedup over NumPy is smaller, perhaps 2-3x, because PyTorch then uses MKL or OpenBLAS under the hood, which NumPy already calls in the background.
Where do PyTorch's wins come from? Several places. First, NVIDIA's cuBLAS and cuDNN libraries provide hand-optimised matrix multiplication and convolution kernels that exploit tensor cores; PyTorch dispatches to these automatically for any tensor that lives on the GPU. Second, all reductions, elementwise operations, softmax, and cross-entropy run as CUDA kernels with coalesced memory access. Third, PyTorch fuses certain common compositions into single kernels, most importantly softmax-plus-negative-log-likelihood, which reads the logits once instead of three times. Fourth, the data layout is contiguous and aligned for vectorised access. Fifth, with DataLoader(num_workers=4) data loading happens on CPU worker processes in parallel with GPU compute, so the GPU never stalls waiting for data.
For larger networks the gap widens. A modest CNN that takes 30 minutes per epoch in NumPy takes one minute in PyTorch on a single GPU. A Transformer that would take days in NumPy is the routine workload of every research lab. The performance gap is not a constant factor; it grows with model size because the GPU's relative advantage over the CPU grows with the size of the matrix multiplications. This is why deep learning effectively did not exist as a practical discipline before GPU frameworks.
A more idiomatic version
The version above is verbose so that we can see the parallels with Section 9.16. A working researcher would write the same network more compactly using nn.Sequential, would use a more capable optimiser, and would add a learning-rate schedule. Here is a polished version:
model = nn.Sequential(
nn.Linear(784, 128), nn.ReLU(),
nn.Linear(128, 64), nn.ReLU(),
nn.Linear(64, 10),
).to(device)
optimiser = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimiser, T_max=20)
for epoch in range(20):
train_epoch(model, train_loader, optimiser, loss_fn)
scheduler.step()
print(f'Epoch {epoch+1}: train {evaluate(model, train_loader):.4f}, test {evaluate(model, test_loader):.4f}, lr {scheduler.get_last_lr()[0]:.6f}')
nn.Sequential is a container that stacks layers in order and runs them as a chain. We no longer need a custom subclass of nn.Module for a simple feed-forward network; we list the layers, and Sequential writes the forward method for us. Subclasses are still the right choice when you have branching, skip connections, or any non-linear control flow.
torch.optim.AdamW is the optimiser used in nearly every modern deep learning paper. It is Adam with decoupled weight decay (Loshchilov and Hutter, 2017), which corrects a subtle bug in the original Adam-plus-L2 combination. For most modern problems AdamW with lr=1e-3 and weight_decay=1e-4 is a sensible default that needs almost no tuning to get within a percentage point of optimal.
CosineAnnealingLR is a learning-rate schedule that decays the learning rate following a cosine curve from its initial value down to zero over T_max epochs. Cosine schedules tend to outperform step decay or constant schedules empirically and have become the default in vision and language. We call scheduler.step() once per epoch (not per minibatch); the order matters. This single change typically buys an extra half percentage point of accuracy on MNIST and several points on harder tasks.
Adding dropout (Section 9.12) for regularisation is one extra line: insert nn.Dropout(0.2) after each ReLU. Adding batch normalisation (Section 9.13) is similarly one line: insert nn.BatchNorm1d(128) between nn.Linear and nn.ReLU. The training loop does not change. This is the lever that PyTorch hands you: architectural experimentation costs you only the lines that describe the architecture.
What you should take away
- PyTorch automates initialisation, the entire backward pass, optimiser book-keeping, GPU placement, and minibatching. The lines you keep writing are the modelling decisions: architecture, loss, optimiser, training-loop structure, evaluation.
- Every neural network is a subclass of
nn.Module(or ann.Sequential); training is the four-line dance ofzero_grad, forward,loss.backward(),optimiser.step(). - Always set
model.train()before training andmodel.eval()before evaluation; wrap evaluation in@torch.no_grad(). This is harmless on a plain MLP and essential the moment you add dropout or batch normalisation. - The same network runs ten times faster on a GPU because of cuBLAS/cuDNN kernels, fused operations, and parallel data loading. The performance gap grows with model size.
- The PyTorch idioms in this section are the same idioms used by every modern model: GPT, ResNet, diffusion, multimodal, reinforcement learning. Once you read this loop fluently you can read the rest of the book's code.