diff --git a/Neural graph module/ngm.ipynb b/Neural graph module/ngm.ipynb
index a1576638b3ea4097ce2d5ff2925261ea74af84c4..8bddb696fb833290afd664e8215874bfdc5168ba 100644
--- a/Neural graph module/ngm.ipynb	
+++ b/Neural graph module/ngm.ipynb	
@@ -2,9 +2,18 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 2,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "b:\\Programs\\Miniconda\\envs\\tdde19\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+      "  from .autonotebook import tqdm as notebook_tqdm\n"
+     ]
+    }
+   ],
    "source": [
     "import datasets\n",
     "import torch\n",
@@ -20,7 +29,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 11,
+   "execution_count": 3,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -29,32 +38,55 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 24,
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "cuda\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Use GPU if available\n",
+    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
+    "# Print cuda version\n",
+    "print(device)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
    "metadata": {},
    "outputs": [],
    "source": [
     "\n",
     "class NgmOne(nn.Module):\n",
-    "    def __init__(self):\n",
+    "    def __init__(self, device):\n",
     "        super(NgmOne, self).__init__()\n",
     "        self.tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
-    "        self.bert = BertModel.from_pretrained(\"bert-base-uncased\")\n",
-    "        self.linear = nn.Linear(768, 247)\n",
-    "        self.softmax = nn.Softmax(dim=1)\n",
+    "        self.bert = BertModel.from_pretrained(\"bert-base-uncased\").to(device)\n",
+    "        self.linear = nn.Linear(768, 247).to(device)\n",
+    "        self.softmax = nn.Softmax(dim=1).to(device)\n",
+    "        self.device = device\n",
     "    \n",
     "    def forward(self, tokenized_seq, tokenized_mask):\n",
-    "        x = self.bert.forward(tokenized_seq, attention_mask=tokenized_mask)\n",
-    "        x = x[0][:,0,:]\n",
+    "        tokenized_seq= tokenized_seq.to(self.device)\n",
+    "        tokenized_mask = tokenized_mask.to(self.device)\n",
+    "        with torch.no_grad():\n",
+    "            x = self.bert.forward(tokenized_seq, attention_mask=tokenized_mask)\n",
+    "            x = x[0][:,0,:].to(self.device)\n",
     "        \n",
     "        x = self.linear(x)\n",
-    "\n",
     "        x = self.softmax(x)\n",
     "        return x"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 13,
+   "execution_count": 6,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -92,7 +124,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 17,
+   "execution_count": 7,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -123,6 +155,9 @@
     "        #remove empty strings\n",
     "        triplet = [x for x in triplet if x != \"\"]\n",
     "\n",
+    "        if len(triplet) != 3:\n",
+    "            continue\n",
+    "\n",
     "        for t in triplet:\n",
     "            if not(t.find(\"?\")):\n",
     "                triplet[triplet.index(t)] = None\n",
@@ -150,28 +185,31 @@
     "        trip = query.split(\"WHERE\")[1]\n",
     "        trip = trip.replace(\"{\", \"\").replace(\"}\", \"\")\n",
     "        triplet = trip.split(\" \")\n",
+    "    \n",
     "\n",
     "        #remove empty strings\n",
     "        triplet = [x for x in triplet if x != \"\"]\n",
     "\n",
-    "        tokenized = tokenizer(triplet[1], padding=True, truncation=True)\n",
-    "        if correct_rels_max_len < len(tokenized[\"input_ids\"]):\n",
-    "            correct_rels_max_len = len(tokenized[\"input_ids\"])\n",
+    "        if len(triplet) != 3:\n",
+    "            continue\n",
+    "        # tokenized = tokenizer(triplet[1], padding=True, truncation=True)\n",
+    "        # if correct_rels_max_len < len(tokenized[\"input_ids\"]):\n",
+    "        #     correct_rels_max_len = len(tokenized[\"input_ids\"])\n",
     "\n",
-    "        correct_rels.append(list(tokenized.values())[0])\n",
+    "        correct_rels.append(triplet[1])#list(tokenized.values())[0])\n",
     "\n",
     "    inputs_padded = np.array([i + [0]*(inputs_max_len-len(i)) for i in inputs])\n",
-    "    correct_rels_padded = np.array([i + [0]*(correct_rels_max_len-len(i)) for i in correct_rels])\n",
+    "    #correct_rels_padded = np.array([i + [0]*(correct_rels_max_len-len(i)) for i in correct_rels])\n",
     "\n",
     "    inputs_attention_mask = np.where(inputs_padded != 0, 1, 0)\n",
-    "    correct_rels_attention_mask = np.where(correct_rels_padded != 0, 1, 0)\n",
+    "    #correct_rels_attention_mask = np.where(correct_rels_padded != 0, 1, 0)\n",
     "    print(\"Finished with batches\")\n",
-    "    return torch.IntTensor(inputs_padded), torch.IntTensor(inputs_attention_mask), torch.IntTensor(correct_rels_padded), torch.IntTensor(correct_rels_attention_mask)\n"
+    "    return torch.LongTensor(inputs_padded), torch.LongTensor(inputs_attention_mask), correct_rels #torch.IntTensor(correct_rels_padded), torch.LongTensor(correct_rels_attention_mask)\n"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 8,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -188,14 +226,14 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 26,
+   "execution_count": 13,
    "metadata": {},
    "outputs": [
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias']\n",
+      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias']\n",
       "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
       "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
      ]
@@ -211,66 +249,442 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "100%|██████████| 408/408 [00:00<00:00, 688.03it/s]\n",
-      "100%|██████████| 408/408 [00:00<00:00, 2241.79it/s]\n"
+      "100%|██████████| 408/408 [00:00<00:00, 2684.45it/s]\n",
+      "100%|██████████| 408/408 [00:00<00:00, 204160.82it/s]\n"
      ]
     },
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Finished with batches\n"
+      "Finished with batches\n",
+      "Epoch: 0001 loss = 5.509346\n",
+      "Epoch: 0002 loss = 5.508417\n",
+      "Epoch: 0003 loss = 5.507416\n",
+      "Epoch: 0004 loss = 5.506336\n",
+      "Epoch: 0005 loss = 5.505169\n",
+      "Epoch: 0006 loss = 5.503905\n",
+      "Epoch: 0007 loss = 5.502542\n",
+      "Epoch: 0008 loss = 5.501083\n",
+      "Epoch: 0009 loss = 5.499537\n",
+      "Epoch: 0010 loss = 5.497901\n",
+      "Epoch: 0011 loss = 5.496177\n",
+      "Epoch: 0012 loss = 5.494359\n",
+      "Epoch: 0013 loss = 5.492444\n",
+      "Epoch: 0014 loss = 5.490432\n",
+      "Epoch: 0015 loss = 5.488323\n",
+      "Epoch: 0016 loss = 5.486121\n",
+      "Epoch: 0017 loss = 5.483822\n",
+      "Epoch: 0018 loss = 5.481424\n",
+      "Epoch: 0019 loss = 5.478921\n",
+      "Epoch: 0020 loss = 5.476310\n",
+      "Epoch: 0021 loss = 5.473589\n",
+      "Epoch: 0022 loss = 5.470768\n",
+      "Epoch: 0023 loss = 5.467863\n",
+      "Epoch: 0024 loss = 5.464887\n",
+      "Epoch: 0025 loss = 5.461853\n",
+      "Epoch: 0026 loss = 5.458760\n",
+      "Epoch: 0027 loss = 5.455590\n",
+      "Epoch: 0028 loss = 5.452313\n",
+      "Epoch: 0029 loss = 5.448884\n",
+      "Epoch: 0030 loss = 5.445263\n",
+      "Epoch: 0031 loss = 5.441427\n",
+      "Epoch: 0032 loss = 5.437370\n",
+      "Epoch: 0033 loss = 5.433111\n",
+      "Epoch: 0034 loss = 5.428682\n",
+      "Epoch: 0035 loss = 5.424133\n",
+      "Epoch: 0036 loss = 5.419532\n",
+      "Epoch: 0037 loss = 5.414931\n",
+      "Epoch: 0038 loss = 5.410347\n",
+      "Epoch: 0039 loss = 5.405736\n",
+      "Epoch: 0040 loss = 5.401030\n",
+      "Epoch: 0041 loss = 5.396172\n",
+      "Epoch: 0042 loss = 5.391167\n",
+      "Epoch: 0043 loss = 5.386050\n",
+      "Epoch: 0044 loss = 5.380856\n",
+      "Epoch: 0045 loss = 5.375600\n",
+      "Epoch: 0046 loss = 5.370294\n",
+      "Epoch: 0047 loss = 5.364953\n",
+      "Epoch: 0048 loss = 5.359621\n",
+      "Epoch: 0049 loss = 5.354363\n",
+      "Epoch: 0050 loss = 5.349256\n",
+      "Epoch: 0051 loss = 5.344321\n",
+      "Epoch: 0052 loss = 5.339508\n",
+      "Epoch: 0053 loss = 5.334712\n",
+      "Epoch: 0054 loss = 5.329853\n",
+      "Epoch: 0055 loss = 5.324935\n",
+      "Epoch: 0056 loss = 5.320047\n",
+      "Epoch: 0057 loss = 5.315296\n",
+      "Epoch: 0058 loss = 5.310766\n",
+      "Epoch: 0059 loss = 5.306486\n",
+      "Epoch: 0060 loss = 5.302420\n",
+      "Epoch: 0061 loss = 5.298513\n",
+      "Epoch: 0062 loss = 5.294705\n",
+      "Epoch: 0063 loss = 5.290928\n",
+      "Epoch: 0064 loss = 5.287130\n",
+      "Epoch: 0065 loss = 5.283293\n",
+      "Epoch: 0066 loss = 5.279424\n",
+      "Epoch: 0067 loss = 5.275546\n",
+      "Epoch: 0068 loss = 5.271703\n",
+      "Epoch: 0069 loss = 5.267990\n",
+      "Epoch: 0070 loss = 5.264547\n",
+      "Epoch: 0071 loss = 5.261484\n",
+      "Epoch: 0072 loss = 5.258725\n",
+      "Epoch: 0073 loss = 5.256050\n",
+      "Epoch: 0074 loss = 5.253313\n",
+      "Epoch: 0075 loss = 5.250482\n",
+      "Epoch: 0076 loss = 5.247588\n",
+      "Epoch: 0077 loss = 5.244689\n",
+      "Epoch: 0078 loss = 5.241833\n",
+      "Epoch: 0079 loss = 5.239061\n",
+      "Epoch: 0080 loss = 5.236411\n",
+      "Epoch: 0081 loss = 5.233915\n",
+      "Epoch: 0082 loss = 5.231617\n",
+      "Epoch: 0083 loss = 5.229539\n",
+      "Epoch: 0084 loss = 5.227572\n",
+      "Epoch: 0085 loss = 5.225513\n",
+      "Epoch: 0086 loss = 5.223309\n",
+      "Epoch: 0087 loss = 5.221033\n",
+      "Epoch: 0088 loss = 5.218764\n",
+      "Epoch: 0089 loss = 5.216528\n",
+      "Epoch: 0090 loss = 5.214312\n",
+      "Epoch: 0091 loss = 5.212102\n",
+      "Epoch: 0092 loss = 5.209961\n",
+      "Epoch: 0093 loss = 5.208022\n",
+      "Epoch: 0094 loss = 5.206336\n",
+      "Epoch: 0095 loss = 5.204650\n",
+      "Epoch: 0096 loss = 5.202743\n",
+      "Epoch: 0097 loss = 5.200671\n",
+      "Epoch: 0098 loss = 5.198618\n",
+      "Epoch: 0099 loss = 5.196756\n",
+      "Epoch: 0100 loss = 5.195175\n",
+      "Epoch: 0101 loss = 5.193856\n",
+      "Epoch: 0102 loss = 5.192677\n",
+      "Epoch: 0103 loss = 5.191495\n",
+      "Epoch: 0104 loss = 5.190241\n",
+      "Epoch: 0105 loss = 5.188928\n",
+      "Epoch: 0106 loss = 5.187597\n",
+      "Epoch: 0107 loss = 5.186284\n",
+      "Epoch: 0108 loss = 5.184998\n",
+      "Epoch: 0109 loss = 5.183724\n",
+      "Epoch: 0110 loss = 5.182417\n",
+      "Epoch: 0111 loss = 5.181022\n",
+      "Epoch: 0112 loss = 5.179502\n",
+      "Epoch: 0113 loss = 5.177876\n",
+      "Epoch: 0114 loss = 5.176267\n",
+      "Epoch: 0115 loss = 5.174932\n",
+      "Epoch: 0116 loss = 5.174131\n",
+      "Epoch: 0117 loss = 5.173597\n",
+      "Epoch: 0118 loss = 5.172731\n",
+      "Epoch: 0119 loss = 5.171515\n",
+      "Epoch: 0120 loss = 5.170227\n",
+      "Epoch: 0121 loss = 5.169064\n",
+      "Epoch: 0122 loss = 5.168075\n",
+      "Epoch: 0123 loss = 5.167188\n",
+      "Epoch: 0124 loss = 5.166307\n",
+      "Epoch: 0125 loss = 5.165361\n",
+      "Epoch: 0126 loss = 5.164318\n",
+      "Epoch: 0127 loss = 5.163185\n",
+      "Epoch: 0128 loss = 5.162001\n",
+      "Epoch: 0129 loss = 5.160835\n",
+      "Epoch: 0130 loss = 5.159757\n",
+      "Epoch: 0131 loss = 5.158821\n",
+      "Epoch: 0132 loss = 5.158039\n",
+      "Epoch: 0133 loss = 5.157377\n",
+      "Epoch: 0134 loss = 5.156784\n",
+      "Epoch: 0135 loss = 5.156210\n",
+      "Epoch: 0136 loss = 5.155619\n",
+      "Epoch: 0137 loss = 5.154994\n",
+      "Epoch: 0138 loss = 5.154339\n",
+      "Epoch: 0139 loss = 5.153667\n",
+      "Epoch: 0140 loss = 5.152996\n",
+      "Epoch: 0141 loss = 5.152328\n",
+      "Epoch: 0142 loss = 5.151660\n",
+      "Epoch: 0143 loss = 5.150975\n",
+      "Epoch: 0144 loss = 5.150242\n",
+      "Epoch: 0145 loss = 5.149413\n",
+      "Epoch: 0146 loss = 5.148431\n",
+      "Epoch: 0147 loss = 5.147236\n",
+      "Epoch: 0148 loss = 5.145802\n",
+      "Epoch: 0149 loss = 5.144192\n",
+      "Epoch: 0150 loss = 5.142632\n",
+      "Epoch: 0151 loss = 5.141498\n",
+      "Epoch: 0152 loss = 5.140967\n",
+      "Epoch: 0153 loss = 5.140542\n",
+      "Epoch: 0154 loss = 5.139736\n",
+      "Epoch: 0155 loss = 5.138566\n",
+      "Epoch: 0156 loss = 5.137214\n",
+      "Epoch: 0157 loss = 5.135843\n",
+      "Epoch: 0158 loss = 5.134551\n",
+      "Epoch: 0159 loss = 5.133349\n",
+      "Epoch: 0160 loss = 5.132212\n",
+      "Epoch: 0161 loss = 5.131142\n",
+      "Epoch: 0162 loss = 5.130210\n",
+      "Epoch: 0163 loss = 5.129475\n",
+      "Epoch: 0164 loss = 5.128819\n",
+      "Epoch: 0165 loss = 5.128070\n",
+      "Epoch: 0166 loss = 5.127200\n",
+      "Epoch: 0167 loss = 5.126271\n",
+      "Epoch: 0168 loss = 5.125329\n",
+      "Epoch: 0169 loss = 5.124373\n",
+      "Epoch: 0170 loss = 5.123369\n",
+      "Epoch: 0171 loss = 5.122267\n",
+      "Epoch: 0172 loss = 5.121037\n",
+      "Epoch: 0173 loss = 5.119725\n",
+      "Epoch: 0174 loss = 5.118514\n",
+      "Epoch: 0175 loss = 5.117722\n",
+      "Epoch: 0176 loss = 5.117445\n",
+      "Epoch: 0177 loss = 5.117110\n",
+      "Epoch: 0178 loss = 5.116376\n",
+      "Epoch: 0179 loss = 5.115437\n",
+      "Epoch: 0180 loss = 5.114553\n",
+      "Epoch: 0181 loss = 5.113859\n",
+      "Epoch: 0182 loss = 5.113349\n",
+      "Epoch: 0183 loss = 5.112945\n",
+      "Epoch: 0184 loss = 5.112563\n",
+      "Epoch: 0185 loss = 5.112154\n",
+      "Epoch: 0186 loss = 5.111700\n",
+      "Epoch: 0187 loss = 5.111213\n",
+      "Epoch: 0188 loss = 5.110720\n",
+      "Epoch: 0189 loss = 5.110251\n",
+      "Epoch: 0190 loss = 5.109831\n",
+      "Epoch: 0191 loss = 5.109465\n",
+      "Epoch: 0192 loss = 5.109145\n",
+      "Epoch: 0193 loss = 5.108854\n",
+      "Epoch: 0194 loss = 5.108573\n",
+      "Epoch: 0195 loss = 5.108290\n",
+      "Epoch: 0196 loss = 5.108000\n",
+      "Epoch: 0197 loss = 5.107710\n",
+      "Epoch: 0198 loss = 5.107430\n",
+      "Epoch: 0199 loss = 5.107163\n",
+      "Epoch: 0200 loss = 5.106915\n",
+      "Epoch: 0201 loss = 5.106686\n",
+      "Epoch: 0202 loss = 5.106472\n",
+      "Epoch: 0203 loss = 5.106268\n",
+      "Epoch: 0204 loss = 5.106072\n",
+      "Epoch: 0205 loss = 5.105881\n",
+      "Epoch: 0206 loss = 5.105694\n",
+      "Epoch: 0207 loss = 5.105511\n",
+      "Epoch: 0208 loss = 5.105332\n",
+      "Epoch: 0209 loss = 5.105159\n",
+      "Epoch: 0210 loss = 5.104992\n",
+      "Epoch: 0211 loss = 5.104831\n",
+      "Epoch: 0212 loss = 5.104676\n",
+      "Epoch: 0213 loss = 5.104527\n",
+      "Epoch: 0214 loss = 5.104386\n",
+      "Epoch: 0215 loss = 5.104248\n",
+      "Epoch: 0216 loss = 5.104114\n",
+      "Epoch: 0217 loss = 5.103983\n",
+      "Epoch: 0218 loss = 5.103856\n",
+      "Epoch: 0219 loss = 5.103732\n",
+      "Epoch: 0220 loss = 5.103611\n",
+      "Epoch: 0221 loss = 5.103494\n",
+      "Epoch: 0222 loss = 5.103378\n",
+      "Epoch: 0223 loss = 5.103266\n",
+      "Epoch: 0224 loss = 5.103157\n",
+      "Epoch: 0225 loss = 5.103050\n",
+      "Epoch: 0226 loss = 5.102947\n",
+      "Epoch: 0227 loss = 5.102844\n",
+      "Epoch: 0228 loss = 5.102745\n",
+      "Epoch: 0229 loss = 5.102647\n",
+      "Epoch: 0230 loss = 5.102551\n",
+      "Epoch: 0231 loss = 5.102457\n",
+      "Epoch: 0232 loss = 5.102361\n",
+      "Epoch: 0233 loss = 5.102268\n",
+      "Epoch: 0234 loss = 5.102171\n",
+      "Epoch: 0235 loss = 5.102070\n",
+      "Epoch: 0236 loss = 5.101960\n",
+      "Epoch: 0237 loss = 5.101832\n",
+      "Epoch: 0238 loss = 5.101676\n",
+      "Epoch: 0239 loss = 5.101472\n",
+      "Epoch: 0240 loss = 5.101196\n",
+      "Epoch: 0241 loss = 5.100809\n",
+      "Epoch: 0242 loss = 5.100273\n",
+      "Epoch: 0243 loss = 5.099566\n",
+      "Epoch: 0244 loss = 5.098716\n",
+      "Epoch: 0245 loss = 5.097848\n",
+      "Epoch: 0246 loss = 5.097132\n",
+      "Epoch: 0247 loss = 5.096595\n",
+      "Epoch: 0248 loss = 5.095999\n",
+      "Epoch: 0249 loss = 5.095338\n",
+      "Epoch: 0250 loss = 5.094879\n",
+      "Epoch: 0251 loss = 5.094688\n",
+      "Epoch: 0252 loss = 5.094509\n",
+      "Epoch: 0253 loss = 5.094189\n",
+      "Epoch: 0254 loss = 5.093769\n",
+      "Epoch: 0255 loss = 5.093320\n",
+      "Epoch: 0256 loss = 5.092857\n",
+      "Epoch: 0257 loss = 5.092339\n",
+      "Epoch: 0258 loss = 5.091694\n",
+      "Epoch: 0259 loss = 5.090842\n",
+      "Epoch: 0260 loss = 5.089743\n",
+      "Epoch: 0261 loss = 5.088472\n",
+      "Epoch: 0262 loss = 5.087311\n",
+      "Epoch: 0263 loss = 5.086748\n",
+      "Epoch: 0264 loss = 5.086977\n",
+      "Epoch: 0265 loss = 5.087138\n",
+      "Epoch: 0266 loss = 5.086704\n",
+      "Epoch: 0267 loss = 5.085968\n",
+      "Epoch: 0268 loss = 5.085289\n",
+      "Epoch: 0269 loss = 5.084831\n",
+      "Epoch: 0270 loss = 5.084577\n",
+      "Epoch: 0271 loss = 5.084425\n",
+      "Epoch: 0272 loss = 5.084273\n",
+      "Epoch: 0273 loss = 5.084045\n",
+      "Epoch: 0274 loss = 5.083702\n",
+      "Epoch: 0275 loss = 5.083218\n",
+      "Epoch: 0276 loss = 5.082567\n",
+      "Epoch: 0277 loss = 5.081728\n",
+      "Epoch: 0278 loss = 5.080690\n",
+      "Epoch: 0279 loss = 5.079515\n",
+      "Epoch: 0280 loss = 5.078421\n",
+      "Epoch: 0281 loss = 5.077838\n",
+      "Epoch: 0282 loss = 5.078005\n",
+      "Epoch: 0283 loss = 5.078021\n",
+      "Epoch: 0284 loss = 5.077395\n",
+      "Epoch: 0285 loss = 5.076532\n",
+      "Epoch: 0286 loss = 5.075821\n",
+      "Epoch: 0287 loss = 5.075374\n",
+      "Epoch: 0288 loss = 5.075100\n",
+      "Epoch: 0289 loss = 5.074864\n",
+      "Epoch: 0290 loss = 5.074569\n",
+      "Epoch: 0291 loss = 5.074176\n",
+      "Epoch: 0292 loss = 5.073692\n",
+      "Epoch: 0293 loss = 5.073163\n",
+      "Epoch: 0294 loss = 5.072643\n",
+      "Epoch: 0295 loss = 5.072181\n",
+      "Epoch: 0296 loss = 5.071800\n",
+      "Epoch: 0297 loss = 5.071474\n",
+      "Epoch: 0298 loss = 5.071146\n",
+      "Epoch: 0299 loss = 5.070775\n",
+      "Epoch: 0300 loss = 5.070364\n"
+     ]
+    }
+   ],
+   "source": [
+    "model = NgmOne(device)\n",
+    "\n",
+    "EPOCHS = 300\n",
+    "criterion = nn.CrossEntropyLoss()\n",
+    "optimizer = optim.Adam(model.parameters(), lr=0.05)\n",
+    "\n",
+    "with open(\"../data/relations-query-qald-9-linked.json\", \"r\") as f:\n",
+    "    relations = json.load(f)\n",
+    "\n",
+    "train, train_mask, corr_rels = make_batch()\n",
+    "corr_indx = torch.LongTensor([relations.index(r) for r in corr_rels]).to(device)\n",
+    "\n",
+    "for epoch in range(EPOCHS):\n",
+    "    optimizer.zero_grad()\n",
+    "\n",
+    "    # Forward pass\n",
+    "    output = model(train, train_mask)\n",
+    "    loss = criterion(output, corr_indx)\n",
+    "\n",
+    "    if (epoch + 1) % 1 == 0:\n",
+    "        print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))\n",
+    "    # Backward pass\n",
+    "    loss.backward()\n",
+    "    optimizer.step()\n",
+    "    \n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Beginning making batch\n"
      ]
     },
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "  0%|          | 0/3 [03:03<?, ?it/s]\n"
+      "100%|██████████| 408/408 [00:00<00:00, 814.37it/s] \n",
+      "100%|██████████| 408/408 [00:00<00:00, 101922.34it/s]\n"
      ]
     },
     {
-     "ename": "RuntimeError",
-     "evalue": "0D or 1D target tensor expected, multi-target not supported",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
-      "Cell \u001b[1;32mIn [26], line 13\u001b[0m\n\u001b[0;32m     11\u001b[0m \u001b[39m# Forward pass\u001b[39;00m\n\u001b[0;32m     12\u001b[0m output \u001b[39m=\u001b[39m model(train, train_mask)\n\u001b[1;32m---> 13\u001b[0m loss \u001b[39m=\u001b[39m criterion(output, corr_rels)\n\u001b[0;32m     15\u001b[0m \u001b[39mif\u001b[39;00m (epoch \u001b[39m+\u001b[39m \u001b[39m1\u001b[39m) \u001b[39m%\u001b[39m \u001b[39m10\u001b[39m \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[0;32m     16\u001b[0m     \u001b[39mprint\u001b[39m(\u001b[39m'\u001b[39m\u001b[39mEpoch:\u001b[39m\u001b[39m'\u001b[39m, \u001b[39m'\u001b[39m\u001b[39m%04d\u001b[39;00m\u001b[39m'\u001b[39m \u001b[39m%\u001b[39m (epoch \u001b[39m+\u001b[39m \u001b[39m1\u001b[39m), \u001b[39m'\u001b[39m\u001b[39mcost =\u001b[39m\u001b[39m'\u001b[39m, \u001b[39m'\u001b[39m\u001b[39m{:.6f}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(loss))\n",
-      "File \u001b[1;32mc:\\Users\\maxbj\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1129\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39m\u001b[39minput\u001b[39m, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[0;32m   1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m   1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
-      "File \u001b[1;32mc:\\Users\\maxbj\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\loss.py:1164\u001b[0m, in \u001b[0;36mCrossEntropyLoss.forward\u001b[1;34m(self, input, target)\u001b[0m\n\u001b[0;32m   1163\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: Tensor, target: Tensor) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Tensor:\n\u001b[1;32m-> 1164\u001b[0m     \u001b[39mreturn\u001b[39;00m F\u001b[39m.\u001b[39;49mcross_entropy(\u001b[39minput\u001b[39;49m, target, weight\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mweight,\n\u001b[0;32m   1165\u001b[0m                            ignore_index\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mignore_index, reduction\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mreduction,\n\u001b[0;32m   1166\u001b[0m                            label_smoothing\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mlabel_smoothing)\n",
-      "File \u001b[1;32mc:\\Users\\maxbj\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\functional.py:3014\u001b[0m, in \u001b[0;36mcross_entropy\u001b[1;34m(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)\u001b[0m\n\u001b[0;32m   3012\u001b[0m \u001b[39mif\u001b[39;00m size_average \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mor\u001b[39;00m reduce \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m   3013\u001b[0m     reduction \u001b[39m=\u001b[39m _Reduction\u001b[39m.\u001b[39mlegacy_get_string(size_average, reduce)\n\u001b[1;32m-> 3014\u001b[0m \u001b[39mreturn\u001b[39;00m torch\u001b[39m.\u001b[39;49m_C\u001b[39m.\u001b[39;49m_nn\u001b[39m.\u001b[39;49mcross_entropy_loss(\u001b[39minput\u001b[39;49m, target, weight, _Reduction\u001b[39m.\u001b[39;49mget_enum(reduction), ignore_index, label_smoothing)\n",
-      "\u001b[1;31mRuntimeError\u001b[0m: 0D or 1D target tensor expected, multi-target not supported"
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Finished with batches\n",
+      "lowest confidence 0.19197923\n",
+      "Accuracy: 0.6323529411764706\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "KeyboardInterrupt\n",
+      "\n"
      ]
     }
    ],
    "source": [
-    "model = NgmOne()\n",
+    "# Predict\n",
+    "train, train_mask, corr_rels = make_batch()\n",
+    "with torch.no_grad():\n",
+    "    output = model(train, train_mask)\n",
     "\n",
-    "EPOCHS = 3\n",
-    "criterion = nn.CrossEntropyLoss()\n",
-    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
+    "output = output.detach().cpu().numpy()\n",
     "\n",
-    "train,train_mask, corr_rels, correct_rels_mask = make_batch()\n",
-    "for epoch in tqdm(range(EPOCHS)):\n",
-    "    optimizer.zero_grad()\n",
     "\n",
-    "    # Forward pass\n",
-    "    output = model(train, train_mask)\n",
-    "    loss = criterion(output, corr_rels)\n",
+    "prediction = [relations[np.argmax(pred).item()]for pred in output]\n",
+    "probability = [pred[np.argmax(pred)] for pred in output]\n",
+    "correct_pred = [corr_rels[i] for i in range(len(output))]\n",
     "\n",
-    "    if (epoch + 1) % 10 == 0:\n",
-    "        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n",
-    "    # Backward pass\n",
-    "    loss.backward()\n",
-    "    optimizer.step()\n",
-    "    \n"
+    "\n",
+    "\n",
+    "preds = [\n",
+    "    {\"pred\": relations[np.argmax(pred).item()], \n",
+    "     \"prob\": pred[np.argmax(pred)],\n",
+    "     \"correct\": corr_rels[count]} \n",
+    "    for count, pred in enumerate(output)\n",
+    "    ]\n",
+    "\n",
+    "\n",
+    "\n",
+    "# for pred in preds:\n",
+    "#     print(\"pred\", pred[\"pred\"], \"prob\", pred[\"prob\"], \"correct\", pred[\"correct\"])\n",
+    "\n",
+    "# for pred, prob, correct_pred in zip(prediction, probability, correct_pred):\n",
+    "#     print(\"pred:\", pred, \"prob:\", prob, \"| correct\", correct_pred)\n",
+    "\n",
+    "print(\"lowest confidence\", min(probability))\n",
+    "\n",
+    "def accuracy_score(y_true, y_pred):\n",
+    "    corr_preds=0\n",
+    "    wrong_preds=0\n",
+    "    for pred, correct in zip(y_pred, y_true):\n",
+    "        if pred == correct:\n",
+    "            corr_preds += 1\n",
+    "        else:\n",
+    "            wrong_preds += 1\n",
+    "    return corr_preds/(corr_preds+wrong_preds)\n",
+    "\n",
+    "print(\"Accuracy:\", accuracy_score(correct_pred, prediction))\n",
+    "\n"
    ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
   }
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "Python 3.9.11 64-bit",
+   "display_name": "Python 3.10.4 ('tdde19')",
    "language": "python",
    "name": "python3"
   },
@@ -284,12 +698,12 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.9.11"
+   "version": "3.10.4"
   },
   "orig_nbformat": 4,
   "vscode": {
    "interpreter": {
-    "hash": "64e7cd3b4b88defe39dd61a4584920400d6beb2615ab2244e340c2e20eecdfe9"
+    "hash": "8e4aa0e1a1e15de86146661edda0b2884b54582522f7ff2b916774ba6b8accb1"
    }
   }
  },
diff --git a/data/prefix-relations-qald-9-train-linked.json b/data/relations-prefixed-qald-9-train-linked.json
similarity index 100%
rename from data/prefix-relations-qald-9-train-linked.json
rename to data/relations-prefixed-qald-9-train-linked.json
diff --git a/data/relations-query-qald-9-linked.json b/data/relations-query-qald-9-linked.json
new file mode 100644
index 0000000000000000000000000000000000000000..3e305e9d35c7ddc8a9970de83ff4d99f0fdb9c80
--- /dev/null
+++ b/data/relations-query-qald-9-linked.json
@@ -0,0 +1,139 @@
+[
+  "dbo:foundationPlace",
+  "dbo:founder",
+  "<http://dbpedia.org/ontology/officialSchoolColour>",
+  "<http://dbpedia.org/property/ethnicity>",
+  "<http://dbpedia.org/ontology/mission>",
+  "<http://dbpedia.org/property/admittancedate>",
+  "dbo:writer",
+  "<http://dbpedia.org/property/carbs>",
+  "<http://dbpedia.org/property/borderingstates>",
+  "dbo:party",
+  "dbo:growingGrape",
+  "<http://dbpedia.org/property/residence>",
+  "dbo:wineRegion",
+  "dbo:language",
+  "dbo:officialLanguage",
+  "dbo:influencedBy",
+  "dbo:routeEnd",
+  "<http://dbpedia.org/ontology/netIncome>",
+  "dbo:bandMember",
+  "dbo:team",
+  "dbo:origin",
+  "<http://dbpedia.org/property/launchPad>",
+  "dbp:species",
+  "dbo:composer",
+  "<http://dbpedia.org/property/author>",
+  "<http://dbpedia.org/ontology/creator>",
+  "dbo:developer",
+  "<http://dbpedia.org/ontology/owner>",
+  "<http://dbpedia.org/ontology/spouse>",
+  "dbo:timeZone",
+  "<http://dbpedia.org/property/governor>",
+  "dbo:numberOfPages",
+  "dbo:deathCause",
+  "dbo:award",
+  "dbo:activeYearsEndDate",
+  "dbo:governmentType",
+  "dbo:maximumDepth",
+  "dbo:owner",
+  "dbo:leader",
+  "dbo:birthDate",
+  "dbo:editor",
+  "dbo:knownFor",
+  "dbo:starring",
+  "dbo:targetAirport",
+  "<http://dbpedia.org/property/founded>",
+  "<http://dbpedia.org/property/leaderParty>",
+  "<http://dbpedia.org/property/speciality>",
+  "dbo:country",
+  "dbo:runtime",
+  "<http://dbpedia.org/ontology/spokenIn>",
+  "dbo:battle",
+  "dbo:discoverer",
+  "dbo:portrayer",
+  "dbo:areaTotal",
+  "<http://dbpedia.org/ontology/portrayer>",
+  "dbo:height",
+  "dbo:spouse",
+  "dbo:abbreviation",
+  "<http://dbpedia.org/ontology/profession>",
+  "dbo:musicComposer",
+  "dbo:firstAscentPerson",
+  "dbo:publisher",
+  "<http://dbpedia.org/ontology/deathPlace>",
+  "<http://dbpedia.org/property/largestmetro>",
+  "dbp:writer",
+  "dbo:ingredient",
+  "dbo:class",
+  "<http://dbpedia.org/property/breed>",
+  "dbo:director",
+  "dbo:alias",
+  "<http://dbpedia.org/ontology/currency>",
+  "<http://dbpedia.org/ontology/populationTotal>",
+  "<http://dbpedia.org/property/successor>",
+  "<http://dbpedia.org/ontology/manager>",
+  "dbo:mayor",
+  "dbo:date",
+  "dbp:editor",
+  "a",
+  "<http://dbpedia.org/ontology/foundedBy>",
+  "dbo:completionDate",
+  "dbp:populationDensityRank",
+  "dbo:presenter",
+  "<http://dbpedia.org/ontology/instrument>",
+  "<http://dbpedia.org/ontology/numberOfEmployees>",
+  "dbo:series",
+  "dbo:creator",
+  "<http://dbpedia.org/ontology/producer>",
+  "dbo:leaderName",
+  "<http://dbpedia.org/property/children>",
+  "<http://dbpedia.org/property/title>",
+  "<http://dbpedia.org/property/fifaMin>",
+  "dbo:crosses",
+  "dbo:mission",
+  "dbo:architect",
+  "dbo:largestCity",
+  "dbo:budget",
+  "<http://dbpedia.org/property/birthName>",
+  "dbo:populationTotal",
+  "dbo:foundingDate",
+  "dbo:vicePresident",
+  "<http://dbpedia.org/ontology/genre>",
+  "dbo:product",
+  "dbo:type",
+  "dbo:state",
+  "dbo:influenced",
+  "dbo:doctoralAdvisor",
+  "dbo:numberOfLocations",
+  "dbo:successor",
+  "<http://dbpedia.org/property/ballpark>",
+  "<http://dbpedia.org/ontology/country>",
+  "dbo:sourceCountry",
+  "dbo:birthPlace",
+  "<http://dbpedia.org/property/highest>",
+  "<http://dbpedia.org/ontology/child>",
+  "dbo:birthName",
+  "dbo:child",
+  "dbo:deathPlace",
+  "<http://dbpedia.org/ontology/abbreviation>",
+  "dbo:dissolutionDate",
+  "dbo:author",
+  "dbo:birthYear",
+  "<http://dbpedia.org/property/beginningDate>",
+  "dbo:ethnicGroup",
+  "dbo:currency",
+  "<http://dbpedia.org/property/programme>",
+  "<http://dbpedia.org/property/shipNamesake>",
+  "dbo:capital",
+  "dbo:programmingLanguage",
+  "dbo:city",
+  "<http://dbpedia.org/ontology/deathDate>",
+  "dbo:almaMater",
+  "<http://dbpedia.org/property/employees>",
+  "dbo:location",
+  "<http://dbpedia.org/property/accessioneudate>",
+  "rdf:type",
+  "<http://dbpedia.org/ontology/developer>",
+  "dbo:restingPlace"
+]