11.8 A from-scratch CIFAR-10 ResNet in PyTorch

Up to this point the chapter has worked at the level of equations and architectural diagrams. None of that is real until you can run it. This final practical section walks through a complete, self-contained PyTorch script that trains a small ResNet on CIFAR-10 from scratch in roughly two to three hours on a single consumer GPU and reaches between 91% and 93% test accuracy. Every line in the code blocks below should be copy-pastable and, taken together, they form the entire training pipeline end-to-end.

The previous section on transfer learning showed the most pragmatic path to a working CNN: take a model pre-trained on ImageNet and fine-tune the head. That route is right almost every time you face a small real-world dataset. But you should also be able to train a network from random weights at least once in your career, because that exercise builds intuition for everything that goes wrong: data loaders that bottleneck the GPU, learning rates that diverge after twenty epochs, batch-norm statistics that drift when batches are too small, validation curves that plateau because augmentation is too aggressive. CIFAR-10 is the right size of dataset for this rite of passage. It is small enough to fit in memory and to iterate on quickly, and large enough that a real residual network is the right tool.

CIFAR-10 dataset

CIFAR-10 is the standard small-image benchmark for academic computer-vision research. It contains 60{,}000 colour photographs at 32×32 pixel resolution, evenly split across ten mutually exclusive classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship and truck. The official split assigns 50{,}000 images to the training set and 10{,}000 to a held-out test set, with exactly 1{,}000 test images per class so that per-class accuracy is meaningful. The images were curated by Alex Krizhevsky and colleagues at the University of Toronto in 2009 as a labelled subset of the much larger 80 Million Tiny Images collection, and have anchored the small-image vision literature ever since.

At 32×32 the images are deliberately coarse. A horse is a smudge of brown over green; an automobile is a few angular pixels above a dark stripe of road. Humans score around 94% on the test set when carefully shown the images, so the headline numbers you will see in this section, 91% to 93%, are within striking distance of human performance. The task is non-trivial because intra-class variation is high (cats range from black silhouettes to ginger tabbies on white backgrounds) and inter-class confusion is real (the automobile/truck and cat/dog pairs account for a disproportionate share of error).

Loading CIFAR-10 in PyTorch is a one-liner. The torchvision.datasets.CIFAR10 class will download the official binary archive on first use, cache it under ./data/, and yield (image, label) pairs. We will wrap it in a DataLoader with shuffling, a batch size of 128 and four worker processes for parallel decoding. Channel-wise mean and standard deviation are pre-computed across the training set and applied as a normalisation step so that each colour channel has zero mean and unit variance entering the network, a small but crucial detail that lets us use higher learning rates without divergence.

The ResNet block

The fundamental unit of a residual network is the residual block. It takes a feature map, applies two 3×3 convolutions interleaved with batch normalisation and ReLU, and adds the result back to the input via a shortcut connection. The point, recall from §11.3, is that the network learns a residual function $F(x)$ rather than a full mapping, and identity is the default behaviour when $F(x) \to 0$. This is what makes networks of fifty or a hundred layers trainable.

import torch
import torch.nn as nn
import torch.nn.functional as F

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride, 1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_ch)
        if stride == 1 and in_ch == out_ch:
            self.shortcut = nn.Identity()
        else:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride, bias=False),
                nn.BatchNorm2d(out_ch),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return F.relu(out + self.shortcut(x))

A handful of details are worth dwelling on. First, every convolution sets bias=False. This is not laziness, the batch-norm layer that immediately follows learns its own affine offset, so any bias term in the convolution would be absorbed into it and contribute nothing but redundant parameters. Second, the kernel size, stride and padding values (3, stride, 1) keep the spatial dimensions matched: with stride 1, padding 1, a 3×3 kernel preserves height and width; with stride 2, padding 1, it halves them. Third, the shortcut branch is the identity when input and output shapes agree, but when we change channel count or downsample we need a 1×1 projection convolution (followed by its own batch-norm) to make the shapes line up so the addition is well-defined. This is the projection-shortcut variant of the original He et al. 2015 paper, and it is preferred over zero-padding because it is symmetric in forward and backward passes.

Fourth, the order of operations in forward follows the original ResNet recipe: conv → BN → ReLU on the first sub-block, conv → BN with no activation on the second, then add the shortcut, then a final ReLU. There is a "pre-activation" variant that places BN and ReLU before the convolutions and tends to train slightly more stably for very deep networks, but for a 20-layer network the original ordering works perfectly well and is what we use here. Finally, note that nothing in the block has a hard-coded spatial size: the same module works for 32×32, 16×16 or 8×8 feature maps. That is one of the quiet virtues of convolutional architectures.

The full ResNet

ResNet-20, the standard CIFAR variant, stacks these blocks in three stages of three blocks each, with channel widths of 16, 32 and 64. The first block of each stage after the first downsamples spatially with stride 2 and doubles the channel count, so the feature map shrinks from 32×32×16 to 16×16×32 to 8×8×64 as it climbs the network. This pattern of "halve the resolution, double the channels" preserves the total number of activations per layer roughly constant and is shared by virtually every modern CNN.

class ResNet20(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
        )
        self.stage1 = self._make_stage(16, 16, 3, stride=1)
        self.stage2 = self._make_stage(16, 32, 3, stride=2)
        self.stage3 = self._make_stage(32, 64, 3, stride=2)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc   = nn.Linear(64, num_classes)

    def _make_stage(self, in_ch, out_ch, n_blocks, stride):
        layers = [ResBlock(in_ch, out_ch, stride)]
        for _ in range(n_blocks - 1):
            layers.append(ResBlock(out_ch, out_ch, 1))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.pool(x).flatten(1)
        return self.fc(x)

The stem is a single 3×3 convolution rather than the 7×7 with stride 2 used by ImageNet-scale ResNets, at 32×32 we cannot afford to throw away resolution that aggressively. The body is three calls to a small helper, _make_stage, which builds a list of n_blocks residual blocks where the first one optionally downsamples. After the third stage the feature map is 8×8×64, and an adaptive average pool collapses the spatial dimensions to a 64-dimensional vector. A single linear layer produces the ten class logits.

The whole network has roughly 270{,}000 trainable parameters, about a thousand times smaller than a ResNet-50, which is why it trains so quickly on a single GPU. Counting parameters by hand is a useful sanity check: stage 1 has three basic blocks of two $3\times 3$ conv layers each at $16\to 16$ channels, so $3 \cdot 2 \cdot (3 \cdot 3 \cdot 16 \cdot 16) \approx 14\text{k}$ weights; stage 2 contributes most of the body; and the final linear layer is $64 \cdot 10 = 640$ weights. If your printout disagrees by an order of magnitude, you have wired the channel widths wrongly. A print(sum(p.numel() for p in model.parameters())) line in your script is a cheap insurance policy.

Training loop

The training loop is the part students most often write incorrectly. The skeleton below uses SGD with momentum (0.9), a weight decay of $5\times 10^{-4}$, an initial learning rate of 0.1, and a cosine schedule that decays the rate smoothly to zero over 200 epochs. SGD is preferred over Adam for CIFAR-style image classification because the literature consistently reports a percentage point or so of extra test accuracy. Adam shines on transformers and on settings with sparse gradients; on dense convolutions it tends to overfit slightly.

import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)
train_tf = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ColorJitter(0.2, 0.2, 0.2),
    T.ToTensor(),
    T.Normalize(mean, std),
])
test_tf = T.Compose([T.ToTensor(), T.Normalize(mean, std)])

train_ds = CIFAR10('./data', train=True,  download=True, transform=train_tf)
test_ds  = CIFAR10('./data', train=False, download=True, transform=test_tf)
train_dl = DataLoader(train_ds, 128, shuffle=True,  num_workers=4, pin_memory=True)
test_dl  = DataLoader(test_ds,  256, shuffle=False, num_workers=4, pin_memory=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model  = ResNet20().to(device)
opt    = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9,
                         weight_decay=5e-4, nesterov=True)
sched  = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=200)
scaler = torch.cuda.amp.GradScaler()

for epoch in range(200):
    model.train()
    for x, y in train_dl:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast():
            logits = model(x)
            loss = F.cross_entropy(logits, y)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
    sched.step()

Several augmentation choices matter. Random cropping with four pixels of zero padding is the canonical CIFAR augmentation and reliably adds about two percentage points of test accuracy. Horizontal flip helps because the classes are left-right symmetric. Mild colour jitter softens overfitting without hurting convergence. We deliberately avoid heavier augmentation such as Cutout, AutoAugment or RandAugment in this baseline because they push training time up and the marginal accuracy gain is best understood after the simple version works.

Mixed-precision training via torch.cuda.amp halves memory use and doubles throughput on any GPU with tensor cores. The GradScaler automatically scales the loss so that fp16 gradients do not underflow, then unscales before the optimiser step. The set_to_none=True flag on zero_grad is a small speedup. pin_memory=True and non_blocking=True together let the data loader overlap CPU-to-GPU transfers with computation. None of these tricks change the answer; they change how long you wait for it.

A separate evaluation loop, run once per epoch with model.eval() and torch.no_grad(), computes test accuracy. Switching to eval mode is not optional, it tells the batch-norm layers to use their accumulated running statistics rather than the per-batch ones, which would otherwise leak test-set information and produce noisier numbers. The first ten epochs you should expect to see test accuracy climb from random (10%) to roughly 70%; by epoch 50 you should be at around 87%; by epoch 200 the curve flattens out near 92%.

A few practical sanity checks save hours of debugging. Print the loss after the very first batch: it should be close to $\ln(10) \approx 2.30$, the cross-entropy of a uniform distribution over ten classes. If it is much larger you have probably forgotten to normalise the inputs. Plot training and test accuracy against epoch on the same axes; the gap between them tells you whether you are over- or under-regularised. If training accuracy is climbing past 99% while test accuracy stalls below 90%, increase weight decay or add augmentation; if both are stuck at 80%, your model has under-fit and you need a larger network or a longer schedule. Finally, set the random seed with torch.manual_seed(0) so that re-runs are at least mostly reproducible. Determinism is not perfect on CUDA, cuDNN's autotuner picks different algorithms depending on workspace memory, but seeding is enough to pin down whether a change you made helped or hurt.

Expected results

A correctly implemented ResNet-20 trained for 200 epochs with this recipe reaches 91% to 93% test accuracy. The original He et al. paper reports 91.25% for ResNet-20 on CIFAR-10 and 92.49% for ResNet-32; modern PyTorch reproductions that include cosine annealing and Nesterov momentum typically land between 92% and 93%. If your run finishes below 90%, the most likely culprits are missing augmentation, missing weight decay, or a learning-rate warmup that bites into the early epochs.

The state of the art on CIFAR-10 sits around 99.5% top-1 accuracy and is held by very large vision transformers and ConvNeXt-style networks pre-trained on JFT-300M or ImageNet-21k and then fine-tuned with heavy AutoAugment, Mixup, CutMix and stochastic depth. Closing the four-to-six-percentage-point gap from your 92% baseline to the leaderboard requires roughly a hundredfold more compute and a dozen orthogonal regularisation tricks. The lesson is that 92% is most of the way there: the easy two thirds of the loss curve is what a small, clean ResNet captures, and the long tail belongs to industrial training pipelines.

Wall-clock training time depends on the GPU. On an NVIDIA RTX 3060 (12 GB) you should expect about 35 seconds per epoch, so 200 epochs takes around two hours. On an A100 it is closer to ten seconds an epoch and the whole run finishes in well under an hour. On the integrated Apple Silicon GPU via the MPS backend it is roughly ninety seconds per epoch, and on CPU alone a full run is impractical, well over a day. If you cannot access a CUDA GPU, run for fifty epochs instead of two hundred and accept landing at 88% rather than 92%; the qualitative shape of the curve is the same.

What you should take away

  1. A complete CNN training pipeline fits in around eighty lines of PyTorch: a dataset, an augmentation pipeline, a residual block, a stack of three stages, an SGD-plus-cosine training loop and a periodic evaluation pass. Nothing more is needed for 92% on CIFAR-10.
  2. Residual blocks are short, identical, composable units; building a ResNet is mostly a question of how many to stack and where to halve the resolution. The "halve resolution, double channels" pattern is universal across modern CNNs.
  3. SGD with momentum, weight decay $5\times 10^{-4}$, learning rate 0.1 with cosine annealing, and the standard CIFAR augmentations (random crop with padding, horizontal flip, mild colour jitter) form a recipe that has been reproducible since 2015 and remains a strong baseline.
  4. Mixed precision and pinned memory roughly double throughput at no cost in accuracy. They are the first optimisations to reach for once your training script is correct.
  5. The headline 92% is within two percentage points of human performance on this benchmark and within seven of the published state of the art. The remaining gap is paid for in compute and engineering, not in fundamentally different ideas, which is exactly why training a small ResNet from scratch is the right way to consolidate your understanding of convolutional networks.

This site is currently in Beta. Contact: Chris Paton

Textbook of Usability · Textbook of Digital Health

Auckland Maths and Science Tutoring

AI tools used: Claude (research, coding, text), ChatGPT (diagrams, images), Grammarly (editing).