-
Notifications
You must be signed in to change notification settings - Fork 9
/
train.py
344 lines (281 loc) · 14 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
#! -*- encoding: utf-8 -*-
from __future__ import print_function
from data_reader import DataReader
from params import hparams
import tensorflow as tf
import time
import argparse
import os
import sys
import numpy as np
from scipy.io import wavfile
from datetime import datetime
from glow import WaveGlow, compute_waveglow_loss
from tensorflow.python.client import timeline
from data_reader import read_binary_lc
STARTED_DATESTRING = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
def get_arguments():
def _str_to_bool(s):
"""Convert string to bool (in argparse context)."""
if s.lower() not in ['true', 'false']:
raise ValueError('Argument needs to be a '
'boolean, got {}'.format(s))
return {'true': True, 'false': False}[s.lower()]
parser = argparse.ArgumentParser(description='Parallel WaveNet Network')
parser.add_argument('--filelist', type=str, default=None, required=True,
help='filelist path for training data.')
parser.add_argument('--wave_dir', type=str, default=None, required=True,
help='wave data directory for training data.')
parser.add_argument('--lc_dir', type=str, default=None, required=True,
help='local condition directory for training data.')
parser.add_argument('--ngpu', type=int, default=1, help='gpu numbers')
parser.add_argument('--run_name', type=str, default='waveglow',
help='run name for log saving')
parser.add_argument('--restore_from', type=str, default=None,
help='restore model from checkpoint')
parser.add_argument('--store_metadata', type=_str_to_bool, default=False,
help='Whether to store advanced debugging information')
return parser.parse_args()
def write_wav(waveform, sample_rate, filename):
"""
:param waveform: [-1,1]
:param sample_rate:
:param filename:
:return:
"""
# TODO: write wave to 16bit PCM, don't use librosa to write wave
y = np.array(waveform, dtype=np.float32)
y *= 32767
wavfile.write(filename, sample_rate, y.astype(np.int16))
print('Updated wav file at {}'.format(filename))
def save(saver, sess, logdir, step, write_meta_graph=False):
model_name = 'model.ckpt'
checkpoint_path = os.path.join(logdir, model_name)
print('Storing checkpoint to {} ...'.format(logdir), end="")
sys.stdout.flush()
if not os.path.exists(logdir):
os.makedirs(logdir)
saver.save(sess, checkpoint_path, global_step=step, write_meta_graph=write_meta_graph)
print(' Done.')
def load(saver, sess, logdir):
print("Trying to restore saved checkpoints from {} ...".format(logdir),
end="")
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt:
print(" Checkpoint found: {}".format(ckpt.model_checkpoint_path))
global_step = int(ckpt.model_checkpoint_path
.split('/')[-1]
.split('-')[-1])
print(" Global step was: {}".format(global_step))
print(" Restoring...", end="")
saver.restore(sess, ckpt.model_checkpoint_path)
print(" Done.")
return global_step
else:
print(" No checkpoint found.")
return None
def average_gradients(tower_grads):
"""
Calculate the average gradient for each shared variable across all towers.
Note that this function provides a synchronization point across all towers.
Args:
tower_grads: List of lists of (gradient, variable) tuples. The outer list
is over individual gradients. The inner list is over the gradient
calculation for each tower.
Returns:
List of pairs of (gradient, variable) where the gradient has been averaged
across all towers.
"""
average_grads = []
for grad_and_vars in zip(*tower_grads):
# Note that each grad_and_vars looks like the following:
# ((grad0_gpu0, var0_gpu0), (grad0_gpu1, var0_gpu1)... , (grad0_gpuN, var0_gpuN))
grads = []
for g, _ in grad_and_vars:
if g is None:
continue
# Add 0 dimension to the gradients to represent the tower.
expanded_g = tf.expand_dims(g, 0)
# Append on a 'tower' dimension which we will average over below.
grads.append(expanded_g)
if len(grads) == 0:
average_grads.append((None, grad_and_vars[0][1]))
continue
# Average over the 'tower' dimension.
grad = tf.concat(grads, 0)
grad = tf.reduce_mean(grad, 0)
# Keep in mind that the Variables are redundant because they are shared
# across towers. So .. we will just return the first tower's pointer to
# the Variable.
v = grad_and_vars[0][1]
grad_and_var = (grad, v)
average_grads.append(grad_and_var)
return average_grads
def stats_graph(graph):
flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
print('FLOPs: {}; Trainable params: {}'.format(flops.total_float_ops, params.tra))
def count():
total_parameters = 0
for variable in tf.trainable_variables():
# shape is an array of tf.Dimension
shape = variable.get_shape()
variable_parameters = 1
for dim in shape:
variable_parameters *= dim.value
total_parameters += variable_parameters
return total_parameters
def main():
args = get_arguments()
args.logdir = os.path.join(hparams.logdir_root, args.run_name)
if not os.path.exists(args.logdir):
os.makedirs(args.logdir)
args.gen_wave_dir = os.path.join(args.logdir, 'wave')
os.makedirs(args.gen_wave_dir, exist_ok=True)
assert hparams.upsampling_rate == hparams.hop_length, 'upsamling rate should be same as hop_length'
# Create coordinator.
coord = tf.train.Coordinator()
global_step = tf.get_variable("global_step", [], initializer=tf.constant_initializer(0), trainable=False)
learning_rate = tf.train.exponential_decay(hparams.lr, global_step, hparams.decay_steps, 0.95, staircase=True)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
with tf.device('/cpu:0'):
with tf.name_scope('inputs'):
reader = DataReader(coord, args.filelist, args.wave_dir, args.lc_dir)
sess = tf.Session(config=tf.ConfigProto(log_device_placement=False, allow_soft_placement=True))
reader.start_threads()
audio_placeholder = tf.placeholder(tf.float32, shape=[None, None, 1], name='audio')
lc_placeholder = tf.placeholder(tf.float32, shape=[None, None, hparams.num_mels], name='lc')
tower_losses = []
tower_grads = []
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
for i in range(args.ngpu):
with tf.device('/gpu:%d' % i), tf.name_scope('tower_%d' % i):
glow = WaveGlow(lc_dim=hparams.num_mels,
n_flows=hparams.n_flows,
n_group=hparams.n_group,
n_early_every=hparams.n_early_every,
n_early_size=hparams.n_early_size)
print('create network %i' % i)
local_audio_placeholder = audio_placeholder[i * hparams.batch_size:(i + 1) * hparams.batch_size, :, :]
local_lc_placeholder = lc_placeholder[i * hparams.batch_size:(i + 1) * hparams.batch_size, :, :]
output_audio, log_s_list, log_det_W_list = glow.create_forward_network(local_audio_placeholder,
local_lc_placeholder)
loss = compute_waveglow_loss(output_audio, log_s_list, log_det_W_list, sigma=hparams.sigma)
grads = optimizer.compute_gradients(loss, var_list=tf.trainable_variables())
tower_losses.append(loss)
tower_grads.append(grads)
tf.summary.scalar('loss_tower_%d' % i, loss)
# # gradient clipping
# gradients = [grad for grad, var in averaged_gradients]
# params = [var for grad, var in averaged_gradients]
# clipped_gradients, norm = tf.clip_by_global_norm(gradients, 1.0)
#
# with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
# train_ops = optimizer.apply_gradients(zip(clipped_gradients, params), global_step=global_step)
print("create network finished")
loss = tf.reduce_mean(tower_losses)
averaged_gradients = average_gradients(tower_grads)
train_ops = optimizer.apply_gradients(averaged_gradients, global_step=global_step)
tf.summary.scalar('loss', loss)
# Set up logging for TensorBoard.
writer = tf.summary.FileWriter(args.logdir)
writer.add_graph(tf.get_default_graph())
run_metadata = tf.RunMetadata()
summaries = tf.summary.merge_all()
# inference for audio
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
lc_placeholder_infer = tf.placeholder(tf.float32, shape=[1, None, hparams.num_mels], name='lc_infer')
audio_infer_ops = glow.infer(lc_placeholder_infer, sigma=hparams.sigma)
# Set up session
init = tf.global_variables_initializer()
sess.run(init)
print('parameters initialization finished')
# stats_graph(tf.get_default_graph())
total_parameters = count()
print("######################################################")
print("### Total Trainable Params is {} ###".format(total_parameters))
print("######################################################")
saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=30)
saved_global_step = 0
if args.restore_from is not None:
try:
saved_global_step = load(saver, sess, args.restore_from)
if saved_global_step is None:
# The first training step will be saved_global_step + 1,
# therefore we put -1 here for new or overwritten trainings.
saved_global_step = 0
except Exception:
print("Something went wrong while restoring checkpoint. "
"We will terminate training to avoid accidentally overwriting "
"the previous model.")
raise
print("restore model successfully!")
print('start training.')
last_saved_step = saved_global_step
try:
for step in range(saved_global_step + 1, hparams.train_steps):
audio, lc = reader.dequeue(num_elements=hparams.batch_size * args.ngpu)
if hparams.lc_conv1d or hparams.lc_encode or hparams.transposed_upsampling:
# if using local condition bi-lstm encoding or tranposed conv upsampling, no need to upsample
# bi-lstm, upsamle will be done in the tf code
lc = np.reshape(lc, [hparams.batch_size * args.ngpu, -1, hparams.num_mels])
else:
# upsampling by directly repeat
lc = np.tile(lc, [1, 1, hparams.upsampling_rate])
lc = np.reshape(lc, [hparams.batch_size * args.ngpu, -1, hparams.num_mels])
start_time = time.time()
if step % 100 == 0 and args.store_metadata:
# Slow run that stores extra information for debugging.
print('Storing metadata')
run_options = tf.RunOptions(
trace_level=tf.RunOptions.FULL_TRACE)
summary, loss_value, _, lr = sess.run(
[summaries, loss, train_ops, learning_rate],
feed_dict={audio_placeholder: audio,
lc_placeholder: lc},
options=run_options,
run_metadata=run_metadata)
writer.add_summary(summary, step)
writer.add_run_metadata(run_metadata,
'step_{:04d}'.format(step))
tl = timeline.Timeline(run_metadata.step_stats)
timeline_path = os.path.join(args.logdir, 'timeline.trace')
with open(timeline_path, 'w') as f:
f.write(tl.generate_chrome_trace_format(show_memory=True))
else:
summary, loss_value, _, lr = sess.run([summaries, loss, train_ops, learning_rate],
feed_dict={audio_placeholder: audio, lc_placeholder: lc})
writer.add_summary(summary, step)
duration = time.time() - start_time
step_log = 'step {:d} - loss = {:.3f}, lr={:.8f}, time cost={:4f}' \
.format(step, loss_value, lr, duration)
print(step_log)
if step % hparams.save_model_every == 0:
save(saver, sess, args.logdir, step)
last_saved_step = step
if step % hparams.gen_test_wave_every == 0:
generate_wave(lc_placeholder_infer, audio_infer_ops, sess, step, args.gen_wave_dir)
except KeyboardInterrupt:
# Introduce a line break after ^C is displayed so save message
# is on its own line.
print()
finally:
if step > last_saved_step:
save(saver, sess, args.logdir, step)
coord.request_stop()
coord.join()
def generate_wave(lc_placeholder_infer, audio_infer_ops, sess, step, path):
save_name = str(step).zfill(8) + '.wav'
save_name = os.path.join(path, save_name)
lc = read_binary_lc(hparams.gen_file, hparams.num_mels)
if hparams.lc_conv1d or hparams.lc_encode or hparams.transposed_upsampling:
lc = np.reshape(lc, [1, -1, hparams.num_mels])
else:
# upsampling local condition
lc = np.tile(lc, [1, 1, hparams.upsampling_rate])
lc = np.reshape(lc, [1, -1, hparams.num_mels])
audio_output = sess.run(audio_infer_ops, feed_dict={lc_placeholder_infer: lc})
audio_output = audio_output.flatten()
write_wav(audio_output, hparams.sample_rate, save_name)
if __name__ == '__main__':
main()