MQA: multi-query attention
This post talks about the paper that introduced the Multi-query Attention (MQA). MQA is devised by Google, so here the code snippets are all based on Tensorflow. (At that time, it was still very popular)
This paper first give a review about attention, multi-head attention and then introduces Multi-query attention (MQA).
If you are already familiar with attention and multi-head attention, you can jump to the last part for MQA.
Multi-head Attention
Some terminologies:
- $P_q$: learned linear projections, of an input vector $x$
- $P_k$: learned linear projections for the keys.
- $P_v$: learned linear projections for the values.
- $m$: the sequence that $x$ will attend to (e.g. In the encoder-decoder model, $m$ usually represents the sentence to be translated)
- $h$: the number of attention heads. (or in original paper, it is called “$h$ diffrent attention layers”, but here it actually indicates attention heads”)
- $d$: the dimension of hidden state
def MultiHeadAttention(x, M, P_q, P_k, P_v, P_o):
"""Multi-head attention on one query
x: a vector, with shape [d]
M: a matrix with shape [m,d]
P_q: a tensor with shape [h, d, k]
P_k: a tensor with shape [h, d, k]
P_v: a tensor with shape [h, d, v]
P_o: a tensor with shape [h, d, v]
y: a vector with shape [d]
q = tf.einsum("d, hdk->hk", x, P_q)
K = tf.einsum("md, hdk -> hmk", M, P_k)
V = tf.einsum("md, hdv -> hmv", M, P_v)
logits = tf.einsum("hk, hmk -> hm", q, K)
weights = tf.softmax(logits)
o = tf.einsum("hm, hmv->hv", weights, V)
y = tf.einsum("hv, hdv->d", o, P_o)
return y
After this, the author also introduced a batched version of MultiHeadAttention:
def MultiHeadAttentionBatched(x, M, P_q, P_k, P_v, P_o):
"""Multi-head attention
X: a tensor with shape [b, n, d] (Here X is going to attend to M)
M: a tensor with shape [b, m, d]
mask: a tensor with shape [b, h, n, m]
P_q: a tensor with shape [h, d, k]
P_k: a tensor with shape [h, d, k]
P_v: a tensor with shape [h, d, v]
P_o: a tensor with shape [h, d, v]
Y: a tensor with shape [b, n, d]
Q = tf.einsum("bnd, hdk->bhnk", X, P_q) # [b, h, n, k]
K = tf.einsum("bmd, hdk -> bhmk", M, P_k) # [b, h, m, k]
V = tf.einsum("bmd, hdv -> bhmv", M, P_v) # [b, h, m, v]
logits = tf.einsum("bhnk, bhmk -> bhnm", Q, K) # [b, h, n, m]
weights = tf.softmax(logits + mask)
O = tf.einsum("bhnm, bhmv -> bhnv", weights, V) # [b, h, n, v]
Y = tf.einsum("bhnv, hdv -> bnd", O, P_o) # [b, n, d]
return Y
Next, let’s check the self-attention mechanism, in which the generated next token is dependent on all previous tokens. :
def MultiheadSelfAttentionIncrement(x, prev_K, prev_V, P_k, P_v, P_o):
"""Multi-head Self-attention (one step)
x: a tensor with shape [b, d]
prev_K: tensor with shape [b, h, m, k]
prev_V: tensor with shape [b, h, m, v]
P_q: a tensor with shape [h, d, k]
P_k: a tensor with shape [h, d, k]
P_v: a tensor with shape [h, d, v]
P_o: a tensor with shape [h, d, v]
y: a tensor with shape [b,d],
new_K: tensor with shape [b, h, m+1, k]
new_V: tensor with shape [b, h, m+1, v]
Multi-Query Attention
What exactly is the multi-query attention?
In short, actually it is a variation of multi-head attention. All different heads share the a single set of keys and values. The following code tells much better of this “Multi-Query Attention”.
def MultiqueryAttentionBatched(X, M, mask, P_q, P_k, P_v, P_o):
"""Multi-query attention
X: a tensor with shape [b, n, d]
M: a tensor with shape [b, m, d]
mask: a tensor with shape [b, h, n, m]
P_q: a tensor with shape [h, d, k]
P_k: a tensor with shape [d, k]
P_v: a tensor with shape [d, v]
P_o: a tensor with shape [h, d, v]
Y: a tensor with shape [b, n, d]
Q = tf.einsum("bnd, hdk->bhnk", X, P_q)
K = tf.einsum("bmd, dk->bmk", M, P_k)
V = tf.einsum("bmd, dv->bmv", M, P_v)
logits = tf.einsum("bhnk, bmk->bhnm", Q, K)
weights = tf.softmax(logits + mask)
O = tf.einsum("bhnm, bmv->bhnv", weights, V)
Y = tf.einsum("bhnv, hdv->bnd", O, P_o)
return Y