Added Advanced Attention Methods

This commit is contained in:
Christian Risi 2025-11-29 21:02:47 +01:00
parent 9640ae1898
commit 3a7f2efa3e
4 changed files with 176 additions and 0 deletions

View File

@ -0,0 +1,176 @@
# Advanced Attention
## KV Caching
The idea behind this is that during autoregression in autoregressive transformers,
such as GPT, values for $K$ and $V$ are always recomputed, wasting computing power
![autoregression without caching](./pngs/decoder-autoregression-no-cache.gif)
> GIF taken from https://medium.com/@joaolages/kv-caching-explained-276520203249
As we can notice, all tokens for previous steps gets recomputer, as well
K and V values. So, a solution is to cache all keys and values until step $n$
and compute only that value.
![autoregression with caching](./pngs/decoder-autoregression-cache.gif)
> GIF taken from https://medium.com/@joaolages/kv-caching-explained-276520203249
Moreover, if we discard taking $QK^T$ for previous steps, we can just obtain
the token we are interested in.
To compute the size needed to have a $KV$ cache, let's go step by step:
- For each layer, we need to store both $K$ and $V$ that are of the same
dimensions (in this context), the number of heads, so the *"number"*
of $K$ matrices, the head dimension and the number of tokens incoming,
the sequence lenght, and
we need to know the number of bytes for `d_type`, usually a `float16`, thus
2 Bytes:
$$
\text{Layer\_Space} = 2 \times \text{n\_heads} \times \text{d\_heads} \times
\text{seq\_len} \times \text{d\_type}
$$
- Now, during training, we pass a minibatch, so we will have a tensor of
dimensions $N \times \text{seq\_length} \times \text{d\_model}$. When they
are processed by $W_K$ and $W_Q$, we will have to store times $N$ more
values:
$$
\text{Batch\_Layer\_Space} = \text{Layer\_Space} \times N
$$
- If you have $L$ layers, during training, at the end you'll need space
equivalent to:
$$
\text{Model\_Space} = \text{Batch\_Layer\_Space} \times L
$$
## Multi-Query and Grouped-Query Attention
The idea here is that we don't need different keys and vectors, but only
different queries. These approaches drastrically reduce memory consumption at
a slight cost of accuracy.
In **multi-query** approach, we have only one $K$ and $V$ with a number of
queries equal to the number of heads. In the **grouped-query** approach, we have
a hyperparameter $G$ that will determine how many $K$s and $V$s matrices are
in the layer, while the number of queries remains equal to the number of attention
heads. Then, queries will be grouped to some $K_g$ and $V_g$
![grouped head attention](./pngs/grouped-head-attention.png)
> Image take from [Multi-Query & Grouped-Query Attention](https://tinkerd.net/blog/machine-learning/multi-query-attention/#multi-query-attention-mqa)
Now, the new layer size becomes:
$$
\text{Layer\_Space} = 2 \times G \times \text{d\_heads} \times
\text{seq\_len} \times \text{d\_type}
$$
## Multi-Head Latent Attention
The idea is to reduce memory consumption by `rank` factoring $Q$, $K$, and $V$
computation. This means that each matrix will be factorized in 2 matrices so
that:
$$
A = M_l \times M_r
\\
A \in \R^{in \times out}, M_l \in \R^{in \times rank},
M_r \in \R^{rank \times out}
$$
The problem, though, is that this method introduces compression, and the lower
$rank$ is, the more compression artifacts.
What we are going to compress are the weight matrices for $Q$, $K$ and $V$:
$$
\begin{aligned}
Q = X \times W_Q^L \times W_Q^R
\end{aligned} \\
X \in \R^{N\times S \times d_{\text{model}}},
W_Q \in \R^{d_{\text{model}} \times (n_\text{head} \cdot d_\text{head})}
\\
W_Q^L\in \R^{d_{\text{model}} \times rank},
W_Q^R\in \R^{rank \times (n_\text{head} \cdot d_\text{head})}
\\
\text{}
\\
W_Q \simeq W_Q^L \times W_Q^R
$$
For simplicity, we didn't write equations for $K$ and $V$ that are basically
the same. However, now we may think that we have just increased the number of
operations from 1 matmul to 2 per each matrix. But the real power lies
when we take a look at the actual computation:
$$
\begin{aligned}
H_i &= \text{softmax}\left(
\frac{
XW_Q^LW_{Q,i}^R \times (XW_{KV}^LW_{K,i}^R)^T
}{
\sqrt{\text{d\_model}}
}
\right) \times X W_{KV}^L W_{V,i}^R \\
&= \text{softmax}\left(
\frac{
XW_Q^LW_{Q,i}^R \times W_{K,i}^{R^T} W_{KV}^{L^T} X^T
}{
\sqrt{\text{d\_model}}
}
\right) \times X W_{KV}^L W_{V,i}^R \\
&= \text{softmax}\left(
\frac{
C_{Q} \times W_{Q,i}^R W_{K,i}^{R^T} \times C_{KV}^T
}{
\sqrt{\text{d\_model}}
}
\right) \times C_{KV} W_{V,i}^R \\
\end{aligned}
$$
As it can be seen, $C_Q$ and $C_{KV}$ do not depend on the head, so they can be
computed once and shared across all heads, then we have
$W_{Q,i}^R \times W_{K,i}^{R^T}$ that while it depends on the head number, it
can be computer ahead of time as it does not depend on the input.
So, for each attention head we need just 3 matmuls plus another 2 happening
at runtime. Moreover, if we want to, we can still apply caching over $C_{KV}$
## Decoupled RoPE for Multi-Latent Head Attention
Since `RoPE` is a positional embedding used **during** attention, this causes
problems if we use a standard Multi Latent Head Attention. In fact, it
shoudl be used on both, separately though, matrices $Q$ and $K$.
Since they come from $C_Q \times W_{Q, i}^R$ and
$W_{KV, i}^{R^T} \times C_{KV}^T$, this means that we can't cache
$W_{Q,i}^R \times W_{K,i}^{R^T}$ anymore.
A solution is to cache head matrices, but not their product, and compute
new pieces that will be used in `RoPE` and then concatenated
to the actual query and keys:
$$
\begin{aligned}
Q_{R,i} &= RoPE(C_Q \times W_{QR,i}^R) \\
K_{R,i} &= RoPE(X \times W_{KR,i}^L)
\end{aligned}
$$
These matrices will then be concatenated with the reconstruciton of $Q$ and $K$.
## References
- [KV Caching Explained: Optimizing Transformer Inference Efficiency](https://huggingface.co/blog/not-lain/kv-caching)
- [Transformers KV Caching Explained](https://medium.com/@joaolages/kv-caching-explained-276520203249)
- [How to calculate size of KV cache](https://www.rohan-paul.com/p/how-to-calculate-size-of-kv-cache)
- [Multi-Query & Grouped-Query Attention](https://tinkerd.net/blog/machine-learning/multi-query-attention/#multi-query-attention-mqa)
- [Understanding Multi-Head Latent Attention](https://planetbanatt.net/articles/mla.html)
- [https://machinelearningmastery.com/a-gentle-introduction-to-multi-head-latent-attention-mla/](https://machinelearningmastery.com/a-gentle-introduction-to-multi-head-latent-attention-mla/)
- [DeepSeek-V3 Explained 1: Multi-head Latent Attention](https://medium.com/data-science/deepseek-v3-explained-1-multi-head-latent-attention-ed6bee2a67c4)

Binary file not shown.

After

Width:  |  Height:  |  Size: 334 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 296 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 230 KiB