One attention head learns one pattern. Why does GPT-4 have 96 of them? This is article 6 of 7 in Phase 1 of the DeepSeek Engineering Blog Series.
If You Read Nothing Else: A single attention head can only learn one type of relationship (e.g., "attend to the previous word"). Multi-Head Attention (MHA) runs multiple attention heads in parallel, each learning a different pattern. The outputs are concatenated and projected back. This gives the model the ability to simultaneously track syntax, semantics, coreference, and position — all in one layer.
Why one head isn't enough
Fig 1 — Multi-head attention: h parallel heads, each learning different patterns, concatenated and projected.
In the previous articles, we implemented single-head self-attention. It works — but it has a fundamental limitation. A single attention head produces one set of attention weights per token. That means token "it" gets one chance to decide what to attend to.
But language has multiple simultaneous relationships:
- "it" needs to attend to "animal" (coreference)
- "it" needs to attend to "was" (syntactic — subject-verb)
- "it" needs to attend to its position neighbors (local context)
One attention head can learn one of these patterns. Multi-head attention learns all of them in parallel.
How multi-head attention works
The idea from Section 3.2.2 of "Attention Is All You Need" (arXiv:1706.03762):
- Split: Divide the model dimension d into h heads, each of size d_h = d / h.
- Project: Each head has its own W_Q, W_K, W_V matrices of size (d, d_h).
- Attend: Run self-attention independently in each head.
- Concatenate: Stack all h output vectors side by side → dimension h × d_h = d.
- Output projection: Multiply by W_O (size d × d) to mix information across heads.
MultiHead(Q, K, V) = Concat(head_1, head_2, ..., head_h) · W_O where head_i = Attention(X · W_Q_i, X · W_K_i, X · W_V_i)
Concrete dimensions
| Model | d_model | Heads (h) | d_head | Total QKV params/layer |
|---|---|---|---|---|
| GPT-2 Small | 768 | 12 | 64 | 1.8M |
| GPT-3 175B | 12,288 | 96 | 128 | 603M |
| LLaMA-2 70B | 8,192 | 64 | 128 | 268M |
| DeepSeek-V2 | 5,120 | 128 | 128 (Q) / latent (KV) | MLA — different |
Notice DeepSeek-V2 has 128 attention heads — more than GPT-3's 96 — but uses MLA instead of standard MHA. The head dimension for queries is 128, but Keys and Values are compressed into a 512-dimensional latent space (covered in Phase 3).
What different heads learn
Researchers have analyzed what trained attention heads actually do. Using tools like BertViz, several distinct patterns emerge:
Fig 2 — Different heads specialize: positional, syntactic, and coreference patterns emerge naturally.
Positional heads
Some heads consistently attend to the previous or next token. These learn bigram-like patterns — local context that helps with syntax and word prediction.
Syntactic heads
Some heads track grammatical structure. In English, specific heads learn that verbs attend to their subjects, adjectives attend to the nouns they modify, and prepositions attend to their objects.
Coreference heads
These resolve pronouns. The head that helps "it" attend to "animal" in our running example. Lena Voita's research (Analyzing Multi-Head Attention) found specific heads that cleanly track coreference chains.
Rare-word / copy heads
Some heads attend strongly to rare or important tokens — proper nouns, numbers, technical terms. These help the model "copy" information forward when it needs to reference specific details.
Not all heads are equally useful. Research by Michel et al. ("Are Sixteen Heads Really Better than One?", arXiv:1905.10650) showed that many heads can be pruned after training with minimal quality loss. Some heads are redundant. This insight partly motivates the move from MHA to MQA/GQA/MLA — if many heads learn similar Key-Value patterns, why not share them?
The computational cost
Multi-head attention has the same asymptotic complexity as single-head attention: O(n² · d). This is because:
- Each head computes attention over (n, d_h) matrices: O(n² · d_h) per head
- With h heads: O(h · n² · d_h) = O(n² · d) — same as single-head with full dimension
The heads run in parallel on GPUs (different heads = different matrix multiplications that can execute simultaneously), so wall-clock time is similar to single-head. The total FLOPs are the same — you're just distributing the same work across specialized sub-problems.
Memory cost: the KV cache problem
This is where multi-head attention gets expensive at inference. Each head caches its own K and V vectors. The KV cache size per layer:
KV cache per layer = 2 × n_heads × head_dim × seq_len × bytes_per_element
= 2 × h × d_h × n × sizeof(dtype)
# GPT-3 at 4096 context, FP16:
= 2 × 96 × 128 × 4096 × 2 bytes
= 201 MB per layer
× 96 layers
= 19.3 GB total KV cache
For DeepSeek-V2 with standard MHA, this would be enormous. That's why MLA exists — it replaces the per-head KV cache with a shared latent vector, reducing memory by 93%.
The output projection W_O
After concatenating all head outputs, we multiply by W_O ∈ ℝ^(d × d). This matrix serves a critical purpose: it lets information from different heads interact.
Without W_O, each head's output would remain in its own subspace. The output projection mixes them, allowing the model to combine syntactic information from head 3 with coreference information from head 7 into a single unified representation.
# Pseudocode
heads = []
for i in range(h):
Q_i = X @ W_Q[i] # (n, d_h)
K_i = X @ W_K[i] # (n, d_h)
V_i = X @ W_V[i] # (n, d_h)
head_i = attention(Q_i, K_i, V_i) # (n, d_h)
heads.append(head_i)
concat = torch.cat(heads, dim=-1) # (n, h*d_h) = (n, d)
output = concat @ W_O # (n, d)
Efficient batched implementation
In practice, no one runs a for-loop over heads. Instead, Q, K, V are computed with a single large matrix multiplication, then reshaped:
# Efficient: one matmul, then reshape QKV = X @ W_QKV # (n, 3*d) — single projection Q, K, V = QKV.chunk(3) # each (n, d) # Reshape: (n, d) → (n, h, d_h) → (h, n, d_h) Q = Q.view(n, h, d_h).transpose(0, 1) K = K.view(n, h, d_h).transpose(0, 1) V = V.view(n, h, d_h).transpose(0, 1) # Batched attention over all heads simultaneously attn = scaled_dot_product_attention(Q, K, V) # (h, n, d_h) # Reshape back: (h, n, d_h) → (n, h, d_h) → (n, d) output = attn.transpose(0, 1).contiguous().view(n, d) output = output @ W_O
This is exactly how it's implemented in nanoGPT (model.py) and in the official PyTorch nn.MultiheadAttention. We'll build this from scratch in Article 1.7.
From MHA to what comes next
Understanding MHA is essential because every optimization in Phase 2 and 3 is a modification of this exact mechanism:
- MQA (Multi-Query Attention): All heads share the same K and V. Reduces KV cache by factor h. Used by PaLM, Falcon.
- GQA (Grouped Query Attention): Group heads into g groups, each sharing K and V. Compromise between MHA and MQA. Used by LLaMA-2/3, Mistral.
- MLA (Multi-Head Latent Attention): DeepSeek's approach. Compress all KV into a single learned latent vector. Different philosophy — compress rather than share.
All three are trying to solve the same problem: MHA's KV cache is too large for long-context inference. They just take different approaches to reducing it.
5 things to remember
- Multiple heads: Split d into h heads of size d/h. Each learns a different attention pattern.
- Concat + W_O: Concatenate all heads, project with W_O to mix across heads.
- Same FLOPs: MHA has the same compute cost as single-head attention with full dimension.
- KV cache: Each head stores its own K and V → memory scales with h × d_h × n × layers.
- Head specialization: Different heads learn syntax, coreference, position, and copy patterns.
Go deeper
- Paper: "Attention Is All You Need" (Section 3.2.2) — arXiv:1706.03762
- Paper: "Are Sixteen Heads Really Better than One?" — arXiv:1905.10650
- Tool: BertViz — attention head visualizer — GitHub
- Blog: Lena Voita — Attention head analysis — lena-voita.github.io