Skip to content

Commit

Permalink
Allow different number of prenet layers to be specified.
Browse files Browse the repository at this point in the history
  • Loading branch information
keithito committed Apr 17, 2018
1 parent 038ea80 commit 51a38a1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
3 changes: 1 addition & 2 deletions hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
# Model:
outputs_per_step=5,
embed_depth=256,
prenet_depth1=256,
prenet_depth2=128,
prenet_depths=[256, 128],
encoder_depth=256,
postnet_depth=256,
attention_depth=256,
Expand Down
7 changes: 3 additions & 4 deletions models/tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,13 @@ def initialize(self, inputs, input_lengths, mel_targets=None, linear_targets=Non
embedded_inputs = tf.nn.embedding_lookup(embedding_table, inputs) # [N, T_in, embed_depth=256]

# Encoder
prenet_layer_sizes = [hp.prenet_depth1, hp.prenet_depth2]
prenet_outputs = prenet(embedded_inputs, is_training, prenet_layer_sizes) # [N, T_in, prenet_depth2=128]
prenet_outputs = prenet(embedded_inputs, is_training, hp.prenet_depths) # [N, T_in, prenet_depths[-1]=128]
encoder_outputs = encoder_cbhg(prenet_outputs, input_lengths, is_training, # [N, T_in, encoder_depth=256]
hp.encoder_depth)

# Attention
attention_cell = AttentionWrapper(
DecoderPrenetWrapper(GRUCell(hp.attention_depth), is_training, prenet_layer_sizes),
DecoderPrenetWrapper(GRUCell(hp.attention_depth), is_training, hp.prenet_depths),
BahdanauAttention(hp.attention_depth, encoder_outputs),
alignment_history=True,
output_attention=False) # [N, T_in, attention_depth=256]
Expand Down Expand Up @@ -82,7 +81,7 @@ def initialize(self, inputs, input_lengths, mel_targets=None, linear_targets=Non
mel_outputs = tf.reshape(decoder_outputs, [batch_size, -1, hp.num_mels]) # [N, T_out, M]

# Add post-processing CBHG:
post_outputs = post_cbhg(mel_outputs, hp.num_mels, is_training, # [N, T_out, prenet_depth2=128]
post_outputs = post_cbhg(mel_outputs, hp.num_mels, is_training, # [N, T_out, postnet_depth=256]
hp.postnet_depth)
linear_outputs = tf.layers.dense(post_outputs, hp.num_freq) # [N, T_out, F]

Expand Down

0 comments on commit 51a38a1

Please sign in to comment.