Skip to content
Snippets Groups Projects
Commit 90c052d3 authored by Ludwig Forsberg's avatar Ludwig Forsberg
Browse files

Merge branch 'main' of gitlab.liu.se:tdde19-2022-1/codebase

parents 9e461682 284732fd
No related branches found
No related tags found
No related merge requests found
Showing with 139375 additions and 1496 deletions
%% Cell type:code id: tags:
``` python
import datasets
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertModel
from transformers.models.bert.modeling_bert import shift_tokens_right
```
%% Cell type:code id: tags:
``` python
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
```
%% Cell type:code id: tags:
``` python
class NgmOne(nn.Module):
def __init__(self):
super(NgmOne, self).__init__()
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
self.bert = BertModel.from_pretrained("bert-base-uncased")
self.linear = nn.Linear(768, 1)
self.softmax = nn.Softmax(dim=1)
def forward(self, triplet, question):
"""Triplet is a list of subject entity, relation, object entity, None if not present"""
#seq = "[CLS] " + question + " [SEP] "
if triplet[0] is not None:
#seq += "[SUB] [SEP] " + triplet[0]
tokenized_seq = self.tokenizer(question, "[SUB]", triplet[0])#, padding=True, truncation=True)
elif triplet[2] is not None:
#seq += "[OBJ] [SEP] " + triplet[2]
tokenized_seq = self.tokenizer(question, "[OBJ]", triplet[2])#, padding=True, truncation=True)
x = self.bert.forward(**tokenized_seq)
x = self.linear(x)
x = self.softmax(x)
return x
```
%% Cell type:code id: tags:
``` python
def encode(batch):
return tokenizer(batch, padding="max_length", max_length=256, return_tensors="pt")
def convert_to_features(example_batch):
input_encodings = encode(example_batch['text'])
target_encodings = encode(example_batch['summary'])
labels = target_encodings['input_ids']
decoder_input_ids = shift_tokens_right(
labels, model.config.pad_token_id, model.config.decoder_start_token_id)
labels[labels[:, :] == model.config.pad_token_id] = -100
encodings = {
'input_ids': input_encodings['input_ids'],
'attention_mask': input_encodings['attention_mask'],
'decoder_input_ids': decoder_input_ids,
'labels': labels,
}
return encodings
def get_dataset(path):
df = pd.read_csv(path, sep=",", on_bad_lines='skip')
dataset = datasets.Dataset.from_pandas(df)
dataset = dataset.map(convert_to_features, batched=True)
columns = ['input_ids', 'labels', 'decoder_input_ids', 'attention_mask', ]
dataset.set_format(type='torch', columns=columns)
return dataset
```
This diff is collapsed.
Source diff could not be displayed: it is too large. Options to address this: view the blob.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Parameters: input_file output_file # Parameters: input_file output_file er_link_style(optional)
# Example usage: python tokenizer.py qald-9-test-linked.json qald-9-test-tokenized
# Example usage: python tokenizer.py qald-9-test-linked.json qald-9-test-tokenized REPLACE
# Note: output_file is the name of the output file without the extension
# Note: er_link_style is the style of entity linking to use.
# It can be "REPLACE", "APPEND" or blank. Default is no usage of entity links.
import sys import sys
import json import json
def main(): def main():
print(sys.argv[0]) print(sys.argv[0])
input_file = sys.argv[1] or "qald-9-test-linked.json" if len(sys.argv) > 1:
output_file = sys.argv[2] or "qald-9-test-tokenized" input_file = sys.argv[1]
import json else:
print("Please provide input file")
# Decides if objects are changed to <objx> or not sys.exit(1)
TOKENIZE = False if len(sys.argv) > 2:
output_file = sys.argv[2]
else:
print("Please provide output file")
sys.exit(1)
if len(sys.argv) > 3:
er_link_style = sys.argv[3]
if (er_link_style != "REPLACE" and er_link_style != "APPEND"):
print("Please provide a valid entity/relationship linking style")
sys.exit(1)
else:
er_link_style = ""
prefixes = { prefixes = {
"http://dbpedia.org/resource/": "res:", "http://dbpedia.org/resource/": "res:",
...@@ -36,42 +51,23 @@ def main(): ...@@ -36,42 +51,23 @@ def main():
out.write("text,summary\n") out.write("text,summary\n")
for item in data["questions"]: for item in data["questions"]:
# Keeps track of how many objects we have replaced # Keeps track of how many objects we have replaced
obj_count = 0
question_string = item["question"][0]["string"] question_string = item["question"][0]["string"]
query_string = item["query"]["sparql"] query_string = item["query"]["sparql"]
entities = item.get("entities") or [] entities = item.get("entities") or []
# out.write(f'\"{question_string}\", \"{query_string}\"\n') relations = item.get("relations") or []
# out.write(" ".join(["-".join(ent) for ent in entities])+"\n")
# print(query_string)
# Find first occurence of PREFIX # Find first occurence of PREFIX
prefix_index = query_string.find("PREFIX") prefix_index = query_string.find("PREFIX")
# While there is a PREFIX, find the next one # While there is a PREFIX, find the next one
while prefix_index != -1: while prefix_index != -1:
# Find the word after PREFIX
prefixed_word = query_string[prefix_index + 7:query_string.find(" ", prefix_index + 7)]
# Find next occurence of /> # Find next occurence of />
end_index = query_string.find(">", prefix_index) end_index = query_string.find(">", prefix_index)
# Remove the prefix # Remove the prefix
query_string = query_string[:prefix_index] + query_string[end_index+2:] query_string = query_string[:prefix_index] + query_string[end_index+2:]
if TOKENIZE:
# Find occurence of prefixed word:xyz in query and replace it with <obj{obj_count}>
obj_index = query_string.find(prefixed_word)
obj_end_index = query_string.find(" ", obj_index)
while obj_index != -1:
query_string = query_string[:obj_index] + f"<obj{obj_count}>" + query_string[obj_end_index:]
obj_index = query_string.find(prefixed_word)
obj_end_index = query_string.find(" ", obj_index)
obj_count += 1
# Find next occurence of PREFIX # Find next occurence of PREFIX
prefix_index = query_string.find("PREFIX") prefix_index = query_string.find("PREFIX")
obj_count = 0
# Loop through all entities and repalce them with <obj{obj_count}>
# Replace all occurences of the prefix with value in prefixes # Replace all occurences of the prefix with value in prefixes
for prefix in prefixes: for prefix in prefixes:
...@@ -87,17 +83,63 @@ def main(): ...@@ -87,17 +83,63 @@ def main():
# Replace all occurences of "COUNT(" with "COUNT( " # Replace all occurences of "COUNT(" with "COUNT( "
query_string = query_string.replace("COUNT(", "COUNT( ") query_string = query_string.replace("COUNT(", "COUNT( ")
if TOKENIZE: if er_link_style == "APPEND":
for entity in entities: # Append all entities to the end of the question
entity_token = entity[1] if (len(entities) > 0):
entity_index = question_string.find(entity_token) question_string += " | "
# print(entity_token)
while entity_index != -1 and question_string[entity_index - 1] != "j": for k, entity in enumerate(entities):
question_string = question_string[:entity_index] + f"<obj{obj_count}>" + question_string[entity_index+len(entity_token):] if k > 0:
entity_index = question_string.find(entity_token) question_string += " | "
obj_count += 1
uri = entity["URI"]
# print("question_string", question_string) # Hopefully the uri includes a uri which we know how to shorten
for prefix in prefixes:
if uri.startswith(prefix):
uri = prefixes[prefix] + uri[len(prefix):]
break
question_string += f"{uri}"
if er_link_style == "REPLACE":
# Helper function since entities and relations are replaced in the same way.
def replace_entity_or_relation(er_source, question_string):
for er in er_source:
uri = er["URI"].strip()
surface_form = er["surface form"].strip()
if not surface_form:
continue
# Hopefully the uri includes a uri which we know how to shorten
for prefix in prefixes:
if uri.startswith(prefix):
uri = prefixes[prefix] + uri[len(prefix):]
break
# Find the first occurence of the surface form
er_index = question_string.find(surface_form)
while er_index != -1:
# Check if the surface form is not part of an already replaced entity/relation
previous_colon = question_string.rfind(":", 0, er_index)
previous_space = question_string.rfind(" ", 0, er_index)
previous_space = max(0, previous_space)
# If there is a colon to the left and there is a space after it then it is part of an entity/relation
if previous_colon > previous_space and previous_colon != -1:
next_space = question_string.find(" ", er_index)
er_index = question_string.find(surface_form, next_space)
continue
# Else replace the surface form with the uri
question_string = question_string[:er_index] + uri + question_string[er_index+len(surface_form):]
# Find the next occurence of the surface form
er_index = question_string.find(surface_form, er_index + len(uri))
return question_string
question_string = replace_entity_or_relation(entities, question_string)
question_string = replace_entity_or_relation(relations, question_string)
out.write(f'\"{question_string}\", \"{query_string}\"\n') out.write(f'\"{question_string}\", \"{query_string}\"\n')
if __name__ == "__main__": if __name__ == "__main__":
......
import sys
import json
import random
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: python train-test-split.py <train_percentage> <input_file>")
sys.exit(1)
train_percentage = int(sys.argv[1])
input_file = sys.argv[2]
file_name = input_file.split(".json")[0]
questions = []
with open(input_file, "r") as f:
data = json.load(f)
questions = data["questions"]
train_size = int(len(questions) * train_percentage / 100)
test_size = len(questions) - train_size
train_questions = random.choices(questions, k=train_size)
test_questions = [q for q in questions if q not in train_questions]
train_data = {"questions": train_questions}
test_data = {"questions": test_questions}
with open(file_name + "-train.json", "w") as f:
json.dump(train_data, f)
with open(file_name + "-test.json", "w") as f:
json.dump(test_data, f)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment