diff --git a/Neural graph module/ngm.ipynb b/Neural graph module/ngm.ipynb
index a924791ec8388b0e67e4a9635a89828d25d7f9bb..d95c77016b150ff219acb1da22657a5994bbad3b 100644
--- a/Neural graph module/ngm.ipynb	
+++ b/Neural graph module/ngm.ipynb	
@@ -2,7 +2,7 @@
   "cells": [
     {
       "cell_type": "code",
-      "execution_count": 17,
+      "execution_count": 25,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -15,21 +15,23 @@
         "from transformers import BertTokenizer, BertModel\n",
         "from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments\n",
         "from tqdm import tqdm\n",
-        "import json\n"
+        "import json\n",
+        "import requests\n"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": 18,
+      "execution_count": 26,
       "metadata": {},
       "outputs": [],
       "source": [
-        "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")"
+        "mod = \"bert-base-uncased\"\n",
+        "tokenizer = BertTokenizer.from_pretrained(mod)"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": 19,
+      "execution_count": 27,
       "metadata": {},
       "outputs": [
         {
@@ -49,7 +51,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 20,
+      "execution_count": 28,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -57,8 +59,12 @@
         "class NgmOne(nn.Module):\n",
         "    def __init__(self, device, relations):\n",
         "        super(NgmOne, self).__init__()\n",
-        "        self.tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
-        "        self.bert = BertModel.from_pretrained(\"bert-base-uncased\").to(device)\n",
+        "        with torch.no_grad():\n",
+        "            self.tokenizer = BertTokenizer.from_pretrained(mod)\n",
+        "            self.bert = BertModel.from_pretrained(mod).to(device)\n",
+        "\n",
+        "        for param in self.bert.parameters():\n",
+        "            param.requires_grad = False\n",
         "        self.linear = nn.Linear(768, len(relations)).to(device)\n",
         "        self.softmax = nn.Softmax(dim=1).to(device)\n",
         "        self.device = device\n",
@@ -68,8 +74,7 @@
         "        tokenized_mask = tokenized_mask.to(self.device)\n",
         "        with torch.no_grad():\n",
         "            x = self.bert.forward(tokenized_seq, attention_mask=tokenized_mask)\n",
-        "            x = x[0][:,0,:].to(self.device)\n",
-        "            \n",
+        "        x = x[0][:,0,:].to(self.device)\n",
         "        x = self.linear(x)\n",
         "        x = self.softmax(x)\n",
         "        return x"
@@ -77,7 +82,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 21,
+      "execution_count": 29,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -98,6 +103,25 @@
         "    \"http://www.w3.org/2004/02/skos/core#\": \"skos:\",\n",
         "}\n",
         "\n",
+        "prefixes_reverse = {v: k for k, v in prefixes.items()}\n",
+        "\n",
+        "# {\n",
+        "# \"res:\": \"http://dbpedia.org/resource/\",\n",
+        "# \"dbo:\": \"http://dbpedia.org/ontology/\",\n",
+        "# \"dbp:\": \"http://dbpedia.org/property/\",\n",
+        "# \"rdfs:\": \"http://www.w3.org/2000/01/rdf-schema#\",\n",
+        "# \"rdf:\": \"http://www.w3.org/1999/02/22-rdf-syntax-ns#\",\n",
+        "# \"yago:\": \"http://dbpedia.org/class/yago/\",\n",
+        "# \"wdt:\": \"http://www.wikidata.org/prop/direct/\",\n",
+        "# \"wd:\": \"http://www.wikidata.org/entity/\",\n",
+        "# \"p:\": \"http://www.wikidata.org/prop/\",\n",
+        "# \"ps:\": \"https://w3id.org/payswarm#\",\n",
+        "# \"pq:\": \"http://www.wikidata.org/prop/qualifier/\",\n",
+        "# \"bd:\": \"http://www.bigdata.com/rdf#\",\n",
+        "# \"wikibase:\": \"http://wikiba.se/ontology#\",\n",
+        "# \"skos:\": \"http://www.w3.org/2004/02/skos/core#\",\n",
+        "# }\n",
+        "\n",
         "prefixes_end = [\"org\", \"se\", \"com\", \"nu\"]\n",
         "ALL_HTTP_PREFIXES = False\n",
         "\n",
@@ -119,6 +143,7 @@
         "    \n",
         "    inputs = []\n",
         "    correct_rels = []\n",
+        "    sub_obj_ents = []\n",
         "    inputs_max_len = 0\n",
         "    for d in tqdm(pred[\"questions\"]):\n",
         "        question = d[\"question\"][0][\"string\"]\n",
@@ -171,17 +196,19 @@
         "                if triplet_i[0] is not None and triplet_i[1] is not None:\n",
         "                    #seq += \"[SUB] [SEP] \" + triplet[0]\n",
         "                    # , padding=True, truncation=True)\n",
-        "                    tokenized_seq = tokenizer(question, \"[SUB]\", triplet_i[0], padding=True, truncation=True)\n",
+        "                    tokenized_seq = tokenizer(question.lower(), \"[OBJ] [SEP] \" + triplet_i[0].split(\":\")[1].lower(), padding=True, truncation=True)\n",
+        "                    sub_obj_ents.append(\"[SUB] \" + prefixes_reverse[\"\".join(triplet_i[0].split(\":\")[0]) + \":\"] + triplet_i[0].split(\":\")[1])\n",
         "                elif triplet_i[2] is not None and triplet_i[1] is not None:\n",
         "                    #seq += \"[OBJ] [SEP] \" + triplet[2]\n",
-        "                    tokenized_seq = tokenizer(question, \"[OBJ]\", triplet_i[2], padding=True, truncation=True)\n",
+        "                    tokenized_seq = tokenizer(question.lower(), \"[OBJ] [SEP] \" + triplet_i[2].split(\":\")[1].lower(), padding=True, truncation=True)\n",
+        "                    sub_obj_ents.append(\"[OBJ] \" + prefixes_reverse[\"\".join(triplet_i[2].split(\":\")[0]) + \":\"] + triplet_i[2].split(\":\")[1])\n",
         "                else:\n",
         "                    continue\n",
         "\n",
         "                if inputs_max_len < len(tokenized_seq[\"input_ids\"]):\n",
         "                    inputs_max_len = len(tokenized_seq[\"input_ids\"])\n",
         "                inputs.append(list(tokenized_seq.values())[0])\n",
-        "                correct_rels.append(triplet_i[1])\n",
+        "                correct_rels.append(triplet_i[1].lower())\n",
         "\n",
         "    inputs_padded = np.array([i + [0]*(inputs_max_len-len(i)) for i in inputs])\n",
         "    #correct_rels_padded = np.array([i + [0]*(correct_rels_max_len-len(i)) for i in correct_rels])\n",
@@ -189,34 +216,35 @@
         "    inputs_attention_mask = np.where(inputs_padded != 0, 1, 0)\n",
         "    #correct_rels_attention_mask = np.where(correct_rels_padded != 0, 1, 0)\n",
         "    print(\"Finished with batches\")\n",
-        "    return torch.LongTensor(inputs_padded), torch.LongTensor(inputs_attention_mask), correct_rels #torch.IntTensor(correct_rels_padded), torch.LongTensor(correct_rels_attention_mask)\n"
+        "    return torch.LongTensor(inputs_padded), torch.LongTensor(inputs_attention_mask), correct_rels, sub_obj_ents #torch.IntTensor(correct_rels_padded), torch.LongTensor(correct_rels_attention_mask)\n"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": 22,
+      "execution_count": 30,
       "metadata": {},
       "outputs": [],
       "source": [
         "from torch.utils.data import Dataset, DataLoader\n",
         "\n",
         "class MyDataset(Dataset):\n",
-        "    def __init__(self, inputs, attention_mask, correct_rels, relations):\n",
+        "    def __init__(self, inputs, attention_mask, correct_rels,ents, relations):\n",
         "        self.inputs = inputs\n",
         "        self.attention_mask = attention_mask\n",
         "        self.correct_rels = correct_rels\n",
+        "        self.ents = ents\n",
         "        self.relations = relations\n",
         "\n",
         "    def __len__(self):\n",
         "        return len(self.inputs)\n",
         "\n",
         "    def __getitem__(self, idx):\n",
-        "        return self.inputs[idx], self.attention_mask[idx], self.relations.index(self.correct_rels[idx])"
+        "        return self.inputs[idx], self.attention_mask[idx], self.relations.index(self.correct_rels[idx]), self.ents[idx]"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": 23,
+      "execution_count": 31,
       "metadata": {},
       "outputs": [
         {
@@ -230,7 +258,7 @@
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "100%|██████████| 2052/2052 [00:03<00:00, 546.37it/s]\n"
+            "100%|██████████| 2052/2052 [00:00<00:00, 2906.57it/s]\n"
           ]
         },
         {
@@ -238,16 +266,17 @@
           "output_type": "stream",
           "text": [
             "Finished with batches\n",
-            "features: tensor([[  101,  2054,  2003,  1996, 13314,  1997, 16122, 10230,  5400,  2100,\n",
-            "         12378,  2136,  1029,   102,  1031,  4942,  1033,   102,     0,     0,\n",
+            "features: ['[CLS] name the home stadium of fc spartak moscow season 2011 - 12 [SEP] [ obj ] [SEP] 2011 – 12 _ fc _ spartak _ moscow _ season [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'] mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+            "         1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+            "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) label_index tensor(431)\n",
+            "valid features: tensor([[  101,  2054,  3063,  2001,  3378,  4876,  2007,  1996,  6436, 26785,\n",
+            "          7971,  6447,  1998,  2600,  1037, 20160, 10362,  1029,   102,  1031,\n",
+            "         27885,  3501,  1033,   102,  2600,  1011,  1037,  1011, 20160, 10362,\n",
+            "           102,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
             "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
-            "             0,     0,     0,     0,     0,     0,     0]]) mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,\n",
-            "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) label_index tensor(378)\n",
-            "valid features: tensor([[  101,  2054,  4752,  1997,  7992, 13843,  5112,  2267,  2003,  2036,\n",
-            "          1996,  9353,  4215, 26432,  2278,  3037,  1997,  2703, 27668,  3077,\n",
-            "          1029,   102,  1031,  4942,  1033,   102,     0,     0,     0,     0,\n",
-            "             0,     0,     0,     0,     0,     0,     0]]) valid mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
-            "         1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) valid label_index tensor(28)\n"
+            "             0,     0,     0,     0,     0,     0,     0,     0]]) valid mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+            "         1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+            "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) valid label_index tensor(149)\n"
           ]
         }
       ],
@@ -259,18 +288,19 @@
         "    with open(file, \"r\") as f:\n",
         "        return json.load(f)\n",
         "\n",
+        "\n",
         "#relations = open_json(\"../data/relations-query-qald-9-linked.json\")\n",
-        "relations = open_json(\"../data/relations-all-lc-quad-no-http.json\")\n",
+        "relations = open_json(\"../data/relations-all-no-http-lowercase.json\")\n",
         "\n",
         "# \"../data/qald-9-train-linked.json\"\n",
         "#pred = \"../LC-QuAD/combined-requeried-linked-train.json\"\n",
-        "inputs, attention_mask, correct_rels = make_batch(src=\"../LC-QuAD/combined-requeried-linked-train.json\", http_prefix = True) #train\n",
+        "inputs, attention_mask, correct_rels, sub_objs = make_batch(src=\"../LC-QuAD/combined-requeried-linked-train.json\", http_prefix = True) #train\n",
         "\n",
         "# relations = open_json(\"../data/relations-lcquad-without-http-train-linked.json\")\n",
         "# train_set = MyDataset(*make_batch(), relations=relations)\n",
         "\n",
         "\n",
-        "dataset = MyDataset(inputs, attention_mask, correct_rels, relations=relations)\n",
+        "dataset = MyDataset(inputs, attention_mask, correct_rels,sub_objs, relations=relations)\n",
         "train_size = int(0.8 * len(dataset))\n",
         "valid_size = len(dataset) - train_size\n",
         "\n",
@@ -278,94 +308,65 @@
         "\n",
         "train_dataloader = DataLoader(train_data, batch_size=1, shuffle=True)\n",
         "#show first entry\n",
-        "train_features, train_mask, train_label = next(iter(train_dataloader))\n",
-        "print(\"features:\", train_features, \"mask:\",train_mask,\"label_index\", train_label[0])\n",
+        "train_features, train_mask, train_label, ents = next(iter(train_dataloader))\n",
+        "print(\"features:\", tokenizer.batch_decode(train_features), \"mask:\",train_mask,\"label_index\", train_label[0])\n",
         "\n",
         "\n",
         "valid_dataloader = DataLoader(valid_data, batch_size=1, shuffle=True)\n",
-        "valid_features, valid_mask, valid_label = next(iter(valid_dataloader))\n",
+        "valid_features, valid_mask, valid_label, ents = next(iter(valid_dataloader))\n",
         "print(\"valid features:\", valid_features, \"valid mask:\",valid_mask,\"valid label_index\", valid_label[0])"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": 24,
+      "execution_count": 32,
       "metadata": {},
       "outputs": [
         {
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']\n",
+            "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']\n",
             "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
             "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
           ]
         }
       ],
       "source": [
+        "SPARQL_ENDPOINT = \"https://dbpedia.org/sparql\"\n",
+        "\n",
+        "\n",
+        "# test_data_file = sys.argv[1]\n",
+        "# predicted_data_file = sys.argv[2]\n",
+        "\n",
+        "headers = {\n",
+        "    'Accept': 'application/sparql-results+json',\n",
+        "    'Content-Type': 'application/x-www-form-urlencoded',\n",
+        "}\n",
+        "\n",
         "# Initialize model\n",
         "model = NgmOne(device, relations)"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": 25,
+      "execution_count": 33,
       "metadata": {},
       "outputs": [
         {
           "name": "stdout",
           "output_type": "stream",
           "text": [
-            "1 Train 6.453075773575726 , Valid  7.615683317184448\n",
-            "2 Train 6.444019513971665 , Valid  7.6052809953689575\n",
-            "3 Train 6.43224432889153 , Valid  7.602278828620911\n",
-            "4 Train 6.415224019218893 , Valid  7.598743319511414\n",
-            "5 Train 6.391957339118509 , Valid  7.584347605705261\n",
-            "6 Train 6.368611419902129 , Valid  7.573382139205933\n",
-            "7 Train 6.339372747084674 , Valid  7.557032465934753\n",
-            "8 Train 6.318648646859562 , Valid  7.539492607116699\n",
-            "9 Train 6.297140289755428 , Valid  7.523481965065002\n",
-            "10 Train 6.281140131108901 , Valid  7.519673824310303\n",
-            "11 Train 6.265462370479808 , Valid  7.5268988609313965\n",
-            "12 Train 6.24550900739782 , Valid  7.504716753959656\n",
-            "13 Train 6.234297051149256 , Valid  7.505620360374451\n",
-            "14 Train 6.220176332137164 , Valid  7.500500917434692\n",
-            "15 Train 6.204408701728372 , Valid  7.493185758590698\n",
-            "16 Train 6.194800797630759 , Valid  7.488588452339172\n",
-            "17 Train 6.183391935685101 , Valid  7.466232180595398\n",
-            "18 Train 6.173011443194221 , Valid  7.457393527030945\n",
-            "19 Train 6.168144422418931 , Valid  7.45364773273468\n",
-            "20 Train 6.157315562753117 , Valid  7.452357649803162\n",
-            "21 Train 6.140512690824621 , Valid  7.456966400146484\n",
-            "22 Train 6.129443336935604 , Valid  7.464043855667114\n",
-            "23 Train 6.12704220940085 , Valid  7.443313360214233\n",
-            "24 Train 6.117525269003475 , Valid  7.4395164251327515\n",
-            "25 Train 6.116069064420812 , Valid  7.456610083580017\n",
-            "26 Train 6.108902790967156 , Valid  7.430782794952393\n",
-            "27 Train 6.108471814323874 , Valid  7.44858992099762\n",
-            "28 Train 6.100216697244083 , Valid  7.457515120506287\n",
-            "29 Train 6.099361447726979 , Valid  7.438013672828674\n",
-            "30 Train 6.092377662658691 , Valid  7.448408484458923\n",
-            "31 Train 6.088698302998262 , Valid  7.442046403884888\n",
-            "32 Train 6.083712998558493 , Valid  7.420018911361694\n",
-            "33 Train 6.081799563239603 , Valid  7.426819205284119\n",
-            "34 Train 6.070652428795309 , Valid  7.426627039909363\n",
-            "35 Train 6.069005825940301 , Valid  7.425306558609009\n",
-            "36 Train 6.059389002182904 , Valid  7.415045142173767\n",
-            "37 Train 6.0618573918062095 , Valid  7.418038249015808\n",
-            "38 Train 6.055309716392966 , Valid  7.434216380119324\n",
-            "39 Train 6.049336994395537 , Valid  7.43299674987793\n",
-            "40 Train 6.048270281623392 , Valid  7.410753607749939\n",
-            "41 Train 6.043416135451373 , Valid  7.410261392593384\n",
-            "42 Train 6.042512613184312 , Valid  7.4074866771698\n",
-            "43 Train 6.032251245835248 , Valid  7.390785813331604\n",
-            "44 Train 6.025217645308551 , Valid  7.370600938796997\n",
-            "45 Train 6.027591985814712 , Valid  7.377846837043762\n",
-            "46 Train 6.0253391826854035 , Valid  7.39339554309845\n",
-            "47 Train 6.019261051626766 , Valid  7.407409071922302\n",
-            "48 Train 6.014283713172464 , Valid  7.393349766731262\n",
-            "49 Train 6.010964337517233 , Valid  7.37293815612793\n",
-            "50 Train 6.010590104495778 , Valid  7.4031054973602295\n"
+            "1 Train 4.168464240410345 , Valid  6.253535368863274\n",
+            "2 Train 4.158141554688378 , Valid  6.238900885862463\n",
+            "3 Train 4.135342076527986 , Valid  6.225469000199261\n",
+            "4 Train 4.12202568705991 , Valid  6.212939725202673\n",
+            "5 Train 4.109015821552963 , Valid  6.205244933857637\n",
+            "6 Train 4.098594773587563 , Valid  6.194438050774967\n",
+            "7 Train 4.08812516884838 , Valid  6.178463767556583\n",
+            "8 Train 4.079949097667667 , Valid  6.175140450982487\n",
+            "9 Train 4.07018859094853 , Valid  6.167128044016221\n",
+            "10 Train 4.065771084037616 , Valid  6.161452812307021\n"
           ]
         }
       ],
@@ -373,11 +374,15 @@
         "# Train with data loader.\n",
         "criterion = nn.CrossEntropyLoss()\n",
         "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
+        "#optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.5)\n",
         "\n",
-        "epoch = 50\n",
-        "batch_size = 64\n",
+        "epoch = 10\n",
+        "batch_size = 8\n",
+        "alpha = 0.5\n",
         "train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n",
         "valid_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=True)\n",
+        "\n",
+        "model.train()\n",
         "for e in range(epoch):\n",
         "    train_loss_epoch = 0\n",
         "    valid_loss_epoch = 0\n",
@@ -386,10 +391,47 @@
         "        train = sample_batched_train[0]\n",
         "        train_mask = sample_batched_train[1]\n",
         "        label_index = sample_batched_train[2].to(device)\n",
+        "        sub_objs = sample_batched_train[3]\n",
         "        \n",
         "        # Forward pass\n",
         "        output = model(train, train_mask)\n",
-        "        loss = criterion(output, label_index)\n",
+        "        loss_gs = []\n",
+        "        for j in range(len(sub_objs)):\n",
+        "            if not (sub_objs[j].split(\" \")[0] == \"[SUB]\" or sub_objs[j].split(\" \")[0] == \"[OBJ]\"):\n",
+        "                continue\n",
+        "\n",
+        "            if sub_objs[j].split(\" \")[0] == \"[SUB]\":\n",
+        "                sub = sub_objs[j].split(\" \")[1]\n",
+        "                q = \"SELECT ?r WHERE { <\" + sub + \"> ?r ?o }\"\n",
+        "\n",
+        "            if sub_objs[j].split(\" \")[0] == \"[OBJ]\":\n",
+        "                obj = sub_objs[j].split(\" \")[1]\n",
+        "                q = \"SELECT ?r WHERE { ?s ?r <\" + obj + \"> }\"\n",
+        "            \n",
+        "            params = {\n",
+        "                \"default-graph-uri\": \"http://dbpedia.org\",\n",
+        "                \"query\": q,\n",
+        "                \"format\": \"json\"\n",
+        "            }\n",
+        "\n",
+        "            response = requests.get(SPARQL_ENDPOINT, headers=headers, params=params, timeout=15)\n",
+        "            results = response.json()\n",
+        "            res_rels = {}\n",
+        "            for i in range(len(list(results[\"results\"].values())[2])):\n",
+        "                res_rels[list(results[\"results\"].values())[2][i][\"r\"][\"value\"].lower()] = \"True\"\n",
+        "            \n",
+        "            loss_rels = []\n",
+        "            for i in range(len(relations)):\n",
+        "                if prefixes_reverse[\"\".join(relations[i].split(\":\")[0]) + \":\"] + relations[i].split(\":\")[1] in list(res_rels.keys()):\n",
+        "                    loss_rels.append(5) \n",
+        "                else:\n",
+        "                    loss_rels.append(1/20)\n",
+        "            loss_gs.append(loss_rels)\n",
+        "        loss_gs = torch.FloatTensor(loss_gs).to(device)\n",
+        "\n",
+        "            #print(response)\n",
+        "        output_gs = output * (1-alpha) + (loss_gs) * alpha\n",
+        "        loss = criterion(output_gs, label_index)\n",
         "\n",
         "        # backward and optimize\n",
         "        loss.backward()\n",
@@ -413,7 +455,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 31,
+      "execution_count": 35,
       "metadata": {},
       "outputs": [
         {
@@ -427,7 +469,7 @@
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "100%|██████████| 2052/2052 [00:03<00:00, 577.85it/s]\n"
+            "100%|██████████| 2052/2052 [00:00<00:00, 2654.58it/s]\n"
           ]
         },
         {
@@ -442,7 +484,7 @@
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "100%|██████████| 1161/1161 [00:02<00:00, 566.88it/s]\n"
+            "100%|██████████| 1161/1161 [00:00<00:00, 2831.73it/s]\n"
           ]
         },
         {
@@ -450,22 +492,22 @@
           "output_type": "stream",
           "text": [
             "Finished with batches\n",
-            "test loss 6.031976699829102\n",
-            "lowest confidence train 0.08024344\n",
-            "lowest confidence test 0.08576707\n",
-            "Accuracy train: 0.3939828080229226\n",
-            "Accuracy test: 0.006361323155216285\n"
+            "test loss 6.039710998535156\n",
+            "lowest confidence train 0.07557418\n",
+            "lowest confidence test 0.08204544\n",
+            "Accuracy train: 0.2349570200573066\n",
+            "Accuracy test: 0.007633587786259542\n"
           ]
         }
       ],
       "source": [
         "# Predict\n",
-        "train, train_mask, corr_rels = make_batch(src=\"../LC-QuAD/combined-requeried-linked-train.json\", http_prefix = True)\n",
-        "test, test_mask, corr_rels_test = make_batch(src=\"../LC-QuAD/combined-requeried-linked-test.json\", http_prefix = True)\n",
-        "test_data = MyDataset(test, test_mask, corr_rels_test, relations=relations)\n",
+        "train, train_mask, corr_rels,ents = make_batch(src=\"../LC-QuAD/combined-requeried-linked-train.json\", http_prefix = True)\n",
+        "test, test_mask, corr_rels_test, ents_test = make_batch(src=\"../LC-QuAD/combined-requeried-linked-test.json\", http_prefix = True)\n",
+        "test_data = MyDataset(test, test_mask, corr_rels_test, ents=ents_test, relations=relations)\n",
         "test_dataloader = DataLoader(test_data, batch_size=len(test_data), shuffle=True)\n",
         "\n",
-        "test_batch, test_mask_batch, corr_rels_test_batch = next(iter(test_dataloader))\n",
+        "test_batch, test_mask_batch, corr_rels_test_batch, sub_objs = next(iter(test_dataloader))\n",
         "corr_rels_test_batch = corr_rels_test_batch.to(device)\n",
         "with torch.no_grad():\n",
         "    output_train = model(train, train_mask)\n",
@@ -473,8 +515,44 @@
         "    loss = criterion(output_test, corr_rels_test_batch)\n",
         "    print(\"test loss\", loss.item())\n",
         "\n",
+        "    loss_gs = []\n",
+        "    for j in range(len(sub_objs)):\n",
+        "        if not (sub_objs[j].split(\" \")[0] == \"[SUB]\" or sub_objs[j].split(\" \")[0] == \"[OBJ]\"):\n",
+        "            continue\n",
+        "\n",
+        "        if sub_objs[j].split(\" \")[0] == \"[SUB]\":\n",
+        "            sub = sub_objs[j].split(\" \")[1]\n",
+        "            q = \"SELECT ?r WHERE { <\" + sub + \"> ?r ?o }\"\n",
+        "\n",
+        "        if sub_objs[j].split(\" \")[0] == \"[OBJ]\":\n",
+        "            obj = sub_objs[j].split(\" \")[1]\n",
+        "            q = \"SELECT ?r WHERE { ?s ?r <\" + obj + \"> }\"\n",
+        "\n",
+        "        params = {\n",
+        "            \"default-graph-uri\": \"http://dbpedia.org\",\n",
+        "            \"query\": q,\n",
+        "            \"format\": \"json\"\n",
+        "        }\n",
+        "\n",
+        "        response = requests.get(\n",
+        "            SPARQL_ENDPOINT, headers=headers, params=params, timeout=15)\n",
+        "        results = response.json()\n",
+        "        res_rels = {}\n",
+        "        for i in range(len(list(results[\"results\"].values())[2])):\n",
+        "            res_rels[list(results[\"results\"].values())[2][i]\n",
+        "                            [\"r\"][\"value\"].lower()] = \"True\"\n",
+        "\n",
+        "        loss_rels = []\n",
+        "        for i in range(len(relations)):\n",
+        "            if prefixes_reverse[\"\".join(relations[i].split(\":\")[0]) + \":\"] + relations[i].split(\":\")[1] in list(res_rels.keys()):\n",
+        "                loss_rels.append(5)\n",
+        "            else:\n",
+        "                loss_rels.append(1/20)\n",
+        "        loss_gs.append(loss_rels)\n",
+        "    loss_gs = torch.FloatTensor(loss_gs).to(device)\n",
+        "output_gs_test = output_test# * loss_gs\n",
         "output_train = output_train.detach().cpu().numpy()\n",
-        "output_test = output_test.detach().cpu().numpy()\n",
+        "output_test = output_gs_test.detach().cpu().numpy()\n",
         "\n",
         "\n",
         "prediction_train = [relations[np.argmax(pred).item()]for pred in output_train]\n",
@@ -506,15 +584,55 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 190,
       "metadata": {},
-      "outputs": [],
-      "source": []
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "+---------------+-------------------+\n",
+            "|    Mod name   | Parameters Listed |\n",
+            "+---------------+-------------------+\n",
+            "| linear.weight |       338688      |\n",
+            "|  linear.bias  |        441        |\n",
+            "+---------------+-------------------+\n",
+            "Sum of trained parameters: 339129\n"
+          ]
+        },
+        {
+          "data": {
+            "text/plain": [
+              "339129"
+            ]
+          },
+          "execution_count": 190,
+          "metadata": {},
+          "output_type": "execute_result"
+        }
+      ],
+      "source": [
+        "from prettytable import PrettyTable\n",
+        "def count_parameters(model):\n",
+        "    table = PrettyTable([\"Mod name\", \"Parameters Listed\"])\n",
+        "    t_params = 0\n",
+        "    for name, parameter in model.named_parameters():\n",
+        "        if not parameter.requires_grad:\n",
+        "            continue\n",
+        "        param = parameter.numel()\n",
+        "        table.add_row([name, param])\n",
+        "        t_params += param\n",
+        "    print(table)\n",
+        "    print(f\"Sum of trained parameters: {t_params}\")\n",
+        "    return t_params\n",
+        "\n",
+        "count_parameters(model)"
+      ]
     }
   ],
   "metadata": {
     "kernelspec": {
-      "display_name": "Python 3.10.4 ('tdde19')",
+      "display_name": "Python 3.9.11 64-bit",
       "language": "python",
       "name": "python3"
     },
@@ -528,12 +646,12 @@
       "name": "python",
       "nbconvert_exporter": "python",
       "pygments_lexer": "ipython3",
-      "version": "3.10.4"
+      "version": "3.9.11"
     },
     "orig_nbformat": 4,
     "vscode": {
       "interpreter": {
-        "hash": "8e4aa0e1a1e15de86146661edda0b2884b54582522f7ff2b916774ba6b8accb1"
+        "hash": "64e7cd3b4b88defe39dd61a4584920400d6beb2615ab2244e340c2e20eecdfe9"
       }
     }
   },
diff --git a/bart/sparql.ipynb b/bart/sparql.ipynb
index adaf70cfc3a9b55cb2a1499a3d81c2e673ec1f0c..6dcb66eb8825c68573affc46257035764f82681b 100644
--- a/bart/sparql.ipynb
+++ b/bart/sparql.ipynb
@@ -4075,15 +4075,13 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.9.10"
+   "version": "3.9.11"
   },
   "vscode": {
-   "interpreter": {
     "hash": "0988fb18a177ab47253348976d615106e52998fa9d41728b0d76cf08f2eafea6"
    }
   },
   "widgets": {
-   "application/vnd.jupyter.widget-state+json": {
     "04562bf2742444049a073935c78ad553": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.2.0",
diff --git a/data/relations-all-no-http-lowercase.json b/data/relations-all-no-http-lowercase.json
new file mode 100644
index 0000000000000000000000000000000000000000..90d64557b4730b9cf007741d100ac3a8b89f775b
--- /dev/null
+++ b/data/relations-all-no-http-lowercase.json
@@ -0,0 +1,443 @@
+[
+  "dbp:firstdriver",
+  "dbp:governingbody",
+  "dbp:mission",
+  "dbp:flagbearer",
+  "dbp:manager",
+  "dbo:foundationplace",
+  "dbo:related",
+  "dbo:lieutenant",
+  "dbo:formerpartner",
+  "dbp:combatant",
+  "dbo:spouse",
+  "dbo:division",
+  "dbp:religion",
+  "dbp:affiliations",
+  "dbo:predecessor",
+  "dbp:knownfor",
+  "dbp:cinematography",
+  "dbo:ideology",
+  "dbp:meaning",
+  "dbp:constituency",
+  "dbo:maintainedby",
+  "dbo:layout",
+  "dbp:address",
+  "dbo:computingplatform",
+  "dbp:position",
+  "dbo:party",
+  "dbp:religiousaffiliation",
+  "dbp:genre",
+  "dbp:successor",
+  "dbp:nationality",
+  "dbo:parentorganisation",
+  "dbp:assembly",
+  "dbp:subject",
+  "dbp:headquarters",
+  "dbp:placeofbirth",
+  "dbp:team",
+  "dbp:director",
+  "dbo:family",
+  "dbp:district",
+  "dbp:creators",
+  "dbo:firstdriver",
+  "dbo:architecturalstyle",
+  "dbo:mountainrange",
+  "dbp:leadername",
+  "dbo:product",
+  "dbp:origin",
+  "dbp:doctoralstudents",
+  "dbo:trainer",
+  "dbo:order",
+  "dbp:birthplace",
+  "dbo:builder",
+  "dbo:voice",
+  "dbp:majorshrine",
+  "dbp:office",
+  "dbp:leader",
+  "dbp:locationcountry",
+  "dbo:manager",
+  "dbp:draftteam",
+  "dbp:license",
+  "dbp:editor",
+  "dbp:trainer",
+  "dbp:stadium",
+  "dbp:domain",
+  "dbo:gender",
+  "dbo:highschool",
+  "dbo:associatedband",
+  "dbo:coach",
+  "dbp:broadcastarea",
+  "dbo:automobileplatform",
+  "dbo:lyrics",
+  "dbo:training",
+  "dbo:hubairport",
+  "dbo:militaryunit",
+  "dbp:membership",
+  "dbo:firstascentperson",
+  "dbo:operator",
+  "dbo:notablework",
+  "dbp:teamname",
+  "dbp:creator",
+  "dbp:architecturalstyle",
+  "dbp:owners",
+  "dbo:language",
+  "dbp:state",
+  "dbo:mouthmountain",
+  "dbp:archipelago",
+  "dbp:publisher",
+  "dbp:format",
+  "dbp:products",
+  "dbp:predecessor",
+  "dbo:season",
+  "dbo:bronzemedalist",
+  "dbp:discipline",
+  "dbp:lyrics",
+  "dbo:bandmember",
+  "dbo:knownfor",
+  "dbo:silvermedalist",
+  "dbp:architecture",
+  "dbo:stateoforigin",
+  "dbo:network",
+  "dbo:cinematography",
+  "dbp:coach",
+  "dbo:nearestcity",
+  "dbp:author",
+  "dbp:debutteam",
+  "dbp:prizes",
+  "dbo:mayor",
+  "dbo:denomination",
+  "dbo:currency",
+  "dbp:workinstitutions",
+  "dbp:spouse",
+  "dbo:ethnicity",
+  "dbp:notablecommanders",
+  "dbp:starring",
+  "dbp:battles",
+  "dbo:formerteam",
+  "dbo:foundedby",
+  "dbo:leader",
+  "dbo:veneratedin",
+  "dbo:restingplace",
+  "dbp:licensee",
+  "dbp:appointer",
+  "dbp:keypeople",
+  "dbo:relative",
+  "dbp:training",
+  "dbp:presenter",
+  "dbo:managerclub",
+  "dbo:writer",
+  "dbo:partner",
+  "dbo:operatingsystem",
+  "dbp:country",
+  "dbp:music",
+  "dbo:basedon",
+  "dbp:deathdate",
+  "dbp:line",
+  "dbo:creator",
+  "dbo:series",
+  "dbo:previouswork",
+  "dbp:founded",
+  "dbp:jurisdiction",
+  "dbo:opponent",
+  "dbp:area",
+  "dbo:garrison",
+  "dbp:gender",
+  "dbo:stadium",
+  "dbp:place",
+  "dbp:chancellor",
+  "dbp:hostcity",
+  "dbp:placeofburial",
+  "dbp:cityserved",
+  "dbo:starring",
+  "dbo:profession",
+  "dbp:agencyname",
+  "dbo:federalstate",
+  "dbo:origin",
+  "dbp:party",
+  "dbp:tenants",
+  "dbp:schooltype",
+  "dbp:magazine",
+  "dbo:distributinglabel",
+  "dbp:purpose",
+  "dbo:incumbent",
+  "dbp:homestadium",
+  "dbp:canonizedby",
+  "dbo:deathplace",
+  "dbp:design",
+  "dbo:producer",
+  "dbo:officiallanguage",
+  "dbp:foundation",
+  "dbp:champion",
+  "dbo:author",
+  "dbp:hometown",
+  "dbp:doctoraladvisor",
+  "dbp:affiliation",
+  "dbp:children",
+  "dbp:occupation",
+  "dbo:affiliation",
+  "dbo:doctoralstudent",
+  "dbo:authority",
+  "dbp:employer",
+  "dbp:partner",
+  "dbo:architect",
+  "dbo:deathcause",
+  "dbo:jurisdiction",
+  "dbo:battle",
+  "dbp:manufacturer",
+  "dbp:relatives",
+  "dbp:deathplace",
+  "dbp:pastmembers",
+  "dbp:operator",
+  "dbo:honours",
+  "dbp:owner",
+  "dbo:headquarter",
+  "dbo:ceremonialcounty",
+  "dbp:firstteam",
+  "dbp:college",
+  "dbp:label",
+  "dbp:poledriver",
+  "dbo:president",
+  "dbo:destination",
+  "dbp:screenplay",
+  "dbp:material",
+  "dbp:programminglanguage",
+  "dbp:sisterstations",
+  "dbo:founder",
+  "dbo:targetairport",
+  "dbp:officialname",
+  "dbp:hubs",
+  "dbo:academicadvisor",
+  "dbo:locationcity",
+  "dbp:parent",
+  "dbp:currentclub",
+  "dbp:services",
+  "dbo:discoverer",
+  "dbo:colour",
+  "dbo:team",
+  "dbo:largestcity",
+  "dbo:recordedin",
+  "dbp:governor",
+  "dbp:university",
+  "dbo:servingrailwayline",
+  "dbo:commander",
+  "dbo:species",
+  "dbo:poledriver",
+  "dbp:athletics",
+  "dbo:race",
+  "dbp:league",
+  "dbp:languages",
+  "dbo:militarybranch",
+  "dbp:leadertitle",
+  "dbp:notableworks",
+  "dbo:literarygenre",
+  "dbo:genre",
+  "dbp:placeofdeath",
+  "dbp:president",
+  "dbo:manufacturer",
+  "dbo:significantbuilding",
+  "dbp:mascot",
+  "dbp:borough",
+  "dbp:race",
+  "dbp:language",
+  "dbp:birthname",
+  "dbp:commander",
+  "dbo:distributor",
+  "dbp:beatifiedby",
+  "dbp:managerclubs",
+  "dbp:firstaired",
+  "dbo:countyseat",
+  "dbp:artist",
+  "dbo:majorshrine",
+  "dbo:athletics",
+  "dbo:education",
+  "dbo:programmeformat",
+  "dbp:branch",
+  "dbo:location",
+  "dbo:cpu",
+  "dbp:characters",
+  "dbp:inflow",
+  "dbp:outflow",
+  "dbp:houses",
+  "dbp:mainingredient",
+  "dbp:restingplace",
+  "dbo:wineregion",
+  "dbp:producer",
+  "dbp:primeminister",
+  "dbp:related",
+  "dbp:recorded",
+  "dbp:guests",
+  "dbo:director",
+  "dbp:style",
+  "dbo:successor",
+  "dbo:parentcompany",
+  "dbo:university",
+  "dbo:editor",
+  "dbo:timezone",
+  "dbo:hometown",
+  "dbo:commandstructure",
+  "dbo:monarch",
+  "dbo:publisher",
+  "dbo:residence",
+  "dbo:rivermouth",
+  "dbo:otherparty",
+  "dbo:designer",
+  "dbp:titles",
+  "dbo:nationality",
+  "dbo:stylisticorigin",
+  "dbo:movement",
+  "dbp:club",
+  "dbp:neighboringmunicipalities",
+  "dbp:locationtown",
+  "dbp:venue",
+  "dbp:destinations",
+  "dbo:board",
+  "dbp:residence",
+  "dbp:pastteams",
+  "dbo:editing",
+  "dbo:nonfictionsubject",
+  "dbp:lieutenant",
+  "dbp:order",
+  "dbp:title",
+  "dbo:routeend",
+  "dbo:launchsite",
+  "dbo:doctoraladvisor",
+  "dbp:writer",
+  "dbo:owningcompany",
+  "dbo:anthem",
+  "dbo:placeofburial",
+  "dbp:province",
+  "dbo:executiveproducer",
+  "dbp:distributor",
+  "dbo:debutteam",
+  "dbo:developer",
+  "dbo:chairman",
+  "dbp:ground",
+  "dbp:buildingtype",
+  "dbo:programminglanguage",
+  "dbp:nearestcity",
+  "dbp:international",
+  "dbp:almamater",
+  "dbo:academicdiscipline",
+  "dbp:deputy",
+  "dbo:river",
+  "dbp:garrison",
+  "dbo:associatedmusicalartist",
+  "dbp:school",
+  "dbo:breeder",
+  "dbp:chairman",
+  "dbo:sport",
+  "dbo:club",
+  "dbo:illustrator",
+  "dbp:founder",
+  "dbo:school",
+  "dbp:citizenship",
+  "dbp:designer",
+  "dbp:playedfor",
+  "dbp:arena",
+  "dbo:parent",
+  "dbo:ingredient",
+  "dbp:junction",
+  "dbo:institution",
+  "dbp:engine",
+  "dbo:regionserved",
+  "dbp:rank",
+  "dbp:operatingsystem",
+  "dbo:type",
+  "dbp:carries",
+  "dbo:citizenship",
+  "dbp:locationcity",
+  "dbp:nickname",
+  "dbo:governmenttype",
+  "dbp:commandstructure",
+  "dbo:militaryrank",
+  "dbo:formerbandmember",
+  "dbp:maininterests",
+  "dbo:almamater",
+  "dbo:country",
+  "dbp:narrated",
+  "dbo:occupation",
+  "dbo:primeminister",
+  "dbo:field",
+  "dbo:composer",
+  "dbo:routestart",
+  "dbo:league",
+  "dbo:outflow",
+  "dbo:subsidiary",
+  "dbo:child",
+  "dbp:highschool",
+  "dbp:sisternames",
+  "dbo:portrayer",
+  "dbo:award",
+  "dbp:crosses",
+  "dbo:employer",
+  "dbp:name",
+  "dbp:city",
+  "dbo:birthplace",
+  "dbo:musicby",
+  "dbp:editing",
+  "dbo:license",
+  "dbo:relation",
+  "dbp:fields",
+  "dbp:nationalorigin",
+  "dbp:location",
+  "dbo:kingdom",
+  "dbp:largestcity",
+  "dbo:service",
+  "dbo:capital",
+  "dbp:field",
+  "dbp:currency",
+  "dbo:phylum",
+  "dbp:haircolor",
+  "dbp:deathcause",
+  "dbo:ground",
+  "dbp:education",
+  "dbo:broadcastarea",
+  "dbo:territory",
+  "dbp:mother",
+  "dbo:narrator",
+  "dbp:allegiance",
+  "dbo:assembly",
+  "dbo:college",
+  "dbp:region",
+  "dbo:region",
+  "dbo:album",
+  "dbo:county",
+  "dbo:tenant",
+  "dbp:executiveproducer",
+  "dbp:awards",
+  "dbp:developer",
+  "dbp:writers",
+  "dbp:architect",
+  "dbp:thememusiccomposer",
+  "dbo:presenter",
+  "dbp:company",
+  "dbo:owner",
+  "dbp:birthdate",
+  "dbp:nationalteam",
+  "dbo:keyperson",
+  "dbp:cities",
+  "dbo:campus",
+  "dbo:city",
+  "dbo:languagefamily",
+  "dbo:sisterstation",
+  "dbp:coverartist",
+  "dbp:veneratedin",
+  "dbp:animator",
+  "dbp:type",
+  "dbp:notableinstruments",
+  "dbo:locatedinarea",
+  "dbo:religion",
+  "dbo:binomialauthority",
+  "dbo:neighboringmunicipality",
+  "dbp:role",
+  "dbo:homestadium",
+  "dbp:os",
+  "dbp:youthclubs",
+  "dbp:currentmembers",
+  "dbp:album",
+  "dbo:instrument",
+  "dbo:subsequentwork",
+  "dbo:inflow",
+  "dbo:artist",
+  "dbp:associatedacts"
+]