Skip to content

Multi-Head Attention

Category: Generative AI
Difficulty: Advanced
Time Complexity: O(h × n² × d_k)
Space Complexity: O(h × n² + n × d)

Multi-Head Attention is a core mechanism in the Transformer architecture that runs several attention functions in parallel. Instead of computing a single attention pass with d_model-dimensional keys, queries, and values, the model splits them into h heads, each operating on d_k = d_model / h dimensions. Each head learns to attend to different aspects of the input — one head might focus on syntactic relationships while another captures semantic similarity. After all heads compute their attention independently, their outputs are concatenated back into a d_model-dimensional vector and passed through a final linear projection W_O. This projection allows the model to mix information across heads. The key insight is that multiple smaller attention operations are more expressive than a single large one, enabling the model to jointly attend to information from different representation subspaces at different positions.

{
"tokens": [
"The",
"cat",
"sat",
"on"
],
"embeddingDim": 8,
"numHeads": 2
}
{
"tokens": [
"The",
"cat",
"sat",
"on"
],
"embeddingDim": 8,
"numHeads": 2
}
{
"tokens": [
"I",
"love",
"deep",
"learning"
],
"embeddingDim": 8,
"numHeads": 4
}
function multiHeadAttention(X, numHeads):
d_k = d_model / numHeads
for h = 1 to numHeads:
Q_h = X × W_Q_h // [n, d_k]
K_h = X × W_K_h // [n, d_k]
V_h = X × W_V_h // [n, d_k]
scores_h = (Q_h × K_hᵀ) / √d_k
weights_h = softmax(scores_h) // row-wise
head_h = weights_h × V_h
concat = [head_1 ; head_2 ; ... ; head_h] // [n, d_model]
output = concat × W_O
return output
def multi_head_attention(X, W_Q, W_K, W_V, W_O, num_heads):
d_model = X.shape[-1]
d_k = d_model // num_heads
heads = []
for h in range(num_heads):
Q_h = X @ W_Q[h] # [n, d_k]
K_h = X @ W_K[h] # [n, d_k]
V_h = X @ W_V[h] # [n, d_k]
scores = (Q_h @ K_h.T) / math.sqrt(d_k)
weights = softmax(scores, axis=-1)
heads.append(weights @ V_h)
concat = np.concatenate(heads, axis=-1) # [n, d_model]
return concat @ W_O
function multiHeadAttention(X, numHeads, W_Qs, W_Ks, W_Vs, W_O) {
const d_k = X[0].length / numHeads;
const heads = [];
for (let h = 0; h < numHeads; h++) {
const Q = matMul(X, W_Qs[h]);
const K = matMul(X, W_Ks[h]);
const V = matMul(X, W_Vs[h]);
const scores = scale(matMul(Q, transpose(K)), 1 / Math.sqrt(d_k));
const weights = softmaxRows(scores);
heads.push(matMul(weights, V));
}
const concat = concatenate(heads); // [n, d_model]
return matMul(concat, W_O);
}

A single attention head can only learn one type of relationship between tokens. Multiple heads allow the model to simultaneously attend to different aspects — for example, one head might learn syntactic dependencies (subject-verb agreement) while another learns semantic relationships (word meaning similarity). This is analogous to having multiple ‘perspectives’ on the same data.

Each head operates on d_k = d_model / numHeads dimensions, NOT the full d_model. This means the total computation is roughly the same as single-head attention with full dimensionality, but the model gains the expressiveness of multiple independent attention patterns. The scaling factor in each head uses √d_k, not √d_model.

After all heads compute their outputs (each of shape [n, d_k]), they are concatenated along the feature dimension to form a [n, d_model] matrix. The final W_O projection (shape [d_model, d_model]) mixes information across heads, allowing the model to combine the different perspectives into a unified representation.

  • d_model must be divisible by numHeads: If d_model is not evenly divisible by numHeads, the dimension split is impossible. For example, d_model=7 with numHeads=2 fails because 7/2 = 3.5 is not an integer. Always choose numHeads to be a factor of d_model.
  • Scaling factor uses d_k, not d_model: A common mistake is scaling by √d_model instead of √d_k. Each head’s attention scores are dot products of d_k-dimensional vectors, so the expected magnitude scales with d_k. Using √d_model would under-scale the scores, leading to overly uniform attention weights.
  • Head outputs must be concatenated, not summed: The outputs from different heads are concatenated along the feature axis, not summed or averaged. Summing would lose information and reduce the effective dimensionality. Concatenation preserves all head outputs and lets W_O learn how to combine them.

Q1: If d_model = 12 and numHeads = 3, what is d_k?

  • A) 12
  • B) 3
  • C) 4
  • D) 36
Show answer

Answer: C) 4

d_k = d_model / numHeads = 12 / 3 = 4. Each head operates on 4-dimensional queries, keys, and values.

Q2: Why is multi-head attention more expressive than single-head attention?

  • A) It uses more total parameters
  • B) Each head can learn different attention patterns in different subspaces
  • C) It is computationally faster
  • D) It uses a larger scaling factor
Show answer

Answer: B) Each head can learn different attention patterns in different subspaces

Multiple heads allow the model to jointly attend to information from different representation subspaces. Each head can specialize in different types of relationships (syntactic, semantic, positional, etc.).

Q3: What is the shape of the concatenated output before the W_O projection?

  • A) [seqLen, d_k]
  • B) [seqLen, d_model]
  • C) [numHeads, seqLen, d_k]
  • D) [seqLen, numHeads]
Show answer

Answer: B) [seqLen, d_model]

Each head produces [seqLen, d_k]. Concatenating numHeads outputs gives [seqLen, numHeads × d_k] = [seqLen, d_model], since d_k = d_model / numHeads.