-
Notifications
You must be signed in to change notification settings - Fork 66
/
Copy pathCustomDataset.py
60 lines (44 loc) · 2.01 KB
/
CustomDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from torch.utils.data import Dataset
from transformers import T5Tokenizer
import csv
tokenizer = T5Tokenizer.from_pretrained('t5-base')
class CustomDataset(Dataset):
def __init__(self, tokenizer, data_dir, type_path, max_len=256):
# self.path = os.path.join(data_dir, type_path + '.csv')
self.source_column = "question1"
self.target_column = "question2"
self.data = []
with open(type_path + ".csv", "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=',')
line_count = 0
for row in csv_reader:
self.data.append(row)
self.max_len = max_len
self.tokenizer = tokenizer
self.inputs = []
self.targets = []
self._build()
def __len__(self):
return len(self.inputs)
def __getitem__(self, index):
source_ids = self.inputs[index]["input_ids"].squeeze()
target_ids = self.targets[index]["input_ids"].squeeze()
src_mask = self.inputs[index]["attention_mask"].squeeze() # might need to squeeze
target_mask = self.targets[index]["attention_mask"].squeeze() # might need to squeeze
return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask}
def _build(self):
for example in self.data:
input_ = example[0]
target = example[1]
input_ = "paraphrase: " + input_ + ' </s>'
target = target + " </s>"
# tokenize inputs
tokenized_inputs = self.tokenizer.batch_encode_plus(
[input_], max_length=self.max_len, pad_to_max_length=True, truncation=True, return_tensors="pt"
)
# tokenize targets
tokenized_targets = self.tokenizer.batch_encode_plus(
[target], max_length=self.max_len, pad_to_max_length=True, truncation=True, return_tensors="pt"
)
self.inputs.append(tokenized_inputs)
self.targets.append(tokenized_targets)