Multi-Head Attention
Category: Generative AI
Difficulty: Advanced
Time Complexity: O(h × n² × d_k)
Space Complexity: O(h × n² + n × d)
Overview
Section titled “Overview”Multi-Head Attention is a core mechanism in the Transformer architecture that runs several attention functions in parallel. Instead of computing a single attention pass with d_model-dimensional keys, queries, and values, the model splits them into h heads, each operating on d_k = d_model / h dimensions. Each head learns to attend to different aspects of the input — one head might focus on syntactic relationships while another captures semantic similarity. After all heads compute their attention independently, their outputs are concatenated back into a d_model-dimensional vector and passed through a final linear projection W_O. This projection allows the model to mix information across heads. The key insight is that multiple smaller attention operations are more expressive than a single large one, enabling the model to jointly attend to information from different representation subspaces at different positions.
Try It
Section titled “Try It”- Web: Open in Eigenvue →
- Python:
import eigenvueeigenvue.show("multi-head-attention")
Default Inputs
Section titled “Default Inputs”{ "tokens": [ "The", "cat", "sat", "on" ], "embeddingDim": 8, "numHeads": 2}Input Examples
Section titled “Input Examples”2 heads, dim=8
Section titled “2 heads, dim=8”{ "tokens": [ "The", "cat", "sat", "on" ], "embeddingDim": 8, "numHeads": 2}4 heads, dim=8
Section titled “4 heads, dim=8”{ "tokens": [ "I", "love", "deep", "learning" ], "embeddingDim": 8, "numHeads": 4}Pseudocode
Section titled “Pseudocode”function multiHeadAttention(X, numHeads): d_k = d_model / numHeads for h = 1 to numHeads: Q_h = X × W_Q_h // [n, d_k] K_h = X × W_K_h // [n, d_k] V_h = X × W_V_h // [n, d_k] scores_h = (Q_h × K_hᵀ) / √d_k weights_h = softmax(scores_h) // row-wise head_h = weights_h × V_h concat = [head_1 ; head_2 ; ... ; head_h] // [n, d_model] output = concat × W_O return outputPython
Section titled “Python”def multi_head_attention(X, W_Q, W_K, W_V, W_O, num_heads): d_model = X.shape[-1] d_k = d_model // num_heads heads = [] for h in range(num_heads): Q_h = X @ W_Q[h] # [n, d_k] K_h = X @ W_K[h] # [n, d_k] V_h = X @ W_V[h] # [n, d_k] scores = (Q_h @ K_h.T) / math.sqrt(d_k) weights = softmax(scores, axis=-1) heads.append(weights @ V_h) concat = np.concatenate(heads, axis=-1) # [n, d_model] return concat @ W_OJavaScript
Section titled “JavaScript”function multiHeadAttention(X, numHeads, W_Qs, W_Ks, W_Vs, W_O) { const d_k = X[0].length / numHeads; const heads = []; for (let h = 0; h < numHeads; h++) { const Q = matMul(X, W_Qs[h]); const K = matMul(X, W_Ks[h]); const V = matMul(X, W_Vs[h]); const scores = scale(matMul(Q, transpose(K)), 1 / Math.sqrt(d_k)); const weights = softmaxRows(scores); heads.push(matMul(weights, V)); } const concat = concatenate(heads); // [n, d_model] return matMul(concat, W_O);}Key Concepts
Section titled “Key Concepts”Why Multiple Heads?
Section titled “Why Multiple Heads?”A single attention head can only learn one type of relationship between tokens. Multiple heads allow the model to simultaneously attend to different aspects — for example, one head might learn syntactic dependencies (subject-verb agreement) while another learns semantic relationships (word meaning similarity). This is analogous to having multiple ‘perspectives’ on the same data.
d_k vs d_model
Section titled “d_k vs d_model”Each head operates on d_k = d_model / numHeads dimensions, NOT the full d_model. This means the total computation is roughly the same as single-head attention with full dimensionality, but the model gains the expressiveness of multiple independent attention patterns. The scaling factor in each head uses √d_k, not √d_model.
Concatenation and W_O Projection
Section titled “Concatenation and W_O Projection”After all heads compute their outputs (each of shape [n, d_k]), they are concatenated along the feature dimension to form a [n, d_model] matrix. The final W_O projection (shape [d_model, d_model]) mixes information across heads, allowing the model to combine the different perspectives into a unified representation.
Common Pitfalls
Section titled “Common Pitfalls”- d_model must be divisible by numHeads: If d_model is not evenly divisible by numHeads, the dimension split is impossible. For example, d_model=7 with numHeads=2 fails because 7/2 = 3.5 is not an integer. Always choose numHeads to be a factor of d_model.
- Scaling factor uses d_k, not d_model: A common mistake is scaling by √d_model instead of √d_k. Each head’s attention scores are dot products of d_k-dimensional vectors, so the expected magnitude scales with d_k. Using √d_model would under-scale the scores, leading to overly uniform attention weights.
- Head outputs must be concatenated, not summed: The outputs from different heads are concatenated along the feature axis, not summed or averaged. Summing would lose information and reduce the effective dimensionality. Concatenation preserves all head outputs and lets W_O learn how to combine them.
Q1: If d_model = 12 and numHeads = 3, what is d_k?
- A) 12
- B) 3
- C) 4
- D) 36
Show answer
Answer: C) 4
d_k = d_model / numHeads = 12 / 3 = 4. Each head operates on 4-dimensional queries, keys, and values.
Q2: Why is multi-head attention more expressive than single-head attention?
- A) It uses more total parameters
- B) Each head can learn different attention patterns in different subspaces
- C) It is computationally faster
- D) It uses a larger scaling factor
Show answer
Answer: B) Each head can learn different attention patterns in different subspaces
Multiple heads allow the model to jointly attend to information from different representation subspaces. Each head can specialize in different types of relationships (syntactic, semantic, positional, etc.).
Q3: What is the shape of the concatenated output before the W_O projection?
- A) [seqLen, d_k]
- B) [seqLen, d_model]
- C) [numHeads, seqLen, d_k]
- D) [seqLen, numHeads]
Show answer
Answer: B) [seqLen, d_model]
Each head produces [seqLen, d_k]. Concatenating numHeads outputs gives [seqLen, numHeads × d_k] = [seqLen, d_model], since d_k = d_model / numHeads.
Further Reading
Section titled “Further Reading”- Attention Is All You Need (Vaswani et al., 2017) (article)
- The Illustrated Transformer — Jay Alammar (article)
- Multi-Head Attention — Wikipedia (reference)