diff --git a/Neural graph module/ngm.ipynb b/Neural graph module/ngm.ipynb
index 8bddb696fb833290afd664e8215874bfdc5168ba..14d945f4bdc93c70b54d5eb307a085a1777c93db 100644
--- a/Neural graph module/ngm.ipynb	
+++ b/Neural graph module/ngm.ipynb	
@@ -2,18 +2,9 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "b:\\Programs\\Miniconda\\envs\\tdde19\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
-      "  from .autonotebook import tqdm as notebook_tqdm\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "import datasets\n",
     "import torch\n",
@@ -29,7 +20,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -38,17 +29,9 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "cuda\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "# Use GPU if available\n",
     "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
@@ -58,7 +41,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -78,7 +61,7 @@
     "        with torch.no_grad():\n",
     "            x = self.bert.forward(tokenized_seq, attention_mask=tokenized_mask)\n",
     "            x = x[0][:,0,:].to(self.device)\n",
-    "        \n",
+    "            \n",
     "        x = self.linear(x)\n",
     "        x = self.softmax(x)\n",
     "        return x"
@@ -86,7 +69,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -124,7 +107,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -209,7 +192,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -226,347 +209,661 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 13,
+   "execution_count": 44,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from torch.utils.data import Dataset, DataLoader\n",
+    "\n",
+    "class MyDataset(Dataset):\n",
+    "    def __init__(self, inputs, attention_mask, correct_rels, relations):\n",
+    "        self.inputs = inputs\n",
+    "        self.attention_mask = attention_mask\n",
+    "        self.correct_rels = correct_rels\n",
+    "        self.relations = relations\n",
+    "        \n",
+    "\n",
+    "    def __len__(self):\n",
+    "        return len(self.inputs)\n",
+    "\n",
+    "    def __getitem__(self, idx):\n",
+    "        return self.inputs[idx], self.attention_mask[idx], self.relations.index(self.correct_rels[idx])\n",
+    "\n",
+    "\n",
+    "\n",
+    "\n",
+    "#From scratch json creates data set.\n",
+    "# class MyDataset(Dataset):\n",
+    "#     def __init__(self, json_file, transform=None):\n",
+    "#         self.qald_data = json.load(json_file)\n",
+    "\n",
+    "#     def __len__(self):\n",
+    "#         return len(self.qald_data)\n",
+    "\n",
+    "#     def __getitem__(self, idx):\n",
+    "#         self.inputs[idx], self.attention_mask[idx], self.labels[idx]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 45,
    "metadata": {},
    "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Beginning making batch\n"
+     ]
+    },
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias']\n",
-      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
-      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
+      "100%|██████████| 408/408 [00:00<00:00, 955.51it/s] \n",
+      "100%|██████████| 408/408 [00:00<00:00, 136139.70it/s]"
      ]
     },
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Beginning making batch\n"
+      "Finished with batches\n",
+      "features: tensor([[  101,  2040,  2003,  1996,  3677,  1997,  1996,  4035,  6870, 19247,\n",
+      "          1029,   102,  1031,  4942,  1033,   102,     0,     0,     0,     0,\n",
+      "             0,     0,     0]]) mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]]) label_index tensor(81)\n"
      ]
     },
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "100%|██████████| 408/408 [00:00<00:00, 2684.45it/s]\n",
-      "100%|██████████| 408/408 [00:00<00:00, 204160.82it/s]\n"
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "#Prepare data\n",
+    "\n",
+    "def open_json(file):\n",
+    "    with open(file, \"r\") as f:\n",
+    "        return json.load(f)\n",
+    "\n",
+    "relations = open_json(\"../data/relations-query-qald-9-linked.json\")\n",
+    "train_set = MyDataset(*make_batch(), relations=relations)\n",
+    "train_dataloader = DataLoader(train_set, batch_size=1, shuffle=True)\n",
+    "\n",
+    "#show first entry\n",
+    "train_features, train_mask, train_label = next(iter(train_dataloader))\n",
+    "print(\"features:\", train_features, \"mask:\",train_mask,\"label_index\", train_label[0])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 65,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']\n",
+      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
      ]
     },
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Finished with batches\n",
-      "Epoch: 0001 loss = 5.509346\n",
-      "Epoch: 0002 loss = 5.508417\n",
-      "Epoch: 0003 loss = 5.507416\n",
-      "Epoch: 0004 loss = 5.506336\n",
-      "Epoch: 0005 loss = 5.505169\n",
-      "Epoch: 0006 loss = 5.503905\n",
-      "Epoch: 0007 loss = 5.502542\n",
-      "Epoch: 0008 loss = 5.501083\n",
-      "Epoch: 0009 loss = 5.499537\n",
-      "Epoch: 0010 loss = 5.497901\n",
-      "Epoch: 0011 loss = 5.496177\n",
-      "Epoch: 0012 loss = 5.494359\n",
-      "Epoch: 0013 loss = 5.492444\n",
-      "Epoch: 0014 loss = 5.490432\n",
-      "Epoch: 0015 loss = 5.488323\n",
-      "Epoch: 0016 loss = 5.486121\n",
-      "Epoch: 0017 loss = 5.483822\n",
-      "Epoch: 0018 loss = 5.481424\n",
-      "Epoch: 0019 loss = 5.478921\n",
-      "Epoch: 0020 loss = 5.476310\n",
-      "Epoch: 0021 loss = 5.473589\n",
-      "Epoch: 0022 loss = 5.470768\n",
-      "Epoch: 0023 loss = 5.467863\n",
-      "Epoch: 0024 loss = 5.464887\n",
-      "Epoch: 0025 loss = 5.461853\n",
-      "Epoch: 0026 loss = 5.458760\n",
-      "Epoch: 0027 loss = 5.455590\n",
-      "Epoch: 0028 loss = 5.452313\n",
-      "Epoch: 0029 loss = 5.448884\n",
-      "Epoch: 0030 loss = 5.445263\n",
-      "Epoch: 0031 loss = 5.441427\n",
-      "Epoch: 0032 loss = 5.437370\n",
-      "Epoch: 0033 loss = 5.433111\n",
-      "Epoch: 0034 loss = 5.428682\n",
-      "Epoch: 0035 loss = 5.424133\n",
-      "Epoch: 0036 loss = 5.419532\n",
-      "Epoch: 0037 loss = 5.414931\n",
-      "Epoch: 0038 loss = 5.410347\n",
-      "Epoch: 0039 loss = 5.405736\n",
-      "Epoch: 0040 loss = 5.401030\n",
-      "Epoch: 0041 loss = 5.396172\n",
-      "Epoch: 0042 loss = 5.391167\n",
-      "Epoch: 0043 loss = 5.386050\n",
-      "Epoch: 0044 loss = 5.380856\n",
-      "Epoch: 0045 loss = 5.375600\n",
-      "Epoch: 0046 loss = 5.370294\n",
-      "Epoch: 0047 loss = 5.364953\n",
-      "Epoch: 0048 loss = 5.359621\n",
-      "Epoch: 0049 loss = 5.354363\n",
-      "Epoch: 0050 loss = 5.349256\n",
-      "Epoch: 0051 loss = 5.344321\n",
-      "Epoch: 0052 loss = 5.339508\n",
-      "Epoch: 0053 loss = 5.334712\n",
-      "Epoch: 0054 loss = 5.329853\n",
-      "Epoch: 0055 loss = 5.324935\n",
-      "Epoch: 0056 loss = 5.320047\n",
-      "Epoch: 0057 loss = 5.315296\n",
-      "Epoch: 0058 loss = 5.310766\n",
-      "Epoch: 0059 loss = 5.306486\n",
-      "Epoch: 0060 loss = 5.302420\n",
-      "Epoch: 0061 loss = 5.298513\n",
-      "Epoch: 0062 loss = 5.294705\n",
-      "Epoch: 0063 loss = 5.290928\n",
-      "Epoch: 0064 loss = 5.287130\n",
-      "Epoch: 0065 loss = 5.283293\n",
-      "Epoch: 0066 loss = 5.279424\n",
-      "Epoch: 0067 loss = 5.275546\n",
-      "Epoch: 0068 loss = 5.271703\n",
-      "Epoch: 0069 loss = 5.267990\n",
-      "Epoch: 0070 loss = 5.264547\n",
-      "Epoch: 0071 loss = 5.261484\n",
-      "Epoch: 0072 loss = 5.258725\n",
-      "Epoch: 0073 loss = 5.256050\n",
-      "Epoch: 0074 loss = 5.253313\n",
-      "Epoch: 0075 loss = 5.250482\n",
-      "Epoch: 0076 loss = 5.247588\n",
-      "Epoch: 0077 loss = 5.244689\n",
-      "Epoch: 0078 loss = 5.241833\n",
-      "Epoch: 0079 loss = 5.239061\n",
-      "Epoch: 0080 loss = 5.236411\n",
-      "Epoch: 0081 loss = 5.233915\n",
-      "Epoch: 0082 loss = 5.231617\n",
-      "Epoch: 0083 loss = 5.229539\n",
-      "Epoch: 0084 loss = 5.227572\n",
-      "Epoch: 0085 loss = 5.225513\n",
-      "Epoch: 0086 loss = 5.223309\n",
-      "Epoch: 0087 loss = 5.221033\n",
-      "Epoch: 0088 loss = 5.218764\n",
-      "Epoch: 0089 loss = 5.216528\n",
-      "Epoch: 0090 loss = 5.214312\n",
-      "Epoch: 0091 loss = 5.212102\n",
-      "Epoch: 0092 loss = 5.209961\n",
-      "Epoch: 0093 loss = 5.208022\n",
-      "Epoch: 0094 loss = 5.206336\n",
-      "Epoch: 0095 loss = 5.204650\n",
-      "Epoch: 0096 loss = 5.202743\n",
-      "Epoch: 0097 loss = 5.200671\n",
-      "Epoch: 0098 loss = 5.198618\n",
-      "Epoch: 0099 loss = 5.196756\n",
-      "Epoch: 0100 loss = 5.195175\n",
-      "Epoch: 0101 loss = 5.193856\n",
-      "Epoch: 0102 loss = 5.192677\n",
-      "Epoch: 0103 loss = 5.191495\n",
-      "Epoch: 0104 loss = 5.190241\n",
-      "Epoch: 0105 loss = 5.188928\n",
-      "Epoch: 0106 loss = 5.187597\n",
-      "Epoch: 0107 loss = 5.186284\n",
-      "Epoch: 0108 loss = 5.184998\n",
-      "Epoch: 0109 loss = 5.183724\n",
-      "Epoch: 0110 loss = 5.182417\n",
-      "Epoch: 0111 loss = 5.181022\n",
-      "Epoch: 0112 loss = 5.179502\n",
-      "Epoch: 0113 loss = 5.177876\n",
-      "Epoch: 0114 loss = 5.176267\n",
-      "Epoch: 0115 loss = 5.174932\n",
-      "Epoch: 0116 loss = 5.174131\n",
-      "Epoch: 0117 loss = 5.173597\n",
-      "Epoch: 0118 loss = 5.172731\n",
-      "Epoch: 0119 loss = 5.171515\n",
-      "Epoch: 0120 loss = 5.170227\n",
-      "Epoch: 0121 loss = 5.169064\n",
-      "Epoch: 0122 loss = 5.168075\n",
-      "Epoch: 0123 loss = 5.167188\n",
-      "Epoch: 0124 loss = 5.166307\n",
-      "Epoch: 0125 loss = 5.165361\n",
-      "Epoch: 0126 loss = 5.164318\n",
-      "Epoch: 0127 loss = 5.163185\n",
-      "Epoch: 0128 loss = 5.162001\n",
-      "Epoch: 0129 loss = 5.160835\n",
-      "Epoch: 0130 loss = 5.159757\n",
-      "Epoch: 0131 loss = 5.158821\n",
-      "Epoch: 0132 loss = 5.158039\n",
-      "Epoch: 0133 loss = 5.157377\n",
-      "Epoch: 0134 loss = 5.156784\n",
-      "Epoch: 0135 loss = 5.156210\n",
-      "Epoch: 0136 loss = 5.155619\n",
-      "Epoch: 0137 loss = 5.154994\n",
-      "Epoch: 0138 loss = 5.154339\n",
-      "Epoch: 0139 loss = 5.153667\n",
-      "Epoch: 0140 loss = 5.152996\n",
-      "Epoch: 0141 loss = 5.152328\n",
-      "Epoch: 0142 loss = 5.151660\n",
-      "Epoch: 0143 loss = 5.150975\n",
-      "Epoch: 0144 loss = 5.150242\n",
-      "Epoch: 0145 loss = 5.149413\n",
-      "Epoch: 0146 loss = 5.148431\n",
-      "Epoch: 0147 loss = 5.147236\n",
-      "Epoch: 0148 loss = 5.145802\n",
-      "Epoch: 0149 loss = 5.144192\n",
-      "Epoch: 0150 loss = 5.142632\n",
-      "Epoch: 0151 loss = 5.141498\n",
-      "Epoch: 0152 loss = 5.140967\n",
-      "Epoch: 0153 loss = 5.140542\n",
-      "Epoch: 0154 loss = 5.139736\n",
-      "Epoch: 0155 loss = 5.138566\n",
-      "Epoch: 0156 loss = 5.137214\n",
-      "Epoch: 0157 loss = 5.135843\n",
-      "Epoch: 0158 loss = 5.134551\n",
-      "Epoch: 0159 loss = 5.133349\n",
-      "Epoch: 0160 loss = 5.132212\n",
-      "Epoch: 0161 loss = 5.131142\n",
-      "Epoch: 0162 loss = 5.130210\n",
-      "Epoch: 0163 loss = 5.129475\n",
-      "Epoch: 0164 loss = 5.128819\n",
-      "Epoch: 0165 loss = 5.128070\n",
-      "Epoch: 0166 loss = 5.127200\n",
-      "Epoch: 0167 loss = 5.126271\n",
-      "Epoch: 0168 loss = 5.125329\n",
-      "Epoch: 0169 loss = 5.124373\n",
-      "Epoch: 0170 loss = 5.123369\n",
-      "Epoch: 0171 loss = 5.122267\n",
-      "Epoch: 0172 loss = 5.121037\n",
-      "Epoch: 0173 loss = 5.119725\n",
-      "Epoch: 0174 loss = 5.118514\n",
-      "Epoch: 0175 loss = 5.117722\n",
-      "Epoch: 0176 loss = 5.117445\n",
-      "Epoch: 0177 loss = 5.117110\n",
-      "Epoch: 0178 loss = 5.116376\n",
-      "Epoch: 0179 loss = 5.115437\n",
-      "Epoch: 0180 loss = 5.114553\n",
-      "Epoch: 0181 loss = 5.113859\n",
-      "Epoch: 0182 loss = 5.113349\n",
-      "Epoch: 0183 loss = 5.112945\n",
-      "Epoch: 0184 loss = 5.112563\n",
-      "Epoch: 0185 loss = 5.112154\n",
-      "Epoch: 0186 loss = 5.111700\n",
-      "Epoch: 0187 loss = 5.111213\n",
-      "Epoch: 0188 loss = 5.110720\n",
-      "Epoch: 0189 loss = 5.110251\n",
-      "Epoch: 0190 loss = 5.109831\n",
-      "Epoch: 0191 loss = 5.109465\n",
-      "Epoch: 0192 loss = 5.109145\n",
-      "Epoch: 0193 loss = 5.108854\n",
-      "Epoch: 0194 loss = 5.108573\n",
-      "Epoch: 0195 loss = 5.108290\n",
-      "Epoch: 0196 loss = 5.108000\n",
-      "Epoch: 0197 loss = 5.107710\n",
-      "Epoch: 0198 loss = 5.107430\n",
-      "Epoch: 0199 loss = 5.107163\n",
-      "Epoch: 0200 loss = 5.106915\n",
-      "Epoch: 0201 loss = 5.106686\n",
-      "Epoch: 0202 loss = 5.106472\n",
-      "Epoch: 0203 loss = 5.106268\n",
-      "Epoch: 0204 loss = 5.106072\n",
-      "Epoch: 0205 loss = 5.105881\n",
-      "Epoch: 0206 loss = 5.105694\n",
-      "Epoch: 0207 loss = 5.105511\n",
-      "Epoch: 0208 loss = 5.105332\n",
-      "Epoch: 0209 loss = 5.105159\n",
-      "Epoch: 0210 loss = 5.104992\n",
-      "Epoch: 0211 loss = 5.104831\n",
-      "Epoch: 0212 loss = 5.104676\n",
-      "Epoch: 0213 loss = 5.104527\n",
-      "Epoch: 0214 loss = 5.104386\n",
-      "Epoch: 0215 loss = 5.104248\n",
-      "Epoch: 0216 loss = 5.104114\n",
-      "Epoch: 0217 loss = 5.103983\n",
-      "Epoch: 0218 loss = 5.103856\n",
-      "Epoch: 0219 loss = 5.103732\n",
-      "Epoch: 0220 loss = 5.103611\n",
-      "Epoch: 0221 loss = 5.103494\n",
-      "Epoch: 0222 loss = 5.103378\n",
-      "Epoch: 0223 loss = 5.103266\n",
-      "Epoch: 0224 loss = 5.103157\n",
-      "Epoch: 0225 loss = 5.103050\n",
-      "Epoch: 0226 loss = 5.102947\n",
-      "Epoch: 0227 loss = 5.102844\n",
-      "Epoch: 0228 loss = 5.102745\n",
-      "Epoch: 0229 loss = 5.102647\n",
-      "Epoch: 0230 loss = 5.102551\n",
-      "Epoch: 0231 loss = 5.102457\n",
-      "Epoch: 0232 loss = 5.102361\n",
-      "Epoch: 0233 loss = 5.102268\n",
-      "Epoch: 0234 loss = 5.102171\n",
-      "Epoch: 0235 loss = 5.102070\n",
-      "Epoch: 0236 loss = 5.101960\n",
-      "Epoch: 0237 loss = 5.101832\n",
-      "Epoch: 0238 loss = 5.101676\n",
-      "Epoch: 0239 loss = 5.101472\n",
-      "Epoch: 0240 loss = 5.101196\n",
-      "Epoch: 0241 loss = 5.100809\n",
-      "Epoch: 0242 loss = 5.100273\n",
-      "Epoch: 0243 loss = 5.099566\n",
-      "Epoch: 0244 loss = 5.098716\n",
-      "Epoch: 0245 loss = 5.097848\n",
-      "Epoch: 0246 loss = 5.097132\n",
-      "Epoch: 0247 loss = 5.096595\n",
-      "Epoch: 0248 loss = 5.095999\n",
-      "Epoch: 0249 loss = 5.095338\n",
-      "Epoch: 0250 loss = 5.094879\n",
-      "Epoch: 0251 loss = 5.094688\n",
-      "Epoch: 0252 loss = 5.094509\n",
-      "Epoch: 0253 loss = 5.094189\n",
-      "Epoch: 0254 loss = 5.093769\n",
-      "Epoch: 0255 loss = 5.093320\n",
-      "Epoch: 0256 loss = 5.092857\n",
-      "Epoch: 0257 loss = 5.092339\n",
-      "Epoch: 0258 loss = 5.091694\n",
-      "Epoch: 0259 loss = 5.090842\n",
-      "Epoch: 0260 loss = 5.089743\n",
-      "Epoch: 0261 loss = 5.088472\n",
-      "Epoch: 0262 loss = 5.087311\n",
-      "Epoch: 0263 loss = 5.086748\n",
-      "Epoch: 0264 loss = 5.086977\n",
-      "Epoch: 0265 loss = 5.087138\n",
-      "Epoch: 0266 loss = 5.086704\n",
-      "Epoch: 0267 loss = 5.085968\n",
-      "Epoch: 0268 loss = 5.085289\n",
-      "Epoch: 0269 loss = 5.084831\n",
-      "Epoch: 0270 loss = 5.084577\n",
-      "Epoch: 0271 loss = 5.084425\n",
-      "Epoch: 0272 loss = 5.084273\n",
-      "Epoch: 0273 loss = 5.084045\n",
-      "Epoch: 0274 loss = 5.083702\n",
-      "Epoch: 0275 loss = 5.083218\n",
-      "Epoch: 0276 loss = 5.082567\n",
-      "Epoch: 0277 loss = 5.081728\n",
-      "Epoch: 0278 loss = 5.080690\n",
-      "Epoch: 0279 loss = 5.079515\n",
-      "Epoch: 0280 loss = 5.078421\n",
-      "Epoch: 0281 loss = 5.077838\n",
-      "Epoch: 0282 loss = 5.078005\n",
-      "Epoch: 0283 loss = 5.078021\n",
-      "Epoch: 0284 loss = 5.077395\n",
-      "Epoch: 0285 loss = 5.076532\n",
-      "Epoch: 0286 loss = 5.075821\n",
-      "Epoch: 0287 loss = 5.075374\n",
-      "Epoch: 0288 loss = 5.075100\n",
-      "Epoch: 0289 loss = 5.074864\n",
-      "Epoch: 0290 loss = 5.074569\n",
-      "Epoch: 0291 loss = 5.074176\n",
-      "Epoch: 0292 loss = 5.073692\n",
-      "Epoch: 0293 loss = 5.073163\n",
-      "Epoch: 0294 loss = 5.072643\n",
-      "Epoch: 0295 loss = 5.072181\n",
-      "Epoch: 0296 loss = 5.071800\n",
-      "Epoch: 0297 loss = 5.071474\n",
-      "Epoch: 0298 loss = 5.071146\n",
-      "Epoch: 0299 loss = 5.070775\n",
-      "Epoch: 0300 loss = 5.070364\n"
+      "Epoch 0 batch: 0 , loss = 5.509496\n",
+      "Epoch 1 batch: 0 , loss = 5.507990\n",
+      "Epoch 2 batch: 0 , loss = 5.506450\n",
+      "Epoch 3 batch: 0 , loss = 5.504800\n",
+      "Epoch 4 batch: 0 , loss = 5.502756\n",
+      "Epoch 5 batch: 0 , loss = 5.500911\n",
+      "Epoch 6 batch: 0 , loss = 5.498587\n",
+      "Epoch 7 batch: 0 , loss = 5.496222\n",
+      "Epoch 8 batch: 0 , loss = 5.493753\n",
+      "Epoch 9 batch: 0 , loss = 5.492003\n",
+      "Epoch 10 batch: 0 , loss = 5.489583\n",
+      "Epoch 11 batch: 0 , loss = 5.487563\n",
+      "Epoch 12 batch: 0 , loss = 5.485127\n",
+      "Epoch 13 batch: 0 , loss = 5.481408\n",
+      "Epoch 14 batch: 0 , loss = 5.479365\n",
+      "Epoch 15 batch: 0 , loss = 5.476476\n",
+      "Epoch 16 batch: 0 , loss = 5.475063\n",
+      "Epoch 17 batch: 0 , loss = 5.473976\n",
+      "Epoch 18 batch: 0 , loss = 5.471124\n",
+      "Epoch 19 batch: 0 , loss = 5.472827\n",
+      "Epoch 20 batch: 0 , loss = 5.466723\n",
+      "Epoch 21 batch: 0 , loss = 5.464960\n",
+      "Epoch 22 batch: 0 , loss = 5.466041\n",
+      "Epoch 23 batch: 0 , loss = 5.462078\n",
+      "Epoch 24 batch: 0 , loss = 5.460968\n",
+      "Epoch 25 batch: 0 , loss = 5.463361\n",
+      "Epoch 26 batch: 0 , loss = 5.462362\n",
+      "Epoch 27 batch: 0 , loss = 5.459974\n",
+      "Epoch 28 batch: 0 , loss = 5.462227\n",
+      "Epoch 29 batch: 0 , loss = 5.457432\n",
+      "Epoch 30 batch: 0 , loss = 5.456643\n",
+      "Epoch 31 batch: 0 , loss = 5.455560\n",
+      "Epoch 32 batch: 0 , loss = 5.454978\n",
+      "Epoch 33 batch: 0 , loss = 5.452355\n",
+      "Epoch 34 batch: 0 , loss = 5.449059\n",
+      "Epoch 35 batch: 0 , loss = 5.450156\n",
+      "Epoch 36 batch: 0 , loss = 5.443484\n",
+      "Epoch 37 batch: 0 , loss = 5.441142\n",
+      "Epoch 38 batch: 0 , loss = 5.438660\n",
+      "Epoch 39 batch: 0 , loss = 5.437134\n",
+      "Epoch 40 batch: 0 , loss = 5.434308\n",
+      "Epoch 41 batch: 0 , loss = 5.431943\n",
+      "Epoch 42 batch: 0 , loss = 5.429959\n",
+      "Epoch 43 batch: 0 , loss = 5.428108\n",
+      "Epoch 44 batch: 0 , loss = 5.426363\n",
+      "Epoch 45 batch: 0 , loss = 5.428246\n",
+      "Epoch 46 batch: 0 , loss = 5.424921\n",
+      "Epoch 47 batch: 0 , loss = 5.425705\n",
+      "Epoch 48 batch: 0 , loss = 5.424184\n",
+      "Epoch 49 batch: 0 , loss = 5.423537\n",
+      "Epoch 50 batch: 0 , loss = 5.424012\n",
+      "Epoch 51 batch: 0 , loss = 5.419333\n",
+      "Epoch 52 batch: 0 , loss = 5.414816\n",
+      "Epoch 53 batch: 0 , loss = 5.418826\n",
+      "Epoch 54 batch: 0 , loss = 5.418625\n",
+      "Epoch 55 batch: 0 , loss = 5.419236\n",
+      "Epoch 56 batch: 0 , loss = 5.421367\n",
+      "Epoch 57 batch: 0 , loss = 5.418901\n",
+      "Epoch 58 batch: 0 , loss = 5.416471\n",
+      "Epoch 59 batch: 0 , loss = 5.414352\n",
+      "Epoch 60 batch: 0 , loss = 5.416772\n",
+      "Epoch 61 batch: 0 , loss = 5.412184\n",
+      "Epoch 62 batch: 0 , loss = 5.407087\n",
+      "Epoch 63 batch: 0 , loss = 5.404146\n",
+      "Epoch 64 batch: 0 , loss = 5.405499\n",
+      "Epoch 65 batch: 0 , loss = 5.404672\n",
+      "Epoch 66 batch: 0 , loss = 5.399035\n",
+      "Epoch 67 batch: 0 , loss = 5.407500\n",
+      "Epoch 68 batch: 0 , loss = 5.403162\n",
+      "Epoch 69 batch: 0 , loss = 5.397717\n",
+      "Epoch 70 batch: 0 , loss = 5.398447\n",
+      "Epoch 71 batch: 0 , loss = 5.398669\n",
+      "Epoch 72 batch: 0 , loss = 5.392782\n",
+      "Epoch 73 batch: 0 , loss = 5.392102\n",
+      "Epoch 74 batch: 0 , loss = 5.386912\n",
+      "Epoch 75 batch: 0 , loss = 5.384616\n",
+      "Epoch 76 batch: 0 , loss = 5.382891\n",
+      "Epoch 77 batch: 0 , loss = 5.381707\n",
+      "Epoch 78 batch: 0 , loss = 5.376785\n",
+      "Epoch 79 batch: 0 , loss = 5.374002\n",
+      "Epoch 80 batch: 0 , loss = 5.376191\n",
+      "Epoch 81 batch: 0 , loss = 5.381535\n",
+      "Epoch 82 batch: 0 , loss = 5.376332\n",
+      "Epoch 83 batch: 0 , loss = 5.372051\n",
+      "Epoch 84 batch: 0 , loss = 5.367785\n",
+      "Epoch 85 batch: 0 , loss = 5.367027\n",
+      "Epoch 86 batch: 0 , loss = 5.366450\n",
+      "Epoch 87 batch: 0 , loss = 5.364227\n",
+      "Epoch 88 batch: 0 , loss = 5.364299\n",
+      "Epoch 89 batch: 0 , loss = 5.363710\n",
+      "Epoch 90 batch: 0 , loss = 5.356105\n",
+      "Epoch 91 batch: 0 , loss = 5.353946\n",
+      "Epoch 92 batch: 0 , loss = 5.356158\n",
+      "Epoch 93 batch: 0 , loss = 5.355119\n",
+      "Epoch 94 batch: 0 , loss = 5.348075\n",
+      "Epoch 95 batch: 0 , loss = 5.351038\n",
+      "Epoch 96 batch: 0 , loss = 5.349224\n",
+      "Epoch 97 batch: 0 , loss = 5.344652\n",
+      "Epoch 98 batch: 0 , loss = 5.341354\n",
+      "Epoch 99 batch: 0 , loss = 5.342059\n",
+      "Epoch 100 batch: 0 , loss = 5.344870\n",
+      "Epoch 101 batch: 0 , loss = 5.335486\n",
+      "Epoch 102 batch: 0 , loss = 5.339808\n",
+      "Epoch 103 batch: 0 , loss = 5.330384\n",
+      "Epoch 104 batch: 0 , loss = 5.333795\n",
+      "Epoch 105 batch: 0 , loss = 5.336934\n",
+      "Epoch 106 batch: 0 , loss = 5.334551\n",
+      "Epoch 107 batch: 0 , loss = 5.328543\n",
+      "Epoch 108 batch: 0 , loss = 5.329463\n",
+      "Epoch 109 batch: 0 , loss = 5.327895\n",
+      "Epoch 110 batch: 0 , loss = 5.324828\n",
+      "Epoch 111 batch: 0 , loss = 5.325212\n",
+      "Epoch 112 batch: 0 , loss = 5.320701\n",
+      "Epoch 113 batch: 0 , loss = 5.327697\n",
+      "Epoch 114 batch: 0 , loss = 5.321108\n",
+      "Epoch 115 batch: 0 , loss = 5.318196\n",
+      "Epoch 116 batch: 0 , loss = 5.313715\n",
+      "Epoch 117 batch: 0 , loss = 5.311478\n",
+      "Epoch 118 batch: 0 , loss = 5.310149\n",
+      "Epoch 119 batch: 0 , loss = 5.304658\n",
+      "Epoch 120 batch: 0 , loss = 5.299352\n",
+      "Epoch 121 batch: 0 , loss = 5.297422\n",
+      "Epoch 122 batch: 0 , loss = 5.296243\n",
+      "Epoch 123 batch: 0 , loss = 5.299733\n",
+      "Epoch 124 batch: 0 , loss = 5.296957\n",
+      "Epoch 125 batch: 0 , loss = 5.296990\n",
+      "Epoch 126 batch: 0 , loss = 5.295414\n",
+      "Epoch 127 batch: 0 , loss = 5.291314\n",
+      "Epoch 128 batch: 0 , loss = 5.288688\n",
+      "Epoch 129 batch: 0 , loss = 5.290800\n",
+      "Epoch 130 batch: 0 , loss = 5.290220\n",
+      "Epoch 131 batch: 0 , loss = 5.292923\n",
+      "Epoch 132 batch: 0 , loss = 5.282723\n",
+      "Epoch 133 batch: 0 , loss = 5.281114\n",
+      "Epoch 134 batch: 0 , loss = 5.284384\n",
+      "Epoch 135 batch: 0 , loss = 5.285651\n",
+      "Epoch 136 batch: 0 , loss = 5.279850\n",
+      "Epoch 137 batch: 0 , loss = 5.276040\n",
+      "Epoch 138 batch: 0 , loss = 5.279545\n",
+      "Epoch 139 batch: 0 , loss = 5.275866\n",
+      "Epoch 140 batch: 0 , loss = 5.275192\n",
+      "Epoch 141 batch: 0 , loss = 5.278750\n",
+      "Epoch 142 batch: 0 , loss = 5.281165\n",
+      "Epoch 143 batch: 0 , loss = 5.281812\n",
+      "Epoch 144 batch: 0 , loss = 5.270550\n",
+      "Epoch 145 batch: 0 , loss = 5.269819\n",
+      "Epoch 146 batch: 0 , loss = 5.268355\n",
+      "Epoch 147 batch: 0 , loss = 5.266929\n",
+      "Epoch 148 batch: 0 , loss = 5.261328\n",
+      "Epoch 149 batch: 0 , loss = 5.268577\n",
+      "Epoch 150 batch: 0 , loss = 5.260501\n",
+      "Epoch 151 batch: 0 , loss = 5.263970\n",
+      "Epoch 152 batch: 0 , loss = 5.261120\n",
+      "Epoch 153 batch: 0 , loss = 5.258860\n",
+      "Epoch 154 batch: 0 , loss = 5.253738\n",
+      "Epoch 155 batch: 0 , loss = 5.261181\n",
+      "Epoch 156 batch: 0 , loss = 5.246259\n",
+      "Epoch 157 batch: 0 , loss = 5.245531\n",
+      "Epoch 158 batch: 0 , loss = 5.244577\n",
+      "Epoch 159 batch: 0 , loss = 5.243086\n",
+      "Epoch 160 batch: 0 , loss = 5.241098\n",
+      "Epoch 161 batch: 0 , loss = 5.245851\n",
+      "Epoch 162 batch: 0 , loss = 5.245241\n",
+      "Epoch 163 batch: 0 , loss = 5.244030\n",
+      "Epoch 164 batch: 0 , loss = 5.238179\n",
+      "Epoch 165 batch: 0 , loss = 5.232537\n",
+      "Epoch 166 batch: 0 , loss = 5.238233\n",
+      "Epoch 167 batch: 0 , loss = 5.244127\n",
+      "Epoch 168 batch: 0 , loss = 5.238196\n",
+      "Epoch 169 batch: 0 , loss = 5.231448\n",
+      "Epoch 170 batch: 0 , loss = 5.231340\n",
+      "Epoch 171 batch: 0 , loss = 5.230532\n",
+      "Epoch 172 batch: 0 , loss = 5.232080\n",
+      "Epoch 173 batch: 0 , loss = 5.228147\n",
+      "Epoch 174 batch: 0 , loss = 5.225359\n",
+      "Epoch 175 batch: 0 , loss = 5.229460\n",
+      "Epoch 176 batch: 0 , loss = 5.223203\n",
+      "Epoch 177 batch: 0 , loss = 5.226225\n",
+      "Epoch 178 batch: 0 , loss = 5.224785\n",
+      "Epoch 179 batch: 0 , loss = 5.232825\n",
+      "Epoch 180 batch: 0 , loss = 5.227349\n",
+      "Epoch 181 batch: 0 , loss = 5.222055\n",
+      "Epoch 182 batch: 0 , loss = 5.212518\n",
+      "Epoch 183 batch: 0 , loss = 5.221508\n",
+      "Epoch 184 batch: 0 , loss = 5.218023\n",
+      "Epoch 185 batch: 0 , loss = 5.221085\n",
+      "Epoch 186 batch: 0 , loss = 5.212870\n",
+      "Epoch 187 batch: 0 , loss = 5.217089\n",
+      "Epoch 188 batch: 0 , loss = 5.220237\n",
+      "Epoch 189 batch: 0 , loss = 5.220267\n",
+      "Epoch 190 batch: 0 , loss = 5.216769\n",
+      "Epoch 191 batch: 0 , loss = 5.218822\n",
+      "Epoch 192 batch: 0 , loss = 5.220790\n",
+      "Epoch 193 batch: 0 , loss = 5.206959\n",
+      "Epoch 194 batch: 0 , loss = 5.204648\n",
+      "Epoch 195 batch: 0 , loss = 5.202166\n",
+      "Epoch 196 batch: 0 , loss = 5.199206\n",
+      "Epoch 197 batch: 0 , loss = 5.200701\n",
+      "Epoch 198 batch: 0 , loss = 5.187927\n",
+      "Epoch 199 batch: 0 , loss = 5.194535\n",
+      "Epoch 200 batch: 0 , loss = 5.190029\n",
+      "Epoch 201 batch: 0 , loss = 5.189040\n",
+      "Epoch 202 batch: 0 , loss = 5.190596\n",
+      "Epoch 203 batch: 0 , loss = 5.189606\n",
+      "Epoch 204 batch: 0 , loss = 5.186440\n",
+      "Epoch 205 batch: 0 , loss = 5.191073\n",
+      "Epoch 206 batch: 0 , loss = 5.184719\n",
+      "Epoch 207 batch: 0 , loss = 5.179452\n",
+      "Epoch 208 batch: 0 , loss = 5.173555\n",
+      "Epoch 209 batch: 0 , loss = 5.180843\n",
+      "Epoch 210 batch: 0 , loss = 5.186315\n",
+      "Epoch 211 batch: 0 , loss = 5.182886\n",
+      "Epoch 212 batch: 0 , loss = 5.179914\n",
+      "Epoch 213 batch: 0 , loss = 5.175862\n",
+      "Epoch 214 batch: 0 , loss = 5.183718\n",
+      "Epoch 215 batch: 0 , loss = 5.162323\n",
+      "Epoch 216 batch: 0 , loss = 5.170660\n",
+      "Epoch 217 batch: 0 , loss = 5.169362\n",
+      "Epoch 218 batch: 0 , loss = 5.162789\n",
+      "Epoch 219 batch: 0 , loss = 5.160868\n",
+      "Epoch 220 batch: 0 , loss = 5.164267\n",
+      "Epoch 221 batch: 0 , loss = 5.164003\n",
+      "Epoch 222 batch: 0 , loss = 5.165540\n",
+      "Epoch 223 batch: 0 , loss = 5.161615\n",
+      "Epoch 224 batch: 0 , loss = 5.152071\n",
+      "Epoch 225 batch: 0 , loss = 5.162641\n",
+      "Epoch 226 batch: 0 , loss = 5.158351\n",
+      "Epoch 227 batch: 0 , loss = 5.167762\n",
+      "Epoch 228 batch: 0 , loss = 5.166461\n",
+      "Epoch 229 batch: 0 , loss = 5.157466\n",
+      "Epoch 230 batch: 0 , loss = 5.156347\n",
+      "Epoch 231 batch: 0 , loss = 5.156344\n",
+      "Epoch 232 batch: 0 , loss = 5.164078\n",
+      "Epoch 233 batch: 0 , loss = 5.155646\n",
+      "Epoch 234 batch: 0 , loss = 5.154130\n",
+      "Epoch 235 batch: 0 , loss = 5.152800\n",
+      "Epoch 236 batch: 0 , loss = 5.161687\n",
+      "Epoch 237 batch: 0 , loss = 5.151773\n",
+      "Epoch 238 batch: 0 , loss = 5.146783\n",
+      "Epoch 239 batch: 0 , loss = 5.146691\n",
+      "Epoch 240 batch: 0 , loss = 5.151015\n",
+      "Epoch 241 batch: 0 , loss = 5.145954\n",
+      "Epoch 242 batch: 0 , loss = 5.154754\n",
+      "Epoch 243 batch: 0 , loss = 5.144487\n",
+      "Epoch 244 batch: 0 , loss = 5.152771\n",
+      "Epoch 245 batch: 0 , loss = 5.156934\n",
+      "Epoch 246 batch: 0 , loss = 5.146832\n",
+      "Epoch 247 batch: 0 , loss = 5.148869\n",
+      "Epoch 248 batch: 0 , loss = 5.142589\n",
+      "Epoch 249 batch: 0 , loss = 5.141822\n",
+      "Epoch 250 batch: 0 , loss = 5.144300\n",
+      "Epoch 251 batch: 0 , loss = 5.140250\n",
+      "Epoch 252 batch: 0 , loss = 5.135976\n",
+      "Epoch 253 batch: 0 , loss = 5.144660\n",
+      "Epoch 254 batch: 0 , loss = 5.144534\n",
+      "Epoch 255 batch: 0 , loss = 5.140327\n",
+      "Epoch 256 batch: 0 , loss = 5.139380\n",
+      "Epoch 257 batch: 0 , loss = 5.149578\n",
+      "Epoch 258 batch: 0 , loss = 5.139646\n",
+      "Epoch 259 batch: 0 , loss = 5.139930\n",
+      "Epoch 260 batch: 0 , loss = 5.143685\n",
+      "Epoch 261 batch: 0 , loss = 5.138163\n",
+      "Epoch 262 batch: 0 , loss = 5.138759\n",
+      "Epoch 263 batch: 0 , loss = 5.133649\n",
+      "Epoch 264 batch: 0 , loss = 5.137669\n",
+      "Epoch 265 batch: 0 , loss = 5.133360\n",
+      "Epoch 266 batch: 0 , loss = 5.137102\n",
+      "Epoch 267 batch: 0 , loss = 5.144728\n",
+      "Epoch 268 batch: 0 , loss = 5.135507\n",
+      "Epoch 269 batch: 0 , loss = 5.133213\n",
+      "Epoch 270 batch: 0 , loss = 5.130654\n",
+      "Epoch 271 batch: 0 , loss = 5.128548\n",
+      "Epoch 272 batch: 0 , loss = 5.118666\n",
+      "Epoch 273 batch: 0 , loss = 5.121534\n",
+      "Epoch 274 batch: 0 , loss = 5.122122\n",
+      "Epoch 275 batch: 0 , loss = 5.127708\n",
+      "Epoch 276 batch: 0 , loss = 5.132959\n",
+      "Epoch 277 batch: 0 , loss = 5.129933\n",
+      "Epoch 278 batch: 0 , loss = 5.125852\n",
+      "Epoch 279 batch: 0 , loss = 5.120491\n",
+      "Epoch 280 batch: 0 , loss = 5.123435\n",
+      "Epoch 281 batch: 0 , loss = 5.129238\n",
+      "Epoch 282 batch: 0 , loss = 5.124735\n",
+      "Epoch 283 batch: 0 , loss = 5.116185\n",
+      "Epoch 284 batch: 0 , loss = 5.113766\n",
+      "Epoch 285 batch: 0 , loss = 5.114601\n",
+      "Epoch 286 batch: 0 , loss = 5.112402\n",
+      "Epoch 287 batch: 0 , loss = 5.111151\n",
+      "Epoch 288 batch: 0 , loss = 5.117523\n",
+      "Epoch 289 batch: 0 , loss = 5.117987\n",
+      "Epoch 290 batch: 0 , loss = 5.112226\n",
+      "Epoch 291 batch: 0 , loss = 5.103069\n",
+      "Epoch 292 batch: 0 , loss = 5.105778\n",
+      "Epoch 293 batch: 0 , loss = 5.117328\n",
+      "Epoch 294 batch: 0 , loss = 5.117546\n",
+      "Epoch 295 batch: 0 , loss = 5.114745\n",
+      "Epoch 296 batch: 0 , loss = 5.111210\n",
+      "Epoch 297 batch: 0 , loss = 5.116782\n",
+      "Epoch 298 batch: 0 , loss = 5.126735\n",
+      "Epoch 299 batch: 0 , loss = 5.110609\n",
+      "Epoch 300 batch: 0 , loss = 5.119059\n",
+      "Epoch 301 batch: 0 , loss = 5.115992\n",
+      "Epoch 302 batch: 0 , loss = 5.106349\n",
+      "Epoch 303 batch: 0 , loss = 5.104899\n",
+      "Epoch 304 batch: 0 , loss = 5.104232\n",
+      "Epoch 305 batch: 0 , loss = 5.112337\n",
+      "Epoch 306 batch: 0 , loss = 5.111980\n",
+      "Epoch 307 batch: 0 , loss = 5.106858\n",
+      "Epoch 308 batch: 0 , loss = 5.109094\n",
+      "Epoch 309 batch: 0 , loss = 5.111865\n",
+      "Epoch 310 batch: 0 , loss = 5.096600\n",
+      "Epoch 311 batch: 0 , loss = 5.106102\n",
+      "Epoch 312 batch: 0 , loss = 5.102217\n",
+      "Epoch 313 batch: 0 , loss = 5.111903\n",
+      "Epoch 314 batch: 0 , loss = 5.102543\n",
+      "Epoch 315 batch: 0 , loss = 5.101692\n",
+      "Epoch 316 batch: 0 , loss = 5.114990\n",
+      "Epoch 317 batch: 0 , loss = 5.108889\n",
+      "Epoch 318 batch: 0 , loss = 5.102445\n",
+      "Epoch 319 batch: 0 , loss = 5.096200\n",
+      "Epoch 320 batch: 0 , loss = 5.100635\n",
+      "Epoch 321 batch: 0 , loss = 5.099531\n",
+      "Epoch 322 batch: 0 , loss = 5.100465\n",
+      "Epoch 323 batch: 0 , loss = 5.095500\n",
+      "Epoch 324 batch: 0 , loss = 5.100604\n",
+      "Epoch 325 batch: 0 , loss = 5.099518\n",
+      "Epoch 326 batch: 0 , loss = 5.094840\n",
+      "Epoch 327 batch: 0 , loss = 5.093679\n",
+      "Epoch 328 batch: 0 , loss = 5.093080\n",
+      "Epoch 329 batch: 0 , loss = 5.095733\n",
+      "Epoch 330 batch: 0 , loss = 5.099861\n",
+      "Epoch 331 batch: 0 , loss = 5.090419\n",
+      "Epoch 332 batch: 0 , loss = 5.099437\n",
+      "Epoch 333 batch: 0 , loss = 5.098624\n",
+      "Epoch 334 batch: 0 , loss = 5.093289\n",
+      "Epoch 335 batch: 0 , loss = 5.097763\n",
+      "Epoch 336 batch: 0 , loss = 5.097146\n",
+      "Epoch 337 batch: 0 , loss = 5.101106\n",
+      "Epoch 338 batch: 0 , loss = 5.082275\n",
+      "Epoch 339 batch: 0 , loss = 5.087108\n",
+      "Epoch 340 batch: 0 , loss = 5.092202\n",
+      "Epoch 341 batch: 0 , loss = 5.082821\n",
+      "Epoch 342 batch: 0 , loss = 5.092401\n",
+      "Epoch 343 batch: 0 , loss = 5.097392\n",
+      "Epoch 344 batch: 0 , loss = 5.097518\n",
+      "Epoch 345 batch: 0 , loss = 5.095610\n",
+      "Epoch 346 batch: 0 , loss = 5.089358\n",
+      "Epoch 347 batch: 0 , loss = 5.084134\n",
+      "Epoch 348 batch: 0 , loss = 5.093090\n",
+      "Epoch 349 batch: 0 , loss = 5.094424\n",
+      "Epoch 350 batch: 0 , loss = 5.089258\n",
+      "Epoch 351 batch: 0 , loss = 5.089724\n",
+      "Epoch 352 batch: 0 , loss = 5.092799\n",
+      "Epoch 353 batch: 0 , loss = 5.089324\n",
+      "Epoch 354 batch: 0 , loss = 5.093751\n",
+      "Epoch 355 batch: 0 , loss = 5.088273\n",
+      "Epoch 356 batch: 0 , loss = 5.083136\n",
+      "Epoch 357 batch: 0 , loss = 5.087886\n",
+      "Epoch 358 batch: 0 , loss = 5.092871\n",
+      "Epoch 359 batch: 0 , loss = 5.077843\n",
+      "Epoch 360 batch: 0 , loss = 5.087354\n",
+      "Epoch 361 batch: 0 , loss = 5.077330\n",
+      "Epoch 362 batch: 0 , loss = 5.096845\n",
+      "Epoch 363 batch: 0 , loss = 5.081802\n",
+      "Epoch 364 batch: 0 , loss = 5.086153\n",
+      "Epoch 365 batch: 0 , loss = 5.081354\n",
+      "Epoch 366 batch: 0 , loss = 5.084242\n",
+      "Epoch 367 batch: 0 , loss = 5.094817\n",
+      "Epoch 368 batch: 0 , loss = 5.089900\n",
+      "Epoch 369 batch: 0 , loss = 5.089128\n",
+      "Epoch 370 batch: 0 , loss = 5.084967\n",
+      "Epoch 371 batch: 0 , loss = 5.085038\n",
+      "Epoch 372 batch: 0 , loss = 5.084878\n",
+      "Epoch 373 batch: 0 , loss = 5.085416\n",
+      "Epoch 374 batch: 0 , loss = 5.084856\n",
+      "Epoch 375 batch: 0 , loss = 5.085449\n",
+      "Epoch 376 batch: 0 , loss = 5.080372\n",
+      "Epoch 377 batch: 0 , loss = 5.079864\n",
+      "Epoch 378 batch: 0 , loss = 5.089789\n",
+      "Epoch 379 batch: 0 , loss = 5.084702\n",
+      "Epoch 380 batch: 0 , loss = 5.080189\n",
+      "Epoch 381 batch: 0 , loss = 5.080185\n",
+      "Epoch 382 batch: 0 , loss = 5.080244\n",
+      "Epoch 383 batch: 0 , loss = 5.080095\n",
+      "Epoch 384 batch: 0 , loss = 5.088795\n",
+      "Epoch 385 batch: 0 , loss = 5.084120\n",
+      "Epoch 386 batch: 0 , loss = 5.087992\n",
+      "Epoch 387 batch: 0 , loss = 5.082991\n",
+      "Epoch 388 batch: 0 , loss = 5.072812\n",
+      "Epoch 389 batch: 0 , loss = 5.077140\n",
+      "Epoch 390 batch: 0 , loss = 5.081978\n",
+      "Epoch 391 batch: 0 , loss = 5.086642\n",
+      "Epoch 392 batch: 0 , loss = 5.091150\n",
+      "Epoch 393 batch: 0 , loss = 5.081849\n",
+      "Epoch 394 batch: 0 , loss = 5.071902\n",
+      "Epoch 395 batch: 0 , loss = 5.090557\n",
+      "Epoch 396 batch: 0 , loss = 5.085775\n",
+      "Epoch 397 batch: 0 , loss = 5.076406\n",
+      "Epoch 398 batch: 0 , loss = 5.086689\n",
+      "Epoch 399 batch: 0 , loss = 5.082210\n",
+      "Epoch 400 batch: 0 , loss = 5.091569\n",
+      "Epoch 401 batch: 0 , loss = 5.082448\n",
+      "Epoch 402 batch: 0 , loss = 5.081671\n",
+      "Epoch 403 batch: 0 , loss = 5.086580\n",
+      "Epoch 404 batch: 0 , loss = 5.081904\n",
+      "Epoch 405 batch: 0 , loss = 5.082079\n",
+      "Epoch 406 batch: 0 , loss = 5.081985\n",
+      "Epoch 407 batch: 0 , loss = 5.086331\n",
+      "Epoch 408 batch: 0 , loss = 5.076781\n",
+      "Epoch 409 batch: 0 , loss = 5.086619\n",
+      "Epoch 410 batch: 0 , loss = 5.085722\n",
+      "Epoch 411 batch: 0 , loss = 5.081203\n",
+      "Epoch 412 batch: 0 , loss = 5.080491\n",
+      "Epoch 413 batch: 0 , loss = 5.086140\n",
+      "Epoch 414 batch: 0 , loss = 5.076138\n",
+      "Epoch 415 batch: 0 , loss = 5.090469\n",
+      "Epoch 416 batch: 0 , loss = 5.075500\n",
+      "Epoch 417 batch: 0 , loss = 5.080328\n",
+      "Epoch 418 batch: 0 , loss = 5.084836\n",
+      "Epoch 419 batch: 0 , loss = 5.079198\n",
+      "Epoch 420 batch: 0 , loss = 5.068509\n",
+      "Epoch 421 batch: 0 , loss = 5.081526\n",
+      "Epoch 422 batch: 0 , loss = 5.075669\n",
+      "Epoch 423 batch: 0 , loss = 5.070409\n",
+      "Epoch 424 batch: 0 , loss = 5.069685\n",
+      "Epoch 425 batch: 0 , loss = 5.074168\n",
+      "Epoch 426 batch: 0 , loss = 5.073371\n",
+      "Epoch 427 batch: 0 , loss = 5.077053\n",
+      "Epoch 428 batch: 0 , loss = 5.076373\n",
+      "Epoch 429 batch: 0 , loss = 5.065866\n",
+      "Epoch 430 batch: 0 , loss = 5.069530\n",
+      "Epoch 431 batch: 0 , loss = 5.063504\n",
+      "Epoch 432 batch: 0 , loss = 5.066977\n",
+      "Epoch 433 batch: 0 , loss = 5.061393\n",
+      "Epoch 434 batch: 0 , loss = 5.064750\n",
+      "Epoch 435 batch: 0 , loss = 5.063208\n",
+      "Epoch 436 batch: 0 , loss = 5.062410\n",
+      "Epoch 437 batch: 0 , loss = 5.056867\n",
+      "Epoch 438 batch: 0 , loss = 5.061165\n",
+      "Epoch 439 batch: 0 , loss = 5.055538\n",
+      "Epoch 440 batch: 0 , loss = 5.064088\n",
+      "Epoch 441 batch: 0 , loss = 5.068861\n",
+      "Epoch 442 batch: 0 , loss = 5.058971\n",
+      "Epoch 443 batch: 0 , loss = 5.054216\n",
+      "Epoch 444 batch: 0 , loss = 5.053985\n",
+      "Epoch 445 batch: 0 , loss = 5.053171\n",
+      "Epoch 446 batch: 0 , loss = 5.052192\n",
+      "Epoch 447 batch: 0 , loss = 5.055376\n",
+      "Epoch 448 batch: 0 , loss = 5.048327\n",
+      "Epoch 449 batch: 0 , loss = 5.047246\n",
+      "Epoch 450 batch: 0 , loss = 5.056007\n",
+      "Epoch 451 batch: 0 , loss = 5.045246\n",
+      "Epoch 452 batch: 0 , loss = 5.048958\n",
+      "Epoch 453 batch: 0 , loss = 5.043387\n",
+      "Epoch 454 batch: 0 , loss = 5.052229\n",
+      "Epoch 455 batch: 0 , loss = 5.036552\n",
+      "Epoch 456 batch: 0 , loss = 5.047226\n",
+      "Epoch 457 batch: 0 , loss = 5.047490\n",
+      "Epoch 458 batch: 0 , loss = 5.045058\n",
+      "Epoch 459 batch: 0 , loss = 5.034773\n",
+      "Epoch 460 batch: 0 , loss = 5.037952\n",
+      "Epoch 461 batch: 0 , loss = 5.040687\n",
+      "Epoch 462 batch: 0 , loss = 5.031157\n",
+      "Epoch 463 batch: 0 , loss = 5.032485\n",
+      "Epoch 464 batch: 0 , loss = 5.017200\n",
+      "Epoch 465 batch: 0 , loss = 5.020655\n",
+      "Epoch 466 batch: 0 , loss = 5.024525\n",
+      "Epoch 467 batch: 0 , loss = 5.021727\n",
+      "Epoch 468 batch: 0 , loss = 5.015924\n",
+      "Epoch 469 batch: 0 , loss = 5.019942\n",
+      "Epoch 470 batch: 0 , loss = 5.024877\n",
+      "Epoch 471 batch: 0 , loss = 5.019931\n",
+      "Epoch 472 batch: 0 , loss = 5.023381\n",
+      "Epoch 473 batch: 0 , loss = 5.007413\n",
+      "Epoch 474 batch: 0 , loss = 5.015634\n",
+      "Epoch 475 batch: 0 , loss = 5.019332\n",
+      "Epoch 476 batch: 0 , loss = 5.023258\n",
+      "Epoch 477 batch: 0 , loss = 5.007919\n",
+      "Epoch 478 batch: 0 , loss = 5.012198\n",
+      "Epoch 479 batch: 0 , loss = 5.016301\n",
+      "Epoch 480 batch: 0 , loss = 5.015327\n",
+      "Epoch 481 batch: 0 , loss = 5.015544\n",
+      "Epoch 482 batch: 0 , loss = 5.015000\n",
+      "Epoch 483 batch: 0 , loss = 5.009738\n",
+      "Epoch 484 batch: 0 , loss = 5.004784\n",
+      "Epoch 485 batch: 0 , loss = 5.009272\n",
+      "Epoch 486 batch: 0 , loss = 5.014519\n",
+      "Epoch 487 batch: 0 , loss = 5.014775\n",
+      "Epoch 488 batch: 0 , loss = 5.018719\n",
+      "Epoch 489 batch: 0 , loss = 5.009684\n",
+      "Epoch 490 batch: 0 , loss = 5.014743\n",
+      "Epoch 491 batch: 0 , loss = 5.015655\n",
+      "Epoch 492 batch: 0 , loss = 5.011545\n",
+      "Epoch 493 batch: 0 , loss = 5.015726\n",
+      "Epoch 494 batch: 0 , loss = 5.009212\n",
+      "Epoch 495 batch: 0 , loss = 5.008894\n",
+      "Epoch 496 batch: 0 , loss = 5.002744\n",
+      "Epoch 497 batch: 0 , loss = 5.007370\n",
+      "Epoch 498 batch: 0 , loss = 5.007164\n",
+      "Epoch 499 batch: 0 , loss = 5.007099\n"
      ]
     }
    ],
+   "source": [
+    "# Train with data loader.\n",
+    "\n",
+    "\n",
+    "\n",
+    "model = NgmOne(device)\n",
+    "criterion = nn.CrossEntropyLoss()\n",
+    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
+    "\n",
+    "epoch = 500\n",
+    "batch_size = 200\n",
+    "for e in range(epoch):\n",
+    "    train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
+    "    for i_batch, sample_batched in enumerate(train_dataloader):\n",
+    "        optimizer.zero_grad()\n",
+    "        train = sample_batched[0]\n",
+    "        train_mask = sample_batched[1]\n",
+    "        label_index = sample_batched[2].to(device)\n",
+    "\n",
+    "        # Forward pass\n",
+    "        output = model(train, train_mask)\n",
+    "        loss = criterion(output, label_index)\n",
+    "\n",
+    "        # backward and optimize\n",
+    "        loss.backward()\n",
+    "        optimizer.step()\n",
+    "        if i_batch % batch_size == 0:\n",
+    "            print(\"Epoch\", e, \"batch:\",i_batch, ', loss =', '{:.6f}'.format(loss))\n",
+    "    \n",
+    "    \n",
+    "    \n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
    "source": [
     "model = NgmOne(device)\n",
     "\n",
-    "EPOCHS = 300\n",
+    "EPOCHS = 1500\n",
     "criterion = nn.CrossEntropyLoss()\n",
-    "optimizer = optim.Adam(model.parameters(), lr=0.05)\n",
+    "optimizer = optim.Adam(model.parameters(), lr=0.007)\n",
     "\n",
     "with open(\"../data/relations-query-qald-9-linked.json\", \"r\") as f:\n",
     "    relations = json.load(f)\n",
@@ -591,7 +888,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 66,
    "metadata": {},
    "outputs": [
     {
@@ -605,8 +902,8 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "100%|██████████| 408/408 [00:00<00:00, 814.37it/s] \n",
-      "100%|██████████| 408/408 [00:00<00:00, 101922.34it/s]\n"
+      "100%|██████████| 408/408 [00:00<00:00, 962.29it/s] \n",
+      "100%|██████████| 408/408 [00:00<00:00, 81687.72it/s]\n"
      ]
     },
     {
@@ -614,17 +911,8 @@
      "output_type": "stream",
      "text": [
       "Finished with batches\n",
-      "lowest confidence 0.19197923\n",
-      "Accuracy: 0.6323529411764706\n"
-     ]
-    },
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "\n",
-      "KeyboardInterrupt\n",
-      "\n"
+      "lowest confidence 0.14628477\n",
+      "Accuracy: 0.5245098039215687\n"
      ]
     }
    ],
@@ -636,7 +924,6 @@
     "\n",
     "output = output.detach().cpu().numpy()\n",
     "\n",
-    "\n",
     "prediction = [relations[np.argmax(pred).item()]for pred in output]\n",
     "probability = [pred[np.argmax(pred)] for pred in output]\n",
     "correct_pred = [corr_rels[i] for i in range(len(output))]\n",