Attention

Shortcoming of Seq2Seq

The final state is incapable of remembering a long sequence

Seq2Seq Model with Attention

  • Attention tremendously improves Seq2Seq model.

  • With attention, Seq2Seq model does not forget source input.

  • With attention, the decoder knows where to focus.

  • Downside: much more computation.

Simple RNN + Attention

There are two options to calculate weight: αi=align(hi,s0)\alpha_i = align(h_i, s_0)

Option1: (Used in original paper)

Then normalize α~1,...,α~m\tilde\alpha_1, ..., \tilde\alpha_m (so that they sum to 1):

  • [α~1,...,α~m]=Softmax([α~1,...,α~m])[\tilde\alpha_1,...,\tilde\alpha_m] = Softmax([\tilde\alpha_1, ..., \tilde\alpha_m])

  1. Linear maps:

    1. ki=WKhik_i=W_K \cdot h_i, for i = 1 to m

    2. q0=WQs0q_0=W_Q \cdot s_0

  2. Inner product:

    1. α~i=kiTq0\tilde\alpha_i = k_i^T q_0, for i = 1 to m

  3. Normalization

    1. [α~1,...,α~m]=Softmax([α~1,...,α~m])[\tilde\alpha_1,...,\tilde\alpha_m] = Softmax([\tilde\alpha_1, ..., \tilde\alpha_m])

Calculate the next state

Simple RNN: s1=tanh(A[x1s0]+b)s_1 = tanh(A' \cdot \begin{bmatrix} x'_1 \\ s_0 \end{bmatrix} + b)

Simple RNN + Attention: s1=tanh(A[x1s0c0]+b)s_1 = tanh(A' \cdot \begin{bmatrix} x'_1 \\ s_0 \\ c_0 \end{bmatrix} + b)

Context vector: c0=α1h1+...+αmhmc_0 = \alpha_1 h_1 + ... + \alpha_m h_m

For next state s2s_2, do not use the previously calculated αi\alpha_i

Compute parameters For jthj^{th}state

  • Query: q:j=WQsjq_{:j}=W_Q s_j (To match others)

  • Key: k:i=WKhik_{:i}=W_K h_i (To be matched)

  • Value: v:i=WVhiv_{:i}=W_Vh_i (To be weighted averaged)

  • Weights:αij=align(hi,sj)\alpha_{ij} = align(h_i, s_j)

    • Compute k:i=WKhik_{:i}=W_K h_i and q:j=WQsjq_{:j}=W_Q s_j

    • Compute weights: α:j=Softmax(KTq:j)Rm\alpha_{:j}=Softmax(K^T q_{:j}) \in R^m

  • Context vector: cj=αijv:1+...+αmjv:mc_j=\alpha_{ij} v_{:1} + ... +\alpha_{mj} v_{:m}

Time complexity

  • Question: How many weights αi\alpha_i have been computed?

    • To compute one vector cjc_j, we compute mm weights: α1,...,αm\alpha_1, ... , \alpha_m.

      • The decoder has tt states, so there are totally mtm \cdot t weights.

Weights Visualization

Figure is from https://distill.pub/2016/augmented-rnns/

Summary

  • Standard Seq2Seq model: the decoder looks at only its current state.

  • Attention: decoder additionally looks at all the states of the encoder.

  • Attention: decoder knows where to focus.

  • Downside: higher time complexity.

    • mm: source sequence length

    • tt: target sequence length

    • Standard Seq2Seq: O(m+t)O(m+t) time complexity

    • Seq2Seq + attention: O(mt)O(m \cdot t) time complexity

References:

  1. Bahdanau, Cho, & Bengio. Neural machine translation by jointly learning to align and translate.

    In ICLR, 2015.

Last updated

Was this helpful?