5.12 Empirical Bayes and Hierarchical Models
When you have several related estimation problems, treating them in isolation throws away information. Hierarchical models, and their frequentist twin, empirical Bayes, pool information across problems to give estimates that are jointly better than the unpooled MLEs.
Stein's paradox
Charles Stein's 1956 result is one of the most counter-intuitive in statistics. Suppose you observe $X_i \sim \mathcal{N}(\mu_i, 1)$ for $i = 1, \ldots, k$, and want to estimate the vector $(\mu_1, \ldots, \mu_k)$. The obvious estimator is $\hat\mu_i = X_i$. Stein showed that for $k \geq 3$, this estimator is inadmissible, it is dominated in mean squared error by
$$\hat\mu_i^{\text{JS}} = \left(1 - \frac{k-2}{\sum_j X_j^2}\right) X_i,$$
the James–Stein estimator, which shrinks each component toward zero. The shrinkage works even when the true means have nothing to do with each other. Pooling totally unrelated estimation problems improves all of them simultaneously.
The intuition: in high dimensions, the random fluctuations in the $X_i$ on average overshoot the true $\mu_i$. Shrinking corrects this overshoot. The estimator can be recast as the posterior mean under a hierarchical Gaussian prior with empirically estimated variance, hence empirical Bayes: treat hyperparameters as parameters of a hierarchical model and learn them from the marginal likelihood.
Hierarchical models
A hierarchical (multilevel) model typically has the form
$$y_{ij} \mid \theta_j \sim p(y \mid \theta_j),\qquad \theta_j \mid \phi \sim p(\theta \mid \phi),\qquad \phi \sim p(\phi).$$
The "school effects" textbook example from Gelman et al.'s Bayesian Data Analysis: eight schools' SAT-coaching effects are jointly estimated, with each school's effect drawn from a population distribution whose hyperparameters are inferred from the data. Schools with little data borrow strength from the population mean; schools with abundant data dominate their own estimate.
Partial pooling
Hierarchical estimation is often summarised as partial pooling:
- No pooling = fit each group separately. High variance for small groups.
- Complete pooling = ignore groups, fit one model. High bias.
- Partial pooling = hierarchical model. Best of both: bias and variance jointly minimised per group.
Empirical Bayes in ML
The pattern recurs throughout machine learning under different names:
- Multi-task learning with a shared prior over task-specific heads.
- Meta-learning (MAML) where the initialisation is learned across tasks.
- Federated learning with hierarchical user-level priors.
- Recommender systems where user and item effects are partially pooled.
- Gaussian process kernel learning via marginal-likelihood maximisation (Type II ML).
Whenever you have related problems and limited per-problem data, partial pooling is on the table.
What you should take away
- Stein dominates the obvious estimator. For three or more means, the maximum-likelihood estimate $\hat\mu_i = X_i$ is inadmissible: shrinkage toward zero (or any fixed point) reduces total mean squared error, even when the means are unrelated.
- Pooling exploits a structural property of high dimensions. Independent fluctuations on average overshoot the truth; shrinkage corrects the overshoot. The benefit is geometric, not Bayesian.
- Hierarchical models are the Bayesian counterpart. A prior on group-level parameters is itself drawn from a population distribution whose hyperparameters are inferred from the marginal likelihood. Empirical Bayes estimates those hyperparameters by maximisation; full Bayes integrates over them.
- Partial pooling sits between the extremes. No pooling fits each group separately (high variance); complete pooling ignores groups (high bias); partial pooling minimises both jointly, with each group borrowing strength in proportion to how little it has.
- The pattern recurs throughout ML. Multi-task learning, meta-learning, federated learning, recommender systems and Gaussian-process kernel learning all rely on partial pooling under different names. Whenever data is split across related problems with limited per-problem signal, hierarchy is the right starting point.