diff --git a/Neural graph module/ngm.ipynb b/Neural graph module/ngm.ipynb index 14d945f4bdc93c70b54d5eb307a085a1777c93db..303e0909ef4f68fccbddc758a6c5d890c1eebfe3 100644 --- a/Neural graph module/ngm.ipynb +++ b/Neural graph module/ngm.ipynb @@ -1,999 +1,514 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import datasets\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.optim as optim\n", - "import pandas as pd\n", - "import numpy as np\n", - "from transformers import BertTokenizer, BertModel\n", - "from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments\n", - "from tqdm import tqdm\n", - "import json\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Use GPU if available\n", - "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", - "# Print cuda version\n", - "print(device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "class NgmOne(nn.Module):\n", - " def __init__(self, device):\n", - " super(NgmOne, self).__init__()\n", - " self.tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n", - " self.bert = BertModel.from_pretrained(\"bert-base-uncased\").to(device)\n", - " self.linear = nn.Linear(768, 247).to(device)\n", - " self.softmax = nn.Softmax(dim=1).to(device)\n", - " self.device = device\n", - " \n", - " def forward(self, tokenized_seq, tokenized_mask):\n", - " tokenized_seq= tokenized_seq.to(self.device)\n", - " tokenized_mask = tokenized_mask.to(self.device)\n", - " with torch.no_grad():\n", - " x = self.bert.forward(tokenized_seq, attention_mask=tokenized_mask)\n", - " x = x[0][:,0,:].to(self.device)\n", - " \n", - " x = self.linear(x)\n", - " x = self.softmax(x)\n", - " return x" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# def encode(batch):\n", - "# return tokenizer(batch, padding=\"max_length\", max_length=256, return_tensors=\"pt\")\n", - "\n", - "\n", - "# def convert_to_features(example_batch):\n", - "# input_encodings = encode(example_batch['text'])\n", - "# target_encodings = encode(example_batch['summary'])\n", - "\n", - "# labels = target_encodings['input_ids']\n", - "# decoder_input_ids = shift_tokens_right(\n", - "# labels, model.config.pad_token_id, model.config.decoder_start_token_id)\n", - "# labels[labels[:, :] == model.config.pad_token_id] = -100\n", - "\n", - "# encodings = {\n", - "# 'input_ids': input_encodings['input_ids'],\n", - "# 'attention_mask': input_encodings['attention_mask'],\n", - "# 'decoder_input_ids': decoder_input_ids,\n", - "# 'labels': labels,\n", - "# }\n", - "\n", - "# return encodings\n", - "\n", - "\n", - "# def get_dataset(path):\n", - "# df = pd.read_csv(path, sep=\",\", on_bad_lines='skip')\n", - "# dataset = datasets.Dataset.from_pandas(df)\n", - "# dataset = dataset.map(convert_to_features, batched=True)\n", - "# columns = ['input_ids', 'labels', 'decoder_input_ids', 'attention_mask', ]\n", - "# dataset.set_format(type='torch', columns=columns)\n", - "# return dataset\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def make_batch():\n", - " \"\"\"Triplet is a list of [subject entity, relation, object entity], None if not present\"\"\"\n", - "\n", - " # Load predicted data\n", - " pred = \"../data/qald-9-train-linked.json\"\n", - "\n", - " #Load gold data\n", - " gold = \"../data/qald-9-train-linked.json\"\n", - " print(\"Beginning making batch\")\n", - " with open(pred, \"r\") as p, open(gold, \"r\") as g:\n", - " pred = json.load(p)\n", - " gold = json.load(g)\n", - "\n", - " inputs = []\n", - " inputs_max_len = 0\n", - " for d in tqdm(pred[\"questions\"]):\n", - " question = d[\"question\"][0][\"string\"]\n", - " query = d[\"query\"][\"sparql\"]\n", - "\n", - " #Take the first tripletin query\n", - " trip = query.split(\"WHERE\")[1]\n", - " trip = trip.replace(\"{\", \"\").replace(\"}\", \"\")\n", - " triplet = trip.split(\" \")\n", - "\n", - " #remove empty strings\n", - " triplet = [x for x in triplet if x != \"\"]\n", - "\n", - " if len(triplet) != 3:\n", - " continue\n", - "\n", - " for t in triplet:\n", - " if not(t.find(\"?\")):\n", - " triplet[triplet.index(t)] = None\n", - "\n", - " #seq = \"[CLS] \" + question + \" [SEP] \"\n", - " if triplet[0] is not None:\n", - " #seq += \"[SUB] [SEP] \" + triplet[0]\n", - " # , padding=True, truncation=True)\n", - " tokenized_seq = tokenizer(question, \"[SUB]\", triplet[0], padding=True, truncation=True)\n", - " elif triplet[2] is not None:\n", - " #seq += \"[OBJ] [SEP] \" + triplet[2]\n", - " tokenized_seq = tokenizer(question, \"[OBJ]\", triplet[2], padding=True, truncation=True)\n", - "\n", - " if inputs_max_len < len(tokenized_seq[\"input_ids\"]):\n", - " inputs_max_len = len(tokenized_seq[\"input_ids\"])\n", - " inputs.append(list(tokenized_seq.values())[0])\n", - "\n", - " correct_rels_max_len = 0\n", - " correct_rels = []\n", - " for d in tqdm(gold[\"questions\"]):\n", - " question = d[\"question\"][0][\"string\"]\n", - " query = d[\"query\"][\"sparql\"]\n", - "\n", - " #Take the first tripletin query\n", - " trip = query.split(\"WHERE\")[1]\n", - " trip = trip.replace(\"{\", \"\").replace(\"}\", \"\")\n", - " triplet = trip.split(\" \")\n", - " \n", - "\n", - " #remove empty strings\n", - " triplet = [x for x in triplet if x != \"\"]\n", - "\n", - " if len(triplet) != 3:\n", - " continue\n", - " # tokenized = tokenizer(triplet[1], padding=True, truncation=True)\n", - " # if correct_rels_max_len < len(tokenized[\"input_ids\"]):\n", - " # correct_rels_max_len = len(tokenized[\"input_ids\"])\n", - "\n", - " correct_rels.append(triplet[1])#list(tokenized.values())[0])\n", - "\n", - " inputs_padded = np.array([i + [0]*(inputs_max_len-len(i)) for i in inputs])\n", - " #correct_rels_padded = np.array([i + [0]*(correct_rels_max_len-len(i)) for i in correct_rels])\n", - "\n", - " inputs_attention_mask = np.where(inputs_padded != 0, 1, 0)\n", - " #correct_rels_attention_mask = np.where(correct_rels_padded != 0, 1, 0)\n", - " print(\"Finished with batches\")\n", - " return torch.LongTensor(inputs_padded), torch.LongTensor(inputs_attention_mask), correct_rels #torch.IntTensor(correct_rels_padded), torch.LongTensor(correct_rels_attention_mask)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# training_args = Seq2SeqTrainingArguments(\n", - "# output_dir='./models/blackbox',\n", - "# num_train_epochs=1,\n", - "# per_device_train_batch_size=1,\n", - "# per_device_eval_batch_size=1,\n", - "# warmup_steps=10,\n", - "# weight_decay=0.01,\n", - "# logging_dir='./logs',\n", - "# )\n" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "metadata": {}, - "outputs": [], - "source": [ - "from torch.utils.data import Dataset, DataLoader\n", - "\n", - "class MyDataset(Dataset):\n", - " def __init__(self, inputs, attention_mask, correct_rels, relations):\n", - " self.inputs = inputs\n", - " self.attention_mask = attention_mask\n", - " self.correct_rels = correct_rels\n", - " self.relations = relations\n", - " \n", - "\n", - " def __len__(self):\n", - " return len(self.inputs)\n", - "\n", - " def __getitem__(self, idx):\n", - " return self.inputs[idx], self.attention_mask[idx], self.relations.index(self.correct_rels[idx])\n", - "\n", - "\n", - "\n", - "\n", - "#From scratch json creates data set.\n", - "# class MyDataset(Dataset):\n", - "# def __init__(self, json_file, transform=None):\n", - "# self.qald_data = json.load(json_file)\n", - "\n", - "# def __len__(self):\n", - "# return len(self.qald_data)\n", - "\n", - "# def __getitem__(self, idx):\n", - "# self.inputs[idx], self.attention_mask[idx], self.labels[idx]" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "metadata": {}, - "outputs": [ + "cells": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Beginning making batch\n" - ] + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\maxbj\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import datasets\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import pandas as pd\n", + "import numpy as np\n", + "from transformers import BertTokenizer, BertModel\n", + "from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments\n", + "from tqdm import tqdm\n", + "import json\n" + ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 408/408 [00:00<00:00, 955.51it/s] \n", - "100%|██████████| 408/408 [00:00<00:00, 136139.70it/s]" - ] + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")" + ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Finished with batches\n", - "features: tensor([[ 101, 2040, 2003, 1996, 3677, 1997, 1996, 4035, 6870, 19247,\n", - " 1029, 102, 1031, 4942, 1033, 102, 0, 0, 0, 0,\n", - " 0, 0, 0]]) mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]]) label_index tensor(81)\n" - ] + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Use GPU if available\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "# Print cuda version\n", + "print(device)" + ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "#Prepare data\n", - "\n", - "def open_json(file):\n", - " with open(file, \"r\") as f:\n", - " return json.load(f)\n", - "\n", - "relations = open_json(\"../data/relations-query-qald-9-linked.json\")\n", - "train_set = MyDataset(*make_batch(), relations=relations)\n", - "train_dataloader = DataLoader(train_set, batch_size=1, shuffle=True)\n", - "\n", - "#show first entry\n", - "train_features, train_mask, train_label = next(iter(train_dataloader))\n", - "print(\"features:\", train_features, \"mask:\",train_mask,\"label_index\", train_label[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 65, - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "class NgmOne(nn.Module):\n", + " def __init__(self, device):\n", + " super(NgmOne, self).__init__()\n", + " self.tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n", + " self.bert = BertModel.from_pretrained(\"bert-base-uncased\").to(device)\n", + " self.linear = nn.Linear(768, 247).to(device)\n", + " self.softmax = nn.Softmax(dim=1).to(device)\n", + " self.device = device\n", + " \n", + " def forward(self, tokenized_seq, tokenized_mask):\n", + " tokenized_seq= tokenized_seq.to(self.device)\n", + " tokenized_mask = tokenized_mask.to(self.device)\n", + " with torch.no_grad():\n", + " x = self.bert.forward(tokenized_seq, attention_mask=tokenized_mask)\n", + " x = x[0][:,0,:].to(self.device)\n", + " \n", + " x = self.linear(x)\n", + " x = self.softmax(x)\n", + " return x" + ] + }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']\n", - "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", - "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" - ] + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# def encode(batch):\n", + "# return tokenizer(batch, padding=\"max_length\", max_length=256, return_tensors=\"pt\")\n", + "\n", + "\n", + "# def convert_to_features(example_batch):\n", + "# input_encodings = encode(example_batch['text'])\n", + "# target_encodings = encode(example_batch['summary'])\n", + "\n", + "# labels = target_encodings['input_ids']\n", + "# decoder_input_ids = shift_tokens_right(\n", + "# labels, model.config.pad_token_id, model.config.decoder_start_token_id)\n", + "# labels[labels[:, :] == model.config.pad_token_id] = -100\n", + "\n", + "# encodings = {\n", + "# 'input_ids': input_encodings['input_ids'],\n", + "# 'attention_mask': input_encodings['attention_mask'],\n", + "# 'decoder_input_ids': decoder_input_ids,\n", + "# 'labels': labels,\n", + "# }\n", + "\n", + "# return encodings\n", + "\n", + "\n", + "# def get_dataset(path):\n", + "# df = pd.read_csv(path, sep=\",\", on_bad_lines='skip')\n", + "# dataset = datasets.Dataset.from_pandas(df)\n", + "# dataset = dataset.map(convert_to_features, batched=True)\n", + "# columns = ['input_ids', 'labels', 'decoder_input_ids', 'attention_mask', ]\n", + "# dataset.set_format(type='torch', columns=columns)\n", + "# return dataset\n" + ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 0 batch: 0 , loss = 5.509496\n", - "Epoch 1 batch: 0 , loss = 5.507990\n", - "Epoch 2 batch: 0 , loss = 5.506450\n", - "Epoch 3 batch: 0 , loss = 5.504800\n", - "Epoch 4 batch: 0 , loss = 5.502756\n", - "Epoch 5 batch: 0 , loss = 5.500911\n", - "Epoch 6 batch: 0 , loss = 5.498587\n", - "Epoch 7 batch: 0 , loss = 5.496222\n", - "Epoch 8 batch: 0 , loss = 5.493753\n", - "Epoch 9 batch: 0 , loss = 5.492003\n", - "Epoch 10 batch: 0 , loss = 5.489583\n", - "Epoch 11 batch: 0 , loss = 5.487563\n", - "Epoch 12 batch: 0 , loss = 5.485127\n", - "Epoch 13 batch: 0 , loss = 5.481408\n", - "Epoch 14 batch: 0 , loss = 5.479365\n", - "Epoch 15 batch: 0 , loss = 5.476476\n", - "Epoch 16 batch: 0 , loss = 5.475063\n", - "Epoch 17 batch: 0 , loss = 5.473976\n", - "Epoch 18 batch: 0 , loss = 5.471124\n", - "Epoch 19 batch: 0 , loss = 5.472827\n", - "Epoch 20 batch: 0 , loss = 5.466723\n", - "Epoch 21 batch: 0 , loss = 5.464960\n", - "Epoch 22 batch: 0 , loss = 5.466041\n", - "Epoch 23 batch: 0 , loss = 5.462078\n", - "Epoch 24 batch: 0 , loss = 5.460968\n", - "Epoch 25 batch: 0 , loss = 5.463361\n", - "Epoch 26 batch: 0 , loss = 5.462362\n", - "Epoch 27 batch: 0 , loss = 5.459974\n", - "Epoch 28 batch: 0 , loss = 5.462227\n", - "Epoch 29 batch: 0 , loss = 5.457432\n", - "Epoch 30 batch: 0 , loss = 5.456643\n", - "Epoch 31 batch: 0 , loss = 5.455560\n", - "Epoch 32 batch: 0 , loss = 5.454978\n", - "Epoch 33 batch: 0 , loss = 5.452355\n", - "Epoch 34 batch: 0 , loss = 5.449059\n", - "Epoch 35 batch: 0 , loss = 5.450156\n", - "Epoch 36 batch: 0 , loss = 5.443484\n", - "Epoch 37 batch: 0 , loss = 5.441142\n", - "Epoch 38 batch: 0 , loss = 5.438660\n", - "Epoch 39 batch: 0 , loss = 5.437134\n", - "Epoch 40 batch: 0 , loss = 5.434308\n", - "Epoch 41 batch: 0 , loss = 5.431943\n", - "Epoch 42 batch: 0 , loss = 5.429959\n", - "Epoch 43 batch: 0 , loss = 5.428108\n", - "Epoch 44 batch: 0 , loss = 5.426363\n", - "Epoch 45 batch: 0 , loss = 5.428246\n", - "Epoch 46 batch: 0 , loss = 5.424921\n", - "Epoch 47 batch: 0 , loss = 5.425705\n", - "Epoch 48 batch: 0 , loss = 5.424184\n", - "Epoch 49 batch: 0 , loss = 5.423537\n", - "Epoch 50 batch: 0 , loss = 5.424012\n", - "Epoch 51 batch: 0 , loss = 5.419333\n", - "Epoch 52 batch: 0 , loss = 5.414816\n", - "Epoch 53 batch: 0 , loss = 5.418826\n", - "Epoch 54 batch: 0 , loss = 5.418625\n", - "Epoch 55 batch: 0 , loss = 5.419236\n", - "Epoch 56 batch: 0 , loss = 5.421367\n", - "Epoch 57 batch: 0 , loss = 5.418901\n", - "Epoch 58 batch: 0 , loss = 5.416471\n", - "Epoch 59 batch: 0 , loss = 5.414352\n", - "Epoch 60 batch: 0 , loss = 5.416772\n", - "Epoch 61 batch: 0 , loss = 5.412184\n", - "Epoch 62 batch: 0 , loss = 5.407087\n", - "Epoch 63 batch: 0 , loss = 5.404146\n", - "Epoch 64 batch: 0 , loss = 5.405499\n", - "Epoch 65 batch: 0 , loss = 5.404672\n", - "Epoch 66 batch: 0 , loss = 5.399035\n", - "Epoch 67 batch: 0 , loss = 5.407500\n", - "Epoch 68 batch: 0 , loss = 5.403162\n", - "Epoch 69 batch: 0 , loss = 5.397717\n", - "Epoch 70 batch: 0 , loss = 5.398447\n", - "Epoch 71 batch: 0 , loss = 5.398669\n", - "Epoch 72 batch: 0 , loss = 5.392782\n", - "Epoch 73 batch: 0 , loss = 5.392102\n", - "Epoch 74 batch: 0 , loss = 5.386912\n", - "Epoch 75 batch: 0 , loss = 5.384616\n", - "Epoch 76 batch: 0 , loss = 5.382891\n", - "Epoch 77 batch: 0 , loss = 5.381707\n", - "Epoch 78 batch: 0 , loss = 5.376785\n", - "Epoch 79 batch: 0 , loss = 5.374002\n", - "Epoch 80 batch: 0 , loss = 5.376191\n", - "Epoch 81 batch: 0 , loss = 5.381535\n", - "Epoch 82 batch: 0 , loss = 5.376332\n", - "Epoch 83 batch: 0 , loss = 5.372051\n", - "Epoch 84 batch: 0 , loss = 5.367785\n", - "Epoch 85 batch: 0 , loss = 5.367027\n", - "Epoch 86 batch: 0 , loss = 5.366450\n", - "Epoch 87 batch: 0 , loss = 5.364227\n", - "Epoch 88 batch: 0 , loss = 5.364299\n", - "Epoch 89 batch: 0 , loss = 5.363710\n", - "Epoch 90 batch: 0 , loss = 5.356105\n", - "Epoch 91 batch: 0 , loss = 5.353946\n", - "Epoch 92 batch: 0 , loss = 5.356158\n", - "Epoch 93 batch: 0 , loss = 5.355119\n", - "Epoch 94 batch: 0 , loss = 5.348075\n", - "Epoch 95 batch: 0 , loss = 5.351038\n", - "Epoch 96 batch: 0 , loss = 5.349224\n", - "Epoch 97 batch: 0 , loss = 5.344652\n", - "Epoch 98 batch: 0 , loss = 5.341354\n", - "Epoch 99 batch: 0 , loss = 5.342059\n", - "Epoch 100 batch: 0 , loss = 5.344870\n", - "Epoch 101 batch: 0 , loss = 5.335486\n", - "Epoch 102 batch: 0 , loss = 5.339808\n", - "Epoch 103 batch: 0 , loss = 5.330384\n", - "Epoch 104 batch: 0 , loss = 5.333795\n", - "Epoch 105 batch: 0 , loss = 5.336934\n", - "Epoch 106 batch: 0 , loss = 5.334551\n", - "Epoch 107 batch: 0 , loss = 5.328543\n", - "Epoch 108 batch: 0 , loss = 5.329463\n", - "Epoch 109 batch: 0 , loss = 5.327895\n", - "Epoch 110 batch: 0 , loss = 5.324828\n", - "Epoch 111 batch: 0 , loss = 5.325212\n", - "Epoch 112 batch: 0 , loss = 5.320701\n", - "Epoch 113 batch: 0 , loss = 5.327697\n", - "Epoch 114 batch: 0 , loss = 5.321108\n", - "Epoch 115 batch: 0 , loss = 5.318196\n", - "Epoch 116 batch: 0 , loss = 5.313715\n", - "Epoch 117 batch: 0 , loss = 5.311478\n", - "Epoch 118 batch: 0 , loss = 5.310149\n", - "Epoch 119 batch: 0 , loss = 5.304658\n", - "Epoch 120 batch: 0 , loss = 5.299352\n", - "Epoch 121 batch: 0 , loss = 5.297422\n", - "Epoch 122 batch: 0 , loss = 5.296243\n", - "Epoch 123 batch: 0 , loss = 5.299733\n", - "Epoch 124 batch: 0 , loss = 5.296957\n", - "Epoch 125 batch: 0 , loss = 5.296990\n", - "Epoch 126 batch: 0 , loss = 5.295414\n", - "Epoch 127 batch: 0 , loss = 5.291314\n", - "Epoch 128 batch: 0 , loss = 5.288688\n", - "Epoch 129 batch: 0 , loss = 5.290800\n", - "Epoch 130 batch: 0 , loss = 5.290220\n", - "Epoch 131 batch: 0 , loss = 5.292923\n", - "Epoch 132 batch: 0 , loss = 5.282723\n", - "Epoch 133 batch: 0 , loss = 5.281114\n", - "Epoch 134 batch: 0 , loss = 5.284384\n", - "Epoch 135 batch: 0 , loss = 5.285651\n", - "Epoch 136 batch: 0 , loss = 5.279850\n", - "Epoch 137 batch: 0 , loss = 5.276040\n", - "Epoch 138 batch: 0 , loss = 5.279545\n", - "Epoch 139 batch: 0 , loss = 5.275866\n", - "Epoch 140 batch: 0 , loss = 5.275192\n", - "Epoch 141 batch: 0 , loss = 5.278750\n", - "Epoch 142 batch: 0 , loss = 5.281165\n", - "Epoch 143 batch: 0 , loss = 5.281812\n", - "Epoch 144 batch: 0 , loss = 5.270550\n", - "Epoch 145 batch: 0 , loss = 5.269819\n", - "Epoch 146 batch: 0 , loss = 5.268355\n", - "Epoch 147 batch: 0 , loss = 5.266929\n", - "Epoch 148 batch: 0 , loss = 5.261328\n", - "Epoch 149 batch: 0 , loss = 5.268577\n", - "Epoch 150 batch: 0 , loss = 5.260501\n", - "Epoch 151 batch: 0 , loss = 5.263970\n", - "Epoch 152 batch: 0 , loss = 5.261120\n", - "Epoch 153 batch: 0 , loss = 5.258860\n", - "Epoch 154 batch: 0 , loss = 5.253738\n", - "Epoch 155 batch: 0 , loss = 5.261181\n", - "Epoch 156 batch: 0 , loss = 5.246259\n", - "Epoch 157 batch: 0 , loss = 5.245531\n", - "Epoch 158 batch: 0 , loss = 5.244577\n", - "Epoch 159 batch: 0 , loss = 5.243086\n", - "Epoch 160 batch: 0 , loss = 5.241098\n", - "Epoch 161 batch: 0 , loss = 5.245851\n", - "Epoch 162 batch: 0 , loss = 5.245241\n", - "Epoch 163 batch: 0 , loss = 5.244030\n", - "Epoch 164 batch: 0 , loss = 5.238179\n", - "Epoch 165 batch: 0 , loss = 5.232537\n", - "Epoch 166 batch: 0 , loss = 5.238233\n", - "Epoch 167 batch: 0 , loss = 5.244127\n", - "Epoch 168 batch: 0 , loss = 5.238196\n", - "Epoch 169 batch: 0 , loss = 5.231448\n", - "Epoch 170 batch: 0 , loss = 5.231340\n", - "Epoch 171 batch: 0 , loss = 5.230532\n", - "Epoch 172 batch: 0 , loss = 5.232080\n", - "Epoch 173 batch: 0 , loss = 5.228147\n", - "Epoch 174 batch: 0 , loss = 5.225359\n", - "Epoch 175 batch: 0 , loss = 5.229460\n", - "Epoch 176 batch: 0 , loss = 5.223203\n", - "Epoch 177 batch: 0 , loss = 5.226225\n", - "Epoch 178 batch: 0 , loss = 5.224785\n", - "Epoch 179 batch: 0 , loss = 5.232825\n", - "Epoch 180 batch: 0 , loss = 5.227349\n", - "Epoch 181 batch: 0 , loss = 5.222055\n", - "Epoch 182 batch: 0 , loss = 5.212518\n", - "Epoch 183 batch: 0 , loss = 5.221508\n", - "Epoch 184 batch: 0 , loss = 5.218023\n", - "Epoch 185 batch: 0 , loss = 5.221085\n", - "Epoch 186 batch: 0 , loss = 5.212870\n", - "Epoch 187 batch: 0 , loss = 5.217089\n", - "Epoch 188 batch: 0 , loss = 5.220237\n", - "Epoch 189 batch: 0 , loss = 5.220267\n", - "Epoch 190 batch: 0 , loss = 5.216769\n", - "Epoch 191 batch: 0 , loss = 5.218822\n", - "Epoch 192 batch: 0 , loss = 5.220790\n", - "Epoch 193 batch: 0 , loss = 5.206959\n", - "Epoch 194 batch: 0 , loss = 5.204648\n", - "Epoch 195 batch: 0 , loss = 5.202166\n", - "Epoch 196 batch: 0 , loss = 5.199206\n", - "Epoch 197 batch: 0 , loss = 5.200701\n", - "Epoch 198 batch: 0 , loss = 5.187927\n", - "Epoch 199 batch: 0 , loss = 5.194535\n", - "Epoch 200 batch: 0 , loss = 5.190029\n", - "Epoch 201 batch: 0 , loss = 5.189040\n", - "Epoch 202 batch: 0 , loss = 5.190596\n", - "Epoch 203 batch: 0 , loss = 5.189606\n", - "Epoch 204 batch: 0 , loss = 5.186440\n", - "Epoch 205 batch: 0 , loss = 5.191073\n", - "Epoch 206 batch: 0 , loss = 5.184719\n", - "Epoch 207 batch: 0 , loss = 5.179452\n", - "Epoch 208 batch: 0 , loss = 5.173555\n", - "Epoch 209 batch: 0 , loss = 5.180843\n", - "Epoch 210 batch: 0 , loss = 5.186315\n", - "Epoch 211 batch: 0 , loss = 5.182886\n", - "Epoch 212 batch: 0 , loss = 5.179914\n", - "Epoch 213 batch: 0 , loss = 5.175862\n", - "Epoch 214 batch: 0 , loss = 5.183718\n", - "Epoch 215 batch: 0 , loss = 5.162323\n", - "Epoch 216 batch: 0 , loss = 5.170660\n", - "Epoch 217 batch: 0 , loss = 5.169362\n", - "Epoch 218 batch: 0 , loss = 5.162789\n", - "Epoch 219 batch: 0 , loss = 5.160868\n", - "Epoch 220 batch: 0 , loss = 5.164267\n", - "Epoch 221 batch: 0 , loss = 5.164003\n", - "Epoch 222 batch: 0 , loss = 5.165540\n", - "Epoch 223 batch: 0 , loss = 5.161615\n", - "Epoch 224 batch: 0 , loss = 5.152071\n", - "Epoch 225 batch: 0 , loss = 5.162641\n", - "Epoch 226 batch: 0 , loss = 5.158351\n", - "Epoch 227 batch: 0 , loss = 5.167762\n", - "Epoch 228 batch: 0 , loss = 5.166461\n", - "Epoch 229 batch: 0 , loss = 5.157466\n", - "Epoch 230 batch: 0 , loss = 5.156347\n", - "Epoch 231 batch: 0 , loss = 5.156344\n", - "Epoch 232 batch: 0 , loss = 5.164078\n", - "Epoch 233 batch: 0 , loss = 5.155646\n", - "Epoch 234 batch: 0 , loss = 5.154130\n", - "Epoch 235 batch: 0 , loss = 5.152800\n", - "Epoch 236 batch: 0 , loss = 5.161687\n", - "Epoch 237 batch: 0 , loss = 5.151773\n", - "Epoch 238 batch: 0 , loss = 5.146783\n", - "Epoch 239 batch: 0 , loss = 5.146691\n", - "Epoch 240 batch: 0 , loss = 5.151015\n", - "Epoch 241 batch: 0 , loss = 5.145954\n", - "Epoch 242 batch: 0 , loss = 5.154754\n", - "Epoch 243 batch: 0 , loss = 5.144487\n", - "Epoch 244 batch: 0 , loss = 5.152771\n", - "Epoch 245 batch: 0 , loss = 5.156934\n", - "Epoch 246 batch: 0 , loss = 5.146832\n", - "Epoch 247 batch: 0 , loss = 5.148869\n", - "Epoch 248 batch: 0 , loss = 5.142589\n", - "Epoch 249 batch: 0 , loss = 5.141822\n", - "Epoch 250 batch: 0 , loss = 5.144300\n", - "Epoch 251 batch: 0 , loss = 5.140250\n", - "Epoch 252 batch: 0 , loss = 5.135976\n", - "Epoch 253 batch: 0 , loss = 5.144660\n", - "Epoch 254 batch: 0 , loss = 5.144534\n", - "Epoch 255 batch: 0 , loss = 5.140327\n", - "Epoch 256 batch: 0 , loss = 5.139380\n", - "Epoch 257 batch: 0 , loss = 5.149578\n", - "Epoch 258 batch: 0 , loss = 5.139646\n", - "Epoch 259 batch: 0 , loss = 5.139930\n", - "Epoch 260 batch: 0 , loss = 5.143685\n", - "Epoch 261 batch: 0 , loss = 5.138163\n", - "Epoch 262 batch: 0 , loss = 5.138759\n", - "Epoch 263 batch: 0 , loss = 5.133649\n", - "Epoch 264 batch: 0 , loss = 5.137669\n", - "Epoch 265 batch: 0 , loss = 5.133360\n", - "Epoch 266 batch: 0 , loss = 5.137102\n", - "Epoch 267 batch: 0 , loss = 5.144728\n", - "Epoch 268 batch: 0 , loss = 5.135507\n", - "Epoch 269 batch: 0 , loss = 5.133213\n", - "Epoch 270 batch: 0 , loss = 5.130654\n", - "Epoch 271 batch: 0 , loss = 5.128548\n", - "Epoch 272 batch: 0 , loss = 5.118666\n", - "Epoch 273 batch: 0 , loss = 5.121534\n", - "Epoch 274 batch: 0 , loss = 5.122122\n", - "Epoch 275 batch: 0 , loss = 5.127708\n", - "Epoch 276 batch: 0 , loss = 5.132959\n", - "Epoch 277 batch: 0 , loss = 5.129933\n", - "Epoch 278 batch: 0 , loss = 5.125852\n", - "Epoch 279 batch: 0 , loss = 5.120491\n", - "Epoch 280 batch: 0 , loss = 5.123435\n", - "Epoch 281 batch: 0 , loss = 5.129238\n", - "Epoch 282 batch: 0 , loss = 5.124735\n", - "Epoch 283 batch: 0 , loss = 5.116185\n", - "Epoch 284 batch: 0 , loss = 5.113766\n", - "Epoch 285 batch: 0 , loss = 5.114601\n", - "Epoch 286 batch: 0 , loss = 5.112402\n", - "Epoch 287 batch: 0 , loss = 5.111151\n", - "Epoch 288 batch: 0 , loss = 5.117523\n", - "Epoch 289 batch: 0 , loss = 5.117987\n", - "Epoch 290 batch: 0 , loss = 5.112226\n", - "Epoch 291 batch: 0 , loss = 5.103069\n", - "Epoch 292 batch: 0 , loss = 5.105778\n", - "Epoch 293 batch: 0 , loss = 5.117328\n", - "Epoch 294 batch: 0 , loss = 5.117546\n", - "Epoch 295 batch: 0 , loss = 5.114745\n", - "Epoch 296 batch: 0 , loss = 5.111210\n", - "Epoch 297 batch: 0 , loss = 5.116782\n", - "Epoch 298 batch: 0 , loss = 5.126735\n", - "Epoch 299 batch: 0 , loss = 5.110609\n", - "Epoch 300 batch: 0 , loss = 5.119059\n", - "Epoch 301 batch: 0 , loss = 5.115992\n", - "Epoch 302 batch: 0 , loss = 5.106349\n", - "Epoch 303 batch: 0 , loss = 5.104899\n", - "Epoch 304 batch: 0 , loss = 5.104232\n", - "Epoch 305 batch: 0 , loss = 5.112337\n", - "Epoch 306 batch: 0 , loss = 5.111980\n", - "Epoch 307 batch: 0 , loss = 5.106858\n", - "Epoch 308 batch: 0 , loss = 5.109094\n", - "Epoch 309 batch: 0 , loss = 5.111865\n", - "Epoch 310 batch: 0 , loss = 5.096600\n", - "Epoch 311 batch: 0 , loss = 5.106102\n", - "Epoch 312 batch: 0 , loss = 5.102217\n", - "Epoch 313 batch: 0 , loss = 5.111903\n", - "Epoch 314 batch: 0 , loss = 5.102543\n", - "Epoch 315 batch: 0 , loss = 5.101692\n", - "Epoch 316 batch: 0 , loss = 5.114990\n", - "Epoch 317 batch: 0 , loss = 5.108889\n", - "Epoch 318 batch: 0 , loss = 5.102445\n", - "Epoch 319 batch: 0 , loss = 5.096200\n", - "Epoch 320 batch: 0 , loss = 5.100635\n", - "Epoch 321 batch: 0 , loss = 5.099531\n", - "Epoch 322 batch: 0 , loss = 5.100465\n", - "Epoch 323 batch: 0 , loss = 5.095500\n", - "Epoch 324 batch: 0 , loss = 5.100604\n", - "Epoch 325 batch: 0 , loss = 5.099518\n", - "Epoch 326 batch: 0 , loss = 5.094840\n", - "Epoch 327 batch: 0 , loss = 5.093679\n", - "Epoch 328 batch: 0 , loss = 5.093080\n", - "Epoch 329 batch: 0 , loss = 5.095733\n", - "Epoch 330 batch: 0 , loss = 5.099861\n", - "Epoch 331 batch: 0 , loss = 5.090419\n", - "Epoch 332 batch: 0 , loss = 5.099437\n", - "Epoch 333 batch: 0 , loss = 5.098624\n", - "Epoch 334 batch: 0 , loss = 5.093289\n", - "Epoch 335 batch: 0 , loss = 5.097763\n", - "Epoch 336 batch: 0 , loss = 5.097146\n", - "Epoch 337 batch: 0 , loss = 5.101106\n", - "Epoch 338 batch: 0 , loss = 5.082275\n", - "Epoch 339 batch: 0 , loss = 5.087108\n", - "Epoch 340 batch: 0 , loss = 5.092202\n", - "Epoch 341 batch: 0 , loss = 5.082821\n", - "Epoch 342 batch: 0 , loss = 5.092401\n", - "Epoch 343 batch: 0 , loss = 5.097392\n", - "Epoch 344 batch: 0 , loss = 5.097518\n", - "Epoch 345 batch: 0 , loss = 5.095610\n", - "Epoch 346 batch: 0 , loss = 5.089358\n", - "Epoch 347 batch: 0 , loss = 5.084134\n", - "Epoch 348 batch: 0 , loss = 5.093090\n", - "Epoch 349 batch: 0 , loss = 5.094424\n", - "Epoch 350 batch: 0 , loss = 5.089258\n", - "Epoch 351 batch: 0 , loss = 5.089724\n", - "Epoch 352 batch: 0 , loss = 5.092799\n", - "Epoch 353 batch: 0 , loss = 5.089324\n", - "Epoch 354 batch: 0 , loss = 5.093751\n", - "Epoch 355 batch: 0 , loss = 5.088273\n", - "Epoch 356 batch: 0 , loss = 5.083136\n", - "Epoch 357 batch: 0 , loss = 5.087886\n", - "Epoch 358 batch: 0 , loss = 5.092871\n", - "Epoch 359 batch: 0 , loss = 5.077843\n", - "Epoch 360 batch: 0 , loss = 5.087354\n", - "Epoch 361 batch: 0 , loss = 5.077330\n", - "Epoch 362 batch: 0 , loss = 5.096845\n", - "Epoch 363 batch: 0 , loss = 5.081802\n", - "Epoch 364 batch: 0 , loss = 5.086153\n", - "Epoch 365 batch: 0 , loss = 5.081354\n", - "Epoch 366 batch: 0 , loss = 5.084242\n", - "Epoch 367 batch: 0 , loss = 5.094817\n", - "Epoch 368 batch: 0 , loss = 5.089900\n", - "Epoch 369 batch: 0 , loss = 5.089128\n", - "Epoch 370 batch: 0 , loss = 5.084967\n", - "Epoch 371 batch: 0 , loss = 5.085038\n", - "Epoch 372 batch: 0 , loss = 5.084878\n", - "Epoch 373 batch: 0 , loss = 5.085416\n", - "Epoch 374 batch: 0 , loss = 5.084856\n", - "Epoch 375 batch: 0 , loss = 5.085449\n", - "Epoch 376 batch: 0 , loss = 5.080372\n", - "Epoch 377 batch: 0 , loss = 5.079864\n", - "Epoch 378 batch: 0 , loss = 5.089789\n", - "Epoch 379 batch: 0 , loss = 5.084702\n", - "Epoch 380 batch: 0 , loss = 5.080189\n", - "Epoch 381 batch: 0 , loss = 5.080185\n", - "Epoch 382 batch: 0 , loss = 5.080244\n", - "Epoch 383 batch: 0 , loss = 5.080095\n", - "Epoch 384 batch: 0 , loss = 5.088795\n", - "Epoch 385 batch: 0 , loss = 5.084120\n", - "Epoch 386 batch: 0 , loss = 5.087992\n", - "Epoch 387 batch: 0 , loss = 5.082991\n", - "Epoch 388 batch: 0 , loss = 5.072812\n", - "Epoch 389 batch: 0 , loss = 5.077140\n", - "Epoch 390 batch: 0 , loss = 5.081978\n", - "Epoch 391 batch: 0 , loss = 5.086642\n", - "Epoch 392 batch: 0 , loss = 5.091150\n", - "Epoch 393 batch: 0 , loss = 5.081849\n", - "Epoch 394 batch: 0 , loss = 5.071902\n", - "Epoch 395 batch: 0 , loss = 5.090557\n", - "Epoch 396 batch: 0 , loss = 5.085775\n", - "Epoch 397 batch: 0 , loss = 5.076406\n", - "Epoch 398 batch: 0 , loss = 5.086689\n", - "Epoch 399 batch: 0 , loss = 5.082210\n", - "Epoch 400 batch: 0 , loss = 5.091569\n", - "Epoch 401 batch: 0 , loss = 5.082448\n", - "Epoch 402 batch: 0 , loss = 5.081671\n", - "Epoch 403 batch: 0 , loss = 5.086580\n", - "Epoch 404 batch: 0 , loss = 5.081904\n", - "Epoch 405 batch: 0 , loss = 5.082079\n", - "Epoch 406 batch: 0 , loss = 5.081985\n", - "Epoch 407 batch: 0 , loss = 5.086331\n", - "Epoch 408 batch: 0 , loss = 5.076781\n", - "Epoch 409 batch: 0 , loss = 5.086619\n", - "Epoch 410 batch: 0 , loss = 5.085722\n", - "Epoch 411 batch: 0 , loss = 5.081203\n", - "Epoch 412 batch: 0 , loss = 5.080491\n", - "Epoch 413 batch: 0 , loss = 5.086140\n", - "Epoch 414 batch: 0 , loss = 5.076138\n", - "Epoch 415 batch: 0 , loss = 5.090469\n", - "Epoch 416 batch: 0 , loss = 5.075500\n", - "Epoch 417 batch: 0 , loss = 5.080328\n", - "Epoch 418 batch: 0 , loss = 5.084836\n", - "Epoch 419 batch: 0 , loss = 5.079198\n", - "Epoch 420 batch: 0 , loss = 5.068509\n", - "Epoch 421 batch: 0 , loss = 5.081526\n", - "Epoch 422 batch: 0 , loss = 5.075669\n", - "Epoch 423 batch: 0 , loss = 5.070409\n", - "Epoch 424 batch: 0 , loss = 5.069685\n", - "Epoch 425 batch: 0 , loss = 5.074168\n", - "Epoch 426 batch: 0 , loss = 5.073371\n", - "Epoch 427 batch: 0 , loss = 5.077053\n", - "Epoch 428 batch: 0 , loss = 5.076373\n", - "Epoch 429 batch: 0 , loss = 5.065866\n", - "Epoch 430 batch: 0 , loss = 5.069530\n", - "Epoch 431 batch: 0 , loss = 5.063504\n", - "Epoch 432 batch: 0 , loss = 5.066977\n", - "Epoch 433 batch: 0 , loss = 5.061393\n", - "Epoch 434 batch: 0 , loss = 5.064750\n", - "Epoch 435 batch: 0 , loss = 5.063208\n", - "Epoch 436 batch: 0 , loss = 5.062410\n", - "Epoch 437 batch: 0 , loss = 5.056867\n", - "Epoch 438 batch: 0 , loss = 5.061165\n", - "Epoch 439 batch: 0 , loss = 5.055538\n", - "Epoch 440 batch: 0 , loss = 5.064088\n", - "Epoch 441 batch: 0 , loss = 5.068861\n", - "Epoch 442 batch: 0 , loss = 5.058971\n", - "Epoch 443 batch: 0 , loss = 5.054216\n", - "Epoch 444 batch: 0 , loss = 5.053985\n", - "Epoch 445 batch: 0 , loss = 5.053171\n", - "Epoch 446 batch: 0 , loss = 5.052192\n", - "Epoch 447 batch: 0 , loss = 5.055376\n", - "Epoch 448 batch: 0 , loss = 5.048327\n", - "Epoch 449 batch: 0 , loss = 5.047246\n", - "Epoch 450 batch: 0 , loss = 5.056007\n", - "Epoch 451 batch: 0 , loss = 5.045246\n", - "Epoch 452 batch: 0 , loss = 5.048958\n", - "Epoch 453 batch: 0 , loss = 5.043387\n", - "Epoch 454 batch: 0 , loss = 5.052229\n", - "Epoch 455 batch: 0 , loss = 5.036552\n", - "Epoch 456 batch: 0 , loss = 5.047226\n", - "Epoch 457 batch: 0 , loss = 5.047490\n", - "Epoch 458 batch: 0 , loss = 5.045058\n", - "Epoch 459 batch: 0 , loss = 5.034773\n", - "Epoch 460 batch: 0 , loss = 5.037952\n", - "Epoch 461 batch: 0 , loss = 5.040687\n", - "Epoch 462 batch: 0 , loss = 5.031157\n", - "Epoch 463 batch: 0 , loss = 5.032485\n", - "Epoch 464 batch: 0 , loss = 5.017200\n", - "Epoch 465 batch: 0 , loss = 5.020655\n", - "Epoch 466 batch: 0 , loss = 5.024525\n", - "Epoch 467 batch: 0 , loss = 5.021727\n", - "Epoch 468 batch: 0 , loss = 5.015924\n", - "Epoch 469 batch: 0 , loss = 5.019942\n", - "Epoch 470 batch: 0 , loss = 5.024877\n", - "Epoch 471 batch: 0 , loss = 5.019931\n", - "Epoch 472 batch: 0 , loss = 5.023381\n", - "Epoch 473 batch: 0 , loss = 5.007413\n", - "Epoch 474 batch: 0 , loss = 5.015634\n", - "Epoch 475 batch: 0 , loss = 5.019332\n", - "Epoch 476 batch: 0 , loss = 5.023258\n", - "Epoch 477 batch: 0 , loss = 5.007919\n", - "Epoch 478 batch: 0 , loss = 5.012198\n", - "Epoch 479 batch: 0 , loss = 5.016301\n", - "Epoch 480 batch: 0 , loss = 5.015327\n", - "Epoch 481 batch: 0 , loss = 5.015544\n", - "Epoch 482 batch: 0 , loss = 5.015000\n", - "Epoch 483 batch: 0 , loss = 5.009738\n", - "Epoch 484 batch: 0 , loss = 5.004784\n", - "Epoch 485 batch: 0 , loss = 5.009272\n", - "Epoch 486 batch: 0 , loss = 5.014519\n", - "Epoch 487 batch: 0 , loss = 5.014775\n", - "Epoch 488 batch: 0 , loss = 5.018719\n", - "Epoch 489 batch: 0 , loss = 5.009684\n", - "Epoch 490 batch: 0 , loss = 5.014743\n", - "Epoch 491 batch: 0 , loss = 5.015655\n", - "Epoch 492 batch: 0 , loss = 5.011545\n", - "Epoch 493 batch: 0 , loss = 5.015726\n", - "Epoch 494 batch: 0 , loss = 5.009212\n", - "Epoch 495 batch: 0 , loss = 5.008894\n", - "Epoch 496 batch: 0 , loss = 5.002744\n", - "Epoch 497 batch: 0 , loss = 5.007370\n", - "Epoch 498 batch: 0 , loss = 5.007164\n", - "Epoch 499 batch: 0 , loss = 5.007099\n" - ] - } - ], - "source": [ - "# Train with data loader.\n", - "\n", - "\n", - "\n", - "model = NgmOne(device)\n", - "criterion = nn.CrossEntropyLoss()\n", - "optimizer = optim.Adam(model.parameters(), lr=0.001)\n", - "\n", - "epoch = 500\n", - "batch_size = 200\n", - "for e in range(epoch):\n", - " train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n", - " for i_batch, sample_batched in enumerate(train_dataloader):\n", - " optimizer.zero_grad()\n", - " train = sample_batched[0]\n", - " train_mask = sample_batched[1]\n", - " label_index = sample_batched[2].to(device)\n", - "\n", - " # Forward pass\n", - " output = model(train, train_mask)\n", - " loss = criterion(output, label_index)\n", - "\n", - " # backward and optimize\n", - " loss.backward()\n", - " optimizer.step()\n", - " if i_batch % batch_size == 0:\n", - " print(\"Epoch\", e, \"batch:\",i_batch, ', loss =', '{:.6f}'.format(loss))\n", - " \n", - " \n", - " \n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model = NgmOne(device)\n", - "\n", - "EPOCHS = 1500\n", - "criterion = nn.CrossEntropyLoss()\n", - "optimizer = optim.Adam(model.parameters(), lr=0.007)\n", - "\n", - "with open(\"../data/relations-query-qald-9-linked.json\", \"r\") as f:\n", - " relations = json.load(f)\n", - "\n", - "train, train_mask, corr_rels = make_batch()\n", - "corr_indx = torch.LongTensor([relations.index(r) for r in corr_rels]).to(device)\n", - "\n", - "for epoch in range(EPOCHS):\n", - " optimizer.zero_grad()\n", - "\n", - " # Forward pass\n", - " output = model(train, train_mask)\n", - " loss = criterion(output, corr_indx)\n", - "\n", - " if (epoch + 1) % 1 == 0:\n", - " print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))\n", - " # Backward pass\n", - " loss.backward()\n", - " optimizer.step()\n", - " \n" - ] - }, - { - "cell_type": "code", - "execution_count": 66, - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "def make_batch():\n", + " \"\"\"Triplet is a list of [subject entity, relation, object entity], None if not present\"\"\"\n", + "\n", + " # Load predicted data\n", + " pred = \"../data/qald-9-train-linked.json\"\n", + "\n", + " #Load gold data\n", + " gold = \"../data/qald-9-train-linked.json\"\n", + " print(\"Beginning making batch\")\n", + " with open(pred, \"r\") as p, open(gold, \"r\") as g:\n", + " pred = json.load(p)\n", + " gold = json.load(g)\n", + "\n", + " inputs = []\n", + " correct_rels = []\n", + " inputs_max_len = 0\n", + " for d in tqdm(pred[\"questions\"]):\n", + " question = d[\"question\"][0][\"string\"]\n", + " query = d[\"query\"][\"sparql\"]\n", + "\n", + " #Take the first tripletin query\n", + " trip = query.split(\"WHERE {\")[1]\n", + " trip = trip.split(\"}\")[0]\n", + " trip = trip.split(\"FILTER \")[0]\n", + " trip = trip.replace(\"{\", \"\").replace(\"}\", \"\")\n", + " trip = trip.replace(\".\", \"\")\n", + " trip = trip.replace(\";\", \"\")\n", + " triplet = trip.split(\" \")\n", + "\n", + " #remove empty strings\n", + " triplet = [x for x in triplet if x != \"\"]\n", + "\n", + " if len(triplet) % 3 == 0 and \" \".join(triplet).find(\"rdf\") == -1:\n", + " for i in range(len(triplet)//3):\n", + " triplet_i = triplet[i*3:i*3+3]\n", + " for t in triplet_i:\n", + " if not(t.find(\"?\")):\n", + " triplet_i[triplet_i.index(t)] = None\n", + "\n", + " #seq = \"[CLS] \" + question + \" [SEP] \"\n", + " if triplet_i[0] is not None and triplet_i[1] is not None:\n", + " #seq += \"[SUB] [SEP] \" + triplet[0]\n", + " # , padding=True, truncation=True)\n", + " tokenized_seq = tokenizer(question, \"[SUB]\", triplet_i[0], padding=True, truncation=True)\n", + " elif triplet_i[2] is not None and triplet_i[1] is not None:\n", + " #seq += \"[OBJ] [SEP] \" + triplet[2]\n", + " tokenized_seq = tokenizer(question, \"[OBJ]\", triplet_i[2], padding=True, truncation=True)\n", + " else:\n", + " continue\n", + "\n", + " if inputs_max_len < len(tokenized_seq[\"input_ids\"]):\n", + " inputs_max_len = len(tokenized_seq[\"input_ids\"])\n", + " inputs.append(list(tokenized_seq.values())[0])\n", + " correct_rels.append(triplet_i[1])\n", + "\n", + " inputs_padded = np.array([i + [0]*(inputs_max_len-len(i)) for i in inputs])\n", + " #correct_rels_padded = np.array([i + [0]*(correct_rels_max_len-len(i)) for i in correct_rels])\n", + "\n", + " inputs_attention_mask = np.where(inputs_padded != 0, 1, 0)\n", + " #correct_rels_attention_mask = np.where(correct_rels_padded != 0, 1, 0)\n", + " print(\"Finished with batches\")\n", + " return torch.LongTensor(inputs_padded), torch.LongTensor(inputs_attention_mask), correct_rels #torch.IntTensor(correct_rels_padded), torch.LongTensor(correct_rels_attention_mask)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# training_args = Seq2SeqTrainingArguments(\n", + "# output_dir='./models/blackbox',\n", + "# num_train_epochs=1,\n", + "# per_device_train_batch_size=1,\n", + "# per_device_eval_batch_size=1,\n", + "# warmup_steps=10,\n", + "# weight_decay=0.01,\n", + "# logging_dir='./logs',\n", + "# )\n" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import Dataset, DataLoader\n", + "\n", + "class MyDataset(Dataset):\n", + " def __init__(self, inputs, attention_mask, correct_rels, relations):\n", + " self.inputs = inputs\n", + " self.attention_mask = attention_mask\n", + " self.correct_rels = correct_rels\n", + " self.relations = relations\n", + " \n", + "\n", + " def __len__(self):\n", + " return len(self.inputs)\n", + "\n", + " def __getitem__(self, idx):\n", + " return self.inputs[idx], self.attention_mask[idx], self.relations.index(self.correct_rels[idx])\n", + "\n", + "\n", + "\n", + "\n", + "#From scratch json creates data set.\n", + "# class MyDataset(Dataset):\n", + "# def __init__(self, json_file, transform=None):\n", + "# self.qald_data = json.load(json_file)\n", + "\n", + "# def __len__(self):\n", + "# return len(self.qald_data)\n", + "\n", + "# def __getitem__(self, idx):\n", + "# self.inputs[idx], self.attention_mask[idx], self.labels[idx]" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Beginning making batch\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias']\n", + "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Finished with batches\n", + "features: tensor([[ 101, 2040, 2003, 1996, 3677, 1997, 1996, 4035, 6870, 19247,\n", + " 1029, 102, 1031, 4942, 1033, 102, 0, 0, 0, 0,\n", + " 0, 0, 0]]) mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]]) label_index tensor(81)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "#Prepare data\n", + "\n", + "def open_json(file):\n", + " with open(file, \"r\") as f:\n", + " return json.load(f)\n", + "\n", + "relations = open_json(\"../data/relations-query-qald-9-linked.json\")\n", + "train_set = MyDataset(*make_batch(), relations=relations)\n", + "train_dataloader = DataLoader(train_set, batch_size=1, shuffle=True)\n", + "\n", + "#show first entry\n", + "train_features, train_mask, train_label = next(iter(train_dataloader))\n", + "print(\"features:\", train_features, \"mask:\",train_mask,\"label_index\", train_label[0])" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Beginning making batch\n" - ] + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']\n", + "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Finished with batches\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "ename": "ValueError", + "evalue": "'<http://dbpediaorg/property/launchPad>' is not in list", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn [43], line 11\u001b[0m\n\u001b[0;32m 8\u001b[0m relations \u001b[39m=\u001b[39m json\u001b[39m.\u001b[39mload(f)\n\u001b[0;32m 10\u001b[0m train, train_mask, corr_rels \u001b[39m=\u001b[39m make_batch()\n\u001b[1;32m---> 11\u001b[0m corr_indx \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mLongTensor([relations\u001b[39m.\u001b[39mindex(r) \u001b[39mfor\u001b[39;00m r \u001b[39min\u001b[39;00m corr_rels])\u001b[39m.\u001b[39mto(device)\n\u001b[0;32m 13\u001b[0m \u001b[39mfor\u001b[39;00m epoch \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(EPOCHS):\n\u001b[0;32m 14\u001b[0m optimizer\u001b[39m.\u001b[39mzero_grad()\n", + "Cell \u001b[1;32mIn [43], line 11\u001b[0m, in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 8\u001b[0m relations \u001b[39m=\u001b[39m json\u001b[39m.\u001b[39mload(f)\n\u001b[0;32m 10\u001b[0m train, train_mask, corr_rels \u001b[39m=\u001b[39m make_batch()\n\u001b[1;32m---> 11\u001b[0m corr_indx \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mLongTensor([relations\u001b[39m.\u001b[39;49mindex(r) \u001b[39mfor\u001b[39;00m r \u001b[39min\u001b[39;00m corr_rels])\u001b[39m.\u001b[39mto(device)\n\u001b[0;32m 13\u001b[0m \u001b[39mfor\u001b[39;00m epoch \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(EPOCHS):\n\u001b[0;32m 14\u001b[0m optimizer\u001b[39m.\u001b[39mzero_grad()\n", + "\u001b[1;31mValueError\u001b[0m: '<http://dbpediaorg/property/launchPad>' is not in list" + ] + } + ], + "source": [ + "# Train with data loader.\n", + "\n", + "\n", + "\n", + "model = NgmOne(device)\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.Adam(model.parameters(), lr=0.001)\n", + "\n", + "epoch = 500\n", + "batch_size = 200\n", + "for e in range(epoch):\n", + " train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n", + " for i_batch, sample_batched in enumerate(train_dataloader):\n", + " optimizer.zero_grad()\n", + " train = sample_batched[0]\n", + " train_mask = sample_batched[1]\n", + " label_index = sample_batched[2].to(device)\n", + "\n", + " # Forward pass\n", + " output = model(train, train_mask)\n", + " loss = criterion(output, label_index)\n", + "\n", + " # backward and optimize\n", + " loss.backward()\n", + " optimizer.step()\n", + " if i_batch % batch_size == 0:\n", + " print(\"Epoch\", e, \"batch:\",i_batch, ', loss =', '{:.6f}'.format(loss))\n", + " \n", + " \n", + " \n", + "\n" + ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 408/408 [00:00<00:00, 962.29it/s] \n", - "100%|██████████| 408/408 [00:00<00:00, 81687.72it/s]\n" - ] + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = NgmOne(device)\n", + "\n", + "EPOCHS = 1500\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.Adam(model.parameters(), lr=0.007)\n", + "\n", + "with open(\"../data/relations-query-qald-9-linked.json\", \"r\") as f:\n", + " relations = json.load(f)\n", + "\n", + "train, train_mask, corr_rels = make_batch()\n", + "corr_indx = torch.LongTensor([relations.index(r) for r in corr_rels]).to(device)\n", + "\n", + "for epoch in range(EPOCHS):\n", + " optimizer.zero_grad()\n", + "\n", + " # Forward pass\n", + " output = model(train, train_mask)\n", + " loss = criterion(output, corr_indx)\n", + "\n", + " if (epoch + 1) % 1 == 0:\n", + " print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))\n", + " # Backward pass\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n" + ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Finished with batches\n", - "lowest confidence 0.14628477\n", - "Accuracy: 0.5245098039215687\n" - ] + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Beginning making batch\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 408/408 [00:00<00:00, 962.29it/s] \n", + "100%|██████████| 408/408 [00:00<00:00, 81687.72it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Finished with batches\n", + "lowest confidence 0.14628477\n", + "Accuracy: 0.5245098039215687\n" + ] + } + ], + "source": [ + "# Predict\n", + "train, train_mask, corr_rels = make_batch()\n", + "with torch.no_grad():\n", + " output = model(train, train_mask)\n", + "\n", + "output = output.detach().cpu().numpy()\n", + "\n", + "prediction = [relations[np.argmax(pred).item()]for pred in output]\n", + "probability = [pred[np.argmax(pred)] for pred in output]\n", + "correct_pred = [corr_rels[i] for i in range(len(output))]\n", + "\n", + "\n", + "\n", + "preds = [\n", + " {\"pred\": relations[np.argmax(pred).item()], \n", + " \"prob\": pred[np.argmax(pred)],\n", + " \"correct\": corr_rels[count]} \n", + " for count, pred in enumerate(output)\n", + " ]\n", + "\n", + "\n", + "\n", + "# for pred in preds:\n", + "# print(\"pred\", pred[\"pred\"], \"prob\", pred[\"prob\"], \"correct\", pred[\"correct\"])\n", + "\n", + "# for pred, prob, correct_pred in zip(prediction, probability, correct_pred):\n", + "# print(\"pred:\", pred, \"prob:\", prob, \"| correct\", correct_pred)\n", + "\n", + "print(\"lowest confidence\", min(probability))\n", + "\n", + "def accuracy_score(y_true, y_pred):\n", + " corr_preds=0\n", + " wrong_preds=0\n", + " for pred, correct in zip(y_pred, y_true):\n", + " if pred == correct:\n", + " corr_preds += 1\n", + " else:\n", + " wrong_preds += 1\n", + " return corr_preds/(corr_preds+wrong_preds)\n", + "\n", + "print(\"Accuracy:\", accuracy_score(correct_pred, prediction))\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.11 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.11" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "64e7cd3b4b88defe39dd61a4584920400d6beb2615ab2244e340c2e20eecdfe9" + } } - ], - "source": [ - "# Predict\n", - "train, train_mask, corr_rels = make_batch()\n", - "with torch.no_grad():\n", - " output = model(train, train_mask)\n", - "\n", - "output = output.detach().cpu().numpy()\n", - "\n", - "prediction = [relations[np.argmax(pred).item()]for pred in output]\n", - "probability = [pred[np.argmax(pred)] for pred in output]\n", - "correct_pred = [corr_rels[i] for i in range(len(output))]\n", - "\n", - "\n", - "\n", - "preds = [\n", - " {\"pred\": relations[np.argmax(pred).item()], \n", - " \"prob\": pred[np.argmax(pred)],\n", - " \"correct\": corr_rels[count]} \n", - " for count, pred in enumerate(output)\n", - " ]\n", - "\n", - "\n", - "\n", - "# for pred in preds:\n", - "# print(\"pred\", pred[\"pred\"], \"prob\", pred[\"prob\"], \"correct\", pred[\"correct\"])\n", - "\n", - "# for pred, prob, correct_pred in zip(prediction, probability, correct_pred):\n", - "# print(\"pred:\", pred, \"prob:\", prob, \"| correct\", correct_pred)\n", - "\n", - "print(\"lowest confidence\", min(probability))\n", - "\n", - "def accuracy_score(y_true, y_pred):\n", - " corr_preds=0\n", - " wrong_preds=0\n", - " for pred, correct in zip(y_pred, y_true):\n", - " if pred == correct:\n", - " corr_preds += 1\n", - " else:\n", - " wrong_preds += 1\n", - " return corr_preds/(corr_preds+wrong_preds)\n", - "\n", - "print(\"Accuracy:\", accuracy_score(correct_pred, prediction))\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.10.4 ('tdde19')", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.4" }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "8e4aa0e1a1e15de86146661edda0b2884b54582522f7ff2b916774ba6b8accb1" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "nbformat": 4, + "nbformat_minor": 2 }