From 13dc93681ff8f827845be9cf2ffd9cf9cc987c71 Mon Sep 17 00:00:00 2001
From: Albin <abbe_h@hotmail.com>
Date: Wed, 23 Nov 2022 11:59:07 +0100
Subject: [PATCH] Test and validation

---
 Neural graph module/ngm.ipynb           | 298 +++++++++++-----
 data/relations-all-lc-quad-no-http.json | 446 ++++++++++++++++++++++++
 2 files changed, 657 insertions(+), 87 deletions(-)
 create mode 100644 data/relations-all-lc-quad-no-http.json

diff --git a/Neural graph module/ngm.ipynb b/Neural graph module/ngm.ipynb
index 448cf36..a924791 100644
--- a/Neural graph module/ngm.ipynb	
+++ b/Neural graph module/ngm.ipynb	
@@ -2,7 +2,7 @@
   "cells": [
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 17,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -20,7 +20,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 18,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -29,9 +29,17 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 19,
       "metadata": {},
-      "outputs": [],
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "cuda\n"
+          ]
+        }
+      ],
       "source": [
         "# Use GPU if available\n",
         "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
@@ -41,17 +49,17 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 20,
       "metadata": {},
       "outputs": [],
       "source": [
         "\n",
         "class NgmOne(nn.Module):\n",
-        "    def __init__(self, device):\n",
+        "    def __init__(self, device, relations):\n",
         "        super(NgmOne, self).__init__()\n",
         "        self.tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
         "        self.bert = BertModel.from_pretrained(\"bert-base-uncased\").to(device)\n",
-        "        self.linear = nn.Linear(768, 1396).to(device)\n",
+        "        self.linear = nn.Linear(768, len(relations)).to(device)\n",
         "        self.softmax = nn.Softmax(dim=1).to(device)\n",
         "        self.device = device\n",
         "    \n",
@@ -69,7 +77,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 21,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -91,18 +99,19 @@
         "}\n",
         "\n",
         "prefixes_end = [\"org\", \"se\", \"com\", \"nu\"]\n",
-        "ALL_HTTP_PREFIXES = True\n",
+        "ALL_HTTP_PREFIXES = False\n",
         "\n",
-        "def make_batch():\n",
+        "def make_batch(src, http_prefix=False):\n",
         "    \"\"\"Triplet is a list of [subject entity, relation, object entity], None if not present\"\"\"\n",
-        "\n",
+        "    pred = src\n",
+        "    gold = src\n",
         "    # Load predicted data\n",
         "    # \"../data/qald-9-train-linked.json\"\n",
-        "    pred = \"../LC-QuAD/combined-requeried-linked-train.json\"\n",
+        "    #pred = \"../LC-QuAD/combined-requeried-linked-train.json\"\n",
         "\n",
         "    #Load gold data\n",
         "    # \"../LC-QuAD/combined-requeried-linked-train.json\"\n",
-        "    gold = \"../LC-QuAD/combined-requeried-linked-train.json\"\n",
+        "    #gold = \"../LC-QuAD/combined-requeried-linked-train.json\"\n",
         "    print(\"Beginning making batch\")\n",
         "    with open(pred, \"r\") as p, open(gold, \"r\") as g:\n",
         "        pred = json.load(p)\n",
@@ -133,7 +142,7 @@
         "                for t in triplet_i:\n",
         "                    if not(t.find(\"?\")):\n",
         "                        triplet_i[triplet_i.index(t)] = None\n",
-        "                    elif ALL_HTTP_PREFIXES:\n",
+        "                    elif http_prefix:\n",
         "                        n = t.replace(\"<\", \"\").replace(\">\", \"\")\n",
         "                        n_sub = n.split(\"/\")[-1]\n",
         "                        for i in prefixes_end:\n",
@@ -185,7 +194,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 22,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -207,7 +216,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 15,
+      "execution_count": 23,
       "metadata": {},
       "outputs": [
         {
@@ -221,7 +230,7 @@
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "100%|██████████| 408/408 [00:00<00:00, 616.32it/s]\n"
+            "100%|██████████| 2052/2052 [00:03<00:00, 546.37it/s]\n"
           ]
         },
         {
@@ -229,40 +238,16 @@
           "output_type": "stream",
           "text": [
             "Finished with batches\n",
-            "Beginning making batch\n"
-          ]
-        },
-        {
-          "name": "stderr",
-          "output_type": "stream",
-          "text": [
-            "100%|██████████| 408/408 [00:00<00:00, 620.42it/s]"
-          ]
-        },
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "Finished with batches\n"
-          ]
-        },
-        {
-          "name": "stderr",
-          "output_type": "stream",
-          "text": [
-            "\n"
-          ]
-        },
-        {
-          "ename": "ValueError",
-          "evalue": "Sum of input lengths does not equal the length of the input dataset!",
-          "output_type": "error",
-          "traceback": [
-            "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
-            "\u001b[1;31mValueError\u001b[0m                                Traceback (most recent call last)",
-            "\u001b[1;32mc:\\Users\\Albin\\Documents\\TDDE19\\codebase\\Neural graph module\\ngm.ipynb Cell 7\u001b[0m in \u001b[0;36m<cell line: 14>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      <a href='vscode-notebook-cell:/c%3A/Users/Albin/Documents/TDDE19/codebase/Neural%20graph%20module/ngm.ipynb#X11sZmlsZQ%3D%3D?line=8'>9</a>\u001b[0m inputs, attention_mask, correct_rels \u001b[39m=\u001b[39m make_batch()\n\u001b[0;32m     <a href='vscode-notebook-cell:/c%3A/Users/Albin/Documents/TDDE19/codebase/Neural%20graph%20module/ngm.ipynb#X11sZmlsZQ%3D%3D?line=12'>13</a>\u001b[0m dataset \u001b[39m=\u001b[39m MyDataset(\u001b[39m*\u001b[39mmake_batch(), relations\u001b[39m=\u001b[39mrelations)\n\u001b[1;32m---> <a href='vscode-notebook-cell:/c%3A/Users/Albin/Documents/TDDE19/codebase/Neural%20graph%20module/ngm.ipynb#X11sZmlsZQ%3D%3D?line=13'>14</a>\u001b[0m split_data \u001b[39m=\u001b[39m random_split(dataset, [\u001b[39m0.8\u001b[39;49m, \u001b[39m0.2\u001b[39;49m], generator\u001b[39m=\u001b[39;49mtorch\u001b[39m.\u001b[39;49mGenerator()\u001b[39m.\u001b[39;49mmanual_seed(\u001b[39m42\u001b[39;49m))\n\u001b[0;32m     <a href='vscode-notebook-cell:/c%3A/Users/Albin/Documents/TDDE19/codebase/Neural%20graph%20module/ngm.ipynb#X11sZmlsZQ%3D%3D?line=15'>16</a>\u001b[0m train_dataloader \u001b[39m=\u001b[39m DataLoader(train_set, batch_size\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m, shuffle\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[0;32m     <a href='vscode-notebook-cell:/c%3A/Users/Albin/Documents/TDDE19/codebase/Neural%20graph%20module/ngm.ipynb#X11sZmlsZQ%3D%3D?line=17'>18</a>\u001b[0m \u001b[39m#show first entry\u001b[39;00m\n",
-            "File \u001b[1;32mb:\\Programs\\Miniconda\\envs\\tdde19\\lib\\site-packages\\torch\\utils\\data\\dataset.py:311\u001b[0m, in \u001b[0;36mrandom_split\u001b[1;34m(dataset, lengths, generator)\u001b[0m\n\u001b[0;32m    309\u001b[0m \u001b[39m# Cannot verify that dataset is Sized\u001b[39;00m\n\u001b[0;32m    310\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39msum\u001b[39m(lengths) \u001b[39m!=\u001b[39m \u001b[39mlen\u001b[39m(dataset):    \u001b[39m# type: ignore[arg-type]\u001b[39;00m\n\u001b[1;32m--> 311\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mSum of input lengths does not equal the length of the input dataset!\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m    313\u001b[0m indices \u001b[39m=\u001b[39m randperm(\u001b[39msum\u001b[39m(lengths), generator\u001b[39m=\u001b[39mgenerator)\u001b[39m.\u001b[39mtolist()\n\u001b[0;32m    314\u001b[0m \u001b[39mreturn\u001b[39;00m [Subset(dataset, indices[offset \u001b[39m-\u001b[39m length : offset]) \u001b[39mfor\u001b[39;00m offset, length \u001b[39min\u001b[39;00m \u001b[39mzip\u001b[39m(_accumulate(lengths), lengths)]\n",
-            "\u001b[1;31mValueError\u001b[0m: Sum of input lengths does not equal the length of the input dataset!"
+            "features: tensor([[  101,  2054,  2003,  1996, 13314,  1997, 16122, 10230,  5400,  2100,\n",
+            "         12378,  2136,  1029,   102,  1031,  4942,  1033,   102,     0,     0,\n",
+            "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
+            "             0,     0,     0,     0,     0,     0,     0]]) mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,\n",
+            "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) label_index tensor(378)\n",
+            "valid features: tensor([[  101,  2054,  4752,  1997,  7992, 13843,  5112,  2267,  2003,  2036,\n",
+            "          1996,  9353,  4215, 26432,  2278,  3037,  1997,  2703, 27668,  3077,\n",
+            "          1029,   102,  1031,  4942,  1033,   102,     0,     0,     0,     0,\n",
+            "             0,     0,     0,     0,     0,     0,     0]]) valid mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+            "         1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) valid label_index tensor(28)\n"
           ]
         }
       ],
@@ -274,55 +259,133 @@
         "    with open(file, \"r\") as f:\n",
         "        return json.load(f)\n",
         "\n",
-        "relations = open_json(\"../data/relations-query-qald-9-linked.json\")\n",
-        "inputs, attention_mask, correct_rels = make_batch()\n",
+        "#relations = open_json(\"../data/relations-query-qald-9-linked.json\")\n",
+        "relations = open_json(\"../data/relations-all-lc-quad-no-http.json\")\n",
+        "\n",
+        "# \"../data/qald-9-train-linked.json\"\n",
+        "#pred = \"../LC-QuAD/combined-requeried-linked-train.json\"\n",
+        "inputs, attention_mask, correct_rels = make_batch(src=\"../LC-QuAD/combined-requeried-linked-train.json\", http_prefix = True) #train\n",
         "\n",
         "# relations = open_json(\"../data/relations-lcquad-without-http-train-linked.json\")\n",
         "# train_set = MyDataset(*make_batch(), relations=relations)\n",
         "\n",
         "\n",
-        "dataset = MyDataset(*make_batch(), relations=relations)\n",
-        "split_data = random_split(dataset, [0.8, 0.2], generator=torch.Generator().manual_seed(42))\n",
+        "dataset = MyDataset(inputs, attention_mask, correct_rels, relations=relations)\n",
+        "train_size = int(0.8 * len(dataset))\n",
+        "valid_size = len(dataset) - train_size\n",
         "\n",
-        "train_dataloader = DataLoader(train_set, batch_size=1, shuffle=True)\n",
+        "train_data, valid_data = random_split(dataset, [train_size, valid_size], generator=torch.Generator().manual_seed(42))\n",
         "\n",
+        "train_dataloader = DataLoader(train_data, batch_size=1, shuffle=True)\n",
         "#show first entry\n",
         "train_features, train_mask, train_label = next(iter(train_dataloader))\n",
-        "print(\"features:\", train_features, \"mask:\",train_mask,\"label_index\", train_label[0])"
+        "print(\"features:\", train_features, \"mask:\",train_mask,\"label_index\", train_label[0])\n",
+        "\n",
+        "\n",
+        "valid_dataloader = DataLoader(valid_data, batch_size=1, shuffle=True)\n",
+        "valid_features, valid_mask, valid_label = next(iter(valid_dataloader))\n",
+        "print(\"valid features:\", valid_features, \"valid mask:\",valid_mask,\"valid label_index\", valid_label[0])"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 24,
       "metadata": {},
-      "outputs": [],
+      "outputs": [
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']\n",
+            "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+            "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
+          ]
+        }
+      ],
       "source": [
         "# Initialize model\n",
-        "model = NgmOne(device)"
+        "model = NgmOne(device, relations)"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 25,
       "metadata": {},
-      "outputs": [],
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "1 Train 6.453075773575726 , Valid  7.615683317184448\n",
+            "2 Train 6.444019513971665 , Valid  7.6052809953689575\n",
+            "3 Train 6.43224432889153 , Valid  7.602278828620911\n",
+            "4 Train 6.415224019218893 , Valid  7.598743319511414\n",
+            "5 Train 6.391957339118509 , Valid  7.584347605705261\n",
+            "6 Train 6.368611419902129 , Valid  7.573382139205933\n",
+            "7 Train 6.339372747084674 , Valid  7.557032465934753\n",
+            "8 Train 6.318648646859562 , Valid  7.539492607116699\n",
+            "9 Train 6.297140289755428 , Valid  7.523481965065002\n",
+            "10 Train 6.281140131108901 , Valid  7.519673824310303\n",
+            "11 Train 6.265462370479808 , Valid  7.5268988609313965\n",
+            "12 Train 6.24550900739782 , Valid  7.504716753959656\n",
+            "13 Train 6.234297051149256 , Valid  7.505620360374451\n",
+            "14 Train 6.220176332137164 , Valid  7.500500917434692\n",
+            "15 Train 6.204408701728372 , Valid  7.493185758590698\n",
+            "16 Train 6.194800797630759 , Valid  7.488588452339172\n",
+            "17 Train 6.183391935685101 , Valid  7.466232180595398\n",
+            "18 Train 6.173011443194221 , Valid  7.457393527030945\n",
+            "19 Train 6.168144422418931 , Valid  7.45364773273468\n",
+            "20 Train 6.157315562753117 , Valid  7.452357649803162\n",
+            "21 Train 6.140512690824621 , Valid  7.456966400146484\n",
+            "22 Train 6.129443336935604 , Valid  7.464043855667114\n",
+            "23 Train 6.12704220940085 , Valid  7.443313360214233\n",
+            "24 Train 6.117525269003475 , Valid  7.4395164251327515\n",
+            "25 Train 6.116069064420812 , Valid  7.456610083580017\n",
+            "26 Train 6.108902790967156 , Valid  7.430782794952393\n",
+            "27 Train 6.108471814323874 , Valid  7.44858992099762\n",
+            "28 Train 6.100216697244083 , Valid  7.457515120506287\n",
+            "29 Train 6.099361447726979 , Valid  7.438013672828674\n",
+            "30 Train 6.092377662658691 , Valid  7.448408484458923\n",
+            "31 Train 6.088698302998262 , Valid  7.442046403884888\n",
+            "32 Train 6.083712998558493 , Valid  7.420018911361694\n",
+            "33 Train 6.081799563239603 , Valid  7.426819205284119\n",
+            "34 Train 6.070652428795309 , Valid  7.426627039909363\n",
+            "35 Train 6.069005825940301 , Valid  7.425306558609009\n",
+            "36 Train 6.059389002182904 , Valid  7.415045142173767\n",
+            "37 Train 6.0618573918062095 , Valid  7.418038249015808\n",
+            "38 Train 6.055309716392966 , Valid  7.434216380119324\n",
+            "39 Train 6.049336994395537 , Valid  7.43299674987793\n",
+            "40 Train 6.048270281623392 , Valid  7.410753607749939\n",
+            "41 Train 6.043416135451373 , Valid  7.410261392593384\n",
+            "42 Train 6.042512613184312 , Valid  7.4074866771698\n",
+            "43 Train 6.032251245835248 , Valid  7.390785813331604\n",
+            "44 Train 6.025217645308551 , Valid  7.370600938796997\n",
+            "45 Train 6.027591985814712 , Valid  7.377846837043762\n",
+            "46 Train 6.0253391826854035 , Valid  7.39339554309845\n",
+            "47 Train 6.019261051626766 , Valid  7.407409071922302\n",
+            "48 Train 6.014283713172464 , Valid  7.393349766731262\n",
+            "49 Train 6.010964337517233 , Valid  7.37293815612793\n",
+            "50 Train 6.010590104495778 , Valid  7.4031054973602295\n"
+          ]
+        }
+      ],
       "source": [
         "# Train with data loader.\n",
         "criterion = nn.CrossEntropyLoss()\n",
-        "optimizer = optim.Adam(model.parameters(), lr=0.0001)\n",
+        "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
         "\n",
-        "epoch = 500\n",
+        "epoch = 50\n",
         "batch_size = 64\n",
-        "train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
-        "valid_dataloader = DataLoader(valid_set, batch_size=batch_size, shuffle=True)\n",
+        "train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n",
+        "valid_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=True)\n",
         "for e in range(epoch):\n",
         "    train_loss_epoch = 0\n",
         "    valid_loss_epoch = 0\n",
-        "    for i_batch, sample_batched in enumerate(train_dataloader):\n",
+        "    for i_train, sample_batched_train in enumerate(train_dataloader):\n",
         "        optimizer.zero_grad()\n",
-        "        train = sample_batched[0]\n",
-        "        train_mask = sample_batched[1]\n",
-        "        label_index = sample_batched[2].to(device)\n",
+        "        train = sample_batched_train[0]\n",
+        "        train_mask = sample_batched_train[1]\n",
+        "        label_index = sample_batched_train[2].to(device)\n",
         "        \n",
         "        # Forward pass\n",
         "        output = model(train, train_mask)\n",
@@ -333,38 +396,98 @@
         "        optimizer.step()\n",
         "        train_loss_epoch = train_loss_epoch + loss.item()\n",
         "\n",
-        "    for i_batch, sample_batched in enumerate(valid_dataloader):\n",
-        "        valid = sample_batched[0]\n",
-        "        valid_mask = sample_batched[1]\n",
-        "        label_index = sample_batched[2].to(device)\n",
+        "    for i_valid, sample_batched_valid in enumerate(valid_dataloader):\n",
+        "        valid = sample_batched_valid[0]\n",
+        "        valid_mask = sample_batched_valid[1]\n",
+        "        label_index = sample_batched_valid[2].to(device)\n",
         "        \n",
         "        # Forward pass\n",
-        "        output = model(train, train_mask)\n",
-        "        loss = criterion(output, label_index)\n",
+        "        with torch.no_grad():\n",
+        "            output = model(valid, valid_mask)\n",
+        "            loss = criterion(output, label_index)\n",
         "\n",
         "        valid_loss_epoch = valid_loss_epoch + loss.item()\n",
         "\n",
-        "    print(e+1, \"Train\", valid_loss_epoch/len(sample_batched), \", Valid \", valid_loss_epoch/len(sample_batched))"
+        "    print(e+1, \"Train\", train_loss_epoch/i_train, \", Valid \", valid_loss_epoch/i_valid)"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 31,
       "metadata": {},
-      "outputs": [],
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "Beginning making batch\n"
+          ]
+        },
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "100%|██████████| 2052/2052 [00:03<00:00, 577.85it/s]\n"
+          ]
+        },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "Finished with batches\n",
+            "Beginning making batch\n"
+          ]
+        },
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "100%|██████████| 1161/1161 [00:02<00:00, 566.88it/s]\n"
+          ]
+        },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "Finished with batches\n",
+            "test loss 6.031976699829102\n",
+            "lowest confidence train 0.08024344\n",
+            "lowest confidence test 0.08576707\n",
+            "Accuracy train: 0.3939828080229226\n",
+            "Accuracy test: 0.006361323155216285\n"
+          ]
+        }
+      ],
       "source": [
         "# Predict\n",
-        "train, train_mask, corr_rels = make_batch()\n",
+        "train, train_mask, corr_rels = make_batch(src=\"../LC-QuAD/combined-requeried-linked-train.json\", http_prefix = True)\n",
+        "test, test_mask, corr_rels_test = make_batch(src=\"../LC-QuAD/combined-requeried-linked-test.json\", http_prefix = True)\n",
+        "test_data = MyDataset(test, test_mask, corr_rels_test, relations=relations)\n",
+        "test_dataloader = DataLoader(test_data, batch_size=len(test_data), shuffle=True)\n",
+        "\n",
+        "test_batch, test_mask_batch, corr_rels_test_batch = next(iter(test_dataloader))\n",
+        "corr_rels_test_batch = corr_rels_test_batch.to(device)\n",
         "with torch.no_grad():\n",
-        "    output = model(train, train_mask)\n",
+        "    output_train = model(train, train_mask)\n",
+        "    output_test = model(test_batch, test_mask_batch)\n",
+        "    loss = criterion(output_test, corr_rels_test_batch)\n",
+        "    print(\"test loss\", loss.item())\n",
+        "\n",
+        "output_train = output_train.detach().cpu().numpy()\n",
+        "output_test = output_test.detach().cpu().numpy()\n",
+        "\n",
+        "\n",
+        "prediction_train = [relations[np.argmax(pred).item()]for pred in output_train]\n",
+        "probability_train = [pred[np.argmax(pred)] for pred in output_train]\n",
+        "correct_pred_train = [corr_rels[i] for i in range(len(output_train))]\n",
         "\n",
-        "output = output.detach().cpu().numpy()\n",
         "\n",
-        "prediction = [relations[np.argmax(pred).item()]for pred in output]\n",
-        "probability = [pred[np.argmax(pred)] for pred in output]\n",
-        "correct_pred = [corr_rels[i] for i in range(len(output))]\n",
+        "prediction_test = [relations[np.argmax(pred).item()]for pred in output_test]\n",
+        "probability_test = [pred[np.argmax(pred)] for pred in output_test]\n",
+        "correct_pred_test = [corr_rels[i] for i in range(len(output_test))]\n",
         "\n",
-        "print(\"lowest confidence\", min(probability))\n",
+        "print(\"lowest confidence train\", min(probability_train))\n",
+        "print(\"lowest confidence test\", min(probability_test))\n",
         "\n",
         "def accuracy_score(y_true, y_pred):\n",
         "    corr_preds=0\n",
@@ -376,7 +499,8 @@
         "            wrong_preds += 1\n",
         "    return corr_preds/(corr_preds+wrong_preds)\n",
         "\n",
-        "print(\"Accuracy:\", accuracy_score(correct_pred, prediction))\n",
+        "print(\"Accuracy train:\", accuracy_score(correct_pred_train, prediction_train))\n",
+        "print(\"Accuracy test:\", accuracy_score(correct_pred_test, prediction_test))\n",
         "\n"
       ]
     },
diff --git a/data/relations-all-lc-quad-no-http.json b/data/relations-all-lc-quad-no-http.json
new file mode 100644
index 0000000..f37bd46
--- /dev/null
+++ b/data/relations-all-lc-quad-no-http.json
@@ -0,0 +1,446 @@
+[
+  "dbo:layout",
+  "dbp:owners",
+  "dbo:literaryGenre",
+  "dbp:teamName",
+  "dbp:style",
+  "dbo:genre",
+  "dbo:honours",
+  "dbo:child",
+  "dbo:writer",
+  "dbo:currency",
+  "dbp:president",
+  "dbp:editor",
+  "dbp:governor",
+  "dbo:placeOfBurial",
+  "dbp:governingBody",
+  "dbp:placeofburial",
+  "dbp:race",
+  "dbp:cityServed",
+  "dbo:subsidiary",
+  "dbp:constituency",
+  "dbp:league",
+  "dbo:ideology",
+  "dbp:district",
+  "dbp:deathPlace",
+  "dbp:headquarters",
+  "dbp:sisterNames",
+  "dbp:writers",
+  "dbp:veneratedIn",
+  "dbp:fields",
+  "dbp:branch",
+  "dbo:origin",
+  "dbp:school",
+  "dbp:firstAired",
+  "dbo:movement",
+  "dbp:prizes",
+  "dbo:river",
+  "dbo:battle",
+  "dbo:associatedBand",
+  "dbp:place",
+  "dbp:state",
+  "dbp:animator",
+  "dbp:artist",
+  "dbp:children",
+  "dbp:club",
+  "dbo:nearestCity",
+  "dbo:club",
+  "dbo:designer",
+  "dbp:manager",
+  "dbo:city",
+  "dbp:nationalOrigin",
+  "dbo:governmentType",
+  "dbp:junction",
+  "dbo:predecessor",
+  "dbp:designer",
+  "dbo:party",
+  "dbo:school",
+  "dbp:field",
+  "dbp:placeOfBirth",
+  "dbo:type",
+  "dbo:locationCity",
+  "dbo:kingdom",
+  "dbo:academicDiscipline",
+  "dbp:chancellor",
+  "dbo:doctoralAdvisor",
+  "dbp:poleDriver",
+  "dbp:locationTown",
+  "dbo:silverMedalist",
+  "dbp:chairman",
+  "dbp:carries",
+  "dbo:languageFamily",
+  "dbo:mayor",
+  "dbp:hometown",
+  "dbp:almaMater",
+  "dbp:beatifiedBy",
+  "dbo:riverMouth",
+  "dbo:managerClub",
+  "dbo:mouthMountain",
+  "dbo:sport",
+  "dbo:chairman",
+  "dbp:firstDriver",
+  "dbp:awards",
+  "dbo:mountainRange",
+  "dbp:leaderName",
+  "dbp:lieutenant",
+  "dbp:employer",
+  "dbo:distributor",
+  "dbo:militaryUnit",
+  "dbp:type",
+  "dbo:series",
+  "dbp:nearestCity",
+  "dbp:leaderTitle",
+  "dbo:lyrics",
+  "dbo:firstDriver",
+  "dbo:homeStadium",
+  "dbo:publisher",
+  "dbo:narrator",
+  "dbo:builder",
+  "dbp:highschool",
+  "dbp:meaning",
+  "dbp:presenter",
+  "dbp:garrison",
+  "dbp:director",
+  "dbp:birthDate",
+  "dbp:placeOfBurial",
+  "dbp:destinations",
+  "dbo:cpu",
+  "dbo:inflow",
+  "dbp:officialName",
+  "dbo:automobilePlatform",
+  "dbo:manufacturer",
+  "dbp:cinematography",
+  "dbp:lyrics",
+  "dbp:architecture",
+  "dbp:subject",
+  "dbp:licensee",
+  "dbp:playedFor",
+  "dbp:partner",
+  "dbo:parentOrganisation",
+  "dbp:residence",
+  "dbo:ingredient",
+  "dbo:locatedInArea",
+  "dbp:managerclubs",
+  "dbo:bandMember",
+  "dbp:nationality",
+  "dbp:team",
+  "dbp:pastMembers",
+  "dbo:bronzeMedalist",
+  "dbp:firstTeam",
+  "dbo:successor",
+  "dbo:illustrator",
+  "dbo:related",
+  "dbo:broadcastArea",
+  "dbo:militaryBranch",
+  "dbo:lieutenant",
+  "dbp:agencyName",
+  "dbo:institution",
+  "dbo:relation",
+  "dbp:architecturalStyle",
+  "dbp:affiliation",
+  "dbp:locationCountry",
+  "dbp:athletics",
+  "dbp:debutteam",
+  "dbp:developer",
+  "dbp:restingplace",
+  "dbp:religion",
+  "dbo:nationality",
+  "dbp:domain",
+  "dbp:cities",
+  "dbo:highschool",
+  "dbo:computingPlatform",
+  "dbp:broadcastArea",
+  "dbo:affiliation",
+  "dbp:hostCity",
+  "dbo:tenant",
+  "dbo:owner",
+  "dbp:narrated",
+  "dbo:artist",
+  "dbp:deathCause",
+  "dbo:nonFictionSubject",
+  "dbp:international",
+  "dbo:award",
+  "dbo:residence",
+  "dbp:keyPeople",
+  "dbo:order",
+  "dbo:headquarter",
+  "dbo:species",
+  "dbo:distributingLabel",
+  "dbp:borough",
+  "dbo:country",
+  "dbo:foundationPlace",
+  "dbo:starring",
+  "dbo:relative",
+  "dbo:creator",
+  "dbo:network",
+  "dbo:spouse",
+  "dbp:languages",
+  "dbo:operatingSystem",
+  "dbp:draftTeam",
+  "dbp:genre",
+  "dbo:ground",
+  "dbp:design",
+  "dbo:formerTeam",
+  "dbp:spouse",
+  "dbo:parent",
+  "dbp:role",
+  "dbo:editor",
+  "dbp:address",
+  "dbo:developer",
+  "dbo:instrument",
+  "dbo:foundedBy",
+  "dbo:recordedIn",
+  "dbp:membership",
+  "dbp:venue",
+  "dbo:programmingLanguage",
+  "dbp:rank",
+  "dbo:producer",
+  "dbp:editing",
+  "dbp:successor",
+  "dbp:guests",
+  "dbo:owningCompany",
+  "dbo:timeZone",
+  "dbo:maintainedBy",
+  "dbo:binomialAuthority",
+  "dbp:college",
+  "dbp:allegiance",
+  "dbp:notableInstruments",
+  "dbo:author",
+  "dbp:discipline",
+  "dbo:gender",
+  "dbo:garrison",
+  "dbp:starring",
+  "dbp:crosses",
+  "dbp:foundation",
+  "dbp:license",
+  "dbo:parentCompany",
+  "dbp:highSchool",
+  "dbp:largestCity",
+  "dbp:music",
+  "dbp:themeMusicComposer",
+  "dbp:university",
+  "dbp:affiliations",
+  "dbp:knownFor",
+  "dbp:currentMembers",
+  "dbp:os",
+  "dbo:architect",
+  "dbp:screenplay",
+  "dbp:title",
+  "dbo:campus",
+  "dbo:subsequentWork",
+  "dbp:deputy",
+  "dbo:university",
+  "dbo:targetAirport",
+  "dbp:arena",
+  "dbo:associatedMusicalArtist",
+  "dbo:breeder",
+  "dbo:religion",
+  "dbp:sisterStations",
+  "dbo:birthPlace",
+  "dbo:trainer",
+  "dbo:formerPartner",
+  "dbp:order",
+  "dbo:ceremonialCounty",
+  "dbo:season",
+  "dbp:programmingLanguage",
+  "dbo:hubAirport",
+  "dbo:notableWork",
+  "dbo:routeEnd",
+  "dbo:doctoralStudent",
+  "dbp:champion",
+  "dbp:deathDate",
+  "dbp:owner",
+  "dbp:predecessor",
+  "dbo:executiveProducer",
+  "dbp:battles",
+  "dbo:formerBandMember",
+  "dbo:manager",
+  "dbp:executiveProducer",
+  "dbp:doctoralAdvisor",
+  "dbo:editing",
+  "dbo:firstAscentPerson",
+  "dbo:territory",
+  "dbo:athletics",
+  "dbp:producer",
+  "dbo:album",
+  "dbp:birthplace",
+  "dbp:purpose",
+  "dbp:hubs",
+  "dbp:neighboringMunicipalities",
+  "dbo:discoverer",
+  "dbo:division",
+  "dbo:employer",
+  "dbp:name",
+  "dbo:wineRegion",
+  "dbp:ground",
+  "dbp:currency",
+  "dbo:otherParty",
+  "dbo:founder",
+  "dbo:voice",
+  "dbp:schooltype",
+  "dbo:family",
+  "dbp:buildingType",
+  "dbo:league",
+  "dbo:jurisdiction",
+  "dbo:monarch",
+  "dbp:origin",
+  "dbo:president",
+  "dbp:appointer",
+  "dbp:distributor",
+  "dbp:position",
+  "dbo:keyPerson",
+  "dbo:poleDriver",
+  "dbo:architecturalStyle",
+  "dbp:locationCity",
+  "dbp:notableworks",
+  "dbo:servingRailwayLine",
+  "dbp:author",
+  "dbp:stadium",
+  "dbp:products",
+  "dbp:associatedActs",
+  "dbp:creator",
+  "dbp:office",
+  "dbp:citizenship",
+  "dbo:colour",
+  "dbo:stadium",
+  "dbp:writer",
+  "dbo:region",
+  "dbo:incumbent",
+  "dbp:canonizedBy",
+  "dbp:party",
+  "dbo:profession",
+  "dbp:services",
+  "dbo:officialLanguage",
+  "dbo:leader",
+  "dbo:team",
+  "dbo:previousWork",
+  "dbp:youthclubs",
+  "dbp:workInstitutions",
+  "dbo:almaMater",
+  "dbo:capital",
+  "dbo:field",
+  "dbo:college",
+  "dbo:militaryRank",
+  "dbp:manufacturer",
+  "dbo:phylum",
+  "dbo:primeMinister",
+  "dbp:founded",
+  "dbp:region",
+  "dbp:homeStadium",
+  "dbp:birthName",
+  "dbp:operatingSystem",
+  "dbp:publisher",
+  "dbp:parent",
+  "dbp:commandStructure",
+  "dbo:citizenship",
+  "dbp:label",
+  "dbp:related",
+  "dbo:commander",
+  "dbo:county",
+  "dbo:training",
+  "dbp:location",
+  "dbp:houses",
+  "dbo:routeStart",
+  "dbp:city",
+  "dbo:cinematography",
+  "dbo:ethnicity",
+  "dbp:leader",
+  "dbp:religiousAffiliation",
+  "dbp:assembly",
+  "dbp:currentclub",
+  "dbp:area",
+  "dbo:language",
+  "dbo:board",
+  "dbp:mainIngredient",
+  "dbp:mainInterests",
+  "dbo:majorShrine",
+  "dbp:outflow",
+  "dbp:mother",
+  "dbp:training",
+  "dbo:director",
+  "dbp:line",
+  "dbp:architect",
+  "dbo:academicAdvisor",
+  "dbp:birthPlace",
+  "dbp:magazine",
+  "dbp:majorShrine",
+  "dbo:portrayer",
+  "dbp:jurisdiction",
+  "dbp:format",
+  "dbo:anthem",
+  "dbp:album",
+  "dbp:education",
+  "dbo:location",
+  "dbp:hairColor",
+  "dbo:stylisticOrigin",
+  "dbo:musicBy",
+  "dbp:inflow",
+  "dbo:neighboringMunicipality",
+  "dbo:countySeat",
+  "dbp:mascot",
+  "dbp:creators",
+  "dbo:regionServed",
+  "dbp:founder",
+  "dbp:commander",
+  "dbo:deathPlace",
+  "dbo:debutTeam",
+  "dbp:combatant",
+  "dbo:sisterStation",
+  "dbp:coverArtist",
+  "dbo:significantBuilding",
+  "dbp:occupation",
+  "dbo:commandStructure",
+  "dbo:programmeFormat",
+  "dbp:flagbearer",
+  "dbp:engine",
+  "dbp:company",
+  "dbo:knownFor",
+  "dbo:deathCause",
+  "dbo:product",
+  "dbp:pastteams",
+  "dbo:opponent",
+  "dbp:titles",
+  "dbp:doctoralStudents",
+  "dbo:service",
+  "dbo:veneratedIn",
+  "dbp:recorded",
+  "dbp:tenants",
+  "dbo:partner",
+  "dbp:province",
+  "dbo:license",
+  "dbp:material",
+  "dbp:nationalteam",
+  "dbo:stateOfOrigin",
+  "dbp:operator",
+  "dbo:race",
+  "dbo:operator",
+  "dbp:characters",
+  "dbo:outflow",
+  "dbp:language",
+  "dbp:gender",
+  "dbo:occupation",
+  "dbo:composer",
+  "dbp:mission",
+  "dbo:restingPlace",
+  "dbp:primeminister",
+  "dbp:nickname",
+  "dbo:federalState",
+  "dbo:assembly",
+  "dbp:placeOfDeath",
+  "dbp:notableCommanders",
+  "dbp:relatives",
+  "dbo:hometown",
+  "dbo:largestCity",
+  "dbo:authority",
+  "dbo:destination",
+  "dbp:archipelago",
+  "dbo:basedOn",
+  "dbp:trainer",
+  "dbp:country",
+  "dbo:education",
+  "dbo:denomination",
+  "dbp:coach",
+  "dbo:presenter",
+  "dbo:coach",
+  "dbo:launchSite"
+]
-- 
GitLab