Skip to content
Snippets Groups Projects
Commit 88732e7f authored by Albin's avatar Albin
Browse files

Neural graph model runs, (sketchy tho)

parent ebac0096
Branches
No related tags found
No related merge requests found
%% Cell type:code id: tags:
``` python
import datasets
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertModel
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from tqdm import tqdm
import json
```
%% Output
b:\Programs\Miniconda\envs\tdde19\lib\site-packages\tqdm\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
%% Cell type:code id: tags:
``` python
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
```
%% Cell type:code id: tags:
``` python
# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Print cuda version
print(device)
```
%% Output
cuda
%% Cell type:code id: tags:
``` python
class NgmOne(nn.Module):
def __init__(self):
def __init__(self, device):
super(NgmOne, self).__init__()
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
self.bert = BertModel.from_pretrained("bert-base-uncased")
self.linear = nn.Linear(768, 247)
self.softmax = nn.Softmax(dim=1)
self.bert = BertModel.from_pretrained("bert-base-uncased").to(device)
self.linear = nn.Linear(768, 247).to(device)
self.softmax = nn.Softmax(dim=1).to(device)
self.device = device
def forward(self, tokenized_seq, tokenized_mask):
x = self.bert.forward(tokenized_seq, attention_mask=tokenized_mask)
x = x[0][:,0,:]
tokenized_seq= tokenized_seq.to(self.device)
tokenized_mask = tokenized_mask.to(self.device)
with torch.no_grad():
x = self.bert.forward(tokenized_seq, attention_mask=tokenized_mask)
x = x[0][:,0,:].to(self.device)
x = self.linear(x)
x = self.softmax(x)
return x
```
%% Cell type:code id: tags:
``` python
# def encode(batch):
# return tokenizer(batch, padding="max_length", max_length=256, return_tensors="pt")
# def convert_to_features(example_batch):
# input_encodings = encode(example_batch['text'])
# target_encodings = encode(example_batch['summary'])
# labels = target_encodings['input_ids']
# decoder_input_ids = shift_tokens_right(
# labels, model.config.pad_token_id, model.config.decoder_start_token_id)
# labels[labels[:, :] == model.config.pad_token_id] = -100
# encodings = {
# 'input_ids': input_encodings['input_ids'],
# 'attention_mask': input_encodings['attention_mask'],
# 'decoder_input_ids': decoder_input_ids,
# 'labels': labels,
# }
# return encodings
# def get_dataset(path):
# df = pd.read_csv(path, sep=",", on_bad_lines='skip')
# dataset = datasets.Dataset.from_pandas(df)
# dataset = dataset.map(convert_to_features, batched=True)
# columns = ['input_ids', 'labels', 'decoder_input_ids', 'attention_mask', ]
# dataset.set_format(type='torch', columns=columns)
# return dataset
```
%% Cell type:code id: tags:
``` python
def make_batch():
"""Triplet is a list of [subject entity, relation, object entity], None if not present"""
# Load predicted data
pred = "../data/qald-9-train-linked.json"
#Load gold data
gold = "../data/qald-9-train-linked.json"
print("Beginning making batch")
with open(pred, "r") as p, open(gold, "r") as g:
pred = json.load(p)
gold = json.load(g)
inputs = []
inputs_max_len = 0
for d in tqdm(pred["questions"]):
question = d["question"][0]["string"]
query = d["query"]["sparql"]
#Take the first tripletin query
trip = query.split("WHERE")[1]
trip = trip.replace("{", "").replace("}", "")
triplet = trip.split(" ")
#remove empty strings
triplet = [x for x in triplet if x != ""]
if len(triplet) != 3:
continue
for t in triplet:
if not(t.find("?")):
triplet[triplet.index(t)] = None
#seq = "[CLS] " + question + " [SEP] "
if triplet[0] is not None:
#seq += "[SUB] [SEP] " + triplet[0]
# , padding=True, truncation=True)
tokenized_seq = tokenizer(question, "[SUB]", triplet[0], padding=True, truncation=True)
elif triplet[2] is not None:
#seq += "[OBJ] [SEP] " + triplet[2]
tokenized_seq = tokenizer(question, "[OBJ]", triplet[2], padding=True, truncation=True)
if inputs_max_len < len(tokenized_seq["input_ids"]):
inputs_max_len = len(tokenized_seq["input_ids"])
inputs.append(list(tokenized_seq.values())[0])
correct_rels_max_len = 0
correct_rels = []
for d in tqdm(gold["questions"]):
question = d["question"][0]["string"]
query = d["query"]["sparql"]
#Take the first tripletin query
trip = query.split("WHERE")[1]
trip = trip.replace("{", "").replace("}", "")
triplet = trip.split(" ")
#remove empty strings
triplet = [x for x in triplet if x != ""]
tokenized = tokenizer(triplet[1], padding=True, truncation=True)
if correct_rels_max_len < len(tokenized["input_ids"]):
correct_rels_max_len = len(tokenized["input_ids"])
if len(triplet) != 3:
continue
# tokenized = tokenizer(triplet[1], padding=True, truncation=True)
# if correct_rels_max_len < len(tokenized["input_ids"]):
# correct_rels_max_len = len(tokenized["input_ids"])
correct_rels.append(list(tokenized.values())[0])
correct_rels.append(triplet[1])#list(tokenized.values())[0])
inputs_padded = np.array([i + [0]*(inputs_max_len-len(i)) for i in inputs])
correct_rels_padded = np.array([i + [0]*(correct_rels_max_len-len(i)) for i in correct_rels])
#correct_rels_padded = np.array([i + [0]*(correct_rels_max_len-len(i)) for i in correct_rels])
inputs_attention_mask = np.where(inputs_padded != 0, 1, 0)
correct_rels_attention_mask = np.where(correct_rels_padded != 0, 1, 0)
#correct_rels_attention_mask = np.where(correct_rels_padded != 0, 1, 0)
print("Finished with batches")
return torch.IntTensor(inputs_padded), torch.IntTensor(inputs_attention_mask), torch.IntTensor(correct_rels_padded), torch.IntTensor(correct_rels_attention_mask)
return torch.LongTensor(inputs_padded), torch.LongTensor(inputs_attention_mask), correct_rels #torch.IntTensor(correct_rels_padded), torch.LongTensor(correct_rels_attention_mask)
```
%% Cell type:code id: tags:
``` python
# training_args = Seq2SeqTrainingArguments(
# output_dir='./models/blackbox',
# num_train_epochs=1,
# per_device_train_batch_size=1,
# per_device_eval_batch_size=1,
# warmup_steps=10,
# weight_decay=0.01,
# logging_dir='./logs',
# )
```
%% Cell type:code id: tags:
``` python
model = NgmOne()
model = NgmOne(device)
EPOCHS = 3
EPOCHS = 300
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
optimizer = optim.Adam(model.parameters(), lr=0.05)
with open("../data/relations-query-qald-9-linked.json", "r") as f:
relations = json.load(f)
train,train_mask, corr_rels, correct_rels_mask = make_batch()
for epoch in tqdm(range(EPOCHS)):
train, train_mask, corr_rels = make_batch()
corr_indx = torch.LongTensor([relations.index(r) for r in corr_rels]).to(device)
for epoch in range(EPOCHS):
optimizer.zero_grad()
# Forward pass
output = model(train, train_mask)
loss = criterion(output, corr_rels)
loss = criterion(output, corr_indx)
if (epoch + 1) % 10 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
if (epoch + 1) % 1 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
# Backward pass
loss.backward()
optimizer.step()
```
%% Output
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias']
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Beginning making batch
100%|██████████| 408/408 [00:00<00:00, 688.03it/s]
100%|██████████| 408/408 [00:00<00:00, 2241.79it/s]
100%|██████████| 408/408 [00:00<00:00, 2684.45it/s]
100%|██████████| 408/408 [00:00<00:00, 204160.82it/s]
Finished with batches
Epoch: 0001 loss = 5.509346
Epoch: 0002 loss = 5.508417
Epoch: 0003 loss = 5.507416
Epoch: 0004 loss = 5.506336
Epoch: 0005 loss = 5.505169
Epoch: 0006 loss = 5.503905
Epoch: 0007 loss = 5.502542
Epoch: 0008 loss = 5.501083
Epoch: 0009 loss = 5.499537
Epoch: 0010 loss = 5.497901
Epoch: 0011 loss = 5.496177
Epoch: 0012 loss = 5.494359
Epoch: 0013 loss = 5.492444
Epoch: 0014 loss = 5.490432
Epoch: 0015 loss = 5.488323
Epoch: 0016 loss = 5.486121
Epoch: 0017 loss = 5.483822
Epoch: 0018 loss = 5.481424
Epoch: 0019 loss = 5.478921
Epoch: 0020 loss = 5.476310
Epoch: 0021 loss = 5.473589
Epoch: 0022 loss = 5.470768
Epoch: 0023 loss = 5.467863
Epoch: 0024 loss = 5.464887
Epoch: 0025 loss = 5.461853
Epoch: 0026 loss = 5.458760
Epoch: 0027 loss = 5.455590
Epoch: 0028 loss = 5.452313
Epoch: 0029 loss = 5.448884
Epoch: 0030 loss = 5.445263
Epoch: 0031 loss = 5.441427
Epoch: 0032 loss = 5.437370
Epoch: 0033 loss = 5.433111
Epoch: 0034 loss = 5.428682
Epoch: 0035 loss = 5.424133
Epoch: 0036 loss = 5.419532
Epoch: 0037 loss = 5.414931
Epoch: 0038 loss = 5.410347
Epoch: 0039 loss = 5.405736
Epoch: 0040 loss = 5.401030
Epoch: 0041 loss = 5.396172
Epoch: 0042 loss = 5.391167
Epoch: 0043 loss = 5.386050
Epoch: 0044 loss = 5.380856
Epoch: 0045 loss = 5.375600
Epoch: 0046 loss = 5.370294
Epoch: 0047 loss = 5.364953
Epoch: 0048 loss = 5.359621
Epoch: 0049 loss = 5.354363
Epoch: 0050 loss = 5.349256
Epoch: 0051 loss = 5.344321
Epoch: 0052 loss = 5.339508
Epoch: 0053 loss = 5.334712
Epoch: 0054 loss = 5.329853
Epoch: 0055 loss = 5.324935
Epoch: 0056 loss = 5.320047
Epoch: 0057 loss = 5.315296
Epoch: 0058 loss = 5.310766
Epoch: 0059 loss = 5.306486
Epoch: 0060 loss = 5.302420
Epoch: 0061 loss = 5.298513
Epoch: 0062 loss = 5.294705
Epoch: 0063 loss = 5.290928
Epoch: 0064 loss = 5.287130
Epoch: 0065 loss = 5.283293
Epoch: 0066 loss = 5.279424
Epoch: 0067 loss = 5.275546
Epoch: 0068 loss = 5.271703
Epoch: 0069 loss = 5.267990
Epoch: 0070 loss = 5.264547
Epoch: 0071 loss = 5.261484
Epoch: 0072 loss = 5.258725
Epoch: 0073 loss = 5.256050
Epoch: 0074 loss = 5.253313
Epoch: 0075 loss = 5.250482
Epoch: 0076 loss = 5.247588
Epoch: 0077 loss = 5.244689
Epoch: 0078 loss = 5.241833
Epoch: 0079 loss = 5.239061
Epoch: 0080 loss = 5.236411
Epoch: 0081 loss = 5.233915
Epoch: 0082 loss = 5.231617
Epoch: 0083 loss = 5.229539
Epoch: 0084 loss = 5.227572
Epoch: 0085 loss = 5.225513
Epoch: 0086 loss = 5.223309
Epoch: 0087 loss = 5.221033
Epoch: 0088 loss = 5.218764
Epoch: 0089 loss = 5.216528
Epoch: 0090 loss = 5.214312
Epoch: 0091 loss = 5.212102
Epoch: 0092 loss = 5.209961
Epoch: 0093 loss = 5.208022
Epoch: 0094 loss = 5.206336
Epoch: 0095 loss = 5.204650
Epoch: 0096 loss = 5.202743
Epoch: 0097 loss = 5.200671
Epoch: 0098 loss = 5.198618
Epoch: 0099 loss = 5.196756
Epoch: 0100 loss = 5.195175
Epoch: 0101 loss = 5.193856
Epoch: 0102 loss = 5.192677
Epoch: 0103 loss = 5.191495
Epoch: 0104 loss = 5.190241
Epoch: 0105 loss = 5.188928
Epoch: 0106 loss = 5.187597
Epoch: 0107 loss = 5.186284
Epoch: 0108 loss = 5.184998
Epoch: 0109 loss = 5.183724
Epoch: 0110 loss = 5.182417
Epoch: 0111 loss = 5.181022
Epoch: 0112 loss = 5.179502
Epoch: 0113 loss = 5.177876
Epoch: 0114 loss = 5.176267
Epoch: 0115 loss = 5.174932
Epoch: 0116 loss = 5.174131
Epoch: 0117 loss = 5.173597
Epoch: 0118 loss = 5.172731
Epoch: 0119 loss = 5.171515
Epoch: 0120 loss = 5.170227
Epoch: 0121 loss = 5.169064
Epoch: 0122 loss = 5.168075
Epoch: 0123 loss = 5.167188
Epoch: 0124 loss = 5.166307
Epoch: 0125 loss = 5.165361
Epoch: 0126 loss = 5.164318
Epoch: 0127 loss = 5.163185
Epoch: 0128 loss = 5.162001
Epoch: 0129 loss = 5.160835
Epoch: 0130 loss = 5.159757
Epoch: 0131 loss = 5.158821
Epoch: 0132 loss = 5.158039
Epoch: 0133 loss = 5.157377
Epoch: 0134 loss = 5.156784
Epoch: 0135 loss = 5.156210
Epoch: 0136 loss = 5.155619
Epoch: 0137 loss = 5.154994
Epoch: 0138 loss = 5.154339
Epoch: 0139 loss = 5.153667
Epoch: 0140 loss = 5.152996
Epoch: 0141 loss = 5.152328
Epoch: 0142 loss = 5.151660
Epoch: 0143 loss = 5.150975
Epoch: 0144 loss = 5.150242
Epoch: 0145 loss = 5.149413
Epoch: 0146 loss = 5.148431
Epoch: 0147 loss = 5.147236
Epoch: 0148 loss = 5.145802
Epoch: 0149 loss = 5.144192
Epoch: 0150 loss = 5.142632
Epoch: 0151 loss = 5.141498
Epoch: 0152 loss = 5.140967
Epoch: 0153 loss = 5.140542
Epoch: 0154 loss = 5.139736
Epoch: 0155 loss = 5.138566
Epoch: 0156 loss = 5.137214
Epoch: 0157 loss = 5.135843
Epoch: 0158 loss = 5.134551
Epoch: 0159 loss = 5.133349
Epoch: 0160 loss = 5.132212
Epoch: 0161 loss = 5.131142
Epoch: 0162 loss = 5.130210
Epoch: 0163 loss = 5.129475
Epoch: 0164 loss = 5.128819
Epoch: 0165 loss = 5.128070
Epoch: 0166 loss = 5.127200
Epoch: 0167 loss = 5.126271
Epoch: 0168 loss = 5.125329
Epoch: 0169 loss = 5.124373
Epoch: 0170 loss = 5.123369
Epoch: 0171 loss = 5.122267
Epoch: 0172 loss = 5.121037
Epoch: 0173 loss = 5.119725
Epoch: 0174 loss = 5.118514
Epoch: 0175 loss = 5.117722
Epoch: 0176 loss = 5.117445
Epoch: 0177 loss = 5.117110
Epoch: 0178 loss = 5.116376
Epoch: 0179 loss = 5.115437
Epoch: 0180 loss = 5.114553
Epoch: 0181 loss = 5.113859
Epoch: 0182 loss = 5.113349
Epoch: 0183 loss = 5.112945
Epoch: 0184 loss = 5.112563
Epoch: 0185 loss = 5.112154
Epoch: 0186 loss = 5.111700
Epoch: 0187 loss = 5.111213
Epoch: 0188 loss = 5.110720
Epoch: 0189 loss = 5.110251
Epoch: 0190 loss = 5.109831
Epoch: 0191 loss = 5.109465
Epoch: 0192 loss = 5.109145
Epoch: 0193 loss = 5.108854
Epoch: 0194 loss = 5.108573
Epoch: 0195 loss = 5.108290
Epoch: 0196 loss = 5.108000
Epoch: 0197 loss = 5.107710
Epoch: 0198 loss = 5.107430
Epoch: 0199 loss = 5.107163
Epoch: 0200 loss = 5.106915
Epoch: 0201 loss = 5.106686
Epoch: 0202 loss = 5.106472
Epoch: 0203 loss = 5.106268
Epoch: 0204 loss = 5.106072
Epoch: 0205 loss = 5.105881
Epoch: 0206 loss = 5.105694
Epoch: 0207 loss = 5.105511
Epoch: 0208 loss = 5.105332
Epoch: 0209 loss = 5.105159
Epoch: 0210 loss = 5.104992
Epoch: 0211 loss = 5.104831
Epoch: 0212 loss = 5.104676
Epoch: 0213 loss = 5.104527
Epoch: 0214 loss = 5.104386
Epoch: 0215 loss = 5.104248
Epoch: 0216 loss = 5.104114
Epoch: 0217 loss = 5.103983
Epoch: 0218 loss = 5.103856
Epoch: 0219 loss = 5.103732
Epoch: 0220 loss = 5.103611
Epoch: 0221 loss = 5.103494
Epoch: 0222 loss = 5.103378
Epoch: 0223 loss = 5.103266
Epoch: 0224 loss = 5.103157
Epoch: 0225 loss = 5.103050
Epoch: 0226 loss = 5.102947
Epoch: 0227 loss = 5.102844
Epoch: 0228 loss = 5.102745
Epoch: 0229 loss = 5.102647
Epoch: 0230 loss = 5.102551
Epoch: 0231 loss = 5.102457
Epoch: 0232 loss = 5.102361
Epoch: 0233 loss = 5.102268
Epoch: 0234 loss = 5.102171
Epoch: 0235 loss = 5.102070
Epoch: 0236 loss = 5.101960
Epoch: 0237 loss = 5.101832
Epoch: 0238 loss = 5.101676
Epoch: 0239 loss = 5.101472
Epoch: 0240 loss = 5.101196
Epoch: 0241 loss = 5.100809
Epoch: 0242 loss = 5.100273
Epoch: 0243 loss = 5.099566
Epoch: 0244 loss = 5.098716
Epoch: 0245 loss = 5.097848
Epoch: 0246 loss = 5.097132
Epoch: 0247 loss = 5.096595
Epoch: 0248 loss = 5.095999
Epoch: 0249 loss = 5.095338
Epoch: 0250 loss = 5.094879
Epoch: 0251 loss = 5.094688
Epoch: 0252 loss = 5.094509
Epoch: 0253 loss = 5.094189
Epoch: 0254 loss = 5.093769
Epoch: 0255 loss = 5.093320
Epoch: 0256 loss = 5.092857
Epoch: 0257 loss = 5.092339
Epoch: 0258 loss = 5.091694
Epoch: 0259 loss = 5.090842
Epoch: 0260 loss = 5.089743
Epoch: 0261 loss = 5.088472
Epoch: 0262 loss = 5.087311
Epoch: 0263 loss = 5.086748
Epoch: 0264 loss = 5.086977
Epoch: 0265 loss = 5.087138
Epoch: 0266 loss = 5.086704
Epoch: 0267 loss = 5.085968
Epoch: 0268 loss = 5.085289
Epoch: 0269 loss = 5.084831
Epoch: 0270 loss = 5.084577
Epoch: 0271 loss = 5.084425
Epoch: 0272 loss = 5.084273
Epoch: 0273 loss = 5.084045
Epoch: 0274 loss = 5.083702
Epoch: 0275 loss = 5.083218
Epoch: 0276 loss = 5.082567
Epoch: 0277 loss = 5.081728
Epoch: 0278 loss = 5.080690
Epoch: 0279 loss = 5.079515
Epoch: 0280 loss = 5.078421
Epoch: 0281 loss = 5.077838
Epoch: 0282 loss = 5.078005
Epoch: 0283 loss = 5.078021
Epoch: 0284 loss = 5.077395
Epoch: 0285 loss = 5.076532
Epoch: 0286 loss = 5.075821
Epoch: 0287 loss = 5.075374
Epoch: 0288 loss = 5.075100
Epoch: 0289 loss = 5.074864
Epoch: 0290 loss = 5.074569
Epoch: 0291 loss = 5.074176
Epoch: 0292 loss = 5.073692
Epoch: 0293 loss = 5.073163
Epoch: 0294 loss = 5.072643
Epoch: 0295 loss = 5.072181
Epoch: 0296 loss = 5.071800
Epoch: 0297 loss = 5.071474
Epoch: 0298 loss = 5.071146
Epoch: 0299 loss = 5.070775
Epoch: 0300 loss = 5.070364
%% Cell type:code id: tags:
``` python
# Predict
train, train_mask, corr_rels = make_batch()
with torch.no_grad():
output = model(train, train_mask)
output = output.detach().cpu().numpy()
prediction = [relations[np.argmax(pred).item()]for pred in output]
probability = [pred[np.argmax(pred)] for pred in output]
correct_pred = [corr_rels[i] for i in range(len(output))]
0%| | 0/3 [03:03<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In [26], line 13
11 # Forward pass
12 output = model(train, train_mask)
---> 13 loss = criterion(output, corr_rels)
15 if (epoch + 1) % 10 == 0:
16 print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
File c:\Users\maxbj\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File c:\Users\maxbj\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\loss.py:1164, in CrossEntropyLoss.forward(self, input, target)
1163 def forward(self, input: Tensor, target: Tensor) -> Tensor:
-> 1164 return F.cross_entropy(input, target, weight=self.weight,
1165 ignore_index=self.ignore_index, reduction=self.reduction,
1166 label_smoothing=self.label_smoothing)
File c:\Users\maxbj\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\functional.py:3014, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
3012 if size_average is not None or reduce is not None:
3013 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3014 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: 0D or 1D target tensor expected, multi-target not supported
preds = [
{"pred": relations[np.argmax(pred).item()],
"prob": pred[np.argmax(pred)],
"correct": corr_rels[count]}
for count, pred in enumerate(output)
]
# for pred in preds:
# print("pred", pred["pred"], "prob", pred["prob"], "correct", pred["correct"])
# for pred, prob, correct_pred in zip(prediction, probability, correct_pred):
# print("pred:", pred, "prob:", prob, "| correct", correct_pred)
print("lowest confidence", min(probability))
def accuracy_score(y_true, y_pred):
corr_preds=0
wrong_preds=0
for pred, correct in zip(y_pred, y_true):
if pred == correct:
corr_preds += 1
else:
wrong_preds += 1
return corr_preds/(corr_preds+wrong_preds)
print("Accuracy:", accuracy_score(correct_pred, prediction))
```
%% Output
Beginning making batch
100%|██████████| 408/408 [00:00<00:00, 814.37it/s]
100%|██████████| 408/408 [00:00<00:00, 101922.34it/s]
Finished with batches
lowest confidence 0.19197923
Accuracy: 0.6323529411764706
KeyboardInterrupt
%% Cell type:code id: tags:
``` python
```
......
[
"dbo:foundationPlace",
"dbo:founder",
"<http://dbpedia.org/ontology/officialSchoolColour>",
"<http://dbpedia.org/property/ethnicity>",
"<http://dbpedia.org/ontology/mission>",
"<http://dbpedia.org/property/admittancedate>",
"dbo:writer",
"<http://dbpedia.org/property/carbs>",
"<http://dbpedia.org/property/borderingstates>",
"dbo:party",
"dbo:growingGrape",
"<http://dbpedia.org/property/residence>",
"dbo:wineRegion",
"dbo:language",
"dbo:officialLanguage",
"dbo:influencedBy",
"dbo:routeEnd",
"<http://dbpedia.org/ontology/netIncome>",
"dbo:bandMember",
"dbo:team",
"dbo:origin",
"<http://dbpedia.org/property/launchPad>",
"dbp:species",
"dbo:composer",
"<http://dbpedia.org/property/author>",
"<http://dbpedia.org/ontology/creator>",
"dbo:developer",
"<http://dbpedia.org/ontology/owner>",
"<http://dbpedia.org/ontology/spouse>",
"dbo:timeZone",
"<http://dbpedia.org/property/governor>",
"dbo:numberOfPages",
"dbo:deathCause",
"dbo:award",
"dbo:activeYearsEndDate",
"dbo:governmentType",
"dbo:maximumDepth",
"dbo:owner",
"dbo:leader",
"dbo:birthDate",
"dbo:editor",
"dbo:knownFor",
"dbo:starring",
"dbo:targetAirport",
"<http://dbpedia.org/property/founded>",
"<http://dbpedia.org/property/leaderParty>",
"<http://dbpedia.org/property/speciality>",
"dbo:country",
"dbo:runtime",
"<http://dbpedia.org/ontology/spokenIn>",
"dbo:battle",
"dbo:discoverer",
"dbo:portrayer",
"dbo:areaTotal",
"<http://dbpedia.org/ontology/portrayer>",
"dbo:height",
"dbo:spouse",
"dbo:abbreviation",
"<http://dbpedia.org/ontology/profession>",
"dbo:musicComposer",
"dbo:firstAscentPerson",
"dbo:publisher",
"<http://dbpedia.org/ontology/deathPlace>",
"<http://dbpedia.org/property/largestmetro>",
"dbp:writer",
"dbo:ingredient",
"dbo:class",
"<http://dbpedia.org/property/breed>",
"dbo:director",
"dbo:alias",
"<http://dbpedia.org/ontology/currency>",
"<http://dbpedia.org/ontology/populationTotal>",
"<http://dbpedia.org/property/successor>",
"<http://dbpedia.org/ontology/manager>",
"dbo:mayor",
"dbo:date",
"dbp:editor",
"a",
"<http://dbpedia.org/ontology/foundedBy>",
"dbo:completionDate",
"dbp:populationDensityRank",
"dbo:presenter",
"<http://dbpedia.org/ontology/instrument>",
"<http://dbpedia.org/ontology/numberOfEmployees>",
"dbo:series",
"dbo:creator",
"<http://dbpedia.org/ontology/producer>",
"dbo:leaderName",
"<http://dbpedia.org/property/children>",
"<http://dbpedia.org/property/title>",
"<http://dbpedia.org/property/fifaMin>",
"dbo:crosses",
"dbo:mission",
"dbo:architect",
"dbo:largestCity",
"dbo:budget",
"<http://dbpedia.org/property/birthName>",
"dbo:populationTotal",
"dbo:foundingDate",
"dbo:vicePresident",
"<http://dbpedia.org/ontology/genre>",
"dbo:product",
"dbo:type",
"dbo:state",
"dbo:influenced",
"dbo:doctoralAdvisor",
"dbo:numberOfLocations",
"dbo:successor",
"<http://dbpedia.org/property/ballpark>",
"<http://dbpedia.org/ontology/country>",
"dbo:sourceCountry",
"dbo:birthPlace",
"<http://dbpedia.org/property/highest>",
"<http://dbpedia.org/ontology/child>",
"dbo:birthName",
"dbo:child",
"dbo:deathPlace",
"<http://dbpedia.org/ontology/abbreviation>",
"dbo:dissolutionDate",
"dbo:author",
"dbo:birthYear",
"<http://dbpedia.org/property/beginningDate>",
"dbo:ethnicGroup",
"dbo:currency",
"<http://dbpedia.org/property/programme>",
"<http://dbpedia.org/property/shipNamesake>",
"dbo:capital",
"dbo:programmingLanguage",
"dbo:city",
"<http://dbpedia.org/ontology/deathDate>",
"dbo:almaMater",
"<http://dbpedia.org/property/employees>",
"dbo:location",
"<http://dbpedia.org/property/accessioneudate>",
"rdf:type",
"<http://dbpedia.org/ontology/developer>",
"dbo:restingPlace"
]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment