12.14 Beam search and other decoding strategies
A trained sequence model defines a distribution $P_\theta(y \mid x)$ over output sequences. At test time, we want to produce a single sequence (or a small set of candidates) from this distribution. Decoding strategy is a deeper subject than it first appears; the choice between greedy, beam, top-$k$, top-$p$, and temperature sampling materially affects output quality, diversity, and behaviour.
12.14.1 Greedy decoding
At each step, output the most probable next token:
$$y_t = \arg\max_{w \in \mathcal{V}} P_\theta(w \mid y_{\lt t}, x).$$
def greedy_decode(model, x, max_len, eos_id):
y = [BOS_ID]
state = model.encode(x)
for _ in range(max_len):
logits, state = model.decode_step(y[-1], state)
next_token = int(logits.argmax(-1))
y.append(next_token)
if next_token == eos_id:
break
return y[1:]
Greedy is fast and deterministic but myopic: a low-probability token at step $t$ may be the gateway to a much-higher-probability completion. Greedy decoding routinely produces dull, repetitive output and misses the global optimum.
12.14.2 Beam search
Beam search maintains a set of $B$ partial sequences (the beam) ranked by accumulated log-probability. At each step, expand every partial sequence by every possible next token, score the resulting $|B| \cdot V$ candidates, and keep the top $B$. Continue until all sequences in the beam have terminated, or a length limit is reached.
import math
def beam_search(model, x, beam_size, max_len, eos_id):
state = model.encode(x)
# Each beam: (score, tokens, state, finished)
beams = [(0.0, [BOS_ID], state, False)]
for _ in range(max_len):
candidates = []
for score, toks, st, finished in beams:
if finished:
candidates.append((score, toks, st, True))
continue
logits, new_st = model.decode_step(toks[-1], st)
log_probs = log_softmax(logits)
top = log_probs.topk(beam_size)
for lp, idx in zip(top.values.tolist(), top.indices.tolist()):
fin = (idx == eos_id)
candidates.append((score + lp, toks + [idx], new_st, fin))
beams = sorted(candidates, key=lambda b: b[0] / len(b[1]),
reverse=True)[:beam_size]
if all(b[3] for b in beams):
break
best = max(beams, key=lambda b: b[0] / len(b[1]))
return best[1][1:]
Two practical points:
- Length normalisation. Without normalisation, beam search systematically prefers shorter sequences (every additional token adds a negative log-probability). Standard practice is to divide the score by $|y|^\alpha$ with $\alpha \in [0.6, 1.0]$, or to use Wu et al.'s 2016 coverage penalty for translation.
- Beam width. Wider beams find higher-probability outputs but at the cost of compute and, perversely, often worse output quality. The "beam search curse", quality declining as beam width increases past 5–10, is well-documented for translation. The most likely sequence under a maximum-likelihood-trained model is often a degenerate one (empty, or extremely short, or repetitive). Best beam widths in production are typically 4 to 10.
12.14.3 Sampling: temperature, top-$k$, top-$p$
For open-ended generation (story writing, dialogue, summarisation), some randomness is desirable. Pure sampling from $P_\theta$ is too noisy; greedy is too dull. The intermediate strategies are:
Temperature sampling. Replace the softmax $P(w) \propto \exp(z_w)$ by $P_\tau(w) \propto \exp(z_w / \tau)$ for a temperature $\tau > 0$.
- $\tau \to 0$: distribution concentrates on the argmax (recovers greedy).
- $\tau = 1$: original distribution.
- $\tau > 1$: distribution flattens, more diversity, more risk of incoherence.
- $\tau < 1$ but $> 0$: distribution sharpens but remains stochastic.
Top-$k$ sampling. Restrict the sampling distribution to the top $k$ most probable tokens; renormalise; sample.
def top_k_sample(logits, k, temperature=1.0):
logits = logits / temperature
top_vals, top_idxs = logits.topk(k)
probs = softmax(top_vals)
choice = torch.multinomial(probs, num_samples=1)
return top_idxs[choice]
This caps the worst-case "tail" tokens. Typical $k$ ranges from 10 to 100.
Top-$p$ (nucleus) sampling. Holtzman et al. 2019 argue that the appropriate cutoff is not a fixed $k$ but a fixed cumulative probability mass $p$. Sort tokens by probability, take the smallest set whose cumulative probability is at least $p$, renormalise, and sample.
def top_p_sample(logits, p, temperature=1.0):
logits = logits / temperature
sorted_logits, sorted_idxs = logits.sort(descending=True)
probs = softmax(sorted_logits)
cum = probs.cumsum(0)
cutoff = int((cum < p).sum().item()) + 1
top_probs = probs[:cutoff]
top_probs = top_probs / top_probs.sum()
choice = torch.multinomial(top_probs, num_samples=1)
return sorted_idxs[choice]
Top-$p$ adapts the truncation per-step: when the model is confident (peaked distribution), only a few tokens enter the nucleus; when it is uncertain (broad distribution), more tokens are included. Typical $p$ is 0.9 to 0.95.
Combining strategies. Modern systems often combine: temperature 0.8, top-$p$ 0.95, optionally top-$k$ 50 as a hard cap. The interaction of the three knobs is empirically determined and somewhat task-dependent.
12.14.4 Repetition penalties
Sequence models, especially language models, tend to fall into degenerate repetition loops at low temperatures. Repetition penalty (Keskar et al. 2019) discounts the logits of tokens that have already appeared in the output:
$$z_w \gets z_w / r \quad \text{if } w \in y_{\lt t} \text{ and } z_w > 0; \qquad z_w \gets z_w \cdot r \quad \text{if } w \in y_{\lt t} \text{ and } z_w \le 0,$$
with $r \approx 1.1$ to $1.3$. Variants include n-gram blocking (forbid producing any n-gram that has already appeared) and presence/frequency penalties (additive rather than multiplicative).