Encoder - Decoder Attention in the Transformer

Understand how encoder-decoder attention powers tasks like machine translation by dynamically linking input and output sequences using queries, keys, and values.

Attention Mechanism Cover

The encoder-decoder attention mechanism is a pivotal component in modern sequence-to-sequence models, particularly in tasks like machine translation, text summarization, and speech recognition. It allows the model to dynamically focus on different parts of the input sequence when generating each element of the output sequence. In this section, we delve deeply into how encoder-decoder attention works, explaining the roles of the query, key, and value matrices, why they are derived from the decoder and encoder respectively, and how they operate within multi-layer architectures like the Transformer.

Encoder Decoder Attention

Queries from the Decoder

In right section of the figure, the queries q1,q2,,q4q_1, q_2, \dots, q_4 (yellow blocks) are generated by the decoder based on the current decoding state, which reflects the tokens produced so far. At each decoding step, the decoder generates a query vector, such as q2q_2 for the word "for," representing the specific information the decoder seeks from the encoder's output.

Queries(QQ) come from the decoder because they need to represent what the decoder requires at each time step, based on the generated output sequence so far.

Keys and Values from the Encoder

In left section of the figure, The encoder's output provides keys k1,k2,k3k_1, k_2, k_3 and values v1,v2,v3v_1, v_2, v_3 (shown as "Keys" and "Values" boxes). These vectors are derived from the encoder's processing of the input sequence (outputs of last encoder layer), which captures the input tokens' contextual meaning.

The Attention Computation Process

At each decoding step, the attention mechanism performs the following operations:

  1. Compute Attention Scores: Measure the similarity between the query and each key to determine how much attention the decoder should pay to each position in the input sequence.

    For each decoder time step tt:

    • Query Vector: qtRdkq_t \in \mathbb{R}^{d_k} (from the decoder's current state).
    • Key Vectors: K=[k1,k2,,kTenc]RTenc×dkK = [k_1, k_2, \dots, k_{T_{enc}}]^\top \in \mathbb{R}^{T_{enc} \times d_k} (from the outputs of the last encoder layer).
    • Attention Scores: Compute the compatibility between qtq_t and each kik_i:
    eti=score(qt,ki)e_{ti} = \text{score}(q_t, k_i)

    For the scaled dot-product, the score is:

    eti=qtkidke_{ti} = \frac{q_t \cdot k_i^\top}{\sqrt{d_k}}
  2. Calculate Attention Weights: Normalize the attention scores using a softmax function to obtain a probability distribution over the input positions.

    αti=exp(eti)j=1Tencexp(etj)\alpha_{ti} = \frac{\exp(e_{ti})}{\sum_{j=1}^{T_{enc}} \exp(e_{tj})}
  3. Compute the Context Vector: Take a weighted sum of the value vectors V=[v1,v2,,vTenc]V = [v_1, v_2, \dots, v_{T_{enc}}]^\top (from the outputs of the last encoder layer), using the attention weights. This context vector summarizes the relevant information from the encoder outputs for the current decoding step.

    ct=i=1Tencαtivic_t = \sum_{i=1}^{T_{enc}} \alpha_{ti} v_i

Conclusion

By leveraging the interplay of queries, keys, and values—derived from the decoder and encoder, respectively—the mechanism facilitates a fine-grained and context-aware information flow between the encoder's representation of the input and the decoder's generation of the output. This dynamic interaction is crucial for tasks requiring high-level comprehension and structured output, such as machine translation and text summarization.

Next Steps