Skip to content
Snippets Groups Projects
Commit 4c3576e6 authored by Marco Kuhlmann's avatar Marco Kuhlmann
Browse files

Merge branch 'MergeBranch' into 'master'

L3: Corrected vocab usage in training-loop to avoid potential collisions with earlier tests.

See merge request nlp/nlp-course!5
parents eb502e27 6024acca
Branches master
No related tags found
1 merge request!5L3: Corrected vocab usage in training-loop to avoid potential collisions with earlier tests.
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# L3: Attention # L3: Attention
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
In this lab, you will implement the encoder–decoder architecture presented in Lecture 3.2 ([Sutskever et al., 2014](https://papers.nips.cc/paper/2014/file/a14ac55a4f27472c5d894ec1c3c743d2-Paper.pdf)), including the attention-based extension presented in Lecture 3.3 ([Bahdanau et al., 2015](https://arxiv.org/abs/1409.0473)), and evaluate this architecture on a machine translation task. In this lab, you will implement the encoder–decoder architecture presented in Lecture 3.2 ([Sutskever et al., 2014](https://papers.nips.cc/paper/2014/file/a14ac55a4f27472c5d894ec1c3c743d2-Paper.pdf)), including the attention-based extension presented in Lecture 3.3 ([Bahdanau et al., 2015](https://arxiv.org/abs/1409.0473)), and evaluate this architecture on a machine translation task.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import torch import torch
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Training the models in this notebook requires significant compute power, and we strongly recommend using a GPU. Training the models in this notebook requires significant compute power, and we strongly recommend using a GPU.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## The data ## The data
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
We will build a system that translates from German (our **source language**) to English (our **target language**). The dataset is a collection of parallel English–German sentences taken from translations of subtitles for TED talks. It was derived from the [TED2013](https://opus.nlpl.eu/TED2013-v1.1.php) dataset, which is available in the [OPUS](http://opus.nlpl.eu/) collection. The code cell below prints the first lines in the training data: We will build a system that translates from German (our **source language**) to English (our **target language**). The dataset is a collection of parallel English–German sentences taken from translations of subtitles for TED talks. It was derived from the [TED2013](https://opus.nlpl.eu/TED2013-v1.1.php) dataset, which is available in the [OPUS](http://opus.nlpl.eu/) collection. The code cell below prints the first lines in the training data:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
with open('train-de.txt') as src, open('train-en.txt') as tgt: with open('train-de.txt') as src, open('train-en.txt') as tgt:
for i, src_sentence, tgt_sentence in zip(range(5), src, tgt): for i, src_sentence, tgt_sentence in zip(range(5), src, tgt):
print(f'{i}: {src_sentence.rstrip()} / {tgt_sentence.rstrip()}') print(f'{i}: {src_sentence.rstrip()} / {tgt_sentence.rstrip()}')
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
As you can see, some ‘sentences’ are actually *sequences* of sentences, but we will use the term *sentence* nevertheless. All sentences are whitespace-tokenised and lowercased. To make your life a bit easier, we have removed sentences longer than 25 words. As you can see, some ‘sentences’ are actually *sequences* of sentences, but we will use the term *sentence* nevertheless. All sentences are whitespace-tokenised and lowercased. To make your life a bit easier, we have removed sentences longer than 25 words.
The next cell contains code that yields the sentences contained in a file as lists of strings: The next cell contains code that yields the sentences contained in a file as lists of strings:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def sentences(filename): def sentences(filename):
with open(filename) as source: with open(filename) as source:
for line in source: for line in source:
yield line.rstrip().split() yield line.rstrip().split()
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Problem 1: Build the vocabularies ## Problem 1: Build the vocabularies
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Your first task is to build the vocabularies for the data, one vocabulary for each language. Each vocabulary should contain the 10,000 most frequent words in the training data for the respective language. Your first task is to build the vocabularies for the data, one vocabulary for each language. Each vocabulary should contain the 10,000 most frequent words in the training data for the respective language.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def make_vocab(sentences, max_size): def make_vocab(sentences, max_size):
# TODO: Replace the next line with your own code # TODO: Replace the next line with your own code
raise NotImplementedError raise NotImplementedError
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Your implementation must comply with the following specification: Your implementation must comply with the following specification:
**make_vocab** (*sentences*, *max_size*) **make_vocab** (*sentences*, *max_size*)
> Returns a dictionary that maps the most frequent words in the *sentences* to a contiguous range of integers starting at&nbsp;0. The first four mappings in this dictionary are reserved for the pseudowords `<pad>` (padding, id&nbsp;0), `<bos>` (beginning of sequence, id&nbsp;1), `<eos>` (end of sequence, id&nbsp;2), and `<unk>` (unknown word, id&nbsp;3). The parameter *max_size* caps the size of the dictionary, including the pseudowords. > Returns a dictionary that maps the most frequent words in the *sentences* to a contiguous range of integers starting at&nbsp;0. The first four mappings in this dictionary are reserved for the pseudowords `<pad>` (padding, id&nbsp;0), `<bos>` (beginning of sequence, id&nbsp;1), `<eos>` (end of sequence, id&nbsp;2), and `<unk>` (unknown word, id&nbsp;3). The parameter *max_size* caps the size of the dictionary, including the pseudowords.
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
With this function, we can construct the vocabularies as follows: With this function, we can construct the vocabularies as follows:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
src_vocab = make_vocab(sentences('train-de.txt'), 10000) src_vocab = make_vocab(sentences('train-de.txt'), 10000)
tgt_vocab = make_vocab(sentences('train-en.txt'), 10000) tgt_vocab = make_vocab(sentences('train-en.txt'), 10000)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### 🤞 Test your code ### 🤞 Test your code
To test you code, check that each vocabulary contains 10,000 words, including the pseudowords. To test you code, check that each vocabulary contains 10,000 words, including the pseudowords.
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Load the data ## Load the data
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
The next cell defines a class for the parallel dataset. We sub-class the abstract [`Dataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) class, which represents map-style datasets in PyTorch. This will let us use standard infrastructure related to the loading and automatic batching of data. The next cell defines a class for the parallel dataset. We sub-class the abstract [`Dataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) class, which represents map-style datasets in PyTorch. This will let us use standard infrastructure related to the loading and automatic batching of data.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from torch.utils.data import Dataset from torch.utils.data import Dataset
class TranslationDataset(Dataset): class TranslationDataset(Dataset):
def __init__(self, src_vocab, src_filename, tgt_vocab, tgt_filename): def __init__(self, src_vocab, src_filename, tgt_vocab, tgt_filename):
self.src_vocab = src_vocab self.src_vocab = src_vocab
self.tgt_vocab = tgt_vocab self.tgt_vocab = tgt_vocab
# We hard-wire the codes for <bos> (1), <eos> (2), and <unk> (3). # We hard-wire the codes for <bos> (1), <eos> (2), and <unk> (3).
self.src = [[self.src_vocab.get(w, 3) for w in s] for s in sentences(src_filename)] self.src = [[self.src_vocab.get(w, 3) for w in s] for s in sentences(src_filename)]
self.tgt = [[self.tgt_vocab.get(w, 3) for w in s] + [2] for s in sentences(tgt_filename)] self.tgt = [[self.tgt_vocab.get(w, 3) for w in s] + [2] for s in sentences(tgt_filename)]
def __getitem__(self, idx): def __getitem__(self, idx):
return self.src[idx], self.tgt[idx] return self.src[idx], self.tgt[idx]
def __len__(self): def __len__(self):
return len(self.src) return len(self.src)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
We load the training data: We load the training data:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
train_dataset = TranslationDataset(src_vocab, 'train-de.txt', tgt_vocab, 'train-en.txt') train_dataset = TranslationDataset(src_vocab, 'train-de.txt', tgt_vocab, 'train-en.txt')
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
The following function will be helpful for debugging. It extracts a single source–target pair of sentences from the specified *dataset* and converts it into batches of size&nbsp;1, which can be fed into the encoder–decoder model. The following function will be helpful for debugging. It extracts a single source–target pair of sentences from the specified *dataset* and converts it into batches of size&nbsp;1, which can be fed into the encoder–decoder model.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def example(dataset, i): def example(dataset, i):
src, tgt = dataset[i] src, tgt = dataset[i]
return torch.LongTensor(src).unsqueeze(0), torch.LongTensor(tgt).unsqueeze(0) return torch.LongTensor(src).unsqueeze(0), torch.LongTensor(tgt).unsqueeze(0)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
example(train_dataset, 0) example(train_dataset, 0)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Problem 2: The encoder–decoder architecture ## Problem 2: The encoder–decoder architecture
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
In this section, you will implement the encoder–decoder architecture, including the extension of that architecture by an attention mechanism. The implementation consists of four parts: the encoder, the attention mechanism, the decoder, and a class that wraps the complete architecture. In this section, you will implement the encoder–decoder architecture, including the extension of that architecture by an attention mechanism. The implementation consists of four parts: the encoder, the attention mechanism, the decoder, and a class that wraps the complete architecture.
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Problem 2.1: Implement the encoder ### Problem 2.1: Implement the encoder
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
The encoder is relatively straightforward. We look up word embeddings and unroll a bidirectional GRU over the embedding vectors to compute a representation at each token position. We then take the last hidden state of the forward GRU and the last hidden state of the backward GRU, concatenate them, and pass them through a linear layer. This produces a summary of the source sentence, which we will later feed into the decoder. The encoder is relatively straightforward. We look up word embeddings and unroll a bidirectional GRU over the embedding vectors to compute a representation at each token position. We then take the last hidden state of the forward GRU and the last hidden state of the backward GRU, concatenate them, and pass them through a linear layer. This produces a summary of the source sentence, which we will later feed into the decoder.
To solve this problem, complete the skeleton code in the next code cell: To solve this problem, complete the skeleton code in the next code cell:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import torch.nn as nn import torch.nn as nn
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, num_words, embedding_dim=256, hidden_dim=512): def __init__(self, num_words, embedding_dim=256, hidden_dim=512):
super().__init__() super().__init__()
# TODO: Add your code here # TODO: Add your code here
def forward(self, src): def forward(self, src):
# TODO: Replace the next line with your own code # TODO: Replace the next line with your own code
raise NotImplementedError raise NotImplementedError
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Your code must comply with the following specification: Your code must comply with the following specification:
**__init__** (*num_words*, *embedding_dim* = 256, *hidden_dim* = 512) **__init__** (*num_words*, *embedding_dim* = 256, *hidden_dim* = 512)
> Initialises the encoder. The encoder consists of an embedding layer that maps each of *num_words* words to an embedding vector of size *embedding_dim*, a bidirectional GRU that maps each embedding vector to a position-specific representation of size 2 × *hidden_dim*, and a final linear layer that projects these representationcons to new representations of size *hidden_dim*. > Initialises the encoder. The encoder consists of an embedding layer that maps each of *num_words* words to an embedding vector of size *embedding_dim*, a bidirectional GRU that maps each embedding vector to a position-specific representation of size 2 × *hidden_dim*, and a final linear layer that projects these representationcons to new representations of size *hidden_dim*.
**forward** (*self*, *src*) **forward** (*self*, *src*)
> Takes a tensor *src* with source-language word ids and sends it through the encoder. The input tensor has shape (*batch_size*, *src_len*), where *src_len* is the length of the sentences in the batch. (We will make sure that all sentences in the same batch have the same length.) The method returns a pair of tensors (*output*, *hidden*), where *output* has shape (*batch_size*, *src_len*, *hidden_dim*), and *hidden* has shape (*batch_size*, *hidden_dim*). > Takes a tensor *src* with source-language word ids and sends it through the encoder. The input tensor has shape (*batch_size*, *src_len*), where *src_len* is the length of the sentences in the batch. (We will make sure that all sentences in the same batch have the same length.) The method returns a pair of tensors (*output*, *hidden*), where *output* has shape (*batch_size*, *src_len*, *hidden_dim*), and *hidden* has shape (*batch_size*, *hidden_dim*).
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### 🤞 Test your code ### 🤞 Test your code
To test your code, instantiate an encoder, feed it the first source sentence in the training data, and check that the tensors returned by the encoder have the expected shapes. To test your code, instantiate an encoder, feed it the first source sentence in the training data, and check that the tensors returned by the encoder have the expected shapes.
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Problem 2.2: Implement the attention mechanism ### Problem 2.2: Implement the attention mechanism
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Your next task is to implement the attention mechanism. Recall that the purpose of this mechanism is to inform the decoder when generating the translation of the next word. For this, attention has access to the previous hidden state of the decoder, as well as the complete output of the encoder. It returns the attention-weighted sum of the encoder output, the so-called *context* vector. For later usage, we also return the attention weights. Your next task is to implement the attention mechanism. Recall that the purpose of this mechanism is to inform the decoder when generating the translation of the next word. For this, attention has access to the previous hidden state of the decoder, as well as the complete output of the encoder. It returns the attention-weighted sum of the encoder output, the so-called *context* vector. For later usage, we also return the attention weights.
As mentioned in Lecture&nbsp;3.3, attention can be implemented in various ways. One very simple implementation is *uniform attention*, which assigns equal weight to each position-specific representation in the output of the encoder, and completely ignores the hidden state of the decoder. This mechanism is implemented in the cell below. As mentioned in Lecture&nbsp;3.3, attention can be implemented in various ways. One very simple implementation is *uniform attention*, which assigns equal weight to each position-specific representation in the output of the encoder, and completely ignores the hidden state of the decoder. This mechanism is implemented in the cell below.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import torch.nn.functional as F import torch.nn.functional as F
class UniformAttention(nn.Module): class UniformAttention(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def forward(self, decoder_hidden, encoder_output, src_mask): def forward(self, decoder_hidden, encoder_output, src_mask):
batch_size, src_len, _ = encoder_output.shape batch_size, src_len, _ = encoder_output.shape
# Set all attention scores to the same constant value (0). After # Set all attention scores to the same constant value (0). After
# the softmax, we will have uniform weights. # the softmax, we will have uniform weights.
scores = torch.zeros(batch_size, src_len, device=encoder_output.device) scores = torch.zeros(batch_size, src_len, device=encoder_output.device)
# Mask out the attention scores for the padding tokens. We set # Mask out the attention scores for the padding tokens. We set
# them to -inf. After the softmax, we will have 0. # them to -inf. After the softmax, we will have 0.
scores.data.masked_fill_(~src_mask, -float('inf')) scores.data.masked_fill_(~src_mask, -float('inf'))
# Convert scores into weights # Convert scores into weights
alpha = F.softmax(scores, dim=1) alpha = F.softmax(scores, dim=1)
# The context is the alpha-weighted sum of the encoder outputs. # The context is the alpha-weighted sum of the encoder outputs.
context = torch.bmm(alpha.unsqueeze(1), encoder_output).squeeze(1) context = torch.bmm(alpha.unsqueeze(1), encoder_output).squeeze(1)
return context, alpha return context, alpha
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
One technical detail in this code is our use of a mask *src_mask* to compute attention weights only for the ‘real’ tokens in the source sentences, but not for the padding tokens that we introduce to bring all sentences in a batch to the same length. One technical detail in this code is our use of a mask *src_mask* to compute attention weights only for the ‘real’ tokens in the source sentences, but not for the padding tokens that we introduce to bring all sentences in a batch to the same length.
Your task now is to implement the attention mechanism from the paper by [Bahdanau et al. (2015)](https://arxiv.org/abs/1409.0473). The relevant equation is in Section&nbsp;A.1.2: Your task now is to implement the attention mechanism from the paper by [Bahdanau et al. (2015)](https://arxiv.org/abs/1409.0473). The relevant equation is in Section&nbsp;A.1.2:
$$ $$
a(s_{i-1}, h_j) = v^{\top} \tanh(W s_{i-1} + U h_j) a(s_{i-1}, h_j) = v^{\top} \tanh(W s_{i-1} + U h_j)
$$ $$
This equation specifies how to compute the attention score (a scalar) for the previous hidden state of the decoder, denoted by&nbsp;$s_{i-1}$, and the $j$th position-specific representation in the output of the encoder, denoted by&nbsp;$h_j$. The equation refers to three parameters: a vector $v$ and $W$ and $U$. In PyTorch, these parameters can be represented in terms of (bias-free) linear layers that are trained along with the other parameters of the model. This equation specifies how to compute the attention score (a scalar) for the previous hidden state of the decoder, denoted by&nbsp;$s_{i-1}$, and the $j$th position-specific representation in the output of the encoder, denoted by&nbsp;$h_j$. The equation refers to three parameters: a vector $v$ and $W$ and $U$. In PyTorch, these parameters can be represented in terms of (bias-free) linear layers that are trained along with the other parameters of the model.
Here is the skeleton code for this problem. As you can see, your specific task is to initialise the required parameters and to compute the attention scores (*scores*); the rest of the code is the same as for the uniform attention. Here is the skeleton code for this problem. As you can see, your specific task is to initialise the required parameters and to compute the attention scores (*scores*); the rest of the code is the same as for the uniform attention.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
class BahdanauAttention(nn.Module): class BahdanauAttention(nn.Module):
def __init__(self, hidden_dim=512): def __init__(self, hidden_dim=512):
super().__init__() super().__init__()
# TODO: Add your code here # TODO: Add your code here
def forward(self, decoder_hidden, encoder_output, src_mask): def forward(self, decoder_hidden, encoder_output, src_mask):
batch_size, src_len, _ = encoder_output.shape batch_size, src_len, _ = encoder_output.shape
# TODO: Replace the next line with your own code # TODO: Replace the next line with your own code
scores = torch.zeros(batch_size, src_len, device=encoder_output.device) scores = torch.zeros(batch_size, src_len, device=encoder_output.device)
# The rest of the code is as in UniformAttention # The rest of the code is as in UniformAttention
# Mask out the attention scores for the padding tokens. We set # Mask out the attention scores for the padding tokens. We set
# them to -inf. After the softmax, we will have 0. # them to -inf. After the softmax, we will have 0.
scores.data.masked_fill_(~src_mask, -float('inf')) scores.data.masked_fill_(~src_mask, -float('inf'))
# Convert scores into weights # Convert scores into weights
alpha = F.softmax(scores, dim=1) alpha = F.softmax(scores, dim=1)
# The context vector is the alpha-weighted sum of the encoder outputs. # The context vector is the alpha-weighted sum of the encoder outputs.
context = torch.bmm(alpha.unsqueeze(1), encoder_output).squeeze(1) context = torch.bmm(alpha.unsqueeze(1), encoder_output).squeeze(1)
return context, alpha return context, alpha
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Your code must comply with the following specification: Your code must comply with the following specification:
**forward** (*decoder_hidden*, *encoder_output*, *src_mask*) **forward** (*decoder_hidden*, *encoder_output*, *src_mask*)
> Takes the previous hidden state of the decoder (*decoder_hidden*) and the encoder output (*encoder_output*) and returns a pair (*context*, *alpha*) where *context* is the context as computed as in [Bahdanau et al. (2015)](https://arxiv.org/abs/1409.0473), and *alpha* are the corresponding attention weights. The hidden state has shape (*batch_size*, *hidden_dim*), the encoder output has shape (*batch_size*, *src_len*, *hidden_dim*), the context has shape (*batch_size*, *hidden_dim*), and the attention weights have shape (*batch_size*, *src_len*). > Takes the previous hidden state of the decoder (*decoder_hidden*) and the encoder output (*encoder_output*) and returns a pair (*context*, *alpha*) where *context* is the context as computed as in [Bahdanau et al. (2015)](https://arxiv.org/abs/1409.0473), and *alpha* are the corresponding attention weights. The hidden state has shape (*batch_size*, *hidden_dim*), the encoder output has shape (*batch_size*, *src_len*, *hidden_dim*), the context has shape (*batch_size*, *hidden_dim*), and the attention weights have shape (*batch_size*, *src_len*).
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### 🤞 Test your code ### 🤞 Test your code
To test your code, extend your test from Problem&nbsp;2.1: Feed the output of your encoder into your attention class. As the previous hidden state of the decoder, you can use the hidden state returned by the encoder. You will also need to create a source mask; this can be done as follows: To test your code, extend your test from Problem&nbsp;2.1: Feed the output of your encoder into your attention class. As the previous hidden state of the decoder, you can use the hidden state returned by the encoder. You will also need to create a source mask; this can be done as follows:
``` ```
src_mask = (src != 0) src_mask = (src != 0)
``` ```
Check that the context tensor and the attention weights returned by the attention class have the expected shapes. Check that the context tensor and the attention weights returned by the attention class have the expected shapes.
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Problem 2.3: Implement the decoder ### Problem 2.3: Implement the decoder
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Now you are ready to implement the decoder. Like the encoder, the decoder is based on a GRU; but this time we use a unidirectional network, as we generate the target sentences left-to-right. Now you are ready to implement the decoder. Like the encoder, the decoder is based on a GRU; but this time we use a unidirectional network, as we generate the target sentences left-to-right.
**⚠️ We expect that solving this problem will take you the longest time in this lab.** **⚠️ We expect that solving this problem will take you the longest time in this lab.**
Because the decoder is an autoregressive model, we need to unroll the GRU ‘manually’: At each position, we take the previous hidden state as well as the new input, and apply the GRU for one step. The initial hidden state comes from the encoder. The new input is the embedding of the previous word, concatenated with the context vector from the attention model. To produce the final output, we take the output of the GRU, concatenate the embedding vector and the context vector (residual connection), and feed the result into a linear layer. Here is a graphical representation: Because the decoder is an autoregressive model, we need to unroll the GRU ‘manually’: At each position, we take the previous hidden state as well as the new input, and apply the GRU for one step. The initial hidden state comes from the encoder. The new input is the embedding of the previous word, concatenated with the context vector from the attention model. To produce the final output, we take the output of the GRU, concatenate the embedding vector and the context vector (residual connection), and feed the result into a linear layer. Here is a graphical representation:
<img src="https://gitlab.liu.se/nlp/nlp-course/-/raw/master/labs/l3/decoder.svg" width="50%" alt="Decoder architecture"/> <img src="https://gitlab.liu.se/nlp/nlp-course/-/raw/master/labs/l3/decoder.svg" width="50%" alt="Decoder architecture"/>
We need to implement this manual unrolling for two very similar tasks: When *training*, both the inputs to and the target outputs of the GRU come from the training data. When *decoding*, the outputs of the GRU are used to generate new target-side words, and these words become the inputs to the next step of the unrolling. We have implemented methods `forward` and `decode` for these two different modes of usage. Your task is to implement a method `step` that takes a single step with the GRU. We need to implement this manual unrolling for two very similar tasks: When *training*, both the inputs to and the target outputs of the GRU come from the training data. When *decoding*, the outputs of the GRU are used to generate new target-side words, and these words become the inputs to the next step of the unrolling. We have implemented methods `forward` and `decode` for these two different modes of usage. Your task is to implement a method `step` that takes a single step with the GRU.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, num_words, attention, embedding_dim=256, hidden_dim=512): def __init__(self, num_words, attention, embedding_dim=256, hidden_dim=512):
super().__init__() super().__init__()
self.embedding = nn.Embedding(num_words, embedding_dim) self.embedding = nn.Embedding(num_words, embedding_dim)
self.attention = attention self.attention = attention
# TODO: Add your own code # TODO: Add your own code
def forward(self, encoder_output, hidden, src_mask, tgt): def forward(self, encoder_output, hidden, src_mask, tgt):
batch_size, tgt_len = tgt.shape batch_size, tgt_len = tgt.shape
# Lookup the embeddings for the previous words # Lookup the embeddings for the previous words
embedded = self.embedding(tgt) embedded = self.embedding(tgt)
# Initialise the list of outputs (in each sentence) # Initialise the list of outputs (in each sentence)
outputs = [] outputs = []
for i in range(tgt_len): for i in range(tgt_len):
# Get the embedding for the previous word (in each sentence) # Get the embedding for the previous word (in each sentence)
prev_embedded = embedded[:, i] prev_embedded = embedded[:, i]
# Take one step with the RNN # Take one step with the RNN
output, hidden, alpha = self.step(encoder_output, hidden, src_mask, prev_embedded) output, hidden, alpha = self.step(encoder_output, hidden, src_mask, prev_embedded)
# Update the list of outputs (in each sentence) # Update the list of outputs (in each sentence)
outputs.append(output.unsqueeze(1)) outputs.append(output.unsqueeze(1))
return torch.cat(outputs, dim=1) return torch.cat(outputs, dim=1)
def decode(self, encoder_output, hidden, src_mask, max_len): def decode(self, encoder_output, hidden, src_mask, max_len):
batch_size = encoder_output.size(0) batch_size = encoder_output.size(0)
# Initialise the list of generated words and attention weights (in each sentence) # Initialise the list of generated words and attention weights (in each sentence)
generated = [torch.ones(batch_size, dtype=torch.long, device=hidden.device)] generated = [torch.ones(batch_size, dtype=torch.long, device=hidden.device)]
alphas = [] alphas = []
for i in range(max_len): for i in range(max_len):
# Get the embedding for the previous word (in each sentence) # Get the embedding for the previous word (in each sentence)
prev_embedded = self.embedding(generated[-1]) prev_embedded = self.embedding(generated[-1])
# Take one step with the RNN # Take one step with the RNN
output, hidden, alpha = self.step(encoder_output, hidden, src_mask, prev_embedded) output, hidden, alpha = self.step(encoder_output, hidden, src_mask, prev_embedded)
# Update the list of generated words and attention weights (in each sentence) # Update the list of generated words and attention weights (in each sentence)
generated.append(output.argmax(-1)) generated.append(output.argmax(-1))
alphas.append(alpha) alphas.append(alpha)
generated = [x.unsqueeze(1) for x in generated[1:]] generated = [x.unsqueeze(1) for x in generated[1:]]
alphas = [x.unsqueeze(1) for x in alphas] alphas = [x.unsqueeze(1) for x in alphas]
return torch.cat(generated, dim=1), torch.cat(alphas, dim=1) return torch.cat(generated, dim=1), torch.cat(alphas, dim=1)
def step(self, encoder_output, hidden, src_mask, prev_embedded): def step(self, encoder_output, hidden, src_mask, prev_embedded):
# TODO: Replace the next line with your own code # TODO: Replace the next line with your own code
raise NotImplementedError raise NotImplementedError
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Your implementation should comply with the following specification: Your implementation should comply with the following specification:
**step** (*self*, *encoder_output*, *hidden*, *src_mask*, *prev_embedded*) **step** (*self*, *encoder_output*, *hidden*, *src_mask*, *prev_embedded*)
> Performs a single step in the manual unrolling of the decoder GRU. This takes the output of the encoder (*encoder_output*), the previous hidden state of the decoder (*hidden*), the source mask as described in Problem&nbsp;2.2 (*src_mask*), and the embedding vector of the previous word (*prev_embedded*), and computes the output as described above. > Performs a single step in the manual unrolling of the decoder GRU. This takes the output of the encoder (*encoder_output*), the previous hidden state of the decoder (*hidden*), the source mask as described in Problem&nbsp;2.2 (*src_mask*), and the embedding vector of the previous word (*prev_embedded*), and computes the output as described above.
> >
> The shape of *encoder_output* is (*batch_size*, *src_len*, *hidden_dim*); the shape of *hidden* is (*batch_size*, *hidden_dim*); the shape of *src_mask* is (*batch_size*, *src_len*); and the shape of *prev_embedded* is (*batch_size*, *embedding_dim*). > The shape of *encoder_output* is (*batch_size*, *src_len*, *hidden_dim*); the shape of *hidden* is (*batch_size*, *hidden_dim*); the shape of *src_mask* is (*batch_size*, *src_len*); and the shape of *prev_embedded* is (*batch_size*, *embedding_dim*).
> >
> The method returns a triple of tensors (*output*, *hidden*, *alpha*) where *output* is the position-specific output of the GRU, of shape (*batch_size*, *num_words*); *hidden* is the new hidden state, of shape (*batch_size*, *hidden_dim*); and *alpha* are the attention weights that were used to compute the *output*, of shape (*batch_size*, *src_len*). > The method returns a triple of tensors (*output*, *hidden*, *alpha*) where *output* is the position-specific output of the GRU, of shape (*batch_size*, *num_words*); *hidden* is the new hidden state, of shape (*batch_size*, *hidden_dim*); and *alpha* are the attention weights that were used to compute the *output*, of shape (*batch_size*, *src_len*).
#### 💡 Hints on the implementation #### 💡 Hints on the implementation
**Batch first!** Per default, the GRU implementation in PyTorch (just as the LSTM implementation) expects its input to be a three-dimensional tensor of the form (*seq_len*, *batch_size*, *input_size*). We find it conceptually easier to change this default behaviour and let the models take their input in the form (*batch_size*, *seq_len*, *input_size*). To do so, set *batch_first=True* when instantiating the GRU. **Batch first!** Per default, the GRU implementation in PyTorch (just as the LSTM implementation) expects its input to be a three-dimensional tensor of the form (*seq_len*, *batch_size*, *input_size*). We find it conceptually easier to change this default behaviour and let the models take their input in the form (*batch_size*, *seq_len*, *input_size*). To do so, set *batch_first=True* when instantiating the GRU.
**Unsqueeze and squeeze.** When doing the unrolling manually, we get the input in the form (*batch_size*, *input_size*). To convert between this representation and the (*batch_size*, *seq_len*, *input_size*) representation, you can use [`unsqueeze`](https://pytorch.org/docs/stable/generated/torch.unsqueeze.html) and [`squeeze`](https://pytorch.org/docs/stable/generated/torch.squeeze.html). **Unsqueeze and squeeze.** When doing the unrolling manually, we get the input in the form (*batch_size*, *input_size*). To convert between this representation and the (*batch_size*, *seq_len*, *input_size*) representation, you can use [`unsqueeze`](https://pytorch.org/docs/stable/generated/torch.unsqueeze.html) and [`squeeze`](https://pytorch.org/docs/stable/generated/torch.squeeze.html).
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### 🤞 Test your code ### 🤞 Test your code
To test your code, extend your test from the previous problems, and simulate a complete forward pass of the encoder–decoder architecture on the example sentence. Check the shapes of the resulting tensors. To test your code, extend your test from the previous problems, and simulate a complete forward pass of the encoder–decoder architecture on the example sentence. Check the shapes of the resulting tensors.
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Encoder–decoder wrapper class ### Encoder–decoder wrapper class
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
The last part of the implementation is a class that wraps the encoder and the decoder as a single model: The last part of the implementation is a class that wraps the encoder and the decoder as a single model:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
class EncoderDecoder(nn.Module): class EncoderDecoder(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, attention): def __init__(self, src_vocab_size, tgt_vocab_size, attention):
super().__init__() super().__init__()
self.encoder = Encoder(src_vocab_size) self.encoder = Encoder(src_vocab_size)
self.decoder = Decoder(tgt_vocab_size, attention) self.decoder = Decoder(tgt_vocab_size, attention)
def forward(self, src, tgt): def forward(self, src, tgt):
encoder_output, hidden = self.encoder(src) encoder_output, hidden = self.encoder(src)
return self.decoder.forward(encoder_output, hidden, src != 0, tgt) return self.decoder.forward(encoder_output, hidden, src != 0, tgt)
def decode(self, src, max_len): def decode(self, src, max_len):
encoder_output, hidden = self.encoder(src) encoder_output, hidden = self.encoder(src)
return self.decoder.decode(encoder_output, hidden, src != 0, max_len) return self.decoder.decode(encoder_output, hidden, src != 0, max_len)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### 🤞 Test your code ### 🤞 Test your code
As a final test, instantiate an encoder–decoder model and use it to decode the example sentence. Check the shapes of the resulting tensors. As a final test, instantiate an encoder–decoder model and use it to decode the example sentence. Check the shapes of the resulting tensors.
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Problem 3: Train a translator ## Problem 3: Train a translator
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
We now have all the pieces to build and train a complete translation system. We now have all the pieces to build and train a complete translation system.
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Translator class ### Translator class
We first define a class `Translator` that initialises an encoder–decoder model and uses it to translate sentences. It can also return the attention weights that were used to produce the translation of each sentence. We first define a class `Translator` that initialises an encoder–decoder model and uses it to translate sentences. It can also return the attention weights that were used to produce the translation of each sentence.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
class Translator(object): class Translator(object):
def __init__(self, src_vocab, tgt_vocab, attention, device=torch.device('cpu')): def __init__(self, src_vocab, tgt_vocab, attention, device=torch.device('cpu')):
self.src_vocab = src_vocab self.src_vocab = src_vocab
self.tgt_vocab = tgt_vocab self.tgt_vocab = tgt_vocab
self.device = device self.device = device
self.model = EncoderDecoder(len(src_vocab), len(tgt_vocab), attention).to(device) self.model = EncoderDecoder(len(src_vocab), len(tgt_vocab), attention).to(device)
def translate_with_attention(self, sentences): def translate_with_attention(self, sentences):
# Encode each sentence # Encode each sentence
encoded = [[self.src_vocab.get(w, 3) for w in s.split()] for s in sentences] encoded = [[self.src_vocab.get(w, 3) for w in s.split()] for s in sentences]
# Determine the maximal length of an encoded sentence # Determine the maximal length of an encoded sentence
max_len = max(len(e) for e in encoded) max_len = max(len(e) for e in encoded)
# Build the input tensor, padding all sequences to the same length # Build the input tensor, padding all sequences to the same length
src = torch.LongTensor([e + [0] * (max_len - len(e)) for e in encoded]).to(self.device) src = torch.LongTensor([e + [0] * (max_len - len(e)) for e in encoded]).to(self.device)
# Run the decoder and convert the result into nested lists # Run the decoder and convert the result into nested lists
with torch.no_grad(): with torch.no_grad():
decoded, alphas = tuple(d.cpu().numpy().tolist() for d in self.model.decode(src, 2 * max_len)) decoded, alphas = tuple(d.cpu().numpy().tolist() for d in self.model.decode(src, 2 * max_len))
# Prune each decoded sentence after the first <eos> # Prune each decoded sentence after the first <eos>
i2w = {i: w for w, i in self.tgt_vocab.items()} i2w = {i: w for w, i in self.tgt_vocab.items()}
result = [] result = []
for d, a in zip(decoded, alphas): for d, a in zip(decoded, alphas):
d = [i2w[i] for i in d] d = [i2w[i] for i in d]
try: try:
eos_index = d.index('<eos>') eos_index = d.index('<eos>')
del d[eos_index:] del d[eos_index:]
del a[eos_index:] del a[eos_index:]
except: except:
pass pass
result.append((' '.join(d), a)) result.append((' '.join(d), a))
return result return result
def translate(self, sentences): def translate(self, sentences):
translated, alphas = zip(*self.translate_with_attention(sentences)) translated, alphas = zip(*self.translate_with_attention(sentences))
return translated return translated
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
The code below shows how this class is supposed to be used: The code below shows how this class is supposed to be used:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
translator = Translator(src_vocab, tgt_vocab, BahdanauAttention()) translator = Translator(src_vocab, tgt_vocab, BahdanauAttention())
translator.translate(['ich weiß nicht .', 'das haus ist klein .']) translator.translate(['ich weiß nicht .', 'das haus ist klein .'])
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Evaluation function ### Evaluation function
As mentioned in Lecture&nbsp;3.1, machine translation systems are typically evaluated using the BLEU metric. Here we use the implementation of this metric from the `sacrebleu` library. As mentioned in Lecture&nbsp;3.1, machine translation systems are typically evaluated using the BLEU metric. Here we use the implementation of this metric from the `sacrebleu` library.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# If sacrebleu is not found, uncomment the next line: # If sacrebleu is not found, uncomment the next line:
# !pip install sacrebleu # !pip install sacrebleu
import sacrebleu import sacrebleu
def bleu(translator, src, ref): def bleu(translator, src, ref):
translated = translator.translate(src) translated = translator.translate(src)
return sacrebleu.raw_corpus_bleu(translated, [ref], 0.01).score return sacrebleu.raw_corpus_bleu(translated, [ref], 0.01).score
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
We will report the BLEU score on the validation data: We will report the BLEU score on the validation data:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
with open('valid-de.txt') as src, open('valid-en.txt') as ref: with open('valid-de.txt') as src, open('valid-en.txt') as ref:
valid_src = [line.rstrip() for line in src] valid_src = [line.rstrip() for line in src]
valid_ref = [line.rstrip() for line in ref] valid_ref = [line.rstrip() for line in ref]
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Batcher class ### Batcher class
To prepare the training, we next create a class that takes a batch of encoded parallel sentences (a pair of lists of integers) and transforms it into two tensors, one for the source side and one for the target side. Each tensor contains sequences padded to the length of the longest sequence. To prepare the training, we next create a class that takes a batch of encoded parallel sentences (a pair of lists of integers) and transforms it into two tensors, one for the source side and one for the target side. Each tensor contains sequences padded to the length of the longest sequence.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
class TranslationBatcher(object): class TranslationBatcher(object):
def __init__(self, device): def __init__(self, device):
self.device = device self.device = device
def __call__(self, batch): def __call__(self, batch):
srcs, tgts = zip(*batch) srcs, tgts = zip(*batch)
# Determine the maximal length of a source/target sequence # Determine the maximal length of a source/target sequence
max_src_len = max(len(s) for s in srcs) max_src_len = max(len(s) for s in srcs)
max_tgt_len = max(len(t) for t in tgts) max_tgt_len = max(len(t) for t in tgts)
# Create the source/target tensors # Create the source/target tensors
S = torch.LongTensor([s + [0] * (max_src_len - len(s)) for s in srcs]) S = torch.LongTensor([s + [0] * (max_src_len - len(s)) for s in srcs])
T = torch.LongTensor([t + [0] * (max_tgt_len - len(t)) for t in tgts]) T = torch.LongTensor([t + [0] * (max_tgt_len - len(t)) for t in tgts])
return S.to(self.device), T.to(self.device) return S.to(self.device), T.to(self.device)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Training loop ### Training loop
The training loop resembles the training loops that you have seen in previous labs, except that we use a few new utilities from the PyTorch ecosystem. The training loop resembles the training loops that you have seen in previous labs, except that we use a few new utilities from the PyTorch ecosystem.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
def train(n_epochs=2, batch_size=128, lr=5e-4): def train(n_epochs=2, batch_size=128, lr=5e-4):
# Build the vocabularies # Build the vocabularies
vocab_src = make_vocab(sentences('train-de.txt'), 10000) vocab_src = make_vocab(sentences('train-de.txt'), 10000)
vocab_tgt = make_vocab(sentences('train-en.txt'), 10000) vocab_tgt = make_vocab(sentences('train-en.txt'), 10000)
# Prepare the dataset # Prepare the dataset
train_dataset = TranslationDataset(vocab_src, 'train-de.txt', vocab_tgt, 'train-en.txt') train_dataset = TranslationDataset(vocab_src, 'train-de.txt', vocab_tgt, 'train-en.txt')
# Prepare the data loaders # Prepare the data loaders
batcher = TranslationBatcher(device) batcher = TranslationBatcher(device)
train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=batcher) train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=batcher)
# Build the translator # Build the translator
translator = Translator(src_vocab, tgt_vocab, ScaledDotProductAttention(), device=device) translator = Translator(vocab_src, vocab_tgt, ScaledDotProductAttention(), device=device)
# Initialise the optimiser # Initialise the optimiser
optimizer = torch.optim.Adam(translator.model.parameters(), lr=lr) optimizer = torch.optim.Adam(translator.model.parameters(), lr=lr)
# Make it possible to interrupt the training # Make it possible to interrupt the training
try: try:
for epoch in range(n_epochs): for epoch in range(n_epochs):
losses = [] losses = []
bleu_valid = 0 bleu_valid = 0
sample = '<none>' sample = '<none>'
with tqdm(total=len(train_dataset)) as pbar: with tqdm(total=len(train_dataset)) as pbar:
for i, (src_batch, tgt_batch) in enumerate(train_loader): for i, (src_batch, tgt_batch) in enumerate(train_loader):
# Create a shifted version of tgt_batch containing the previous words # Create a shifted version of tgt_batch containing the previous words
batch_size, tgt_len = tgt_batch.shape batch_size, tgt_len = tgt_batch.shape
bos = torch.ones(batch_size, 1, dtype=torch.long, device=tgt_batch.device) bos = torch.ones(batch_size, 1, dtype=torch.long, device=tgt_batch.device)
tgt_batch_shifted = torch.cat((bos, tgt_batch[:, :-1]), dim=1) tgt_batch_shifted = torch.cat((bos, tgt_batch[:, :-1]), dim=1)
translator.model.train() translator.model.train()
# Forward pass # Forward pass
scores = translator.model(src_batch, tgt_batch_shifted) scores = translator.model(src_batch, tgt_batch_shifted)
scores = scores.view(-1, len(tgt_vocab)) scores = scores.view(-1, len(vocab_tgt))
# Backward pass # Backward pass
optimizer.zero_grad() optimizer.zero_grad()
loss = F.cross_entropy(scores, tgt_batch.view(-1), ignore_index=0) loss = F.cross_entropy(scores, tgt_batch.view(-1), ignore_index=0)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# Update the diagnostics # Update the diagnostics
losses.append(loss.item()) losses.append(loss.item())
pbar.set_postfix(loss=(sum(losses) / len(losses)), bleu_valid=bleu_valid, sample=sample) pbar.set_postfix(loss=(sum(losses) / len(losses)), bleu_valid=bleu_valid, sample=sample)
pbar.update(len(src_batch)) pbar.update(len(src_batch))
if i % 50 == 0: if i % 50 == 0:
translator.model.eval() translator.model.eval()
bleu_valid = int(bleu(translator, valid_src, valid_ref)) bleu_valid = int(bleu(translator, valid_src, valid_ref))
sample = translator.translate(['das haus ist klein .'])[0] sample = translator.translate(['das haus ist klein .'])[0]
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
return translator return translator
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Now it is time to train the system. During training, two diagnostics will be printed periodically: the running average of the training loss, the BLEU score on the validation data, and the translation of a sample sentence, *das haus ist klein* (which should translate into *the house is small*). Now it is time to train the system. During training, two diagnostics will be printed periodically: the running average of the training loss, the BLEU score on the validation data, and the translation of a sample sentence, *das haus ist klein* (which should translate into *the house is small*).
As mentioned before, training the translator takes quite a bit of compute power and time. Even with a GPU, you should expect training times per epoch of about 8–10 minutes. The default number of epochs is&nbsp;2; however, you may want to interrupt the training prematurely and use a partially trained model in case you run out of time. As mentioned before, training the translator takes quite a bit of compute power and time. Even with a GPU, you should expect training times per epoch of about 8–10 minutes. The default number of epochs is&nbsp;2; however, you may want to interrupt the training prematurely and use a partially trained model in case you run out of time.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
translator = train() translator = train()
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
**⚠️ Your submitted notebook must contain output demonstrating at least 16 BLEU points on the validation data.** **⚠️ Your submitted notebook must contain output demonstrating at least 16 BLEU points on the validation data.**
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Problem 4: Visualising attention (reflection) ## Problem 4: Visualising attention (reflection)
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Figure&nbsp;3 in the paper by [Bahdanau et al. (2015)](https://arxiv.org/abs/1409.0473) shows some heatmaps of attention weights in selected sentences. In the last problem of this lab, we ask you to inspect attention weights for your trained translation system. We define a function `plot_attention` that visualises the attention weights. The *x* axis corresponds to the words in the source sentence (German) and the *y* axis to the generated target sentence (English). The heatmap colours represent the strengths of the attention weights. Figure&nbsp;3 in the paper by [Bahdanau et al. (2015)](https://arxiv.org/abs/1409.0473) shows some heatmaps of attention weights in selected sentences. In the last problem of this lab, we ask you to inspect attention weights for your trained translation system. We define a function `plot_attention` that visualises the attention weights. The *x* axis corresponds to the words in the source sentence (German) and the *y* axis to the generated target sentence (English). The heatmap colours represent the strengths of the attention weights.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
%config InlineBackend.figure_format = 'svg' %config InlineBackend.figure_format = 'svg'
plt.style.use('seaborn') plt.style.use('seaborn')
def plot_attention(translator, sentence): def plot_attention(translator, sentence):
translation, weights = translator.translate_with_attention([sentence])[0] translation, weights = translator.translate_with_attention([sentence])[0]
weights = np.array(weights) weights = np.array(weights)
fig, ax = plt.subplots() fig, ax = plt.subplots()
heatmap = ax.pcolor(weights, cmap='Blues_r') heatmap = ax.pcolor(weights, cmap='Blues_r')
ax.set_xticklabels(sentence.split(), minor=False, rotation='vertical') ax.set_xticklabels(sentence.split(), minor=False, rotation='vertical')
ax.set_yticklabels(translation.split(), minor=False) ax.set_yticklabels(translation.split(), minor=False)
ax.xaxis.tick_top() ax.xaxis.tick_top()
ax.set_xticks(np.arange(weights.shape[1]) + 0.5, minor=False) ax.set_xticks(np.arange(weights.shape[1]) + 0.5, minor=False)
ax.set_yticks(np.arange(weights.shape[0]) + 0.5, minor=False) ax.set_yticks(np.arange(weights.shape[0]) + 0.5, minor=False)
ax.invert_yaxis() ax.invert_yaxis()
plt.colorbar(heatmap) plt.colorbar(heatmap)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Here is an example: Here is an example:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
plot_attention(translator, 'das haus ist klein .') plot_attention(translator, 'das haus ist klein .')
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Use these heatmaps to inspect the attention patterns for selected German sentences. Try to find sentences for which the model produces reasonably good English translations. If your German is a bit rusty (or non-existent), use sentences from the validation data. It might be interesting to look at examples where the German and the English word order differ substantially. Document your exploration in a short reflection piece (ca. 150 words). Respond to the following prompts: Use these heatmaps to inspect the attention patterns for selected German sentences. Try to find sentences for which the model produces reasonably good English translations. If your German is a bit rusty (or non-existent), use sentences from the validation data. It might be interesting to look at examples where the German and the English word order differ substantially. Document your exploration in a short reflection piece (ca. 150 words). Respond to the following prompts:
* What sentences did you try out? What patterns did you spot? Include example heatmaps in your notebook. * What sentences did you try out? What patterns did you spot? Include example heatmaps in your notebook.
* Based on what you know about attention, did you expect your results? Was there anything surprising in them? * Based on what you know about attention, did you expect your results? Was there anything surprising in them?
* What did you learn? How, exactly, did you learn it? Why does this learning matter? * What did you learn? How, exactly, did you learn it? Why does this learning matter?
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
**🥳 Congratulations on finishing this lab! 🥳** **🥳 Congratulations on finishing this lab! 🥳**
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment