-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrain_tacotron.py
124 lines (106 loc) · 5.05 KB
/
train_tacotron.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
import os
from time import sleep
import infolog
import tensorflow as tf
from hparams import hparams
from infolog import log
from tacotron.synthesize import tacotron_synthesize
from tacotron.train import tacotron_train
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
log = infolog.log
def save_seq(file, sequence, input_path):
'''Save Tacotron-2 training state to disk. (To skip for future runs)
'''
sequence = [str(int(s)) for s in sequence] + [input_path]
with open(file, 'w') as f:
f.write('|'.join(sequence))
def read_seq(file):
'''Load Tacotron-2 training state from disk. (To skip if not first run)
'''
if os.path.isfile(file):
with open(file, 'r') as f:
sequence = f.read().split('|')
return [bool(int(s)) for s in sequence[:-1]], sequence[-1]
else:
return [0, 0, 0], ''
def prepare_run(args):
modified_hp = hparams.parse(args.hparams)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(args.tf_log_level)
run_name = args.name or args.model
log_dir = os.path.join(args.base_dir, 'logs-{}'.format(run_name))
os.makedirs(log_dir, exist_ok=True)
infolog.init(os.path.join(log_dir, 'Terminal_train_log'), run_name, args.slack_url)
return log_dir, modified_hp
def train(args, log_dir, hparams):
state_file = os.path.join(log_dir, 'state_log')
#Get training states
(taco_state, GTA_state, wave_state), input_path = read_seq(state_file)
print('taco_state, GTA_state, wave_state:',taco_state, GTA_state, wave_state)
if not taco_state:
log('\n#############################################################\n')
log('Tacotron Train\n')
log('###########################################################\n')
checkpoint = tacotron_train(args, log_dir, hparams)
tf.reset_default_graph()
#Sleep 1/2 second to let previous graph close and avoid error messages while synthesis
sleep(0.5)
if checkpoint is None:
raise('Error occured while training Tacotron, Exiting!')
taco_state = 1
save_seq(state_file, [taco_state, GTA_state, wave_state], input_path)
else:
checkpoint = os.path.join(log_dir, 'taco_pretrained/')
# print('checkpoint:',checkpoint)
log('tacotron_train done!!')
if not GTA_state:
log('\n#############################################################\n')
log('Tacotron GTA Synthesis\n')
log('###########################################################\n')
input_path = tacotron_synthesize(args, hparams, checkpoint)
tf.reset_default_graph()
#Sleep 1/2 second to let previous graph close and avoid error messages while Wavenet is training
sleep(0.5)
GTA_state = 1
save_seq(state_file, [taco_state, GTA_state, wave_state], input_path)
else:
input_path = os.path.join(log_dir, 'tacotron_' + args.output_dir, 'gta', 'map.txt')
# input_path = './tacotron_output/gta/map.txt'
log('Tacotron GTA Synthesis done')
def main():
train_data_base = '/xxx/tacotron2_wavernn/'
parser = argparse.ArgumentParser()
parser.add_argument('--base_dir', default=train_data_base)
parser.add_argument('--hparams', default='',
help='Hyperparameter overrides as a comma-separated list of name=value pairs')
parser.add_argument('--tacotron_input', default='training_data/train.txt')
parser.add_argument('--name', help='Name of logging directory.')
parser.add_argument('--model', default='Tacotron-2')
parser.add_argument('--input_dir', default=train_data_base + 'training_data', help='folder to contain inputs sentences/targets')
parser.add_argument('--output_dir', default='output', help='folder to contain synthesized mel spectrograms')
parser.add_argument('--mode', default='synthesis', help='mode for synthesis of tacotron after training')
parser.add_argument('--gta_output', default=train_data_base + 'training_data/')
parser.add_argument('--GTA', default='True', help='Ground truth aligned synthesis, defaults to True, only considered in Tacotron synthesis mode')
parser.add_argument('--restore', type=bool, default=True, help='Set this to False to do a fresh training')
parser.add_argument('--summary_interval', type=int, default=250,
help='Steps between running summary ops')
parser.add_argument('--embedding_interval', type=int, default=5000,
help='Steps between updating embeddings projection visualization')
parser.add_argument('--checkpoint_interval', type=int, default=5000,
help='Steps between writing checkpoints')
parser.add_argument('--eval_interval', type=int, default=5000,
help='Steps between eval on test data')
parser.add_argument('--tacotron_train_steps', type=int, default=400000, help='total number of tacotron training steps')
parser.add_argument('--tf_log_level', type=int, default=3, help='Tensorflow C++ log level.')
parser.add_argument('--slack_url', default=None, help='slack webhook notification destination link')
args = parser.parse_args()
accepted_models = ['Tacotron-2']
if args.model not in accepted_models:
raise ValueError('please enter a valid model to train: {}'.format(accepted_models))
log_dir, hparams = prepare_run(args)
if args.model == 'Tacotron-2':
train(args, log_dir, hparams)
else:
raise ValueError('Model provided {} unknown! {}'.format(args.model, accepted_models))
if __name__ == '__main__':
main()