Skip to content
Snippets Groups Projects
Commit 4aa6eb71 authored by Albin's avatar Albin
Browse files

Merge branch 'main' of gitlab.liu.se:tdde19-2022-1/codebase into main

parents 1afff55f c620b964
Branches
No related tags found
No related merge requests found
%% Cell type:code id: tags:
``` python
import datasets
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertModel
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from tqdm import tqdm
import json
```
%% Output
c:\Users\maxbj\AppData\Local\Programs\Python\Python39\lib\site-packages\tqdm\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
%% Cell type:code id: tags:
``` python
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
```
%% Cell type:code id: tags:
``` python
# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Print cuda version
print(device)
```
%% Cell type:code id: tags:
``` python
class NgmOne(nn.Module):
def __init__(self, device):
super(NgmOne, self).__init__()
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
self.bert = BertModel.from_pretrained("bert-base-uncased").to(device)
self.linear = nn.Linear(768, 247).to(device)
self.softmax = nn.Softmax(dim=1).to(device)
self.device = device
def forward(self, tokenized_seq, tokenized_mask):
tokenized_seq= tokenized_seq.to(self.device)
tokenized_mask = tokenized_mask.to(self.device)
with torch.no_grad():
x = self.bert.forward(tokenized_seq, attention_mask=tokenized_mask)
x = x[0][:,0,:].to(self.device)
x = self.linear(x)
x = self.softmax(x)
return x
```
%% Cell type:code id: tags:
``` python
# def encode(batch):
# return tokenizer(batch, padding="max_length", max_length=256, return_tensors="pt")
# def convert_to_features(example_batch):
# input_encodings = encode(example_batch['text'])
# target_encodings = encode(example_batch['summary'])
# labels = target_encodings['input_ids']
# decoder_input_ids = shift_tokens_right(
# labels, model.config.pad_token_id, model.config.decoder_start_token_id)
# labels[labels[:, :] == model.config.pad_token_id] = -100
# encodings = {
# 'input_ids': input_encodings['input_ids'],
# 'attention_mask': input_encodings['attention_mask'],
# 'decoder_input_ids': decoder_input_ids,
# 'labels': labels,
# }
# return encodings
# def get_dataset(path):
# df = pd.read_csv(path, sep=",", on_bad_lines='skip')
# dataset = datasets.Dataset.from_pandas(df)
# dataset = dataset.map(convert_to_features, batched=True)
# columns = ['input_ids', 'labels', 'decoder_input_ids', 'attention_mask', ]
# dataset.set_format(type='torch', columns=columns)
# return dataset
```
%% Cell type:code id: tags:
``` python
def make_batch():
"""Triplet is a list of [subject entity, relation, object entity], None if not present"""
# Load predicted data
pred = "../data/qald-9-train-linked.json"
#Load gold data
gold = "../data/qald-9-train-linked.json"
print("Beginning making batch")
with open(pred, "r") as p, open(gold, "r") as g:
pred = json.load(p)
gold = json.load(g)
inputs = []
correct_rels = []
inputs_max_len = 0
for d in tqdm(pred["questions"]):
question = d["question"][0]["string"]
query = d["query"]["sparql"]
#Take the first tripletin query
trip = query.split("WHERE")[1]
trip = query.split("WHERE {")[1]
trip = trip.split("}")[0]
trip = trip.split("FILTER ")[0]
trip = trip.replace("{", "").replace("}", "")
trip = trip.replace(".", "")
trip = trip.replace(";", "")
triplet = trip.split(" ")
#remove empty strings
triplet = [x for x in triplet if x != ""]
if len(triplet) != 3:
continue
for t in triplet:
if not(t.find("?")):
triplet[triplet.index(t)] = None
#seq = "[CLS] " + question + " [SEP] "
if triplet[0] is not None:
#seq += "[SUB] [SEP] " + triplet[0]
# , padding=True, truncation=True)
tokenized_seq = tokenizer(question, "[SUB]", triplet[0], padding=True, truncation=True)
elif triplet[2] is not None:
#seq += "[OBJ] [SEP] " + triplet[2]
tokenized_seq = tokenizer(question, "[OBJ]", triplet[2], padding=True, truncation=True)
if inputs_max_len < len(tokenized_seq["input_ids"]):
inputs_max_len = len(tokenized_seq["input_ids"])
inputs.append(list(tokenized_seq.values())[0])
correct_rels_max_len = 0
correct_rels = []
for d in tqdm(gold["questions"]):
question = d["question"][0]["string"]
query = d["query"]["sparql"]
#Take the first tripletin query
trip = query.split("WHERE")[1]
trip = trip.replace("{", "").replace("}", "")
triplet = trip.split(" ")
#remove empty strings
triplet = [x for x in triplet if x != ""]
if len(triplet) != 3:
continue
# tokenized = tokenizer(triplet[1], padding=True, truncation=True)
# if correct_rels_max_len < len(tokenized["input_ids"]):
# correct_rels_max_len = len(tokenized["input_ids"])
correct_rels.append(triplet[1])#list(tokenized.values())[0])
if len(triplet) % 3 == 0 and " ".join(triplet).find("rdf") == -1:
for i in range(len(triplet)//3):
triplet_i = triplet[i*3:i*3+3]
for t in triplet_i:
if not(t.find("?")):
triplet_i[triplet_i.index(t)] = None
#seq = "[CLS] " + question + " [SEP] "
if triplet_i[0] is not None and triplet_i[1] is not None:
#seq += "[SUB] [SEP] " + triplet[0]
# , padding=True, truncation=True)
tokenized_seq = tokenizer(question, "[SUB]", triplet_i[0], padding=True, truncation=True)
elif triplet_i[2] is not None and triplet_i[1] is not None:
#seq += "[OBJ] [SEP] " + triplet[2]
tokenized_seq = tokenizer(question, "[OBJ]", triplet_i[2], padding=True, truncation=True)
else:
continue
if inputs_max_len < len(tokenized_seq["input_ids"]):
inputs_max_len = len(tokenized_seq["input_ids"])
inputs.append(list(tokenized_seq.values())[0])
correct_rels.append(triplet_i[1])
inputs_padded = np.array([i + [0]*(inputs_max_len-len(i)) for i in inputs])
#correct_rels_padded = np.array([i + [0]*(correct_rels_max_len-len(i)) for i in correct_rels])
inputs_attention_mask = np.where(inputs_padded != 0, 1, 0)
#correct_rels_attention_mask = np.where(correct_rels_padded != 0, 1, 0)
print("Finished with batches")
return torch.LongTensor(inputs_padded), torch.LongTensor(inputs_attention_mask), correct_rels #torch.IntTensor(correct_rels_padded), torch.LongTensor(correct_rels_attention_mask)
```
%% Cell type:code id: tags:
``` python
# training_args = Seq2SeqTrainingArguments(
# output_dir='./models/blackbox',
# num_train_epochs=1,
# per_device_train_batch_size=1,
# per_device_eval_batch_size=1,
# warmup_steps=10,
# weight_decay=0.01,
# logging_dir='./logs',
# )
```
%% Cell type:code id: tags:
``` python
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, inputs, attention_mask, correct_rels, relations):
self.inputs = inputs
self.attention_mask = attention_mask
self.correct_rels = correct_rels
self.relations = relations
def __len__(self):
return len(self.inputs)
def __getitem__(self, idx):
return self.inputs[idx], self.attention_mask[idx], self.relations.index(self.correct_rels[idx])
#From scratch json creates data set.
# class MyDataset(Dataset):
# def __init__(self, json_file, transform=None):
# self.qald_data = json.load(json_file)
# def __len__(self):
# return len(self.qald_data)
# def __getitem__(self, idx):
# self.inputs[idx], self.attention_mask[idx], self.labels[idx]
```
%% Cell type:code id: tags:
``` python
#Prepare data
def open_json(file):
with open(file, "r") as f:
return json.load(f)
relations = open_json("../data/relations-query-qald-9-linked.json")
train_set = MyDataset(*make_batch(), relations=relations)
train_dataloader = DataLoader(train_set, batch_size=1, shuffle=True)
#show first entry
train_features, train_mask, train_label = next(iter(train_dataloader))
print("features:", train_features, "mask:",train_mask,"label_index", train_label[0])
```
%% Output
Beginning making batch
100%|██████████| 408/408 [00:00<00:00, 955.51it/s]
100%|██████████| 408/408 [00:00<00:00, 136139.70it/s]
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Finished with batches
features: tensor([[ 101, 2040, 2003, 1996, 3677, 1997, 1996, 4035, 6870, 19247,
1029, 102, 1031, 4942, 1033, 102, 0, 0, 0, 0,
0, 0, 0]]) mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]]) label_index tensor(81)
%% Cell type:code id: tags:
``` python
# Train with data loader.
model = NgmOne(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
epoch = 500
batch_size = 200
for e in range(epoch):
train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
for i_batch, sample_batched in enumerate(train_dataloader):
optimizer.zero_grad()
train = sample_batched[0]
train_mask = sample_batched[1]
label_index = sample_batched[2].to(device)
# Forward pass
output = model(train, train_mask)
loss = criterion(output, label_index)
# backward and optimize
loss.backward()
optimizer.step()
if i_batch % batch_size == 0:
print("Epoch", e, "batch:",i_batch, ', loss =', '{:.6f}'.format(loss))
```
%% Output
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Epoch 0 batch: 0 , loss = 5.509496
Epoch 1 batch: 0 , loss = 5.507990
Epoch 2 batch: 0 , loss = 5.506450
Epoch 3 batch: 0 , loss = 5.504800
Epoch 4 batch: 0 , loss = 5.502756
Epoch 5 batch: 0 , loss = 5.500911
Epoch 6 batch: 0 , loss = 5.498587
Epoch 7 batch: 0 , loss = 5.496222
Epoch 8 batch: 0 , loss = 5.493753
Epoch 9 batch: 0 , loss = 5.492003
Epoch 10 batch: 0 , loss = 5.489583
Epoch 11 batch: 0 , loss = 5.487563
Epoch 12 batch: 0 , loss = 5.485127
Epoch 13 batch: 0 , loss = 5.481408
Epoch 14 batch: 0 , loss = 5.479365
Epoch 15 batch: 0 , loss = 5.476476
Epoch 16 batch: 0 , loss = 5.475063
Epoch 17 batch: 0 , loss = 5.473976
Epoch 18 batch: 0 , loss = 5.471124
Epoch 19 batch: 0 , loss = 5.472827
Epoch 20 batch: 0 , loss = 5.466723
Epoch 21 batch: 0 , loss = 5.464960
Epoch 22 batch: 0 , loss = 5.466041
Epoch 23 batch: 0 , loss = 5.462078
Epoch 24 batch: 0 , loss = 5.460968
Epoch 25 batch: 0 , loss = 5.463361
Epoch 26 batch: 0 , loss = 5.462362
Epoch 27 batch: 0 , loss = 5.459974
Epoch 28 batch: 0 , loss = 5.462227
Epoch 29 batch: 0 , loss = 5.457432
Epoch 30 batch: 0 , loss = 5.456643
Epoch 31 batch: 0 , loss = 5.455560
Epoch 32 batch: 0 , loss = 5.454978
Epoch 33 batch: 0 , loss = 5.452355
Epoch 34 batch: 0 , loss = 5.449059
Epoch 35 batch: 0 , loss = 5.450156
Epoch 36 batch: 0 , loss = 5.443484
Epoch 37 batch: 0 , loss = 5.441142
Epoch 38 batch: 0 , loss = 5.438660
Epoch 39 batch: 0 , loss = 5.437134
Epoch 40 batch: 0 , loss = 5.434308
Epoch 41 batch: 0 , loss = 5.431943
Epoch 42 batch: 0 , loss = 5.429959
Epoch 43 batch: 0 , loss = 5.428108
Epoch 44 batch: 0 , loss = 5.426363
Epoch 45 batch: 0 , loss = 5.428246
Epoch 46 batch: 0 , loss = 5.424921
Epoch 47 batch: 0 , loss = 5.425705
Epoch 48 batch: 0 , loss = 5.424184
Epoch 49 batch: 0 , loss = 5.423537
Epoch 50 batch: 0 , loss = 5.424012
Epoch 51 batch: 0 , loss = 5.419333
Epoch 52 batch: 0 , loss = 5.414816
Epoch 53 batch: 0 , loss = 5.418826
Epoch 54 batch: 0 , loss = 5.418625
Epoch 55 batch: 0 , loss = 5.419236
Epoch 56 batch: 0 , loss = 5.421367
Epoch 57 batch: 0 , loss = 5.418901
Epoch 58 batch: 0 , loss = 5.416471
Epoch 59 batch: 0 , loss = 5.414352
Epoch 60 batch: 0 , loss = 5.416772
Epoch 61 batch: 0 , loss = 5.412184
Epoch 62 batch: 0 , loss = 5.407087
Epoch 63 batch: 0 , loss = 5.404146
Epoch 64 batch: 0 , loss = 5.405499
Epoch 65 batch: 0 , loss = 5.404672
Epoch 66 batch: 0 , loss = 5.399035
Epoch 67 batch: 0 , loss = 5.407500
Epoch 68 batch: 0 , loss = 5.403162
Epoch 69 batch: 0 , loss = 5.397717
Epoch 70 batch: 0 , loss = 5.398447
Epoch 71 batch: 0 , loss = 5.398669
Epoch 72 batch: 0 , loss = 5.392782
Epoch 73 batch: 0 , loss = 5.392102
Epoch 74 batch: 0 , loss = 5.386912
Epoch 75 batch: 0 , loss = 5.384616
Epoch 76 batch: 0 , loss = 5.382891
Epoch 77 batch: 0 , loss = 5.381707
Epoch 78 batch: 0 , loss = 5.376785
Epoch 79 batch: 0 , loss = 5.374002
Epoch 80 batch: 0 , loss = 5.376191
Epoch 81 batch: 0 , loss = 5.381535
Epoch 82 batch: 0 , loss = 5.376332
Epoch 83 batch: 0 , loss = 5.372051
Epoch 84 batch: 0 , loss = 5.367785
Epoch 85 batch: 0 , loss = 5.367027
Epoch 86 batch: 0 , loss = 5.366450
Epoch 87 batch: 0 , loss = 5.364227
Epoch 88 batch: 0 , loss = 5.364299
Epoch 89 batch: 0 , loss = 5.363710
Epoch 90 batch: 0 , loss = 5.356105
Epoch 91 batch: 0 , loss = 5.353946
Epoch 92 batch: 0 , loss = 5.356158
Epoch 93 batch: 0 , loss = 5.355119
Epoch 94 batch: 0 , loss = 5.348075
Epoch 95 batch: 0 , loss = 5.351038
Epoch 96 batch: 0 , loss = 5.349224
Epoch 97 batch: 0 , loss = 5.344652
Epoch 98 batch: 0 , loss = 5.341354
Epoch 99 batch: 0 , loss = 5.342059
Epoch 100 batch: 0 , loss = 5.344870
Epoch 101 batch: 0 , loss = 5.335486
Epoch 102 batch: 0 , loss = 5.339808
Epoch 103 batch: 0 , loss = 5.330384
Epoch 104 batch: 0 , loss = 5.333795
Epoch 105 batch: 0 , loss = 5.336934
Epoch 106 batch: 0 , loss = 5.334551
Epoch 107 batch: 0 , loss = 5.328543
Epoch 108 batch: 0 , loss = 5.329463
Epoch 109 batch: 0 , loss = 5.327895
Epoch 110 batch: 0 , loss = 5.324828
Epoch 111 batch: 0 , loss = 5.325212
Epoch 112 batch: 0 , loss = 5.320701
Epoch 113 batch: 0 , loss = 5.327697
Epoch 114 batch: 0 , loss = 5.321108
Epoch 115 batch: 0 , loss = 5.318196
Epoch 116 batch: 0 , loss = 5.313715
Epoch 117 batch: 0 , loss = 5.311478
Epoch 118 batch: 0 , loss = 5.310149
Epoch 119 batch: 0 , loss = 5.304658
Epoch 120 batch: 0 , loss = 5.299352
Epoch 121 batch: 0 , loss = 5.297422
Epoch 122 batch: 0 , loss = 5.296243
Epoch 123 batch: 0 , loss = 5.299733
Epoch 124 batch: 0 , loss = 5.296957
Epoch 125 batch: 0 , loss = 5.296990
Epoch 126 batch: 0 , loss = 5.295414
Epoch 127 batch: 0 , loss = 5.291314
Epoch 128 batch: 0 , loss = 5.288688
Epoch 129 batch: 0 , loss = 5.290800
Epoch 130 batch: 0 , loss = 5.290220
Epoch 131 batch: 0 , loss = 5.292923
Epoch 132 batch: 0 , loss = 5.282723
Epoch 133 batch: 0 , loss = 5.281114
Epoch 134 batch: 0 , loss = 5.284384
Epoch 135 batch: 0 , loss = 5.285651
Epoch 136 batch: 0 , loss = 5.279850
Epoch 137 batch: 0 , loss = 5.276040
Epoch 138 batch: 0 , loss = 5.279545
Epoch 139 batch: 0 , loss = 5.275866
Epoch 140 batch: 0 , loss = 5.275192
Epoch 141 batch: 0 , loss = 5.278750
Epoch 142 batch: 0 , loss = 5.281165
Epoch 143 batch: 0 , loss = 5.281812
Epoch 144 batch: 0 , loss = 5.270550
Epoch 145 batch: 0 , loss = 5.269819
Epoch 146 batch: 0 , loss = 5.268355
Epoch 147 batch: 0 , loss = 5.266929
Epoch 148 batch: 0 , loss = 5.261328
Epoch 149 batch: 0 , loss = 5.268577
Epoch 150 batch: 0 , loss = 5.260501
Epoch 151 batch: 0 , loss = 5.263970
Epoch 152 batch: 0 , loss = 5.261120
Epoch 153 batch: 0 , loss = 5.258860
Epoch 154 batch: 0 , loss = 5.253738
Epoch 155 batch: 0 , loss = 5.261181
Epoch 156 batch: 0 , loss = 5.246259
Epoch 157 batch: 0 , loss = 5.245531
Epoch 158 batch: 0 , loss = 5.244577
Epoch 159 batch: 0 , loss = 5.243086
Epoch 160 batch: 0 , loss = 5.241098
Epoch 161 batch: 0 , loss = 5.245851
Epoch 162 batch: 0 , loss = 5.245241
Epoch 163 batch: 0 , loss = 5.244030
Epoch 164 batch: 0 , loss = 5.238179
Epoch 165 batch: 0 , loss = 5.232537
Epoch 166 batch: 0 , loss = 5.238233
Epoch 167 batch: 0 , loss = 5.244127
Epoch 168 batch: 0 , loss = 5.238196
Epoch 169 batch: 0 , loss = 5.231448
Epoch 170 batch: 0 , loss = 5.231340
Epoch 171 batch: 0 , loss = 5.230532
Epoch 172 batch: 0 , loss = 5.232080
Epoch 173 batch: 0 , loss = 5.228147
Epoch 174 batch: 0 , loss = 5.225359
Epoch 175 batch: 0 , loss = 5.229460
Epoch 176 batch: 0 , loss = 5.223203
Epoch 177 batch: 0 , loss = 5.226225
Epoch 178 batch: 0 , loss = 5.224785
Epoch 179 batch: 0 , loss = 5.232825
Epoch 180 batch: 0 , loss = 5.227349
Epoch 181 batch: 0 , loss = 5.222055
Epoch 182 batch: 0 , loss = 5.212518
Epoch 183 batch: 0 , loss = 5.221508
Epoch 184 batch: 0 , loss = 5.218023
Epoch 185 batch: 0 , loss = 5.221085
Epoch 186 batch: 0 , loss = 5.212870
Epoch 187 batch: 0 , loss = 5.217089
Epoch 188 batch: 0 , loss = 5.220237
Epoch 189 batch: 0 , loss = 5.220267
Epoch 190 batch: 0 , loss = 5.216769
Epoch 191 batch: 0 , loss = 5.218822
Epoch 192 batch: 0 , loss = 5.220790
Epoch 193 batch: 0 , loss = 5.206959
Epoch 194 batch: 0 , loss = 5.204648
Epoch 195 batch: 0 , loss = 5.202166
Epoch 196 batch: 0 , loss = 5.199206
Epoch 197 batch: 0 , loss = 5.200701
Epoch 198 batch: 0 , loss = 5.187927
Epoch 199 batch: 0 , loss = 5.194535
Epoch 200 batch: 0 , loss = 5.190029
Epoch 201 batch: 0 , loss = 5.189040
Epoch 202 batch: 0 , loss = 5.190596
Epoch 203 batch: 0 , loss = 5.189606
Epoch 204 batch: 0 , loss = 5.186440
Epoch 205 batch: 0 , loss = 5.191073
Epoch 206 batch: 0 , loss = 5.184719
Epoch 207 batch: 0 , loss = 5.179452
Epoch 208 batch: 0 , loss = 5.173555
Epoch 209 batch: 0 , loss = 5.180843
Epoch 210 batch: 0 , loss = 5.186315
Epoch 211 batch: 0 , loss = 5.182886
Epoch 212 batch: 0 , loss = 5.179914
Epoch 213 batch: 0 , loss = 5.175862
Epoch 214 batch: 0 , loss = 5.183718
Epoch 215 batch: 0 , loss = 5.162323
Epoch 216 batch: 0 , loss = 5.170660
Epoch 217 batch: 0 , loss = 5.169362
Epoch 218 batch: 0 , loss = 5.162789
Epoch 219 batch: 0 , loss = 5.160868
Epoch 220 batch: 0 , loss = 5.164267
Epoch 221 batch: 0 , loss = 5.164003
Epoch 222 batch: 0 , loss = 5.165540
Epoch 223 batch: 0 , loss = 5.161615
Epoch 224 batch: 0 , loss = 5.152071
Epoch 225 batch: 0 , loss = 5.162641
Epoch 226 batch: 0 , loss = 5.158351
Epoch 227 batch: 0 , loss = 5.167762
Epoch 228 batch: 0 , loss = 5.166461
Epoch 229 batch: 0 , loss = 5.157466
Epoch 230 batch: 0 , loss = 5.156347
Epoch 231 batch: 0 , loss = 5.156344
Epoch 232 batch: 0 , loss = 5.164078
Epoch 233 batch: 0 , loss = 5.155646
Epoch 234 batch: 0 , loss = 5.154130
Epoch 235 batch: 0 , loss = 5.152800
Epoch 236 batch: 0 , loss = 5.161687
Epoch 237 batch: 0 , loss = 5.151773
Epoch 238 batch: 0 , loss = 5.146783
Epoch 239 batch: 0 , loss = 5.146691
Epoch 240 batch: 0 , loss = 5.151015
Epoch 241 batch: 0 , loss = 5.145954
Epoch 242 batch: 0 , loss = 5.154754
Epoch 243 batch: 0 , loss = 5.144487
Epoch 244 batch: 0 , loss = 5.152771
Epoch 245 batch: 0 , loss = 5.156934
Epoch 246 batch: 0 , loss = 5.146832
Epoch 247 batch: 0 , loss = 5.148869
Epoch 248 batch: 0 , loss = 5.142589
Epoch 249 batch: 0 , loss = 5.141822
Epoch 250 batch: 0 , loss = 5.144300
Epoch 251 batch: 0 , loss = 5.140250
Epoch 252 batch: 0 , loss = 5.135976
Epoch 253 batch: 0 , loss = 5.144660
Epoch 254 batch: 0 , loss = 5.144534
Epoch 255 batch: 0 , loss = 5.140327
Epoch 256 batch: 0 , loss = 5.139380
Epoch 257 batch: 0 , loss = 5.149578
Epoch 258 batch: 0 , loss = 5.139646
Epoch 259 batch: 0 , loss = 5.139930
Epoch 260 batch: 0 , loss = 5.143685
Epoch 261 batch: 0 , loss = 5.138163
Epoch 262 batch: 0 , loss = 5.138759
Epoch 263 batch: 0 , loss = 5.133649
Epoch 264 batch: 0 , loss = 5.137669
Epoch 265 batch: 0 , loss = 5.133360
Epoch 266 batch: 0 , loss = 5.137102
Epoch 267 batch: 0 , loss = 5.144728
Epoch 268 batch: 0 , loss = 5.135507
Epoch 269 batch: 0 , loss = 5.133213
Epoch 270 batch: 0 , loss = 5.130654
Epoch 271 batch: 0 , loss = 5.128548
Epoch 272 batch: 0 , loss = 5.118666
Epoch 273 batch: 0 , loss = 5.121534
Epoch 274 batch: 0 , loss = 5.122122
Epoch 275 batch: 0 , loss = 5.127708
Epoch 276 batch: 0 , loss = 5.132959
Epoch 277 batch: 0 , loss = 5.129933
Epoch 278 batch: 0 , loss = 5.125852
Epoch 279 batch: 0 , loss = 5.120491
Epoch 280 batch: 0 , loss = 5.123435
Epoch 281 batch: 0 , loss = 5.129238
Epoch 282 batch: 0 , loss = 5.124735
Epoch 283 batch: 0 , loss = 5.116185
Epoch 284 batch: 0 , loss = 5.113766
Epoch 285 batch: 0 , loss = 5.114601
Epoch 286 batch: 0 , loss = 5.112402
Epoch 287 batch: 0 , loss = 5.111151
Epoch 288 batch: 0 , loss = 5.117523
Epoch 289 batch: 0 , loss = 5.117987
Epoch 290 batch: 0 , loss = 5.112226
Epoch 291 batch: 0 , loss = 5.103069
Epoch 292 batch: 0 , loss = 5.105778
Epoch 293 batch: 0 , loss = 5.117328
Epoch 294 batch: 0 , loss = 5.117546
Epoch 295 batch: 0 , loss = 5.114745
Epoch 296 batch: 0 , loss = 5.111210
Epoch 297 batch: 0 , loss = 5.116782
Epoch 298 batch: 0 , loss = 5.126735
Epoch 299 batch: 0 , loss = 5.110609
Epoch 300 batch: 0 , loss = 5.119059
Epoch 301 batch: 0 , loss = 5.115992
Epoch 302 batch: 0 , loss = 5.106349
Epoch 303 batch: 0 , loss = 5.104899
Epoch 304 batch: 0 , loss = 5.104232
Epoch 305 batch: 0 , loss = 5.112337
Epoch 306 batch: 0 , loss = 5.111980
Epoch 307 batch: 0 , loss = 5.106858
Epoch 308 batch: 0 , loss = 5.109094
Epoch 309 batch: 0 , loss = 5.111865
Epoch 310 batch: 0 , loss = 5.096600
Epoch 311 batch: 0 , loss = 5.106102
Epoch 312 batch: 0 , loss = 5.102217
Epoch 313 batch: 0 , loss = 5.111903
Epoch 314 batch: 0 , loss = 5.102543
Epoch 315 batch: 0 , loss = 5.101692
Epoch 316 batch: 0 , loss = 5.114990
Epoch 317 batch: 0 , loss = 5.108889
Epoch 318 batch: 0 , loss = 5.102445
Epoch 319 batch: 0 , loss = 5.096200
Epoch 320 batch: 0 , loss = 5.100635
Epoch 321 batch: 0 , loss = 5.099531
Epoch 322 batch: 0 , loss = 5.100465
Epoch 323 batch: 0 , loss = 5.095500
Epoch 324 batch: 0 , loss = 5.100604
Epoch 325 batch: 0 , loss = 5.099518
Epoch 326 batch: 0 , loss = 5.094840
Epoch 327 batch: 0 , loss = 5.093679
Epoch 328 batch: 0 , loss = 5.093080
Epoch 329 batch: 0 , loss = 5.095733
Epoch 330 batch: 0 , loss = 5.099861
Epoch 331 batch: 0 , loss = 5.090419
Epoch 332 batch: 0 , loss = 5.099437
Epoch 333 batch: 0 , loss = 5.098624
Epoch 334 batch: 0 , loss = 5.093289
Epoch 335 batch: 0 , loss = 5.097763
Epoch 336 batch: 0 , loss = 5.097146
Epoch 337 batch: 0 , loss = 5.101106
Epoch 338 batch: 0 , loss = 5.082275
Epoch 339 batch: 0 , loss = 5.087108
Epoch 340 batch: 0 , loss = 5.092202
Epoch 341 batch: 0 , loss = 5.082821
Epoch 342 batch: 0 , loss = 5.092401
Epoch 343 batch: 0 , loss = 5.097392
Epoch 344 batch: 0 , loss = 5.097518
Epoch 345 batch: 0 , loss = 5.095610
Epoch 346 batch: 0 , loss = 5.089358
Epoch 347 batch: 0 , loss = 5.084134
Epoch 348 batch: 0 , loss = 5.093090
Epoch 349 batch: 0 , loss = 5.094424
Epoch 350 batch: 0 , loss = 5.089258
Epoch 351 batch: 0 , loss = 5.089724
Epoch 352 batch: 0 , loss = 5.092799
Epoch 353 batch: 0 , loss = 5.089324
Epoch 354 batch: 0 , loss = 5.093751
Epoch 355 batch: 0 , loss = 5.088273
Epoch 356 batch: 0 , loss = 5.083136
Epoch 357 batch: 0 , loss = 5.087886
Epoch 358 batch: 0 , loss = 5.092871
Epoch 359 batch: 0 , loss = 5.077843
Epoch 360 batch: 0 , loss = 5.087354
Epoch 361 batch: 0 , loss = 5.077330
Epoch 362 batch: 0 , loss = 5.096845
Epoch 363 batch: 0 , loss = 5.081802
Epoch 364 batch: 0 , loss = 5.086153
Epoch 365 batch: 0 , loss = 5.081354
Epoch 366 batch: 0 , loss = 5.084242
Epoch 367 batch: 0 , loss = 5.094817
Epoch 368 batch: 0 , loss = 5.089900
Epoch 369 batch: 0 , loss = 5.089128
Epoch 370 batch: 0 , loss = 5.084967
Epoch 371 batch: 0 , loss = 5.085038
Epoch 372 batch: 0 , loss = 5.084878
Epoch 373 batch: 0 , loss = 5.085416
Epoch 374 batch: 0 , loss = 5.084856
Epoch 375 batch: 0 , loss = 5.085449
Epoch 376 batch: 0 , loss = 5.080372
Epoch 377 batch: 0 , loss = 5.079864
Epoch 378 batch: 0 , loss = 5.089789
Epoch 379 batch: 0 , loss = 5.084702
Epoch 380 batch: 0 , loss = 5.080189
Epoch 381 batch: 0 , loss = 5.080185
Epoch 382 batch: 0 , loss = 5.080244
Epoch 383 batch: 0 , loss = 5.080095
Epoch 384 batch: 0 , loss = 5.088795
Epoch 385 batch: 0 , loss = 5.084120
Epoch 386 batch: 0 , loss = 5.087992
Epoch 387 batch: 0 , loss = 5.082991
Epoch 388 batch: 0 , loss = 5.072812
Epoch 389 batch: 0 , loss = 5.077140
Epoch 390 batch: 0 , loss = 5.081978
Epoch 391 batch: 0 , loss = 5.086642
Epoch 392 batch: 0 , loss = 5.091150
Epoch 393 batch: 0 , loss = 5.081849
Epoch 394 batch: 0 , loss = 5.071902
Epoch 395 batch: 0 , loss = 5.090557
Epoch 396 batch: 0 , loss = 5.085775
Epoch 397 batch: 0 , loss = 5.076406
Epoch 398 batch: 0 , loss = 5.086689
Epoch 399 batch: 0 , loss = 5.082210
Epoch 400 batch: 0 , loss = 5.091569
Epoch 401 batch: 0 , loss = 5.082448
Epoch 402 batch: 0 , loss = 5.081671
Epoch 403 batch: 0 , loss = 5.086580
Epoch 404 batch: 0 , loss = 5.081904
Epoch 405 batch: 0 , loss = 5.082079
Epoch 406 batch: 0 , loss = 5.081985
Epoch 407 batch: 0 , loss = 5.086331
Epoch 408 batch: 0 , loss = 5.076781
Epoch 409 batch: 0 , loss = 5.086619
Epoch 410 batch: 0 , loss = 5.085722
Epoch 411 batch: 0 , loss = 5.081203
Epoch 412 batch: 0 , loss = 5.080491
Epoch 413 batch: 0 , loss = 5.086140
Epoch 414 batch: 0 , loss = 5.076138
Epoch 415 batch: 0 , loss = 5.090469
Epoch 416 batch: 0 , loss = 5.075500
Epoch 417 batch: 0 , loss = 5.080328
Epoch 418 batch: 0 , loss = 5.084836
Epoch 419 batch: 0 , loss = 5.079198
Epoch 420 batch: 0 , loss = 5.068509
Epoch 421 batch: 0 , loss = 5.081526
Epoch 422 batch: 0 , loss = 5.075669
Epoch 423 batch: 0 , loss = 5.070409
Epoch 424 batch: 0 , loss = 5.069685
Epoch 425 batch: 0 , loss = 5.074168
Epoch 426 batch: 0 , loss = 5.073371
Epoch 427 batch: 0 , loss = 5.077053
Epoch 428 batch: 0 , loss = 5.076373
Epoch 429 batch: 0 , loss = 5.065866
Epoch 430 batch: 0 , loss = 5.069530
Epoch 431 batch: 0 , loss = 5.063504
Epoch 432 batch: 0 , loss = 5.066977
Epoch 433 batch: 0 , loss = 5.061393
Epoch 434 batch: 0 , loss = 5.064750
Epoch 435 batch: 0 , loss = 5.063208
Epoch 436 batch: 0 , loss = 5.062410
Epoch 437 batch: 0 , loss = 5.056867
Epoch 438 batch: 0 , loss = 5.061165
Epoch 439 batch: 0 , loss = 5.055538
Epoch 440 batch: 0 , loss = 5.064088
Epoch 441 batch: 0 , loss = 5.068861
Epoch 442 batch: 0 , loss = 5.058971
Epoch 443 batch: 0 , loss = 5.054216
Epoch 444 batch: 0 , loss = 5.053985
Epoch 445 batch: 0 , loss = 5.053171
Epoch 446 batch: 0 , loss = 5.052192
Epoch 447 batch: 0 , loss = 5.055376
Epoch 448 batch: 0 , loss = 5.048327
Epoch 449 batch: 0 , loss = 5.047246
Epoch 450 batch: 0 , loss = 5.056007
Epoch 451 batch: 0 , loss = 5.045246
Epoch 452 batch: 0 , loss = 5.048958
Epoch 453 batch: 0 , loss = 5.043387
Epoch 454 batch: 0 , loss = 5.052229
Epoch 455 batch: 0 , loss = 5.036552
Epoch 456 batch: 0 , loss = 5.047226
Epoch 457 batch: 0 , loss = 5.047490
Epoch 458 batch: 0 , loss = 5.045058
Epoch 459 batch: 0 , loss = 5.034773
Epoch 460 batch: 0 , loss = 5.037952
Epoch 461 batch: 0 , loss = 5.040687
Epoch 462 batch: 0 , loss = 5.031157
Epoch 463 batch: 0 , loss = 5.032485
Epoch 464 batch: 0 , loss = 5.017200
Epoch 465 batch: 0 , loss = 5.020655
Epoch 466 batch: 0 , loss = 5.024525
Epoch 467 batch: 0 , loss = 5.021727
Epoch 468 batch: 0 , loss = 5.015924
Epoch 469 batch: 0 , loss = 5.019942
Epoch 470 batch: 0 , loss = 5.024877
Epoch 471 batch: 0 , loss = 5.019931
Epoch 472 batch: 0 , loss = 5.023381
Epoch 473 batch: 0 , loss = 5.007413
Epoch 474 batch: 0 , loss = 5.015634
Epoch 475 batch: 0 , loss = 5.019332
Epoch 476 batch: 0 , loss = 5.023258
Epoch 477 batch: 0 , loss = 5.007919
Epoch 478 batch: 0 , loss = 5.012198
Epoch 479 batch: 0 , loss = 5.016301
Epoch 480 batch: 0 , loss = 5.015327
Epoch 481 batch: 0 , loss = 5.015544
Epoch 482 batch: 0 , loss = 5.015000
Epoch 483 batch: 0 , loss = 5.009738
Epoch 484 batch: 0 , loss = 5.004784
Epoch 485 batch: 0 , loss = 5.009272
Epoch 486 batch: 0 , loss = 5.014519
Epoch 487 batch: 0 , loss = 5.014775
Epoch 488 batch: 0 , loss = 5.018719
Epoch 489 batch: 0 , loss = 5.009684
Epoch 490 batch: 0 , loss = 5.014743
Epoch 491 batch: 0 , loss = 5.015655
Epoch 492 batch: 0 , loss = 5.011545
Epoch 493 batch: 0 , loss = 5.015726
Epoch 494 batch: 0 , loss = 5.009212
Epoch 495 batch: 0 , loss = 5.008894
Epoch 496 batch: 0 , loss = 5.002744
Epoch 497 batch: 0 , loss = 5.007370
Epoch 498 batch: 0 , loss = 5.007164
Epoch 499 batch: 0 , loss = 5.007099
Finished with batches
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In [43], line 11
8 relations = json.load(f)
10 train, train_mask, corr_rels = make_batch()
---> 11 corr_indx = torch.LongTensor([relations.index(r) for r in corr_rels]).to(device)
13 for epoch in range(EPOCHS):
14 optimizer.zero_grad()
Cell In [43], line 11, in <listcomp>(.0)
8 relations = json.load(f)
10 train, train_mask, corr_rels = make_batch()
---> 11 corr_indx = torch.LongTensor([relations.index(r) for r in corr_rels]).to(device)
13 for epoch in range(EPOCHS):
14 optimizer.zero_grad()
ValueError: '<http://dbpediaorg/property/launchPad>' is not in list
%% Cell type:code id: tags:
``` python
model = NgmOne(device)
EPOCHS = 1500
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.007)
with open("../data/relations-query-qald-9-linked.json", "r") as f:
relations = json.load(f)
train, train_mask, corr_rels = make_batch()
corr_indx = torch.LongTensor([relations.index(r) for r in corr_rels]).to(device)
for epoch in range(EPOCHS):
optimizer.zero_grad()
# Forward pass
output = model(train, train_mask)
loss = criterion(output, corr_indx)
if (epoch + 1) % 1 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
# Backward pass
loss.backward()
optimizer.step()
```
%% Cell type:code id: tags:
``` python
# Predict
train, train_mask, corr_rels = make_batch()
with torch.no_grad():
output = model(train, train_mask)
output = output.detach().cpu().numpy()
prediction = [relations[np.argmax(pred).item()]for pred in output]
probability = [pred[np.argmax(pred)] for pred in output]
correct_pred = [corr_rels[i] for i in range(len(output))]
preds = [
{"pred": relations[np.argmax(pred).item()],
"prob": pred[np.argmax(pred)],
"correct": corr_rels[count]}
for count, pred in enumerate(output)
]
# for pred in preds:
# print("pred", pred["pred"], "prob", pred["prob"], "correct", pred["correct"])
# for pred, prob, correct_pred in zip(prediction, probability, correct_pred):
# print("pred:", pred, "prob:", prob, "| correct", correct_pred)
print("lowest confidence", min(probability))
def accuracy_score(y_true, y_pred):
corr_preds=0
wrong_preds=0
for pred, correct in zip(y_pred, y_true):
if pred == correct:
corr_preds += 1
else:
wrong_preds += 1
return corr_preds/(corr_preds+wrong_preds)
print("Accuracy:", accuracy_score(correct_pred, prediction))
```
%% Output
Beginning making batch
100%|██████████| 408/408 [00:00<00:00, 962.29it/s]
100%|██████████| 408/408 [00:00<00:00, 81687.72it/s]
Finished with batches
lowest confidence 0.14628477
Accuracy: 0.5245098039215687
%% Cell type:code id: tags:
``` python
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment