diff --git a/Neural graph module/ngm.ipynb b/Neural graph module/ngm.ipynb
index d95c77016b150ff219acb1da22657a5994bbad3b..a9d65ee24701f9458b591a0d099a587efc82f298 100644
--- a/Neural graph module/ngm.ipynb	
+++ b/Neural graph module/ngm.ipynb	
@@ -2,7 +2,7 @@
   "cells": [
     {
       "cell_type": "code",
-      "execution_count": 25,
+      "execution_count": 17,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -21,7 +21,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 26,
+      "execution_count": 18,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -31,7 +31,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 27,
+      "execution_count": 19,
       "metadata": {},
       "outputs": [
         {
@@ -51,7 +51,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 28,
+      "execution_count": 20,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -82,7 +82,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 29,
+      "execution_count": 21,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -196,7 +196,7 @@
         "                if triplet_i[0] is not None and triplet_i[1] is not None:\n",
         "                    #seq += \"[SUB] [SEP] \" + triplet[0]\n",
         "                    # , padding=True, truncation=True)\n",
-        "                    tokenized_seq = tokenizer(question.lower(), \"[OBJ] [SEP] \" + triplet_i[0].split(\":\")[1].lower(), padding=True, truncation=True)\n",
+        "                    tokenized_seq = tokenizer(question.lower(), \"[SUB] [SEP] \" + triplet_i[0].split(\":\")[1].lower(), padding=True, truncation=True)\n",
         "                    sub_obj_ents.append(\"[SUB] \" + prefixes_reverse[\"\".join(triplet_i[0].split(\":\")[0]) + \":\"] + triplet_i[0].split(\":\")[1])\n",
         "                elif triplet_i[2] is not None and triplet_i[1] is not None:\n",
         "                    #seq += \"[OBJ] [SEP] \" + triplet[2]\n",
@@ -221,7 +221,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 30,
+      "execution_count": 22,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -244,7 +244,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 31,
+      "execution_count": 23,
       "metadata": {},
       "outputs": [
         {
@@ -258,7 +258,7 @@
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "100%|██████████| 2052/2052 [00:00<00:00, 2906.57it/s]\n"
+            "100%|██████████| 2052/2052 [00:01<00:00, 1068.18it/s]\n"
           ]
         },
         {
@@ -266,17 +266,17 @@
           "output_type": "stream",
           "text": [
             "Finished with batches\n",
-            "features: ['[CLS] name the home stadium of fc spartak moscow season 2011 - 12 [SEP] [ obj ] [SEP] 2011 – 12 _ fc _ spartak _ moscow _ season [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'] mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
-            "         1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
-            "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) label_index tensor(431)\n",
-            "valid features: tensor([[  101,  2054,  3063,  2001,  3378,  4876,  2007,  1996,  6436, 26785,\n",
-            "          7971,  6447,  1998,  2600,  1037, 20160, 10362,  1029,   102,  1031,\n",
-            "         27885,  3501,  1033,   102,  2600,  1011,  1037,  1011, 20160, 10362,\n",
-            "           102,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
+            "features: ['[CLS] what ingredients are used in preparing the dish of ragout fin? [SEP] [ sub ] [SEP] ragout _ fin [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'] mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+            "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+            "         0, 0, 0, 0, 0, 0, 0, 0, 0]]) label_index tensor(336)\n",
+            "valid features: tensor([[  101,  2054,  2003,  1996,  2344,  1997,  2577, 10424,  2483, 11283,\n",
+            "          7570,  2906,  1029,   102,  1031,  4942,  1033,   102,  2577,  1035,\n",
+            "         10424,  2483, 11283,  1035,  7570,  2906,   102,     0,     0,     0,\n",
             "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
-            "             0,     0,     0,     0,     0,     0,     0,     0]]) valid mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
-            "         1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
-            "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) valid label_index tensor(149)\n"
+            "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
+            "             0,     0,     0,     0,     0,     0,     0]]) valid mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+            "         1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+            "         0, 0, 0, 0, 0, 0, 0, 0, 0]]) valid label_index tensor(297)\n"
           ]
         }
       ],
@@ -319,14 +319,14 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 32,
+      "execution_count": 24,
       "metadata": {},
       "outputs": [
         {
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']\n",
+            "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']\n",
             "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
             "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
           ]
@@ -350,29 +350,29 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 33,
+      "execution_count": 35,
       "metadata": {},
       "outputs": [
         {
           "name": "stdout",
           "output_type": "stream",
           "text": [
-            "1 Train 4.168464240410345 , Valid  6.253535368863274\n",
-            "2 Train 4.158141554688378 , Valid  6.238900885862463\n",
-            "3 Train 4.135342076527986 , Valid  6.225469000199261\n",
-            "4 Train 4.12202568705991 , Valid  6.212939725202673\n",
-            "5 Train 4.109015821552963 , Valid  6.205244933857637\n",
-            "6 Train 4.098594773587563 , Valid  6.194438050774967\n",
-            "7 Train 4.08812516884838 , Valid  6.178463767556583\n",
-            "8 Train 4.079949097667667 , Valid  6.175140450982487\n",
-            "9 Train 4.07018859094853 , Valid  6.167128044016221\n",
-            "10 Train 4.065771084037616 , Valid  6.161452812307021\n"
+            "1 Train 3.517506525670882 , Valid  6.241430044174194\n",
+            "2 Train 3.514909356618099 , Valid  6.230757559047026\n",
+            "3 Train 3.494271710622225 , Valid  6.218190235250137\n",
+            "4 Train 3.487279334514261 , Valid  6.216006854001214\n",
+            "5 Train 3.4676955844000945 , Valid  6.193478163550882\n",
+            "6 Train 3.4382689976863725 , Valid  6.1805572509765625\n",
+            "7 Train 3.442131721716133 , Valid  6.178038597106934\n",
+            "8 Train 3.433108813471074 , Valid  6.17140617090113\n",
+            "9 Train 3.424916239951154 , Valid  6.16327678456026\n",
+            "10 Train 3.4210985067079394 , Valid  6.16277868607465\n"
           ]
         }
       ],
       "source": [
         "# Train with data loader.\n",
-        "criterion = nn.CrossEntropyLoss()\n",
+        "criterion = nn.CrossEntropyLoss(reduction=\"none\")\n",
         "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
         "#optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.5)\n",
         "\n",
@@ -423,17 +423,26 @@
         "            loss_rels = []\n",
         "            for i in range(len(relations)):\n",
         "                if prefixes_reverse[\"\".join(relations[i].split(\":\")[0]) + \":\"] + relations[i].split(\":\")[1] in list(res_rels.keys()):\n",
-        "                    loss_rels.append(5) \n",
+        "                    loss_rels.append(0) \n",
         "                else:\n",
-        "                    loss_rels.append(1/20)\n",
+        "                    loss_rels.append(1)\n",
         "            loss_gs.append(loss_rels)\n",
         "        loss_gs = torch.FloatTensor(loss_gs).to(device)\n",
         "\n",
         "            #print(response)\n",
-        "        output_gs = output * (1-alpha) + (loss_gs) * alpha\n",
-        "        loss = criterion(output_gs, label_index)\n",
+        "        preds = [relations[np.argmax(pred).item()]for pred in output.detach().cpu().numpy()]\n",
+        "        relation_loss = []\n",
+        "        for i in range(len(preds)):\n",
+        "            if prefixes_reverse[\"\".join(preds[i].split(\":\")[0]) + \":\"] + preds[i].split(\":\")[1] in list(res_rels.keys()):\n",
+        "                relation_loss.append(0)\n",
+        "            else:\n",
+        "                relation_loss.append(1)\n",
         "\n",
+        "        relation_loss = torch.FloatTensor(relation_loss).to(device)\n",
+        "        loss = criterion(output, label_index)\n",
+        "        loss = loss * (1-alpha) + (relation_loss) * alpha\n",
         "        # backward and optimize\n",
+        "        loss = loss.mean()\n",
         "        loss.backward()\n",
         "        optimizer.step()\n",
         "        train_loss_epoch = train_loss_epoch + loss.item()\n",
@@ -448,14 +457,14 @@
         "            output = model(valid, valid_mask)\n",
         "            loss = criterion(output, label_index)\n",
         "\n",
-        "        valid_loss_epoch = valid_loss_epoch + loss.item()\n",
+        "        valid_loss_epoch = valid_loss_epoch + loss.mean().item()\n",
         "\n",
         "    print(e+1, \"Train\", train_loss_epoch/i_train, \", Valid \", valid_loss_epoch/i_valid)"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": 35,
+      "execution_count": 38,
       "metadata": {},
       "outputs": [
         {
@@ -469,7 +478,7 @@
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "100%|██████████| 2052/2052 [00:00<00:00, 2654.58it/s]\n"
+            "100%|██████████| 2052/2052 [00:02<00:00, 992.25it/s] \n"
           ]
         },
         {
@@ -484,7 +493,7 @@
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "100%|██████████| 1161/1161 [00:00<00:00, 2831.73it/s]\n"
+            "100%|██████████| 1161/1161 [00:01<00:00, 1040.33it/s]\n"
           ]
         },
         {
@@ -492,11 +501,11 @@
           "output_type": "stream",
           "text": [
             "Finished with batches\n",
-            "test loss 6.039710998535156\n",
-            "lowest confidence train 0.07557418\n",
-            "lowest confidence test 0.08204544\n",
-            "Accuracy train: 0.2349570200573066\n",
-            "Accuracy test: 0.007633587786259542\n"
+            "test loss 6.041214942932129\n",
+            "lowest confidence train 0.07357887\n",
+            "lowest confidence test 0.07673955\n",
+            "Accuracy train: 0.24283667621776503\n",
+            "Accuracy test: 0.015267175572519083\n"
           ]
         }
       ],
@@ -513,7 +522,7 @@
         "    output_train = model(train, train_mask)\n",
         "    output_test = model(test_batch, test_mask_batch)\n",
         "    loss = criterion(output_test, corr_rels_test_batch)\n",
-        "    print(\"test loss\", loss.item())\n",
+        "    print(\"test loss\", loss.mean().item())\n",
         "\n",
         "    loss_gs = []\n",
         "    for j in range(len(sub_objs)):\n",