diff --git a/labs/l3/NLP-L3.ipynb b/labs/l3/NLP-L3.ipynb index 960a939b6f376619c6c0e1a238e5a4f48035fe6b..ade668f1a342e78db6632b1fd5b2a04d1e1198b4 100644 --- a/labs/l3/NLP-L3.ipynb +++ b/labs/l3/NLP-L3.ipynb @@ -907,7 +907,7 @@ " train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=batcher)\n", "\n", " # Build the translator\n", - " translator = Translator(src_vocab, tgt_vocab, ScaledDotProductAttention(), device=device)\n", + " translator = Translator(vocab_src, vocab_tgt, ScaledDotProductAttention(), device=device)\n", "\n", " # Initialise the optimiser\n", " optimizer = torch.optim.Adam(translator.model.parameters(), lr=lr)\n", @@ -929,7 +929,7 @@ "\n", " # Forward pass\n", " scores = translator.model(src_batch, tgt_batch_shifted)\n", - " scores = scores.view(-1, len(tgt_vocab))\n", + " scores = scores.view(-1, len(vocab_tgt))\n", "\n", " # Backward pass\n", " optimizer.zero_grad()\n",