-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
executable file
·210 lines (181 loc) · 7.39 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# -*- coding: utf-8 -*-
import os
import json
import argparse
import wandb
import numpy as np
import transformers
from shutil import copyfile
from datasets import load_metric
from transformers import Trainer
from transformers import Wav2Vec2ForCTC
from transformers import TrainingArguments
from transformers import EarlyStoppingCallback
from utils.generic_utils import load_config, load_vocab
from transformers.trainer_utils import get_last_checkpoint
transformers.logging.set_verbosity_info()
wandb.login()
wer_metric = load_metric("wer")
def map_data_augmentation(aug_config):
aug_name = aug_config['name']
del aug_config['name']
if aug_name == 'additive':
return AddBackgroundNoise(**aug_config)
elif aug_name == 'gaussian':
return AddGaussianNoise(**aug_config)
elif aug_name == 'rir':
return AddImpulseResponse(**aug_config)
elif aug_name == 'gain':
return Gain(**aug_config)
elif aug_name == 'pitch_shift':
return PitchShift(**aug_config)
else:
raise ValueError("The data augmentation '" + aug_name + "' doesn't exist !!")
def evaluation(pred):
global processor
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.batch_decode(pred_ids)
# we do not want to group tokens when computing the metrics
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
# remove empty strings
while "" in label_str or " " in label_str:
if "" in label_str:
idx = label_str.index("")
del label_str[idx], pred_str[idx]
if " " in label_str:
idx = label_str.index(" ")
del label_str[idx], pred_str[idx]
wer = wer_metric.compute(predictions=pred_str, references=label_str)
# print("PRED:", pred_str, "Label:", label_str)
return {"wer": wer}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'-c',
'--config_path',
type=str, required=True,
help="json file with configurations"
)
parser.add_argument(
'--checkpoint_path',
type=str,
default='facebook/wav2vec2-large-xlsr-53',
help="path of checkpoint pt file, for continue training"
)
parser.add_argument(
'--continue_train',
default=False,
action='store_true',
help='If True Continue the training using the checkpoint_path'
)
args = parser.parse_args()
config = load_config(args.config_path)
OUTPUT_DIR = config['output_path']
os.makedirs(OUTPUT_DIR, exist_ok=True)
vocab = load_vocab(config.vocab['vocab_path'])
if 'preprocess_dataset' in config.keys() and config['preprocess_dataset']:
from utils.dataset_preprocessed import Dataset, DataColletor
else:
from utils.dataset import Dataset, DataColletor
dataset = Dataset(config, vocab)
# preprocess and normalise datasets
dataset.preprocess_datasets()
processor = dataset.processor
# save the feature_extractor and the tokenizer
processor.save_pretrained(OUTPUT_DIR)
# save vocab
with open(os.path.join(OUTPUT_DIR, 'vocab.json'), "w", encoding="utf-8") as vocab_file:
json.dump(vocab, vocab_file, ensure_ascii=False)
# save config train
copyfile(args.config_path, os.path.join(OUTPUT_DIR, 'config_train.json'))
# Audio Data augmentation
if 'audio_augmentation' in config.keys():
from audiomentations import Compose, Gain, AddGaussianNoise, PitchShift, AddBackgroundNoise, AddImpulseResponse
# ToDo: Implement Time mask and Freq mask
audio_augmentator = Compose([map_data_augmentation(aug_config) for aug_config in config['audio_augmentation']])
else:
audio_augmentator = None
# create data colletor
data_collator = DataColletor(processor, audio_augmentator=audio_augmentator, sampling_rate=config.sampling_rate, padding=True, apply_dbfs_norm=config.apply_dbfs_norm, target_dbfs=config.target_dbfs)
if os.path.isdir(args.checkpoint_path):
last_checkpoint = get_last_checkpoint(args.checkpoint_path)
print("> Resuming Train with checkpoint: ", last_checkpoint)
else:
last_checkpoint = None
# load model
model = Wav2Vec2ForCTC.from_pretrained(
last_checkpoint if last_checkpoint else args.checkpoint_path,
attention_dropout=config['attention_dropout'],
hidden_dropout=config['hidden_dropout'],
feat_proj_dropout=config['feat_proj_dropout'],
mask_time_prob=config['mask_time_prob'],
mask_feature_prob=config['mask_feature_prob'],
layerdrop=config['layerdrop'],
gradient_checkpointing=config['gradient_checkpointing'],
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
vocab_size=len(processor.tokenizer),
ctc_zero_infinity=True,
apply_spec_augment=config['apply_spec_augment'],
mask_time_length=config['mask_time_length'],
mask_feature_length=config['mask_feature_length']
)
# freeze feature extractor
if config['freeze_feature_extractor']:
model.freeze_feature_extractor()
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
logging_dir=os.path.join(OUTPUT_DIR, "tensorboard"),
report_to="all",
run_name="CORAA-norm-spontaneous_speech-inf_train",
group_by_length=True,
logging_first_step=True,
per_device_train_batch_size=config['batch_size'],
per_device_eval_batch_size=config['batch_size'],
dataloader_num_workers=config['num_loader_workers'],
gradient_accumulation_steps=config['gradient_accumulation_steps'],
seed=config.seed,
num_train_epochs=config['epochs'],
fp16=config.mixed_precision,
logging_steps=config['logging_steps'],
learning_rate=config['lr'],
warmup_steps=config['warmup_steps'],
warmup_ratio=config['warmup_ratio'],
save_strategy="epoch",
evaluation_strategy="epoch",
load_best_model_at_end=config['load_best_model_at_end'],
metric_for_best_model="eval_loss",
greater_is_better=False,
save_total_limit=config['save_total_limit']
)
trainer = Trainer(
model=model,
data_collator=data_collator,
args=training_args,
compute_metrics=evaluation,
train_dataset=dataset.train_dataset,
eval_dataset=dataset.devel_dataset,
tokenizer=processor.feature_extractor
)
if config['early_stop_epochs']:
print(f"> Adding Early Stop in: {config['early_stop_epochs']} epochs")
trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=config['early_stop_epochs']))
print("> Starting Training")
train_result = trainer.train(resume_from_checkpoint=last_checkpoint if args.continue_train else None)
# save best model
# model.save_pretrained(OUTPUT_DIR)
trainer.save_model()
# save train results
metrics = train_result.metrics
metrics["train_samples"] = len(dataset.train_dataset)
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# save eval results
print("--- Evaluate ---")
metrics = trainer.evaluate()
metrics["eval_samples"] = len(dataset.devel_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)