14.4 Variational autoencoders
A variational autoencoder pairs an encoder that maps an input to a distribution over a low-dimensional latent code with a decoder that maps a latent code back to the input space. The two networks are trained jointly to maximise a tractable lower bound on the data log-likelihood, the evidence lower bound, or ELBO. The reward for solving the optimisation is twofold. The decoder becomes a generative model: sampling a latent from the prior and pushing it through the decoder yields a fresh datum that resembles the training distribution. The encoder, meanwhile, learns to compress: each input is mapped to a tightly clustered Gaussian whose location encodes its content. Because the prior is smooth and the encoder is regularised to match it, the latent space is structured. Walks between two latent codes produce smooth interpolations in output space, the eyebrows of one face becoming the eyebrows of another, the eight gradually flattening into a one. This combination of explicit density modelling, smooth latent geometry, and end-to-end gradient training made the VAE the prototypical deep latent-variable model. Its variational machinery is also the conceptual foundation for the latent diffusion models (§14.12) that now power most production text-to-image systems.
We have just left §14.3's autoregressive image models, which factor $p(\mathbf{x})$ pixel-by-pixel and avoid latent variables altogether. This section introduces the alternative philosophy: posit unobserved causes $\mathbf{z}$ that explain the data, and learn both the generative process and an inference network for the inverse problem. The variational machinery developed here (encoder, prior, ELBO, reparameterisation, KL term) reappears almost line-for-line in §14.9 (denoising diffusion), where the latent is replaced by a sequence of progressively noised versions of the input, and the decoder by a learned denoiser.
The setup
Imagine the data as the visible end of a two-step generative story. First, a hidden variable $\mathbf{z}$, a small vector in $\mathbb{R}^d$, where $d$ might be 2, 32, or a few hundred, is drawn from a simple prior. By overwhelming convention $p(\mathbf{z}) = \mathcal{N}(\mathbf{0}, \mathbf{I})$: a unit-variance, isotropic Gaussian centred at the origin. Second, the latent is fed through a stochastic decoder $p_\theta(\mathbf{x} \mid \mathbf{z})$, a neural network with parameters $\theta$, which produces the observation $\mathbf{x}$. For a binary image like MNIST the decoder typically outputs Bernoulli probabilities at each pixel; for a natural image, Gaussian means with fixed variance; for a discrete sequence, categorical logits.
Why is this setup attractive? Because the prior is simple and the decoder is flexible, the marginal $p_\theta(\mathbf{x}) = \int p_\theta(\mathbf{x} \mid \mathbf{z}) p(\mathbf{z}) \, d\mathbf{z}$ can in principle express any continuous distribution that the decoder is rich enough to represent. It also separates what a datum is (encoded compactly in $\mathbf{z}$) from how it appears in pixel space (decoded by $p_\theta$). If $\mathbf{z}$ is much lower-dimensional than $\mathbf{x}$, the decoder must learn the manifold on which natural data lie, throwing away nuisance variation and keeping the directions that matter.
The price is that the marginal is intractable. Maximum-likelihood training would require $\sum_i \log \int p_\theta(\mathbf{x}_i \mid \mathbf{z}) p(\mathbf{z}) \, d\mathbf{z}$, and that integral has no closed form when $p_\theta$ is a neural network. Naive Monte Carlo with samples from the prior fails because $p_\theta(\mathbf{x} \mid \mathbf{z})$ is sharply concentrated on a small region of latent space, most random $\mathbf{z}$ produce an $\mathbf{x}$ wildly unlike any training datum, contributing negligibly to the integral and making variance impossibly large. We need importance sampling, but with a proposal distribution that places weight where the integrand actually lives.
The VAE's solution is to introduce an inference network, the encoder $q_\phi(\mathbf{z} \mid \mathbf{x})$, that takes a datum and produces a tight Gaussian over the latent codes consistent with it. Rather than fit $q_\phi$ separately for each datum (as classical mean-field variational inference does), the encoder is a single neural network shared across the dataset: a feed-forward map $\mathbf{x} \mapsto (\boldsymbol{\mu}_\phi(\mathbf{x}), \boldsymbol{\sigma}_\phi(\mathbf{x}))$ that amortises inference. One forward pass, and you have a usable proposal. With the encoder in place, the integral can be replaced by an expectation under $q_\phi$, and the bound that results (the ELBO) is differentiable in both $\theta$ and $\phi$; this trains the generator and inference network end-to-end.
Deriving the ELBO
Start from the marginal log-likelihood and write it as an expectation under any distribution $q_\phi(\mathbf{z} \mid \mathbf{x})$, the constant $\log p_\theta(\mathbf{x})$ does not depend on $\mathbf{z}$, so it survives the integration:
$$\log p_\theta(\mathbf{x}) = \mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}\left[\log p_\theta(\mathbf{x})\right].$$
Multiply and divide inside the logarithm by $p_\theta(\mathbf{z} \mid \mathbf{x}) q_\phi(\mathbf{z} \mid \mathbf{x})$ and use the identity $p_\theta(\mathbf{x}) = p_\theta(\mathbf{x}, \mathbf{z}) / p_\theta(\mathbf{z} \mid \mathbf{x})$:
$$\log p_\theta(\mathbf{x}) = \mathbb{E}_{q_\phi}\!\left[\log \frac{p_\theta(\mathbf{x}, \mathbf{z})}{p_\theta(\mathbf{z} \mid \mathbf{x})}\right] = \mathbb{E}_{q_\phi}\!\left[\log \frac{p_\theta(\mathbf{x}, \mathbf{z})}{q_\phi(\mathbf{z} \mid \mathbf{x})}\right] + \mathbb{E}_{q_\phi}\!\left[\log \frac{q_\phi(\mathbf{z} \mid \mathbf{x})}{p_\theta(\mathbf{z} \mid \mathbf{x})}\right].$$
The second expectation is the KL divergence between the approximate posterior $q_\phi$ and the true (intractable) posterior $p_\theta(\mathbf{z} \mid \mathbf{x})$. KL divergence is non-negative, it is zero if and only if the two distributions agree almost everywhere, so dropping it gives a lower bound:
$$\log p_\theta(\mathbf{x}) \geq \mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}\!\left[\log p_\theta(\mathbf{x}, \mathbf{z}) - \log q_\phi(\mathbf{z} \mid \mathbf{x})\right] \;\equiv\; \mathcal{L}(\theta, \phi; \mathbf{x}).$$
This is the evidence lower bound. Decompose the joint $\log p_\theta(\mathbf{x}, \mathbf{z}) = \log p_\theta(\mathbf{x} \mid \mathbf{z}) + \log p(\mathbf{z})$ and tidy:
$$\boxed{\;\mathcal{L}(\theta, \phi; \mathbf{x}) = \underbrace{\mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}\!\left[\log p_\theta(\mathbf{x} \mid \mathbf{z})\right]}_{\text{reconstruction}} \;-\; \underbrace{\mathrm{KL}\!\left(q_\phi(\mathbf{z} \mid \mathbf{x}) \,\|\, p(\mathbf{z})\right)}_{\text{regularisation}}\;}$$
The two terms have transparent interpretations.
The reconstruction term is the expected log-likelihood of the data conditional on a latent drawn from the encoder. If the decoder outputs a Gaussian with fixed variance, the term reduces to a (negative) squared error; if Bernoulli, to a (negative) binary cross-entropy. Maximising this term encourages every $\mathbf{z}$ that the encoder might produce for a given $\mathbf{x}$ to decode back to something close to $\mathbf{x}$.
The regularisation term is the KL divergence from the encoder to the prior. It pulls each per-datum posterior towards the standard Gaussian. Without it, the encoder could place each training point in its own infinitesimal Gaussian, far from the others; the decoder would memorise the dataset; the marginal latent distribution would be a forest of spikes nothing like $p(\mathbf{z})$, and sampling from the prior would land in latent regions the decoder had never seen.
A second view: rewrite the ELBO as $\mathcal{L} = \log p_\theta(\mathbf{x}) - \mathrm{KL}(q_\phi \| p_\theta(\cdot \mid \mathbf{x}))$. Maximising over $\theta$ pushes $\log p_\theta(\mathbf{x})$ up (model fit). Maximising over $\phi$ minimises the gap between $q_\phi$ and the true posterior (inference quality). The two improvements proceed simultaneously through the same gradient steps. When the encoder is expressive enough to match the true posterior exactly, the bound is tight and the VAE recovers maximum likelihood.
A third view, due to Hoffman and Johnson (2016), aggregates the per-datum KL: $\mathbb{E}_{p_{\text{data}}}[\mathrm{KL}(q_\phi(\mathbf{z} \mid \mathbf{x}) \| p(\mathbf{z}))] = \mathrm{KL}(q_\phi(\mathbf{z}) \| p(\mathbf{z})) + \mathbb{E}_{p_{\text{data}}}[\mathrm{KL}(q_\phi(\mathbf{z} \mid \mathbf{x}) \| q_\phi(\mathbf{z}))]$, where $q_\phi(\mathbf{z}) = \mathbb{E}_{p_{\text{data}}}[q_\phi(\mathbf{z} \mid \mathbf{x})]$ is the aggregate posterior. The first term encourages the marginal in latent space to match the prior; the second is mutual information between $\mathbf{x}$ and $\mathbf{z}$. The vanilla ELBO penalises both, a fact $\beta$-VAE will exploit.
The reparameterisation trick
The reconstruction term is an expectation under $q_\phi$ that we cannot evaluate analytically because the decoder is a neural network. We need a Monte Carlo estimate, and we need its gradient with respect to $\phi$, but the distribution we are sampling from is itself a function of $\phi$. The naive score-function estimator $\nabla_\phi \mathbb{E}_{q_\phi}[f(\mathbf{z})] = \mathbb{E}_{q_\phi}[f(\mathbf{z}) \nabla_\phi \log q_\phi(\mathbf{z})]$ is unbiased but has huge variance, especially when $f$ has a wide range, and ours, the per-datum reconstruction log-likelihood, certainly does.
Kingma and Welling's (2014) reparameterisation trick trades the dependent random variable for an independent one and a deterministic transformation. For a diagonal Gaussian variational family, write a sample as
$$\mathbf{z} = \boldsymbol{\mu}_\phi(\mathbf{x}) + \boldsymbol{\sigma}_\phi(\mathbf{x}) \odot \boldsymbol{\epsilon}, \qquad \boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}),$$
where $\odot$ is element-wise multiplication. The randomness now lives in $\boldsymbol{\epsilon}$, which is independent of the parameters; the dependence on $\phi$ is through the smooth functions $\boldsymbol{\mu}_\phi$ and $\boldsymbol{\sigma}_\phi$. A single $\boldsymbol{\epsilon}$ produces one Monte Carlo sample of $\mathbf{z}$ and hence one sample of $\log p_\theta(\mathbf{x} \mid \mathbf{z})$, and the gradient $\nabla_\phi$ flows through $\mathbf{z}$ as it would through any other intermediate layer. Variance drops by orders of magnitude relative to the score-function estimator, and a single sample per data point per minibatch is enough in practice.
The trick generalises: any distribution that admits a location–scale or push-forward parameterisation works the same way. For a uniform on $[\boldsymbol{\mu} - \boldsymbol{\sigma}, \boldsymbol{\mu} + \boldsymbol{\sigma}]$, sample $\boldsymbol{\epsilon} \sim \mathrm{Uniform}(-1, 1)$ and set $\mathbf{z} = \boldsymbol{\mu} + \boldsymbol{\sigma} \odot \boldsymbol{\epsilon}$. For an exponential, $\mathbf{z} = -\log(1-\boldsymbol{\epsilon})/\lambda$. Discrete distributions need either the Gumbel-softmax relaxation (continuous, biased) or the score-function estimator with control variates (discrete, unbiased, higher variance).
A practical detail: networks output $\log \boldsymbol{\sigma}^2$ rather than $\boldsymbol{\sigma}$ or $\boldsymbol{\sigma}^2$. Exponentiating $\log \boldsymbol{\sigma}^2$ guarantees positivity without constraining the network output, keeps gradients well-scaled across many orders of magnitude, and slots cleanly into the closed-form KL below.
Without the reparameterisation trick, the VAE simply could not be trained by standard stochastic gradient descent. Almost every later latent-variable generative model, normalising flows, diffusion, VQ-VAE with straight-through estimators, uses some form of the same idea: shift the randomness to a parameter-free source so gradients can flow.
KL divergence in closed form
For the diagonal-Gaussian variational family $q_\phi(\mathbf{z} \mid \mathbf{x}) = \mathcal{N}(\boldsymbol{\mu}, \mathrm{diag}(\boldsymbol{\sigma}^2))$ and the standard-normal prior $p(\mathbf{z}) = \mathcal{N}(\mathbf{0}, \mathbf{I})$, the KL divergence has a closed form:
$$\mathrm{KL}\!\left(q_\phi(\mathbf{z} \mid \mathbf{x}) \,\|\, p(\mathbf{z})\right) = \frac{1}{2}\sum_{i=1}^{d}\!\left(\sigma_i^2 + \mu_i^2 - 1 - \log \sigma_i^2\right).$$
Derivation in one line: writing the two Gaussian densities, expanding $\mathbb{E}_q[\log q - \log p]$, and using $\mathbb{E}_q[(z_i - \mu_i)^2] = \sigma_i^2$ and $\mathbb{E}_q[z_i^2] = \mu_i^2 + \sigma_i^2$, the constants $-d/2$ from each entropy cancel, the inner-product term collapses, and the formula above falls out dimension by dimension. The KL is non-negative, hits zero precisely when $\boldsymbol{\mu} = \mathbf{0}$ and $\boldsymbol{\sigma} = \mathbf{1}$, and has gradients that are trivial to compute.
A worked instance: take $d = 2$, $\boldsymbol{\mu} = (1.0,\ 0.0)$, $\boldsymbol{\sigma}^2 = (0.5,\ 1.5)$. Plug in dimension by dimension:
$$\frac{1}{2}\!\left(0.5 + 1.0 - 1 - \log 0.5\right) + \frac{1}{2}\!\left(1.5 + 0.0 - 1 - \log 1.5\right) = \frac{1}{2}(0.5 + 0.6931) + \frac{1}{2}(0.5 - 0.4055)$$
which evaluates to $0.597 + 0.047 \approx 0.644$ nats. The first dimension is doubly penalised, its mean is far from zero and its variance is shrunk below one, while the second is hardly penalised at all (mean already zero, variance close to one). This is the exact pressure the regularisation term applies to every minibatch.
For a general Gaussian prior $\mathcal{N}(\mathbf{m}, \mathbf{S})$ the KL is the familiar $\tfrac{1}{2}[\mathrm{tr}(\mathbf{S}^{-1}\boldsymbol{\Sigma}) + (\mathbf{m} - \boldsymbol{\mu})^\top \mathbf{S}^{-1} (\mathbf{m} - \boldsymbol{\mu}) - d + \log \det \mathbf{S} - \log \det \boldsymbol{\Sigma}]$. The standard-normal special case is what makes diagonal Gaussians the default choice.
Training a VAE on MNIST
A canonical reference architecture: encoder $784 \to 256 \to 128$ with ReLU, branching at the last hidden layer into two parallel linear heads of size $32$ that output $\boldsymbol{\mu}$ and $\log \boldsymbol{\sigma}^2$ respectively. Decoder $32 \to 128 \to 256 \to 784$ with ReLU and a final sigmoid that interprets each output as a Bernoulli probability. Total parameter count: roughly half a million. Adam at learning rate $10^{-3}$, minibatches of 128, ten to twenty epochs on a single GPU.
The forward pass and loss are short:
def forward(self, x):
mu, logvar = self.encode(x.view(-1, 784))
std = (0.5 * logvar).exp()
eps = torch.randn_like(std)
z = mu + std * eps
x_hat = torch.sigmoid(self.decode(z))
return x_hat, mu, logvar
def loss(x_hat, x, mu, logvar):
recon = F.binary_cross_entropy(x_hat, x.view(-1, 784), reduction='sum')
kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon + kl
The per-batch loss is the summed negative ELBO over pixels and over the minibatch. A common bug is averaging over pixels (which down-weights the reconstruction term by a factor of 784 and effectively makes you train a $\beta$-VAE with $\beta \approx 784$, collapsing the latent). Sum, then average over the batch.
To sample from a trained VAE: draw $\mathbf{z} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$, push through the decoder, and read off the mean image. To reconstruct: encode an input, take $\boldsymbol{\mu}$ as the latent (or sample, which is fine but slightly noisier), decode. To interpolate: encode two inputs to $\mathbf{z}_1$ and $\mathbf{z}_2$, decode along $\mathbf{z}(t) = (1-t)\mathbf{z}_1 + t \mathbf{z}_2$ for $t \in [0, 1]$ (or, better, slerp on the unit sphere).
Diagnostic numbers on MNIST after a dozen epochs: total negative ELBO around $90$ nats per image, of which roughly $85$ nats is reconstruction and $5$ nats is KL. The split is informative. If reconstruction dominates and KL is tiny, the latent is barely being used. If KL dominates and reconstruction is poor, the encoder is collapsing to the prior: posterior collapse, discussed below.
What the latent space looks like
Fit a VAE with a two-dimensional latent so the geometry can be drawn directly. Encode every test image and scatter the means $(\mu_1, \mu_2)$ coloured by digit class. The result is the diagnostic plot every VAE tutorial reproduces: ten roughly Gaussian blobs partially overlapping, arranged loosely around the origin, with smooth corridors between them rather than disconnected islands. Threes and eights cluster near each other; ones, with their narrow stroke, occupy a region away from the rest; sixes and zeros sit at opposite ends of a curl-direction axis.
Now sample on a grid: take $z_1, z_2 \in \{-3, -2, \dots, 2, 3\}$, decode each $\mathbf{z}$, and tile the resulting $7 \times 7$ images. The grid is the canonical "digit manifold" picture. Moving rightwards or upwards smoothly morphs one digit into another, a six rotating into a zero, an eight thinning into a one, without ever passing through nonsense in between. This continuity is the structural property the KL term bought us. Models trained without regularisation, or with the KL set to zero, produce the same reconstructions but a discontinuous latent: walking between two encoded points lands you in regions the decoder has never seen, and you get blurred chimaeras and ghost strokes.
Higher-dimensional latents (32 is a common choice) cannot be visualised directly but can be probed by interpolation, by traversal of single coordinates with the others held fixed, or by aggregating posteriors and projecting with t-SNE or UMAP. The qualitative pattern persists: smooth, mostly-connected manifold; class structure emerges without supervision; no axis is guaranteed to align with a human-meaningful factor of variation, but in practice some often do.
Limitations
- Blurry samples. The Gaussian-output decoder, combined with the integration over the posterior, averages over many plausible reconstructions and produces soft images. Sharper alternatives exist (discretised mixtures of logistics, autoregressive PixelCNN decoders, vector-quantised codebooks), but vanilla VAEs on natural images look noticeably foggier than a GAN or a diffusion model trained on the same data.
- Posterior collapse. When the decoder is powerful enough to model the data on its own, the optimisation can settle into a degenerate solution where $q_\phi(\mathbf{z} \mid \mathbf{x}) \approx p(\mathbf{z})$ for every input, the KL term goes to zero, and the latent carries no information about $\mathbf{x}$. The decoder ignores $\mathbf{z}$ and behaves as an unconditional model. KL annealing (warming the KL weight from $0$ to $1$ over training), free-bits, and skip connections all help, but the failure mode is real and stalks every powerful-decoder VAE.
- Reconstruction–KL trade-off. The two ELBO terms pull in opposite directions; the relative weight is set implicitly by the choice of likelihood variance and explicitly in $\beta$-VAE by the coefficient $\beta$. $\beta > 1$ buys disentangled, prior-aligned latents at the cost of reconstruction fidelity; $\beta < 1$ produces sharper reconstructions but a less structured latent. There is no principled choice of $\beta$ for a given dataset, only empirical tuning.
VAEs are no longer state of the art for raw image generation; diffusion models (§14.9) produce sharper samples with comparable training cost. But VAEs remain ubiquitous as components. VQ-VAE (van den Oord, Vinyals, Kavukcuoglu, 2017) replaces the Gaussian latent with a discrete codebook, enabling the autoregressive priors that power DALL·E and Jukebox. Latent diffusion (Rombach et al., 2022), the architecture behind Stable Diffusion, trains a KL-regularised VAE to compress images to a small latent grid, then runs a diffusion model in that latent space, paying a fraction of the compute a pixel-space diffusion would cost. The variational machinery itself (encoder, prior, ELBO, reparameterisation) is now standard equipment for any latent-variable deep model that needs a tractable training objective.
What you should take away
- A VAE is a generative model defined by a prior $p(\mathbf{z})$ and a learned decoder $p_\theta(\mathbf{x} \mid \mathbf{z})$, paired with a learned encoder $q_\phi(\mathbf{z} \mid \mathbf{x})$ that amortises variational inference over the dataset.
- Training maximises the evidence lower bound, $\mathcal{L} = \mathbb{E}_{q_\phi}[\log p_\theta(\mathbf{x} \mid \mathbf{z})] - \mathrm{KL}(q_\phi(\mathbf{z} \mid \mathbf{x}) \| p(\mathbf{z}))$, a reconstruction term minus a regularisation term, which is a lower bound on the marginal log-likelihood and tight when $q_\phi$ matches the true posterior.
- The reparameterisation trick rewrites a sample $\mathbf{z} \sim \mathcal{N}(\boldsymbol{\mu}, \boldsymbol{\sigma}^2)$ as $\boldsymbol{\mu} + \boldsymbol{\sigma} \odot \boldsymbol{\epsilon}$ with $\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$, so gradients flow through the encoder via standard backpropagation.
- The closed-form KL for a diagonal-Gaussian posterior against a standard-normal prior, $\tfrac{1}{2}\sum_i(\sigma_i^2 + \mu_i^2 - 1 - \log \sigma_i^2)$, supplies the regularisation term analytically without any sampling.
- The trained latent space is smooth and structured: interpolations between codes decode to interpolations in input space. Vanilla VAEs produce blurry samples and can suffer posterior collapse, but the variational machinery, and especially KL-regularised latent compression, underlies VQ-VAE, latent diffusion, and most modern large-scale generative systems.