Skip to content
Snippets Groups Projects
Commit 465ac813 authored by Max Björkander's avatar Max Björkander
Browse files

konstigt loss func

parent f9949bd5
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id: tags:
``` python
import datasets
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertModel
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from tqdm import tqdm
import json
import requests
```
%% Cell type:code id: tags:
``` python
mod = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(mod)
```
%% Cell type:code id: tags:
``` python
# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Print cuda version
print(device)
```
%% Output
cuda
%% Cell type:code id: tags:
``` python
class NgmOne(nn.Module):
def __init__(self, device, relations):
super(NgmOne, self).__init__()
with torch.no_grad():
self.tokenizer = BertTokenizer.from_pretrained(mod)
self.bert = BertModel.from_pretrained(mod).to(device)
for param in self.bert.parameters():
param.requires_grad = False
self.linear = nn.Linear(768, len(relations)).to(device)
self.softmax = nn.Softmax(dim=1).to(device)
self.device = device
def forward(self, tokenized_seq, tokenized_mask):
tokenized_seq= tokenized_seq.to(self.device)
tokenized_mask = tokenized_mask.to(self.device)
with torch.no_grad():
x = self.bert.forward(tokenized_seq, attention_mask=tokenized_mask)
x = x[0][:,0,:].to(self.device)
x = self.linear(x)
x = self.softmax(x)
return x
```
%% Cell type:code id: tags:
``` python
prefixes = {
"http://dbpedia.org/resource/": "res:",
"http://dbpedia.org/ontology/": "dbo:",
"http://dbpedia.org/property/": "dbp:",
"http://www.w3.org/2000/01/rdf-schema#": "rdfs:",
"http://www.w3.org/1999/02/22-rdf-syntax-ns#": "rdf:",
"http://dbpedia.org/class/yago/": "yago:",
"http://www.wikidata.org/prop/direct/": "wdt:",
"http://www.wikidata.org/entity/": "wd:",
"http://www.wikidata.org/prop/": "p:",
"https://w3id.org/payswarm#": "ps:",
"http://www.wikidata.org/prop/qualifier/": "pq:",
"http://www.bigdata.com/rdf#": "bd:",
"http://wikiba.se/ontology#": "wikibase:",
"http://www.w3.org/2004/02/skos/core#": "skos:",
}
prefixes_reverse = {v: k for k, v in prefixes.items()}
# {
# "res:": "http://dbpedia.org/resource/",
# "dbo:": "http://dbpedia.org/ontology/",
# "dbp:": "http://dbpedia.org/property/",
# "rdfs:": "http://www.w3.org/2000/01/rdf-schema#",
# "rdf:": "http://www.w3.org/1999/02/22-rdf-syntax-ns#",
# "yago:": "http://dbpedia.org/class/yago/",
# "wdt:": "http://www.wikidata.org/prop/direct/",
# "wd:": "http://www.wikidata.org/entity/",
# "p:": "http://www.wikidata.org/prop/",
# "ps:": "https://w3id.org/payswarm#",
# "pq:": "http://www.wikidata.org/prop/qualifier/",
# "bd:": "http://www.bigdata.com/rdf#",
# "wikibase:": "http://wikiba.se/ontology#",
# "skos:": "http://www.w3.org/2004/02/skos/core#",
# }
prefixes_end = ["org", "se", "com", "nu"]
ALL_HTTP_PREFIXES = False
def make_batch(src, http_prefix=False):
"""Triplet is a list of [subject entity, relation, object entity], None if not present"""
pred = src
gold = src
# Load predicted data
# "../data/qald-9-train-linked.json"
#pred = "../LC-QuAD/combined-requeried-linked-train.json"
#Load gold data
# "../LC-QuAD/combined-requeried-linked-train.json"
#gold = "../LC-QuAD/combined-requeried-linked-train.json"
print("Beginning making batch")
with open(pred, "r") as p, open(gold, "r") as g:
pred = json.load(p)
gold = json.load(g)
inputs = []
correct_rels = []
sub_obj_ents = []
inputs_max_len = 0
for d in tqdm(pred["questions"]):
question = d["question"][0]["string"]
query = d["query"]["sparql"]
#Take the first tripletin query
trip = query.split("WHERE {")[1]
trip = trip.split("}")[0]
trip = trip.split("FILTER ")[0]
trip = trip.replace("{", "").replace("}", "")
trip = trip.replace(".", "")
trip = trip.replace(";", "")
triplet = trip.split(" ")
#remove empty strings
triplet = [x for x in triplet if x != ""]
if len(triplet) % 3 == 0 and " ".join(triplet).find("rdf") == -1:
for i in range(len(triplet)//3):
triplet_i = triplet[i*3:i*3+3]
for t in triplet_i:
if not(t.find("?")):
triplet_i[triplet_i.index(t)] = None
elif http_prefix:
n = t.replace("<", "").replace(">", "")
n_sub = n.split("/")[-1]
for i in prefixes_end:
if n.find(i) != -1:
n_ = "".join(n.split(i)[0]) + "." + i + "".join(n.split(i)[1:])
break
for i in prefixes:
if n_.find(i) != -1:
new = prefixes[i]
break
#n_other = "/".join(n_.split("/")[0:-1]) + "/"
#new = prefixes[n_other]
triplet_i[triplet_i.index(t)] = new + n_sub
elif t.find("http") != -1:
for i in prefixes_end:
if t.find(i) != -1:
n_ = "".join(t.split(i)[0]) + "." + i + "".join(t.split(i)[1:])
break
triplet_i[triplet_i.index(t)] = n_
#seq = "[CLS] " + question + " [SEP] "
if triplet_i[0] is not None and triplet_i[1] is not None:
#seq += "[SUB] [SEP] " + triplet[0]
# , padding=True, truncation=True)
tokenized_seq = tokenizer(question.lower(), "[OBJ] [SEP] " + triplet_i[0].split(":")[1].lower(), padding=True, truncation=True)
tokenized_seq = tokenizer(question.lower(), "[SUB] [SEP] " + triplet_i[0].split(":")[1].lower(), padding=True, truncation=True)
sub_obj_ents.append("[SUB] " + prefixes_reverse["".join(triplet_i[0].split(":")[0]) + ":"] + triplet_i[0].split(":")[1])
elif triplet_i[2] is not None and triplet_i[1] is not None:
#seq += "[OBJ] [SEP] " + triplet[2]
tokenized_seq = tokenizer(question.lower(), "[OBJ] [SEP] " + triplet_i[2].split(":")[1].lower(), padding=True, truncation=True)
sub_obj_ents.append("[OBJ] " + prefixes_reverse["".join(triplet_i[2].split(":")[0]) + ":"] + triplet_i[2].split(":")[1])
else:
continue
if inputs_max_len < len(tokenized_seq["input_ids"]):
inputs_max_len = len(tokenized_seq["input_ids"])
inputs.append(list(tokenized_seq.values())[0])
correct_rels.append(triplet_i[1].lower())
inputs_padded = np.array([i + [0]*(inputs_max_len-len(i)) for i in inputs])
#correct_rels_padded = np.array([i + [0]*(correct_rels_max_len-len(i)) for i in correct_rels])
inputs_attention_mask = np.where(inputs_padded != 0, 1, 0)
#correct_rels_attention_mask = np.where(correct_rels_padded != 0, 1, 0)
print("Finished with batches")
return torch.LongTensor(inputs_padded), torch.LongTensor(inputs_attention_mask), correct_rels, sub_obj_ents #torch.IntTensor(correct_rels_padded), torch.LongTensor(correct_rels_attention_mask)
```
%% Cell type:code id: tags:
``` python
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, inputs, attention_mask, correct_rels,ents, relations):
self.inputs = inputs
self.attention_mask = attention_mask
self.correct_rels = correct_rels
self.ents = ents
self.relations = relations
def __len__(self):
return len(self.inputs)
def __getitem__(self, idx):
return self.inputs[idx], self.attention_mask[idx], self.relations.index(self.correct_rels[idx]), self.ents[idx]
```
%% Cell type:code id: tags:
``` python
from torch.utils.data import random_split
#Prepare data
def open_json(file):
with open(file, "r") as f:
return json.load(f)
#relations = open_json("../data/relations-query-qald-9-linked.json")
relations = open_json("../data/relations-all-no-http-lowercase.json")
# "../data/qald-9-train-linked.json"
#pred = "../LC-QuAD/combined-requeried-linked-train.json"
inputs, attention_mask, correct_rels, sub_objs = make_batch(src="../LC-QuAD/combined-requeried-linked-train.json", http_prefix = True) #train
# relations = open_json("../data/relations-lcquad-without-http-train-linked.json")
# train_set = MyDataset(*make_batch(), relations=relations)
dataset = MyDataset(inputs, attention_mask, correct_rels,sub_objs, relations=relations)
train_size = int(0.8 * len(dataset))
valid_size = len(dataset) - train_size
train_data, valid_data = random_split(dataset, [train_size, valid_size], generator=torch.Generator().manual_seed(42))
train_dataloader = DataLoader(train_data, batch_size=1, shuffle=True)
#show first entry
train_features, train_mask, train_label, ents = next(iter(train_dataloader))
print("features:", tokenizer.batch_decode(train_features), "mask:",train_mask,"label_index", train_label[0])
valid_dataloader = DataLoader(valid_data, batch_size=1, shuffle=True)
valid_features, valid_mask, valid_label, ents = next(iter(valid_dataloader))
print("valid features:", valid_features, "valid mask:",valid_mask,"valid label_index", valid_label[0])
```
%% Output
Beginning making batch
100%|██████████| 2052/2052 [00:00<00:00, 2906.57it/s]
100%|██████████| 2052/2052 [00:01<00:00, 1068.18it/s]
Finished with batches
features: ['[CLS] name the home stadium of fc spartak moscow season 2011 - 12 [SEP] [ obj ] [SEP] 2011 – 12 _ fc _ spartak _ moscow _ season [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'] mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) label_index tensor(431)
valid features: tensor([[ 101, 2054, 3063, 2001, 3378, 4876, 2007, 1996, 6436, 26785,
7971, 6447, 1998, 2600, 1037, 20160, 10362, 1029, 102, 1031,
27885, 3501, 1033, 102, 2600, 1011, 1037, 1011, 20160, 10362,
102, 0, 0, 0, 0, 0, 0, 0, 0, 0,
features: ['[CLS] what ingredients are used in preparing the dish of ragout fin? [SEP] [ sub ] [SEP] ragout _ fin [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'] mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0]]) label_index tensor(336)
valid features: tensor([[ 101, 2054, 2003, 1996, 2344, 1997, 2577, 10424, 2483, 11283,
7570, 2906, 1029, 102, 1031, 4942, 1033, 102, 2577, 1035,
10424, 2483, 11283, 1035, 7570, 2906, 102, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0]]) valid mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) valid label_index tensor(149)
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0]]) valid mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0]]) valid label_index tensor(297)
%% Cell type:code id: tags:
``` python
SPARQL_ENDPOINT = "https://dbpedia.org/sparql"
# test_data_file = sys.argv[1]
# predicted_data_file = sys.argv[2]
headers = {
'Accept': 'application/sparql-results+json',
'Content-Type': 'application/x-www-form-urlencoded',
}
# Initialize model
model = NgmOne(device, relations)
```
%% Output
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
%% Cell type:code id: tags:
``` python
# Train with data loader.
criterion = nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss(reduction="none")
optimizer = optim.Adam(model.parameters(), lr=0.001)
#optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.5)
epoch = 10
batch_size = 8
alpha = 0.5
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=True)
model.train()
for e in range(epoch):
train_loss_epoch = 0
valid_loss_epoch = 0
for i_train, sample_batched_train in enumerate(train_dataloader):
optimizer.zero_grad()
train = sample_batched_train[0]
train_mask = sample_batched_train[1]
label_index = sample_batched_train[2].to(device)
sub_objs = sample_batched_train[3]
# Forward pass
output = model(train, train_mask)
loss_gs = []
for j in range(len(sub_objs)):
if not (sub_objs[j].split(" ")[0] == "[SUB]" or sub_objs[j].split(" ")[0] == "[OBJ]"):
continue
if sub_objs[j].split(" ")[0] == "[SUB]":
sub = sub_objs[j].split(" ")[1]
q = "SELECT ?r WHERE { <" + sub + "> ?r ?o }"
if sub_objs[j].split(" ")[0] == "[OBJ]":
obj = sub_objs[j].split(" ")[1]
q = "SELECT ?r WHERE { ?s ?r <" + obj + "> }"
params = {
"default-graph-uri": "http://dbpedia.org",
"query": q,
"format": "json"
}
response = requests.get(SPARQL_ENDPOINT, headers=headers, params=params, timeout=15)
results = response.json()
res_rels = {}
for i in range(len(list(results["results"].values())[2])):
res_rels[list(results["results"].values())[2][i]["r"]["value"].lower()] = "True"
loss_rels = []
for i in range(len(relations)):
if prefixes_reverse["".join(relations[i].split(":")[0]) + ":"] + relations[i].split(":")[1] in list(res_rels.keys()):
loss_rels.append(5)
loss_rels.append(0)
else:
loss_rels.append(1/20)
loss_rels.append(1)
loss_gs.append(loss_rels)
loss_gs = torch.FloatTensor(loss_gs).to(device)
#print(response)
output_gs = output * (1-alpha) + (loss_gs) * alpha
loss = criterion(output_gs, label_index)
preds = [relations[np.argmax(pred).item()]for pred in output.detach().cpu().numpy()]
relation_loss = []
for i in range(len(preds)):
if prefixes_reverse["".join(preds[i].split(":")[0]) + ":"] + preds[i].split(":")[1] in list(res_rels.keys()):
relation_loss.append(0)
else:
relation_loss.append(1)
relation_loss = torch.FloatTensor(relation_loss).to(device)
loss = criterion(output, label_index)
loss = loss * (1-alpha) + (relation_loss) * alpha
# backward and optimize
loss = loss.mean()
loss.backward()
optimizer.step()
train_loss_epoch = train_loss_epoch + loss.item()
for i_valid, sample_batched_valid in enumerate(valid_dataloader):
valid = sample_batched_valid[0]
valid_mask = sample_batched_valid[1]
label_index = sample_batched_valid[2].to(device)
# Forward pass
with torch.no_grad():
output = model(valid, valid_mask)
loss = criterion(output, label_index)
valid_loss_epoch = valid_loss_epoch + loss.item()
valid_loss_epoch = valid_loss_epoch + loss.mean().item()
print(e+1, "Train", train_loss_epoch/i_train, ", Valid ", valid_loss_epoch/i_valid)
```
%% Output
1 Train 4.168464240410345 , Valid 6.253535368863274
2 Train 4.158141554688378 , Valid 6.238900885862463
3 Train 4.135342076527986 , Valid 6.225469000199261
4 Train 4.12202568705991 , Valid 6.212939725202673
5 Train 4.109015821552963 , Valid 6.205244933857637
6 Train 4.098594773587563 , Valid 6.194438050774967
7 Train 4.08812516884838 , Valid 6.178463767556583
8 Train 4.079949097667667 , Valid 6.175140450982487
9 Train 4.07018859094853 , Valid 6.167128044016221
10 Train 4.065771084037616 , Valid 6.161452812307021
1 Train 3.517506525670882 , Valid 6.241430044174194
2 Train 3.514909356618099 , Valid 6.230757559047026
3 Train 3.494271710622225 , Valid 6.218190235250137
4 Train 3.487279334514261 , Valid 6.216006854001214
5 Train 3.4676955844000945 , Valid 6.193478163550882
6 Train 3.4382689976863725 , Valid 6.1805572509765625
7 Train 3.442131721716133 , Valid 6.178038597106934
8 Train 3.433108813471074 , Valid 6.17140617090113
9 Train 3.424916239951154 , Valid 6.16327678456026
10 Train 3.4210985067079394 , Valid 6.16277868607465
%% Cell type:code id: tags:
``` python
# Predict
train, train_mask, corr_rels,ents = make_batch(src="../LC-QuAD/combined-requeried-linked-train.json", http_prefix = True)
test, test_mask, corr_rels_test, ents_test = make_batch(src="../LC-QuAD/combined-requeried-linked-test.json", http_prefix = True)
test_data = MyDataset(test, test_mask, corr_rels_test, ents=ents_test, relations=relations)
test_dataloader = DataLoader(test_data, batch_size=len(test_data), shuffle=True)
test_batch, test_mask_batch, corr_rels_test_batch, sub_objs = next(iter(test_dataloader))
corr_rels_test_batch = corr_rels_test_batch.to(device)
with torch.no_grad():
output_train = model(train, train_mask)
output_test = model(test_batch, test_mask_batch)
loss = criterion(output_test, corr_rels_test_batch)
print("test loss", loss.item())
print("test loss", loss.mean().item())
loss_gs = []
for j in range(len(sub_objs)):
if not (sub_objs[j].split(" ")[0] == "[SUB]" or sub_objs[j].split(" ")[0] == "[OBJ]"):
continue
if sub_objs[j].split(" ")[0] == "[SUB]":
sub = sub_objs[j].split(" ")[1]
q = "SELECT ?r WHERE { <" + sub + "> ?r ?o }"
if sub_objs[j].split(" ")[0] == "[OBJ]":
obj = sub_objs[j].split(" ")[1]
q = "SELECT ?r WHERE { ?s ?r <" + obj + "> }"
params = {
"default-graph-uri": "http://dbpedia.org",
"query": q,
"format": "json"
}
response = requests.get(
SPARQL_ENDPOINT, headers=headers, params=params, timeout=15)
results = response.json()
res_rels = {}
for i in range(len(list(results["results"].values())[2])):
res_rels[list(results["results"].values())[2][i]
["r"]["value"].lower()] = "True"
loss_rels = []
for i in range(len(relations)):
if prefixes_reverse["".join(relations[i].split(":")[0]) + ":"] + relations[i].split(":")[1] in list(res_rels.keys()):
loss_rels.append(5)
else:
loss_rels.append(1/20)
loss_gs.append(loss_rels)
loss_gs = torch.FloatTensor(loss_gs).to(device)
output_gs_test = output_test# * loss_gs
output_train = output_train.detach().cpu().numpy()
output_test = output_gs_test.detach().cpu().numpy()
prediction_train = [relations[np.argmax(pred).item()]for pred in output_train]
probability_train = [pred[np.argmax(pred)] for pred in output_train]
correct_pred_train = [corr_rels[i] for i in range(len(output_train))]
prediction_test = [relations[np.argmax(pred).item()]for pred in output_test]
probability_test = [pred[np.argmax(pred)] for pred in output_test]
correct_pred_test = [corr_rels[i] for i in range(len(output_test))]
print("lowest confidence train", min(probability_train))
print("lowest confidence test", min(probability_test))
def accuracy_score(y_true, y_pred):
corr_preds=0
wrong_preds=0
for pred, correct in zip(y_pred, y_true):
if pred == correct:
corr_preds += 1
else:
wrong_preds += 1
return corr_preds/(corr_preds+wrong_preds)
print("Accuracy train:", accuracy_score(correct_pred_train, prediction_train))
print("Accuracy test:", accuracy_score(correct_pred_test, prediction_test))
```
%% Output
Beginning making batch
100%|██████████| 2052/2052 [00:00<00:00, 2654.58it/s]
100%|██████████| 2052/2052 [00:02<00:00, 992.25it/s]
Finished with batches
Beginning making batch
100%|██████████| 1161/1161 [00:00<00:00, 2831.73it/s]
100%|██████████| 1161/1161 [00:01<00:00, 1040.33it/s]
Finished with batches
test loss 6.039710998535156
lowest confidence train 0.07557418
lowest confidence test 0.08204544
Accuracy train: 0.2349570200573066
Accuracy test: 0.007633587786259542
test loss 6.041214942932129
lowest confidence train 0.07357887
lowest confidence test 0.07673955
Accuracy train: 0.24283667621776503
Accuracy test: 0.015267175572519083
%% Cell type:code id: tags:
``` python
from prettytable import PrettyTable
def count_parameters(model):
table = PrettyTable(["Mod name", "Parameters Listed"])
t_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad:
continue
param = parameter.numel()
table.add_row([name, param])
t_params += param
print(table)
print(f"Sum of trained parameters: {t_params}")
return t_params
count_parameters(model)
```
%% Output
+---------------+-------------------+
| Mod name | Parameters Listed |
+---------------+-------------------+
| linear.weight | 338688 |
| linear.bias | 441 |
+---------------+-------------------+
Sum of trained parameters: 339129
339129
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment