From 0e3ebe24f7887df08db745bfecdc15f811de0649 Mon Sep 17 00:00:00 2001 From: primetang Date: Wed, 28 Dec 2016 11:12:33 +0800 Subject: [PATCH] better code to employ decaying learning rate --- model.py | 6 ++++-- train.py | 10 +++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/model.py b/model.py index 5ea675f4..f56f2aa3 100644 --- a/model.py +++ b/model.py @@ -51,12 +51,14 @@ def loop(prev, _): args.vocab_size) self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_length self.final_state = last_state - self.lr = tf.Variable(0.0, trainable=False) + self.global_step = tf.Variable(0, name='global_step', trainable=False) + self.lr = tf.train.exponential_decay( + args.learning_rate, self.global_step, args.decay_step, args.decay_rate) tvars = tf.trainable_variables() grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars), args.grad_clip) optimizer = tf.train.AdamOptimizer(self.lr) - self.train_op = optimizer.apply_gradients(zip(grads, tvars)) + self.train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step) def sample(self, sess, chars, vocab, num=200, prime='The ', sampling_type=1): state = sess.run(self.cell.zero_state(1, tf.float32)) diff --git a/train.py b/train.py index 08ab2957..bc1781e2 100644 --- a/train.py +++ b/train.py @@ -50,7 +50,8 @@ def main(): def train(args): data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length) args.vocab_size = data_loader.vocab_size - + args.decay_step = data_loader.num_batches + # check compatibility if training is continued from previously saved model if args.init_from is not None: # check if all necessary files exist @@ -88,7 +89,6 @@ def train(args): if args.init_from is not None: saver.restore(sess, ckpt.model_checkpoint_path) for e in range(args.num_epochs): - sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e))) data_loader.reset_batch_pointer() state = sess.run(model.initial_state) for b in range(data_loader.num_batches): @@ -98,12 +98,12 @@ def train(args): for i, (c, h) in enumerate(model.initial_state): feed[c] = state[i].c feed[h] = state[i].h - train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed) + lr, train_loss, state, _ = sess.run([model.lr, model.cost, model.final_state, model.train_op], feed) end = time.time() - print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ + print("{}/{} (epoch {}), lr = {:.6f}, train_loss = {:.3f}, time/batch = {:.3f}" \ .format(e * data_loader.num_batches + b, args.num_epochs * data_loader.num_batches, - e, train_loss, end - start)) + e, lr, train_loss, end - start)) if (e * data_loader.num_batches + b) % args.save_every == 0\ or (e==args.num_epochs-1 and b == data_loader.num_batches-1): # save for the last result checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')