diff --git a/eval.py b/eval.py
new file mode 100644
index 00000000..d5f88501
--- /dev/null
+++ b/eval.py
@@ -0,0 +1,47 @@
+from __future__ import print_function
+import numpy as np
+import tensorflow as tf
+
+import argparse
+import codecs
+import time
+import os
+from six.moves import cPickle
+
+from utils import TextLoader
+from model import Model
+
+from six import text_type
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--save_dir', type=str, default='save',
+ help='model directory to store checkpointed models')
+ parser.add_argument('--text', type=str,
+ help='filename of text to evaluate on')
+ args = parser.parse_args()
+ eval(args)
+
+def eval(args):
+ with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
+ saved_args = cPickle.load(f)
+ saved_args.batch_size = 1
+ saved_args.seq_length = 200
+ with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
+ chars, vocab = cPickle.load(f)
+ model = Model(saved_args, False)
+
+ with codecs.open(args.text, 'r', encoding='utf-8') as f:
+ text = f.read()
+
+ with tf.Session() as sess:
+ tf.initialize_all_variables().run()
+ saver = tf.train.Saver(tf.all_variables())
+ ckpt = tf.train.get_checkpoint_state(args.save_dir)
+ if ckpt and ckpt.model_checkpoint_path:
+ saver.restore(sess, ckpt.model_checkpoint_path)
+ ppl = model.eval(sess, chars, vocab, text)
+ print('perplexity: {0}'.format(ppl))
+
+if __name__ == '__main__':
+ main()
diff --git a/model.py b/model.py
index 5ea675f4..a2909a66 100644
--- a/model.py
+++ b/model.py
@@ -1,5 +1,6 @@
import tensorflow as tf
from tensorflow.python.ops import rnn_cell
+from tensorflow.python.ops import rnn
from tensorflow.python.ops import seq2seq
import numpy as np
@@ -10,47 +11,81 @@ def __init__(self, args, infer=False):
if infer:
args.batch_size = 1
args.seq_length = 1
-
- if args.model == 'rnn':
+ if args.cell == 'rnn':
cell_fn = rnn_cell.BasicRNNCell
- elif args.model == 'gru':
+ elif args.cell == 'gru':
cell_fn = rnn_cell.GRUCell
- elif args.model == 'lstm':
+ elif args.cell == 'lstm':
cell_fn = rnn_cell.BasicLSTMCell
else:
raise Exception("model type not supported: {}".format(args.model))
- cell = cell_fn(args.rnn_size, state_is_tuple=True)
-
- self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers, state_is_tuple=True)
-
- self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
- self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
- self.initial_state = cell.zero_state(args.batch_size, tf.float32)
-
- with tf.variable_scope('rnnlm'):
- softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size])
- softmax_b = tf.get_variable("softmax_b", [args.vocab_size])
- with tf.device("/cpu:0"):
- embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])
- inputs = tf.split(1, args.seq_length, tf.nn.embedding_lookup(embedding, self.input_data))
- inputs = [tf.squeeze(input_, [1]) for input_ in inputs]
-
- def loop(prev, _):
- prev = tf.matmul(prev, softmax_w) + softmax_b
- prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
- return tf.nn.embedding_lookup(embedding, prev_symbol)
+ def build_unirnn(args):
+ cell = cell_fn(args.rnn_size, state_is_tuple=True)
+
+ self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers, state_is_tuple=True)
+ self.cell2 = cell_dummy = cell_fn(args.rnn_size, state_is_tuple=True)
+
+ self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
+ self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
+ self.initial_state = cell.zero_state(args.batch_size, tf.float32)
+ self.initial_state2 = cell_dummy.zero_state(args.batch_size, tf.float32)
+
+ with tf.variable_scope('rnnlm'):
+ softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size])
+ softmax_b = tf.get_variable("softmax_b", [args.vocab_size])
+ with tf.device("/cpu:0"):
+ embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])
+ input_embeddings = tf.nn.embedding_lookup(embedding, self.input_data)
+ inputs = tf.unpack(input_embeddings, axis=1)
+ def loop(prev, _):
+ prev = tf.matmul(prev, softmax_w) + softmax_b
+ prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
+ return tf.nn.embedding_lookup(embedding, prev_symbol)
+
+ outputs, self.last_state = seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if infer else None, scope='rnnlm')
+ # dummy state in this case
+ self.last_state2 = cell_dummy.zero_state(args.batch_size, tf.float32)
+ output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size])
+ self.logits = tf.matmul(output, softmax_w) + softmax_b
+
+ def build_birnn(args):
+ self.cell = fw_cell = cell_fn(args.rnn_size, state_is_tuple=True)
+ self.cell2 = bw_cell = cell_fn(args.rnn_size, state_is_tuple=True)
+
+ self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
+ self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
+ self.initial_state = fw_cell.zero_state(args.batch_size, tf.float32)
+ self.initial_state2 = bw_cell.zero_state(args.batch_size, tf.float32)
+
+ with tf.variable_scope('rnnlm'):
+ softmax_w = tf.get_variable("softmax_w", [2*args.rnn_size, args.vocab_size])
+ softmax_b = tf.get_variable("softmax_b", [args.vocab_size])
+ with tf.device("/cpu:0"):
+ embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])
+ input_embeddings = tf.nn.embedding_lookup(embedding, self.input_data)
+ inputs = tf.unpack(input_embeddings, axis=1)
+
+ outputs, self.last_state, self.last_state2 = rnn.bidirectional_rnn(fw_cell,
+ bw_cell,
+ inputs,
+ initial_state_fw=self.initial_state,
+ initial_state_bw=self.initial_state2,
+ dtype=tf.float32)
+ output = tf.reshape(tf.concat(1, outputs), [-1, 2*args.rnn_size])
+ self.logits = tf.matmul(output, softmax_w) + softmax_b
+
+ if args.model == 'uni':
+ build_unirnn(args)
+ else:
+ build_birnn(args)
- outputs, last_state = seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if infer else None, scope='rnnlm')
- output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size])
- self.logits = tf.matmul(output, softmax_w) + softmax_b
self.probs = tf.nn.softmax(self.logits)
- loss = seq2seq.sequence_loss_by_example([self.logits],
+ self.loss = seq2seq.sequence_loss_by_example([self.logits],
[tf.reshape(self.targets, [-1])],
[tf.ones([args.batch_size * args.seq_length])],
args.vocab_size)
- self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_length
- self.final_state = last_state
+ self.cost = tf.reduce_sum(self.loss) / args.batch_size / args.seq_length
self.lr = tf.Variable(0.0, trainable=False)
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),
@@ -58,13 +93,38 @@ def loop(prev, _):
optimizer = tf.train.AdamOptimizer(self.lr)
self.train_op = optimizer.apply_gradients(zip(grads, tvars))
+ def eval(self, sess, chars, vocab, text):
+ batch_size = 200
+ state = sess.run(self.cell.zero_state(1, tf.float32))
+ state2 = sess.run(self.cell2.zero_state(1, tf.float32))
+ x = [vocab[c] if c in vocab else vocab['UNK'] for c in text]
+ x = [vocab['']] + x + [vocab['']]
+ total_len = len(x) - 1
+ # pad x so the batch_size divides it
+ while len(x) % 200 != 1:
+ x.append(vocab[' '])
+ y = np.array(x[1:]).reshape((-1, batch_size))
+ x = np.array(x[:-1]).reshape((-1, batch_size))
+
+ total_loss = 0.0
+ for i in range(x.shape[0]):
+ feed = {self.input_data: x[i:i+1, :], self.targets: y[i:i+1, :],
+ self.initial_state: state, self.initial_state2: state2}
+ [state, state2, loss] = sess.run([self.last_state, self.last_state2, self.loss], feed)
+ total_loss += loss.sum()
+ # need to subtract off loss from padding tokens
+ total_loss -= loss[total_len % batch_size - batch_size:].sum()
+ avg_entropy = total_loss / len(text)
+ return np.exp(avg_entropy) # this is the perplexity
+
def sample(self, sess, chars, vocab, num=200, prime='The ', sampling_type=1):
state = sess.run(self.cell.zero_state(1, tf.float32))
+ state2 = sess.run(self.cell2.zero_state(1, tf.float32))
for char in prime[:-1]:
x = np.zeros((1, 1))
x[0, 0] = vocab[char]
- feed = {self.input_data: x, self.initial_state:state}
- [state] = sess.run([self.final_state], feed)
+ feed = {self.input_data: x, self.initial_state: state, self.initial_state2: state2}
+ [state, state2] = sess.run([self.last_state, self.last_state2], feed)
def weighted_pick(weights):
t = np.cumsum(weights)
@@ -77,7 +137,7 @@ def weighted_pick(weights):
x = np.zeros((1, 1))
x[0, 0] = vocab[char]
feed = {self.input_data: x, self.initial_state:state}
- [probs, state] = sess.run([self.probs, self.final_state], feed)
+ [probs, state, state2] = sess.run([self.probs, self.last_state, self.last_state2], feed)
p = probs[0]
if sampling_type == 0:
diff --git a/train.py b/train.py
index 8f78f1a5..e793b1e5 100644
--- a/train.py
+++ b/train.py
@@ -20,8 +20,10 @@ def main():
help='size of RNN hidden state')
parser.add_argument('--num_layers', type=int, default=2,
help='number of layers in the RNN')
- parser.add_argument('--model', type=str, default='lstm',
+ parser.add_argument('--cell', type=str, default='lstm',
help='rnn, gru, or lstm')
+ parser.add_argument('--model', type=str, default='uni',
+ help='uni, or bi')
parser.add_argument('--batch_size', type=int, default=50,
help='minibatch size')
parser.add_argument('--seq_length', type=int, default=50,
@@ -91,14 +93,15 @@ def train(args):
sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
data_loader.reset_batch_pointer()
state = sess.run(model.initial_state)
+ state2 = sess.run(model.initial_state2)
for b in range(data_loader.num_batches):
start = time.time()
x, y = data_loader.next_batch()
- feed = {model.input_data: x, model.targets: y}
- for i, (c, h) in enumerate(model.initial_state):
- feed[c] = state[i].c
- feed[h] = state[i].h
- train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
+ if b == 0 and e == 0:
+ feed = {model.input_data: x, model.targets: y}
+ else:
+ feed = {model.input_data: x, model.targets: y, model.initial_state: state, model.initial_state2: state2}
+ train_loss, state, state2, _ = sess.run([model.cost, model.last_state, model.last_state2, model.train_op], feed)
end = time.time()
print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
.format(e * data_loader.num_batches + b,
diff --git a/utils.py b/utils.py
index 4df553ff..ddac5b91 100644
--- a/utils.py
+++ b/utils.py
@@ -4,6 +4,7 @@
from six.moves import cPickle
import numpy as np
+
class TextLoader():
def __init__(self, data_dir, batch_size, seq_length, encoding='utf-8'):
self.data_dir = data_dir
@@ -28,13 +29,14 @@ def preprocess(self, input_file, vocab_file, tensor_file):
with codecs.open(input_file, "r", encoding=self.encoding) as f:
data = f.read()
counter = collections.Counter(data)
+ counter.update(('', '', 'UNK')) # add tokens for start end and unk
count_pairs = sorted(counter.items(), key=lambda x: -x[1])
self.chars, _ = zip(*count_pairs)
self.vocab_size = len(self.chars)
self.vocab = dict(zip(self.chars, range(len(self.chars))))
with open(vocab_file, 'wb') as f:
cPickle.dump(self.chars, f)
- self.tensor = np.array(list(map(self.vocab.get, data)))
+ self.tensor = np.array(list(map(self.vocab.get, [''] + list(data) + [''])))
np.save(tensor_file, self.tensor)
def load_preprocessed(self, vocab_file, tensor_file):