Skip to content
Snippets Groups Projects
Commit d0958432 authored by Max Björkander's avatar Max Björkander
Browse files

began on neural graph search module

parent 0426cd7d
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id: tags:
``` python
import datasets
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertModel
from transformers.models.bert.modeling_bert import shift_tokens_right
```
%% Cell type:code id: tags:
``` python
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
```
%% Cell type:code id: tags:
``` python
class NgmOne(nn.Module):
def __init__(self):
super(NgmOne, self).__init__()
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
self.bert = BertModel.from_pretrained("bert-base-uncased")
self.linear = nn.Linear(768, 1)
self.softmax = nn.Softmax(dim=1)
def forward(self, triplet, question):
"""Triplet is a list of subject entity, relation, object entity, None if not present"""
#seq = "[CLS] " + question + " [SEP] "
if triplet[0] is not None:
#seq += "[SUB] [SEP] " + triplet[0]
tokenized_seq = self.tokenizer(question, "[SUB]", triplet[0])#, padding=True, truncation=True)
elif triplet[2] is not None:
#seq += "[OBJ] [SEP] " + triplet[2]
tokenized_seq = self.tokenizer(question, "[OBJ]", triplet[2])#, padding=True, truncation=True)
x = self.bert.forward(**tokenized_seq)
x = self.linear(x)
x = self.softmax(x)
return x
```
%% Cell type:code id: tags:
``` python
def encode(batch):
return tokenizer(batch, padding="max_length", max_length=256, return_tensors="pt")
def convert_to_features(example_batch):
input_encodings = encode(example_batch['text'])
target_encodings = encode(example_batch['summary'])
labels = target_encodings['input_ids']
decoder_input_ids = shift_tokens_right(
labels, model.config.pad_token_id, model.config.decoder_start_token_id)
labels[labels[:, :] == model.config.pad_token_id] = -100
encodings = {
'input_ids': input_encodings['input_ids'],
'attention_mask': input_encodings['attention_mask'],
'decoder_input_ids': decoder_input_ids,
'labels': labels,
}
return encodings
def get_dataset(path):
df = pd.read_csv(path, sep=",", on_bad_lines='skip')
dataset = datasets.Dataset.from_pandas(df)
dataset = dataset.map(convert_to_features, batched=True)
columns = ['input_ids', 'labels', 'decoder_input_ids', 'attention_mask', ]
dataset.set_format(type='torch', columns=columns)
return dataset
```
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment