diff --git a/eval.py b/eval.py new file mode 100644 index 00000000..d5f88501 --- /dev/null +++ b/eval.py @@ -0,0 +1,47 @@ +from __future__ import print_function +import numpy as np +import tensorflow as tf + +import argparse +import codecs +import time +import os +from six.moves import cPickle + +from utils import TextLoader +from model import Model + +from six import text_type + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--save_dir', type=str, default='save', + help='model directory to store checkpointed models') + parser.add_argument('--text', type=str, + help='filename of text to evaluate on') + args = parser.parse_args() + eval(args) + +def eval(args): + with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f: + saved_args = cPickle.load(f) + saved_args.batch_size = 1 + saved_args.seq_length = 200 + with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f: + chars, vocab = cPickle.load(f) + model = Model(saved_args, False) + + with codecs.open(args.text, 'r', encoding='utf-8') as f: + text = f.read() + + with tf.Session() as sess: + tf.initialize_all_variables().run() + saver = tf.train.Saver(tf.all_variables()) + ckpt = tf.train.get_checkpoint_state(args.save_dir) + if ckpt and ckpt.model_checkpoint_path: + saver.restore(sess, ckpt.model_checkpoint_path) + ppl = model.eval(sess, chars, vocab, text) + print('perplexity: {0}'.format(ppl)) + +if __name__ == '__main__': + main() diff --git a/model.py b/model.py index 5ea675f4..a2909a66 100644 --- a/model.py +++ b/model.py @@ -1,5 +1,6 @@ import tensorflow as tf from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import rnn from tensorflow.python.ops import seq2seq import numpy as np @@ -10,47 +11,81 @@ def __init__(self, args, infer=False): if infer: args.batch_size = 1 args.seq_length = 1 - - if args.model == 'rnn': + if args.cell == 'rnn': cell_fn = rnn_cell.BasicRNNCell - elif args.model == 'gru': + elif args.cell == 'gru': cell_fn = rnn_cell.GRUCell - elif args.model == 'lstm': + elif args.cell == 'lstm': cell_fn = rnn_cell.BasicLSTMCell else: raise Exception("model type not supported: {}".format(args.model)) - cell = cell_fn(args.rnn_size, state_is_tuple=True) - - self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers, state_is_tuple=True) - - self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length]) - self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length]) - self.initial_state = cell.zero_state(args.batch_size, tf.float32) - - with tf.variable_scope('rnnlm'): - softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size]) - softmax_b = tf.get_variable("softmax_b", [args.vocab_size]) - with tf.device("/cpu:0"): - embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size]) - inputs = tf.split(1, args.seq_length, tf.nn.embedding_lookup(embedding, self.input_data)) - inputs = [tf.squeeze(input_, [1]) for input_ in inputs] - - def loop(prev, _): - prev = tf.matmul(prev, softmax_w) + softmax_b - prev_symbol = tf.stop_gradient(tf.argmax(prev, 1)) - return tf.nn.embedding_lookup(embedding, prev_symbol) + def build_unirnn(args): + cell = cell_fn(args.rnn_size, state_is_tuple=True) + + self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers, state_is_tuple=True) + self.cell2 = cell_dummy = cell_fn(args.rnn_size, state_is_tuple=True) + + self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length]) + self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length]) + self.initial_state = cell.zero_state(args.batch_size, tf.float32) + self.initial_state2 = cell_dummy.zero_state(args.batch_size, tf.float32) + + with tf.variable_scope('rnnlm'): + softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size]) + softmax_b = tf.get_variable("softmax_b", [args.vocab_size]) + with tf.device("/cpu:0"): + embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size]) + input_embeddings = tf.nn.embedding_lookup(embedding, self.input_data) + inputs = tf.unpack(input_embeddings, axis=1) + def loop(prev, _): + prev = tf.matmul(prev, softmax_w) + softmax_b + prev_symbol = tf.stop_gradient(tf.argmax(prev, 1)) + return tf.nn.embedding_lookup(embedding, prev_symbol) + + outputs, self.last_state = seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if infer else None, scope='rnnlm') + # dummy state in this case + self.last_state2 = cell_dummy.zero_state(args.batch_size, tf.float32) + output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size]) + self.logits = tf.matmul(output, softmax_w) + softmax_b + + def build_birnn(args): + self.cell = fw_cell = cell_fn(args.rnn_size, state_is_tuple=True) + self.cell2 = bw_cell = cell_fn(args.rnn_size, state_is_tuple=True) + + self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length]) + self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length]) + self.initial_state = fw_cell.zero_state(args.batch_size, tf.float32) + self.initial_state2 = bw_cell.zero_state(args.batch_size, tf.float32) + + with tf.variable_scope('rnnlm'): + softmax_w = tf.get_variable("softmax_w", [2*args.rnn_size, args.vocab_size]) + softmax_b = tf.get_variable("softmax_b", [args.vocab_size]) + with tf.device("/cpu:0"): + embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size]) + input_embeddings = tf.nn.embedding_lookup(embedding, self.input_data) + inputs = tf.unpack(input_embeddings, axis=1) + + outputs, self.last_state, self.last_state2 = rnn.bidirectional_rnn(fw_cell, + bw_cell, + inputs, + initial_state_fw=self.initial_state, + initial_state_bw=self.initial_state2, + dtype=tf.float32) + output = tf.reshape(tf.concat(1, outputs), [-1, 2*args.rnn_size]) + self.logits = tf.matmul(output, softmax_w) + softmax_b + + if args.model == 'uni': + build_unirnn(args) + else: + build_birnn(args) - outputs, last_state = seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if infer else None, scope='rnnlm') - output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size]) - self.logits = tf.matmul(output, softmax_w) + softmax_b self.probs = tf.nn.softmax(self.logits) - loss = seq2seq.sequence_loss_by_example([self.logits], + self.loss = seq2seq.sequence_loss_by_example([self.logits], [tf.reshape(self.targets, [-1])], [tf.ones([args.batch_size * args.seq_length])], args.vocab_size) - self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_length - self.final_state = last_state + self.cost = tf.reduce_sum(self.loss) / args.batch_size / args.seq_length self.lr = tf.Variable(0.0, trainable=False) tvars = tf.trainable_variables() grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars), @@ -58,13 +93,38 @@ def loop(prev, _): optimizer = tf.train.AdamOptimizer(self.lr) self.train_op = optimizer.apply_gradients(zip(grads, tvars)) + def eval(self, sess, chars, vocab, text): + batch_size = 200 + state = sess.run(self.cell.zero_state(1, tf.float32)) + state2 = sess.run(self.cell2.zero_state(1, tf.float32)) + x = [vocab[c] if c in vocab else vocab['UNK'] for c in text] + x = [vocab['']] + x + [vocab['']] + total_len = len(x) - 1 + # pad x so the batch_size divides it + while len(x) % 200 != 1: + x.append(vocab[' ']) + y = np.array(x[1:]).reshape((-1, batch_size)) + x = np.array(x[:-1]).reshape((-1, batch_size)) + + total_loss = 0.0 + for i in range(x.shape[0]): + feed = {self.input_data: x[i:i+1, :], self.targets: y[i:i+1, :], + self.initial_state: state, self.initial_state2: state2} + [state, state2, loss] = sess.run([self.last_state, self.last_state2, self.loss], feed) + total_loss += loss.sum() + # need to subtract off loss from padding tokens + total_loss -= loss[total_len % batch_size - batch_size:].sum() + avg_entropy = total_loss / len(text) + return np.exp(avg_entropy) # this is the perplexity + def sample(self, sess, chars, vocab, num=200, prime='The ', sampling_type=1): state = sess.run(self.cell.zero_state(1, tf.float32)) + state2 = sess.run(self.cell2.zero_state(1, tf.float32)) for char in prime[:-1]: x = np.zeros((1, 1)) x[0, 0] = vocab[char] - feed = {self.input_data: x, self.initial_state:state} - [state] = sess.run([self.final_state], feed) + feed = {self.input_data: x, self.initial_state: state, self.initial_state2: state2} + [state, state2] = sess.run([self.last_state, self.last_state2], feed) def weighted_pick(weights): t = np.cumsum(weights) @@ -77,7 +137,7 @@ def weighted_pick(weights): x = np.zeros((1, 1)) x[0, 0] = vocab[char] feed = {self.input_data: x, self.initial_state:state} - [probs, state] = sess.run([self.probs, self.final_state], feed) + [probs, state, state2] = sess.run([self.probs, self.last_state, self.last_state2], feed) p = probs[0] if sampling_type == 0: diff --git a/train.py b/train.py index 8f78f1a5..e793b1e5 100644 --- a/train.py +++ b/train.py @@ -20,8 +20,10 @@ def main(): help='size of RNN hidden state') parser.add_argument('--num_layers', type=int, default=2, help='number of layers in the RNN') - parser.add_argument('--model', type=str, default='lstm', + parser.add_argument('--cell', type=str, default='lstm', help='rnn, gru, or lstm') + parser.add_argument('--model', type=str, default='uni', + help='uni, or bi') parser.add_argument('--batch_size', type=int, default=50, help='minibatch size') parser.add_argument('--seq_length', type=int, default=50, @@ -91,14 +93,15 @@ def train(args): sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e))) data_loader.reset_batch_pointer() state = sess.run(model.initial_state) + state2 = sess.run(model.initial_state2) for b in range(data_loader.num_batches): start = time.time() x, y = data_loader.next_batch() - feed = {model.input_data: x, model.targets: y} - for i, (c, h) in enumerate(model.initial_state): - feed[c] = state[i].c - feed[h] = state[i].h - train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed) + if b == 0 and e == 0: + feed = {model.input_data: x, model.targets: y} + else: + feed = {model.input_data: x, model.targets: y, model.initial_state: state, model.initial_state2: state2} + train_loss, state, state2, _ = sess.run([model.cost, model.last_state, model.last_state2, model.train_op], feed) end = time.time() print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(e * data_loader.num_batches + b, diff --git a/utils.py b/utils.py index 4df553ff..ddac5b91 100644 --- a/utils.py +++ b/utils.py @@ -4,6 +4,7 @@ from six.moves import cPickle import numpy as np + class TextLoader(): def __init__(self, data_dir, batch_size, seq_length, encoding='utf-8'): self.data_dir = data_dir @@ -28,13 +29,14 @@ def preprocess(self, input_file, vocab_file, tensor_file): with codecs.open(input_file, "r", encoding=self.encoding) as f: data = f.read() counter = collections.Counter(data) + counter.update(('', '', 'UNK')) # add tokens for start end and unk count_pairs = sorted(counter.items(), key=lambda x: -x[1]) self.chars, _ = zip(*count_pairs) self.vocab_size = len(self.chars) self.vocab = dict(zip(self.chars, range(len(self.chars)))) with open(vocab_file, 'wb') as f: cPickle.dump(self.chars, f) - self.tensor = np.array(list(map(self.vocab.get, data))) + self.tensor = np.array(list(map(self.vocab.get, [''] + list(data) + ['']))) np.save(tensor_file, self.tensor) def load_preprocessed(self, vocab_file, tensor_file):