13.3 Multi-head attention
A single attention head looks at every position in the sequence and produces, for each position, one weighted blend of the other positions' value vectors. That blend is governed by exactly one set of compatibility scores, which means a single head commits to one way of mixing information. If a head learns to track subject–verb agreement, the same head must also handle long-range references and prepositional attachment using the very same softmax. Real language demands several relationships at once: the word "it" might need to retrieve its antecedent, while the same position simultaneously needs to know what the previous adjective was. One blend per position is not enough.
Multi-head attention solves this by running several attention computations side by side, each with its own learned projections. We call each parallel computation a head. Each head produces its own weighted blend, and because each head has its own queries, keys and values, each head can specialise in a different pattern. One head might focus on local context, the previous one or two tokens. Another might attend to the start-of-sequence token as a default fall-back. A third might learn syntactic dependencies, attending from a verb back to its subject. The model is no longer forced to compress every relationship through a single softmax.
Once each head has produced its output, the heads are concatenated end to end and passed through a single learned linear projection $\mathbf{W}_O$. The concatenation places the per-head vectors next to each other; the projection then mixes them so that downstream layers see a single $d_{\text{model}}$-dimensional representation per position. This concatenate-then-project pattern is what gives multi-head attention its expressive power: heads can specialise in isolation, and $\mathbf{W}_O$ then decides how the layer's output is composed from those specialists.
This section bridges scaled dot-product attention (§13.2) and the full transformer block (§13.6).
The formula
For $h$ heads, each indexed by $i$:
$\mathrm{head}_i = \mathrm{Attention}(\mathbf{Q}\mathbf{W}_Q^i, \mathbf{K}\mathbf{W}_K^i, \mathbf{V}\mathbf{W}_V^i)$
$\mathrm{MultiHead}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \mathrm{Concat}(\mathrm{head}_1, \ldots, \mathrm{head}_h)\,\mathbf{W}_O.$
The first line says that head $i$ takes the same input matrices $\mathbf{Q}$, $\mathbf{K}$ and $\mathbf{V}$ that arrive at the layer, but projects each of them through its own private matrices $\mathbf{W}_Q^i$, $\mathbf{W}_K^i$, $\mathbf{W}_V^i$ before passing them into the scaled dot-product attention of §13.2. Each of these projection matrices has shape $d_{\text{model}} \times d_k$, so the projected $\mathbf{Q}\mathbf{W}_Q^i$ is a smaller matrix with $d_k$ columns instead of $d_{\text{model}}$. The smaller dimension $d_k$ is chosen so that $h \cdot d_k = d_{\text{model}}$, the per-head dimension is precisely the total dimension divided by the number of heads. This split is deliberate: it means the multi-head computation costs about the same as a single-head computation at the full dimension.
The output of head $i$ is therefore a matrix with one row per input position and $d_v = d_k$ columns. Once we have all $h$ such matrices, we glue them together side by side using $\mathrm{Concat}$. Concatenation along the last axis yields a matrix with one row per position and $h \cdot d_v = d_{\text{model}}$ columns. The final step multiplies this concatenated matrix by $\mathbf{W}_O$, which has shape $d_{\text{model}} \times d_{\text{model}}$ and is itself learned during training. $\mathbf{W}_O$'s job is to mix information across heads: without it, each head's contribution would sit in its own slice of the output and never interact with the others. With it, the model can learn to combine, weight and re-route what each head has discovered.
In practice, we never implement multi-head attention as $h$ separate matrix multiplications. Each per-head projection matrix $\mathbf{W}_Q^i$ for $i = 1, \ldots, h$ is stacked side by side into a single big $\mathbf{W}_Q$ of shape $d_{\text{model}} \times d_{\text{model}}$. We multiply the input by this one matrix and then reshape the output into a tensor of shape $n \times h \times d_k$. The same trick applies to $\mathbf{K}$ and $\mathbf{V}$. The entire layer therefore requires four big matrix multiplications regardless of how many heads we use, and GPUs are extremely efficient at large dense matrix multiplications.
Why multiple heads
A single head must compress every relationship in a sentence into one set of attention weights. Consider "the cat that the dog chased was scared". When reading "scared", a single head would have to decide simultaneously whether to attend to "cat" (the subject of the relative clause), "dog" (the agent inside it), or "chased" (the verb). Whatever single distribution it produces will be a compromise.
Multi-head attention removes this tension. Each head has independent parameters, so there is no reason for two heads to learn the same pattern. During training, gradient descent gently nudges the heads apart: if two heads currently do similar things, only one will receive credit for a useful behaviour, and the other is free to drift. This pressure produces a division of labour.
Empirical studies of trained transformers catalogue the kinds of specialisation that emerge. Some heads become local-context experts, attending almost entirely to the previous one or two tokens; they behave like delay lines copying recent information forward. Others roam much further, reaching back hundreds of tokens for a matching word. Some track syntactic relationships (verbs to subjects, prepositions to their nouns, pronouns to antecedents). Others track semantic features such as topic or named-entity type. None of this is hand-coded; it emerges as a side effect of training a network with many heads against a loss that rewards better predictions.
A particularly informative experiment is head ablation. We zero out the contribution of a single head and measure the effect on overall performance. Often, ablating one head produces only a small drop in average loss but a sharp drop on a specific narrow task. A head that tracks subject–verb agreement, for example, may matter only when a long noun phrase intervenes between subject and verb; on most sentences its absence is invisible, but on those particular sentences it is critical. This pattern of small average losses with sharp specific losses is strong evidence that heads specialise, even though no part of training asked them to.
Not every head turns out to be useful. Many heads in trained transformers can be pruned without measurable harm, they have learned redundant or near-trivial patterns. This redundancy motivates the efficiency variants we discuss below.
Worked: 8 heads with $d_{\text{model}} = 64$
To make the parameter count concrete, take $h = 8$ and $d_{\text{model}} = 64$. The per-head dimension is then $d_k = d_v = 64 / 8 = 8$. Each head has its own query projection $\mathbf{W}_Q^i$ of shape $64 \times 8$. With eight such matrices, the total number of parameters used for query projections is $8 \times 64 \times 8 = 4096$. Now compare this with a single-head attention layer that uses the full model dimension: it needs one $\mathbf{W}_Q$ of shape $64 \times 64$, which contains $64 \times 64 = 4096$ parameters. The two are identical. Multi-head attention does not increase the parameter count; it reorganises the same parameters into smaller, parallel slices.
The same argument applies to $\mathbf{W}_K$ and $\mathbf{W}_V$. Each contributes $4096$ parameters whether viewed as $h$ small matrices or one big one. The output projection $\mathbf{W}_O$ is a single $64 \times 64$ matrix, another $4096$ parameters. The multi-head attention sub-layer therefore has $4 \times d_{\text{model}}^2 = 16{,}384$ parameters in total, ignoring biases, exactly the same count as one single head at the full dimension.
This equivalence is the design choice that makes multi-head attention practical. By setting $d_k = d_{\text{model}}/h$, the original transformer paper ensured that adding heads is free in parameters and roughly free in computation: the total scoring cost is $n^2 \times d_{\text{model}}$, independent of $h$. We can therefore choose $h$ purely on modelling grounds.
The trade-off shows up elsewhere. With more heads at fixed $d_{\text{model}}$, each head has fewer dimensions to work with. A head with $d_k = 8$ has plenty of room for tasks like "attend to the previous token" but may be too narrow for richer relational patterns. With fewer heads, each head is more expressive but the layer has fewer parallel specialists. Production transformers typically use $h$ between 8 and 96, with $d_k$ between 64 and 128; the original 2017 paper used $h = 8, d_k = 64$ in the base model.
What heads actually learn
Once a transformer has been trained, we can visualise what each head attends to by plotting the attention weights as a heatmap. Studies that have done this systematically, Vig (2019) on GPT-2 and BERT, Clark et al. (2019) on BERT, Voita et al. (2019) on the original transformer, have catalogued a recurring set of patterns that show up across architectures.
The simplest pattern is the next-token head: a head whose attention weights are concentrated almost entirely one position to the right (or one position to the left, in the causal case where future tokens are masked). These heads behave like a hard-coded shift register; their function is to make a copy of the adjacent token's representation available at the current position. Previous-token heads do the same in the other direction. Together with residual connections, these heads give the network a way to mix information across nearby positions without requiring a deep stack.
Many heads attend strongly to the start-of-sequence token regardless of the input. This token serves as an attention sink: a no-op default that lets a head abstain when it has nothing useful to say. Without such a sink, the softmax forces weights to sum to one across the sequence, so a head that wants to "do nothing" still smears weight across real tokens. The sink absorbs that weight and lets the head genuinely contribute zero.
Heads tracking syntactic dependencies are particularly clean. A head may attend, for every verb, primarily to the verb's grammatical subject, even when separated by relative clauses. Another may attend, for every preposition, to the noun it modifies. Yet another may track coreference, with pronouns attending to their antecedents. These patterns emerged because they help the network predict the next token well.
A particularly important emergent pattern is the induction head, identified in Anthropic's mechanistic interpretability work. An induction head implements the algorithm "if the current token is $A$, find a previous occurrence of $A$, then attend to whatever followed it". Induction heads always appear in pairs of layers, with a previous-token head copying information forward and a matching head in the next layer using that information as a search key. Strikingly, induction heads form at a sharp phase transition during training, and that transition coincides with the model gaining the ability to do in-context learning. This is one of the cleanest results in mechanistic interpretability: a specific circuit demonstrably implements a specific algorithm, which demonstrably explains a specific capability.
Multi-query attention and grouped-query attention
Multi-head attention is straightforward to train but expensive to run at inference time. The reason is the KV cache, which we will discuss in detail in §13.19. Briefly: when generating tokens one at a time, the keys and values for every previously generated position must be kept in memory so that the next query can attend to them. With $h$ heads, the cache stores $h$ copies of the keys and values, one per head. For long contexts and large models, this cache can dwarf the model's own parameters in memory footprint and become the dominant cost of serving the model.
Multi-query attention (MQA), proposed by Shazeer in 2019, addresses this by sharing the keys and values across heads. The queries remain per-head; there are still $h$ separate query projections, so the heads can still specialise in what they look for, but only one set of keys and one set of values is computed and cached. This reduces the KV cache by a factor of $h$, often a tenfold or larger saving. The cost is a small drop in modelling quality, because all heads must now be content with the same key and value subspaces.
Grouped-query attention (GQA) generalises this idea. Heads are partitioned into $g$ groups, and within each group the keys and values are shared. With $g = 1$ this is multi-query attention; with $g = h$ it is full multi-head attention; intermediate values trade off cache size against quality. A typical configuration is $h = 32$ query heads and $g = 8$ key-value groups, giving four query heads per group and a fourfold reduction in cache size. Empirical studies show that GQA with a sensible group count recovers most of the quality of full multi-head attention while keeping most of the inference savings of multi-query.
Modern frontier models adopt these variants almost universally. Llama-2 and Llama-3 use grouped-query attention. GPT-4, Gemini and Claude families are widely understood to use multi-query or grouped-query variants in production. The driving force is inference economics: with hundreds of millions of users and contexts of tens or hundreds of thousands of tokens, the cost of serving a model is dominated by KV cache memory and bandwidth. The principle from §13.2, attention as content-based retrieval with separate $\mathbf{Q}$, $\mathbf{K}$ and $\mathbf{V}$ ports, is preserved; only the wiring has changed.
Where multi-head shows up everywhere
Every modern transformer uses multi-head attention. BERT uses bidirectional multi-head self-attention to build context-aware representations. GPT uses causal multi-head self-attention so that each position attends only to itself and earlier positions. T5 and BART use multi-head attention in both encoder and decoder. Vision Transformers apply multi-head self-attention to flattened image patches. Multimodal models such as CLIP, Flamingo and GPT-4V use multi-head cross-attention to let text queries reach into image keys and values.
The variants matter less than the core. FlashAttention rearranges the order of operations to fit the GPU's memory hierarchy; it computes exactly the same multi-head attention, just faster. Multi-query and grouped-query attention change which heads share keys and values, but the $\mathrm{Concat}$-and-project structure remains. Linear-attention variants approximate the softmax but still partition channels into heads. The pattern (run several attention computations in parallel, glue them together, mix them with a learned projection) has proven durable across a decade of architectural change.
What you should take away
- Multi-head attention runs $h$ scaled dot-product attentions in parallel. Each head has its own $\mathbf{W}_Q^i$, $\mathbf{W}_K^i$ and $\mathbf{W}_V^i$ at the smaller dimension $d_k = d_{\text{model}}/h$, so each head can specialise in a different relational pattern.
- Concatenation plus a single output projection is what makes the layer expressive. The $h$ head outputs are glued side by side, and $\mathbf{W}_O$ then mixes information across heads to produce the layer's $d_{\text{model}}$-dimensional output.
- Multi-head has the same parameter count as a single full-dimensional head. With $d_k = d_{\text{model}}/h$, the total parameters across all per-head projections equal $d_{\text{model}}^2$, exactly as for one big projection. Adding heads is essentially free.
- Trained heads specialise in interpretable ways. Visualisations show next-token heads, previous-token heads, attention-sink heads, syntactic-dependency heads, coreference heads and induction heads. None of these are hand-coded; they emerge from gradient descent.
- At inference time, multi-query and grouped-query variants share keys and values across heads to shrink the KV cache. This trades a small loss in quality for large savings in memory and bandwidth, and is now standard in modern frontier models such as Llama-3, GPT-4 and Gemini.