diff --git a/Neural graph module/ngm.ipynb b/Neural graph module/ngm.ipynb
index bf75c0e5c224fa7785492429a288b518ec45bada..529b7d2ba3c0c0b0a457fc8c293954f3c241b12e 100644
--- a/Neural graph module/ngm.ipynb	
+++ b/Neural graph module/ngm.ipynb	
@@ -2,31 +2,47 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 28,
    "metadata": {},
    "outputs": [],
    "source": [
     "import datasets\n",
     "import torch\n",
     "import torch.nn as nn\n",
+    "import torch.optim as optim\n",
     "import pandas as pd\n",
     "import numpy as np\n",
     "from transformers import BertTokenizer, BertModel\n",
-    "from transformers.models.bert.modeling_bert import shift_tokens_right\n"
+    "from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments\n",
+    "from tqdm import tqdm\n",
+    "import json\n"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 29,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Downloading: 100%|██████████| 232k/232k [00:00<00:00, 636kB/s] \n",
+      "c:\\Users\\maxbj\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\huggingface_hub\\file_download.py:123: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\maxbj\\.cache\\huggingface\\hub. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
+      "To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
+      "  warnings.warn(message)\n",
+      "Downloading: 100%|██████████| 28.0/28.0 [00:00<00:00, 28.9kB/s]\n",
+      "Downloading: 100%|██████████| 570/570 [00:00<00:00, 572kB/s]\n"
+     ]
+    }
+   ],
    "source": [
     "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 64,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -36,28 +52,129 @@
     "        super(NgmOne, self).__init__()\n",
     "        self.tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
     "        self.bert = BertModel.from_pretrained(\"bert-base-uncased\")\n",
-    "        self.linear = nn.Linear(768, 1)\n",
+    "        self.linear = nn.Linear(768, 247)\n",
     "        self.softmax = nn.Softmax(dim=1)\n",
     "    \n",
-    "    def forward(self, triplet, question):\n",
-    "        \"\"\"Triplet is a list of subject entity, relation, object entity, None if not present\"\"\"\n",
-    "        \n",
+    "    def forward(self, tokenized_seq):\n",
+    "        x = self.bert.forward(tokenized_seq)\n",
+    "        x = self.linear(x)\n",
+    "\n",
+    "        x = self.softmax(x)\n",
+    "        return x"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# def encode(batch):\n",
+    "#   return tokenizer(batch, padding=\"max_length\", max_length=256, return_tensors=\"pt\")\n",
+    "\n",
+    "\n",
+    "# def convert_to_features(example_batch):\n",
+    "#     input_encodings = encode(example_batch['text'])\n",
+    "#     target_encodings = encode(example_batch['summary'])\n",
+    "\n",
+    "#     labels = target_encodings['input_ids']\n",
+    "#     decoder_input_ids = shift_tokens_right(\n",
+    "#         labels, model.config.pad_token_id, model.config.decoder_start_token_id)\n",
+    "#     labels[labels[:, :] == model.config.pad_token_id] = -100\n",
+    "\n",
+    "#     encodings = {\n",
+    "#         'input_ids': input_encodings['input_ids'],\n",
+    "#         'attention_mask': input_encodings['attention_mask'],\n",
+    "#         'decoder_input_ids': decoder_input_ids,\n",
+    "#         'labels': labels,\n",
+    "#     }\n",
+    "\n",
+    "#     return encodings\n",
+    "\n",
+    "\n",
+    "# def get_dataset(path):\n",
+    "#   df = pd.read_csv(path, sep=\",\", on_bad_lines='skip')\n",
+    "#   dataset = datasets.Dataset.from_pandas(df)\n",
+    "#   dataset = dataset.map(convert_to_features, batched=True)\n",
+    "#   columns = ['input_ids', 'labels', 'decoder_input_ids', 'attention_mask', ]\n",
+    "#   dataset.set_format(type='torch', columns=columns)\n",
+    "#   return dataset\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 62,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def make_batch():\n",
+    "    \"\"\"Triplet is a list of [subject entity, relation, object entity], None if not present\"\"\"\n",
+    "\n",
+    "    # Load predicted data\n",
+    "    pred = \"../data/qald-9-train-linked.json\"\n",
+    "\n",
+    "    #Load gold data\n",
+    "    gold = \"../data/qald-9-train-linked.json\"\n",
+    "    print(\"Beginning making batch\")\n",
+    "    with open(pred, \"r\") as p, open(gold, \"r\") as g:\n",
+    "        pred = json.load(p)\n",
+    "        gold = json.load(g)\n",
+    "\n",
+    "    inputs = []\n",
+    "    inputs_max_len = 0\n",
+    "    for d in tqdm(pred[\"questions\"]):\n",
+    "        question = d[\"question\"][0][\"string\"]\n",
+    "        query = d[\"query\"][\"sparql\"]\n",
+    "\n",
+    "        #Take the first tripletin query\n",
+    "        trip = query.split(\"WHERE\")[1]\n",
+    "        trip = trip.replace(\"{\", \"\").replace(\"}\", \"\")\n",
+    "        triplet = trip.split(\" \")\n",
+    "\n",
+    "        #remove empty strings\n",
+    "        triplet = [x for x in triplet if x != \"\"]\n",
+    "\n",
+    "        for t in triplet:\n",
+    "            if not(t.find(\"?\")):\n",
+    "                triplet[triplet.index(t)] = None\n",
+    "\n",
     "        #seq = \"[CLS] \" + question + \" [SEP] \"\n",
     "        if triplet[0] is not None:\n",
     "            #seq += \"[SUB] [SEP] \" + triplet[0]\n",
-    "            tokenized_seq = self.tokenizer(question, \"[SUB]\", triplet[0])#, padding=True, truncation=True)\n",
+    "            # , padding=True, truncation=True)\n",
+    "            tokenized_seq = tokenizer(question, \"[SUB]\", triplet[0], padding=True, truncation=True)\n",
     "        elif triplet[2] is not None:\n",
     "            #seq += \"[OBJ] [SEP] \" + triplet[2]\n",
-    "            tokenized_seq = self.tokenizer(question, \"[OBJ]\", triplet[2])#, padding=True, truncation=True)\n",
-    "        \n",
-    "        x = self.bert.forward(**tokenized_seq)\n",
-    "        x = self.linear(x)\n",
-    "        \n",
-    "        x = self.softmax(x)\n",
-    "        return x\n",
+    "            tokenized_seq = tokenizer(question, \"[OBJ]\", triplet[2], padding=True, truncation=True)\n",
+    "\n",
+    "        if inputs_max_len < len(tokenized_seq[\"input_ids\"]):\n",
+    "            inputs_max_len = len(tokenized_seq[\"input_ids\"])\n",
+    "        inputs.append(list(tokenized_seq.values())[0])\n",
     "\n",
+    "    correct_rels_max_len = 0\n",
+    "    correct_rels = []\n",
+    "    for d in tqdm(gold[\"questions\"]):\n",
+    "        question = d[\"question\"][0][\"string\"]\n",
+    "        query = d[\"query\"][\"sparql\"]\n",
     "\n",
-    "\n"
+    "        #Take the first tripletin query\n",
+    "        trip = query.split(\"WHERE\")[1]\n",
+    "        trip = trip.replace(\"{\", \"\").replace(\"}\", \"\")\n",
+    "        triplet = trip.split(\" \")\n",
+    "\n",
+    "        #remove empty strings\n",
+    "        triplet = [x for x in triplet if x != \"\"]\n",
+    "\n",
+    "        tokenized = tokenizer(triplet[1], padding=True, truncation=True)\n",
+    "        if correct_rels_max_len < len(tokenized[\"input_ids\"]):\n",
+    "            correct_rels_max_len = len(tokenized[\"input_ids\"])\n",
+    "\n",
+    "        correct_rels.append(list(tokenized.values())[0])\n",
+    "\n",
+    "    inputs_padded = np.array([i + [0]*(inputs_max_len-len(i)) for i in inputs])\n",
+    "    correct_rels_padded = np.array([i + [0]*(correct_rels_max_len-len(i)) for i in correct_rels])\n",
+    "    print(\"Finished with batches\")\n",
+    "    return torch.IntTensor(inputs_padded), torch.IntTensor(correct_rels_padded)\n"
    ]
   },
   {
@@ -66,36 +183,65 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "def encode(batch):\n",
-    "  return tokenizer(batch, padding=\"max_length\", max_length=256, return_tensors=\"pt\")\n",
-    "\n",
-    "\n",
-    "def convert_to_features(example_batch):\n",
-    "    input_encodings = encode(example_batch['text'])\n",
-    "    target_encodings = encode(example_batch['summary'])\n",
-    "\n",
-    "    labels = target_encodings['input_ids']\n",
-    "    decoder_input_ids = shift_tokens_right(\n",
-    "        labels, model.config.pad_token_id, model.config.decoder_start_token_id)\n",
-    "    labels[labels[:, :] == model.config.pad_token_id] = -100\n",
+    "# training_args = Seq2SeqTrainingArguments(\n",
+    "#     output_dir='./models/blackbox',\n",
+    "#     num_train_epochs=1,\n",
+    "#     per_device_train_batch_size=1,\n",
+    "#     per_device_eval_batch_size=1,\n",
+    "#     warmup_steps=10,\n",
+    "#     weight_decay=0.01,\n",
+    "#     logging_dir='./logs',\n",
+    "# )\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 69,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "MemoryError",
+     "evalue": "",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
+      "File \u001b[1;32mc:\\Users\\maxbj\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\modeling_utils.py:399\u001b[0m, in \u001b[0;36mload_state_dict\u001b[1;34m(checkpoint_file)\u001b[0m\n\u001b[0;32m    398\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m--> 399\u001b[0m     \u001b[39mreturn\u001b[39;00m torch\u001b[39m.\u001b[39;49mload(checkpoint_file, map_location\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mcpu\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n\u001b[0;32m    400\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e:\n",
+      "File \u001b[1;32mc:\\Users\\maxbj\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\serialization.py:713\u001b[0m, in \u001b[0;36mload\u001b[1;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[0;32m    712\u001b[0m         \u001b[39mreturn\u001b[39;00m _load(opened_zipfile, map_location, pickle_module, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mpickle_load_args)\n\u001b[1;32m--> 713\u001b[0m \u001b[39mreturn\u001b[39;00m _legacy_load(opened_file, map_location, pickle_module, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mpickle_load_args)\n",
+      "File \u001b[1;32mc:\\Users\\maxbj\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\serialization.py:930\u001b[0m, in \u001b[0;36m_legacy_load\u001b[1;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[0;32m    929\u001b[0m unpickler\u001b[39m.\u001b[39mpersistent_load \u001b[39m=\u001b[39m persistent_load\n\u001b[1;32m--> 930\u001b[0m result \u001b[39m=\u001b[39m unpickler\u001b[39m.\u001b[39;49mload()\n\u001b[0;32m    932\u001b[0m deserialized_storage_keys \u001b[39m=\u001b[39m pickle_module\u001b[39m.\u001b[39mload(f, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mpickle_load_args)\n",
+      "File \u001b[1;32mc:\\Users\\maxbj\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\serialization.py:871\u001b[0m, in \u001b[0;36m_legacy_load.<locals>.persistent_load\u001b[1;34m(saved_id)\u001b[0m\n\u001b[0;32m    870\u001b[0m \u001b[39mif\u001b[39;00m root_key \u001b[39mnot\u001b[39;00m \u001b[39min\u001b[39;00m deserialized_objects:\n\u001b[1;32m--> 871\u001b[0m     obj \u001b[39m=\u001b[39m cast(Storage, torch\u001b[39m.\u001b[39;49m_UntypedStorage(nbytes))\n\u001b[0;32m    872\u001b[0m     obj\u001b[39m.\u001b[39m_torch_load_uninitialized \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n",
+      "\u001b[1;31mRuntimeError\u001b[0m: [enforce fail at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\c10\\core\\impl\\alloc_cpu.cpp:81] data. DefaultCPUAllocator: not enough memory: you tried to allocate 93763584 bytes.",
+      "\nDuring handling of the above exception, another exception occurred:\n",
+      "\u001b[1;31mMemoryError\u001b[0m                               Traceback (most recent call last)",
+      "Cell \u001b[1;32mIn [69], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m model \u001b[39m=\u001b[39m NgmOne()\n\u001b[0;32m      3\u001b[0m EPOCHS \u001b[39m=\u001b[39m \u001b[39m3\u001b[39m\n\u001b[0;32m      4\u001b[0m criterion \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mCrossEntropyLoss()\n",
+      "Cell \u001b[1;32mIn [64], line 5\u001b[0m, in \u001b[0;36mNgmOne.__init__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m      3\u001b[0m \u001b[39msuper\u001b[39m(NgmOne, \u001b[39mself\u001b[39m)\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m()\n\u001b[0;32m      4\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtokenizer \u001b[39m=\u001b[39m BertTokenizer\u001b[39m.\u001b[39mfrom_pretrained(\u001b[39m\"\u001b[39m\u001b[39mbert-base-uncased\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m----> 5\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbert \u001b[39m=\u001b[39m BertModel\u001b[39m.\u001b[39;49mfrom_pretrained(\u001b[39m\"\u001b[39;49m\u001b[39mbert-base-uncased\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n\u001b[0;32m      6\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlinear \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mLinear(\u001b[39m768\u001b[39m, \u001b[39m247\u001b[39m)\n\u001b[0;32m      7\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msoftmax \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mSoftmax(dim\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)\n",
+      "File \u001b[1;32mc:\\Users\\maxbj\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\modeling_utils.py:2184\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[1;34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001b[0m\n\u001b[0;32m   2181\u001b[0m \u001b[39mif\u001b[39;00m from_pt:\n\u001b[0;32m   2182\u001b[0m     \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m is_sharded \u001b[39mand\u001b[39;00m state_dict \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m   2183\u001b[0m         \u001b[39m# Time to load the checkpoint\u001b[39;00m\n\u001b[1;32m-> 2184\u001b[0m         state_dict \u001b[39m=\u001b[39m load_state_dict(resolved_archive_file)\n\u001b[0;32m   2186\u001b[0m     \u001b[39m# set dtype to instantiate the model under:\u001b[39;00m\n\u001b[0;32m   2187\u001b[0m     \u001b[39m# 1. If torch_dtype is not None, we use that dtype\u001b[39;00m\n\u001b[0;32m   2188\u001b[0m     \u001b[39m# 2. If torch_dtype is \"auto\", we auto-detect dtype from the loaded state_dict, by checking its first\u001b[39;00m\n\u001b[0;32m   2189\u001b[0m     \u001b[39m#    weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype\u001b[39;00m\n\u001b[0;32m   2190\u001b[0m     \u001b[39m# we also may have config.torch_dtype available, but we won't rely on it till v5\u001b[39;00m\n\u001b[0;32m   2191\u001b[0m     dtype_orig \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n",
+      "File \u001b[1;32mc:\\Users\\maxbj\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\modeling_utils.py:403\u001b[0m, in \u001b[0;36mload_state_dict\u001b[1;34m(checkpoint_file)\u001b[0m\n\u001b[0;32m    401\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m    402\u001b[0m     \u001b[39mwith\u001b[39;00m \u001b[39mopen\u001b[39m(checkpoint_file) \u001b[39mas\u001b[39;00m f:\n\u001b[1;32m--> 403\u001b[0m         \u001b[39mif\u001b[39;00m f\u001b[39m.\u001b[39;49mread()\u001b[39m.\u001b[39mstartswith(\u001b[39m\"\u001b[39m\u001b[39mversion\u001b[39m\u001b[39m\"\u001b[39m):\n\u001b[0;32m    404\u001b[0m             \u001b[39mraise\u001b[39;00m \u001b[39mOSError\u001b[39;00m(\n\u001b[0;32m    405\u001b[0m                 \u001b[39m\"\u001b[39m\u001b[39mYou seem to have cloned a repository without having git-lfs installed. Please install \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m    406\u001b[0m                 \u001b[39m\"\u001b[39m\u001b[39mgit-lfs and run `git lfs install` followed by `git lfs pull` in the folder \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m    407\u001b[0m                 \u001b[39m\"\u001b[39m\u001b[39myou cloned.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m    408\u001b[0m             )\n\u001b[0;32m    409\u001b[0m         \u001b[39melse\u001b[39;00m:\n",
+      "\u001b[1;31mMemoryError\u001b[0m: "
+     ]
+    }
+   ],
+   "source": [
+    "model = NgmOne()\n",
     "\n",
-    "    encodings = {\n",
-    "        'input_ids': input_encodings['input_ids'],\n",
-    "        'attention_mask': input_encodings['attention_mask'],\n",
-    "        'decoder_input_ids': decoder_input_ids,\n",
-    "        'labels': labels,\n",
-    "    }\n",
+    "EPOCHS = 3\n",
+    "criterion = nn.CrossEntropyLoss()\n",
+    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
     "\n",
-    "    return encodings\n",
+    "train, corr_rels = make_batch()\n",
+    "for epoch in tqdm(range(EPOCHS)):\n",
+    "    optimizer.zero_grad()\n",
     "\n",
+    "    # Forward pass\n",
+    "    output = model(train)\n",
+    "    loss = criterion(output, corr_rels)\n",
     "\n",
-    "def get_dataset(path):\n",
-    "  df = pd.read_csv(path, sep=\",\", on_bad_lines='skip')\n",
-    "  dataset = datasets.Dataset.from_pandas(df)\n",
-    "  dataset = dataset.map(convert_to_features, batched=True)\n",
-    "  columns = ['input_ids', 'labels', 'decoder_input_ids', 'attention_mask', ]\n",
-    "  dataset.set_format(type='torch', columns=columns)\n",
-    "  return dataset\n"
+    "    if (epoch + 1) % 10 == 0:\n",
+    "        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n",
+    "    # Backward pass\n",
+    "    loss.backward()\n",
+    "    optimizer.step()\n",
+    "    \n"
    ]
   }
  ],
@@ -106,7 +252,15 @@
    "name": "python3"
   },
   "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
    "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
    "version": "3.9.11"
   },
   "orig_nbformat": 4,