diff --git a/Neural graph module/ngm.ipynb b/Neural graph module/ngm.ipynb index a924791ec8388b0e67e4a9635a89828d25d7f9bb..d95c77016b150ff219acb1da22657a5994bbad3b 100644 --- a/Neural graph module/ngm.ipynb +++ b/Neural graph module/ngm.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 17, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -15,21 +15,23 @@ "from transformers import BertTokenizer, BertModel\n", "from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments\n", "from tqdm import tqdm\n", - "import json\n" + "import json\n", + "import requests\n" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ - "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")" + "mod = \"bert-base-uncased\"\n", + "tokenizer = BertTokenizer.from_pretrained(mod)" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -49,7 +51,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -57,8 +59,12 @@ "class NgmOne(nn.Module):\n", " def __init__(self, device, relations):\n", " super(NgmOne, self).__init__()\n", - " self.tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n", - " self.bert = BertModel.from_pretrained(\"bert-base-uncased\").to(device)\n", + " with torch.no_grad():\n", + " self.tokenizer = BertTokenizer.from_pretrained(mod)\n", + " self.bert = BertModel.from_pretrained(mod).to(device)\n", + "\n", + " for param in self.bert.parameters():\n", + " param.requires_grad = False\n", " self.linear = nn.Linear(768, len(relations)).to(device)\n", " self.softmax = nn.Softmax(dim=1).to(device)\n", " self.device = device\n", @@ -68,8 +74,7 @@ " tokenized_mask = tokenized_mask.to(self.device)\n", " with torch.no_grad():\n", " x = self.bert.forward(tokenized_seq, attention_mask=tokenized_mask)\n", - " x = x[0][:,0,:].to(self.device)\n", - " \n", + " x = x[0][:,0,:].to(self.device)\n", " x = self.linear(x)\n", " x = self.softmax(x)\n", " return x" @@ -77,7 +82,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 29, "metadata": {}, "outputs": [], "source": [ @@ -98,6 +103,25 @@ " \"http://www.w3.org/2004/02/skos/core#\": \"skos:\",\n", "}\n", "\n", + "prefixes_reverse = {v: k for k, v in prefixes.items()}\n", + "\n", + "# {\n", + "# \"res:\": \"http://dbpedia.org/resource/\",\n", + "# \"dbo:\": \"http://dbpedia.org/ontology/\",\n", + "# \"dbp:\": \"http://dbpedia.org/property/\",\n", + "# \"rdfs:\": \"http://www.w3.org/2000/01/rdf-schema#\",\n", + "# \"rdf:\": \"http://www.w3.org/1999/02/22-rdf-syntax-ns#\",\n", + "# \"yago:\": \"http://dbpedia.org/class/yago/\",\n", + "# \"wdt:\": \"http://www.wikidata.org/prop/direct/\",\n", + "# \"wd:\": \"http://www.wikidata.org/entity/\",\n", + "# \"p:\": \"http://www.wikidata.org/prop/\",\n", + "# \"ps:\": \"https://w3id.org/payswarm#\",\n", + "# \"pq:\": \"http://www.wikidata.org/prop/qualifier/\",\n", + "# \"bd:\": \"http://www.bigdata.com/rdf#\",\n", + "# \"wikibase:\": \"http://wikiba.se/ontology#\",\n", + "# \"skos:\": \"http://www.w3.org/2004/02/skos/core#\",\n", + "# }\n", + "\n", "prefixes_end = [\"org\", \"se\", \"com\", \"nu\"]\n", "ALL_HTTP_PREFIXES = False\n", "\n", @@ -119,6 +143,7 @@ " \n", " inputs = []\n", " correct_rels = []\n", + " sub_obj_ents = []\n", " inputs_max_len = 0\n", " for d in tqdm(pred[\"questions\"]):\n", " question = d[\"question\"][0][\"string\"]\n", @@ -171,17 +196,19 @@ " if triplet_i[0] is not None and triplet_i[1] is not None:\n", " #seq += \"[SUB] [SEP] \" + triplet[0]\n", " # , padding=True, truncation=True)\n", - " tokenized_seq = tokenizer(question, \"[SUB]\", triplet_i[0], padding=True, truncation=True)\n", + " tokenized_seq = tokenizer(question.lower(), \"[OBJ] [SEP] \" + triplet_i[0].split(\":\")[1].lower(), padding=True, truncation=True)\n", + " sub_obj_ents.append(\"[SUB] \" + prefixes_reverse[\"\".join(triplet_i[0].split(\":\")[0]) + \":\"] + triplet_i[0].split(\":\")[1])\n", " elif triplet_i[2] is not None and triplet_i[1] is not None:\n", " #seq += \"[OBJ] [SEP] \" + triplet[2]\n", - " tokenized_seq = tokenizer(question, \"[OBJ]\", triplet_i[2], padding=True, truncation=True)\n", + " tokenized_seq = tokenizer(question.lower(), \"[OBJ] [SEP] \" + triplet_i[2].split(\":\")[1].lower(), padding=True, truncation=True)\n", + " sub_obj_ents.append(\"[OBJ] \" + prefixes_reverse[\"\".join(triplet_i[2].split(\":\")[0]) + \":\"] + triplet_i[2].split(\":\")[1])\n", " else:\n", " continue\n", "\n", " if inputs_max_len < len(tokenized_seq[\"input_ids\"]):\n", " inputs_max_len = len(tokenized_seq[\"input_ids\"])\n", " inputs.append(list(tokenized_seq.values())[0])\n", - " correct_rels.append(triplet_i[1])\n", + " correct_rels.append(triplet_i[1].lower())\n", "\n", " inputs_padded = np.array([i + [0]*(inputs_max_len-len(i)) for i in inputs])\n", " #correct_rels_padded = np.array([i + [0]*(correct_rels_max_len-len(i)) for i in correct_rels])\n", @@ -189,34 +216,35 @@ " inputs_attention_mask = np.where(inputs_padded != 0, 1, 0)\n", " #correct_rels_attention_mask = np.where(correct_rels_padded != 0, 1, 0)\n", " print(\"Finished with batches\")\n", - " return torch.LongTensor(inputs_padded), torch.LongTensor(inputs_attention_mask), correct_rels #torch.IntTensor(correct_rels_padded), torch.LongTensor(correct_rels_attention_mask)\n" + " return torch.LongTensor(inputs_padded), torch.LongTensor(inputs_attention_mask), correct_rels, sub_obj_ents #torch.IntTensor(correct_rels_padded), torch.LongTensor(correct_rels_attention_mask)\n" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import Dataset, DataLoader\n", "\n", "class MyDataset(Dataset):\n", - " def __init__(self, inputs, attention_mask, correct_rels, relations):\n", + " def __init__(self, inputs, attention_mask, correct_rels,ents, relations):\n", " self.inputs = inputs\n", " self.attention_mask = attention_mask\n", " self.correct_rels = correct_rels\n", + " self.ents = ents\n", " self.relations = relations\n", "\n", " def __len__(self):\n", " return len(self.inputs)\n", "\n", " def __getitem__(self, idx):\n", - " return self.inputs[idx], self.attention_mask[idx], self.relations.index(self.correct_rels[idx])" + " return self.inputs[idx], self.attention_mask[idx], self.relations.index(self.correct_rels[idx]), self.ents[idx]" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 31, "metadata": {}, "outputs": [ { @@ -230,7 +258,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 2052/2052 [00:03<00:00, 546.37it/s]\n" + "100%|██████████| 2052/2052 [00:00<00:00, 2906.57it/s]\n" ] }, { @@ -238,16 +266,17 @@ "output_type": "stream", "text": [ "Finished with batches\n", - "features: tensor([[ 101, 2054, 2003, 1996, 13314, 1997, 16122, 10230, 5400, 2100,\n", - " 12378, 2136, 1029, 102, 1031, 4942, 1033, 102, 0, 0,\n", + "features: ['[CLS] name the home stadium of fc spartak moscow season 2011 - 12 [SEP] [ obj ] [SEP] 2011 – 12 _ fc _ spartak _ moscow _ season [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'] mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) label_index tensor(431)\n", + "valid features: tensor([[ 101, 2054, 3063, 2001, 3378, 4876, 2007, 1996, 6436, 26785,\n", + " 7971, 6447, 1998, 2600, 1037, 20160, 10362, 1029, 102, 1031,\n", + " 27885, 3501, 1033, 102, 2600, 1011, 1037, 1011, 20160, 10362,\n", + " 102, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", - " 0, 0, 0, 0, 0, 0, 0]]) mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,\n", - " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) label_index tensor(378)\n", - "valid features: tensor([[ 101, 2054, 4752, 1997, 7992, 13843, 5112, 2267, 2003, 2036,\n", - " 1996, 9353, 4215, 26432, 2278, 3037, 1997, 2703, 27668, 3077,\n", - " 1029, 102, 1031, 4942, 1033, 102, 0, 0, 0, 0,\n", - " 0, 0, 0, 0, 0, 0, 0]]) valid mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", - " 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) valid label_index tensor(28)\n" + " 0, 0, 0, 0, 0, 0, 0, 0]]) valid mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) valid label_index tensor(149)\n" ] } ], @@ -259,18 +288,19 @@ " with open(file, \"r\") as f:\n", " return json.load(f)\n", "\n", + "\n", "#relations = open_json(\"../data/relations-query-qald-9-linked.json\")\n", - "relations = open_json(\"../data/relations-all-lc-quad-no-http.json\")\n", + "relations = open_json(\"../data/relations-all-no-http-lowercase.json\")\n", "\n", "# \"../data/qald-9-train-linked.json\"\n", "#pred = \"../LC-QuAD/combined-requeried-linked-train.json\"\n", - "inputs, attention_mask, correct_rels = make_batch(src=\"../LC-QuAD/combined-requeried-linked-train.json\", http_prefix = True) #train\n", + "inputs, attention_mask, correct_rels, sub_objs = make_batch(src=\"../LC-QuAD/combined-requeried-linked-train.json\", http_prefix = True) #train\n", "\n", "# relations = open_json(\"../data/relations-lcquad-without-http-train-linked.json\")\n", "# train_set = MyDataset(*make_batch(), relations=relations)\n", "\n", "\n", - "dataset = MyDataset(inputs, attention_mask, correct_rels, relations=relations)\n", + "dataset = MyDataset(inputs, attention_mask, correct_rels,sub_objs, relations=relations)\n", "train_size = int(0.8 * len(dataset))\n", "valid_size = len(dataset) - train_size\n", "\n", @@ -278,94 +308,65 @@ "\n", "train_dataloader = DataLoader(train_data, batch_size=1, shuffle=True)\n", "#show first entry\n", - "train_features, train_mask, train_label = next(iter(train_dataloader))\n", - "print(\"features:\", train_features, \"mask:\",train_mask,\"label_index\", train_label[0])\n", + "train_features, train_mask, train_label, ents = next(iter(train_dataloader))\n", + "print(\"features:\", tokenizer.batch_decode(train_features), \"mask:\",train_mask,\"label_index\", train_label[0])\n", "\n", "\n", "valid_dataloader = DataLoader(valid_data, batch_size=1, shuffle=True)\n", - "valid_features, valid_mask, valid_label = next(iter(valid_dataloader))\n", + "valid_features, valid_mask, valid_label, ents = next(iter(valid_dataloader))\n", "print(\"valid features:\", valid_features, \"valid mask:\",valid_mask,\"valid label_index\", valid_label[0])" ] }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']\n", + "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] } ], "source": [ + "SPARQL_ENDPOINT = \"https://dbpedia.org/sparql\"\n", + "\n", + "\n", + "# test_data_file = sys.argv[1]\n", + "# predicted_data_file = sys.argv[2]\n", + "\n", + "headers = {\n", + " 'Accept': 'application/sparql-results+json',\n", + " 'Content-Type': 'application/x-www-form-urlencoded',\n", + "}\n", + "\n", "# Initialize model\n", "model = NgmOne(device, relations)" ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "1 Train 6.453075773575726 , Valid 7.615683317184448\n", - "2 Train 6.444019513971665 , Valid 7.6052809953689575\n", - "3 Train 6.43224432889153 , Valid 7.602278828620911\n", - "4 Train 6.415224019218893 , Valid 7.598743319511414\n", - "5 Train 6.391957339118509 , Valid 7.584347605705261\n", - "6 Train 6.368611419902129 , Valid 7.573382139205933\n", - "7 Train 6.339372747084674 , Valid 7.557032465934753\n", - "8 Train 6.318648646859562 , Valid 7.539492607116699\n", - "9 Train 6.297140289755428 , Valid 7.523481965065002\n", - "10 Train 6.281140131108901 , Valid 7.519673824310303\n", - "11 Train 6.265462370479808 , Valid 7.5268988609313965\n", - "12 Train 6.24550900739782 , Valid 7.504716753959656\n", - "13 Train 6.234297051149256 , Valid 7.505620360374451\n", - "14 Train 6.220176332137164 , Valid 7.500500917434692\n", - "15 Train 6.204408701728372 , Valid 7.493185758590698\n", - "16 Train 6.194800797630759 , Valid 7.488588452339172\n", - "17 Train 6.183391935685101 , Valid 7.466232180595398\n", - "18 Train 6.173011443194221 , Valid 7.457393527030945\n", - "19 Train 6.168144422418931 , Valid 7.45364773273468\n", - "20 Train 6.157315562753117 , Valid 7.452357649803162\n", - "21 Train 6.140512690824621 , Valid 7.456966400146484\n", - "22 Train 6.129443336935604 , Valid 7.464043855667114\n", - "23 Train 6.12704220940085 , Valid 7.443313360214233\n", - "24 Train 6.117525269003475 , Valid 7.4395164251327515\n", - "25 Train 6.116069064420812 , Valid 7.456610083580017\n", - "26 Train 6.108902790967156 , Valid 7.430782794952393\n", - "27 Train 6.108471814323874 , Valid 7.44858992099762\n", - "28 Train 6.100216697244083 , Valid 7.457515120506287\n", - "29 Train 6.099361447726979 , Valid 7.438013672828674\n", - "30 Train 6.092377662658691 , Valid 7.448408484458923\n", - "31 Train 6.088698302998262 , Valid 7.442046403884888\n", - "32 Train 6.083712998558493 , Valid 7.420018911361694\n", - "33 Train 6.081799563239603 , Valid 7.426819205284119\n", - "34 Train 6.070652428795309 , Valid 7.426627039909363\n", - "35 Train 6.069005825940301 , Valid 7.425306558609009\n", - "36 Train 6.059389002182904 , Valid 7.415045142173767\n", - "37 Train 6.0618573918062095 , Valid 7.418038249015808\n", - "38 Train 6.055309716392966 , Valid 7.434216380119324\n", - "39 Train 6.049336994395537 , Valid 7.43299674987793\n", - "40 Train 6.048270281623392 , Valid 7.410753607749939\n", - "41 Train 6.043416135451373 , Valid 7.410261392593384\n", - "42 Train 6.042512613184312 , Valid 7.4074866771698\n", - "43 Train 6.032251245835248 , Valid 7.390785813331604\n", - "44 Train 6.025217645308551 , Valid 7.370600938796997\n", - "45 Train 6.027591985814712 , Valid 7.377846837043762\n", - "46 Train 6.0253391826854035 , Valid 7.39339554309845\n", - "47 Train 6.019261051626766 , Valid 7.407409071922302\n", - "48 Train 6.014283713172464 , Valid 7.393349766731262\n", - "49 Train 6.010964337517233 , Valid 7.37293815612793\n", - "50 Train 6.010590104495778 , Valid 7.4031054973602295\n" + "1 Train 4.168464240410345 , Valid 6.253535368863274\n", + "2 Train 4.158141554688378 , Valid 6.238900885862463\n", + "3 Train 4.135342076527986 , Valid 6.225469000199261\n", + "4 Train 4.12202568705991 , Valid 6.212939725202673\n", + "5 Train 4.109015821552963 , Valid 6.205244933857637\n", + "6 Train 4.098594773587563 , Valid 6.194438050774967\n", + "7 Train 4.08812516884838 , Valid 6.178463767556583\n", + "8 Train 4.079949097667667 , Valid 6.175140450982487\n", + "9 Train 4.07018859094853 , Valid 6.167128044016221\n", + "10 Train 4.065771084037616 , Valid 6.161452812307021\n" ] } ], @@ -373,11 +374,15 @@ "# Train with data loader.\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(model.parameters(), lr=0.001)\n", + "#optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.5)\n", "\n", - "epoch = 50\n", - "batch_size = 64\n", + "epoch = 10\n", + "batch_size = 8\n", + "alpha = 0.5\n", "train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n", "valid_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=True)\n", + "\n", + "model.train()\n", "for e in range(epoch):\n", " train_loss_epoch = 0\n", " valid_loss_epoch = 0\n", @@ -386,10 +391,47 @@ " train = sample_batched_train[0]\n", " train_mask = sample_batched_train[1]\n", " label_index = sample_batched_train[2].to(device)\n", + " sub_objs = sample_batched_train[3]\n", " \n", " # Forward pass\n", " output = model(train, train_mask)\n", - " loss = criterion(output, label_index)\n", + " loss_gs = []\n", + " for j in range(len(sub_objs)):\n", + " if not (sub_objs[j].split(\" \")[0] == \"[SUB]\" or sub_objs[j].split(\" \")[0] == \"[OBJ]\"):\n", + " continue\n", + "\n", + " if sub_objs[j].split(\" \")[0] == \"[SUB]\":\n", + " sub = sub_objs[j].split(\" \")[1]\n", + " q = \"SELECT ?r WHERE { <\" + sub + \"> ?r ?o }\"\n", + "\n", + " if sub_objs[j].split(\" \")[0] == \"[OBJ]\":\n", + " obj = sub_objs[j].split(\" \")[1]\n", + " q = \"SELECT ?r WHERE { ?s ?r <\" + obj + \"> }\"\n", + " \n", + " params = {\n", + " \"default-graph-uri\": \"http://dbpedia.org\",\n", + " \"query\": q,\n", + " \"format\": \"json\"\n", + " }\n", + "\n", + " response = requests.get(SPARQL_ENDPOINT, headers=headers, params=params, timeout=15)\n", + " results = response.json()\n", + " res_rels = {}\n", + " for i in range(len(list(results[\"results\"].values())[2])):\n", + " res_rels[list(results[\"results\"].values())[2][i][\"r\"][\"value\"].lower()] = \"True\"\n", + " \n", + " loss_rels = []\n", + " for i in range(len(relations)):\n", + " if prefixes_reverse[\"\".join(relations[i].split(\":\")[0]) + \":\"] + relations[i].split(\":\")[1] in list(res_rels.keys()):\n", + " loss_rels.append(5) \n", + " else:\n", + " loss_rels.append(1/20)\n", + " loss_gs.append(loss_rels)\n", + " loss_gs = torch.FloatTensor(loss_gs).to(device)\n", + "\n", + " #print(response)\n", + " output_gs = output * (1-alpha) + (loss_gs) * alpha\n", + " loss = criterion(output_gs, label_index)\n", "\n", " # backward and optimize\n", " loss.backward()\n", @@ -413,7 +455,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 35, "metadata": {}, "outputs": [ { @@ -427,7 +469,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 2052/2052 [00:03<00:00, 577.85it/s]\n" + "100%|██████████| 2052/2052 [00:00<00:00, 2654.58it/s]\n" ] }, { @@ -442,7 +484,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 1161/1161 [00:02<00:00, 566.88it/s]\n" + "100%|██████████| 1161/1161 [00:00<00:00, 2831.73it/s]\n" ] }, { @@ -450,22 +492,22 @@ "output_type": "stream", "text": [ "Finished with batches\n", - "test loss 6.031976699829102\n", - "lowest confidence train 0.08024344\n", - "lowest confidence test 0.08576707\n", - "Accuracy train: 0.3939828080229226\n", - "Accuracy test: 0.006361323155216285\n" + "test loss 6.039710998535156\n", + "lowest confidence train 0.07557418\n", + "lowest confidence test 0.08204544\n", + "Accuracy train: 0.2349570200573066\n", + "Accuracy test: 0.007633587786259542\n" ] } ], "source": [ "# Predict\n", - "train, train_mask, corr_rels = make_batch(src=\"../LC-QuAD/combined-requeried-linked-train.json\", http_prefix = True)\n", - "test, test_mask, corr_rels_test = make_batch(src=\"../LC-QuAD/combined-requeried-linked-test.json\", http_prefix = True)\n", - "test_data = MyDataset(test, test_mask, corr_rels_test, relations=relations)\n", + "train, train_mask, corr_rels,ents = make_batch(src=\"../LC-QuAD/combined-requeried-linked-train.json\", http_prefix = True)\n", + "test, test_mask, corr_rels_test, ents_test = make_batch(src=\"../LC-QuAD/combined-requeried-linked-test.json\", http_prefix = True)\n", + "test_data = MyDataset(test, test_mask, corr_rels_test, ents=ents_test, relations=relations)\n", "test_dataloader = DataLoader(test_data, batch_size=len(test_data), shuffle=True)\n", "\n", - "test_batch, test_mask_batch, corr_rels_test_batch = next(iter(test_dataloader))\n", + "test_batch, test_mask_batch, corr_rels_test_batch, sub_objs = next(iter(test_dataloader))\n", "corr_rels_test_batch = corr_rels_test_batch.to(device)\n", "with torch.no_grad():\n", " output_train = model(train, train_mask)\n", @@ -473,8 +515,44 @@ " loss = criterion(output_test, corr_rels_test_batch)\n", " print(\"test loss\", loss.item())\n", "\n", + " loss_gs = []\n", + " for j in range(len(sub_objs)):\n", + " if not (sub_objs[j].split(\" \")[0] == \"[SUB]\" or sub_objs[j].split(\" \")[0] == \"[OBJ]\"):\n", + " continue\n", + "\n", + " if sub_objs[j].split(\" \")[0] == \"[SUB]\":\n", + " sub = sub_objs[j].split(\" \")[1]\n", + " q = \"SELECT ?r WHERE { <\" + sub + \"> ?r ?o }\"\n", + "\n", + " if sub_objs[j].split(\" \")[0] == \"[OBJ]\":\n", + " obj = sub_objs[j].split(\" \")[1]\n", + " q = \"SELECT ?r WHERE { ?s ?r <\" + obj + \"> }\"\n", + "\n", + " params = {\n", + " \"default-graph-uri\": \"http://dbpedia.org\",\n", + " \"query\": q,\n", + " \"format\": \"json\"\n", + " }\n", + "\n", + " response = requests.get(\n", + " SPARQL_ENDPOINT, headers=headers, params=params, timeout=15)\n", + " results = response.json()\n", + " res_rels = {}\n", + " for i in range(len(list(results[\"results\"].values())[2])):\n", + " res_rels[list(results[\"results\"].values())[2][i]\n", + " [\"r\"][\"value\"].lower()] = \"True\"\n", + "\n", + " loss_rels = []\n", + " for i in range(len(relations)):\n", + " if prefixes_reverse[\"\".join(relations[i].split(\":\")[0]) + \":\"] + relations[i].split(\":\")[1] in list(res_rels.keys()):\n", + " loss_rels.append(5)\n", + " else:\n", + " loss_rels.append(1/20)\n", + " loss_gs.append(loss_rels)\n", + " loss_gs = torch.FloatTensor(loss_gs).to(device)\n", + "output_gs_test = output_test# * loss_gs\n", "output_train = output_train.detach().cpu().numpy()\n", - "output_test = output_test.detach().cpu().numpy()\n", + "output_test = output_gs_test.detach().cpu().numpy()\n", "\n", "\n", "prediction_train = [relations[np.argmax(pred).item()]for pred in output_train]\n", @@ -506,15 +584,55 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 190, "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+---------------+-------------------+\n", + "| Mod name | Parameters Listed |\n", + "+---------------+-------------------+\n", + "| linear.weight | 338688 |\n", + "| linear.bias | 441 |\n", + "+---------------+-------------------+\n", + "Sum of trained parameters: 339129\n" + ] + }, + { + "data": { + "text/plain": [ + "339129" + ] + }, + "execution_count": 190, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from prettytable import PrettyTable\n", + "def count_parameters(model):\n", + " table = PrettyTable([\"Mod name\", \"Parameters Listed\"])\n", + " t_params = 0\n", + " for name, parameter in model.named_parameters():\n", + " if not parameter.requires_grad:\n", + " continue\n", + " param = parameter.numel()\n", + " table.add_row([name, param])\n", + " t_params += param\n", + " print(table)\n", + " print(f\"Sum of trained parameters: {t_params}\")\n", + " return t_params\n", + "\n", + "count_parameters(model)" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.10.4 ('tdde19')", + "display_name": "Python 3.9.11 64-bit", "language": "python", "name": "python3" }, @@ -528,12 +646,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.4" + "version": "3.9.11" }, "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "8e4aa0e1a1e15de86146661edda0b2884b54582522f7ff2b916774ba6b8accb1" + "hash": "64e7cd3b4b88defe39dd61a4584920400d6beb2615ab2244e340c2e20eecdfe9" } } }, diff --git a/bart/sparql.ipynb b/bart/sparql.ipynb index adaf70cfc3a9b55cb2a1499a3d81c2e673ec1f0c..6dcb66eb8825c68573affc46257035764f82681b 100644 --- a/bart/sparql.ipynb +++ b/bart/sparql.ipynb @@ -4075,15 +4075,13 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.10" + "version": "3.9.11" }, "vscode": { - "interpreter": { "hash": "0988fb18a177ab47253348976d615106e52998fa9d41728b0d76cf08f2eafea6" } }, "widgets": { - "application/vnd.jupyter.widget-state+json": { "04562bf2742444049a073935c78ad553": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", diff --git a/data/relations-all-no-http-lowercase.json b/data/relations-all-no-http-lowercase.json new file mode 100644 index 0000000000000000000000000000000000000000..90d64557b4730b9cf007741d100ac3a8b89f775b --- /dev/null +++ b/data/relations-all-no-http-lowercase.json @@ -0,0 +1,443 @@ +[ + "dbp:firstdriver", + "dbp:governingbody", + "dbp:mission", + "dbp:flagbearer", + "dbp:manager", + "dbo:foundationplace", + "dbo:related", + "dbo:lieutenant", + "dbo:formerpartner", + "dbp:combatant", + "dbo:spouse", + "dbo:division", + "dbp:religion", + "dbp:affiliations", + "dbo:predecessor", + "dbp:knownfor", + "dbp:cinematography", + "dbo:ideology", + "dbp:meaning", + "dbp:constituency", + "dbo:maintainedby", + "dbo:layout", + "dbp:address", + "dbo:computingplatform", + "dbp:position", + "dbo:party", + "dbp:religiousaffiliation", + "dbp:genre", + "dbp:successor", + "dbp:nationality", + "dbo:parentorganisation", + "dbp:assembly", + "dbp:subject", + "dbp:headquarters", + "dbp:placeofbirth", + "dbp:team", + "dbp:director", + "dbo:family", + "dbp:district", + "dbp:creators", + "dbo:firstdriver", + "dbo:architecturalstyle", + "dbo:mountainrange", + "dbp:leadername", + "dbo:product", + "dbp:origin", + "dbp:doctoralstudents", + "dbo:trainer", + "dbo:order", + "dbp:birthplace", + "dbo:builder", + "dbo:voice", + "dbp:majorshrine", + "dbp:office", + "dbp:leader", + "dbp:locationcountry", + "dbo:manager", + "dbp:draftteam", + "dbp:license", + "dbp:editor", + "dbp:trainer", + "dbp:stadium", + "dbp:domain", + "dbo:gender", + "dbo:highschool", + "dbo:associatedband", + "dbo:coach", + "dbp:broadcastarea", + "dbo:automobileplatform", + "dbo:lyrics", + "dbo:training", + "dbo:hubairport", + "dbo:militaryunit", + "dbp:membership", + "dbo:firstascentperson", + "dbo:operator", + "dbo:notablework", + "dbp:teamname", + "dbp:creator", + "dbp:architecturalstyle", + "dbp:owners", + "dbo:language", + "dbp:state", + "dbo:mouthmountain", + "dbp:archipelago", + "dbp:publisher", + "dbp:format", + "dbp:products", + "dbp:predecessor", + "dbo:season", + "dbo:bronzemedalist", + "dbp:discipline", + "dbp:lyrics", + "dbo:bandmember", + "dbo:knownfor", + "dbo:silvermedalist", + "dbp:architecture", + "dbo:stateoforigin", + "dbo:network", + "dbo:cinematography", + "dbp:coach", + "dbo:nearestcity", + "dbp:author", + "dbp:debutteam", + "dbp:prizes", + "dbo:mayor", + "dbo:denomination", + "dbo:currency", + "dbp:workinstitutions", + "dbp:spouse", + "dbo:ethnicity", + "dbp:notablecommanders", + "dbp:starring", + "dbp:battles", + "dbo:formerteam", + "dbo:foundedby", + "dbo:leader", + "dbo:veneratedin", + "dbo:restingplace", + "dbp:licensee", + "dbp:appointer", + "dbp:keypeople", + "dbo:relative", + "dbp:training", + "dbp:presenter", + "dbo:managerclub", + "dbo:writer", + "dbo:partner", + "dbo:operatingsystem", + "dbp:country", + "dbp:music", + "dbo:basedon", + "dbp:deathdate", + "dbp:line", + "dbo:creator", + "dbo:series", + "dbo:previouswork", + "dbp:founded", + "dbp:jurisdiction", + "dbo:opponent", + "dbp:area", + "dbo:garrison", + "dbp:gender", + "dbo:stadium", + "dbp:place", + "dbp:chancellor", + "dbp:hostcity", + "dbp:placeofburial", + "dbp:cityserved", + "dbo:starring", + "dbo:profession", + "dbp:agencyname", + "dbo:federalstate", + "dbo:origin", + "dbp:party", + "dbp:tenants", + "dbp:schooltype", + "dbp:magazine", + "dbo:distributinglabel", + "dbp:purpose", + "dbo:incumbent", + "dbp:homestadium", + "dbp:canonizedby", + "dbo:deathplace", + "dbp:design", + "dbo:producer", + "dbo:officiallanguage", + "dbp:foundation", + "dbp:champion", + "dbo:author", + "dbp:hometown", + "dbp:doctoraladvisor", + "dbp:affiliation", + "dbp:children", + "dbp:occupation", + "dbo:affiliation", + "dbo:doctoralstudent", + "dbo:authority", + "dbp:employer", + "dbp:partner", + "dbo:architect", + "dbo:deathcause", + "dbo:jurisdiction", + "dbo:battle", + "dbp:manufacturer", + "dbp:relatives", + "dbp:deathplace", + "dbp:pastmembers", + "dbp:operator", + "dbo:honours", + "dbp:owner", + "dbo:headquarter", + "dbo:ceremonialcounty", + "dbp:firstteam", + "dbp:college", + "dbp:label", + "dbp:poledriver", + "dbo:president", + "dbo:destination", + "dbp:screenplay", + "dbp:material", + "dbp:programminglanguage", + "dbp:sisterstations", + "dbo:founder", + "dbo:targetairport", + "dbp:officialname", + "dbp:hubs", + "dbo:academicadvisor", + "dbo:locationcity", + "dbp:parent", + "dbp:currentclub", + "dbp:services", + "dbo:discoverer", + "dbo:colour", + "dbo:team", + "dbo:largestcity", + "dbo:recordedin", + "dbp:governor", + "dbp:university", + "dbo:servingrailwayline", + "dbo:commander", + "dbo:species", + "dbo:poledriver", + "dbp:athletics", + "dbo:race", + "dbp:league", + "dbp:languages", + "dbo:militarybranch", + "dbp:leadertitle", + "dbp:notableworks", + "dbo:literarygenre", + "dbo:genre", + "dbp:placeofdeath", + "dbp:president", + "dbo:manufacturer", + "dbo:significantbuilding", + "dbp:mascot", + "dbp:borough", + "dbp:race", + "dbp:language", + "dbp:birthname", + "dbp:commander", + "dbo:distributor", + "dbp:beatifiedby", + "dbp:managerclubs", + "dbp:firstaired", + "dbo:countyseat", + "dbp:artist", + "dbo:majorshrine", + "dbo:athletics", + "dbo:education", + "dbo:programmeformat", + "dbp:branch", + "dbo:location", + "dbo:cpu", + "dbp:characters", + "dbp:inflow", + "dbp:outflow", + "dbp:houses", + "dbp:mainingredient", + "dbp:restingplace", + "dbo:wineregion", + "dbp:producer", + "dbp:primeminister", + "dbp:related", + "dbp:recorded", + "dbp:guests", + "dbo:director", + "dbp:style", + "dbo:successor", + "dbo:parentcompany", + "dbo:university", + "dbo:editor", + "dbo:timezone", + "dbo:hometown", + "dbo:commandstructure", + "dbo:monarch", + "dbo:publisher", + "dbo:residence", + "dbo:rivermouth", + "dbo:otherparty", + "dbo:designer", + "dbp:titles", + "dbo:nationality", + "dbo:stylisticorigin", + "dbo:movement", + "dbp:club", + "dbp:neighboringmunicipalities", + "dbp:locationtown", + "dbp:venue", + "dbp:destinations", + "dbo:board", + "dbp:residence", + "dbp:pastteams", + "dbo:editing", + "dbo:nonfictionsubject", + "dbp:lieutenant", + "dbp:order", + "dbp:title", + "dbo:routeend", + "dbo:launchsite", + "dbo:doctoraladvisor", + "dbp:writer", + "dbo:owningcompany", + "dbo:anthem", + "dbo:placeofburial", + "dbp:province", + "dbo:executiveproducer", + "dbp:distributor", + "dbo:debutteam", + "dbo:developer", + "dbo:chairman", + "dbp:ground", + "dbp:buildingtype", + "dbo:programminglanguage", + "dbp:nearestcity", + "dbp:international", + "dbp:almamater", + "dbo:academicdiscipline", + "dbp:deputy", + "dbo:river", + "dbp:garrison", + "dbo:associatedmusicalartist", + "dbp:school", + "dbo:breeder", + "dbp:chairman", + "dbo:sport", + "dbo:club", + "dbo:illustrator", + "dbp:founder", + "dbo:school", + "dbp:citizenship", + "dbp:designer", + "dbp:playedfor", + "dbp:arena", + "dbo:parent", + "dbo:ingredient", + "dbp:junction", + "dbo:institution", + "dbp:engine", + "dbo:regionserved", + "dbp:rank", + "dbp:operatingsystem", + "dbo:type", + "dbp:carries", + "dbo:citizenship", + "dbp:locationcity", + "dbp:nickname", + "dbo:governmenttype", + "dbp:commandstructure", + "dbo:militaryrank", + "dbo:formerbandmember", + "dbp:maininterests", + "dbo:almamater", + "dbo:country", + "dbp:narrated", + "dbo:occupation", + "dbo:primeminister", + "dbo:field", + "dbo:composer", + "dbo:routestart", + "dbo:league", + "dbo:outflow", + "dbo:subsidiary", + "dbo:child", + "dbp:highschool", + "dbp:sisternames", + "dbo:portrayer", + "dbo:award", + "dbp:crosses", + "dbo:employer", + "dbp:name", + "dbp:city", + "dbo:birthplace", + "dbo:musicby", + "dbp:editing", + "dbo:license", + "dbo:relation", + "dbp:fields", + "dbp:nationalorigin", + "dbp:location", + "dbo:kingdom", + "dbp:largestcity", + "dbo:service", + "dbo:capital", + "dbp:field", + "dbp:currency", + "dbo:phylum", + "dbp:haircolor", + "dbp:deathcause", + "dbo:ground", + "dbp:education", + "dbo:broadcastarea", + "dbo:territory", + "dbp:mother", + "dbo:narrator", + "dbp:allegiance", + "dbo:assembly", + "dbo:college", + "dbp:region", + "dbo:region", + "dbo:album", + "dbo:county", + "dbo:tenant", + "dbp:executiveproducer", + "dbp:awards", + "dbp:developer", + "dbp:writers", + "dbp:architect", + "dbp:thememusiccomposer", + "dbo:presenter", + "dbp:company", + "dbo:owner", + "dbp:birthdate", + "dbp:nationalteam", + "dbo:keyperson", + "dbp:cities", + "dbo:campus", + "dbo:city", + "dbo:languagefamily", + "dbo:sisterstation", + "dbp:coverartist", + "dbp:veneratedin", + "dbp:animator", + "dbp:type", + "dbp:notableinstruments", + "dbo:locatedinarea", + "dbo:religion", + "dbo:binomialauthority", + "dbo:neighboringmunicipality", + "dbp:role", + "dbo:homestadium", + "dbp:os", + "dbp:youthclubs", + "dbp:currentmembers", + "dbp:album", + "dbo:instrument", + "dbo:subsequentwork", + "dbo:inflow", + "dbo:artist", + "dbp:associatedacts" +]