Introduction: Why Multi-Head Attention Matters
Welcome to the DeepSeek architecture series! This part focuses on a critical building block: Multi-Head Attention (MHA).
If you want to understand how large language models (LLMs) like DeepSeek work, you must understand MHA. MHA is the main mechanism that powers these models and allows them to understand complex language.
For students learning AI and systems engineering, knowing how to implement MHA from scratch is essential. It helps you see how foundational concepts like self-attention, linear algebra, and masking work together to build powerful AI. Our goal is deep technical understanding, not just using libraries.
The Problem: One View is Not Enough
First, let us remember Self-Attention (SA). The main goal of SA is to take the input text embeddings and turn them into context embeddings. Context embeddings are important because they show how one word relates to all other words nearby.
Self-Attention is powerful, but it has a major limit. SA can only capture one perspective or meaning of a given sentence.
Consider this simple sentence: "The artist painted the portrait of a woman with a brush."
This sentence can have two possible meanings (perspectives):
- Perspective 1: The artist used the brush to create the portrait of the woman.
- Perspective 2: The portrait shows a woman who is holding a brush in her hand.
Since Self-Attention only produces one Attention Scores Matrix, it can only decide on one perspective, ignoring the other possible meaning. We want our AI model to be knowledgeable and rich enough to capture all possible meanings or views present in the text.
The Solution: Divide and Conquer
To solve the single-perspective problem, we use Multi-Head Attention.
The core idea is simple: If one Self-Attention mechanism (or one "head") captures one perspective, then multiple Self-Attention mechanisms (multiple heads) can capture multiple perspectives.
You can think of this approach as "divide and conquer". We divide the process into smaller parts, and each part (head) focuses on finding a different relationship or meaning in the text. For example, one head might pay more attention to verbs, while another pays more attention to the relationship between nouns.
When we use MHA, we create multiple copies of the attention mechanism. Each copy processes the input and finds a unique perspective. We then combine all these results.
Multi-Head Attention: Step-by-Step Implementation
We will now look at the step-by-step mathematical flow that makes MHA work, linking the concept to the code implementation.
Step 1: Define Dimensions and Input
We start with the input embedding matrix, X. It typically has three dimensions: (Batch Size, Number of Tokens, D_in). D_in is the input dimension.
Before we begin calculations, we must decide two fixed values:
- Output Dimension ($D_{out}$): This is the final size of the output for each token.
- Number of Attention Heads ($N_{heads}$): The number of perspectives we want to capture.
Using these, we calculate the size of each head, called the Head Dimension ($H_{dim}$): $$H_{dim} = D_{out} / N_{heads}$$ For example, if $D_{out}$ = 6 and $N_{heads}$ = 2, then $H_{dim}$ = 3.
Step 2: Calculate Keys, Queries, and Values
Next, we multiply the input matrix X with the randomly initialized, trainable weight matrices: $W^Q$ (Query), $W^K$ (Key), and $W^V$ (Value). These matrices have dimensions ($D_{in}, D_{out}$).
This multiplication gives us the Queries (Q), Keys (K), and Values (V) matrices. The shape of Q, K, and V is now (Batch Size, Number of Tokens, $D_{out}$).
Step 3: Reshaping for Multiple Heads
This is where the multi-head logic begins. We must prepare the Q, K, and V matrices for the separate attention heads.
We split the output dimension ($D_{out}$) into ($N_{heads}, H_{dim}$). The matrix shape changes from (B, T, $D_{out}$) to: $$(B, \text{Number of Tokens}, \mathbf{N_{heads}}, \mathbf{H_{dim}})$$
This reshaping means that the vector for each token is now visually split into sections, one for each head.
Step 4: Grouping by Head (Transposing)
Currently, the matrix is grouped by the Number of Tokens. To perform parallel calculations for each head, we must group the matrices by the Number of Heads.
We interchange the Token dimension and the Head dimension (a process called transposing). The shape becomes: $$(B, \mathbf{N_{heads}}, \text{Number of Tokens}, H_{dim})$$
Now, all the data for Head 1 is together, all data for Head 2 is together, and so on. This setup is crucial for the next calculation.
Step 5: Calculating Attention Scores
With the matrices grouped by head, we compute the Attention Scores by taking the dot product between the Queries and the Keys transpose for each head.
- Head 1: $Q_1 \times K_1^T$
- Head 2: $Q_2 \times K_2^T$
- ...and so on.
The resultant Attention Scores Matrix has the shape: $$(B, N_{heads}, \text{Number of Tokens}, \text{Number of Tokens})$$ This step gives us multiple attention scores matrices—one for each head—where each matrix captures a different perspective.
Step 6: Normalizing Attention Scores (Attention Weights)
Attention scores must be processed to become Attention Weights.
- Scaling: We divide the attention scores by $\sqrt{H_{dim}}$. This simple step is vital to keep the mathematical variance stable during training.
- Causal Attention: Since we are building an LLM that predicts the next token, we cannot allow it to "peek" into the future. We use a mask to set all attention scores related to future tokens to negative infinity ($-\infty$).
- Softmax: We apply the Softmax function. Softmax ensures that the weights for any given token sum up to 1.
The result is the Attention Weights Matrix, which retains the same shape: (B, $N_{heads}$, T, T).
Step 7: Calculating Context Vectors
We use the Attention Weights to influence the values. We multiply the Attention Weights Matrix by the Values (V) Matrix. $$Context = Attention Weights \times V$$
This results in a set of Context Vector Matrices. The shape remains: (B, $N_{heads}$, T, $H_{dim}$).
The context vector for each token is now much richer because it includes information about how it relates to other tokens, filtered by the perspective of that specific head.
Step 8: Merging the Heads
The final step is to merge the outputs from all the different heads back together.
First, we transpose the context matrix back to group by tokens, returning the shape to (B, T, $N_{heads}$, $H_{dim}$). Then, we merge the last two dimensions ($N_{heads}$ and $H_{dim}$). Since $N_{heads} \times H_{dim} = D_{out}$, the final output matrix shape is: $$(B, \text{Number of Tokens}, \mathbf{D_{out}})$$
The output shape looks the same as if we used only one head. However, this final context matrix is much richer because it contains the multiple perspectives captured by Head 1, Head 2, and all other heads.