forked from Doreenruirui/ACL2018_Multi_Input_OCR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
decode.py
121 lines (103 loc) · 4.08 KB
/
decode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import json
import os
import time
import numpy as np
import tensorflow as tf
from os.path import join as pjoin
import model as ocr_model
from util import read_vocab, padded
import util
from flag import FLAGS
import re
reverse_vocab, vocab, data = None, None, None
def create_model(session, vocab_size, forward_only):
model = ocr_model.Model(FLAGS.size, vocab_size,
FLAGS.num_layers, FLAGS.max_gradient_norm,
FLAGS.learning_rate, FLAGS.learning_rate_decay_factor,
forward_only=forward_only, decode=FLAGS.decode)
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
print("Reading model parameters from %s" % ckpt.model_checkpoint_path, file=sys.stderr)
model.saver.restore(session, ckpt.model_checkpoint_path)
else:
print("Created model with fresh parameters.", file=sys.stderr)
session.run(tf.global_variables_initializer())
return model
def tokenize_multi(sents, vocab):
token_ids = []
for sent in sents:
token_ids.append(util.sentenc_to_token_ids(sent, vocab))
token_ids = padded(token_ids)
source = np.array(token_ids).T
source_mask = (source != 0).astype(np.int32)
return source, source_mask
def tokenize_single(sent, vocab):
token_ids = util.sentenc_to_token_ids(sent, vocab)
ones = [1] * len(token_ids)
source = np.array(token_ids).reshape([-1, 1])
mask = np.array(ones).reshape([-1, 1])
return source, mask
def detokenize(sents, reverse_vocab):
def detok_sent(sent):
outsent = ''
for t in sent:
if t >= len(util._START_VOCAB):
outsent += reverse_vocab[t]
return outsent
return [detok_sent(s) for s in sents]
def fix_sent(model, sess, sents):
if FLAGS.decode == 'single':
input_toks, mask = tokenize_single(sents[0], vocab)
# len_inp * batch_size * num_units
encoder_output = model.encode(sess, input_toks, mask)
s1 = encoder_output.shape[0]
else:
input_toks, mask = tokenize_multi(sents, vocab)
# len_inp * num_wit * num_units
encoder_output = model.encode(sess, input_toks, mask)
# len_inp * num_wit * (2 * size)
s1, s2, s3 = encoder_output.shape
# num_wit * len_inp * 1
mask = np.transpose(mask, (1, 0))
# num_wit * len_inp * (2 * size)
encoder_output = np.transpose(encoder_output, (1, 0, 2))
beam_toks, probs = model.decode_beam(sess, encoder_output, mask, s1, FLAGS.beam_size)
beam_toks = beam_toks.tolist()
probs = probs.tolist()
# De-tokenize
beam_strs = detokenize(beam_toks, reverse_vocab)
return beam_strs, probs
def cleanWit(s):
return s.replace(u'\u00ad\n', '').replace(u'\n', ' ')
def decode():
global reverse_vocab, vocab
print("Preparing NLC data in %s" % FLAGS.data_dir, file=sys.stderr)
vocab_path = pjoin(FLAGS.voc_dir, "vocab.dat")
vocab, reverse_vocab = read_vocab(vocab_path)
vocab_size = len(vocab)
print("Vocabulary size: %d" % vocab_size, file=sys.stderr)
sess = tf.Session()
print("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size), file=sys.stderr)
model = create_model(sess, len(vocab), True)
line_id = 0
for rec in sys.stdin:
doc = json.loads(rec)
res = list()
for line in doc['lines']:
variants = [line['text']]
if 'wits' in line:
variants += [cleanWit(w['text']) for w in line['wits']]
# sents = [ele for ele in line.strip('\n').split('\t')][:50]
# sents = [ele for ele in sents if len(ele.strip()) > 0]
hyps, probs = fix_sent(model, sess, variants)
line['hyps'] = [{'text': h, 'p': p} for h, p in zip(hyps, probs)]
res.append(line)
print(json.dumps({'id': doc['id'], 'lines': res}))
def main(_):
decode()
if __name__ == "__main__":
tf.app.run()