-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
44 lines (33 loc) · 1.21 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from torch.utils.data import Dataset
import csv
class AnswersCSVDataset(Dataset):
"""
Minimalistic Dataset
- No tokenization
- No padding
- No class dictionary
"""
def __init__(self, answer_file_path):
self.data_list, self.answer_ids, self.header = self.load_data(answer_file_path)
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
return self.data_list[index]
def load_data(self, data_file):
data_list = []
answer_ids = set()
header = None
for file in data_file:
with open(file) as csvfile:
csv_reader = csv.reader(csvfile)
header = next(csv_reader)
for row in csv_reader:
a_id = row[0] # answer id
if a_id not in answer_ids:
answer_ids.add(a_id)
line = row[1] # answer text
q_id = row[2] # question id
a_score = row[3] # answer score
a_feed = row[4] # answer feedback
data_list.append((a_id, line, q_id, a_score, a_feed))
return data_list, answer_ids, header