Shortcoming of Seq2Seq
The final state is incapable of remembering a long sequence
Seq2Seq Model with Attention
Attention tremendously improves Seq2Seq model.
With attention, Seq2Seq model does not forget source input.
With attention, the decoder knows where to focus.
Downside: much more computation.
Simple RNN + Attention
There are two options to calculate weight: αi=align(hi,s0)
Option1: (Used in original paper)
Then normalize α~1,...,α~m (so that they sum to 1):
[α~1,...,α~m]=Softmax([α~1,...,α~m])
Option2: More popular; the same to Transformer
Linear maps:
ki=WK⋅hi, for i = 1 to m
q0=WQ⋅s0
Inner product:
α~i=kiTq0, for i = 1 to m
Normalization
[α~1,...,α~m]=Softmax([α~1,...,α~m])
Calculate the next state
Simple RNN: s1=tanh(A′⋅[x1′s0]+b)
Simple RNN + Attention: s1=tanh(A′⋅x1′s0c0+b)
Context vector: c0=α1h1+...+αmhm
For next state s2, do not use the previously calculated αi
Compute parameters For
jthstate
Query: q:j=WQsj (To match others)
Key: k:i=WKhi (To be matched)
Value: v:i=WVhi (To be weighted averaged)
Weights:αij=align(hi,sj)
Compute k:i=WKhiand q:j=WQsj
Compute weights: α:j=Softmax(KTq:j)∈Rm
Context vector: cj=αijv:1+...+αmjv:m
Time complexity
Question: How many weights αi have been computed?
To compute one vector cj, we compute m weights: α1,...,αm.
• The decoder has t states, so there are totally m⋅t weights.
Weights Visualization
Summary
Standard Seq2Seq model: the decoder looks at only its current state.
Attention: decoder additionally looks at all the states of the encoder.
Attention: decoder knows where to focus.
Downside: higher time complexity.
m: source sequence length
t: target sequence length
Standard Seq2Seq: O(m+t) time complexity
Seq2Seq + attention: O(m⋅t) time complexity
References:
Bahdanau, Cho, & Bengio. Neural machine translation by jointly learning to align and translate.
In ICLR, 2015.