-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
407 lines (333 loc) · 17.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
import argparse
import cPickle
import codecs
import os
import sys
import tempfile
from collections import defaultdict as dd
import numpy as np
from loader import encode_sentence, read_datafile
from model import build_model, Params
from utils import get_morph_analyzes, create_single_word_single_line_format
def create_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--command", default="train", required=True,
choices=["train", "predict", "disambiguate"])
parser.add_argument("--train_filepath", required=True)
parser.add_argument("--test_filepath", required=True)
parser.add_argument("--run_name", required=True)
parser.add_argument("--model_path")
parser.add_argument("--label2ids_path")
return parser
def sample_generator(sentences, label2ids, batch_size=32, return_sentence=False):
while True:
n_in_batch = 0
batch = [[], []]
decoded_sentences_in_batch = []
shuffled_indices = np.random.permutation(len(sentences))
for sentence_i in shuffled_indices:
sentence = sentences[sentence_i]
sentences_word_root_input, sentences_analysis_input, surface_form_input, correct_tags_input, shuffled_positions_record = \
encode_sentence(sentence, label2ids)
if return_sentence:
decoded_sentences_in_batch.append([sentence, shuffled_positions_record])
input_array = []
for array in [sentences_word_root_input, sentences_analysis_input, surface_form_input,
correct_tags_input]:
# print array.shape
# print array.reshape([-1] + list(array.shape)).shape
input_array += [np.expand_dims(np.copy(array), axis=0)]
# print array.shape
# print np.expand_dims(array, axis=0).shape
batch[0].append(input_array[:-1])
batch[1].append([input_array[-1]])
n_in_batch += 1
if n_in_batch == batch_size:
# yield np.concatenate(batch[0], axis=0), np.concatenate(batch[1], axis=0)
# for b in batch[1]:
# for i in range(1):
# print i
# print b[i].shape
encoded_sentences_in_batch = ([np.concatenate([b[i] for b in batch[0]], axis=0) for i in range(3)], \
[np.concatenate([b[0] for b in batch[1]], axis=0)])
if not return_sentence:
yield encoded_sentences_in_batch
else:
yield encoded_sentences_in_batch, decoded_sentences_in_batch
# yield batch[0], batch[1]
n_in_batch = 0
batch = [[], []]
decoded_sentences_in_batch = []
if n_in_batch > 0:
encoded_sentences_in_batch = ([np.concatenate([b[i] for b in batch[0]], axis=0) for i in
range(3)], \
[np.concatenate([b[0] for b in batch[1]], axis=0)])
if not return_sentence:
yield encoded_sentences_in_batch
else:
yield encoded_sentences_in_batch, decoded_sentences_in_batch
def load_label2ids_and_params(args):
if args.label2ids_path:
with open(args.label2ids_path, "r") as f:
label2ids = cPickle.load(f)
else:
if os.path.exists(args.model_path + ".label2ids"):
with open(args.model_path + ".label2ids", "r") as f:
label2ids = cPickle.load(f)
else:
train_and_test_sentences, label2ids = read_datafile(args.train_filepath,
args.test_filepath)
with open(args.model_path + ".label2ids", "w") as f:
cPickle.dump(label2ids, f)
params = create_params(label2ids)
return label2ids, params, train_and_test_sentences
params = create_params(label2ids)
return label2ids, params, []
def create_params(label2ids):
params = Params()
params.max_sentence_length = label2ids['max_sentence_length']
params.max_n_analyses = label2ids['max_n_analysis']
params.batch_size = 1
params.n_subepochs = 40
params.max_surface_form_length = label2ids['max_surface_form_length']
params.max_word_root_length = label2ids['max_word_root_length']
params.max_analysis_length = label2ids['max_analysis_length']
params.char_vocabulary_size = label2ids['character_unique_count']['value']
params.tag_vocabulary_size = label2ids['morph_token_unique_count']['value']
params.char_lstm_dim = 100
params.char_embedding_dim = 100
params.tag_lstm_dim = params.char_lstm_dim
params.tag_embedding_dim = 100
params.sentence_level_lstm_dim = 2 * params.char_lstm_dim
return params
def disambiguate_single_line_sentence(line, model, label2ids, params, print_prediction_lines=True):
"""
:param line:
:param model:
:param label2ids:
:param params:
:param print_prediction_lines:
:return: An array of strings which represent the words and disambiguated analyzes
"""
string_output = get_morph_analyzes(line)# print "XXX", string_output, "YYY"
analyzer_output_string = create_single_word_single_line_format(string_output)
# print string_output_single_line.decode("iso-8859-9")
# print type(string_output_single_line)
fd, f_path = tempfile.mkstemp()
with codecs.open(f_path, "w", encoding="utf8") as f:
f.write(analyzer_output_string.decode("iso-8859-9"))
os.close(fd)
# print f_path
train_and_test_sentences, _ = read_datafile(f_path, f_path, preloaded_label2ids=label2ids)
# print train_and_test_sentences[1]
sample_batch, decoded_sample_batch = iter(sample_generator(train_and_test_sentences[1],
label2ids,
batch_size=params.batch_size,
return_sentence=True)).next()
# for batch_idx, (sample_batch, decoded_sample_batch) in enumerate(sample_generator(train_and_test_sentences[1],
# label2ids,
# batch_size=params.batch_size,
# return_sentence=True)
# ):
pred_probs = model.predict(sample_batch[0], batch_size=params.batch_size, verbose=1)
pred_tags = np.argmax(pred_probs[0], axis=1)
# print pred_tags
first_sentence = decoded_sample_batch[0][0]
first_shuffled_positions = decoded_sample_batch[0][1]
sentence_length = len(first_sentence['roots'])
# print sentence_length
pred_probs_copy = np.copy(pred_probs)
for row_idx, first_shuffled_positions_row in enumerate(first_shuffled_positions):
# print pred_probs_copy[0, row_idx]
# print first_shuffled_positions_row
for col_idx, first_shuffled_position in enumerate(first_shuffled_positions_row):
pred_probs_copy[0, row_idx, first_shuffled_position] = pred_probs[
0, row_idx, col_idx]
# print pred_probs_copy[0, row_idx]
pred_tags = np.argmax(pred_probs_copy[0], axis=1)
# print pred_tags
prediction_lines = [
first_sentence['surface_forms'][word_idx] + " " + first_sentence['roots'][word_idx][
pred_tag] + "+" + "+".join(first_sentence['morph_tokens'][word_idx][pred_tag]) for
word_idx, pred_tag in enumerate(pred_tags[:sentence_length])]
prediction_lines_raw = [
[first_sentence['surface_forms'][word_idx], first_sentence['roots'][word_idx][
pred_tag] + "+" + "+".join(first_sentence['morph_tokens'][word_idx][pred_tag])] for
word_idx, pred_tag in enumerate(pred_tags[:sentence_length])]
if print_prediction_lines:
print "\n".join(prediction_lines)
print ""
return prediction_lines
else:
print [type(x) for x in prediction_lines]
return analyzer_output_string, prediction_lines, prediction_lines_raw
def create_model_for_disambiguation(args):
from model import build_model
label2ids, params, _ = load_label2ids_and_params(args)
model = build_model(params)
assert args.model_path, "--model_path should be given in the arguments for 'predict'"
model.load_weights(args.model_path)
return label2ids, params, model
if __name__ == "__main__":
parser = create_parser()
args = parser.parse_args()
if args.command == "train":
train_and_test_sentences, label2ids = read_datafile(args.train_filepath, args.test_filepath)
params = Params()
params.max_sentence_length = label2ids['max_sentence_length']
params.max_n_analyses = label2ids['max_n_analysis']
params.batch_size = 1
params.n_subepochs = 40
params.max_surface_form_length = label2ids['max_surface_form_length']
params.max_word_root_length = label2ids['max_word_root_length']
params.max_analysis_length = label2ids['max_analysis_length']
params.char_vocabulary_size = label2ids['character_unique_count']['value']
params.tag_vocabulary_size = label2ids['morph_token_unique_count']['value']
params.char_lstm_dim = 100
params.char_embedding_dim = 100
params.tag_lstm_dim = params.char_lstm_dim
params.tag_embedding_dim = 100
params.sentence_level_lstm_dim = 2 * params.char_lstm_dim
# train_and_test_sentences, label2ids = read_datafile("test.merge.utf8", "test.merge.utf8")
model = build_model(params)
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
from keras.callbacks import ModelCheckpoint, TensorBoard, ReduceLROnPlateau
checkpointer = ModelCheckpoint(filepath="./models/ntd-{run_name}".format(run_name=args.run_name)
+ '.epoch-{epoch:02d}-val_acc-{val_acc:.5f}.hdf5',
verbose=1,
monitor="val_acc",
save_best_only=True)
tensorboard_callback = TensorBoard(log_dir=os.path.join('./logs', args.run_name),
histogram_freq=1,
write_graph=True,
write_images=True,
embeddings_freq=1,
embeddings_layer_names=None,
embeddings_metadata={'char_embedding_layer': 'char_embedding_layer.tsv',
'tag_embedding_layer': 'tag_embedding_layer.tsv'})
reduce_lr = ReduceLROnPlateau(monitor='val_loss',
factor=0.1,
patience=10,
verbose=0,
mode='auto',
epsilon=0.0001,
cooldown=1,
min_lr=0)
model.fit_generator(sample_generator(train_and_test_sentences[0], label2ids, batch_size=params.batch_size),
steps_per_epoch=len(train_and_test_sentences[0])/params.batch_size/params.n_subepochs,
epochs=10*params.n_subepochs,
validation_data=sample_generator(train_and_test_sentences[1], label2ids, batch_size=params.batch_size),
validation_steps=len(train_and_test_sentences[1])/params.batch_size+1,
callbacks=[checkpointer, tensorboard_callback, reduce_lr])
print "Saving"
model.save("./models/ntd-{run_name}-final.hdf5".format(run_name=args.run_name))
# model.evaluate()
elif args.command == "predict":
label2ids, params, train_and_test_sentences = load_label2ids_and_params(args)
train_and_test_sentences, label2ids_from_input_file = read_datafile(args.train_filepath, args.test_filepath)
from model import build_model
model = build_model(params)
assert args.model_path, "--model_path should be given in the arguments for 'predict'"
model.load_weights(args.model_path)
total_correct = 0
total_tokens = 0
total_correct_all_structure = 0
total_tokens_all_structure = 0
total_correct_ambigious = 0
total_tokens_ambigious = 0
correct_counts = dd(int)
total_counts = dd(int)
for batch_idx, (sample_batch, decoded_sample_batch) in enumerate(sample_generator(train_and_test_sentences[1],
label2ids,
batch_size=params.batch_size,
return_sentence=True)):
# print sample
# pred_probs = model.predict(sample_batch[0], batch_size=params.batch_size, verbose=1)
# print pred_probs
# print pred_probs.shape
# print np.argmax(pred_probs, axis=2)
# print sample_batch[1][0]
# print sample_batch[1][0].shape
# print np.argmax(sample_batch[1][0], axis=2)
# print decoded_sample_batch
# print "decoded_sample_batch: ", decoded_sample_batch
correct_tags = np.argmax(sample_batch[1][0], axis=2)
pred_probs = model.predict(sample_batch[0], batch_size=params.batch_size, verbose=1)
pred_tags = np.argmax(pred_probs, axis=2)
for idx, correct_tag in enumerate(correct_tags):
sentence_length = len(decoded_sample_batch[idx][0]['surface_form_lengths'])
pred_tag = pred_tags[idx]
# print correct_tag
# print correct_tag.shape
# print pred_tag
# print pred_tag.shape
n_correct = np.sum(correct_tag[:sentence_length] == pred_tag[:sentence_length])
total_correct += n_correct
total_tokens += sentence_length
# print "sentence_length: ", sentence_length
# print "n_correct: ", n_correct
# print "sentence_length: ", sentence_length
total_correct_all_structure += n_correct + correct_tag.shape[0] - sentence_length
total_tokens_all_structure += correct_tag.shape[0]
baseline_log_prob = 0
import math
for j in range(len(decoded_sample_batch[idx][0]['roots'])):
if j >= sentence_length:
break
n_analyses = len(decoded_sample_batch[idx][0]['roots'][j])
baseline_log_prob += math.log(1/float(n_analyses))
# print n_analyses
assert n_analyses >= 1
if n_analyses > 1:
if correct_tag[j] == pred_tag[j]:
total_correct_ambigious += 1
correct_counts[n_analyses] += 1
total_tokens_ambigious += 1
total_counts[n_analyses] += 1
if batch_idx % 100 == 0:
print "only the filled part of the sentence"
print total_correct
print total_tokens
print float(total_correct)/total_tokens
print "all the sentence"
print total_correct_all_structure
print total_tokens_all_structure
print float(total_correct_all_structure)/total_tokens_all_structure
print "==="
print "ambigous"
print total_correct_ambigious
print total_tokens_ambigious
print float(total_correct_ambigious)/total_tokens_ambigious
print "==="
for key in correct_counts:
print "disambiguations out of n_analyses: %d ===> %lf" % (key, float(correct_counts[key])/total_counts[key])
print "==="
if batch_idx*params.batch_size >= len(train_and_test_sentences[1]):
print "Evaluation finished, batch_id: %d" % batch_idx
print "only the filled part of the sentence"
print total_correct
print total_tokens
print float(total_correct)/total_tokens
print "all the sentence"
print total_correct_all_structure
print total_tokens_all_structure
print float(total_correct_all_structure) / total_tokens_all_structure
print "==="
print "ambigous"
print total_correct_ambigious
print total_tokens_ambigious
print float(total_correct_ambigious)/total_tokens_ambigious
print "==="
for key in correct_counts:
print "disambiguations out of n_analyses: %d ===> %lf %d %d" % (key, float(correct_counts[key])/total_counts[key], correct_counts[key], total_counts[key])
print "==="
break
elif args.command == "disambiguate":
label2ids, params, model = create_model_for_disambiguation(args)
line = sys.stdin.readline()
while line:
line = line.strip("\n")
disambiguate_single_line_sentence(line, model, label2ids, params)
line = sys.stdin.readline()