Skip to content

Commit

Permalink
Be more consistent about which layers hparams are passed down to.
Browse files Browse the repository at this point in the history
  • Loading branch information
keithito committed Apr 17, 2018
1 parent a3531b8 commit 038ea80
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 20 deletions.
1 change: 0 additions & 1 deletion hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
ref_level_db=20,

# Model:
# TODO: add more configurable hparams
outputs_per_step=5,
embed_depth=256,
prenet_depth1=256,
Expand Down
11 changes: 5 additions & 6 deletions models/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from tensorflow.contrib.rnn import GRUCell


def prenet(inputs, is_training, h_params, scope=None):
layer_sizes=[h_params.prenet_depth1, h_params.prenet_depth2]
def prenet(inputs, is_training, layer_sizes, scope=None):
x = inputs
drop_rate = 0.5 if is_training else 0.0
with tf.variable_scope(scope or 'prenet'):
Expand All @@ -13,7 +12,7 @@ def prenet(inputs, is_training, h_params, scope=None):
return x


def encoder_cbhg(inputs, input_lengths, is_training, h_params):
def encoder_cbhg(inputs, input_lengths, is_training, depth):
input_channels = inputs.get_shape()[2]
return cbhg(
inputs,
Expand All @@ -22,18 +21,18 @@ def encoder_cbhg(inputs, input_lengths, is_training, h_params):
scope='encoder_cbhg',
K=16,
projections=[128, input_channels],
depth=h_params.encoder_depth)
depth=depth)


def post_cbhg(inputs, input_dim, is_training, h_params):
def post_cbhg(inputs, input_dim, is_training, depth):
return cbhg(
inputs,
None,
is_training,
scope='post_cbhg',
K=8,
projections=[256, input_dim],
depth=h_params.postnet_depth)
depth=depth)


def cbhg(inputs, input_lengths, is_training, scope, K, projections, depth):
Expand Down
7 changes: 3 additions & 4 deletions models/rnn_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

class DecoderPrenetWrapper(RNNCell):
'''Runs RNN inputs through a prenet before sending them to the cell.'''
def __init__(self, cell, is_training, h_params):
def __init__(self, cell, is_training, layer_sizes):
super(DecoderPrenetWrapper, self).__init__()
self._cell = cell
self._is_training = is_training
self._h_params = h_params
self._layer_sizes = layer_sizes

@property
def state_size(self):
Expand All @@ -21,8 +21,7 @@ def output_size(self):
return self._cell.output_size

def call(self, inputs, state):
prenet_out = prenet(inputs, self._is_training, self._h_params,
scope='decoder_prenet')
prenet_out = prenet(inputs, self._is_training, self._layer_sizes, scope='decoder_prenet')
return self._cell(prenet_out, state)

def zero_state(self, batch_size, dtype):
Expand Down
20 changes: 11 additions & 9 deletions models/tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,17 @@ def initialize(self, inputs, input_lengths, mel_targets=None, linear_targets=Non
embedding_table = tf.get_variable(
'embedding', [len(symbols), hp.embed_depth], dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.5))
embedded_inputs = tf.nn.embedding_lookup(embedding_table, inputs) # [N, T_in, embed_depth=256]
embedded_inputs = tf.nn.embedding_lookup(embedding_table, inputs) # [N, T_in, embed_depth=256]

# Encoder
prenet_outputs = prenet(embedded_inputs, is_training, hp) # [N, T_in, prenet_depth2=128]
encoder_outputs = encoder_cbhg(prenet_outputs, input_lengths, is_training, # [N, T_in, encoder_depth=256]
hp)
prenet_layer_sizes = [hp.prenet_depth1, hp.prenet_depth2]
prenet_outputs = prenet(embedded_inputs, is_training, prenet_layer_sizes) # [N, T_in, prenet_depth2=128]
encoder_outputs = encoder_cbhg(prenet_outputs, input_lengths, is_training, # [N, T_in, encoder_depth=256]
hp.encoder_depth)

# Attention
attention_cell = AttentionWrapper(
DecoderPrenetWrapper(GRUCell(hp.attention_depth), is_training, hp),
DecoderPrenetWrapper(GRUCell(hp.attention_depth), is_training, prenet_layer_sizes),
BahdanauAttention(hp.attention_depth, encoder_outputs),
alignment_history=True,
output_attention=False) # [N, T_in, attention_depth=256]
Expand All @@ -75,14 +76,15 @@ def initialize(self, inputs, input_lengths, mel_targets=None, linear_targets=Non

(decoder_outputs, _), final_decoder_state, _ = tf.contrib.seq2seq.dynamic_decode(
BasicDecoder(output_cell, helper, decoder_init_state),
maximum_iterations=hp.max_iters) # [N, T_out/r, M*r]
maximum_iterations=hp.max_iters) # [N, T_out/r, M*r]

# Reshape outputs to be one output per entry
mel_outputs = tf.reshape(decoder_outputs, [batch_size, -1, hp.num_mels]) # [N, T_out, M]
mel_outputs = tf.reshape(decoder_outputs, [batch_size, -1, hp.num_mels]) # [N, T_out, M]

# Add post-processing CBHG:
post_outputs = post_cbhg(mel_outputs, hp.num_mels, is_training, hp) # [N, T_out, prenet_depth2=128]
linear_outputs = tf.layers.dense(post_outputs, hp.num_freq) # [N, T_out, F]
post_outputs = post_cbhg(mel_outputs, hp.num_mels, is_training, # [N, T_out, prenet_depth2=128]
hp.postnet_depth)
linear_outputs = tf.layers.dense(post_outputs, hp.num_freq) # [N, T_out, F]

# Grab alignments from the final decoder state:
alignments = tf.transpose(final_decoder_state[0].alignment_history.stack(), [1, 2, 0])
Expand Down

0 comments on commit 038ea80

Please sign in to comment.