Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BRNN + perplexity evaluation #65

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import print_function
import numpy as np
import tensorflow as tf

import argparse
import codecs
import time
import os
from six.moves import cPickle

from utils import TextLoader
from model import Model

from six import text_type

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--save_dir', type=str, default='save',
help='model directory to store checkpointed models')
parser.add_argument('--text', type=str,
help='filename of text to evaluate on')
args = parser.parse_args()
eval(args)

def eval(args):
with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
saved_args = cPickle.load(f)
saved_args.batch_size = 1
saved_args.seq_length = 200
with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
chars, vocab = cPickle.load(f)
model = Model(saved_args, False)

with codecs.open(args.text, 'r', encoding='utf-8') as f:
text = f.read()

with tf.Session() as sess:
tf.initialize_all_variables().run()
saver = tf.train.Saver(tf.all_variables())
ckpt = tf.train.get_checkpoint_state(args.save_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
ppl = model.eval(sess, chars, vocab, text)
print('perplexity: {0}'.format(ppl))

if __name__ == '__main__':
main()
126 changes: 93 additions & 33 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import tensorflow as tf
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import rnn
from tensorflow.python.ops import seq2seq

import numpy as np
Expand All @@ -10,61 +11,120 @@ def __init__(self, args, infer=False):
if infer:
args.batch_size = 1
args.seq_length = 1

if args.model == 'rnn':
if args.cell == 'rnn':
cell_fn = rnn_cell.BasicRNNCell
elif args.model == 'gru':
elif args.cell == 'gru':
cell_fn = rnn_cell.GRUCell
elif args.model == 'lstm':
elif args.cell == 'lstm':
cell_fn = rnn_cell.BasicLSTMCell
else:
raise Exception("model type not supported: {}".format(args.model))

cell = cell_fn(args.rnn_size, state_is_tuple=True)

self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers, state_is_tuple=True)

self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.initial_state = cell.zero_state(args.batch_size, tf.float32)

with tf.variable_scope('rnnlm'):
softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size])
softmax_b = tf.get_variable("softmax_b", [args.vocab_size])
with tf.device("/cpu:0"):
embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])
inputs = tf.split(1, args.seq_length, tf.nn.embedding_lookup(embedding, self.input_data))
inputs = [tf.squeeze(input_, [1]) for input_ in inputs]

def loop(prev, _):
prev = tf.matmul(prev, softmax_w) + softmax_b
prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
return tf.nn.embedding_lookup(embedding, prev_symbol)
def build_unirnn(args):
cell = cell_fn(args.rnn_size, state_is_tuple=True)

self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers, state_is_tuple=True)
self.cell2 = cell_dummy = cell_fn(args.rnn_size, state_is_tuple=True)

self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.initial_state = cell.zero_state(args.batch_size, tf.float32)
self.initial_state2 = cell_dummy.zero_state(args.batch_size, tf.float32)

with tf.variable_scope('rnnlm'):
softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size])
softmax_b = tf.get_variable("softmax_b", [args.vocab_size])
with tf.device("/cpu:0"):
embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])
input_embeddings = tf.nn.embedding_lookup(embedding, self.input_data)
inputs = tf.unpack(input_embeddings, axis=1)
def loop(prev, _):
prev = tf.matmul(prev, softmax_w) + softmax_b
prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
return tf.nn.embedding_lookup(embedding, prev_symbol)

outputs, self.last_state = seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if infer else None, scope='rnnlm')
# dummy state in this case
self.last_state2 = cell_dummy.zero_state(args.batch_size, tf.float32)
output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size])
self.logits = tf.matmul(output, softmax_w) + softmax_b

def build_birnn(args):
self.cell = fw_cell = cell_fn(args.rnn_size, state_is_tuple=True)
self.cell2 = bw_cell = cell_fn(args.rnn_size, state_is_tuple=True)

self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.initial_state = fw_cell.zero_state(args.batch_size, tf.float32)
self.initial_state2 = bw_cell.zero_state(args.batch_size, tf.float32)

with tf.variable_scope('rnnlm'):
softmax_w = tf.get_variable("softmax_w", [2*args.rnn_size, args.vocab_size])
softmax_b = tf.get_variable("softmax_b", [args.vocab_size])
with tf.device("/cpu:0"):
embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])
input_embeddings = tf.nn.embedding_lookup(embedding, self.input_data)
inputs = tf.unpack(input_embeddings, axis=1)

outputs, self.last_state, self.last_state2 = rnn.bidirectional_rnn(fw_cell,
bw_cell,
inputs,
initial_state_fw=self.initial_state,
initial_state_bw=self.initial_state2,
dtype=tf.float32)
output = tf.reshape(tf.concat(1, outputs), [-1, 2*args.rnn_size])
self.logits = tf.matmul(output, softmax_w) + softmax_b

if args.model == 'uni':
build_unirnn(args)
else:
build_birnn(args)

outputs, last_state = seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if infer else None, scope='rnnlm')
output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size])
self.logits = tf.matmul(output, softmax_w) + softmax_b
self.probs = tf.nn.softmax(self.logits)
loss = seq2seq.sequence_loss_by_example([self.logits],
self.loss = seq2seq.sequence_loss_by_example([self.logits],
[tf.reshape(self.targets, [-1])],
[tf.ones([args.batch_size * args.seq_length])],
args.vocab_size)
self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_length
self.final_state = last_state
self.cost = tf.reduce_sum(self.loss) / args.batch_size / args.seq_length
self.lr = tf.Variable(0.0, trainable=False)
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),
args.grad_clip)
optimizer = tf.train.AdamOptimizer(self.lr)
self.train_op = optimizer.apply_gradients(zip(grads, tvars))

def eval(self, sess, chars, vocab, text):
batch_size = 200
state = sess.run(self.cell.zero_state(1, tf.float32))
state2 = sess.run(self.cell2.zero_state(1, tf.float32))
x = [vocab[c] if c in vocab else vocab['UNK'] for c in text]
x = [vocab['<S>']] + x + [vocab['</S>']]
total_len = len(x) - 1
# pad x so the batch_size divides it
while len(x) % 200 != 1:
x.append(vocab[' '])
y = np.array(x[1:]).reshape((-1, batch_size))
x = np.array(x[:-1]).reshape((-1, batch_size))

total_loss = 0.0
for i in range(x.shape[0]):
feed = {self.input_data: x[i:i+1, :], self.targets: y[i:i+1, :],
self.initial_state: state, self.initial_state2: state2}
[state, state2, loss] = sess.run([self.last_state, self.last_state2, self.loss], feed)
total_loss += loss.sum()
# need to subtract off loss from padding tokens
total_loss -= loss[total_len % batch_size - batch_size:].sum()
avg_entropy = total_loss / len(text)
return np.exp(avg_entropy) # this is the perplexity

def sample(self, sess, chars, vocab, num=200, prime='The ', sampling_type=1):
state = sess.run(self.cell.zero_state(1, tf.float32))
state2 = sess.run(self.cell2.zero_state(1, tf.float32))
for char in prime[:-1]:
x = np.zeros((1, 1))
x[0, 0] = vocab[char]
feed = {self.input_data: x, self.initial_state:state}
[state] = sess.run([self.final_state], feed)
feed = {self.input_data: x, self.initial_state: state, self.initial_state2: state2}
[state, state2] = sess.run([self.last_state, self.last_state2], feed)

def weighted_pick(weights):
t = np.cumsum(weights)
Expand All @@ -77,7 +137,7 @@ def weighted_pick(weights):
x = np.zeros((1, 1))
x[0, 0] = vocab[char]
feed = {self.input_data: x, self.initial_state:state}
[probs, state] = sess.run([self.probs, self.final_state], feed)
[probs, state, state2] = sess.run([self.probs, self.last_state, self.last_state2], feed)
p = probs[0]

if sampling_type == 0:
Expand Down
15 changes: 9 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ def main():
help='size of RNN hidden state')
parser.add_argument('--num_layers', type=int, default=2,
help='number of layers in the RNN')
parser.add_argument('--model', type=str, default='lstm',
parser.add_argument('--cell', type=str, default='lstm',
help='rnn, gru, or lstm')
parser.add_argument('--model', type=str, default='uni',
help='uni, or bi')
parser.add_argument('--batch_size', type=int, default=50,
help='minibatch size')
parser.add_argument('--seq_length', type=int, default=50,
Expand Down Expand Up @@ -91,14 +93,15 @@ def train(args):
sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
data_loader.reset_batch_pointer()
state = sess.run(model.initial_state)
state2 = sess.run(model.initial_state2)
for b in range(data_loader.num_batches):
start = time.time()
x, y = data_loader.next_batch()
feed = {model.input_data: x, model.targets: y}
for i, (c, h) in enumerate(model.initial_state):
feed[c] = state[i].c
feed[h] = state[i].h
train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
if b == 0 and e == 0:
feed = {model.input_data: x, model.targets: y}
else:
feed = {model.input_data: x, model.targets: y, model.initial_state: state, model.initial_state2: state2}
train_loss, state, state2, _ = sess.run([model.cost, model.last_state, model.last_state2, model.train_op], feed)
end = time.time()
print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
.format(e * data_loader.num_batches + b,
Expand Down
4 changes: 3 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from six.moves import cPickle
import numpy as np


class TextLoader():
def __init__(self, data_dir, batch_size, seq_length, encoding='utf-8'):
self.data_dir = data_dir
Expand All @@ -28,13 +29,14 @@ def preprocess(self, input_file, vocab_file, tensor_file):
with codecs.open(input_file, "r", encoding=self.encoding) as f:
data = f.read()
counter = collections.Counter(data)
counter.update(('<S>', '</S>', 'UNK')) # add tokens for start end and unk
count_pairs = sorted(counter.items(), key=lambda x: -x[1])
self.chars, _ = zip(*count_pairs)
self.vocab_size = len(self.chars)
self.vocab = dict(zip(self.chars, range(len(self.chars))))
with open(vocab_file, 'wb') as f:
cPickle.dump(self.chars, f)
self.tensor = np.array(list(map(self.vocab.get, data)))
self.tensor = np.array(list(map(self.vocab.get, ['<S>'] + list(data) + ['</S>'])))
np.save(tensor_file, self.tensor)

def load_preprocessed(self, vocab_file, tensor_file):
Expand Down