-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset_maker.py
98 lines (82 loc) · 4.87 KB
/
dataset_maker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import logging
import pdb
from typing import Dict
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from args import DataTrainingArguments
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
from datasets import load_dataset, DownloadConfig
logger = logging.getLogger(__name__)
class DatasetMaker:
def __init__(self, dataset_saved_path: str, data_args: DataTrainingArguments,
training_args: Seq2SeqTrainingArguments, tokenizer: PreTrainedTokenizerBase):
self.data_args = data_args
self.training_args = training_args
self.tokenizer = tokenizer
self.dataset_saved_path = dataset_saved_path
def make_dataset(self):
logger.info('******* Making Dataset **********')
data_files = {}
if self.data_args.train_file is not None:
data_files["train"] = self.data_args.train_file
extension = self.data_args.train_file.split(".")[-1]
if self.data_args.validation_file is not None:
data_files["validation"] = self.data_args.validation_file
extension = self.data_args.validation_file.split(".")[-1]
if self.data_args.test_file is not None:
data_files["test"] = self.data_args.test_file
extension = self.data_args.test_file.split(".")[-1]
if extension == 'txt': extension = 'text'
datasets = load_dataset(extension, data_files=data_files, download_config=DownloadConfig(use_etag=False))
# Temporarily set max_target_length for training.
max_target_length = self.data_args.max_target_length
padding = "max_length" if self.data_args.pad_to_max_length else False
if self.training_args.label_smoothing_factor > 0:
logger.warn(
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for model. This will lead to loss being calculated twice and will take up more memory"
)
def preprocess_function(examples: Dict):
"""
如果是json,examples就是json对应的dict。如果是纯文本,examples["text"]就是全部文本,每个item就是文本文件中的一行
"""
if isinstance(examples["src"][0], str):
inputs = [ex.replace(' ', '') if self.data_args.chinese_data else ex for ex in examples["src"]]
elif isinstance(examples["src"][0], list):
inputs = [' '.join(ex).replace(' ', '') if self.data_args.chinese_data else ' '.join(ex) for ex in examples["src"]]
else:
raise ValueError(f'only support str/list in content, now {type(examples["src"][0])}')
if isinstance(examples["tgt"][0], str):
targets = [ex.replace(' ', '') + self.tokenizer.eos_token if self.data_args.chinese_data else ex + self.tokenizer.eos_token for ex in examples["tgt"]]
elif isinstance(examples["tgt"][0], list):
targets = [' '.join(ex).replace(' ', '') + self.tokenizer.eos_token if self.data_args.chinese_data else ' '.join
(ex) + self.tokenizer.eos_token for ex in examples["tgt"]]
else:
raise ValueError(f'only support str/list in summary, now {type(examples["tgt"][0])}')
model_inputs = self.tokenizer(inputs, max_length=self.data_args.max_source_length, padding=padding, truncation=True)
# addi_source = tokenizer(addi_source, max_length=data_args.max_source_length, padding=False, truncation=True,
# add_special_tokens=False)
# pdb.set_trace()
# Setup the tokenizer for targets
with self.tokenizer.as_target_tokenizer():
labels = self.tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
idxes = [ex for ex in examples["idx"]]
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and self.data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != self.tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
model_inputs["idxes"] = idxes
# pdb.set_trace()
return model_inputs
datasets = datasets.map(
preprocess_function,
batched=True,
num_proc=self.data_args.preprocessing_num_workers,
load_from_cache_file=not self.data_args.overwrite_cache,
)
logger.info('saving dataset')
dataset_saved_path = self.dataset_saved_path
datasets.save_to_disk(dataset_saved_path)
logger.info(f'******* Dataset Finish {dataset_saved_path} **********')
return datasets