Efficient Memory Management for Large Language Model Serving with PagedAttention
Note that this is short notes about Efficient Memory Management for Large Language Model Serving with PagedAttention, you could check the original paper for more information. Here we mainly focus on the computing of block based attention.
Lets first review some basics for conventional attention (I suppose that you are familiar with self-attention):
Suppose that we have a list of tokens $(x_1, …, x_n)$, And it has shape $(x_1, …, x_n) \in \mathbb{R}^{n\times d}$. For each token, we can get its query, key, value vectors by following linear transformation:
\[q_i = W_q x_i,\: k_i = W_k x_i, \:v_i = W_v x_i \qquad (2)\]And we can use following formula to calculate the attention score and the final output. Suppose we want to compute the $i$ th token (current we already have $i-1$ tokens, and we use the $i-1$ th token as the input to generate the $i$ th token):
We represent the attention score of $i$ th token against all tokens so far (including itself) as $a_{ij}$, as the definition, we know it shold be like this: \(a_{ij} = \frac{exp(q_i^T k_j / \sqrt{d})}{ \sum_{t=1}^i exp(q_i^T k_t / \sqrt{d})} \qquad (3)\)
And the final output (logits) of $i$ th token is: $o_i = \sum_{j=1}^i a_{ij} v_j$ (in original paper, this is also part of formula $(3)$)
Now lets go back to check the the paged attention. First we give definitions of some symbols:
- Block size $B$ (number of tokens per block, actually is the number of token’s corresponding KV cache)
- Key block $K_j = (k_{(j-1)B + 1}, … k_{jB})$ (each blocks store $jB -[(j-1)B + 1] + 1 = B $ tokens, just as defined)
- Value block $V_j = (v_{(j-1)B + 1}, … v_{jB})$
So the attention score computed in formula $(3)$ is modified to:
\[A_{ij} = \frac{ exp(q_i^T K_j / \sqrt{d}) }{\sum_{t=1}^{\lceil i/B \rceil} exp(q_i^T K_t / \sqrt{d})} , \: o_i = \sum_{j=1}^{\lceil i/B \rceil} V_j A_{ij}^T \qquad (4)\]Where $A_{ij} = (a_{i,(j-1)B + 1},…, a_{i, jB})$, Of course, there is only one $a_{ik}$ (where $k = (j-1)B+1, …, jB$ ) is valid because of auto-regression computing. $a_{i,k<k’\leq jB }$ should be masked to close to zero; And $a_{i, (j-1)B + 1 < k’ \leq k}$ should be kept by the mask.
Benifits of Paged Attention:
- Waste of GPU memory could be reduced to block memory level.