-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
52 lines (44 loc) · 1.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from collator import data_collator
from model import model, processor
from functools import partial
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer, Trainer
from speaker_speech import dataset_creator
import huggingface_hub
huggingface_hub.login('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')
dataset = dataset_creator()
dataset = dataset.train_test_split(test_size=0.1)
# disable cache during training since it's incompatible with gradient checkpointing
model.config.use_cache = False
# set language and task for generation and re-enable cache
model.generate = partial(model.generate, use_cache=True)
training_args = Seq2SeqTrainingArguments(
output_dir="speecht5_finetuned_kazakh_tts2",
per_device_train_batch_size=32,
gradient_accumulation_steps=8,
learning_rate=1e-5,
warmup_steps=1000,
max_steps=5000,
gradient_checkpointing=False,
fp16=True,
evaluation_strategy="steps",
per_device_eval_batch_size=2,
save_steps=1000,
eval_steps=1000,
logging_steps=25,
# report_to=["tensorboard"],
load_best_model_at_end=True,
greater_is_better=False,
label_names=["labels"],
push_to_hub=True,
)
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
data_collator=data_collator,
tokenizer=processor,
)
trainer.train()
trainer.push_to_hub()