2025-11-29 21:02:47 +01:00

6.4 KiB

Advanced Attention

KV Caching

The idea behind this is that during autoregression in autoregressive transformers, such as GPT, values for K and V are always recomputed, wasting computing power

autoregression without caching

GIF taken from https://medium.com/@joaolages/kv-caching-explained-276520203249

As we can notice, all tokens for previous steps gets recomputer, as well K and V values. So, a solution is to cache all keys and values until step n and compute only that value.

autoregression with caching

GIF taken from https://medium.com/@joaolages/kv-caching-explained-276520203249

Moreover, if we discard taking QK^T for previous steps, we can just obtain the token we are interested in.

To compute the size needed to have a KV cache, let's go step by step:

  • For each layer, we need to store both K and V that are of the same dimensions (in this context), the number of heads, so the "number" of K matrices, the head dimension and the number of tokens incoming, the sequence lenght, and we need to know the number of bytes for d_type, usually a float16, thus 2 Bytes:

    
    \text{Layer\_Space} = 2 \times \text{n\_heads} \times \text{d\_heads} \times
    \text{seq\_len} \times \text{d\_type}
    
  • Now, during training, we pass a minibatch, so we will have a tensor of dimensions N \times \text{seq\_length} \times \text{d\_model}. When they are processed by W_K and W_Q, we will have to store times N more values:

    
    \text{Batch\_Layer\_Space} = \text{Layer\_Space} \times N
    
  • If you have L layers, during training, at the end you'll need space equivalent to:

    
    \text{Model\_Space} = \text{Batch\_Layer\_Space} \times L
    

Multi-Query and Grouped-Query Attention

The idea here is that we don't need different keys and vectors, but only different queries. These approaches drastrically reduce memory consumption at a slight cost of accuracy.

In multi-query approach, we have only one K and V with a number of queries equal to the number of heads. In the grouped-query approach, we have a hyperparameter G that will determine how many $K$s and $V$s matrices are in the layer, while the number of queries remains equal to the number of attention heads. Then, queries will be grouped to some K_g and V_g

grouped head attention

Image take from Multi-Query & Grouped-Query Attention

Now, the new layer size becomes:


\text{Layer\_Space} = 2 \times G \times \text{d\_heads} \times
\text{seq\_len} \times \text{d\_type}

Multi-Head Latent Attention

The idea is to reduce memory consumption by rank factoring Q, K, and V computation. This means that each matrix will be factorized in 2 matrices so that:


A = M_l \times M_r
\\
A \in \R^{in \times out}, M_l \in \R^{in \times rank},
M_r \in \R^{rank \times out}

The problem, though, is that this method introduces compression, and the lower rank is, the more compression artifacts.

What we are going to compress are the weight matrices for Q, K and V:


\begin{aligned}
    Q = X \times W_Q^L \times W_Q^R
\end{aligned} \\
X \in \R^{N\times S \times d_{\text{model}}},
W_Q \in \R^{d_{\text{model}} \times (n_\text{head} \cdot d_\text{head})}
\\
W_Q^L\in \R^{d_{\text{model}} \times rank},
W_Q^R\in \R^{rank \times (n_\text{head} \cdot d_\text{head})}
\\
\text{}
\\
W_Q \simeq W_Q^L \times W_Q^R

For simplicity, we didn't write equations for K and V that are basically the same. However, now we may think that we have just increased the number of operations from 1 matmul to 2 per each matrix. But the real power lies when we take a look at the actual computation:


\begin{aligned}
H_i &= \text{softmax}\left(
    \frac{
        XW_Q^LW_{Q,i}^R \times (XW_{KV}^LW_{K,i}^R)^T
    }{
        \sqrt{\text{d\_model}}
    }
\right) \times X W_{KV}^L W_{V,i}^R \\
&= \text{softmax}\left(
    \frac{
        XW_Q^LW_{Q,i}^R \times W_{K,i}^{R^T} W_{KV}^{L^T} X^T
    }{
        \sqrt{\text{d\_model}}
    }
\right) \times X W_{KV}^L W_{V,i}^R \\
&= \text{softmax}\left(
    \frac{
        C_{Q} \times W_{Q,i}^R  W_{K,i}^{R^T} \times C_{KV}^T
    }{
        \sqrt{\text{d\_model}}
    }
\right) \times C_{KV} W_{V,i}^R \\
\end{aligned}

As it can be seen, C_Q and C_{KV} do not depend on the head, so they can be computed once and shared across all heads, then we have W_{Q,i}^R \times W_{K,i}^{R^T} that while it depends on the head number, it can be computer ahead of time as it does not depend on the input.

So, for each attention head we need just 3 matmuls plus another 2 happening at runtime. Moreover, if we want to, we can still apply caching over C_{KV}

Decoupled RoPE for Multi-Latent Head Attention

Since RoPE is a positional embedding used during attention, this causes problems if we use a standard Multi Latent Head Attention. In fact, it shoudl be used on both, separately though, matrices Q and K.

Since they come from C_Q \times W_{Q, i}^R and W_{KV, i}^{R^T} \times C_{KV}^T, this means that we can't cache W_{Q,i}^R \times W_{K,i}^{R^T} anymore.

A solution is to cache head matrices, but not their product, and compute new pieces that will be used in RoPE and then concatenated to the actual query and keys:


\begin{aligned}
Q_{R,i} &= RoPE(C_Q \times W_{QR,i}^R) \\
K_{R,i} &= RoPE(X \times W_{KR,i}^L)
\end{aligned}

These matrices will then be concatenated with the reconstruciton of Q and K.

References