Paper: Attention Is All You Need
TLDR
The transformer architecture was designed to address to limitations of Recurrent Neural Networks (RNNs) and Convolutional Neural Networks (CNNs) in sequence modeling by introducing self-attention mechanisms. In prior work, the sequential nature of RNNs restrict parallelization and inhibit capturing dependencies between distant tokens. Transformers solve this issue by using attention mechanisms.
In this work we propose the Transformer, a model architecture eschewing recurrence and instead relying entirely on an attention mechanism to draw global dependencies between input and output. The Transformer allows for significantly more parallelization and can reach a new state of the art in translation quality after being trained for as little as twelve hours on eight P100 GPUs.
Why RNNs Are Problematic
The main computational bottleneck with RNNs are the sequantial nature of the hidden states. The conditional probability of token \(x_t\) at time step \(t\) depends on the previous \(t - 1\) tokens. As a result, it is impossible to calculate all hidden states simultaneously. For the same reason, RNNs suffer from the exploding / vanishing gradient problem during backpropagation, making it computationally expensive and time-consuming to train RNNs.
Imagine you are reading the sentence: “The animal didn’t cross the street because it was too tired.” In order to understand what “it” refers to, the RNN would have to pass information through many steps to remember “animal”. A transformer uses attention mechanisms to create a weighted link between “it” and “animal”.
What is Attention
Self-attention, sometimes called intra-attention is an attention mechanism relating different positions of a single sequence in order to compute a representation of the sequence. Self-attention has been used successfully in a variety of tasks including reading comprehension, abstractive summarization, textual entailment and learning task-independent sentence representations.
Architecture Overview
Why This Works
Takeaways
\[L = \sum_{i=1}^{n} (y_i - \hat{y}_i)^2\]
def loss_function(y, y_hat):
return ((y - y_hat)**2).sum()