-
Notifications
You must be signed in to change notification settings - Fork 115
/
tutorial_rnnlm_minibatch.py
142 lines (118 loc) · 4.08 KB
/
tutorial_rnnlm_minibatch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from collections import Counter, defaultdict
from itertools import count
import random
import dynet as dy
import numpy as np
# format of files: each line is "word1/tag2 word2/tag2 ..."
train_file="CHAR_TRAIN"
test_file="CHAR_DEV"
MB_SIZE = 10
class Vocab:
def __init__(self, w2i=None):
if w2i is None: w2i = defaultdict(count(0).next)
self.w2i = dict(w2i)
self.i2w = {i:w for w,i in w2i.iteritems()}
@classmethod
def from_corpus(cls, corpus):
w2i = defaultdict(count(0).next)
for sent in corpus:
[w2i[word] for word in sent]
return Vocab(w2i)
def size(self): return len(self.w2i.keys())
def read(fname):
"""
Read a file where each line is of the form "word1 word2 ..."
Yields lists of the form [word1, word2, ...]
"""
with file(fname) as fh:
for line in fh:
sent = line.strip().split()
sent.append("<s>")
yield sent
train=list(read(train_file))
test=list(read(test_file))
words=[]
wc=Counter()
for sent in train:
for w in sent:
words.append(w)
wc[w]+=1
vw = Vocab.from_corpus([words])
S = vw.w2i["<s>"]
nwords = vw.size()
# DyNet Starts
model = dy.Model()
trainer = dy.AdamTrainer(model)
# Lookup parameters for word embeddings
WORDS_LOOKUP = model.add_lookup_parameters((nwords, 64))
# Word-level LSTM (layers=1, input=64, output=128, model)
RNN = dy.LSTMBuilder(1, 64, 128, model)
# Softmax weights/biases on top of LSTM outputs
W_sm = model.add_parameters((nwords, 128))
b_sm = model.add_parameters(nwords)
# Build the language model graph
def calc_lm_loss(sents):
dy.renew_cg()
# parameters -> expressions
W_exp = dy.parameter(W_sm)
b_exp = dy.parameter(b_sm)
# initialize the RNN
f_init = RNN.initial_state()
# get the wids and masks for each step
tot_words = 0
wids = []
masks = []
for i in range(len(sents[0])):
wids.append([
(vw.w2i[sent[i]] if len(sent)>i else S) for sent in sents])
mask = [(1 if len(sent)>i else 0) for sent in sents]
masks.append(mask)
tot_words += sum(mask)
# start the rnn by inputting "<s>"
init_ids = [S] * len(sents)
s = f_init.add_input(dy.lookup_batch(WORDS_LOOKUP,init_ids))
# feed word vectors into the RNN and predict the next word
losses = []
for wid, mask in zip(wids, masks):
# calculate the softmax and loss
score = W_exp * s.output() + b_exp
loss = dy.pickneglogsoftmax_batch(score, wid)
# mask the loss if at least one sentence is shorter
if mask[-1] != 1:
mask_expr = dy.inputVector(mask)
mask_expr = dy.reshape(mask_expr, (1,), MB_SIZE)
loss = loss * mask_expr
losses.append(loss)
# update the state of the RNN
wemb = dy.lookup_batch(WORDS_LOOKUP, wid)
s = s.add_input(wemb)
return dy.sum_batches(dy.esum(losses)), tot_words
num_tagged = cum_loss = 0
# Sort training sentences in descending order and count minibatches
train.sort(key=lambda x: -len(x))
test.sort(key=lambda x: -len(x))
train_order = [x*MB_SIZE for x in range(len(train)/MB_SIZE)]
test_order = [x*MB_SIZE for x in range(len(test)/MB_SIZE)]
# Perform training
for ITER in xrange(50):
random.shuffle(train_order)
for i,sid in enumerate(train_order,1):
if i % (500/MB_SIZE) == 0:
trainer.status()
print cum_loss / num_tagged
num_tagged = cum_loss = 0
if i % (10000/MB_SIZE) == 0 or i == len(train_order)-1:
dev_loss = dev_words = 0
for sid in test_order:
loss_exp, mb_words = calc_lm_loss(test[sid:sid+MB_SIZE])
dev_loss += loss_exp.scalar_value()
dev_words += mb_words
print dev_loss / dev_words
# train on the minibatch
loss_exp, mb_words = calc_lm_loss(train[sid:sid+MB_SIZE])
cum_loss += loss_exp.scalar_value()
num_tagged += mb_words
loss_exp.backward()
trainer.update()
print "epoch %r finished" % ITER
trainer.update_epoch(1.0)