Query, Key and Value in Attention mechanism

Nikhil Verma
5 min readMar 25, 2022

--

Transformers are like bread and butter of any new research methodology and business idea developed in the field of deep learning and especially Natural language processing tasks. All credits go to two pioneering papers:-

  1. Neural Machine Translation by Jointly Learning to Align and Translate
  2. Attention Is All You Need [Vaswani’s et al. 2017]

The key idea of the transformer is the Attention mechanism which takes into account not only inputs from one direction ( as in RNN/LSTM ), but holistically sees all the tokens in the input. Attention in very simple terms is weighted average of some values. While recurrent networks have to look at each word one by one, and as a result have issues with long term memory, transformers have the entire sentence look at itself simultaneously to determine what each word should “pay attention to” amongst itselves.

Vaswani’s paper talks about ways to calculate attention using Query, key and value vectors. What these vectors represent, how are they generated and how one should interpret their interaction terms is going to be the main discussion point in this article.

Looking from the beginning how these QKV vectors are generated, they are simply generated by vector-to-matrix multiplication with three different matrices. The vector here is simply word embedding of any input token obtained from corpus (or could be output of previous attention layer). While the matrices are initially randomly initialized weight matrices and are mostly rectangular to reshape the input vector into a smaller shape representing Query, key and value vectors.

From Jay almar’s blog

The three vectors obtained here could be well understood using a simple medium article’s example.

  • You login to medium and search for some topic of interest — This is Query
  • Medium has a database of article Title and hash key of article itself — Key and Value
  • Define some similarity between your query and Titles in database — ex Dot product of Query and Key
  • Extract the hash key with maximum match
  • Return article(Value) corresponding to has key obtained

In this fashion it could be thought that attention is actually working as a retrieval operation in a database. As mentioned in the paper 1, attention by definition is just a weighted average of values

where ∑𝛼𝑗 =1.

If we restrict 𝛼 to be a one-hot vector, this operation becomes the same as retrieving from a set of elements ℎ with index α. With the restriction removed, the attention operation can be thought of as doing “proportional retrieval” according to the probability vector α.

It should be clear that h in this context is the value. The difference between the two papers(1 and 2 above) lies in how the probability vector α is calculated in equations mentioned below.

The first paper (Bahdanau et al. 2015) computes the score through a neural network

A more efficient model would be to first project s and h onto a common space, then choose a similarity measure (e.g. dot product) as the attention score, like

So basically:

  • q = the vector representing a word
  • K and V = your memory, thus all the words that have been generated before. Note that K and V can be the same (but don’t have to).

So what you do with attention is that you take your current query (word in most cases) and look in your memory for similar keys. To come up with a distribution of relevant words, the softmax function is then used.

For example consider sentence with 9 input words in the sentence

  • I like Natural Language Processing, a lot !

Walking through an example for the first word ‘I’:

  • The query is the input word vector for the token “I”
  • The keys are the input word vectors for all the other tokens, and for the query token too, i.e (semi-colon delimited in the list below):
    [like;Natural;Language;Processing;,;a;lot;!] + [I]
  • The word vector of the query is then DotProduct-ed with the word vectors of each of the keys, to get 9 scalars / numbers a.k.a “weights”
  • These weights are then scaled, but this is not important to understand the intuition
  • The weights then go through a ‘softmax’ which is a particular way of normalizing the 9 weights to values between 0 and 1. This becomes important to get a “weighted-average” of the value vectors , which we see in the next step.
  • Finally, the initial 9 input word vectors a.k.a values are summed in a “weighted average”, with the normalised weights of the previous step. This final step results in a single output word vector representation of the word “I”

Now that we have the process for the word “I”, rinse and repeat to get word vectors for the remaining 8 tokens. We now have 9 output word vectors, each put through the Scaled Dot-Product attention mechanism. You can then add a new attention layer/mechanism to the encoder, by taking these 9 new outputs (a.k.a “hidden vectors”), and considering these as inputs to the new attention layer, which outputs 9 new word vectors of its own.

Summarising it by pointing out that each token (query) is free to take as much information using the dot-product mechanism from the other words (values), and it can pay as much or as little attention to the other words as it likes by weighting the other words with (keys).

Keep Learning Keep Hustling

--

--

Nikhil Verma
Nikhil Verma

Written by Nikhil Verma

Knowledge shared is knowledge squared | My Portfolio https://lihkinverma.github.io/portfolio/ | My blogs are living document, updated as I receive comments