Understanding Grouped Query Attention (GQA)
1. Introduction: The Memory Challenge
Welcome, future AI engineers! In this part of our series, we focus on a critical technique used in modern large language models (LLMs): Grouped Query Attention (GQA).
When we use big AI models like DeepSeek, we need to save the Keys (K) and Values (V) matrices during inference. This stored data is called the KV Cache. Caching K and V saves computation time because we do not have to calculate them again for every new token prediction. This makes the computation time increase linearly, which is much better than quadratically (without caching).
However, for huge models, the KV cache takes up a lot of memory space. For a complex model like DeepSeek, the KV cache size can grow to 400 GB, even for a single input. This large memory requirement makes inference slow and expensive.
To build efficient systems, we must solve this memory problem. The goal of GQA is to reduce the memory needed for the KV cache without losing too much performance.
2. A Quick Look Back: MHA and MQA
Before GQA, engineers used two main methods for attention, each with a trade-off:
2.1 Multi-Head Attention (MHA)
MHA is the traditional method.
- In MHA, every attention head has its own unique Keys ($\text{W}{\text{K}}$) and Values ($\text{W}{\text{V}}$) matrices.
- Because every head has different values for K and V, MHA captures many different perspectives of the text. This gives the best performance and context understanding (high diversity).
- The problem: Since every head is different, we must cache and store the K and V matrices for all attention heads separately. This leads to the largest KV cache size. The cache size grows with the number of attention heads (N).
2.2 Multi-Query Attention (MQA)
MQA is a memory-saving trick.
- In MQA, all attention heads share the exact same Keys ($\text{W}{\text{K}}$) and Values ($\text{W}{\text{V}}$) matrices.
- The benefit: Since all heads are the same, we only need to store one set of K and V matrices in the cache. This gives the smallest KV cache size, achieving the maximum memory reduction (e.g., reducing 400 GB to just 3 GB in DeepSeek's context).
- The problem: This trick severely limits diversity. When all heads see the same K and V, the model cannot capture many different perspectives. This leads to significant performance degradation.
3. Grouped Query Attention (GQA): The Middle Path
Grouped Query Attention (GQA) is a smart solution that sits between MHA and MQA.
3.1 How GQA Works
Instead of forcing all attention heads to share the same K and V (like MQA), GQA creates groups of attention heads.
- Grouping: The total attention heads (N) are split into smaller groups (G).
- Sharing within a Group: Heads that belong to the same group share the identical Key and Value matrices.
- Difference Across Groups: Heads that belong to different groups use different Key and Value matrices.
For example, if a model has 32 heads ($N=32$) and we create 8 groups ($G=8$), there will be 4 heads in each group. All 4 heads in Group 1 share the same K and V. But the K and V used by Group 1 are completely different from the K and V used by Group 2.
3.2 GQA Trade-offs
GQA successfully optimizes both memory and performance.
| Feature | MHA (Multi-Head) | GQA (Grouped Query) | MQA (Multi-Query) | | :--- | :--- | :--- | :--- | | KV Cache Size | Largest (stores N matrices) | Middle (stores G matrices) | Smallest (stores 1 matrix) | | Performance (Diversity) | Best (full diversity) | Middle (some diversity lost) | Worst (low diversity) |
1. Improved Diversity: Since different groups have different K/V content, GQA captures more perspectives than MQA. If a head is in Group 1 and another head is in Group 8, they capture completely different information. This makes the model perform much better than MQA.
2. Reduced Memory: We only need to cache one K and one V matrix for each group.
- The KV cache size reduction is proportional to $\text{N} / \text{G}$.
- If a model has 128 heads ($N=128$) and 8 groups ($G=8$), the memory is reduced by 16 times (128/8) compared to MHA. This is a massive saving while maintaining strong performance.
By using GQA, engineers get the memory efficiency needed for large-scale systems while minimizing the loss of context understanding, moving closer to the performance of MHA. This approach is why major models, like Llama 3 (8B and 70B), use Grouped Query Attention.
4. Next Steps in DeepSeek’s Architecture
GQA solves the immediate problem of MQA's low performance, but it still does not achieve the "best of both worlds"—MHA performance with MQA memory size.
DeepSeek introduces an even newer innovation: Multi-Head Latent Attention. This technique attempts to answer the golden question: Can we have the highest performance and the smallest KV cache size?. We will explore this beautiful solution in the next lecture, which is the final key innovation in the DeepSeek architecture.