Skip to content
Snippets Groups Projects
Commit ba6de697 authored by Ludwig Forsberg's avatar Ludwig Forsberg
Browse files

Added train and evaluation script

parent 5b6afabb
No related branches found
No related tags found
No related merge requests found
Showing
with 3766 additions and 0 deletions
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import sys
import tqdm
from data import eval_query as eq
from data import pred_query_responses as pqr
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers.models.bart.modeling_bart import shift_tokens_right
import datasets
import torch
import pandas as pd
import numpy as np
import csv
import json
import bart.generate_single_query as gsq
if __name__ == "__main__":
if len(sys.argv) > 1:
train_file = sys.argv[1]
else:
print("Please provide train file name")
sys.exit(1)
if len(sys.argv) > 2:
er_link_style = sys.argv[2]
er_link_styles = ["REPLACE", "APPEND-1", "APPEND-2", "TOKENIZE", "TOKENIZE_REPLACE"]
if er_link_style not in er_link_styles:
print("Please provide a valid entity/relationship linking style")
sys.exit(1)
else:
er_link_style = ""
if len(sys.argv) > 3:
bart = sys.argv[3]
types = ["BASE", "LARGE"]
if bart not in types:
print("Please provide a valid bart model")
sys.exit(1)
else:
bart = "BASE"
if len(sys.argv) > 4:
epochs = sys.argv[4]
else:
epochs = 1
train_path = "data/" + train_file
test_path = train_path.replace("train", "test")
train_tok_path = "data/tokenized/" + train_file.replace(".json", "-" + er_link_style.lower() + ".csv")
test_tok_path = train_tok_path.replace("train", "test")
output = test_tok_path.replace("/tokenized/", "/predicted/").replace(".csv", "") + "-" + bart.lower()
print(train_path, test_path, train_tok_path, test_tok_path, output)
model = BartForConditionalGeneration.from_pretrained("facebook/bart-" + bart.lower())
tokenizer = BartTokenizer.from_pretrained("facebook/bart-" + bart.lower())
def encode(batch):
return tokenizer(batch, padding="max_length", max_length=256, return_tensors="pt")
def convert_to_features(example_batch):
input_encodings = encode(example_batch['text'])
target_encodings = encode(example_batch['summary'])
labels = target_encodings['input_ids']
decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id, model.config.decoder_start_token_id)
labels[labels[:, :] == model.config.pad_token_id] = -100
encodings = {
'input_ids': input_encodings['input_ids'],
'attention_mask': input_encodings['attention_mask'],
'decoder_input_ids': decoder_input_ids,
'labels': labels,
}
return encodings
def get_dataset(path):
df = pd.read_csv(path, sep=",", on_bad_lines='skip')
dataset = datasets.Dataset.from_pandas(df)
dataset = dataset.map(convert_to_features, batched=True)
columns = ['input_ids', 'labels', 'decoder_input_ids','attention_mask',]
dataset.set_format(type='torch', columns=columns)
return dataset
train_dataset = get_dataset(train_tok_path)
training_args = Seq2SeqTrainingArguments(
output_dir='./trained-models/blackbox',
num_train_epochs=1,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
warmup_steps=10,
weight_decay=0.01,
logging_dir='./logs',
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset
)
for i in range(int(epochs)):
trainer.train()
pred_path = output + "-" + str(i) + ".csv"
with open(test_path, "r", encoding="utf-8") as f, open(pred_path, "w", encoding="utf-8") as out:
test_data = json.load(f)
test_data = test_data["questions"]
test_data = [q["question"][0]["string"] for q in test_data]
out.write("text,summary\n")
print("Runs predicted queries")
pbar = tqdm.tqdm(total=len(test_data))
for i, question in enumerate(test_data):
predicted = gsq.predict_query(question, model, tokenizer)
out.write(f"\"{question}\",\"{predicted}\"\n")
pbar.update(1)
pbar.close()
dump_path = pred_path.replace("predicted", "pred_responses").replace(".csv", ".json")
pqr.build_responsefile(dump_path, test_path, pred_path)
print("Evaluation againts server results")
precision_macro_query, recall_macro_query, f1_macro_query, precision_micro_query, recall_micro_query, f1_micro_query, fully_correct_query = eq.eval_query_response(test_path, dump_path)
print("Evaluation of queries as strings")
precision_macro_string, recall_macro_string, f1_macro_string, precision_micro_string, recall_micro_string, f1_micro_string, fully_correct_string = eq.eval_query_json(test_path, dump_path)
res_path = dump_path.replace("pred_responses", "eval").replace(".json", ".txt")
with open(res_path, "w") as f:
f.write("String evaluation\n\n")
f.write(f"Precision macro: {precision_macro_string}\n")
f.write(f"Recall macro: {recall_macro_string}\n")
f.write(f"F1 macro: {f1_macro_string}\n")
f.write(f"Precision micro: {precision_micro_string}\n")
f.write(f"Recall micro: {recall_micro_string}\n")
f.write(f"F1 micro: {f1_micro_string}\n")
f.write(f"Fully correct: {fully_correct_string}\n\n")
f.write("Query evaluation\n\n")
f.write(f"Precision macro: {precision_macro_query}\n")
f.write(f"Recall macro: {recall_macro_query}\n")
f.write(f"F1 macro: {f1_macro_query}\n")
f.write(f"Precision micro: {precision_micro_query}\n")
f.write(f"Recall micro: {recall_micro_query}\n")
f.write(f"F1 micro: {f1_micro_query}\n")
f.write(f"Fully correct: {fully_correct_query}\n\n")
#train_dataset = get_dataset("./data/tokenized/lc-quad-requeried-linked-train-tokenized-append-1.csv")
#test_dataset = get_dataset("test.csv")
#
#
#
#
#
#
## Specify the correct and predicted query responses
#input_path= "data/lc-quad-requeried-linked-test.json"
#pred_path = "data/predicted/lc-quad-requeried-linked-test-predicted-append-1-10-epochs.csv"
#dump_path = pred_path.replace("predicted", "pred_responses").replace(".csv", ".json")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment