diff --git a/Neural graph module/ngm.ipynb b/Neural graph module/ngm.ipynb
index 14d945f4bdc93c70b54d5eb307a085a1777c93db..303e0909ef4f68fccbddc758a6c5d890c1eebfe3 100644
--- a/Neural graph module/ngm.ipynb	
+++ b/Neural graph module/ngm.ipynb	
@@ -1,999 +1,514 @@
 {
- "cells": [
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import datasets\n",
-    "import torch\n",
-    "import torch.nn as nn\n",
-    "import torch.optim as optim\n",
-    "import pandas as pd\n",
-    "import numpy as np\n",
-    "from transformers import BertTokenizer, BertModel\n",
-    "from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments\n",
-    "from tqdm import tqdm\n",
-    "import json\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# Use GPU if available\n",
-    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
-    "# Print cuda version\n",
-    "print(device)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "\n",
-    "class NgmOne(nn.Module):\n",
-    "    def __init__(self, device):\n",
-    "        super(NgmOne, self).__init__()\n",
-    "        self.tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
-    "        self.bert = BertModel.from_pretrained(\"bert-base-uncased\").to(device)\n",
-    "        self.linear = nn.Linear(768, 247).to(device)\n",
-    "        self.softmax = nn.Softmax(dim=1).to(device)\n",
-    "        self.device = device\n",
-    "    \n",
-    "    def forward(self, tokenized_seq, tokenized_mask):\n",
-    "        tokenized_seq= tokenized_seq.to(self.device)\n",
-    "        tokenized_mask = tokenized_mask.to(self.device)\n",
-    "        with torch.no_grad():\n",
-    "            x = self.bert.forward(tokenized_seq, attention_mask=tokenized_mask)\n",
-    "            x = x[0][:,0,:].to(self.device)\n",
-    "            \n",
-    "        x = self.linear(x)\n",
-    "        x = self.softmax(x)\n",
-    "        return x"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# def encode(batch):\n",
-    "#   return tokenizer(batch, padding=\"max_length\", max_length=256, return_tensors=\"pt\")\n",
-    "\n",
-    "\n",
-    "# def convert_to_features(example_batch):\n",
-    "#     input_encodings = encode(example_batch['text'])\n",
-    "#     target_encodings = encode(example_batch['summary'])\n",
-    "\n",
-    "#     labels = target_encodings['input_ids']\n",
-    "#     decoder_input_ids = shift_tokens_right(\n",
-    "#         labels, model.config.pad_token_id, model.config.decoder_start_token_id)\n",
-    "#     labels[labels[:, :] == model.config.pad_token_id] = -100\n",
-    "\n",
-    "#     encodings = {\n",
-    "#         'input_ids': input_encodings['input_ids'],\n",
-    "#         'attention_mask': input_encodings['attention_mask'],\n",
-    "#         'decoder_input_ids': decoder_input_ids,\n",
-    "#         'labels': labels,\n",
-    "#     }\n",
-    "\n",
-    "#     return encodings\n",
-    "\n",
-    "\n",
-    "# def get_dataset(path):\n",
-    "#   df = pd.read_csv(path, sep=\",\", on_bad_lines='skip')\n",
-    "#   dataset = datasets.Dataset.from_pandas(df)\n",
-    "#   dataset = dataset.map(convert_to_features, batched=True)\n",
-    "#   columns = ['input_ids', 'labels', 'decoder_input_ids', 'attention_mask', ]\n",
-    "#   dataset.set_format(type='torch', columns=columns)\n",
-    "#   return dataset\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def make_batch():\n",
-    "    \"\"\"Triplet is a list of [subject entity, relation, object entity], None if not present\"\"\"\n",
-    "\n",
-    "    # Load predicted data\n",
-    "    pred = \"../data/qald-9-train-linked.json\"\n",
-    "\n",
-    "    #Load gold data\n",
-    "    gold = \"../data/qald-9-train-linked.json\"\n",
-    "    print(\"Beginning making batch\")\n",
-    "    with open(pred, \"r\") as p, open(gold, \"r\") as g:\n",
-    "        pred = json.load(p)\n",
-    "        gold = json.load(g)\n",
-    "\n",
-    "    inputs = []\n",
-    "    inputs_max_len = 0\n",
-    "    for d in tqdm(pred[\"questions\"]):\n",
-    "        question = d[\"question\"][0][\"string\"]\n",
-    "        query = d[\"query\"][\"sparql\"]\n",
-    "\n",
-    "        #Take the first tripletin query\n",
-    "        trip = query.split(\"WHERE\")[1]\n",
-    "        trip = trip.replace(\"{\", \"\").replace(\"}\", \"\")\n",
-    "        triplet = trip.split(\" \")\n",
-    "\n",
-    "        #remove empty strings\n",
-    "        triplet = [x for x in triplet if x != \"\"]\n",
-    "\n",
-    "        if len(triplet) != 3:\n",
-    "            continue\n",
-    "\n",
-    "        for t in triplet:\n",
-    "            if not(t.find(\"?\")):\n",
-    "                triplet[triplet.index(t)] = None\n",
-    "\n",
-    "        #seq = \"[CLS] \" + question + \" [SEP] \"\n",
-    "        if triplet[0] is not None:\n",
-    "            #seq += \"[SUB] [SEP] \" + triplet[0]\n",
-    "            # , padding=True, truncation=True)\n",
-    "            tokenized_seq = tokenizer(question, \"[SUB]\", triplet[0], padding=True, truncation=True)\n",
-    "        elif triplet[2] is not None:\n",
-    "            #seq += \"[OBJ] [SEP] \" + triplet[2]\n",
-    "            tokenized_seq = tokenizer(question, \"[OBJ]\", triplet[2], padding=True, truncation=True)\n",
-    "\n",
-    "        if inputs_max_len < len(tokenized_seq[\"input_ids\"]):\n",
-    "            inputs_max_len = len(tokenized_seq[\"input_ids\"])\n",
-    "        inputs.append(list(tokenized_seq.values())[0])\n",
-    "\n",
-    "    correct_rels_max_len = 0\n",
-    "    correct_rels = []\n",
-    "    for d in tqdm(gold[\"questions\"]):\n",
-    "        question = d[\"question\"][0][\"string\"]\n",
-    "        query = d[\"query\"][\"sparql\"]\n",
-    "\n",
-    "        #Take the first tripletin query\n",
-    "        trip = query.split(\"WHERE\")[1]\n",
-    "        trip = trip.replace(\"{\", \"\").replace(\"}\", \"\")\n",
-    "        triplet = trip.split(\" \")\n",
-    "    \n",
-    "\n",
-    "        #remove empty strings\n",
-    "        triplet = [x for x in triplet if x != \"\"]\n",
-    "\n",
-    "        if len(triplet) != 3:\n",
-    "            continue\n",
-    "        # tokenized = tokenizer(triplet[1], padding=True, truncation=True)\n",
-    "        # if correct_rels_max_len < len(tokenized[\"input_ids\"]):\n",
-    "        #     correct_rels_max_len = len(tokenized[\"input_ids\"])\n",
-    "\n",
-    "        correct_rels.append(triplet[1])#list(tokenized.values())[0])\n",
-    "\n",
-    "    inputs_padded = np.array([i + [0]*(inputs_max_len-len(i)) for i in inputs])\n",
-    "    #correct_rels_padded = np.array([i + [0]*(correct_rels_max_len-len(i)) for i in correct_rels])\n",
-    "\n",
-    "    inputs_attention_mask = np.where(inputs_padded != 0, 1, 0)\n",
-    "    #correct_rels_attention_mask = np.where(correct_rels_padded != 0, 1, 0)\n",
-    "    print(\"Finished with batches\")\n",
-    "    return torch.LongTensor(inputs_padded), torch.LongTensor(inputs_attention_mask), correct_rels #torch.IntTensor(correct_rels_padded), torch.LongTensor(correct_rels_attention_mask)\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# training_args = Seq2SeqTrainingArguments(\n",
-    "#     output_dir='./models/blackbox',\n",
-    "#     num_train_epochs=1,\n",
-    "#     per_device_train_batch_size=1,\n",
-    "#     per_device_eval_batch_size=1,\n",
-    "#     warmup_steps=10,\n",
-    "#     weight_decay=0.01,\n",
-    "#     logging_dir='./logs',\n",
-    "# )\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 44,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "from torch.utils.data import Dataset, DataLoader\n",
-    "\n",
-    "class MyDataset(Dataset):\n",
-    "    def __init__(self, inputs, attention_mask, correct_rels, relations):\n",
-    "        self.inputs = inputs\n",
-    "        self.attention_mask = attention_mask\n",
-    "        self.correct_rels = correct_rels\n",
-    "        self.relations = relations\n",
-    "        \n",
-    "\n",
-    "    def __len__(self):\n",
-    "        return len(self.inputs)\n",
-    "\n",
-    "    def __getitem__(self, idx):\n",
-    "        return self.inputs[idx], self.attention_mask[idx], self.relations.index(self.correct_rels[idx])\n",
-    "\n",
-    "\n",
-    "\n",
-    "\n",
-    "#From scratch json creates data set.\n",
-    "# class MyDataset(Dataset):\n",
-    "#     def __init__(self, json_file, transform=None):\n",
-    "#         self.qald_data = json.load(json_file)\n",
-    "\n",
-    "#     def __len__(self):\n",
-    "#         return len(self.qald_data)\n",
-    "\n",
-    "#     def __getitem__(self, idx):\n",
-    "#         self.inputs[idx], self.attention_mask[idx], self.labels[idx]"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 45,
-   "metadata": {},
-   "outputs": [
+  "cells": [
     {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Beginning making batch\n"
-     ]
+      "cell_type": "code",
+      "execution_count": 1,
+      "metadata": {},
+      "outputs": [
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "c:\\Users\\maxbj\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+            "  from .autonotebook import tqdm as notebook_tqdm\n"
+          ]
+        }
+      ],
+      "source": [
+        "import datasets\n",
+        "import torch\n",
+        "import torch.nn as nn\n",
+        "import torch.optim as optim\n",
+        "import pandas as pd\n",
+        "import numpy as np\n",
+        "from transformers import BertTokenizer, BertModel\n",
+        "from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments\n",
+        "from tqdm import tqdm\n",
+        "import json\n"
+      ]
     },
     {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "100%|██████████| 408/408 [00:00<00:00, 955.51it/s] \n",
-      "100%|██████████| 408/408 [00:00<00:00, 136139.70it/s]"
-     ]
+      "cell_type": "code",
+      "execution_count": 2,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")"
+      ]
     },
     {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Finished with batches\n",
-      "features: tensor([[  101,  2040,  2003,  1996,  3677,  1997,  1996,  4035,  6870, 19247,\n",
-      "          1029,   102,  1031,  4942,  1033,   102,     0,     0,     0,     0,\n",
-      "             0,     0,     0]]) mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]]) label_index tensor(81)\n"
-     ]
+      "cell_type": "code",
+      "execution_count": 3,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "# Use GPU if available\n",
+        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
+        "# Print cuda version\n",
+        "print(device)"
+      ]
     },
     {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "\n"
-     ]
-    }
-   ],
-   "source": [
-    "#Prepare data\n",
-    "\n",
-    "def open_json(file):\n",
-    "    with open(file, \"r\") as f:\n",
-    "        return json.load(f)\n",
-    "\n",
-    "relations = open_json(\"../data/relations-query-qald-9-linked.json\")\n",
-    "train_set = MyDataset(*make_batch(), relations=relations)\n",
-    "train_dataloader = DataLoader(train_set, batch_size=1, shuffle=True)\n",
-    "\n",
-    "#show first entry\n",
-    "train_features, train_mask, train_label = next(iter(train_dataloader))\n",
-    "print(\"features:\", train_features, \"mask:\",train_mask,\"label_index\", train_label[0])"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 65,
-   "metadata": {},
-   "outputs": [
+      "cell_type": "code",
+      "execution_count": 4,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "\n",
+        "class NgmOne(nn.Module):\n",
+        "    def __init__(self, device):\n",
+        "        super(NgmOne, self).__init__()\n",
+        "        self.tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
+        "        self.bert = BertModel.from_pretrained(\"bert-base-uncased\").to(device)\n",
+        "        self.linear = nn.Linear(768, 247).to(device)\n",
+        "        self.softmax = nn.Softmax(dim=1).to(device)\n",
+        "        self.device = device\n",
+        "    \n",
+        "    def forward(self, tokenized_seq, tokenized_mask):\n",
+        "        tokenized_seq= tokenized_seq.to(self.device)\n",
+        "        tokenized_mask = tokenized_mask.to(self.device)\n",
+        "        with torch.no_grad():\n",
+        "            x = self.bert.forward(tokenized_seq, attention_mask=tokenized_mask)\n",
+        "            x = x[0][:,0,:].to(self.device)\n",
+        "            \n",
+        "        x = self.linear(x)\n",
+        "        x = self.softmax(x)\n",
+        "        return x"
+      ]
+    },
     {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']\n",
-      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
-      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
-     ]
+      "cell_type": "code",
+      "execution_count": 5,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "# def encode(batch):\n",
+        "#   return tokenizer(batch, padding=\"max_length\", max_length=256, return_tensors=\"pt\")\n",
+        "\n",
+        "\n",
+        "# def convert_to_features(example_batch):\n",
+        "#     input_encodings = encode(example_batch['text'])\n",
+        "#     target_encodings = encode(example_batch['summary'])\n",
+        "\n",
+        "#     labels = target_encodings['input_ids']\n",
+        "#     decoder_input_ids = shift_tokens_right(\n",
+        "#         labels, model.config.pad_token_id, model.config.decoder_start_token_id)\n",
+        "#     labels[labels[:, :] == model.config.pad_token_id] = -100\n",
+        "\n",
+        "#     encodings = {\n",
+        "#         'input_ids': input_encodings['input_ids'],\n",
+        "#         'attention_mask': input_encodings['attention_mask'],\n",
+        "#         'decoder_input_ids': decoder_input_ids,\n",
+        "#         'labels': labels,\n",
+        "#     }\n",
+        "\n",
+        "#     return encodings\n",
+        "\n",
+        "\n",
+        "# def get_dataset(path):\n",
+        "#   df = pd.read_csv(path, sep=\",\", on_bad_lines='skip')\n",
+        "#   dataset = datasets.Dataset.from_pandas(df)\n",
+        "#   dataset = dataset.map(convert_to_features, batched=True)\n",
+        "#   columns = ['input_ids', 'labels', 'decoder_input_ids', 'attention_mask', ]\n",
+        "#   dataset.set_format(type='torch', columns=columns)\n",
+        "#   return dataset\n"
+      ]
     },
     {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Epoch 0 batch: 0 , loss = 5.509496\n",
-      "Epoch 1 batch: 0 , loss = 5.507990\n",
-      "Epoch 2 batch: 0 , loss = 5.506450\n",
-      "Epoch 3 batch: 0 , loss = 5.504800\n",
-      "Epoch 4 batch: 0 , loss = 5.502756\n",
-      "Epoch 5 batch: 0 , loss = 5.500911\n",
-      "Epoch 6 batch: 0 , loss = 5.498587\n",
-      "Epoch 7 batch: 0 , loss = 5.496222\n",
-      "Epoch 8 batch: 0 , loss = 5.493753\n",
-      "Epoch 9 batch: 0 , loss = 5.492003\n",
-      "Epoch 10 batch: 0 , loss = 5.489583\n",
-      "Epoch 11 batch: 0 , loss = 5.487563\n",
-      "Epoch 12 batch: 0 , loss = 5.485127\n",
-      "Epoch 13 batch: 0 , loss = 5.481408\n",
-      "Epoch 14 batch: 0 , loss = 5.479365\n",
-      "Epoch 15 batch: 0 , loss = 5.476476\n",
-      "Epoch 16 batch: 0 , loss = 5.475063\n",
-      "Epoch 17 batch: 0 , loss = 5.473976\n",
-      "Epoch 18 batch: 0 , loss = 5.471124\n",
-      "Epoch 19 batch: 0 , loss = 5.472827\n",
-      "Epoch 20 batch: 0 , loss = 5.466723\n",
-      "Epoch 21 batch: 0 , loss = 5.464960\n",
-      "Epoch 22 batch: 0 , loss = 5.466041\n",
-      "Epoch 23 batch: 0 , loss = 5.462078\n",
-      "Epoch 24 batch: 0 , loss = 5.460968\n",
-      "Epoch 25 batch: 0 , loss = 5.463361\n",
-      "Epoch 26 batch: 0 , loss = 5.462362\n",
-      "Epoch 27 batch: 0 , loss = 5.459974\n",
-      "Epoch 28 batch: 0 , loss = 5.462227\n",
-      "Epoch 29 batch: 0 , loss = 5.457432\n",
-      "Epoch 30 batch: 0 , loss = 5.456643\n",
-      "Epoch 31 batch: 0 , loss = 5.455560\n",
-      "Epoch 32 batch: 0 , loss = 5.454978\n",
-      "Epoch 33 batch: 0 , loss = 5.452355\n",
-      "Epoch 34 batch: 0 , loss = 5.449059\n",
-      "Epoch 35 batch: 0 , loss = 5.450156\n",
-      "Epoch 36 batch: 0 , loss = 5.443484\n",
-      "Epoch 37 batch: 0 , loss = 5.441142\n",
-      "Epoch 38 batch: 0 , loss = 5.438660\n",
-      "Epoch 39 batch: 0 , loss = 5.437134\n",
-      "Epoch 40 batch: 0 , loss = 5.434308\n",
-      "Epoch 41 batch: 0 , loss = 5.431943\n",
-      "Epoch 42 batch: 0 , loss = 5.429959\n",
-      "Epoch 43 batch: 0 , loss = 5.428108\n",
-      "Epoch 44 batch: 0 , loss = 5.426363\n",
-      "Epoch 45 batch: 0 , loss = 5.428246\n",
-      "Epoch 46 batch: 0 , loss = 5.424921\n",
-      "Epoch 47 batch: 0 , loss = 5.425705\n",
-      "Epoch 48 batch: 0 , loss = 5.424184\n",
-      "Epoch 49 batch: 0 , loss = 5.423537\n",
-      "Epoch 50 batch: 0 , loss = 5.424012\n",
-      "Epoch 51 batch: 0 , loss = 5.419333\n",
-      "Epoch 52 batch: 0 , loss = 5.414816\n",
-      "Epoch 53 batch: 0 , loss = 5.418826\n",
-      "Epoch 54 batch: 0 , loss = 5.418625\n",
-      "Epoch 55 batch: 0 , loss = 5.419236\n",
-      "Epoch 56 batch: 0 , loss = 5.421367\n",
-      "Epoch 57 batch: 0 , loss = 5.418901\n",
-      "Epoch 58 batch: 0 , loss = 5.416471\n",
-      "Epoch 59 batch: 0 , loss = 5.414352\n",
-      "Epoch 60 batch: 0 , loss = 5.416772\n",
-      "Epoch 61 batch: 0 , loss = 5.412184\n",
-      "Epoch 62 batch: 0 , loss = 5.407087\n",
-      "Epoch 63 batch: 0 , loss = 5.404146\n",
-      "Epoch 64 batch: 0 , loss = 5.405499\n",
-      "Epoch 65 batch: 0 , loss = 5.404672\n",
-      "Epoch 66 batch: 0 , loss = 5.399035\n",
-      "Epoch 67 batch: 0 , loss = 5.407500\n",
-      "Epoch 68 batch: 0 , loss = 5.403162\n",
-      "Epoch 69 batch: 0 , loss = 5.397717\n",
-      "Epoch 70 batch: 0 , loss = 5.398447\n",
-      "Epoch 71 batch: 0 , loss = 5.398669\n",
-      "Epoch 72 batch: 0 , loss = 5.392782\n",
-      "Epoch 73 batch: 0 , loss = 5.392102\n",
-      "Epoch 74 batch: 0 , loss = 5.386912\n",
-      "Epoch 75 batch: 0 , loss = 5.384616\n",
-      "Epoch 76 batch: 0 , loss = 5.382891\n",
-      "Epoch 77 batch: 0 , loss = 5.381707\n",
-      "Epoch 78 batch: 0 , loss = 5.376785\n",
-      "Epoch 79 batch: 0 , loss = 5.374002\n",
-      "Epoch 80 batch: 0 , loss = 5.376191\n",
-      "Epoch 81 batch: 0 , loss = 5.381535\n",
-      "Epoch 82 batch: 0 , loss = 5.376332\n",
-      "Epoch 83 batch: 0 , loss = 5.372051\n",
-      "Epoch 84 batch: 0 , loss = 5.367785\n",
-      "Epoch 85 batch: 0 , loss = 5.367027\n",
-      "Epoch 86 batch: 0 , loss = 5.366450\n",
-      "Epoch 87 batch: 0 , loss = 5.364227\n",
-      "Epoch 88 batch: 0 , loss = 5.364299\n",
-      "Epoch 89 batch: 0 , loss = 5.363710\n",
-      "Epoch 90 batch: 0 , loss = 5.356105\n",
-      "Epoch 91 batch: 0 , loss = 5.353946\n",
-      "Epoch 92 batch: 0 , loss = 5.356158\n",
-      "Epoch 93 batch: 0 , loss = 5.355119\n",
-      "Epoch 94 batch: 0 , loss = 5.348075\n",
-      "Epoch 95 batch: 0 , loss = 5.351038\n",
-      "Epoch 96 batch: 0 , loss = 5.349224\n",
-      "Epoch 97 batch: 0 , loss = 5.344652\n",
-      "Epoch 98 batch: 0 , loss = 5.341354\n",
-      "Epoch 99 batch: 0 , loss = 5.342059\n",
-      "Epoch 100 batch: 0 , loss = 5.344870\n",
-      "Epoch 101 batch: 0 , loss = 5.335486\n",
-      "Epoch 102 batch: 0 , loss = 5.339808\n",
-      "Epoch 103 batch: 0 , loss = 5.330384\n",
-      "Epoch 104 batch: 0 , loss = 5.333795\n",
-      "Epoch 105 batch: 0 , loss = 5.336934\n",
-      "Epoch 106 batch: 0 , loss = 5.334551\n",
-      "Epoch 107 batch: 0 , loss = 5.328543\n",
-      "Epoch 108 batch: 0 , loss = 5.329463\n",
-      "Epoch 109 batch: 0 , loss = 5.327895\n",
-      "Epoch 110 batch: 0 , loss = 5.324828\n",
-      "Epoch 111 batch: 0 , loss = 5.325212\n",
-      "Epoch 112 batch: 0 , loss = 5.320701\n",
-      "Epoch 113 batch: 0 , loss = 5.327697\n",
-      "Epoch 114 batch: 0 , loss = 5.321108\n",
-      "Epoch 115 batch: 0 , loss = 5.318196\n",
-      "Epoch 116 batch: 0 , loss = 5.313715\n",
-      "Epoch 117 batch: 0 , loss = 5.311478\n",
-      "Epoch 118 batch: 0 , loss = 5.310149\n",
-      "Epoch 119 batch: 0 , loss = 5.304658\n",
-      "Epoch 120 batch: 0 , loss = 5.299352\n",
-      "Epoch 121 batch: 0 , loss = 5.297422\n",
-      "Epoch 122 batch: 0 , loss = 5.296243\n",
-      "Epoch 123 batch: 0 , loss = 5.299733\n",
-      "Epoch 124 batch: 0 , loss = 5.296957\n",
-      "Epoch 125 batch: 0 , loss = 5.296990\n",
-      "Epoch 126 batch: 0 , loss = 5.295414\n",
-      "Epoch 127 batch: 0 , loss = 5.291314\n",
-      "Epoch 128 batch: 0 , loss = 5.288688\n",
-      "Epoch 129 batch: 0 , loss = 5.290800\n",
-      "Epoch 130 batch: 0 , loss = 5.290220\n",
-      "Epoch 131 batch: 0 , loss = 5.292923\n",
-      "Epoch 132 batch: 0 , loss = 5.282723\n",
-      "Epoch 133 batch: 0 , loss = 5.281114\n",
-      "Epoch 134 batch: 0 , loss = 5.284384\n",
-      "Epoch 135 batch: 0 , loss = 5.285651\n",
-      "Epoch 136 batch: 0 , loss = 5.279850\n",
-      "Epoch 137 batch: 0 , loss = 5.276040\n",
-      "Epoch 138 batch: 0 , loss = 5.279545\n",
-      "Epoch 139 batch: 0 , loss = 5.275866\n",
-      "Epoch 140 batch: 0 , loss = 5.275192\n",
-      "Epoch 141 batch: 0 , loss = 5.278750\n",
-      "Epoch 142 batch: 0 , loss = 5.281165\n",
-      "Epoch 143 batch: 0 , loss = 5.281812\n",
-      "Epoch 144 batch: 0 , loss = 5.270550\n",
-      "Epoch 145 batch: 0 , loss = 5.269819\n",
-      "Epoch 146 batch: 0 , loss = 5.268355\n",
-      "Epoch 147 batch: 0 , loss = 5.266929\n",
-      "Epoch 148 batch: 0 , loss = 5.261328\n",
-      "Epoch 149 batch: 0 , loss = 5.268577\n",
-      "Epoch 150 batch: 0 , loss = 5.260501\n",
-      "Epoch 151 batch: 0 , loss = 5.263970\n",
-      "Epoch 152 batch: 0 , loss = 5.261120\n",
-      "Epoch 153 batch: 0 , loss = 5.258860\n",
-      "Epoch 154 batch: 0 , loss = 5.253738\n",
-      "Epoch 155 batch: 0 , loss = 5.261181\n",
-      "Epoch 156 batch: 0 , loss = 5.246259\n",
-      "Epoch 157 batch: 0 , loss = 5.245531\n",
-      "Epoch 158 batch: 0 , loss = 5.244577\n",
-      "Epoch 159 batch: 0 , loss = 5.243086\n",
-      "Epoch 160 batch: 0 , loss = 5.241098\n",
-      "Epoch 161 batch: 0 , loss = 5.245851\n",
-      "Epoch 162 batch: 0 , loss = 5.245241\n",
-      "Epoch 163 batch: 0 , loss = 5.244030\n",
-      "Epoch 164 batch: 0 , loss = 5.238179\n",
-      "Epoch 165 batch: 0 , loss = 5.232537\n",
-      "Epoch 166 batch: 0 , loss = 5.238233\n",
-      "Epoch 167 batch: 0 , loss = 5.244127\n",
-      "Epoch 168 batch: 0 , loss = 5.238196\n",
-      "Epoch 169 batch: 0 , loss = 5.231448\n",
-      "Epoch 170 batch: 0 , loss = 5.231340\n",
-      "Epoch 171 batch: 0 , loss = 5.230532\n",
-      "Epoch 172 batch: 0 , loss = 5.232080\n",
-      "Epoch 173 batch: 0 , loss = 5.228147\n",
-      "Epoch 174 batch: 0 , loss = 5.225359\n",
-      "Epoch 175 batch: 0 , loss = 5.229460\n",
-      "Epoch 176 batch: 0 , loss = 5.223203\n",
-      "Epoch 177 batch: 0 , loss = 5.226225\n",
-      "Epoch 178 batch: 0 , loss = 5.224785\n",
-      "Epoch 179 batch: 0 , loss = 5.232825\n",
-      "Epoch 180 batch: 0 , loss = 5.227349\n",
-      "Epoch 181 batch: 0 , loss = 5.222055\n",
-      "Epoch 182 batch: 0 , loss = 5.212518\n",
-      "Epoch 183 batch: 0 , loss = 5.221508\n",
-      "Epoch 184 batch: 0 , loss = 5.218023\n",
-      "Epoch 185 batch: 0 , loss = 5.221085\n",
-      "Epoch 186 batch: 0 , loss = 5.212870\n",
-      "Epoch 187 batch: 0 , loss = 5.217089\n",
-      "Epoch 188 batch: 0 , loss = 5.220237\n",
-      "Epoch 189 batch: 0 , loss = 5.220267\n",
-      "Epoch 190 batch: 0 , loss = 5.216769\n",
-      "Epoch 191 batch: 0 , loss = 5.218822\n",
-      "Epoch 192 batch: 0 , loss = 5.220790\n",
-      "Epoch 193 batch: 0 , loss = 5.206959\n",
-      "Epoch 194 batch: 0 , loss = 5.204648\n",
-      "Epoch 195 batch: 0 , loss = 5.202166\n",
-      "Epoch 196 batch: 0 , loss = 5.199206\n",
-      "Epoch 197 batch: 0 , loss = 5.200701\n",
-      "Epoch 198 batch: 0 , loss = 5.187927\n",
-      "Epoch 199 batch: 0 , loss = 5.194535\n",
-      "Epoch 200 batch: 0 , loss = 5.190029\n",
-      "Epoch 201 batch: 0 , loss = 5.189040\n",
-      "Epoch 202 batch: 0 , loss = 5.190596\n",
-      "Epoch 203 batch: 0 , loss = 5.189606\n",
-      "Epoch 204 batch: 0 , loss = 5.186440\n",
-      "Epoch 205 batch: 0 , loss = 5.191073\n",
-      "Epoch 206 batch: 0 , loss = 5.184719\n",
-      "Epoch 207 batch: 0 , loss = 5.179452\n",
-      "Epoch 208 batch: 0 , loss = 5.173555\n",
-      "Epoch 209 batch: 0 , loss = 5.180843\n",
-      "Epoch 210 batch: 0 , loss = 5.186315\n",
-      "Epoch 211 batch: 0 , loss = 5.182886\n",
-      "Epoch 212 batch: 0 , loss = 5.179914\n",
-      "Epoch 213 batch: 0 , loss = 5.175862\n",
-      "Epoch 214 batch: 0 , loss = 5.183718\n",
-      "Epoch 215 batch: 0 , loss = 5.162323\n",
-      "Epoch 216 batch: 0 , loss = 5.170660\n",
-      "Epoch 217 batch: 0 , loss = 5.169362\n",
-      "Epoch 218 batch: 0 , loss = 5.162789\n",
-      "Epoch 219 batch: 0 , loss = 5.160868\n",
-      "Epoch 220 batch: 0 , loss = 5.164267\n",
-      "Epoch 221 batch: 0 , loss = 5.164003\n",
-      "Epoch 222 batch: 0 , loss = 5.165540\n",
-      "Epoch 223 batch: 0 , loss = 5.161615\n",
-      "Epoch 224 batch: 0 , loss = 5.152071\n",
-      "Epoch 225 batch: 0 , loss = 5.162641\n",
-      "Epoch 226 batch: 0 , loss = 5.158351\n",
-      "Epoch 227 batch: 0 , loss = 5.167762\n",
-      "Epoch 228 batch: 0 , loss = 5.166461\n",
-      "Epoch 229 batch: 0 , loss = 5.157466\n",
-      "Epoch 230 batch: 0 , loss = 5.156347\n",
-      "Epoch 231 batch: 0 , loss = 5.156344\n",
-      "Epoch 232 batch: 0 , loss = 5.164078\n",
-      "Epoch 233 batch: 0 , loss = 5.155646\n",
-      "Epoch 234 batch: 0 , loss = 5.154130\n",
-      "Epoch 235 batch: 0 , loss = 5.152800\n",
-      "Epoch 236 batch: 0 , loss = 5.161687\n",
-      "Epoch 237 batch: 0 , loss = 5.151773\n",
-      "Epoch 238 batch: 0 , loss = 5.146783\n",
-      "Epoch 239 batch: 0 , loss = 5.146691\n",
-      "Epoch 240 batch: 0 , loss = 5.151015\n",
-      "Epoch 241 batch: 0 , loss = 5.145954\n",
-      "Epoch 242 batch: 0 , loss = 5.154754\n",
-      "Epoch 243 batch: 0 , loss = 5.144487\n",
-      "Epoch 244 batch: 0 , loss = 5.152771\n",
-      "Epoch 245 batch: 0 , loss = 5.156934\n",
-      "Epoch 246 batch: 0 , loss = 5.146832\n",
-      "Epoch 247 batch: 0 , loss = 5.148869\n",
-      "Epoch 248 batch: 0 , loss = 5.142589\n",
-      "Epoch 249 batch: 0 , loss = 5.141822\n",
-      "Epoch 250 batch: 0 , loss = 5.144300\n",
-      "Epoch 251 batch: 0 , loss = 5.140250\n",
-      "Epoch 252 batch: 0 , loss = 5.135976\n",
-      "Epoch 253 batch: 0 , loss = 5.144660\n",
-      "Epoch 254 batch: 0 , loss = 5.144534\n",
-      "Epoch 255 batch: 0 , loss = 5.140327\n",
-      "Epoch 256 batch: 0 , loss = 5.139380\n",
-      "Epoch 257 batch: 0 , loss = 5.149578\n",
-      "Epoch 258 batch: 0 , loss = 5.139646\n",
-      "Epoch 259 batch: 0 , loss = 5.139930\n",
-      "Epoch 260 batch: 0 , loss = 5.143685\n",
-      "Epoch 261 batch: 0 , loss = 5.138163\n",
-      "Epoch 262 batch: 0 , loss = 5.138759\n",
-      "Epoch 263 batch: 0 , loss = 5.133649\n",
-      "Epoch 264 batch: 0 , loss = 5.137669\n",
-      "Epoch 265 batch: 0 , loss = 5.133360\n",
-      "Epoch 266 batch: 0 , loss = 5.137102\n",
-      "Epoch 267 batch: 0 , loss = 5.144728\n",
-      "Epoch 268 batch: 0 , loss = 5.135507\n",
-      "Epoch 269 batch: 0 , loss = 5.133213\n",
-      "Epoch 270 batch: 0 , loss = 5.130654\n",
-      "Epoch 271 batch: 0 , loss = 5.128548\n",
-      "Epoch 272 batch: 0 , loss = 5.118666\n",
-      "Epoch 273 batch: 0 , loss = 5.121534\n",
-      "Epoch 274 batch: 0 , loss = 5.122122\n",
-      "Epoch 275 batch: 0 , loss = 5.127708\n",
-      "Epoch 276 batch: 0 , loss = 5.132959\n",
-      "Epoch 277 batch: 0 , loss = 5.129933\n",
-      "Epoch 278 batch: 0 , loss = 5.125852\n",
-      "Epoch 279 batch: 0 , loss = 5.120491\n",
-      "Epoch 280 batch: 0 , loss = 5.123435\n",
-      "Epoch 281 batch: 0 , loss = 5.129238\n",
-      "Epoch 282 batch: 0 , loss = 5.124735\n",
-      "Epoch 283 batch: 0 , loss = 5.116185\n",
-      "Epoch 284 batch: 0 , loss = 5.113766\n",
-      "Epoch 285 batch: 0 , loss = 5.114601\n",
-      "Epoch 286 batch: 0 , loss = 5.112402\n",
-      "Epoch 287 batch: 0 , loss = 5.111151\n",
-      "Epoch 288 batch: 0 , loss = 5.117523\n",
-      "Epoch 289 batch: 0 , loss = 5.117987\n",
-      "Epoch 290 batch: 0 , loss = 5.112226\n",
-      "Epoch 291 batch: 0 , loss = 5.103069\n",
-      "Epoch 292 batch: 0 , loss = 5.105778\n",
-      "Epoch 293 batch: 0 , loss = 5.117328\n",
-      "Epoch 294 batch: 0 , loss = 5.117546\n",
-      "Epoch 295 batch: 0 , loss = 5.114745\n",
-      "Epoch 296 batch: 0 , loss = 5.111210\n",
-      "Epoch 297 batch: 0 , loss = 5.116782\n",
-      "Epoch 298 batch: 0 , loss = 5.126735\n",
-      "Epoch 299 batch: 0 , loss = 5.110609\n",
-      "Epoch 300 batch: 0 , loss = 5.119059\n",
-      "Epoch 301 batch: 0 , loss = 5.115992\n",
-      "Epoch 302 batch: 0 , loss = 5.106349\n",
-      "Epoch 303 batch: 0 , loss = 5.104899\n",
-      "Epoch 304 batch: 0 , loss = 5.104232\n",
-      "Epoch 305 batch: 0 , loss = 5.112337\n",
-      "Epoch 306 batch: 0 , loss = 5.111980\n",
-      "Epoch 307 batch: 0 , loss = 5.106858\n",
-      "Epoch 308 batch: 0 , loss = 5.109094\n",
-      "Epoch 309 batch: 0 , loss = 5.111865\n",
-      "Epoch 310 batch: 0 , loss = 5.096600\n",
-      "Epoch 311 batch: 0 , loss = 5.106102\n",
-      "Epoch 312 batch: 0 , loss = 5.102217\n",
-      "Epoch 313 batch: 0 , loss = 5.111903\n",
-      "Epoch 314 batch: 0 , loss = 5.102543\n",
-      "Epoch 315 batch: 0 , loss = 5.101692\n",
-      "Epoch 316 batch: 0 , loss = 5.114990\n",
-      "Epoch 317 batch: 0 , loss = 5.108889\n",
-      "Epoch 318 batch: 0 , loss = 5.102445\n",
-      "Epoch 319 batch: 0 , loss = 5.096200\n",
-      "Epoch 320 batch: 0 , loss = 5.100635\n",
-      "Epoch 321 batch: 0 , loss = 5.099531\n",
-      "Epoch 322 batch: 0 , loss = 5.100465\n",
-      "Epoch 323 batch: 0 , loss = 5.095500\n",
-      "Epoch 324 batch: 0 , loss = 5.100604\n",
-      "Epoch 325 batch: 0 , loss = 5.099518\n",
-      "Epoch 326 batch: 0 , loss = 5.094840\n",
-      "Epoch 327 batch: 0 , loss = 5.093679\n",
-      "Epoch 328 batch: 0 , loss = 5.093080\n",
-      "Epoch 329 batch: 0 , loss = 5.095733\n",
-      "Epoch 330 batch: 0 , loss = 5.099861\n",
-      "Epoch 331 batch: 0 , loss = 5.090419\n",
-      "Epoch 332 batch: 0 , loss = 5.099437\n",
-      "Epoch 333 batch: 0 , loss = 5.098624\n",
-      "Epoch 334 batch: 0 , loss = 5.093289\n",
-      "Epoch 335 batch: 0 , loss = 5.097763\n",
-      "Epoch 336 batch: 0 , loss = 5.097146\n",
-      "Epoch 337 batch: 0 , loss = 5.101106\n",
-      "Epoch 338 batch: 0 , loss = 5.082275\n",
-      "Epoch 339 batch: 0 , loss = 5.087108\n",
-      "Epoch 340 batch: 0 , loss = 5.092202\n",
-      "Epoch 341 batch: 0 , loss = 5.082821\n",
-      "Epoch 342 batch: 0 , loss = 5.092401\n",
-      "Epoch 343 batch: 0 , loss = 5.097392\n",
-      "Epoch 344 batch: 0 , loss = 5.097518\n",
-      "Epoch 345 batch: 0 , loss = 5.095610\n",
-      "Epoch 346 batch: 0 , loss = 5.089358\n",
-      "Epoch 347 batch: 0 , loss = 5.084134\n",
-      "Epoch 348 batch: 0 , loss = 5.093090\n",
-      "Epoch 349 batch: 0 , loss = 5.094424\n",
-      "Epoch 350 batch: 0 , loss = 5.089258\n",
-      "Epoch 351 batch: 0 , loss = 5.089724\n",
-      "Epoch 352 batch: 0 , loss = 5.092799\n",
-      "Epoch 353 batch: 0 , loss = 5.089324\n",
-      "Epoch 354 batch: 0 , loss = 5.093751\n",
-      "Epoch 355 batch: 0 , loss = 5.088273\n",
-      "Epoch 356 batch: 0 , loss = 5.083136\n",
-      "Epoch 357 batch: 0 , loss = 5.087886\n",
-      "Epoch 358 batch: 0 , loss = 5.092871\n",
-      "Epoch 359 batch: 0 , loss = 5.077843\n",
-      "Epoch 360 batch: 0 , loss = 5.087354\n",
-      "Epoch 361 batch: 0 , loss = 5.077330\n",
-      "Epoch 362 batch: 0 , loss = 5.096845\n",
-      "Epoch 363 batch: 0 , loss = 5.081802\n",
-      "Epoch 364 batch: 0 , loss = 5.086153\n",
-      "Epoch 365 batch: 0 , loss = 5.081354\n",
-      "Epoch 366 batch: 0 , loss = 5.084242\n",
-      "Epoch 367 batch: 0 , loss = 5.094817\n",
-      "Epoch 368 batch: 0 , loss = 5.089900\n",
-      "Epoch 369 batch: 0 , loss = 5.089128\n",
-      "Epoch 370 batch: 0 , loss = 5.084967\n",
-      "Epoch 371 batch: 0 , loss = 5.085038\n",
-      "Epoch 372 batch: 0 , loss = 5.084878\n",
-      "Epoch 373 batch: 0 , loss = 5.085416\n",
-      "Epoch 374 batch: 0 , loss = 5.084856\n",
-      "Epoch 375 batch: 0 , loss = 5.085449\n",
-      "Epoch 376 batch: 0 , loss = 5.080372\n",
-      "Epoch 377 batch: 0 , loss = 5.079864\n",
-      "Epoch 378 batch: 0 , loss = 5.089789\n",
-      "Epoch 379 batch: 0 , loss = 5.084702\n",
-      "Epoch 380 batch: 0 , loss = 5.080189\n",
-      "Epoch 381 batch: 0 , loss = 5.080185\n",
-      "Epoch 382 batch: 0 , loss = 5.080244\n",
-      "Epoch 383 batch: 0 , loss = 5.080095\n",
-      "Epoch 384 batch: 0 , loss = 5.088795\n",
-      "Epoch 385 batch: 0 , loss = 5.084120\n",
-      "Epoch 386 batch: 0 , loss = 5.087992\n",
-      "Epoch 387 batch: 0 , loss = 5.082991\n",
-      "Epoch 388 batch: 0 , loss = 5.072812\n",
-      "Epoch 389 batch: 0 , loss = 5.077140\n",
-      "Epoch 390 batch: 0 , loss = 5.081978\n",
-      "Epoch 391 batch: 0 , loss = 5.086642\n",
-      "Epoch 392 batch: 0 , loss = 5.091150\n",
-      "Epoch 393 batch: 0 , loss = 5.081849\n",
-      "Epoch 394 batch: 0 , loss = 5.071902\n",
-      "Epoch 395 batch: 0 , loss = 5.090557\n",
-      "Epoch 396 batch: 0 , loss = 5.085775\n",
-      "Epoch 397 batch: 0 , loss = 5.076406\n",
-      "Epoch 398 batch: 0 , loss = 5.086689\n",
-      "Epoch 399 batch: 0 , loss = 5.082210\n",
-      "Epoch 400 batch: 0 , loss = 5.091569\n",
-      "Epoch 401 batch: 0 , loss = 5.082448\n",
-      "Epoch 402 batch: 0 , loss = 5.081671\n",
-      "Epoch 403 batch: 0 , loss = 5.086580\n",
-      "Epoch 404 batch: 0 , loss = 5.081904\n",
-      "Epoch 405 batch: 0 , loss = 5.082079\n",
-      "Epoch 406 batch: 0 , loss = 5.081985\n",
-      "Epoch 407 batch: 0 , loss = 5.086331\n",
-      "Epoch 408 batch: 0 , loss = 5.076781\n",
-      "Epoch 409 batch: 0 , loss = 5.086619\n",
-      "Epoch 410 batch: 0 , loss = 5.085722\n",
-      "Epoch 411 batch: 0 , loss = 5.081203\n",
-      "Epoch 412 batch: 0 , loss = 5.080491\n",
-      "Epoch 413 batch: 0 , loss = 5.086140\n",
-      "Epoch 414 batch: 0 , loss = 5.076138\n",
-      "Epoch 415 batch: 0 , loss = 5.090469\n",
-      "Epoch 416 batch: 0 , loss = 5.075500\n",
-      "Epoch 417 batch: 0 , loss = 5.080328\n",
-      "Epoch 418 batch: 0 , loss = 5.084836\n",
-      "Epoch 419 batch: 0 , loss = 5.079198\n",
-      "Epoch 420 batch: 0 , loss = 5.068509\n",
-      "Epoch 421 batch: 0 , loss = 5.081526\n",
-      "Epoch 422 batch: 0 , loss = 5.075669\n",
-      "Epoch 423 batch: 0 , loss = 5.070409\n",
-      "Epoch 424 batch: 0 , loss = 5.069685\n",
-      "Epoch 425 batch: 0 , loss = 5.074168\n",
-      "Epoch 426 batch: 0 , loss = 5.073371\n",
-      "Epoch 427 batch: 0 , loss = 5.077053\n",
-      "Epoch 428 batch: 0 , loss = 5.076373\n",
-      "Epoch 429 batch: 0 , loss = 5.065866\n",
-      "Epoch 430 batch: 0 , loss = 5.069530\n",
-      "Epoch 431 batch: 0 , loss = 5.063504\n",
-      "Epoch 432 batch: 0 , loss = 5.066977\n",
-      "Epoch 433 batch: 0 , loss = 5.061393\n",
-      "Epoch 434 batch: 0 , loss = 5.064750\n",
-      "Epoch 435 batch: 0 , loss = 5.063208\n",
-      "Epoch 436 batch: 0 , loss = 5.062410\n",
-      "Epoch 437 batch: 0 , loss = 5.056867\n",
-      "Epoch 438 batch: 0 , loss = 5.061165\n",
-      "Epoch 439 batch: 0 , loss = 5.055538\n",
-      "Epoch 440 batch: 0 , loss = 5.064088\n",
-      "Epoch 441 batch: 0 , loss = 5.068861\n",
-      "Epoch 442 batch: 0 , loss = 5.058971\n",
-      "Epoch 443 batch: 0 , loss = 5.054216\n",
-      "Epoch 444 batch: 0 , loss = 5.053985\n",
-      "Epoch 445 batch: 0 , loss = 5.053171\n",
-      "Epoch 446 batch: 0 , loss = 5.052192\n",
-      "Epoch 447 batch: 0 , loss = 5.055376\n",
-      "Epoch 448 batch: 0 , loss = 5.048327\n",
-      "Epoch 449 batch: 0 , loss = 5.047246\n",
-      "Epoch 450 batch: 0 , loss = 5.056007\n",
-      "Epoch 451 batch: 0 , loss = 5.045246\n",
-      "Epoch 452 batch: 0 , loss = 5.048958\n",
-      "Epoch 453 batch: 0 , loss = 5.043387\n",
-      "Epoch 454 batch: 0 , loss = 5.052229\n",
-      "Epoch 455 batch: 0 , loss = 5.036552\n",
-      "Epoch 456 batch: 0 , loss = 5.047226\n",
-      "Epoch 457 batch: 0 , loss = 5.047490\n",
-      "Epoch 458 batch: 0 , loss = 5.045058\n",
-      "Epoch 459 batch: 0 , loss = 5.034773\n",
-      "Epoch 460 batch: 0 , loss = 5.037952\n",
-      "Epoch 461 batch: 0 , loss = 5.040687\n",
-      "Epoch 462 batch: 0 , loss = 5.031157\n",
-      "Epoch 463 batch: 0 , loss = 5.032485\n",
-      "Epoch 464 batch: 0 , loss = 5.017200\n",
-      "Epoch 465 batch: 0 , loss = 5.020655\n",
-      "Epoch 466 batch: 0 , loss = 5.024525\n",
-      "Epoch 467 batch: 0 , loss = 5.021727\n",
-      "Epoch 468 batch: 0 , loss = 5.015924\n",
-      "Epoch 469 batch: 0 , loss = 5.019942\n",
-      "Epoch 470 batch: 0 , loss = 5.024877\n",
-      "Epoch 471 batch: 0 , loss = 5.019931\n",
-      "Epoch 472 batch: 0 , loss = 5.023381\n",
-      "Epoch 473 batch: 0 , loss = 5.007413\n",
-      "Epoch 474 batch: 0 , loss = 5.015634\n",
-      "Epoch 475 batch: 0 , loss = 5.019332\n",
-      "Epoch 476 batch: 0 , loss = 5.023258\n",
-      "Epoch 477 batch: 0 , loss = 5.007919\n",
-      "Epoch 478 batch: 0 , loss = 5.012198\n",
-      "Epoch 479 batch: 0 , loss = 5.016301\n",
-      "Epoch 480 batch: 0 , loss = 5.015327\n",
-      "Epoch 481 batch: 0 , loss = 5.015544\n",
-      "Epoch 482 batch: 0 , loss = 5.015000\n",
-      "Epoch 483 batch: 0 , loss = 5.009738\n",
-      "Epoch 484 batch: 0 , loss = 5.004784\n",
-      "Epoch 485 batch: 0 , loss = 5.009272\n",
-      "Epoch 486 batch: 0 , loss = 5.014519\n",
-      "Epoch 487 batch: 0 , loss = 5.014775\n",
-      "Epoch 488 batch: 0 , loss = 5.018719\n",
-      "Epoch 489 batch: 0 , loss = 5.009684\n",
-      "Epoch 490 batch: 0 , loss = 5.014743\n",
-      "Epoch 491 batch: 0 , loss = 5.015655\n",
-      "Epoch 492 batch: 0 , loss = 5.011545\n",
-      "Epoch 493 batch: 0 , loss = 5.015726\n",
-      "Epoch 494 batch: 0 , loss = 5.009212\n",
-      "Epoch 495 batch: 0 , loss = 5.008894\n",
-      "Epoch 496 batch: 0 , loss = 5.002744\n",
-      "Epoch 497 batch: 0 , loss = 5.007370\n",
-      "Epoch 498 batch: 0 , loss = 5.007164\n",
-      "Epoch 499 batch: 0 , loss = 5.007099\n"
-     ]
-    }
-   ],
-   "source": [
-    "# Train with data loader.\n",
-    "\n",
-    "\n",
-    "\n",
-    "model = NgmOne(device)\n",
-    "criterion = nn.CrossEntropyLoss()\n",
-    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
-    "\n",
-    "epoch = 500\n",
-    "batch_size = 200\n",
-    "for e in range(epoch):\n",
-    "    train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
-    "    for i_batch, sample_batched in enumerate(train_dataloader):\n",
-    "        optimizer.zero_grad()\n",
-    "        train = sample_batched[0]\n",
-    "        train_mask = sample_batched[1]\n",
-    "        label_index = sample_batched[2].to(device)\n",
-    "\n",
-    "        # Forward pass\n",
-    "        output = model(train, train_mask)\n",
-    "        loss = criterion(output, label_index)\n",
-    "\n",
-    "        # backward and optimize\n",
-    "        loss.backward()\n",
-    "        optimizer.step()\n",
-    "        if i_batch % batch_size == 0:\n",
-    "            print(\"Epoch\", e, \"batch:\",i_batch, ', loss =', '{:.6f}'.format(loss))\n",
-    "    \n",
-    "    \n",
-    "    \n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "model = NgmOne(device)\n",
-    "\n",
-    "EPOCHS = 1500\n",
-    "criterion = nn.CrossEntropyLoss()\n",
-    "optimizer = optim.Adam(model.parameters(), lr=0.007)\n",
-    "\n",
-    "with open(\"../data/relations-query-qald-9-linked.json\", \"r\") as f:\n",
-    "    relations = json.load(f)\n",
-    "\n",
-    "train, train_mask, corr_rels = make_batch()\n",
-    "corr_indx = torch.LongTensor([relations.index(r) for r in corr_rels]).to(device)\n",
-    "\n",
-    "for epoch in range(EPOCHS):\n",
-    "    optimizer.zero_grad()\n",
-    "\n",
-    "    # Forward pass\n",
-    "    output = model(train, train_mask)\n",
-    "    loss = criterion(output, corr_indx)\n",
-    "\n",
-    "    if (epoch + 1) % 1 == 0:\n",
-    "        print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))\n",
-    "    # Backward pass\n",
-    "    loss.backward()\n",
-    "    optimizer.step()\n",
-    "    \n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 66,
-   "metadata": {},
-   "outputs": [
+      "cell_type": "code",
+      "execution_count": 42,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "def make_batch():\n",
+        "    \"\"\"Triplet is a list of [subject entity, relation, object entity], None if not present\"\"\"\n",
+        "\n",
+        "    # Load predicted data\n",
+        "    pred = \"../data/qald-9-train-linked.json\"\n",
+        "\n",
+        "    #Load gold data\n",
+        "    gold = \"../data/qald-9-train-linked.json\"\n",
+        "    print(\"Beginning making batch\")\n",
+        "    with open(pred, \"r\") as p, open(gold, \"r\") as g:\n",
+        "        pred = json.load(p)\n",
+        "        gold = json.load(g)\n",
+        "\n",
+        "    inputs = []\n",
+        "    correct_rels = []\n",
+        "    inputs_max_len = 0\n",
+        "    for d in tqdm(pred[\"questions\"]):\n",
+        "        question = d[\"question\"][0][\"string\"]\n",
+        "        query = d[\"query\"][\"sparql\"]\n",
+        "\n",
+        "        #Take the first tripletin query\n",
+        "        trip = query.split(\"WHERE {\")[1]\n",
+        "        trip = trip.split(\"}\")[0]\n",
+        "        trip = trip.split(\"FILTER \")[0]\n",
+        "        trip = trip.replace(\"{\", \"\").replace(\"}\", \"\")\n",
+        "        trip = trip.replace(\".\", \"\")\n",
+        "        trip = trip.replace(\";\", \"\")\n",
+        "        triplet = trip.split(\" \")\n",
+        "\n",
+        "        #remove empty strings\n",
+        "        triplet = [x for x in triplet if x != \"\"]\n",
+        "\n",
+        "        if len(triplet) % 3 == 0 and \" \".join(triplet).find(\"rdf\") == -1:\n",
+        "            for i in range(len(triplet)//3):\n",
+        "                triplet_i = triplet[i*3:i*3+3]\n",
+        "                for t in triplet_i:\n",
+        "                    if not(t.find(\"?\")):\n",
+        "                        triplet_i[triplet_i.index(t)] = None\n",
+        "\n",
+        "                #seq = \"[CLS] \" + question + \" [SEP] \"\n",
+        "                if triplet_i[0] is not None and triplet_i[1] is not None:\n",
+        "                    #seq += \"[SUB] [SEP] \" + triplet[0]\n",
+        "                    # , padding=True, truncation=True)\n",
+        "                    tokenized_seq = tokenizer(question, \"[SUB]\", triplet_i[0], padding=True, truncation=True)\n",
+        "                elif triplet_i[2] is not None and triplet_i[1] is not None:\n",
+        "                    #seq += \"[OBJ] [SEP] \" + triplet[2]\n",
+        "                    tokenized_seq = tokenizer(question, \"[OBJ]\", triplet_i[2], padding=True, truncation=True)\n",
+        "                else:\n",
+        "                    continue\n",
+        "\n",
+        "                if inputs_max_len < len(tokenized_seq[\"input_ids\"]):\n",
+        "                    inputs_max_len = len(tokenized_seq[\"input_ids\"])\n",
+        "                inputs.append(list(tokenized_seq.values())[0])\n",
+        "                correct_rels.append(triplet_i[1])\n",
+        "\n",
+        "    inputs_padded = np.array([i + [0]*(inputs_max_len-len(i)) for i in inputs])\n",
+        "    #correct_rels_padded = np.array([i + [0]*(correct_rels_max_len-len(i)) for i in correct_rels])\n",
+        "\n",
+        "    inputs_attention_mask = np.where(inputs_padded != 0, 1, 0)\n",
+        "    #correct_rels_attention_mask = np.where(correct_rels_padded != 0, 1, 0)\n",
+        "    print(\"Finished with batches\")\n",
+        "    return torch.LongTensor(inputs_padded), torch.LongTensor(inputs_attention_mask), correct_rels #torch.IntTensor(correct_rels_padded), torch.LongTensor(correct_rels_attention_mask)\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "# training_args = Seq2SeqTrainingArguments(\n",
+        "#     output_dir='./models/blackbox',\n",
+        "#     num_train_epochs=1,\n",
+        "#     per_device_train_batch_size=1,\n",
+        "#     per_device_eval_batch_size=1,\n",
+        "#     warmup_steps=10,\n",
+        "#     weight_decay=0.01,\n",
+        "#     logging_dir='./logs',\n",
+        "# )\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 44,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "from torch.utils.data import Dataset, DataLoader\n",
+        "\n",
+        "class MyDataset(Dataset):\n",
+        "    def __init__(self, inputs, attention_mask, correct_rels, relations):\n",
+        "        self.inputs = inputs\n",
+        "        self.attention_mask = attention_mask\n",
+        "        self.correct_rels = correct_rels\n",
+        "        self.relations = relations\n",
+        "        \n",
+        "\n",
+        "    def __len__(self):\n",
+        "        return len(self.inputs)\n",
+        "\n",
+        "    def __getitem__(self, idx):\n",
+        "        return self.inputs[idx], self.attention_mask[idx], self.relations.index(self.correct_rels[idx])\n",
+        "\n",
+        "\n",
+        "\n",
+        "\n",
+        "#From scratch json creates data set.\n",
+        "# class MyDataset(Dataset):\n",
+        "#     def __init__(self, json_file, transform=None):\n",
+        "#         self.qald_data = json.load(json_file)\n",
+        "\n",
+        "#     def __len__(self):\n",
+        "#         return len(self.qald_data)\n",
+        "\n",
+        "#     def __getitem__(self, idx):\n",
+        "#         self.inputs[idx], self.attention_mask[idx], self.labels[idx]"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 45,
+      "metadata": {},
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "Beginning making batch\n"
+          ]
+        },
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias']\n",
+            "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+            "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
+          ]
+        },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "Finished with batches\n",
+            "features: tensor([[  101,  2040,  2003,  1996,  3677,  1997,  1996,  4035,  6870, 19247,\n",
+            "          1029,   102,  1031,  4942,  1033,   102,     0,     0,     0,     0,\n",
+            "             0,     0,     0]]) mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]]) label_index tensor(81)\n"
+          ]
+        },
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "\n"
+          ]
+        }
+      ],
+      "source": [
+        "#Prepare data\n",
+        "\n",
+        "def open_json(file):\n",
+        "    with open(file, \"r\") as f:\n",
+        "        return json.load(f)\n",
+        "\n",
+        "relations = open_json(\"../data/relations-query-qald-9-linked.json\")\n",
+        "train_set = MyDataset(*make_batch(), relations=relations)\n",
+        "train_dataloader = DataLoader(train_set, batch_size=1, shuffle=True)\n",
+        "\n",
+        "#show first entry\n",
+        "train_features, train_mask, train_label = next(iter(train_dataloader))\n",
+        "print(\"features:\", train_features, \"mask:\",train_mask,\"label_index\", train_label[0])"
+      ]
+    },
     {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Beginning making batch\n"
-     ]
+      "cell_type": "code",
+      "execution_count": 65,
+      "metadata": {},
+      "outputs": [
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']\n",
+            "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+            "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
+          ]
+        },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "Finished with batches\n"
+          ]
+        },
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "\n"
+          ]
+        },
+        {
+          "ename": "ValueError",
+          "evalue": "'<http://dbpediaorg/property/launchPad>' is not in list",
+          "output_type": "error",
+          "traceback": [
+            "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+            "\u001b[1;31mValueError\u001b[0m                                Traceback (most recent call last)",
+            "Cell \u001b[1;32mIn [43], line 11\u001b[0m\n\u001b[0;32m      8\u001b[0m     relations \u001b[39m=\u001b[39m json\u001b[39m.\u001b[39mload(f)\n\u001b[0;32m     10\u001b[0m train, train_mask, corr_rels \u001b[39m=\u001b[39m make_batch()\n\u001b[1;32m---> 11\u001b[0m corr_indx \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mLongTensor([relations\u001b[39m.\u001b[39mindex(r) \u001b[39mfor\u001b[39;00m r \u001b[39min\u001b[39;00m corr_rels])\u001b[39m.\u001b[39mto(device)\n\u001b[0;32m     13\u001b[0m \u001b[39mfor\u001b[39;00m epoch \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(EPOCHS):\n\u001b[0;32m     14\u001b[0m     optimizer\u001b[39m.\u001b[39mzero_grad()\n",
+            "Cell \u001b[1;32mIn [43], line 11\u001b[0m, in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m      8\u001b[0m     relations \u001b[39m=\u001b[39m json\u001b[39m.\u001b[39mload(f)\n\u001b[0;32m     10\u001b[0m train, train_mask, corr_rels \u001b[39m=\u001b[39m make_batch()\n\u001b[1;32m---> 11\u001b[0m corr_indx \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mLongTensor([relations\u001b[39m.\u001b[39;49mindex(r) \u001b[39mfor\u001b[39;00m r \u001b[39min\u001b[39;00m corr_rels])\u001b[39m.\u001b[39mto(device)\n\u001b[0;32m     13\u001b[0m \u001b[39mfor\u001b[39;00m epoch \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(EPOCHS):\n\u001b[0;32m     14\u001b[0m     optimizer\u001b[39m.\u001b[39mzero_grad()\n",
+            "\u001b[1;31mValueError\u001b[0m: '<http://dbpediaorg/property/launchPad>' is not in list"
+          ]
+        }
+      ],
+      "source": [
+        "# Train with data loader.\n",
+        "\n",
+        "\n",
+        "\n",
+        "model = NgmOne(device)\n",
+        "criterion = nn.CrossEntropyLoss()\n",
+        "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
+        "\n",
+        "epoch = 500\n",
+        "batch_size = 200\n",
+        "for e in range(epoch):\n",
+        "    train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
+        "    for i_batch, sample_batched in enumerate(train_dataloader):\n",
+        "        optimizer.zero_grad()\n",
+        "        train = sample_batched[0]\n",
+        "        train_mask = sample_batched[1]\n",
+        "        label_index = sample_batched[2].to(device)\n",
+        "\n",
+        "        # Forward pass\n",
+        "        output = model(train, train_mask)\n",
+        "        loss = criterion(output, label_index)\n",
+        "\n",
+        "        # backward and optimize\n",
+        "        loss.backward()\n",
+        "        optimizer.step()\n",
+        "        if i_batch % batch_size == 0:\n",
+        "            print(\"Epoch\", e, \"batch:\",i_batch, ', loss =', '{:.6f}'.format(loss))\n",
+        "    \n",
+        "    \n",
+        "    \n",
+        "\n"
+      ]
     },
     {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "100%|██████████| 408/408 [00:00<00:00, 962.29it/s] \n",
-      "100%|██████████| 408/408 [00:00<00:00, 81687.72it/s]\n"
-     ]
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "model = NgmOne(device)\n",
+        "\n",
+        "EPOCHS = 1500\n",
+        "criterion = nn.CrossEntropyLoss()\n",
+        "optimizer = optim.Adam(model.parameters(), lr=0.007)\n",
+        "\n",
+        "with open(\"../data/relations-query-qald-9-linked.json\", \"r\") as f:\n",
+        "    relations = json.load(f)\n",
+        "\n",
+        "train, train_mask, corr_rels = make_batch()\n",
+        "corr_indx = torch.LongTensor([relations.index(r) for r in corr_rels]).to(device)\n",
+        "\n",
+        "for epoch in range(EPOCHS):\n",
+        "    optimizer.zero_grad()\n",
+        "\n",
+        "    # Forward pass\n",
+        "    output = model(train, train_mask)\n",
+        "    loss = criterion(output, corr_indx)\n",
+        "\n",
+        "    if (epoch + 1) % 1 == 0:\n",
+        "        print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))\n",
+        "    # Backward pass\n",
+        "    loss.backward()\n",
+        "    optimizer.step()\n",
+        "    \n"
+      ]
     },
     {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Finished with batches\n",
-      "lowest confidence 0.14628477\n",
-      "Accuracy: 0.5245098039215687\n"
-     ]
+      "cell_type": "code",
+      "execution_count": 66,
+      "metadata": {},
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "Beginning making batch\n"
+          ]
+        },
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "100%|██████████| 408/408 [00:00<00:00, 962.29it/s] \n",
+            "100%|██████████| 408/408 [00:00<00:00, 81687.72it/s]\n"
+          ]
+        },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "Finished with batches\n",
+            "lowest confidence 0.14628477\n",
+            "Accuracy: 0.5245098039215687\n"
+          ]
+        }
+      ],
+      "source": [
+        "# Predict\n",
+        "train, train_mask, corr_rels = make_batch()\n",
+        "with torch.no_grad():\n",
+        "    output = model(train, train_mask)\n",
+        "\n",
+        "output = output.detach().cpu().numpy()\n",
+        "\n",
+        "prediction = [relations[np.argmax(pred).item()]for pred in output]\n",
+        "probability = [pred[np.argmax(pred)] for pred in output]\n",
+        "correct_pred = [corr_rels[i] for i in range(len(output))]\n",
+        "\n",
+        "\n",
+        "\n",
+        "preds = [\n",
+        "    {\"pred\": relations[np.argmax(pred).item()], \n",
+        "     \"prob\": pred[np.argmax(pred)],\n",
+        "     \"correct\": corr_rels[count]} \n",
+        "    for count, pred in enumerate(output)\n",
+        "    ]\n",
+        "\n",
+        "\n",
+        "\n",
+        "# for pred in preds:\n",
+        "#     print(\"pred\", pred[\"pred\"], \"prob\", pred[\"prob\"], \"correct\", pred[\"correct\"])\n",
+        "\n",
+        "# for pred, prob, correct_pred in zip(prediction, probability, correct_pred):\n",
+        "#     print(\"pred:\", pred, \"prob:\", prob, \"| correct\", correct_pred)\n",
+        "\n",
+        "print(\"lowest confidence\", min(probability))\n",
+        "\n",
+        "def accuracy_score(y_true, y_pred):\n",
+        "    corr_preds=0\n",
+        "    wrong_preds=0\n",
+        "    for pred, correct in zip(y_pred, y_true):\n",
+        "        if pred == correct:\n",
+        "            corr_preds += 1\n",
+        "        else:\n",
+        "            wrong_preds += 1\n",
+        "    return corr_preds/(corr_preds+wrong_preds)\n",
+        "\n",
+        "print(\"Accuracy:\", accuracy_score(correct_pred, prediction))\n",
+        "\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {},
+      "outputs": [],
+      "source": []
+    }
+  ],
+  "metadata": {
+    "kernelspec": {
+      "display_name": "Python 3.9.11 64-bit",
+      "language": "python",
+      "name": "python3"
+    },
+    "language_info": {
+      "codemirror_mode": {
+        "name": "ipython",
+        "version": 3
+      },
+      "file_extension": ".py",
+      "mimetype": "text/x-python",
+      "name": "python",
+      "nbconvert_exporter": "python",
+      "pygments_lexer": "ipython3",
+      "version": "3.9.11"
+    },
+    "orig_nbformat": 4,
+    "vscode": {
+      "interpreter": {
+        "hash": "64e7cd3b4b88defe39dd61a4584920400d6beb2615ab2244e340c2e20eecdfe9"
+      }
     }
-   ],
-   "source": [
-    "# Predict\n",
-    "train, train_mask, corr_rels = make_batch()\n",
-    "with torch.no_grad():\n",
-    "    output = model(train, train_mask)\n",
-    "\n",
-    "output = output.detach().cpu().numpy()\n",
-    "\n",
-    "prediction = [relations[np.argmax(pred).item()]for pred in output]\n",
-    "probability = [pred[np.argmax(pred)] for pred in output]\n",
-    "correct_pred = [corr_rels[i] for i in range(len(output))]\n",
-    "\n",
-    "\n",
-    "\n",
-    "preds = [\n",
-    "    {\"pred\": relations[np.argmax(pred).item()], \n",
-    "     \"prob\": pred[np.argmax(pred)],\n",
-    "     \"correct\": corr_rels[count]} \n",
-    "    for count, pred in enumerate(output)\n",
-    "    ]\n",
-    "\n",
-    "\n",
-    "\n",
-    "# for pred in preds:\n",
-    "#     print(\"pred\", pred[\"pred\"], \"prob\", pred[\"prob\"], \"correct\", pred[\"correct\"])\n",
-    "\n",
-    "# for pred, prob, correct_pred in zip(prediction, probability, correct_pred):\n",
-    "#     print(\"pred:\", pred, \"prob:\", prob, \"| correct\", correct_pred)\n",
-    "\n",
-    "print(\"lowest confidence\", min(probability))\n",
-    "\n",
-    "def accuracy_score(y_true, y_pred):\n",
-    "    corr_preds=0\n",
-    "    wrong_preds=0\n",
-    "    for pred, correct in zip(y_pred, y_true):\n",
-    "        if pred == correct:\n",
-    "            corr_preds += 1\n",
-    "        else:\n",
-    "            wrong_preds += 1\n",
-    "    return corr_preds/(corr_preds+wrong_preds)\n",
-    "\n",
-    "print(\"Accuracy:\", accuracy_score(correct_pred, prediction))\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ],
- "metadata": {
-  "kernelspec": {
-   "display_name": "Python 3.10.4 ('tdde19')",
-   "language": "python",
-   "name": "python3"
-  },
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.10.4"
   },
-  "orig_nbformat": 4,
-  "vscode": {
-   "interpreter": {
-    "hash": "8e4aa0e1a1e15de86146661edda0b2884b54582522f7ff2b916774ba6b8accb1"
-   }
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 2
+  "nbformat": 4,
+  "nbformat_minor": 2
 }