-
Notifications
You must be signed in to change notification settings - Fork 0
/
validation.py
34 lines (27 loc) · 1.08 KB
/
validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from NN import doc_retrieval_nn, embedding, retriever
from context_class import corpus
from random import choice
import json
if __name__ == "__main__":
with open("../squad1/train-v1.1.json", "r") as read_file:
valid = json.load(read_file)
# Transform our data into a corpus object.
corp = corpus(valid)
# Apply the create_dataframe method.
df = corp.create_dataframe()
documents = df[["context", "title"]].drop_duplicates().reset_index(drop=True)
X = embedding.fit_transform(documents["context"])
retriever.fit(X, documents["title"])
message = input("Enter a question or type random : ")
if message == "random":
# Select a random question.
trial_text = choice(corp.list_texts)
trial_paragraph = choice(trial_text.list_paragraphs)
trial_question = choice(trial_paragraph.questions)
else:
trial_question = message
print(trial_question)
# Retrieve the 'closest' document based on our NN method.
title, pred = doc_retrieval_nn(documents, trial_question, retriever)
print(title)
print(pred)