ML_101
  • Introduction
  • ML Fundamentals
    • Basics
    • Optimization
    • How to prevent overfitting
    • Linear Algebra
    • Clustering
    • Calculate Parameters in CNN
    • Normalization
    • Confidence Interval
    • Quantization
  • Classical Machine Learning
    • Basics
    • Unsupervised Learning
  • Neural Networks
    • Basics
    • Activation function
    • Different Types of Convolution
    • Resnet
    • Mobilenet
  • Loss
    • L1 and L2 Loss
    • Hinge Loss
    • Cross-Entropy Loss
    • Binary Cross-Entropy Loss
    • Categorical Cross-Entropy Loss
    • (Optional) Focal Loss
    • (Optional) CORAL Loss
  • Computer Vision
    • Two Stage Object Detection
      • Metrics
      • ROI
      • R-CNN
      • Fast RCNN
      • Faster RCNN
      • Mask RCNN
    • One Stage Object Detection
      • FPN
      • YOLO
      • Single Shot MultiBox Detector(SSD)
    • Segmentation
      • Panoptic Segmentation
      • PSPNet
    • FaceNet
    • GAN
    • Imbalance problem in object detection
  • NLP
    • Embedding
    • RNN
    • LSTM
    • LSTM Ext.
    • RNN for text prediction
    • BLEU
    • Seq2Seq
    • Attention
    • Self Attention
    • Attention without RNN
    • Transformer
    • BERT
  • Parallel Computing
    • Communication
    • MapReduce
    • Parameter Server
    • Decentralized And Ring All Reduce
    • Federated Learning
    • Model Parallelism: GPipe
  • Anomaly Detection
    • DBSCAN
    • Autoencoder
  • Visualization
    • Saliency Maps
    • Fooling images
    • Class Visualization
Powered by GitBook
On this page
  • Shortcoming of Seq2Seq
  • Seq2Seq Model with Attention
  • Simple RNN + Attention
  • Option1: (Used in original paper)
  • Option2: More popular; the same to Transformer
  • Calculate the next state
  • Compute parameters For state
  • Time complexity
  • Weights Visualization
  • Summary
  • References:

Was this helpful?

  1. NLP

Attention

PreviousSeq2SeqNextSelf Attention

Last updated 3 years ago

Was this helpful?

Shortcoming of Seq2Seq

The final state is incapable of remembering a long sequence

Seq2Seq Model with Attention

  • Attention tremendously improves Seq2Seq model.

  • With attention, Seq2Seq model does not forget source input.

  • With attention, the decoder knows where to focus.

  • Downside: much more computation.

Simple RNN + Attention

There are two options to calculate weight: αi=align(hi,s0)\alpha_i = align(h_i, s_0)αi​=align(hi​,s0​)

Option1: (Used in original paper)

Then normalize α~1,...,α~m\tilde\alpha_1, ..., \tilde\alpha_mα~1​,...,α~m​ (so that they sum to 1):

  • [α~1,...,α~m]=Softmax([α~1,...,α~m])[\tilde\alpha_1,...,\tilde\alpha_m] = Softmax([\tilde\alpha_1, ..., \tilde\alpha_m])[α~1​,...,α~m​]=Softmax([α~1​,...,α~m​])

Option2: More popular; the same to Transformer

  1. Linear maps:

    1. ki=WK⋅hik_i=W_K \cdot h_iki​=WK​⋅hi​, for i = 1 to m

    2. q0=WQ⋅s0q_0=W_Q \cdot s_0q0​=WQ​⋅s0​

  2. Inner product:

    1. α~i=kiTq0\tilde\alpha_i = k_i^T q_0α~i​=kiT​q0​, for i = 1 to m

  3. Normalization

    1. [α~1,...,α~m]=Softmax([α~1,...,α~m])[\tilde\alpha_1,...,\tilde\alpha_m] = Softmax([\tilde\alpha_1, ..., \tilde\alpha_m])[α~1​,...,α~m​]=Softmax([α~1​,...,α~m​])

Calculate the next state

Simple RNN: s1=tanh(A′⋅[x1′s0]+b)s_1 = tanh(A' \cdot \begin{bmatrix} x'_1 \\ s_0 \end{bmatrix} + b)s1​=tanh(A′⋅[x1′​s0​​]+b)

Simple RNN + Attention: s1=tanh(A′⋅[x1′s0c0]+b)s_1 = tanh(A' \cdot \begin{bmatrix} x'_1 \\ s_0 \\ c_0 \end{bmatrix} + b)s1​=tanh(A′⋅​x1′​s0​c0​​​+b)

Context vector: c0=α1h1+...+αmhmc_0 = \alpha_1 h_1 + ... + \alpha_m h_mc0​=α1​h1​+...+αm​hm​

For next state s2s_2s2​, do not use the previously calculated αi\alpha_iαi​

Compute parameters For jthj^{th}jthstate

  • Query: q:j=WQsjq_{:j}=W_Q s_jq:j​=WQ​sj​ (To match others)

  • Key: k:i=WKhik_{:i}=W_K h_ik:i​=WK​hi​ (To be matched)

  • Value: v:i=WVhiv_{:i}=W_Vh_iv:i​=WV​hi​ (To be weighted averaged)

  • Weights:αij=align(hi,sj)\alpha_{ij} = align(h_i, s_j)αij​=align(hi​,sj​)

    • Compute k:i=WKhik_{:i}=W_K h_i k:i​=WK​hi​and q:j=WQsjq_{:j}=W_Q s_jq:j​=WQ​sj​

    • Compute weights: α:j=Softmax(KTq:j)∈Rm\alpha_{:j}=Softmax(K^T q_{:j}) \in R^mα:j​=Softmax(KTq:j​)∈Rm

  • Context vector: cj=αijv:1+...+αmjv:mc_j=\alpha_{ij} v_{:1} + ... +\alpha_{mj} v_{:m}cj​=αij​v:1​+...+αmj​v:m​

Time complexity

  • Question: How many weights αi\alpha_iαi​ have been computed?

    • To compute one vector cjc_jcj​, we compute mmm weights: α1,...,αm\alpha_1, ... , \alpha_mα1​,...,αm​.

      • The decoder has ttt states, so there are totally m⋅tm \cdot tm⋅t weights.

Weights Visualization

Summary

  • Standard Seq2Seq model: the decoder looks at only its current state.

  • Attention: decoder additionally looks at all the states of the encoder.

  • Attention: decoder knows where to focus.

  • Downside: higher time complexity.

    • mmm: source sequence length

    • ttt: target sequence length

    • Standard Seq2Seq: O(m+t)O(m+t)O(m+t) time complexity

    • Seq2Seq + attention: O(mâ‹…t)O(m \cdot t)O(mâ‹…t) time complexity

References:

  1. Bahdanau, Cho, & Bengio. Neural machine translation by jointly learning to align and translate.

    In ICLR, 2015.

youtube video
Figure is from https://distill.pub/2016/augmented-rnns/