This post talks about the paper that introduced the Multi-query Attention (MQA). MQA is devised by Google, so here the code snippets are all based on Tensorflow. (At that time, it was still very popular)

This paper first give a review about attention, multi-head attention and then introduces Multi-query attention (MQA).

If you are already familiar with attention and multi-head attention, you can jump to the last part for MQA.

Multi-head Attention

Some terminologies:

  • $P_q$: learned linear projections, of an input vector $x$
  • $P_k$: learned linear projections for the keys.
  • $P_v$: learned linear projections for the values.
  • $m$: the sequence that $x$ will attend to (e.g. In the encoder-decoder model, $m$ usually represents the sentence to be translated)
  • $h$: the number of attention heads. (or in original paper, it is called “$h$ diffrent attention layers”, but here it actually indicates attention heads”)
  • $d$: the dimension of hidden state
def MultiHeadAttention(x, M, P_q, P_k, P_v, P_o):
  """Multi-head attention on one query
  Args:
    x: a vector, with shape [d]
    M: a matrix with shape [m,d]
    P_q: a tensor with shape [h, d, k]
    P_k: a tensor with shape [h, d, k]
    P_v: a tensor with shape [h, d, v]
    P_o: a tensor with shape [h, d, v]
  Returns:
    y: a vector with shape [d]
  """
  q = tf.einsum("d, hdk->hk", x, P_q)
  K = tf.einsum("md, hdk -> hmk", M, P_k)
  V = tf.einsum("md, hdv -> hmv", M, P_v)
  logits = tf.einsum("hk, hmk -> hm", q, K)
  weights = tf.softmax(logits)
  o = tf.einsum("hm, hmv->hv", weights, V)
  y = tf.einsum("hv, hdv->d", o, P_o)
  return y

After this, the author also introduced a batched version of MultiHeadAttention:

def MultiHeadAttentionBatched(x, M, P_q, P_k, P_v, P_o):
  """Multi-head attention
  Args:
    X: a tensor with shape [b, n, d] (Here X is going to attend to M)
    M: a tensor with shape [b, m, d]
    mask: a tensor with shape [b, h, n, m]
    P_q: a tensor with shape [h, d, k]
    P_k: a tensor with shape [h, d, k]
    P_v: a tensor with shape [h, d, v]
    P_o: a tensor with shape [h, d, v]
  Returns:
    Y: a tensor with shape [b, n, d]
  """
  Q = tf.einsum("bnd, hdk->bhnk", X, P_q) # [b, h, n, k]
  K = tf.einsum("bmd, hdk -> bhmk", M, P_k) # [b, h, m, k]
  V = tf.einsum("bmd, hdv  -> bhmv", M, P_v) # [b, h, m, v]
  logits = tf.einsum("bhnk, bhmk -> bhnm", Q, K) # [b, h, n, m]
  weights = tf.softmax(logits + mask)
  O = tf.einsum("bhnm, bhmv -> bhnv", weights, V) # [b, h, n, v]
  Y = tf.einsum("bhnv, hdv -> bnd", O, P_o) # [b, n, d]
  return Y

Next, let’s check the self-attention mechanism, in which the generated next token is dependent on all previous tokens. :

def MultiheadSelfAttentionIncrement(x, prev_K, prev_V, P_k, P_v, P_o):
  """Multi-head Self-attention (one step)
  Args:
    x: a tensor with shape [b, d]
    prev_K: tensor with shape [b, h, m, k]
    prev_V: tensor with shape [b, h, m, v]
    P_q: a tensor with shape [h, d, k]
    P_k: a tensor with shape [h, d, k]
    P_v: a tensor with shape [h, d, v]
    P_o: a tensor with shape [h, d, v]
  Returns:
    y: a tensor with shape [b,d],
    new_K: tensor with shape [b, h, m+1, k]
    new_V: tensor with shape [b, h, m+1, v]
  """
  # TODO

Multi-Query Attention

What exactly is the multi-query attention?

In short, actually it is a variation of multi-head attention. All different heads share the a single set of keys and values. The following code tells much better of this “Multi-Query Attention”.

def MultiqueryAttentionBatched(X, M, mask, P_q, P_k, P_v, P_o):
  """Multi-query attention
  Args:
    X: a tensor with shape [b, n, d]
    M: a tensor with shape [b, m, d]
    mask: a tensor with shape [b, h, n, m]
    P_q: a tensor with shape [h, d, k]
    P_k: a tensor with shape [d, k]
    P_v: a tensor with shape [d, v]
    P_o: a tensor with shape [h, d, v]
  Returns:
    Y: a tensor with shape [b, n, d]
  """
  Q = tf.einsum("bnd, hdk->bhnk", X, P_q)
  K = tf.einsum("bmd, dk->bmk", M, P_k)
  V = tf.einsum("bmd, dv->bmv", M, P_v)
  logits = tf.einsum("bhnk, bmk->bhnm", Q, K)
  weights = tf.softmax(logits + mask)
  O = tf.einsum("bhnm, bmv->bhnv", weights, V)
  Y = tf.einsum("bhnv, hdv->bnd", O, P_o)
  return Y

References: