forked from microsoft/nlp-recipes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsquad.py
108 lines (91 loc) · 3.98 KB
/
squad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import json
import pandas as pd
from utils_nlp.dataset.url_utils import maybe_download
URL_DICT = {
"v1.1": {
"train": "https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/"
"master/dataset/train-v1.1.json",
"dev": "https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/"
"master/dataset/dev-v1.1.json",
},
"v2.0": {
"train": "https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/"
"master/dataset/train-v2.0.json",
"dev": "https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/"
"master/dataset/dev-v2.0.json",
},
}
def load_pandas_df(local_cache_path=".", squad_version="v1.1", file_split="train"):
"""Loads the SQuAD dataset in pandas data frame.
Args:
local_cache_path (str, optional): Path to load the data from. If the file doesn't exist,
download it first. Defaults to the current directory.
squad_version (str, optional): Version of the SQuAD dataset, accepted values are:
"v1.1" and "v2.0". Defaults to "v1.1".
file_split (str, optional): Dataset split to load, accepted values are: "train" and "dev".
Defaults to "train".
"""
if file_split not in ["train", "dev"]:
raise ValueError("file_split should be either train or dev")
URL = URL_DICT[squad_version][file_split]
file_name = URL.split("/")[-1]
maybe_download(URL, file_name, local_cache_path)
file_path = os.path.join(local_cache_path, file_name)
with open(file_path, "r", encoding="utf-8") as reader:
input_data = json.load(reader)["data"]
paragraph_text_list = []
question_text_list = []
answer_start_list = []
answer_text_list = []
qa_id_list = []
is_impossible_list = []
for entry in input_data:
for paragraph in entry["paragraphs"]:
paragraph_text = paragraph["context"]
for qa in paragraph["qas"]:
qas_id = qa["id"]
question_text = qa["question"]
answer_offset = None
is_impossible = False
if squad_version == "v2.0":
is_impossible = qa["is_impossible"]
if file_split == "train":
if (len(qa["answers"]) != 1) and (not is_impossible):
raise ValueError(
"For training, each question should have exactly 1 answer."
)
if not is_impossible:
answer = qa["answers"][0]
orig_answer_text = answer["text"]
answer_offset = answer["answer_start"]
else:
orig_answer_text = ""
else:
if not is_impossible:
orig_answer_text = []
answer_offset = []
for answer in qa["answers"]:
orig_answer_text.append(answer["text"])
answer_offset.append(answer["answer_start"])
else:
orig_answer_text = ""
paragraph_text_list.append(paragraph_text)
question_text_list.append(question_text)
answer_start_list.append(answer_offset)
answer_text_list.append(orig_answer_text)
qa_id_list.append(qas_id)
is_impossible_list.append(is_impossible)
output_df = pd.DataFrame(
{
"doc_text": paragraph_text_list,
"question_text": question_text_list,
"answer_start": answer_start_list,
"answer_text": answer_text_list,
"qa_id": qa_id_list,
"is_impossible": is_impossible_list,
}
)
return output_df