diff --git a/Neural graph module/ngm.ipynb b/Neural graph module/ngm.ipynb index a1576638b3ea4097ce2d5ff2925261ea74af84c4..8bddb696fb833290afd664e8215874bfdc5168ba 100644 --- a/Neural graph module/ngm.ipynb +++ b/Neural graph module/ngm.ipynb @@ -2,9 +2,18 @@ "cells": [ { "cell_type": "code", - "execution_count": 10, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "b:\\Programs\\Miniconda\\envs\\tdde19\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "import datasets\n", "import torch\n", @@ -20,7 +29,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -29,32 +38,55 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cuda\n" + ] + } + ], + "source": [ + "# Use GPU if available\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "# Print cuda version\n", + "print(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "\n", "class NgmOne(nn.Module):\n", - " def __init__(self):\n", + " def __init__(self, device):\n", " super(NgmOne, self).__init__()\n", " self.tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n", - " self.bert = BertModel.from_pretrained(\"bert-base-uncased\")\n", - " self.linear = nn.Linear(768, 247)\n", - " self.softmax = nn.Softmax(dim=1)\n", + " self.bert = BertModel.from_pretrained(\"bert-base-uncased\").to(device)\n", + " self.linear = nn.Linear(768, 247).to(device)\n", + " self.softmax = nn.Softmax(dim=1).to(device)\n", + " self.device = device\n", " \n", " def forward(self, tokenized_seq, tokenized_mask):\n", - " x = self.bert.forward(tokenized_seq, attention_mask=tokenized_mask)\n", - " x = x[0][:,0,:]\n", + " tokenized_seq= tokenized_seq.to(self.device)\n", + " tokenized_mask = tokenized_mask.to(self.device)\n", + " with torch.no_grad():\n", + " x = self.bert.forward(tokenized_seq, attention_mask=tokenized_mask)\n", + " x = x[0][:,0,:].to(self.device)\n", " \n", " x = self.linear(x)\n", - "\n", " x = self.softmax(x)\n", " return x" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -92,7 +124,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -123,6 +155,9 @@ " #remove empty strings\n", " triplet = [x for x in triplet if x != \"\"]\n", "\n", + " if len(triplet) != 3:\n", + " continue\n", + "\n", " for t in triplet:\n", " if not(t.find(\"?\")):\n", " triplet[triplet.index(t)] = None\n", @@ -150,28 +185,31 @@ " trip = query.split(\"WHERE\")[1]\n", " trip = trip.replace(\"{\", \"\").replace(\"}\", \"\")\n", " triplet = trip.split(\" \")\n", + " \n", "\n", " #remove empty strings\n", " triplet = [x for x in triplet if x != \"\"]\n", "\n", - " tokenized = tokenizer(triplet[1], padding=True, truncation=True)\n", - " if correct_rels_max_len < len(tokenized[\"input_ids\"]):\n", - " correct_rels_max_len = len(tokenized[\"input_ids\"])\n", + " if len(triplet) != 3:\n", + " continue\n", + " # tokenized = tokenizer(triplet[1], padding=True, truncation=True)\n", + " # if correct_rels_max_len < len(tokenized[\"input_ids\"]):\n", + " # correct_rels_max_len = len(tokenized[\"input_ids\"])\n", "\n", - " correct_rels.append(list(tokenized.values())[0])\n", + " correct_rels.append(triplet[1])#list(tokenized.values())[0])\n", "\n", " inputs_padded = np.array([i + [0]*(inputs_max_len-len(i)) for i in inputs])\n", - " correct_rels_padded = np.array([i + [0]*(correct_rels_max_len-len(i)) for i in correct_rels])\n", + " #correct_rels_padded = np.array([i + [0]*(correct_rels_max_len-len(i)) for i in correct_rels])\n", "\n", " inputs_attention_mask = np.where(inputs_padded != 0, 1, 0)\n", - " correct_rels_attention_mask = np.where(correct_rels_padded != 0, 1, 0)\n", + " #correct_rels_attention_mask = np.where(correct_rels_padded != 0, 1, 0)\n", " print(\"Finished with batches\")\n", - " return torch.IntTensor(inputs_padded), torch.IntTensor(inputs_attention_mask), torch.IntTensor(correct_rels_padded), torch.IntTensor(correct_rels_attention_mask)\n" + " return torch.LongTensor(inputs_padded), torch.LongTensor(inputs_attention_mask), correct_rels #torch.IntTensor(correct_rels_padded), torch.LongTensor(correct_rels_attention_mask)\n" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -188,14 +226,14 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias']\n", + "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] @@ -211,66 +249,442 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 408/408 [00:00<00:00, 688.03it/s]\n", - "100%|██████████| 408/408 [00:00<00:00, 2241.79it/s]\n" + "100%|██████████| 408/408 [00:00<00:00, 2684.45it/s]\n", + "100%|██████████| 408/408 [00:00<00:00, 204160.82it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Finished with batches\n" + "Finished with batches\n", + "Epoch: 0001 loss = 5.509346\n", + "Epoch: 0002 loss = 5.508417\n", + "Epoch: 0003 loss = 5.507416\n", + "Epoch: 0004 loss = 5.506336\n", + "Epoch: 0005 loss = 5.505169\n", + "Epoch: 0006 loss = 5.503905\n", + "Epoch: 0007 loss = 5.502542\n", + "Epoch: 0008 loss = 5.501083\n", + "Epoch: 0009 loss = 5.499537\n", + "Epoch: 0010 loss = 5.497901\n", + "Epoch: 0011 loss = 5.496177\n", + "Epoch: 0012 loss = 5.494359\n", + "Epoch: 0013 loss = 5.492444\n", + "Epoch: 0014 loss = 5.490432\n", + "Epoch: 0015 loss = 5.488323\n", + "Epoch: 0016 loss = 5.486121\n", + "Epoch: 0017 loss = 5.483822\n", + "Epoch: 0018 loss = 5.481424\n", + "Epoch: 0019 loss = 5.478921\n", + "Epoch: 0020 loss = 5.476310\n", + "Epoch: 0021 loss = 5.473589\n", + "Epoch: 0022 loss = 5.470768\n", + "Epoch: 0023 loss = 5.467863\n", + "Epoch: 0024 loss = 5.464887\n", + "Epoch: 0025 loss = 5.461853\n", + "Epoch: 0026 loss = 5.458760\n", + "Epoch: 0027 loss = 5.455590\n", + "Epoch: 0028 loss = 5.452313\n", + "Epoch: 0029 loss = 5.448884\n", + "Epoch: 0030 loss = 5.445263\n", + "Epoch: 0031 loss = 5.441427\n", + "Epoch: 0032 loss = 5.437370\n", + "Epoch: 0033 loss = 5.433111\n", + "Epoch: 0034 loss = 5.428682\n", + "Epoch: 0035 loss = 5.424133\n", + "Epoch: 0036 loss = 5.419532\n", + "Epoch: 0037 loss = 5.414931\n", + "Epoch: 0038 loss = 5.410347\n", + "Epoch: 0039 loss = 5.405736\n", + "Epoch: 0040 loss = 5.401030\n", + "Epoch: 0041 loss = 5.396172\n", + "Epoch: 0042 loss = 5.391167\n", + "Epoch: 0043 loss = 5.386050\n", + "Epoch: 0044 loss = 5.380856\n", + "Epoch: 0045 loss = 5.375600\n", + "Epoch: 0046 loss = 5.370294\n", + "Epoch: 0047 loss = 5.364953\n", + "Epoch: 0048 loss = 5.359621\n", + "Epoch: 0049 loss = 5.354363\n", + "Epoch: 0050 loss = 5.349256\n", + "Epoch: 0051 loss = 5.344321\n", + "Epoch: 0052 loss = 5.339508\n", + "Epoch: 0053 loss = 5.334712\n", + "Epoch: 0054 loss = 5.329853\n", + "Epoch: 0055 loss = 5.324935\n", + "Epoch: 0056 loss = 5.320047\n", + "Epoch: 0057 loss = 5.315296\n", + "Epoch: 0058 loss = 5.310766\n", + "Epoch: 0059 loss = 5.306486\n", + "Epoch: 0060 loss = 5.302420\n", + "Epoch: 0061 loss = 5.298513\n", + "Epoch: 0062 loss = 5.294705\n", + "Epoch: 0063 loss = 5.290928\n", + "Epoch: 0064 loss = 5.287130\n", + "Epoch: 0065 loss = 5.283293\n", + "Epoch: 0066 loss = 5.279424\n", + "Epoch: 0067 loss = 5.275546\n", + "Epoch: 0068 loss = 5.271703\n", + "Epoch: 0069 loss = 5.267990\n", + "Epoch: 0070 loss = 5.264547\n", + "Epoch: 0071 loss = 5.261484\n", + "Epoch: 0072 loss = 5.258725\n", + "Epoch: 0073 loss = 5.256050\n", + "Epoch: 0074 loss = 5.253313\n", + "Epoch: 0075 loss = 5.250482\n", + "Epoch: 0076 loss = 5.247588\n", + "Epoch: 0077 loss = 5.244689\n", + "Epoch: 0078 loss = 5.241833\n", + "Epoch: 0079 loss = 5.239061\n", + "Epoch: 0080 loss = 5.236411\n", + "Epoch: 0081 loss = 5.233915\n", + "Epoch: 0082 loss = 5.231617\n", + "Epoch: 0083 loss = 5.229539\n", + "Epoch: 0084 loss = 5.227572\n", + "Epoch: 0085 loss = 5.225513\n", + "Epoch: 0086 loss = 5.223309\n", + "Epoch: 0087 loss = 5.221033\n", + "Epoch: 0088 loss = 5.218764\n", + "Epoch: 0089 loss = 5.216528\n", + "Epoch: 0090 loss = 5.214312\n", + "Epoch: 0091 loss = 5.212102\n", + "Epoch: 0092 loss = 5.209961\n", + "Epoch: 0093 loss = 5.208022\n", + "Epoch: 0094 loss = 5.206336\n", + "Epoch: 0095 loss = 5.204650\n", + "Epoch: 0096 loss = 5.202743\n", + "Epoch: 0097 loss = 5.200671\n", + "Epoch: 0098 loss = 5.198618\n", + "Epoch: 0099 loss = 5.196756\n", + "Epoch: 0100 loss = 5.195175\n", + "Epoch: 0101 loss = 5.193856\n", + "Epoch: 0102 loss = 5.192677\n", + "Epoch: 0103 loss = 5.191495\n", + "Epoch: 0104 loss = 5.190241\n", + "Epoch: 0105 loss = 5.188928\n", + "Epoch: 0106 loss = 5.187597\n", + "Epoch: 0107 loss = 5.186284\n", + "Epoch: 0108 loss = 5.184998\n", + "Epoch: 0109 loss = 5.183724\n", + "Epoch: 0110 loss = 5.182417\n", + "Epoch: 0111 loss = 5.181022\n", + "Epoch: 0112 loss = 5.179502\n", + "Epoch: 0113 loss = 5.177876\n", + "Epoch: 0114 loss = 5.176267\n", + "Epoch: 0115 loss = 5.174932\n", + "Epoch: 0116 loss = 5.174131\n", + "Epoch: 0117 loss = 5.173597\n", + "Epoch: 0118 loss = 5.172731\n", + "Epoch: 0119 loss = 5.171515\n", + "Epoch: 0120 loss = 5.170227\n", + "Epoch: 0121 loss = 5.169064\n", + "Epoch: 0122 loss = 5.168075\n", + "Epoch: 0123 loss = 5.167188\n", + "Epoch: 0124 loss = 5.166307\n", + "Epoch: 0125 loss = 5.165361\n", + "Epoch: 0126 loss = 5.164318\n", + "Epoch: 0127 loss = 5.163185\n", + "Epoch: 0128 loss = 5.162001\n", + "Epoch: 0129 loss = 5.160835\n", + "Epoch: 0130 loss = 5.159757\n", + "Epoch: 0131 loss = 5.158821\n", + "Epoch: 0132 loss = 5.158039\n", + "Epoch: 0133 loss = 5.157377\n", + "Epoch: 0134 loss = 5.156784\n", + "Epoch: 0135 loss = 5.156210\n", + "Epoch: 0136 loss = 5.155619\n", + "Epoch: 0137 loss = 5.154994\n", + "Epoch: 0138 loss = 5.154339\n", + "Epoch: 0139 loss = 5.153667\n", + "Epoch: 0140 loss = 5.152996\n", + "Epoch: 0141 loss = 5.152328\n", + "Epoch: 0142 loss = 5.151660\n", + "Epoch: 0143 loss = 5.150975\n", + "Epoch: 0144 loss = 5.150242\n", + "Epoch: 0145 loss = 5.149413\n", + "Epoch: 0146 loss = 5.148431\n", + "Epoch: 0147 loss = 5.147236\n", + "Epoch: 0148 loss = 5.145802\n", + "Epoch: 0149 loss = 5.144192\n", + "Epoch: 0150 loss = 5.142632\n", + "Epoch: 0151 loss = 5.141498\n", + "Epoch: 0152 loss = 5.140967\n", + "Epoch: 0153 loss = 5.140542\n", + "Epoch: 0154 loss = 5.139736\n", + "Epoch: 0155 loss = 5.138566\n", + "Epoch: 0156 loss = 5.137214\n", + "Epoch: 0157 loss = 5.135843\n", + "Epoch: 0158 loss = 5.134551\n", + "Epoch: 0159 loss = 5.133349\n", + "Epoch: 0160 loss = 5.132212\n", + "Epoch: 0161 loss = 5.131142\n", + "Epoch: 0162 loss = 5.130210\n", + "Epoch: 0163 loss = 5.129475\n", + "Epoch: 0164 loss = 5.128819\n", + "Epoch: 0165 loss = 5.128070\n", + "Epoch: 0166 loss = 5.127200\n", + "Epoch: 0167 loss = 5.126271\n", + "Epoch: 0168 loss = 5.125329\n", + "Epoch: 0169 loss = 5.124373\n", + "Epoch: 0170 loss = 5.123369\n", + "Epoch: 0171 loss = 5.122267\n", + "Epoch: 0172 loss = 5.121037\n", + "Epoch: 0173 loss = 5.119725\n", + "Epoch: 0174 loss = 5.118514\n", + "Epoch: 0175 loss = 5.117722\n", + "Epoch: 0176 loss = 5.117445\n", + "Epoch: 0177 loss = 5.117110\n", + "Epoch: 0178 loss = 5.116376\n", + "Epoch: 0179 loss = 5.115437\n", + "Epoch: 0180 loss = 5.114553\n", + "Epoch: 0181 loss = 5.113859\n", + "Epoch: 0182 loss = 5.113349\n", + "Epoch: 0183 loss = 5.112945\n", + "Epoch: 0184 loss = 5.112563\n", + "Epoch: 0185 loss = 5.112154\n", + "Epoch: 0186 loss = 5.111700\n", + "Epoch: 0187 loss = 5.111213\n", + "Epoch: 0188 loss = 5.110720\n", + "Epoch: 0189 loss = 5.110251\n", + "Epoch: 0190 loss = 5.109831\n", + "Epoch: 0191 loss = 5.109465\n", + "Epoch: 0192 loss = 5.109145\n", + "Epoch: 0193 loss = 5.108854\n", + "Epoch: 0194 loss = 5.108573\n", + "Epoch: 0195 loss = 5.108290\n", + "Epoch: 0196 loss = 5.108000\n", + "Epoch: 0197 loss = 5.107710\n", + "Epoch: 0198 loss = 5.107430\n", + "Epoch: 0199 loss = 5.107163\n", + "Epoch: 0200 loss = 5.106915\n", + "Epoch: 0201 loss = 5.106686\n", + "Epoch: 0202 loss = 5.106472\n", + "Epoch: 0203 loss = 5.106268\n", + "Epoch: 0204 loss = 5.106072\n", + "Epoch: 0205 loss = 5.105881\n", + "Epoch: 0206 loss = 5.105694\n", + "Epoch: 0207 loss = 5.105511\n", + "Epoch: 0208 loss = 5.105332\n", + "Epoch: 0209 loss = 5.105159\n", + "Epoch: 0210 loss = 5.104992\n", + "Epoch: 0211 loss = 5.104831\n", + "Epoch: 0212 loss = 5.104676\n", + "Epoch: 0213 loss = 5.104527\n", + "Epoch: 0214 loss = 5.104386\n", + "Epoch: 0215 loss = 5.104248\n", + "Epoch: 0216 loss = 5.104114\n", + "Epoch: 0217 loss = 5.103983\n", + "Epoch: 0218 loss = 5.103856\n", + "Epoch: 0219 loss = 5.103732\n", + "Epoch: 0220 loss = 5.103611\n", + "Epoch: 0221 loss = 5.103494\n", + "Epoch: 0222 loss = 5.103378\n", + "Epoch: 0223 loss = 5.103266\n", + "Epoch: 0224 loss = 5.103157\n", + "Epoch: 0225 loss = 5.103050\n", + "Epoch: 0226 loss = 5.102947\n", + "Epoch: 0227 loss = 5.102844\n", + "Epoch: 0228 loss = 5.102745\n", + "Epoch: 0229 loss = 5.102647\n", + "Epoch: 0230 loss = 5.102551\n", + "Epoch: 0231 loss = 5.102457\n", + "Epoch: 0232 loss = 5.102361\n", + "Epoch: 0233 loss = 5.102268\n", + "Epoch: 0234 loss = 5.102171\n", + "Epoch: 0235 loss = 5.102070\n", + "Epoch: 0236 loss = 5.101960\n", + "Epoch: 0237 loss = 5.101832\n", + "Epoch: 0238 loss = 5.101676\n", + "Epoch: 0239 loss = 5.101472\n", + "Epoch: 0240 loss = 5.101196\n", + "Epoch: 0241 loss = 5.100809\n", + "Epoch: 0242 loss = 5.100273\n", + "Epoch: 0243 loss = 5.099566\n", + "Epoch: 0244 loss = 5.098716\n", + "Epoch: 0245 loss = 5.097848\n", + "Epoch: 0246 loss = 5.097132\n", + "Epoch: 0247 loss = 5.096595\n", + "Epoch: 0248 loss = 5.095999\n", + "Epoch: 0249 loss = 5.095338\n", + "Epoch: 0250 loss = 5.094879\n", + "Epoch: 0251 loss = 5.094688\n", + "Epoch: 0252 loss = 5.094509\n", + "Epoch: 0253 loss = 5.094189\n", + "Epoch: 0254 loss = 5.093769\n", + "Epoch: 0255 loss = 5.093320\n", + "Epoch: 0256 loss = 5.092857\n", + "Epoch: 0257 loss = 5.092339\n", + "Epoch: 0258 loss = 5.091694\n", + "Epoch: 0259 loss = 5.090842\n", + "Epoch: 0260 loss = 5.089743\n", + "Epoch: 0261 loss = 5.088472\n", + "Epoch: 0262 loss = 5.087311\n", + "Epoch: 0263 loss = 5.086748\n", + "Epoch: 0264 loss = 5.086977\n", + "Epoch: 0265 loss = 5.087138\n", + "Epoch: 0266 loss = 5.086704\n", + "Epoch: 0267 loss = 5.085968\n", + "Epoch: 0268 loss = 5.085289\n", + "Epoch: 0269 loss = 5.084831\n", + "Epoch: 0270 loss = 5.084577\n", + "Epoch: 0271 loss = 5.084425\n", + "Epoch: 0272 loss = 5.084273\n", + "Epoch: 0273 loss = 5.084045\n", + "Epoch: 0274 loss = 5.083702\n", + "Epoch: 0275 loss = 5.083218\n", + "Epoch: 0276 loss = 5.082567\n", + "Epoch: 0277 loss = 5.081728\n", + "Epoch: 0278 loss = 5.080690\n", + "Epoch: 0279 loss = 5.079515\n", + "Epoch: 0280 loss = 5.078421\n", + "Epoch: 0281 loss = 5.077838\n", + "Epoch: 0282 loss = 5.078005\n", + "Epoch: 0283 loss = 5.078021\n", + "Epoch: 0284 loss = 5.077395\n", + "Epoch: 0285 loss = 5.076532\n", + "Epoch: 0286 loss = 5.075821\n", + "Epoch: 0287 loss = 5.075374\n", + "Epoch: 0288 loss = 5.075100\n", + "Epoch: 0289 loss = 5.074864\n", + "Epoch: 0290 loss = 5.074569\n", + "Epoch: 0291 loss = 5.074176\n", + "Epoch: 0292 loss = 5.073692\n", + "Epoch: 0293 loss = 5.073163\n", + "Epoch: 0294 loss = 5.072643\n", + "Epoch: 0295 loss = 5.072181\n", + "Epoch: 0296 loss = 5.071800\n", + "Epoch: 0297 loss = 5.071474\n", + "Epoch: 0298 loss = 5.071146\n", + "Epoch: 0299 loss = 5.070775\n", + "Epoch: 0300 loss = 5.070364\n" + ] + } + ], + "source": [ + "model = NgmOne(device)\n", + "\n", + "EPOCHS = 300\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.Adam(model.parameters(), lr=0.05)\n", + "\n", + "with open(\"../data/relations-query-qald-9-linked.json\", \"r\") as f:\n", + " relations = json.load(f)\n", + "\n", + "train, train_mask, corr_rels = make_batch()\n", + "corr_indx = torch.LongTensor([relations.index(r) for r in corr_rels]).to(device)\n", + "\n", + "for epoch in range(EPOCHS):\n", + " optimizer.zero_grad()\n", + "\n", + " # Forward pass\n", + " output = model(train, train_mask)\n", + " loss = criterion(output, corr_indx)\n", + "\n", + " if (epoch + 1) % 1 == 0:\n", + " print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))\n", + " # Backward pass\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Beginning making batch\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 0%| | 0/3 [03:03<?, ?it/s]\n" + "100%|██████████| 408/408 [00:00<00:00, 814.37it/s] \n", + "100%|██████████| 408/408 [00:00<00:00, 101922.34it/s]\n" ] }, { - "ename": "RuntimeError", - "evalue": "0D or 1D target tensor expected, multi-target not supported", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[1;32mIn [26], line 13\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[39m# Forward pass\u001b[39;00m\n\u001b[0;32m 12\u001b[0m output \u001b[39m=\u001b[39m model(train, train_mask)\n\u001b[1;32m---> 13\u001b[0m loss \u001b[39m=\u001b[39m criterion(output, corr_rels)\n\u001b[0;32m 15\u001b[0m \u001b[39mif\u001b[39;00m (epoch \u001b[39m+\u001b[39m \u001b[39m1\u001b[39m) \u001b[39m%\u001b[39m \u001b[39m10\u001b[39m \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[0;32m 16\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m'\u001b[39m\u001b[39mEpoch:\u001b[39m\u001b[39m'\u001b[39m, \u001b[39m'\u001b[39m\u001b[39m%04d\u001b[39;00m\u001b[39m'\u001b[39m \u001b[39m%\u001b[39m (epoch \u001b[39m+\u001b[39m \u001b[39m1\u001b[39m), \u001b[39m'\u001b[39m\u001b[39mcost =\u001b[39m\u001b[39m'\u001b[39m, \u001b[39m'\u001b[39m\u001b[39m{:.6f}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(loss))\n", - "File \u001b[1;32mc:\\Users\\maxbj\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39m\u001b[39minput\u001b[39m, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", - "File \u001b[1;32mc:\\Users\\maxbj\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\loss.py:1164\u001b[0m, in \u001b[0;36mCrossEntropyLoss.forward\u001b[1;34m(self, input, target)\u001b[0m\n\u001b[0;32m 1163\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: Tensor, target: Tensor) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Tensor:\n\u001b[1;32m-> 1164\u001b[0m \u001b[39mreturn\u001b[39;00m F\u001b[39m.\u001b[39;49mcross_entropy(\u001b[39minput\u001b[39;49m, target, weight\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mweight,\n\u001b[0;32m 1165\u001b[0m ignore_index\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mignore_index, reduction\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mreduction,\n\u001b[0;32m 1166\u001b[0m label_smoothing\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mlabel_smoothing)\n", - "File \u001b[1;32mc:\\Users\\maxbj\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\functional.py:3014\u001b[0m, in \u001b[0;36mcross_entropy\u001b[1;34m(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)\u001b[0m\n\u001b[0;32m 3012\u001b[0m \u001b[39mif\u001b[39;00m size_average \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mor\u001b[39;00m reduce \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m 3013\u001b[0m reduction \u001b[39m=\u001b[39m _Reduction\u001b[39m.\u001b[39mlegacy_get_string(size_average, reduce)\n\u001b[1;32m-> 3014\u001b[0m \u001b[39mreturn\u001b[39;00m torch\u001b[39m.\u001b[39;49m_C\u001b[39m.\u001b[39;49m_nn\u001b[39m.\u001b[39;49mcross_entropy_loss(\u001b[39minput\u001b[39;49m, target, weight, _Reduction\u001b[39m.\u001b[39;49mget_enum(reduction), ignore_index, label_smoothing)\n", - "\u001b[1;31mRuntimeError\u001b[0m: 0D or 1D target tensor expected, multi-target not supported" + "name": "stdout", + "output_type": "stream", + "text": [ + "Finished with batches\n", + "lowest confidence 0.19197923\n", + "Accuracy: 0.6323529411764706\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "KeyboardInterrupt\n", + "\n" ] } ], "source": [ - "model = NgmOne()\n", + "# Predict\n", + "train, train_mask, corr_rels = make_batch()\n", + "with torch.no_grad():\n", + " output = model(train, train_mask)\n", "\n", - "EPOCHS = 3\n", - "criterion = nn.CrossEntropyLoss()\n", - "optimizer = optim.Adam(model.parameters(), lr=0.001)\n", + "output = output.detach().cpu().numpy()\n", "\n", - "train,train_mask, corr_rels, correct_rels_mask = make_batch()\n", - "for epoch in tqdm(range(EPOCHS)):\n", - " optimizer.zero_grad()\n", "\n", - " # Forward pass\n", - " output = model(train, train_mask)\n", - " loss = criterion(output, corr_rels)\n", + "prediction = [relations[np.argmax(pred).item()]for pred in output]\n", + "probability = [pred[np.argmax(pred)] for pred in output]\n", + "correct_pred = [corr_rels[i] for i in range(len(output))]\n", "\n", - " if (epoch + 1) % 10 == 0:\n", - " print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n", - " # Backward pass\n", - " loss.backward()\n", - " optimizer.step()\n", - " \n" + "\n", + "\n", + "preds = [\n", + " {\"pred\": relations[np.argmax(pred).item()], \n", + " \"prob\": pred[np.argmax(pred)],\n", + " \"correct\": corr_rels[count]} \n", + " for count, pred in enumerate(output)\n", + " ]\n", + "\n", + "\n", + "\n", + "# for pred in preds:\n", + "# print(\"pred\", pred[\"pred\"], \"prob\", pred[\"prob\"], \"correct\", pred[\"correct\"])\n", + "\n", + "# for pred, prob, correct_pred in zip(prediction, probability, correct_pred):\n", + "# print(\"pred:\", pred, \"prob:\", prob, \"| correct\", correct_pred)\n", + "\n", + "print(\"lowest confidence\", min(probability))\n", + "\n", + "def accuracy_score(y_true, y_pred):\n", + " corr_preds=0\n", + " wrong_preds=0\n", + " for pred, correct in zip(y_pred, y_true):\n", + " if pred == correct:\n", + " corr_preds += 1\n", + " else:\n", + " wrong_preds += 1\n", + " return corr_preds/(corr_preds+wrong_preds)\n", + "\n", + "print(\"Accuracy:\", accuracy_score(correct_pred, prediction))\n", + "\n" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.9.11 64-bit", + "display_name": "Python 3.10.4 ('tdde19')", "language": "python", "name": "python3" }, @@ -284,12 +698,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.11" + "version": "3.10.4" }, "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "64e7cd3b4b88defe39dd61a4584920400d6beb2615ab2244e340c2e20eecdfe9" + "hash": "8e4aa0e1a1e15de86146661edda0b2884b54582522f7ff2b916774ba6b8accb1" } } }, diff --git a/data/prefix-relations-qald-9-train-linked.json b/data/relations-prefixed-qald-9-train-linked.json similarity index 100% rename from data/prefix-relations-qald-9-train-linked.json rename to data/relations-prefixed-qald-9-train-linked.json diff --git a/data/relations-query-qald-9-linked.json b/data/relations-query-qald-9-linked.json new file mode 100644 index 0000000000000000000000000000000000000000..3e305e9d35c7ddc8a9970de83ff4d99f0fdb9c80 --- /dev/null +++ b/data/relations-query-qald-9-linked.json @@ -0,0 +1,139 @@ +[ + "dbo:foundationPlace", + "dbo:founder", + "<http://dbpedia.org/ontology/officialSchoolColour>", + "<http://dbpedia.org/property/ethnicity>", + "<http://dbpedia.org/ontology/mission>", + "<http://dbpedia.org/property/admittancedate>", + "dbo:writer", + "<http://dbpedia.org/property/carbs>", + "<http://dbpedia.org/property/borderingstates>", + "dbo:party", + "dbo:growingGrape", + "<http://dbpedia.org/property/residence>", + "dbo:wineRegion", + "dbo:language", + "dbo:officialLanguage", + "dbo:influencedBy", + "dbo:routeEnd", + "<http://dbpedia.org/ontology/netIncome>", + "dbo:bandMember", + "dbo:team", + "dbo:origin", + "<http://dbpedia.org/property/launchPad>", + "dbp:species", + "dbo:composer", + "<http://dbpedia.org/property/author>", + "<http://dbpedia.org/ontology/creator>", + "dbo:developer", + "<http://dbpedia.org/ontology/owner>", + "<http://dbpedia.org/ontology/spouse>", + "dbo:timeZone", + "<http://dbpedia.org/property/governor>", + "dbo:numberOfPages", + "dbo:deathCause", + "dbo:award", + "dbo:activeYearsEndDate", + "dbo:governmentType", + "dbo:maximumDepth", + "dbo:owner", + "dbo:leader", + "dbo:birthDate", + "dbo:editor", + "dbo:knownFor", + "dbo:starring", + "dbo:targetAirport", + "<http://dbpedia.org/property/founded>", + "<http://dbpedia.org/property/leaderParty>", + "<http://dbpedia.org/property/speciality>", + "dbo:country", + "dbo:runtime", + "<http://dbpedia.org/ontology/spokenIn>", + "dbo:battle", + "dbo:discoverer", + "dbo:portrayer", + "dbo:areaTotal", + "<http://dbpedia.org/ontology/portrayer>", + "dbo:height", + "dbo:spouse", + "dbo:abbreviation", + "<http://dbpedia.org/ontology/profession>", + "dbo:musicComposer", + "dbo:firstAscentPerson", + "dbo:publisher", + "<http://dbpedia.org/ontology/deathPlace>", + "<http://dbpedia.org/property/largestmetro>", + "dbp:writer", + "dbo:ingredient", + "dbo:class", + "<http://dbpedia.org/property/breed>", + "dbo:director", + "dbo:alias", + "<http://dbpedia.org/ontology/currency>", + "<http://dbpedia.org/ontology/populationTotal>", + "<http://dbpedia.org/property/successor>", + "<http://dbpedia.org/ontology/manager>", + "dbo:mayor", + "dbo:date", + "dbp:editor", + "a", + "<http://dbpedia.org/ontology/foundedBy>", + "dbo:completionDate", + "dbp:populationDensityRank", + "dbo:presenter", + "<http://dbpedia.org/ontology/instrument>", + "<http://dbpedia.org/ontology/numberOfEmployees>", + "dbo:series", + "dbo:creator", + "<http://dbpedia.org/ontology/producer>", + "dbo:leaderName", + "<http://dbpedia.org/property/children>", + "<http://dbpedia.org/property/title>", + "<http://dbpedia.org/property/fifaMin>", + "dbo:crosses", + "dbo:mission", + "dbo:architect", + "dbo:largestCity", + "dbo:budget", + "<http://dbpedia.org/property/birthName>", + "dbo:populationTotal", + "dbo:foundingDate", + "dbo:vicePresident", + "<http://dbpedia.org/ontology/genre>", + "dbo:product", + "dbo:type", + "dbo:state", + "dbo:influenced", + "dbo:doctoralAdvisor", + "dbo:numberOfLocations", + "dbo:successor", + "<http://dbpedia.org/property/ballpark>", + "<http://dbpedia.org/ontology/country>", + "dbo:sourceCountry", + "dbo:birthPlace", + "<http://dbpedia.org/property/highest>", + "<http://dbpedia.org/ontology/child>", + "dbo:birthName", + "dbo:child", + "dbo:deathPlace", + "<http://dbpedia.org/ontology/abbreviation>", + "dbo:dissolutionDate", + "dbo:author", + "dbo:birthYear", + "<http://dbpedia.org/property/beginningDate>", + "dbo:ethnicGroup", + "dbo:currency", + "<http://dbpedia.org/property/programme>", + "<http://dbpedia.org/property/shipNamesake>", + "dbo:capital", + "dbo:programmingLanguage", + "dbo:city", + "<http://dbpedia.org/ontology/deathDate>", + "dbo:almaMater", + "<http://dbpedia.org/property/employees>", + "dbo:location", + "<http://dbpedia.org/property/accessioneudate>", + "rdf:type", + "<http://dbpedia.org/ontology/developer>", + "dbo:restingPlace" +]