diff --git a/bart/generate_single_query.py b/bart/generate_single_query.py index 5ade102d7ee3c81214a932adf1f42fa08389b51f..4865619f981bc0211f813e31f315ad70692d8978 100644 --- a/bart/generate_single_query.py +++ b/bart/generate_single_query.py @@ -6,7 +6,7 @@ from transformers.models.bart.modeling_bart import shift_tokens_right # Save the model and tokenizer as local variables to use them in the predict_query function def setup(): # Import model from ./trained-models/blackbox - model = BartForConditionalGeneration.from_pretrained("./trained-models/blackbox") + model = BartForConditionalGeneration.from_pretrained("ludfo774/sparql-bart-append-1") tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") model.to('cuda') return model, tokenizer