diff --git a/Chapters/18-Advanced-Attention/INDEX.md b/Chapters/18-Advanced-Attention/INDEX.md new file mode 100644 index 0000000..5e059af --- /dev/null +++ b/Chapters/18-Advanced-Attention/INDEX.md @@ -0,0 +1,176 @@ +# Advanced Attention + +## KV Caching + +The idea behind this is that during autoregression in autoregressive transformers, +such as GPT, values for $K$ and $V$ are always recomputed, wasting computing power + +![autoregression without caching](./pngs/decoder-autoregression-no-cache.gif) +> GIF taken from https://medium.com/@joaolages/kv-caching-explained-276520203249 + +As we can notice, all tokens for previous steps gets recomputer, as well +K and V values. So, a solution is to cache all keys and values until step $n$ +and compute only that value. + +![autoregression with caching](./pngs/decoder-autoregression-cache.gif) +> GIF taken from https://medium.com/@joaolages/kv-caching-explained-276520203249 + +Moreover, if we discard taking $QK^T$ for previous steps, we can just obtain +the token we are interested in. + +To compute the size needed to have a $KV$ cache, let's go step by step: + +- For each layer, we need to store both $K$ and $V$ that are of the same + dimensions (in this context), the number of heads, so the *"number"* + of $K$ matrices, the head dimension and the number of tokens incoming, + the sequence lenght, and + we need to know the number of bytes for `d_type`, usually a `float16`, thus + 2 Bytes: + + $$ + \text{Layer\_Space} = 2 \times \text{n\_heads} \times \text{d\_heads} \times + \text{seq\_len} \times \text{d\_type} + $$ + +- Now, during training, we pass a minibatch, so we will have a tensor of + dimensions $N \times \text{seq\_length} \times \text{d\_model}$. When they + are processed by $W_K$ and $W_Q$, we will have to store times $N$ more + values: + + $$ + \text{Batch\_Layer\_Space} = \text{Layer\_Space} \times N + $$ + +- If you have $L$ layers, during training, at the end you'll need space + equivalent to: + + $$ + \text{Model\_Space} = \text{Batch\_Layer\_Space} \times L + $$ + +## Multi-Query and Grouped-Query Attention + +The idea here is that we don't need different keys and vectors, but only +different queries. These approaches drastrically reduce memory consumption at +a slight cost of accuracy. + +In **multi-query** approach, we have only one $K$ and $V$ with a number of +queries equal to the number of heads. In the **grouped-query** approach, we have +a hyperparameter $G$ that will determine how many $K$s and $V$s matrices are +in the layer, while the number of queries remains equal to the number of attention +heads. Then, queries will be grouped to some $K_g$ and $V_g$ + +![grouped head attention](./pngs/grouped-head-attention.png) +> Image take from [Multi-Query & Grouped-Query Attention](https://tinkerd.net/blog/machine-learning/multi-query-attention/#multi-query-attention-mqa) + +Now, the new layer size becomes: + +$$ +\text{Layer\_Space} = 2 \times G \times \text{d\_heads} \times +\text{seq\_len} \times \text{d\_type} +$$ + +## Multi-Head Latent Attention + +The idea is to reduce memory consumption by `rank` factoring $Q$, $K$, and $V$ +computation. This means that each matrix will be factorized in 2 matrices so +that: + +$$ +A = M_l \times M_r +\\ +A \in \R^{in \times out}, M_l \in \R^{in \times rank}, +M_r \in \R^{rank \times out} +$$ + +The problem, though, is that this method introduces compression, and the lower +$rank$ is, the more compression artifacts. + +What we are going to compress are the weight matrices for $Q$, $K$ and $V$: + +$$ +\begin{aligned} + Q = X \times W_Q^L \times W_Q^R +\end{aligned} \\ +X \in \R^{N\times S \times d_{\text{model}}}, +W_Q \in \R^{d_{\text{model}} \times (n_\text{head} \cdot d_\text{head})} +\\ +W_Q^L\in \R^{d_{\text{model}} \times rank}, +W_Q^R\in \R^{rank \times (n_\text{head} \cdot d_\text{head})} +\\ +\text{} +\\ +W_Q \simeq W_Q^L \times W_Q^R +$$ + +For simplicity, we didn't write equations for $K$ and $V$ that are basically +the same. However, now we may think that we have just increased the number of +operations from 1 matmul to 2 per each matrix. But the real power lies +when we take a look at the actual computation: + +$$ +\begin{aligned} +H_i &= \text{softmax}\left( + \frac{ + XW_Q^LW_{Q,i}^R \times (XW_{KV}^LW_{K,i}^R)^T + }{ + \sqrt{\text{d\_model}} + } +\right) \times X W_{KV}^L W_{V,i}^R \\ +&= \text{softmax}\left( + \frac{ + XW_Q^LW_{Q,i}^R \times W_{K,i}^{R^T} W_{KV}^{L^T} X^T + }{ + \sqrt{\text{d\_model}} + } +\right) \times X W_{KV}^L W_{V,i}^R \\ +&= \text{softmax}\left( + \frac{ + C_{Q} \times W_{Q,i}^R W_{K,i}^{R^T} \times C_{KV}^T + }{ + \sqrt{\text{d\_model}} + } +\right) \times C_{KV} W_{V,i}^R \\ +\end{aligned} +$$ + +As it can be seen, $C_Q$ and $C_{KV}$ do not depend on the head, so they can be +computed once and shared across all heads, then we have +$W_{Q,i}^R \times W_{K,i}^{R^T}$ that while it depends on the head number, it +can be computer ahead of time as it does not depend on the input. + +So, for each attention head we need just 3 matmuls plus another 2 happening +at runtime. Moreover, if we want to, we can still apply caching over $C_{KV}$ + +## Decoupled RoPE for Multi-Latent Head Attention + +Since `RoPE` is a positional embedding used **during** attention, this causes +problems if we use a standard Multi Latent Head Attention. In fact, it +shoudl be used on both, separately though, matrices $Q$ and $K$. + +Since they come from $C_Q \times W_{Q, i}^R$ and +$W_{KV, i}^{R^T} \times C_{KV}^T$, this means that we can't cache +$W_{Q,i}^R \times W_{K,i}^{R^T}$ anymore. + +A solution is to cache head matrices, but not their product, and compute + new pieces that will be used in `RoPE` and then concatenated +to the actual query and keys: + +$$ +\begin{aligned} +Q_{R,i} &= RoPE(C_Q \times W_{QR,i}^R) \\ +K_{R,i} &= RoPE(X \times W_{KR,i}^L) +\end{aligned} +$$ + +These matrices will then be concatenated with the reconstruciton of $Q$ and $K$. + +## References + +- [KV Caching Explained: Optimizing Transformer Inference Efficiency](https://huggingface.co/blog/not-lain/kv-caching) +- [Transformers KV Caching Explained](https://medium.com/@joaolages/kv-caching-explained-276520203249) +- [How to calculate size of KV cache](https://www.rohan-paul.com/p/how-to-calculate-size-of-kv-cache) +- [Multi-Query & Grouped-Query Attention](https://tinkerd.net/blog/machine-learning/multi-query-attention/#multi-query-attention-mqa) +- [Understanding Multi-Head Latent Attention](https://planetbanatt.net/articles/mla.html) +- [https://machinelearningmastery.com/a-gentle-introduction-to-multi-head-latent-attention-mla/](https://machinelearningmastery.com/a-gentle-introduction-to-multi-head-latent-attention-mla/) +- [DeepSeek-V3 Explained 1: Multi-head Latent Attention](https://medium.com/data-science/deepseek-v3-explained-1-multi-head-latent-attention-ed6bee2a67c4) diff --git a/Chapters/18-Advanced-Attention/pngs/decoder-autoregression-cache.gif b/Chapters/18-Advanced-Attention/pngs/decoder-autoregression-cache.gif new file mode 100644 index 0000000..0509c45 Binary files /dev/null and b/Chapters/18-Advanced-Attention/pngs/decoder-autoregression-cache.gif differ diff --git a/Chapters/18-Advanced-Attention/pngs/decoder-autoregression-no-cache.gif b/Chapters/18-Advanced-Attention/pngs/decoder-autoregression-no-cache.gif new file mode 100644 index 0000000..9e9d599 Binary files /dev/null and b/Chapters/18-Advanced-Attention/pngs/decoder-autoregression-no-cache.gif differ diff --git a/Chapters/18-Advanced-Attention/pngs/grouped-head-attention.png b/Chapters/18-Advanced-Attention/pngs/grouped-head-attention.png new file mode 100644 index 0000000..33790a0 Binary files /dev/null and b/Chapters/18-Advanced-Attention/pngs/grouped-head-attention.png differ