Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor fixes #66

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ def __init__(self, args, infer=False):
else:
raise Exception("model type not supported: {}".format(args.model))

cell = cell_fn(args.rnn_size, state_is_tuple=True)

self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers, state_is_tuple=True)
if args.model == 'lstm':
cell = cell_fn(args.rnn_size, state_is_tuple=True)
self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers, state_is_tuple=True)
else:
cell = cell_fn(args.rnn_size)
self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers)

self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
Expand Down
3 changes: 0 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,6 @@ def train(args):
start = time.time()
x, y = data_loader.next_batch()
feed = {model.input_data: x, model.targets: y}
for i, (c, h) in enumerate(model.initial_state):
feed[c] = state[i].c
feed[h] = state[i].h
train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @lns, are you sure about this? Could you point me to a resource which says passing the state directly is permitted?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I take a look at the recent sources of tensorflow( rnn_cell_impl.py and seq2seq.py(?)) and the code seems changed a lot. In my experiments with r11(?) it's safe to remove these lines to have reasonable results. I'm not 100% sure about the way of passing states to a lstm cell, as we are using self.initial_state in rnn_decoder in model.py which is not(?) updated across the batches. If we want to pass the states from the last batch to the next batch, maybe there are more lines to be changed?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you, filed an issue tensorflow/models#774

end = time.time()
print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
Expand Down