Shortcoming of Seq2Seq
The final state is incapable of remembering a long sequence
Seq2Seq Model with Attention
Attention tremendously improves Seq2Seq model.
With attention, Seq2Seq model does not forget source input.
With attention, the decoder knows where to focus.
Downside: much more computation.
Simple RNN + Attention
There are two options to calculate weight: αi​=align(hi​,s0​)
Option1: (Used in original paper)
Then normalize α~1​,...,α~m​ (so that they sum to 1):
[α~1​,...,α~m​]=Softmax([α~1​,...,α~m​])
Linear maps:
ki​=WK​⋅hi​, for i = 1 to m
q0​=WQ​⋅s0​
Inner product:
α~i​=kiT​q0​, for i = 1 to m
Normalization
[α~1​,...,α~m​]=Softmax([α~1​,...,α~m​])
Calculate the next state
Simple RNN: s1​=tanh(A′⋅[x1′​s0​​]+b)
Simple RNN + Attention: s1​=tanh(A′⋅​x1′​s0​c0​​​+b)
Context vector: c0​=α1​h1​+...+αm​hm​
For next state s2​, do not use the previously calculated αi​
Compute parameters For
jthstate
Query: q:j​=WQ​sj​ (To match others)
Key: k:i​=WK​hi​ (To be matched)
Value: v:i​=WV​hi​ (To be weighted averaged)
Weights:αij​=align(hi​,sj​)
Compute k:i​=WK​hi​and q:j​=WQ​sj​
Compute weights: α:j​=Softmax(KTq:j​)∈Rm
Context vector: cj​=αij​v:1​+...+αmj​v:m​
Time complexity
Question: How many weights αi​ have been computed?
To compute one vector cj​, we compute m weights: α1​,...,αm​.
• The decoder has t states, so there are totally m⋅t weights.
Weights Visualization
Summary
Standard Seq2Seq model: the decoder looks at only its current state.
Attention: decoder additionally looks at all the states of the encoder.
Attention: decoder knows where to focus.
Downside: higher time complexity.
m: source sequence length
t: target sequence length
Standard Seq2Seq: O(m+t) time complexity
Seq2Seq + attention: O(mâ‹…t) time complexity
References:
Bahdanau, Cho, & Bengio. Neural machine translation by jointly learning to align and translate.
In ICLR, 2015.