-
Notifications
You must be signed in to change notification settings - Fork 4
/
run_sentences.py
90 lines (68 loc) · 2.72 KB
/
run_sentences.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import json
from argparse import ArgumentParser
import tensorflow as tf
from train import SpacingModel
parser = ArgumentParser()
parser.add_argument("--char-file", type=str, required=True)
parser.add_argument("--model-file", type=str, required=True)
parser.add_argument("--training-config", type=str, required=True)
def main():
args = parser.parse_args()
with open(args.training_config) as f:
config = json.load(f)
with open(args.char_file) as f:
content = f.read()
keys = ["<pad>", "<s>", "</s>", "<unk>"] + list(content)
values = list(range(len(keys)))
vocab_initializer = tf.lookup.KeyValueTensorInitializer(keys, values, key_dtype=tf.string, value_dtype=tf.int32)
vocab_table = tf.lookup.StaticHashTable(vocab_initializer, default_value=3)
model = SpacingModel(
config["vocab_size"],
config["hidden_size"],
conv_activation=config["conv_activation"],
dense_activation=config["dense_activation"],
conv_kernel_and_filter_sizes=config["conv_kernel_and_filter_sizes"],
dropout_rate=config["dropout_rate"],
)
model.load_weights(args.model_file)
model(tf.keras.Input([None], dtype=tf.int32))
model.summary()
inference = get_inference_fn(model, vocab_table)
while True:
input_str = input("Str: ")
input_str = tf.constant(input_str)
result = inference(input_str).numpy()
print(b"".join(result).decode("utf8"))
def get_inference_fn(model, vocab_table):
@tf.function
def inference(tensors):
byte_array = tf.concat(
[["<s>"], tf.strings.unicode_split(tf.strings.regex_replace(tensors, " +", " "), "UTF-8"), ["</s>"]], axis=0
)
strings = vocab_table.lookup(byte_array)[tf.newaxis, :]
model_output = tf.argmax(model(strings), axis=-1)[0]
return convert_output_to_string(byte_array, model_output)
return inference
def convert_output_to_string(byte_array, model_output):
sequence_length = tf.size(model_output)
while_condition = lambda i, *_: i < sequence_length
def while_body(i, o):
o = tf.cond(
model_output[i] == 1,
lambda: tf.concat([o, [byte_array[i], " "]], axis=0),
lambda: tf.cond(
(model_output[i] == 2) and (byte_array[i] == " "),
lambda: o,
lambda: tf.concat([o, [byte_array[i]]], axis=0),
),
)
return i + 1, o
_, strings_result = tf.while_loop(
while_condition,
while_body,
(tf.constant(0), tf.constant([], dtype=tf.string)),
shape_invariants=(tf.TensorShape([]), tf.TensorShape([None])),
)
return strings_result
if __name__ == "__main__":
main()