Positional Encoding in Transformers
Transformer architecture is famous for a while having precisely designed components in itself such as Encoder-Decoder stack, self-attention, multi-headed attention and positional encoding. In this blogpost I will be talking about position embeddings added to the word embeddings in beginning of the transformer architecture to understand the need of adding such embeddings.
Why position encoding is added ?
We do not add any positional encoding to the input embeddings in other architectures/models such as RNN, LSTM or GRU but we do so in transformers. The very first question is the need to add such encodings.
Well If we see recurrence based models, then by default the processing of inputs is sequential in nature. They take as input the first word, followed by second, third and so on. Since words appear sequentially, so the recurrent models need not care about the ordering of the words, and they know which word came first, second and so on. But in transformer like models, all the words in the input are passed at same time and also the attention mechanism takes weighted average of inputs for calculating attention score. There is no way that input order is retained/learned at any other place in the architecture. Therefore, it’s important to send this critical information of ordering of words through the input embedding of the words itself. And that is the only reason for adding positional encoding to the transformer input.
How positional encoding is added ?
There are multiple ways in which positional encoding could be added to the input embeddings. Lets explore the possible solutions :-
- Add the index value of input word in the sentence as its position: Such as assign “1” for first word, “2” for second and so on. But the issue doing so is that as the value gets large with increasing length of sentences. Also model can see sentences longer than what it was trained on, so model may not generalize well in those cases.
- Add a normalised value of input word in the sentence as its position: A simple solution to problems in earlier approach is to normalise the embedding added to the model to say a range say [0, 1](not [-1, 1] since negative position do not make sense). Such that if length of sentence is 6 then position embedding for each word could be [0, 0.2, 0.4, 0.6, 0.8, 1] while for 5 length sentence the same would be [0, 0.25, 0.50, 0.75, 1]. But see that the second encoding in both the cases has different value(0.2 for sentence with 6 words, 0.25 for a sentence with 5 words). This should actually remain constant for a given a particular index of word in the sentence.
- Add a vector representing the positional encoding: Both techniques described earlier ar esimple and scalar values that could be added to each input embedding vector, just like a scaling factor of word embedding at a particular location. One could choose to add a vector instead of scalar such that all the properties we are looking in a good encoding could be represented through these position encoding vectors.
What should be the properties of good encoding vectors then…
- Encoding of position should be unique
- Distance between two consecutive time steps should be constant
- Encoding scheme should generalize well to longer sentences
- Encoding vectors should be deterministic
How position encoding vector is calculated ?
For calculating positional encoding vectors, authors of [1] suggested to use Sine and Cosine functions. But why? Although these trigonometric functions are periodic but the combination of these functions with varying wavelength are unique.
Lets see a sine wave with wavelength of 2π and what will be the encoding of different index values using it
We are getting different values for 1st, 2nd, .., 6th position but if we carefully note, then 4th and 6th index position have values quite similar which will make it difficult to disambiguate between different indexes. So we instead of considering one periodic Sine wave, do consider multiple Sine waves. Such that the combination of sine values(i.e. vector of values) is unique for each position(eg 4th and 6th)
To bring more diversity in the vector so calculated for position encoding, we use another periodic function Cosine as well and the formula for calculation used in the paper is
where the wavelength of waves vary from 2π to 2π * 10000. Here
- t is the position index of word in sentence
- d is length of encoding vector required and 0 ≤ k < d/2
Coding[Link to code in Appendix] the same and visualising the positional encoding of dimension 512 for sentence of max length 100 will look like following:-
A quick check if the embeddings obtained in this manner follow all the required properties or not…
- Uniqueness: The vector of values for each position are unique as seen in graph above
- Distance between any two consecutive position embeddings calculated using l2-norm of difference of two vectors is constant 3.71245 So time-step delta for sequence is varying length is constant
- This method of calculation generalize well for sentences of any length
- The position encodings calculated are deterministic in nature i.e. given a position index, the corresponding position encoding is always same
References
[1] Vaswani, Ashish, et al. “Attention is all you need.” Advances in neural information processing systems 30 (2017).
[2] https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
Appendix
Code for calculating positional encodings
import torch
import numpy as np
from matplotlib import pyplot as plt
def posEnc(seq_len, n = 10000, d = 512):
encoding = np.zeros((seq_len, d))
for k in range(seq_len):
for i in np.arange(int(d/2)):
denominator = np.power(n, (2*i/d))
encoding[k, 2*i] = np.sin(k/denominator)
encoding[k, 2*i+1] = np.cos(k/denominator)
return encoding
sentence_len = 100
sent_enc = posEnc(seq_len = sentence_len)
l2_norms_consecutive = []
for i in range(sentence_len-1):
arr = torch.from_numpy(sent_enc[i] - sent_enc[i+1])
l2_norm = round(torch.linalg.norm(arr, ord = 2).item(), 5)
l2_norms_consecutive.append(l2_norm)
print(l2_norms_consecutive)
colormap = plt.cm.get_cmap('cividis')
colors = colormap(sent_enc)
cax = plt.matshow(sent_enc,cmap=colormap)
plt.gcf().colorbar(cax)
plt.show()
Keep Learning, Keep Hustling