6.3 The supervised learning setup
Supervised learning has a single recipe that hardly changes from one task to the next. You collect examples, you decide what counts as an input and what counts as a label, you choose the family of functions you will draw from, you pick a way of scoring mistakes, and then you tune the function to make as few mistakes as possible. Whether you are predicting house prices, flagging spam, recognising handwritten digits or translating English into French, the steps are the same. What changes is the type of label, the shape of the function and the scoring rule.
This section formalises that template. We will write it down once, carefully, and then walk through four concrete tasks: regression, binary classification, multi-class classification and structured prediction. By the end you should be able to look at a new prediction problem and slot it into the template within a few minutes.
This section develops the supervised template; §6.4 asks the harder question of when doing well on training data implies doing well on unseen data.
The supervised template
Every supervised learning problem can be written down with the same three ingredients, then attacked with the same single principle.
Ingredient one: a dataset. We assume we have $n$ labelled examples,
$$ \mathcal{D} = \{(\mathbf{x}_1, y_1), (\mathbf{x}_2, y_2), \ldots, (\mathbf{x}_n, y_n)\}, $$
drawn independently from some unknown joint distribution $p(\mathbf{x}, y)$. Each $\mathbf{x}_i$ is a vector of features describing one example. Each $y_i$ is the label we want to predict. The phrase i.i.d., independent and identically distributed, is shorthand for the assumption that the examples were drawn separately from the same underlying process. Real data often violates this. Patients in 2030 are not drawn from the same distribution as patients in 2014. Spam in English is not drawn from the same distribution as spam in Mandarin. We will return to this carefully in later sections; for now, work in the i.i.d. setting because it is where the theory has bite.
Ingredient two: a hypothesis class. A hypothesis class $\mathcal{H}$ is a family of candidate functions $h: \mathcal{X} \to \mathcal{Y}$. Choosing $\mathcal{H}$ is the most consequential decision in machine learning. If $\mathcal{H}$ is too small it cannot represent the truth; if it is too large it will memorise noise. Common choices are linear models, decision trees of bounded depth, neural networks of a fixed architecture and nearest-neighbour rules.
Ingredient three: a loss function. A loss $\mathcal{L}(h(\mathbf{x}), y)$ is a real number that says how bad it is to predict $h(\mathbf{x})$ when the truth is $y$. Zero means perfect; larger means worse.
The principle: empirical risk minimisation. The thing we ultimately care about is the true risk, the expected loss on a fresh draw from the population:
$$ R(h) = \mathbb{E}_{(\mathbf{x}, y) \sim p}[\mathcal{L}(h(\mathbf{x}), y)]. $$
We cannot compute $R(h)$ because we do not know $p$. The trick we play instead is to compute the empirical risk, the average loss over the training set we actually have:
$$ \hat R(h) = \frac{1}{n} \sum_{i=1}^{n} \mathcal{L}(h(\mathbf{x}_i), y_i), $$
and pick the hypothesis $\hat h \in \mathcal{H}$ that minimises it. This is empirical risk minimisation (ERM). Almost every algorithm in this book is a variation on it: sometimes with a regularisation term added, sometimes with a clever optimiser, sometimes with stochastic mini-batches, but the skeleton is unchanged.
Regression
In regression, the label is a real number. House prices, blood pressure, tomorrow's temperature, the time a patient will spend in hospital, all of these are regression targets because the answer lives somewhere on the real line, $y \in \mathbb{R}$.
The default loss is mean squared error (MSE):
$$ \mathcal{L}(\hat y, y) = (\hat y - y)^2. $$
MSE penalises large mistakes much more than small ones, because the error is squared. That is sometimes what you want, being out by a million dollars on a house valuation is much worse than being out by a thousand. Sometimes it is the wrong thing, because a single huge outlier (a mansion in a sea of bungalows) will dominate the objective and pull the fitted line towards it. Two robust alternatives:
- Mean absolute error (MAE), $\mathcal{L}(\hat y, y) = |\hat y - y|$. Treats every dollar of error the same. Fits the median of $y$ given $\mathbf{x}$ rather than the mean.
- Huber loss, quadratic for small residuals and linear for large ones. A pragmatic compromise between MSE and MAE.
Worked example: predicting house prices. Suppose you want to predict the sale price of a house from a handful of features, square footage, number of bedrooms, distance to the nearest train station, suburb. Your training data is thousands of past sales. The simplest hypothesis class is linear regression:
$$ h(\mathbf{x}) = \mathbf{w}^\top \mathbf{x} + b, $$
where $\mathbf{w}$ is a weight vector with one entry per feature and $b$ is an intercept. The empirical risk for MSE is
$$ \hat R(\mathbf{w}, b) = \frac{1}{n} \sum_{i=1}^{n} \big( y_i - (\mathbf{w}^\top \mathbf{x}_i + b) \big)^2. $$
This particular objective has a closed-form minimum, the normal equations, which is why linear regression is taught first. For most other models we will use gradient descent.
The fitted model gives you both a prediction (multiply the features by the weights, add the bias) and a kind of explanation (a positive weight on square footage means more space sells for more, a negative weight on distance to the train station means farther from the train sells for less). That is part of why linear regression remains the workhorse of applied statistics: each weight is directly interpretable.
Binary classification
In binary classification, the label is one of two categories: spam or not spam, fraudulent transaction or genuine, malignant tumour or benign. We code these as $y \in \{0, 1\}$.
The output of the model is usually not the class itself but the probability that the class is 1. We squash a real-valued score $z = \mathbf{w}^\top \mathbf{x} + b$ through the sigmoid $\sigma(z) = 1/(1 + e^{-z})$ so that the output lies in $(0, 1)$. To turn a probability into a decision we threshold it, usually at $0.5$.
The standard loss is binary cross-entropy (also called log-loss):
$$ \mathcal{L}(\hat y, y) = -\big[\, y \log \hat y + (1 - y) \log(1 - \hat y) \,\big]. $$
It rewards the model for being confident and right, and punishes it harshly for being confident and wrong. It is the negative log-likelihood of a Bernoulli model, which is the principled probabilistic justification.
Worked example: spam detection. Take the body of an email, count how often each word from a fixed vocabulary appears, and stack those counts into a feature vector $\mathbf{x}$. (This is the bag-of-words representation. It throws away word order, which sounds drastic but works surprisingly well for spam.) The label $y$ is 1 for spam and 0 for ham. The hypothesis class is logistic regression:
$$ h(\mathbf{x}) = \sigma(\mathbf{w}^\top \mathbf{x} + b), $$
interpreted as the predicted probability that the email is spam. Words like "viagra" and "lottery" earn large positive weights; words like "meeting" and "thanks" earn negative ones. The training procedure is gradient descent on the average binary cross-entropy. At test time you classify an email as spam if $h(\mathbf{x}) > 0.5$, or you tune the threshold to trade off false positives against false negatives, a topic we will return to in §6.14.
Multi-class classification
Now the label is one of $K \ge 2$ categories: digit zero through nine, dog/cat/bird/fish, ICD-10 diagnosis code. We write $y \in \{1, 2, \ldots, K\}$.
The model outputs a probability distribution over the $K$ classes, computed by the softmax function applied to a vector of $K$ scores:
$$ \mathrm{softmax}(\mathbf{z})_k = \frac{e^{z_k}}{\sum_{j=1}^{K} e^{z_j}}. $$
The softmax outputs are non-negative and sum to one, so they form a valid probability distribution. The loss is categorical cross-entropy, the natural multi-class generalisation of binary cross-entropy:
$$ \mathcal{L}(\hat{\mathbf{y}}, y) = -\log \hat y_y, $$
i.e. the negative log of the probability the model assigned to the correct class.
Worked example: handwritten digit recognition. Take a 28-by-28 grayscale image of a handwritten digit. Flatten it into a 784-dimensional vector, with each entry the brightness of one pixel. Stack these into a feature matrix. The label is the digit, an integer from 0 to 9. A simple hypothesis class is linear softmax classification:
$$ h(\mathbf{x})_k = \mathrm{softmax}(\mathbf{W}\mathbf{x} + \mathbf{b})_k, \quad k = 1, \ldots, 10. $$
Here $\mathbf{W}$ is a $10 \times 784$ weight matrix and $\mathbf{b}$ is a 10-vector of biases, one row of weights and one bias per digit. Training minimises the average cross-entropy over the training set by gradient descent. At test time you predict the digit with the largest output probability.
This very plain model is enough to reach about 92% accuracy on MNIST, which is impressive for something with no convolutions and no hidden layers. Beating it requires the deep learning machinery of Chapter 9.
Structured prediction
Sometimes the label is not a number or a class but a structured object, a sentence, a parse tree, a segmentation map, a graph. Structured prediction is the umbrella term for these problems. The supervised template still fits, but the loss is more elaborate because comparing two sentences or two segmentations is more elaborate than comparing two numbers.
Three concrete examples:
- Machine translation. Input: a sentence in English. Output: a sentence in French. The space of possible outputs is unbounded. Quality is scored against one or more human reference translations using metrics like BLEU, which counts overlapping $n$-grams.
- Image segmentation. Input: an image. Output: a label per pixel. The loss is per-pixel cross-entropy during training and intersection over union (IoU) at evaluation, which measures the overlap between predicted and true masks.
- Code generation. Input: a docstring or test cases. Output: source code. Evaluation typically uses functional correctness (does the code pass the tests?) rather than string similarity.
Structured prediction is harder than ordinary classification for two reasons. First, the output space is huge or infinite, so you cannot enumerate it; the model has to construct the answer piece by piece. Second, evaluation is messy because there can be many correct answers, and good metrics require care. Most modern structured predictors are sequence models that emit one token at a time and are scored against reference outputs, but the underlying recipe, dataset, hypothesis class, loss, fit, evaluate, is unchanged.
Evaluation
A model that is good on its training data is not necessarily good on new data. To learn how good a model really is, evaluate it on data it has not seen during fitting.
- Train/validation/test split. Carve the data into three disjoint pieces. Use the training set to fit the model. Use the validation set (sometimes called the dev set) to choose hyperparameters, decide when to stop training and compare candidate models. Use the test set exactly once, at the end, to report a final number. A typical split is 80/10/10 or 70/15/15. For very large datasets, even 98/1/1 leaves a substantial test set.
- Cross-validation. When data is scarce, splitting once wastes most of it. Instead, divide the data into $k$ folds, train $k$ models each holding out one fold for validation, and average the resulting metric. Five-fold cross-validation is a common default.
- Bootstrap confidence intervals. A single test-set accuracy is a noisy estimate. Resample the test set with replacement many times, recompute the metric on each resample and report the 2.5th and 97.5th percentiles as a 95% confidence interval. This makes it harder to fool yourself into thinking 86.2% beats 85.9%.
- Calibrated reporting. Always report a metric with an interval. A point estimate without uncertainty is half a number.
The single most important rule is that the test set is sacred. Every time you peek at it during development you burn a little of its statistical purity, because you implicitly select against models that look bad on it. Andrew Ng has a useful slogan: "the test set is fired exactly once." Keep it locked away.
What you should take away
- Supervised learning has one recipe (dataset, features, labels, hypothesis class, loss, fit, evaluate) that is reused across every task type.
- Empirical risk minimisation is the trick that lets us optimise something we can compute (training loss) as a stand-in for what we actually care about (true risk).
- Regression uses real-valued labels and squared (or absolute, or Huber) error; binary classification uses sigmoid outputs and binary cross-entropy; multi-class classification uses softmax outputs and categorical cross-entropy.
- Structured prediction extends the same template to outputs that are sentences, masks or trees, with task-specific losses such as BLEU, IoU and edit distance.
- Always evaluate on held-out data, prefer interval estimates over point estimates and treat the test set as sacred, peek at it as little as possible before the final report.