13.8 Causal masking and autoregressive prediction

A decoder transformer must not look into the future. If, while learning to predict the next word, the model could simply read the next word from the input, training would collapse into a copy. The model would score perfectly on the training set and learn nothing about language. The mechanism that prevents this, and that, almost as a side effect, lets us train these models in massively parallel fashion, is the causal mask. It is a small piece of arithmetic with outsized consequences.

In §13.7 we surveyed the three flavours of transformer: encoder-only (BERT-style), decoder-only (GPT-style) and encoder–decoder (the original 2017 design). The thing that distinguishes a decoder from an encoder is not the layer structure, both are stacks of attention plus feed-forward blocks, but the mask applied to attention. This section is about that mask, why it works, how it interacts with training and inference, and the modern variants that make long-context decoders feasible.

Symbols Used Here
$\mathbf{M}$the causal mask matrix, an $n \times n$ matrix added to attention scores before softmax.

The mechanism

For a sequence of length $n$, the causal mask is the matrix $\mathbf{M} \in \mathbb{R}^{n \times n}$ with entries

$$ M_{ij} = \begin{cases} 0 & \text{if } j \le i, \\ -\infty & \text{if } j > i. \end{cases} $$

In practice $-\infty$ is a large negative constant, typically $-10^{9}$ in float32, or the most negative representable value of the floating-point format, large enough that after softmax the corresponding weight rounds to zero. The mask is added to the scaled dot-product logits before softmax:

$$ \mathbf{A} = \operatorname{softmax}\!\left( \frac{\mathbf{Q}\mathbf{K}^{\top}}{\sqrt{d_k}} + \mathbf{M} \right) \mathbf{V}. $$

Because $\exp(-\infty) = 0$, the row-wise softmax of any logit pushed to $-\infty$ contributes nothing to the normaliser, and the resulting attention weight is exactly zero. Position $i$ therefore attends to positions $1, 2, \dots, i$ and to none of the positions strictly after it. The output at position $i$ is a weighted sum of value vectors drawn only from the past and present.

It is worth pausing on why we mask at the logit stage rather than the weight stage. Setting weights to zero after softmax would not work, because softmax normalises. Zeroing some entries afterwards would leave the row summing to less than one, breaking the convex-combination semantics of attention. By placing $-\infty$ in the logits, we let softmax do the right thing automatically: the surviving past-and-present positions absorb the probability mass that would otherwise have leaked into the future, and the row still sums to exactly one.

The mask is a constant of the architecture, not a learned parameter. In code it is built once and broadcast across heads, batches and layers. In multi-head attention the same mask applies to every head; each head still learns its own $\mathbf{Q}, \mathbf{K}, \mathbf{V}$ projections, but all heads share the same view of which past positions are visible. In encoder–decoder cross-attention (§13.4) only the decoder's self-attention is masked; the cross-attention from decoder to encoder is unmasked, because the encoder represents an input that is fully observed.

Why this lets you train in parallel

The causal mask has a counter-intuitive consequence: it is precisely because the model is forced to be causal that we can train it on a whole sequence in a single forward pass. To see this, consider what we want the loss to be. Given a sequence of tokens $w_1, w_2, \dots, w_n$, a language model is trained to maximise

$$ \log p(w_1, \dots, w_n) \;=\; \sum_{t=1}^{n} \log p(w_t \mid w_1, \dots, w_{t-1}). $$

Each summand depends only on the past. With the causal mask, the transformer's output at position $t$ depends only on positions $1, \dots, t$. So the output at position $t$ is exactly the right thing to use as the prediction for $w_{t+1}$. (Equivalently, after shifting inputs and targets by one: the output at position $t$ predicts the input at position $t+1$.) A single forward pass on the sequence therefore produces $n$ next-token predictions in parallel, and the training loss is the mean of $n$ cross-entropy terms, all computed simultaneously, on a GPU, in dense matrix multiplications.

Compare this with a vanilla RNN. The hidden state at step $t$ depends sequentially on the hidden state at step $t-1$, which depends on $t-2$, and so on. You cannot start computing step $t$ until step $t-1$ has finished. RNN training time scales with sequence length even on perfectly parallel hardware. The transformer breaks this dependency: with the mask in place, the only constraint is "do not look at the future", and that constraint is enforced by a static matrix rather than by sequential execution. All $n$ positions can be processed in one go.

This is the practical reason transformers eclipsed RNNs for large-scale language modelling. On modern GPUs and TPUs, throughput on long sequences is enormous: a single A100 can push tens of thousands of tokens per second through a 1B-parameter decoder during training. The same model trained as an RNN would take vastly longer, because no amount of hardware can paper over a strictly sequential dependency.

At inference time, however, the symmetry breaks. You do not yet have the future tokens; you are generating them. To produce $w_{t+1}$ you sample from the model's predicted distribution, append the sampled token to the input, and run the model again. This is autoregressive decoding. Done naively, generating $n$ tokens costs $n$ forward passes, each on a sequence of growing length, for a total cost of $O(n^2)$ in attention work and $O(n^2 d)$ in feed-forward work. The training-time parallelism is gone, replaced by a fundamentally serial sampling loop. The next subsection describes the trick that makes inference manageable.

KV cache at inference

Most of the computation in autoregressive decoding is wasted. When generating token $t+1$, the model recomputes attention for positions $1, \dots, t$, even though those positions, their inputs, and consequently their keys and values, have not changed since the previous step. The only new thing is the query at position $t+1$.

The KV cache exploits this redundancy. After producing each token, we store its key and value vectors for every layer of the model. On the next step we compute only the query for the new token and the new key and value for that single position; we then attend to the entire cached set of keys and values. The attention computation at step $t+1$ becomes a vector-matrix product: one query against $t$ cached keys, then a softmax-weighted sum over $t$ cached values. Per-token work is $O(t \cdot d)$ rather than $O(t^2 \cdot d)$, and total work to generate $n$ tokens drops from $O(n^3 d)$ to $O(n^2 d)$, still quadratic, but now in the cumulative cache size rather than in repeated full passes.

The memory cost is significant. For a model with $L$ layers, $h$ heads of dimension $d_k$ each, and a sequence of length $n$, the cache holds $2 \cdot L \cdot n \cdot h \cdot d_k$ scalars: keys and values, every layer, every head, every position. For a 70B-parameter model with $L = 80$, $h = 64$, $d_k = 128$ and $n = 4096$ in float16, that is $2 \cdot 80 \cdot 4096 \cdot 64 \cdot 128 \cdot 2$ bytes $\approx 10.7$ GB per request. At long contexts the KV cache, not the model weights, becomes the dominant memory pressure on the GPU. Engineering work around inference, paged attention, KV quantisation, prefix caching, multi-query and grouped-query attention, is largely about taming this cost. We return to it in §13.19.

Sliding-window attention

Even with a KV cache, the cost of attention grows linearly with context length on every generated token, and the cache itself grows without bound. Modern long-context decoders therefore relax the all-the-past assumption. Sliding-window attention, popularised by the Mistral models, restricts each position to attend only to the previous $W$ positions, for some fixed window $W$ (typically a few thousand). The mask is no longer strictly triangular: it is band-limited, with $-\infty$ both above the diagonal and more than $W$ below it.

The cost per token at inference is now $O(W)$ regardless of how far we are into the sequence, and the active KV cache is bounded by $W$. Information from earlier than $W$ steps ago is not lost entirely, however, because each layer slides its own window: position $t$ in layer $\ell$ attends to positions $t-W, \dots, t$ in layer $\ell - 1$, which themselves saw their own $W$-sized windows in layer $\ell - 2$, and so on. Stacking $L$ layers gives an effective receptive field of roughly $L \cdot W$, in much the same way that stacked convolutions extend their effective receptive field. With $L = 32$ and $W = 4096$, the model can in principle propagate information across $\sim$100k tokens.

Variants combine sliding windows with a small number of "global" positions that everyone can attend to (Longformer, BigBird), or with periodic full-attention layers, or with attention sinks that preserve the first few tokens of the sequence (StreamingLLM). The common theme is: the strict $O(n^2)$ all-pairs causal mask is one design point, not the only one, and modern decoders pick whichever pattern of allowed pairs their context budget permits.

It is worth noting one subtlety here. Sliding-window attention does not change the causal nature of the mask, only its density. The window still extends only into the past; positions in the future remain masked out. So all the training-time benefits described above still apply: a single forward pass over a long sequence produces predictions at every position in parallel, and the only difference is that each position now attends to a bounded slice of its past rather than to all of it. This is why the technique composes cleanly with the standard transformer training recipe, no change to the loss, no change to the optimiser, only a change to the mask matrix.

What you should take away

  1. The causal mask is the matrix $\mathbf{M}$ with $0$ on and below the diagonal and $-\infty$ above; adding it to logits before softmax forces each position to attend only to itself and its past.
  2. Causality and parallel training are complementary, not opposed: because the mask makes the prediction at position $t$ depend only on positions $\le t$, a single forward pass produces $n$ next-token predictions and $n$ losses, all computed simultaneously.
  3. At inference the parallelism is gone, generation is autoregressive, but the KV cache preserves keys and values across steps so each new token costs $O(t \cdot d)$ rather than $O(t^2 \cdot d)$.
  4. KV cache memory grows linearly with context and rapidly dominates GPU memory at long sequence lengths; this drives much of the engineering of modern inference stacks.
  5. Sliding-window attention trades global reach for a fixed per-token cost; stacked layers recover an effective receptive field roughly equal to depth $\times$ window.

This site is currently in Beta. Contact: Chris Paton

Textbook of Usability · Textbook of Digital Health

Auckland Maths and Science Tutoring

AI tools used: Claude (research, coding, text), ChatGPT (diagrams, images), Grammarly (editing).