diff --git a/Neural graph module/ngm.ipynb b/Neural graph module/ngm.ipynb index 8bddb696fb833290afd664e8215874bfdc5168ba..14d945f4bdc93c70b54d5eb307a085a1777c93db 100644 --- a/Neural graph module/ngm.ipynb +++ b/Neural graph module/ngm.ipynb @@ -2,18 +2,9 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "b:\\Programs\\Miniconda\\envs\\tdde19\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], + "outputs": [], "source": [ "import datasets\n", "import torch\n", @@ -29,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -38,17 +29,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "cuda\n" - ] - } - ], + "outputs": [], "source": [ "# Use GPU if available\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", @@ -58,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -78,7 +61,7 @@ " with torch.no_grad():\n", " x = self.bert.forward(tokenized_seq, attention_mask=tokenized_mask)\n", " x = x[0][:,0,:].to(self.device)\n", - " \n", + " \n", " x = self.linear(x)\n", " x = self.softmax(x)\n", " return x" @@ -86,7 +69,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -124,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -209,7 +192,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -226,347 +209,661 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import Dataset, DataLoader\n", + "\n", + "class MyDataset(Dataset):\n", + " def __init__(self, inputs, attention_mask, correct_rels, relations):\n", + " self.inputs = inputs\n", + " self.attention_mask = attention_mask\n", + " self.correct_rels = correct_rels\n", + " self.relations = relations\n", + " \n", + "\n", + " def __len__(self):\n", + " return len(self.inputs)\n", + "\n", + " def __getitem__(self, idx):\n", + " return self.inputs[idx], self.attention_mask[idx], self.relations.index(self.correct_rels[idx])\n", + "\n", + "\n", + "\n", + "\n", + "#From scratch json creates data set.\n", + "# class MyDataset(Dataset):\n", + "# def __init__(self, json_file, transform=None):\n", + "# self.qald_data = json.load(json_file)\n", + "\n", + "# def __len__(self):\n", + "# return len(self.qald_data)\n", + "\n", + "# def __getitem__(self, idx):\n", + "# self.inputs[idx], self.attention_mask[idx], self.labels[idx]" + ] + }, + { + "cell_type": "code", + "execution_count": 45, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Beginning making batch\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias']\n", - "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", - "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + "100%|██████████| 408/408 [00:00<00:00, 955.51it/s] \n", + "100%|██████████| 408/408 [00:00<00:00, 136139.70it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Beginning making batch\n" + "Finished with batches\n", + "features: tensor([[ 101, 2040, 2003, 1996, 3677, 1997, 1996, 4035, 6870, 19247,\n", + " 1029, 102, 1031, 4942, 1033, 102, 0, 0, 0, 0,\n", + " 0, 0, 0]]) mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]]) label_index tensor(81)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 408/408 [00:00<00:00, 2684.45it/s]\n", - "100%|██████████| 408/408 [00:00<00:00, 204160.82it/s]\n" + "\n" + ] + } + ], + "source": [ + "#Prepare data\n", + "\n", + "def open_json(file):\n", + " with open(file, \"r\") as f:\n", + " return json.load(f)\n", + "\n", + "relations = open_json(\"../data/relations-query-qald-9-linked.json\")\n", + "train_set = MyDataset(*make_batch(), relations=relations)\n", + "train_dataloader = DataLoader(train_set, batch_size=1, shuffle=True)\n", + "\n", + "#show first entry\n", + "train_features, train_mask, train_label = next(iter(train_dataloader))\n", + "print(\"features:\", train_features, \"mask:\",train_mask,\"label_index\", train_label[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']\n", + "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Finished with batches\n", - "Epoch: 0001 loss = 5.509346\n", - "Epoch: 0002 loss = 5.508417\n", - "Epoch: 0003 loss = 5.507416\n", - "Epoch: 0004 loss = 5.506336\n", - "Epoch: 0005 loss = 5.505169\n", - "Epoch: 0006 loss = 5.503905\n", - "Epoch: 0007 loss = 5.502542\n", - "Epoch: 0008 loss = 5.501083\n", - "Epoch: 0009 loss = 5.499537\n", - "Epoch: 0010 loss = 5.497901\n", - "Epoch: 0011 loss = 5.496177\n", - "Epoch: 0012 loss = 5.494359\n", - "Epoch: 0013 loss = 5.492444\n", - "Epoch: 0014 loss = 5.490432\n", - "Epoch: 0015 loss = 5.488323\n", - "Epoch: 0016 loss = 5.486121\n", - "Epoch: 0017 loss = 5.483822\n", - "Epoch: 0018 loss = 5.481424\n", - "Epoch: 0019 loss = 5.478921\n", - "Epoch: 0020 loss = 5.476310\n", - "Epoch: 0021 loss = 5.473589\n", - "Epoch: 0022 loss = 5.470768\n", - "Epoch: 0023 loss = 5.467863\n", - "Epoch: 0024 loss = 5.464887\n", - "Epoch: 0025 loss = 5.461853\n", - "Epoch: 0026 loss = 5.458760\n", - "Epoch: 0027 loss = 5.455590\n", - "Epoch: 0028 loss = 5.452313\n", - "Epoch: 0029 loss = 5.448884\n", - "Epoch: 0030 loss = 5.445263\n", - "Epoch: 0031 loss = 5.441427\n", - "Epoch: 0032 loss = 5.437370\n", - "Epoch: 0033 loss = 5.433111\n", - "Epoch: 0034 loss = 5.428682\n", - "Epoch: 0035 loss = 5.424133\n", - "Epoch: 0036 loss = 5.419532\n", - "Epoch: 0037 loss = 5.414931\n", - "Epoch: 0038 loss = 5.410347\n", - "Epoch: 0039 loss = 5.405736\n", - "Epoch: 0040 loss = 5.401030\n", - "Epoch: 0041 loss = 5.396172\n", - "Epoch: 0042 loss = 5.391167\n", - "Epoch: 0043 loss = 5.386050\n", - "Epoch: 0044 loss = 5.380856\n", - "Epoch: 0045 loss = 5.375600\n", - "Epoch: 0046 loss = 5.370294\n", - "Epoch: 0047 loss = 5.364953\n", - "Epoch: 0048 loss = 5.359621\n", - "Epoch: 0049 loss = 5.354363\n", - "Epoch: 0050 loss = 5.349256\n", - "Epoch: 0051 loss = 5.344321\n", - "Epoch: 0052 loss = 5.339508\n", - "Epoch: 0053 loss = 5.334712\n", - "Epoch: 0054 loss = 5.329853\n", - "Epoch: 0055 loss = 5.324935\n", - "Epoch: 0056 loss = 5.320047\n", - "Epoch: 0057 loss = 5.315296\n", - "Epoch: 0058 loss = 5.310766\n", - "Epoch: 0059 loss = 5.306486\n", - "Epoch: 0060 loss = 5.302420\n", - "Epoch: 0061 loss = 5.298513\n", - "Epoch: 0062 loss = 5.294705\n", - "Epoch: 0063 loss = 5.290928\n", - "Epoch: 0064 loss = 5.287130\n", - "Epoch: 0065 loss = 5.283293\n", - "Epoch: 0066 loss = 5.279424\n", - "Epoch: 0067 loss = 5.275546\n", - "Epoch: 0068 loss = 5.271703\n", - "Epoch: 0069 loss = 5.267990\n", - "Epoch: 0070 loss = 5.264547\n", - "Epoch: 0071 loss = 5.261484\n", - "Epoch: 0072 loss = 5.258725\n", - "Epoch: 0073 loss = 5.256050\n", - "Epoch: 0074 loss = 5.253313\n", - "Epoch: 0075 loss = 5.250482\n", - "Epoch: 0076 loss = 5.247588\n", - "Epoch: 0077 loss = 5.244689\n", - "Epoch: 0078 loss = 5.241833\n", - "Epoch: 0079 loss = 5.239061\n", - "Epoch: 0080 loss = 5.236411\n", - "Epoch: 0081 loss = 5.233915\n", - "Epoch: 0082 loss = 5.231617\n", - "Epoch: 0083 loss = 5.229539\n", - "Epoch: 0084 loss = 5.227572\n", - "Epoch: 0085 loss = 5.225513\n", - "Epoch: 0086 loss = 5.223309\n", - "Epoch: 0087 loss = 5.221033\n", - "Epoch: 0088 loss = 5.218764\n", - "Epoch: 0089 loss = 5.216528\n", - "Epoch: 0090 loss = 5.214312\n", - "Epoch: 0091 loss = 5.212102\n", - "Epoch: 0092 loss = 5.209961\n", - "Epoch: 0093 loss = 5.208022\n", - "Epoch: 0094 loss = 5.206336\n", - "Epoch: 0095 loss = 5.204650\n", - "Epoch: 0096 loss = 5.202743\n", - "Epoch: 0097 loss = 5.200671\n", - "Epoch: 0098 loss = 5.198618\n", - "Epoch: 0099 loss = 5.196756\n", - "Epoch: 0100 loss = 5.195175\n", - "Epoch: 0101 loss = 5.193856\n", - "Epoch: 0102 loss = 5.192677\n", - "Epoch: 0103 loss = 5.191495\n", - "Epoch: 0104 loss = 5.190241\n", - "Epoch: 0105 loss = 5.188928\n", - "Epoch: 0106 loss = 5.187597\n", - "Epoch: 0107 loss = 5.186284\n", - "Epoch: 0108 loss = 5.184998\n", - "Epoch: 0109 loss = 5.183724\n", - "Epoch: 0110 loss = 5.182417\n", - "Epoch: 0111 loss = 5.181022\n", - "Epoch: 0112 loss = 5.179502\n", - "Epoch: 0113 loss = 5.177876\n", - "Epoch: 0114 loss = 5.176267\n", - "Epoch: 0115 loss = 5.174932\n", - "Epoch: 0116 loss = 5.174131\n", - "Epoch: 0117 loss = 5.173597\n", - "Epoch: 0118 loss = 5.172731\n", - "Epoch: 0119 loss = 5.171515\n", - "Epoch: 0120 loss = 5.170227\n", - "Epoch: 0121 loss = 5.169064\n", - "Epoch: 0122 loss = 5.168075\n", - "Epoch: 0123 loss = 5.167188\n", - "Epoch: 0124 loss = 5.166307\n", - "Epoch: 0125 loss = 5.165361\n", - "Epoch: 0126 loss = 5.164318\n", - "Epoch: 0127 loss = 5.163185\n", - "Epoch: 0128 loss = 5.162001\n", - "Epoch: 0129 loss = 5.160835\n", - "Epoch: 0130 loss = 5.159757\n", - "Epoch: 0131 loss = 5.158821\n", - "Epoch: 0132 loss = 5.158039\n", - "Epoch: 0133 loss = 5.157377\n", - "Epoch: 0134 loss = 5.156784\n", - "Epoch: 0135 loss = 5.156210\n", - "Epoch: 0136 loss = 5.155619\n", - "Epoch: 0137 loss = 5.154994\n", - "Epoch: 0138 loss = 5.154339\n", - "Epoch: 0139 loss = 5.153667\n", - "Epoch: 0140 loss = 5.152996\n", - "Epoch: 0141 loss = 5.152328\n", - "Epoch: 0142 loss = 5.151660\n", - "Epoch: 0143 loss = 5.150975\n", - "Epoch: 0144 loss = 5.150242\n", - "Epoch: 0145 loss = 5.149413\n", - "Epoch: 0146 loss = 5.148431\n", - "Epoch: 0147 loss = 5.147236\n", - "Epoch: 0148 loss = 5.145802\n", - "Epoch: 0149 loss = 5.144192\n", - "Epoch: 0150 loss = 5.142632\n", - "Epoch: 0151 loss = 5.141498\n", - "Epoch: 0152 loss = 5.140967\n", - "Epoch: 0153 loss = 5.140542\n", - "Epoch: 0154 loss = 5.139736\n", - "Epoch: 0155 loss = 5.138566\n", - "Epoch: 0156 loss = 5.137214\n", - "Epoch: 0157 loss = 5.135843\n", - "Epoch: 0158 loss = 5.134551\n", - "Epoch: 0159 loss = 5.133349\n", - "Epoch: 0160 loss = 5.132212\n", - "Epoch: 0161 loss = 5.131142\n", - "Epoch: 0162 loss = 5.130210\n", - "Epoch: 0163 loss = 5.129475\n", - "Epoch: 0164 loss = 5.128819\n", - "Epoch: 0165 loss = 5.128070\n", - "Epoch: 0166 loss = 5.127200\n", - "Epoch: 0167 loss = 5.126271\n", - "Epoch: 0168 loss = 5.125329\n", - "Epoch: 0169 loss = 5.124373\n", - "Epoch: 0170 loss = 5.123369\n", - "Epoch: 0171 loss = 5.122267\n", - "Epoch: 0172 loss = 5.121037\n", - "Epoch: 0173 loss = 5.119725\n", - "Epoch: 0174 loss = 5.118514\n", - "Epoch: 0175 loss = 5.117722\n", - "Epoch: 0176 loss = 5.117445\n", - "Epoch: 0177 loss = 5.117110\n", - "Epoch: 0178 loss = 5.116376\n", - "Epoch: 0179 loss = 5.115437\n", - "Epoch: 0180 loss = 5.114553\n", - "Epoch: 0181 loss = 5.113859\n", - "Epoch: 0182 loss = 5.113349\n", - "Epoch: 0183 loss = 5.112945\n", - "Epoch: 0184 loss = 5.112563\n", - "Epoch: 0185 loss = 5.112154\n", - "Epoch: 0186 loss = 5.111700\n", - "Epoch: 0187 loss = 5.111213\n", - "Epoch: 0188 loss = 5.110720\n", - "Epoch: 0189 loss = 5.110251\n", - "Epoch: 0190 loss = 5.109831\n", - "Epoch: 0191 loss = 5.109465\n", - "Epoch: 0192 loss = 5.109145\n", - "Epoch: 0193 loss = 5.108854\n", - "Epoch: 0194 loss = 5.108573\n", - "Epoch: 0195 loss = 5.108290\n", - "Epoch: 0196 loss = 5.108000\n", - "Epoch: 0197 loss = 5.107710\n", - "Epoch: 0198 loss = 5.107430\n", - "Epoch: 0199 loss = 5.107163\n", - "Epoch: 0200 loss = 5.106915\n", - "Epoch: 0201 loss = 5.106686\n", - "Epoch: 0202 loss = 5.106472\n", - "Epoch: 0203 loss = 5.106268\n", - "Epoch: 0204 loss = 5.106072\n", - "Epoch: 0205 loss = 5.105881\n", - "Epoch: 0206 loss = 5.105694\n", - "Epoch: 0207 loss = 5.105511\n", - "Epoch: 0208 loss = 5.105332\n", - "Epoch: 0209 loss = 5.105159\n", - "Epoch: 0210 loss = 5.104992\n", - "Epoch: 0211 loss = 5.104831\n", - "Epoch: 0212 loss = 5.104676\n", - "Epoch: 0213 loss = 5.104527\n", - "Epoch: 0214 loss = 5.104386\n", - "Epoch: 0215 loss = 5.104248\n", - "Epoch: 0216 loss = 5.104114\n", - "Epoch: 0217 loss = 5.103983\n", - "Epoch: 0218 loss = 5.103856\n", - "Epoch: 0219 loss = 5.103732\n", - "Epoch: 0220 loss = 5.103611\n", - "Epoch: 0221 loss = 5.103494\n", - "Epoch: 0222 loss = 5.103378\n", - "Epoch: 0223 loss = 5.103266\n", - "Epoch: 0224 loss = 5.103157\n", - "Epoch: 0225 loss = 5.103050\n", - "Epoch: 0226 loss = 5.102947\n", - "Epoch: 0227 loss = 5.102844\n", - "Epoch: 0228 loss = 5.102745\n", - "Epoch: 0229 loss = 5.102647\n", - "Epoch: 0230 loss = 5.102551\n", - "Epoch: 0231 loss = 5.102457\n", - "Epoch: 0232 loss = 5.102361\n", - "Epoch: 0233 loss = 5.102268\n", - "Epoch: 0234 loss = 5.102171\n", - "Epoch: 0235 loss = 5.102070\n", - "Epoch: 0236 loss = 5.101960\n", - "Epoch: 0237 loss = 5.101832\n", - "Epoch: 0238 loss = 5.101676\n", - "Epoch: 0239 loss = 5.101472\n", - "Epoch: 0240 loss = 5.101196\n", - "Epoch: 0241 loss = 5.100809\n", - "Epoch: 0242 loss = 5.100273\n", - "Epoch: 0243 loss = 5.099566\n", - "Epoch: 0244 loss = 5.098716\n", - "Epoch: 0245 loss = 5.097848\n", - "Epoch: 0246 loss = 5.097132\n", - "Epoch: 0247 loss = 5.096595\n", - "Epoch: 0248 loss = 5.095999\n", - "Epoch: 0249 loss = 5.095338\n", - "Epoch: 0250 loss = 5.094879\n", - "Epoch: 0251 loss = 5.094688\n", - "Epoch: 0252 loss = 5.094509\n", - "Epoch: 0253 loss = 5.094189\n", - "Epoch: 0254 loss = 5.093769\n", - "Epoch: 0255 loss = 5.093320\n", - "Epoch: 0256 loss = 5.092857\n", - "Epoch: 0257 loss = 5.092339\n", - "Epoch: 0258 loss = 5.091694\n", - "Epoch: 0259 loss = 5.090842\n", - "Epoch: 0260 loss = 5.089743\n", - "Epoch: 0261 loss = 5.088472\n", - "Epoch: 0262 loss = 5.087311\n", - "Epoch: 0263 loss = 5.086748\n", - "Epoch: 0264 loss = 5.086977\n", - "Epoch: 0265 loss = 5.087138\n", - "Epoch: 0266 loss = 5.086704\n", - "Epoch: 0267 loss = 5.085968\n", - "Epoch: 0268 loss = 5.085289\n", - "Epoch: 0269 loss = 5.084831\n", - "Epoch: 0270 loss = 5.084577\n", - "Epoch: 0271 loss = 5.084425\n", - "Epoch: 0272 loss = 5.084273\n", - "Epoch: 0273 loss = 5.084045\n", - "Epoch: 0274 loss = 5.083702\n", - "Epoch: 0275 loss = 5.083218\n", - "Epoch: 0276 loss = 5.082567\n", - "Epoch: 0277 loss = 5.081728\n", - "Epoch: 0278 loss = 5.080690\n", - "Epoch: 0279 loss = 5.079515\n", - "Epoch: 0280 loss = 5.078421\n", - "Epoch: 0281 loss = 5.077838\n", - "Epoch: 0282 loss = 5.078005\n", - "Epoch: 0283 loss = 5.078021\n", - "Epoch: 0284 loss = 5.077395\n", - "Epoch: 0285 loss = 5.076532\n", - "Epoch: 0286 loss = 5.075821\n", - "Epoch: 0287 loss = 5.075374\n", - "Epoch: 0288 loss = 5.075100\n", - "Epoch: 0289 loss = 5.074864\n", - "Epoch: 0290 loss = 5.074569\n", - "Epoch: 0291 loss = 5.074176\n", - "Epoch: 0292 loss = 5.073692\n", - "Epoch: 0293 loss = 5.073163\n", - "Epoch: 0294 loss = 5.072643\n", - "Epoch: 0295 loss = 5.072181\n", - "Epoch: 0296 loss = 5.071800\n", - "Epoch: 0297 loss = 5.071474\n", - "Epoch: 0298 loss = 5.071146\n", - "Epoch: 0299 loss = 5.070775\n", - "Epoch: 0300 loss = 5.070364\n" + "Epoch 0 batch: 0 , loss = 5.509496\n", + "Epoch 1 batch: 0 , loss = 5.507990\n", + "Epoch 2 batch: 0 , loss = 5.506450\n", + "Epoch 3 batch: 0 , loss = 5.504800\n", + "Epoch 4 batch: 0 , loss = 5.502756\n", + "Epoch 5 batch: 0 , loss = 5.500911\n", + "Epoch 6 batch: 0 , loss = 5.498587\n", + "Epoch 7 batch: 0 , loss = 5.496222\n", + "Epoch 8 batch: 0 , loss = 5.493753\n", + "Epoch 9 batch: 0 , loss = 5.492003\n", + "Epoch 10 batch: 0 , loss = 5.489583\n", + "Epoch 11 batch: 0 , loss = 5.487563\n", + "Epoch 12 batch: 0 , loss = 5.485127\n", + "Epoch 13 batch: 0 , loss = 5.481408\n", + "Epoch 14 batch: 0 , loss = 5.479365\n", + "Epoch 15 batch: 0 , loss = 5.476476\n", + "Epoch 16 batch: 0 , loss = 5.475063\n", + "Epoch 17 batch: 0 , loss = 5.473976\n", + "Epoch 18 batch: 0 , loss = 5.471124\n", + "Epoch 19 batch: 0 , loss = 5.472827\n", + "Epoch 20 batch: 0 , loss = 5.466723\n", + "Epoch 21 batch: 0 , loss = 5.464960\n", + "Epoch 22 batch: 0 , loss = 5.466041\n", + "Epoch 23 batch: 0 , loss = 5.462078\n", + "Epoch 24 batch: 0 , loss = 5.460968\n", + "Epoch 25 batch: 0 , loss = 5.463361\n", + "Epoch 26 batch: 0 , loss = 5.462362\n", + "Epoch 27 batch: 0 , loss = 5.459974\n", + "Epoch 28 batch: 0 , loss = 5.462227\n", + "Epoch 29 batch: 0 , loss = 5.457432\n", + "Epoch 30 batch: 0 , loss = 5.456643\n", + "Epoch 31 batch: 0 , loss = 5.455560\n", + "Epoch 32 batch: 0 , loss = 5.454978\n", + "Epoch 33 batch: 0 , loss = 5.452355\n", + "Epoch 34 batch: 0 , loss = 5.449059\n", + "Epoch 35 batch: 0 , loss = 5.450156\n", + "Epoch 36 batch: 0 , loss = 5.443484\n", + "Epoch 37 batch: 0 , loss = 5.441142\n", + "Epoch 38 batch: 0 , loss = 5.438660\n", + "Epoch 39 batch: 0 , loss = 5.437134\n", + "Epoch 40 batch: 0 , loss = 5.434308\n", + "Epoch 41 batch: 0 , loss = 5.431943\n", + "Epoch 42 batch: 0 , loss = 5.429959\n", + "Epoch 43 batch: 0 , loss = 5.428108\n", + "Epoch 44 batch: 0 , loss = 5.426363\n", + "Epoch 45 batch: 0 , loss = 5.428246\n", + "Epoch 46 batch: 0 , loss = 5.424921\n", + "Epoch 47 batch: 0 , loss = 5.425705\n", + "Epoch 48 batch: 0 , loss = 5.424184\n", + "Epoch 49 batch: 0 , loss = 5.423537\n", + "Epoch 50 batch: 0 , loss = 5.424012\n", + "Epoch 51 batch: 0 , loss = 5.419333\n", + "Epoch 52 batch: 0 , loss = 5.414816\n", + "Epoch 53 batch: 0 , loss = 5.418826\n", + "Epoch 54 batch: 0 , loss = 5.418625\n", + "Epoch 55 batch: 0 , loss = 5.419236\n", + "Epoch 56 batch: 0 , loss = 5.421367\n", + "Epoch 57 batch: 0 , loss = 5.418901\n", + "Epoch 58 batch: 0 , loss = 5.416471\n", + "Epoch 59 batch: 0 , loss = 5.414352\n", + "Epoch 60 batch: 0 , loss = 5.416772\n", + "Epoch 61 batch: 0 , loss = 5.412184\n", + "Epoch 62 batch: 0 , loss = 5.407087\n", + "Epoch 63 batch: 0 , loss = 5.404146\n", + "Epoch 64 batch: 0 , loss = 5.405499\n", + "Epoch 65 batch: 0 , loss = 5.404672\n", + "Epoch 66 batch: 0 , loss = 5.399035\n", + "Epoch 67 batch: 0 , loss = 5.407500\n", + "Epoch 68 batch: 0 , loss = 5.403162\n", + "Epoch 69 batch: 0 , loss = 5.397717\n", + "Epoch 70 batch: 0 , loss = 5.398447\n", + "Epoch 71 batch: 0 , loss = 5.398669\n", + "Epoch 72 batch: 0 , loss = 5.392782\n", + "Epoch 73 batch: 0 , loss = 5.392102\n", + "Epoch 74 batch: 0 , loss = 5.386912\n", + "Epoch 75 batch: 0 , loss = 5.384616\n", + "Epoch 76 batch: 0 , loss = 5.382891\n", + "Epoch 77 batch: 0 , loss = 5.381707\n", + "Epoch 78 batch: 0 , loss = 5.376785\n", + "Epoch 79 batch: 0 , loss = 5.374002\n", + "Epoch 80 batch: 0 , loss = 5.376191\n", + "Epoch 81 batch: 0 , loss = 5.381535\n", + "Epoch 82 batch: 0 , loss = 5.376332\n", + "Epoch 83 batch: 0 , loss = 5.372051\n", + "Epoch 84 batch: 0 , loss = 5.367785\n", + "Epoch 85 batch: 0 , loss = 5.367027\n", + "Epoch 86 batch: 0 , loss = 5.366450\n", + "Epoch 87 batch: 0 , loss = 5.364227\n", + "Epoch 88 batch: 0 , loss = 5.364299\n", + "Epoch 89 batch: 0 , loss = 5.363710\n", + "Epoch 90 batch: 0 , loss = 5.356105\n", + "Epoch 91 batch: 0 , loss = 5.353946\n", + "Epoch 92 batch: 0 , loss = 5.356158\n", + "Epoch 93 batch: 0 , loss = 5.355119\n", + "Epoch 94 batch: 0 , loss = 5.348075\n", + "Epoch 95 batch: 0 , loss = 5.351038\n", + "Epoch 96 batch: 0 , loss = 5.349224\n", + "Epoch 97 batch: 0 , loss = 5.344652\n", + "Epoch 98 batch: 0 , loss = 5.341354\n", + "Epoch 99 batch: 0 , loss = 5.342059\n", + "Epoch 100 batch: 0 , loss = 5.344870\n", + "Epoch 101 batch: 0 , loss = 5.335486\n", + "Epoch 102 batch: 0 , loss = 5.339808\n", + "Epoch 103 batch: 0 , loss = 5.330384\n", + "Epoch 104 batch: 0 , loss = 5.333795\n", + "Epoch 105 batch: 0 , loss = 5.336934\n", + "Epoch 106 batch: 0 , loss = 5.334551\n", + "Epoch 107 batch: 0 , loss = 5.328543\n", + "Epoch 108 batch: 0 , loss = 5.329463\n", + "Epoch 109 batch: 0 , loss = 5.327895\n", + "Epoch 110 batch: 0 , loss = 5.324828\n", + "Epoch 111 batch: 0 , loss = 5.325212\n", + "Epoch 112 batch: 0 , loss = 5.320701\n", + "Epoch 113 batch: 0 , loss = 5.327697\n", + "Epoch 114 batch: 0 , loss = 5.321108\n", + "Epoch 115 batch: 0 , loss = 5.318196\n", + "Epoch 116 batch: 0 , loss = 5.313715\n", + "Epoch 117 batch: 0 , loss = 5.311478\n", + "Epoch 118 batch: 0 , loss = 5.310149\n", + "Epoch 119 batch: 0 , loss = 5.304658\n", + "Epoch 120 batch: 0 , loss = 5.299352\n", + "Epoch 121 batch: 0 , loss = 5.297422\n", + "Epoch 122 batch: 0 , loss = 5.296243\n", + "Epoch 123 batch: 0 , loss = 5.299733\n", + "Epoch 124 batch: 0 , loss = 5.296957\n", + "Epoch 125 batch: 0 , loss = 5.296990\n", + "Epoch 126 batch: 0 , loss = 5.295414\n", + "Epoch 127 batch: 0 , loss = 5.291314\n", + "Epoch 128 batch: 0 , loss = 5.288688\n", + "Epoch 129 batch: 0 , loss = 5.290800\n", + "Epoch 130 batch: 0 , loss = 5.290220\n", + "Epoch 131 batch: 0 , loss = 5.292923\n", + "Epoch 132 batch: 0 , loss = 5.282723\n", + "Epoch 133 batch: 0 , loss = 5.281114\n", + "Epoch 134 batch: 0 , loss = 5.284384\n", + "Epoch 135 batch: 0 , loss = 5.285651\n", + "Epoch 136 batch: 0 , loss = 5.279850\n", + "Epoch 137 batch: 0 , loss = 5.276040\n", + "Epoch 138 batch: 0 , loss = 5.279545\n", + "Epoch 139 batch: 0 , loss = 5.275866\n", + "Epoch 140 batch: 0 , loss = 5.275192\n", + "Epoch 141 batch: 0 , loss = 5.278750\n", + "Epoch 142 batch: 0 , loss = 5.281165\n", + "Epoch 143 batch: 0 , loss = 5.281812\n", + "Epoch 144 batch: 0 , loss = 5.270550\n", + "Epoch 145 batch: 0 , loss = 5.269819\n", + "Epoch 146 batch: 0 , loss = 5.268355\n", + "Epoch 147 batch: 0 , loss = 5.266929\n", + "Epoch 148 batch: 0 , loss = 5.261328\n", + "Epoch 149 batch: 0 , loss = 5.268577\n", + "Epoch 150 batch: 0 , loss = 5.260501\n", + "Epoch 151 batch: 0 , loss = 5.263970\n", + "Epoch 152 batch: 0 , loss = 5.261120\n", + "Epoch 153 batch: 0 , loss = 5.258860\n", + "Epoch 154 batch: 0 , loss = 5.253738\n", + "Epoch 155 batch: 0 , loss = 5.261181\n", + "Epoch 156 batch: 0 , loss = 5.246259\n", + "Epoch 157 batch: 0 , loss = 5.245531\n", + "Epoch 158 batch: 0 , loss = 5.244577\n", + "Epoch 159 batch: 0 , loss = 5.243086\n", + "Epoch 160 batch: 0 , loss = 5.241098\n", + "Epoch 161 batch: 0 , loss = 5.245851\n", + "Epoch 162 batch: 0 , loss = 5.245241\n", + "Epoch 163 batch: 0 , loss = 5.244030\n", + "Epoch 164 batch: 0 , loss = 5.238179\n", + "Epoch 165 batch: 0 , loss = 5.232537\n", + "Epoch 166 batch: 0 , loss = 5.238233\n", + "Epoch 167 batch: 0 , loss = 5.244127\n", + "Epoch 168 batch: 0 , loss = 5.238196\n", + "Epoch 169 batch: 0 , loss = 5.231448\n", + "Epoch 170 batch: 0 , loss = 5.231340\n", + "Epoch 171 batch: 0 , loss = 5.230532\n", + "Epoch 172 batch: 0 , loss = 5.232080\n", + "Epoch 173 batch: 0 , loss = 5.228147\n", + "Epoch 174 batch: 0 , loss = 5.225359\n", + "Epoch 175 batch: 0 , loss = 5.229460\n", + "Epoch 176 batch: 0 , loss = 5.223203\n", + "Epoch 177 batch: 0 , loss = 5.226225\n", + "Epoch 178 batch: 0 , loss = 5.224785\n", + "Epoch 179 batch: 0 , loss = 5.232825\n", + "Epoch 180 batch: 0 , loss = 5.227349\n", + "Epoch 181 batch: 0 , loss = 5.222055\n", + "Epoch 182 batch: 0 , loss = 5.212518\n", + "Epoch 183 batch: 0 , loss = 5.221508\n", + "Epoch 184 batch: 0 , loss = 5.218023\n", + "Epoch 185 batch: 0 , loss = 5.221085\n", + "Epoch 186 batch: 0 , loss = 5.212870\n", + "Epoch 187 batch: 0 , loss = 5.217089\n", + "Epoch 188 batch: 0 , loss = 5.220237\n", + "Epoch 189 batch: 0 , loss = 5.220267\n", + "Epoch 190 batch: 0 , loss = 5.216769\n", + "Epoch 191 batch: 0 , loss = 5.218822\n", + "Epoch 192 batch: 0 , loss = 5.220790\n", + "Epoch 193 batch: 0 , loss = 5.206959\n", + "Epoch 194 batch: 0 , loss = 5.204648\n", + "Epoch 195 batch: 0 , loss = 5.202166\n", + "Epoch 196 batch: 0 , loss = 5.199206\n", + "Epoch 197 batch: 0 , loss = 5.200701\n", + "Epoch 198 batch: 0 , loss = 5.187927\n", + "Epoch 199 batch: 0 , loss = 5.194535\n", + "Epoch 200 batch: 0 , loss = 5.190029\n", + "Epoch 201 batch: 0 , loss = 5.189040\n", + "Epoch 202 batch: 0 , loss = 5.190596\n", + "Epoch 203 batch: 0 , loss = 5.189606\n", + "Epoch 204 batch: 0 , loss = 5.186440\n", + "Epoch 205 batch: 0 , loss = 5.191073\n", + "Epoch 206 batch: 0 , loss = 5.184719\n", + "Epoch 207 batch: 0 , loss = 5.179452\n", + "Epoch 208 batch: 0 , loss = 5.173555\n", + "Epoch 209 batch: 0 , loss = 5.180843\n", + "Epoch 210 batch: 0 , loss = 5.186315\n", + "Epoch 211 batch: 0 , loss = 5.182886\n", + "Epoch 212 batch: 0 , loss = 5.179914\n", + "Epoch 213 batch: 0 , loss = 5.175862\n", + "Epoch 214 batch: 0 , loss = 5.183718\n", + "Epoch 215 batch: 0 , loss = 5.162323\n", + "Epoch 216 batch: 0 , loss = 5.170660\n", + "Epoch 217 batch: 0 , loss = 5.169362\n", + "Epoch 218 batch: 0 , loss = 5.162789\n", + "Epoch 219 batch: 0 , loss = 5.160868\n", + "Epoch 220 batch: 0 , loss = 5.164267\n", + "Epoch 221 batch: 0 , loss = 5.164003\n", + "Epoch 222 batch: 0 , loss = 5.165540\n", + "Epoch 223 batch: 0 , loss = 5.161615\n", + "Epoch 224 batch: 0 , loss = 5.152071\n", + "Epoch 225 batch: 0 , loss = 5.162641\n", + "Epoch 226 batch: 0 , loss = 5.158351\n", + "Epoch 227 batch: 0 , loss = 5.167762\n", + "Epoch 228 batch: 0 , loss = 5.166461\n", + "Epoch 229 batch: 0 , loss = 5.157466\n", + "Epoch 230 batch: 0 , loss = 5.156347\n", + "Epoch 231 batch: 0 , loss = 5.156344\n", + "Epoch 232 batch: 0 , loss = 5.164078\n", + "Epoch 233 batch: 0 , loss = 5.155646\n", + "Epoch 234 batch: 0 , loss = 5.154130\n", + "Epoch 235 batch: 0 , loss = 5.152800\n", + "Epoch 236 batch: 0 , loss = 5.161687\n", + "Epoch 237 batch: 0 , loss = 5.151773\n", + "Epoch 238 batch: 0 , loss = 5.146783\n", + "Epoch 239 batch: 0 , loss = 5.146691\n", + "Epoch 240 batch: 0 , loss = 5.151015\n", + "Epoch 241 batch: 0 , loss = 5.145954\n", + "Epoch 242 batch: 0 , loss = 5.154754\n", + "Epoch 243 batch: 0 , loss = 5.144487\n", + "Epoch 244 batch: 0 , loss = 5.152771\n", + "Epoch 245 batch: 0 , loss = 5.156934\n", + "Epoch 246 batch: 0 , loss = 5.146832\n", + "Epoch 247 batch: 0 , loss = 5.148869\n", + "Epoch 248 batch: 0 , loss = 5.142589\n", + "Epoch 249 batch: 0 , loss = 5.141822\n", + "Epoch 250 batch: 0 , loss = 5.144300\n", + "Epoch 251 batch: 0 , loss = 5.140250\n", + "Epoch 252 batch: 0 , loss = 5.135976\n", + "Epoch 253 batch: 0 , loss = 5.144660\n", + "Epoch 254 batch: 0 , loss = 5.144534\n", + "Epoch 255 batch: 0 , loss = 5.140327\n", + "Epoch 256 batch: 0 , loss = 5.139380\n", + "Epoch 257 batch: 0 , loss = 5.149578\n", + "Epoch 258 batch: 0 , loss = 5.139646\n", + "Epoch 259 batch: 0 , loss = 5.139930\n", + "Epoch 260 batch: 0 , loss = 5.143685\n", + "Epoch 261 batch: 0 , loss = 5.138163\n", + "Epoch 262 batch: 0 , loss = 5.138759\n", + "Epoch 263 batch: 0 , loss = 5.133649\n", + "Epoch 264 batch: 0 , loss = 5.137669\n", + "Epoch 265 batch: 0 , loss = 5.133360\n", + "Epoch 266 batch: 0 , loss = 5.137102\n", + "Epoch 267 batch: 0 , loss = 5.144728\n", + "Epoch 268 batch: 0 , loss = 5.135507\n", + "Epoch 269 batch: 0 , loss = 5.133213\n", + "Epoch 270 batch: 0 , loss = 5.130654\n", + "Epoch 271 batch: 0 , loss = 5.128548\n", + "Epoch 272 batch: 0 , loss = 5.118666\n", + "Epoch 273 batch: 0 , loss = 5.121534\n", + "Epoch 274 batch: 0 , loss = 5.122122\n", + "Epoch 275 batch: 0 , loss = 5.127708\n", + "Epoch 276 batch: 0 , loss = 5.132959\n", + "Epoch 277 batch: 0 , loss = 5.129933\n", + "Epoch 278 batch: 0 , loss = 5.125852\n", + "Epoch 279 batch: 0 , loss = 5.120491\n", + "Epoch 280 batch: 0 , loss = 5.123435\n", + "Epoch 281 batch: 0 , loss = 5.129238\n", + "Epoch 282 batch: 0 , loss = 5.124735\n", + "Epoch 283 batch: 0 , loss = 5.116185\n", + "Epoch 284 batch: 0 , loss = 5.113766\n", + "Epoch 285 batch: 0 , loss = 5.114601\n", + "Epoch 286 batch: 0 , loss = 5.112402\n", + "Epoch 287 batch: 0 , loss = 5.111151\n", + "Epoch 288 batch: 0 , loss = 5.117523\n", + "Epoch 289 batch: 0 , loss = 5.117987\n", + "Epoch 290 batch: 0 , loss = 5.112226\n", + "Epoch 291 batch: 0 , loss = 5.103069\n", + "Epoch 292 batch: 0 , loss = 5.105778\n", + "Epoch 293 batch: 0 , loss = 5.117328\n", + "Epoch 294 batch: 0 , loss = 5.117546\n", + "Epoch 295 batch: 0 , loss = 5.114745\n", + "Epoch 296 batch: 0 , loss = 5.111210\n", + "Epoch 297 batch: 0 , loss = 5.116782\n", + "Epoch 298 batch: 0 , loss = 5.126735\n", + "Epoch 299 batch: 0 , loss = 5.110609\n", + "Epoch 300 batch: 0 , loss = 5.119059\n", + "Epoch 301 batch: 0 , loss = 5.115992\n", + "Epoch 302 batch: 0 , loss = 5.106349\n", + "Epoch 303 batch: 0 , loss = 5.104899\n", + "Epoch 304 batch: 0 , loss = 5.104232\n", + "Epoch 305 batch: 0 , loss = 5.112337\n", + "Epoch 306 batch: 0 , loss = 5.111980\n", + "Epoch 307 batch: 0 , loss = 5.106858\n", + "Epoch 308 batch: 0 , loss = 5.109094\n", + "Epoch 309 batch: 0 , loss = 5.111865\n", + "Epoch 310 batch: 0 , loss = 5.096600\n", + "Epoch 311 batch: 0 , loss = 5.106102\n", + "Epoch 312 batch: 0 , loss = 5.102217\n", + "Epoch 313 batch: 0 , loss = 5.111903\n", + "Epoch 314 batch: 0 , loss = 5.102543\n", + "Epoch 315 batch: 0 , loss = 5.101692\n", + "Epoch 316 batch: 0 , loss = 5.114990\n", + "Epoch 317 batch: 0 , loss = 5.108889\n", + "Epoch 318 batch: 0 , loss = 5.102445\n", + "Epoch 319 batch: 0 , loss = 5.096200\n", + "Epoch 320 batch: 0 , loss = 5.100635\n", + "Epoch 321 batch: 0 , loss = 5.099531\n", + "Epoch 322 batch: 0 , loss = 5.100465\n", + "Epoch 323 batch: 0 , loss = 5.095500\n", + "Epoch 324 batch: 0 , loss = 5.100604\n", + "Epoch 325 batch: 0 , loss = 5.099518\n", + "Epoch 326 batch: 0 , loss = 5.094840\n", + "Epoch 327 batch: 0 , loss = 5.093679\n", + "Epoch 328 batch: 0 , loss = 5.093080\n", + "Epoch 329 batch: 0 , loss = 5.095733\n", + "Epoch 330 batch: 0 , loss = 5.099861\n", + "Epoch 331 batch: 0 , loss = 5.090419\n", + "Epoch 332 batch: 0 , loss = 5.099437\n", + "Epoch 333 batch: 0 , loss = 5.098624\n", + "Epoch 334 batch: 0 , loss = 5.093289\n", + "Epoch 335 batch: 0 , loss = 5.097763\n", + "Epoch 336 batch: 0 , loss = 5.097146\n", + "Epoch 337 batch: 0 , loss = 5.101106\n", + "Epoch 338 batch: 0 , loss = 5.082275\n", + "Epoch 339 batch: 0 , loss = 5.087108\n", + "Epoch 340 batch: 0 , loss = 5.092202\n", + "Epoch 341 batch: 0 , loss = 5.082821\n", + "Epoch 342 batch: 0 , loss = 5.092401\n", + "Epoch 343 batch: 0 , loss = 5.097392\n", + "Epoch 344 batch: 0 , loss = 5.097518\n", + "Epoch 345 batch: 0 , loss = 5.095610\n", + "Epoch 346 batch: 0 , loss = 5.089358\n", + "Epoch 347 batch: 0 , loss = 5.084134\n", + "Epoch 348 batch: 0 , loss = 5.093090\n", + "Epoch 349 batch: 0 , loss = 5.094424\n", + "Epoch 350 batch: 0 , loss = 5.089258\n", + "Epoch 351 batch: 0 , loss = 5.089724\n", + "Epoch 352 batch: 0 , loss = 5.092799\n", + "Epoch 353 batch: 0 , loss = 5.089324\n", + "Epoch 354 batch: 0 , loss = 5.093751\n", + "Epoch 355 batch: 0 , loss = 5.088273\n", + "Epoch 356 batch: 0 , loss = 5.083136\n", + "Epoch 357 batch: 0 , loss = 5.087886\n", + "Epoch 358 batch: 0 , loss = 5.092871\n", + "Epoch 359 batch: 0 , loss = 5.077843\n", + "Epoch 360 batch: 0 , loss = 5.087354\n", + "Epoch 361 batch: 0 , loss = 5.077330\n", + "Epoch 362 batch: 0 , loss = 5.096845\n", + "Epoch 363 batch: 0 , loss = 5.081802\n", + "Epoch 364 batch: 0 , loss = 5.086153\n", + "Epoch 365 batch: 0 , loss = 5.081354\n", + "Epoch 366 batch: 0 , loss = 5.084242\n", + "Epoch 367 batch: 0 , loss = 5.094817\n", + "Epoch 368 batch: 0 , loss = 5.089900\n", + "Epoch 369 batch: 0 , loss = 5.089128\n", + "Epoch 370 batch: 0 , loss = 5.084967\n", + "Epoch 371 batch: 0 , loss = 5.085038\n", + "Epoch 372 batch: 0 , loss = 5.084878\n", + "Epoch 373 batch: 0 , loss = 5.085416\n", + "Epoch 374 batch: 0 , loss = 5.084856\n", + "Epoch 375 batch: 0 , loss = 5.085449\n", + "Epoch 376 batch: 0 , loss = 5.080372\n", + "Epoch 377 batch: 0 , loss = 5.079864\n", + "Epoch 378 batch: 0 , loss = 5.089789\n", + "Epoch 379 batch: 0 , loss = 5.084702\n", + "Epoch 380 batch: 0 , loss = 5.080189\n", + "Epoch 381 batch: 0 , loss = 5.080185\n", + "Epoch 382 batch: 0 , loss = 5.080244\n", + "Epoch 383 batch: 0 , loss = 5.080095\n", + "Epoch 384 batch: 0 , loss = 5.088795\n", + "Epoch 385 batch: 0 , loss = 5.084120\n", + "Epoch 386 batch: 0 , loss = 5.087992\n", + "Epoch 387 batch: 0 , loss = 5.082991\n", + "Epoch 388 batch: 0 , loss = 5.072812\n", + "Epoch 389 batch: 0 , loss = 5.077140\n", + "Epoch 390 batch: 0 , loss = 5.081978\n", + "Epoch 391 batch: 0 , loss = 5.086642\n", + "Epoch 392 batch: 0 , loss = 5.091150\n", + "Epoch 393 batch: 0 , loss = 5.081849\n", + "Epoch 394 batch: 0 , loss = 5.071902\n", + "Epoch 395 batch: 0 , loss = 5.090557\n", + "Epoch 396 batch: 0 , loss = 5.085775\n", + "Epoch 397 batch: 0 , loss = 5.076406\n", + "Epoch 398 batch: 0 , loss = 5.086689\n", + "Epoch 399 batch: 0 , loss = 5.082210\n", + "Epoch 400 batch: 0 , loss = 5.091569\n", + "Epoch 401 batch: 0 , loss = 5.082448\n", + "Epoch 402 batch: 0 , loss = 5.081671\n", + "Epoch 403 batch: 0 , loss = 5.086580\n", + "Epoch 404 batch: 0 , loss = 5.081904\n", + "Epoch 405 batch: 0 , loss = 5.082079\n", + "Epoch 406 batch: 0 , loss = 5.081985\n", + "Epoch 407 batch: 0 , loss = 5.086331\n", + "Epoch 408 batch: 0 , loss = 5.076781\n", + "Epoch 409 batch: 0 , loss = 5.086619\n", + "Epoch 410 batch: 0 , loss = 5.085722\n", + "Epoch 411 batch: 0 , loss = 5.081203\n", + "Epoch 412 batch: 0 , loss = 5.080491\n", + "Epoch 413 batch: 0 , loss = 5.086140\n", + "Epoch 414 batch: 0 , loss = 5.076138\n", + "Epoch 415 batch: 0 , loss = 5.090469\n", + "Epoch 416 batch: 0 , loss = 5.075500\n", + "Epoch 417 batch: 0 , loss = 5.080328\n", + "Epoch 418 batch: 0 , loss = 5.084836\n", + "Epoch 419 batch: 0 , loss = 5.079198\n", + "Epoch 420 batch: 0 , loss = 5.068509\n", + "Epoch 421 batch: 0 , loss = 5.081526\n", + "Epoch 422 batch: 0 , loss = 5.075669\n", + "Epoch 423 batch: 0 , loss = 5.070409\n", + "Epoch 424 batch: 0 , loss = 5.069685\n", + "Epoch 425 batch: 0 , loss = 5.074168\n", + "Epoch 426 batch: 0 , loss = 5.073371\n", + "Epoch 427 batch: 0 , loss = 5.077053\n", + "Epoch 428 batch: 0 , loss = 5.076373\n", + "Epoch 429 batch: 0 , loss = 5.065866\n", + "Epoch 430 batch: 0 , loss = 5.069530\n", + "Epoch 431 batch: 0 , loss = 5.063504\n", + "Epoch 432 batch: 0 , loss = 5.066977\n", + "Epoch 433 batch: 0 , loss = 5.061393\n", + "Epoch 434 batch: 0 , loss = 5.064750\n", + "Epoch 435 batch: 0 , loss = 5.063208\n", + "Epoch 436 batch: 0 , loss = 5.062410\n", + "Epoch 437 batch: 0 , loss = 5.056867\n", + "Epoch 438 batch: 0 , loss = 5.061165\n", + "Epoch 439 batch: 0 , loss = 5.055538\n", + "Epoch 440 batch: 0 , loss = 5.064088\n", + "Epoch 441 batch: 0 , loss = 5.068861\n", + "Epoch 442 batch: 0 , loss = 5.058971\n", + "Epoch 443 batch: 0 , loss = 5.054216\n", + "Epoch 444 batch: 0 , loss = 5.053985\n", + "Epoch 445 batch: 0 , loss = 5.053171\n", + "Epoch 446 batch: 0 , loss = 5.052192\n", + "Epoch 447 batch: 0 , loss = 5.055376\n", + "Epoch 448 batch: 0 , loss = 5.048327\n", + "Epoch 449 batch: 0 , loss = 5.047246\n", + "Epoch 450 batch: 0 , loss = 5.056007\n", + "Epoch 451 batch: 0 , loss = 5.045246\n", + "Epoch 452 batch: 0 , loss = 5.048958\n", + "Epoch 453 batch: 0 , loss = 5.043387\n", + "Epoch 454 batch: 0 , loss = 5.052229\n", + "Epoch 455 batch: 0 , loss = 5.036552\n", + "Epoch 456 batch: 0 , loss = 5.047226\n", + "Epoch 457 batch: 0 , loss = 5.047490\n", + "Epoch 458 batch: 0 , loss = 5.045058\n", + "Epoch 459 batch: 0 , loss = 5.034773\n", + "Epoch 460 batch: 0 , loss = 5.037952\n", + "Epoch 461 batch: 0 , loss = 5.040687\n", + "Epoch 462 batch: 0 , loss = 5.031157\n", + "Epoch 463 batch: 0 , loss = 5.032485\n", + "Epoch 464 batch: 0 , loss = 5.017200\n", + "Epoch 465 batch: 0 , loss = 5.020655\n", + "Epoch 466 batch: 0 , loss = 5.024525\n", + "Epoch 467 batch: 0 , loss = 5.021727\n", + "Epoch 468 batch: 0 , loss = 5.015924\n", + "Epoch 469 batch: 0 , loss = 5.019942\n", + "Epoch 470 batch: 0 , loss = 5.024877\n", + "Epoch 471 batch: 0 , loss = 5.019931\n", + "Epoch 472 batch: 0 , loss = 5.023381\n", + "Epoch 473 batch: 0 , loss = 5.007413\n", + "Epoch 474 batch: 0 , loss = 5.015634\n", + "Epoch 475 batch: 0 , loss = 5.019332\n", + "Epoch 476 batch: 0 , loss = 5.023258\n", + "Epoch 477 batch: 0 , loss = 5.007919\n", + "Epoch 478 batch: 0 , loss = 5.012198\n", + "Epoch 479 batch: 0 , loss = 5.016301\n", + "Epoch 480 batch: 0 , loss = 5.015327\n", + "Epoch 481 batch: 0 , loss = 5.015544\n", + "Epoch 482 batch: 0 , loss = 5.015000\n", + "Epoch 483 batch: 0 , loss = 5.009738\n", + "Epoch 484 batch: 0 , loss = 5.004784\n", + "Epoch 485 batch: 0 , loss = 5.009272\n", + "Epoch 486 batch: 0 , loss = 5.014519\n", + "Epoch 487 batch: 0 , loss = 5.014775\n", + "Epoch 488 batch: 0 , loss = 5.018719\n", + "Epoch 489 batch: 0 , loss = 5.009684\n", + "Epoch 490 batch: 0 , loss = 5.014743\n", + "Epoch 491 batch: 0 , loss = 5.015655\n", + "Epoch 492 batch: 0 , loss = 5.011545\n", + "Epoch 493 batch: 0 , loss = 5.015726\n", + "Epoch 494 batch: 0 , loss = 5.009212\n", + "Epoch 495 batch: 0 , loss = 5.008894\n", + "Epoch 496 batch: 0 , loss = 5.002744\n", + "Epoch 497 batch: 0 , loss = 5.007370\n", + "Epoch 498 batch: 0 , loss = 5.007164\n", + "Epoch 499 batch: 0 , loss = 5.007099\n" ] } ], + "source": [ + "# Train with data loader.\n", + "\n", + "\n", + "\n", + "model = NgmOne(device)\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.Adam(model.parameters(), lr=0.001)\n", + "\n", + "epoch = 500\n", + "batch_size = 200\n", + "for e in range(epoch):\n", + " train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n", + " for i_batch, sample_batched in enumerate(train_dataloader):\n", + " optimizer.zero_grad()\n", + " train = sample_batched[0]\n", + " train_mask = sample_batched[1]\n", + " label_index = sample_batched[2].to(device)\n", + "\n", + " # Forward pass\n", + " output = model(train, train_mask)\n", + " loss = criterion(output, label_index)\n", + "\n", + " # backward and optimize\n", + " loss.backward()\n", + " optimizer.step()\n", + " if i_batch % batch_size == 0:\n", + " print(\"Epoch\", e, \"batch:\",i_batch, ', loss =', '{:.6f}'.format(loss))\n", + " \n", + " \n", + " \n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "model = NgmOne(device)\n", "\n", - "EPOCHS = 300\n", + "EPOCHS = 1500\n", "criterion = nn.CrossEntropyLoss()\n", - "optimizer = optim.Adam(model.parameters(), lr=0.05)\n", + "optimizer = optim.Adam(model.parameters(), lr=0.007)\n", "\n", "with open(\"../data/relations-query-qald-9-linked.json\", \"r\") as f:\n", " relations = json.load(f)\n", @@ -591,7 +888,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 66, "metadata": {}, "outputs": [ { @@ -605,8 +902,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 408/408 [00:00<00:00, 814.37it/s] \n", - "100%|██████████| 408/408 [00:00<00:00, 101922.34it/s]\n" + "100%|██████████| 408/408 [00:00<00:00, 962.29it/s] \n", + "100%|██████████| 408/408 [00:00<00:00, 81687.72it/s]\n" ] }, { @@ -614,17 +911,8 @@ "output_type": "stream", "text": [ "Finished with batches\n", - "lowest confidence 0.19197923\n", - "Accuracy: 0.6323529411764706\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n", - "KeyboardInterrupt\n", - "\n" + "lowest confidence 0.14628477\n", + "Accuracy: 0.5245098039215687\n" ] } ], @@ -636,7 +924,6 @@ "\n", "output = output.detach().cpu().numpy()\n", "\n", - "\n", "prediction = [relations[np.argmax(pred).item()]for pred in output]\n", "probability = [pred[np.argmax(pred)] for pred in output]\n", "correct_pred = [corr_rels[i] for i in range(len(output))]\n",