Build DeepSeek from scratch - Part 19: DeepSeek's Multi-Token Prediction (MTP)

December 17, 2025 (3mo ago)

Introduction: Why Multi-Token Prediction Matters

Welcome to this part of our series on building DeepSeek! If you are studying computer science and AI, understanding how large models learn efficiently is key to your future success.

In standard AI models, we use Single Token Prediction (STP). This means that for one piece of input text, the model predicts only one next word, or "token".

Our goal in this post is to understand Multi-Token Prediction (MTP), which is how DeepSeek improves its learning. MTP means that for one input token, the model predicts multiple next tokens.

This method is important because it offers better data efficiency and creates a denser training signal. This means the model learns more from the same amount of data, making the training process better. DeepSeek used this clever MTP method mainly during its pre-training phase.

The Core Idea: Predicting Multiple Tokens

To predict multiple future tokens, we need more than one prediction output.

1. Heads for Depth

If we want to predict $K$ future tokens, we need $K$ separate prediction heads. This prediction amount, $K$, is called the prediction depth.

For example, if we set the depth $K=3$, we use three heads:

We repeat this process for every single input token in the sequence.

2. DeepSeek’s Innovation: The Causal Chain

Older MTP systems predicted the next tokens independently from each other. DeepSeek introduced a key innovation: maintaining a complete causal chain between predictions.

This means that the output from the first prediction step affects the input of the second prediction step.

How the link works:

This linking ensures that information from the past prediction goes into the future predictions, leading to better results.

Building One Prediction Head

Let us look closely at what happens inside one prediction head (for depth $K$). Every head performs the exact same calculations.

For any prediction depth $K$, the head needs two main inputs:

  1. Input Embedding: The vector of the token at the future position ($I+K$).
  2. Hidden State: The resulting state vector from the previous step. For the very first head ($K=1$), this hidden state comes directly from the main set of Transformer blocks.

Here are the four key operations inside the head:

Step 1: Normalization and Merging

Before we join the two inputs, we need to clean them up.

  1. Normalize (RMS Norm): We apply Root Mean Square (RMS) normalization to both the hidden state and the input embedding. This is a simple averaging technique used to normalize the vectors.
  2. Merge (Concatenation): We join these two normalized vectors together. If both vectors are 8 dimensions long, joining them creates one larger vector that is 16 dimensions long (1x8 + 1x8 = 1x16).

Step 2: Linear Projection

The merged vector is too large, so we must project it back to the model's original dimension (e.g., from 16 back to 8).

Step 3: Transformation

The resulting 1x8 vector is then passed through a single, dedicated Transformer block.

This new hidden state is critical: It serves as the hidden state input for the next prediction head in the causal chain.

Step 4: Logits and Loss

Finally, we use the output of the transformer block (the new hidden state) to make the actual prediction for that depth.

  1. Logits Calculation: We pass the hidden state through a shared unembedding matrix (or "logits matrix"). This matrix converts the vector from the embedding dimension (e.g., 8) to the size of the vocabulary (e.g., 50,000).
  2. Prediction: We look at this final vector and select the token index with the highest probability—that is our predicted next token.

For training, we calculate the loss by comparing this predicted token with the actual token at that depth. The total loss for one input token is the sum of the losses from all prediction depths ($K=1, K=2, K=3$).

Summary

DeepSeek's Multi-Token Prediction is a powerful way to accelerate learning. By linking the prediction steps through a causal chain (passing hidden states from one head to the next), DeepSeek improved upon earlier MTP methods.

Remember, while this architecture makes pre-training highly efficient, DeepSeek returned to Single Token Prediction during inference (when the model is used in the real world). They maximized the training benefits of MTP without relying on it for faster operation later.

Understanding the steps—Normalization, Merging, Projection, and Transformation—is key to mastering how complex models operate.