-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTeacherDataset.py
33 lines (29 loc) · 1.24 KB
/
TeacherDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import jsonlines
import random
from torch.utils.data import Dataset
class TeacherDataset(Dataset):
def __init__(self, data_path: str, percentage: float = 1):
self.data = []
with open(data_path, "r") as f:
for item in jsonlines.Reader(f):
self.data.append(item)
random.shuffle(self.data) # shuffle the data after loading
bound = round(len(self.data) * percentage)
self.data = self.data[:bound]
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return{
'text': self.data[index]['text'],
'top_token_prob': self.data[index]['top_token_prob'],
'loss_mask': self.data[index]['loss_mask'],
'attention_mask': self.data[index]['attention_mask']
}
def collate_fn(self, batch):
return{
'input_token': [x['text'] for x in batch],
'output_token': [[[int(z) for z in list(y.keys())] for y in x['top_token_prob']] for x in batch],
'top_prob': [[list(y.values()) for y in x['top_token_prob']] for x in batch],
"loss_mask": [x['loss_mask'] for x in batch],
"attention_mask": [x["attention_mask"] for x in batch]
}