forked from yxtay/char-rnn-text-generation
-
Notifications
You must be signed in to change notification settings - Fork 1
/
generate.py
122 lines (105 loc) · 4.01 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import sys
import random
import utils
import numpy as np
from argparse import ArgumentParser
from keras.models import load_model, Sequential
def main():
dsc = "generate synthetic text from a pre-trained LSTM text generation model"
arg_parser = ArgumentParser(description=dsc)
# generate args
arg_parser.add_argument("--checkpoint-path", required=True,
help="path to load model checkpoints (required)")
group = arg_parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"--text-path", help="path of text file to generate seed")
group.add_argument("--seed", default=None, help="seed character sequence")
arg_parser.add_argument("--length", type=int, default=1024,
help="length of character sequence to generate (default: %(default)s)")
arg_parser.add_argument("--top-n", type=int, default=3,
help="number of top choices to sample (default: %(default)s)")
args = arg_parser.parse_args()
text = generate(args)
print(text)
def generate(args):
"""
generates text from trained model specified in args.
main method for generate subcommand.
"""
# load learning model for config and weights
model = load_model(args.checkpoint_path)
# build inference model and transfer weights
inference_model = build_inference_model(model)
inference_model.set_weights(model.get_weights())
print("model loaded: {}.".format(args.checkpoint_path), file=sys.stderr)
# create seed if not specified
if args.seed is None:
with open(args.text_path) as f:
text = f.read()
seed = generate_seed(text)
print("seed sequence generated from {}".format(args.text_path), file=sys.stderr)
else:
seed = args.seed
return generate_text(inference_model, seed, args.length, args.top_n)
def generate_text(model, seed, length=512, top_n=10):
"""
generates text of specified length from trained model
with given seed character sequence.
"""
print("generating {} characters from top {} choices.".format(length, top_n), file=sys.stderr)
print('generating with seed: "{}".'.format(seed), file=sys.stderr)
generated = seed
encoded = utils.encode_text(seed)
model.reset_states()
for idx in encoded[:-1]:
x = np.array([[idx]])
# input shape: (1, 1)
# set internal states
model.predict(x)
next_index = encoded[-1]
for i in range(length):
x = np.array([[next_index]])
# input shape: (1, 1)
probs = model.predict(x)
# output shape: (1, 1, vocab_size)
next_index = sample_from_probs(probs.squeeze(), top_n)
# append to sequence
generated += utils.ID2CHAR[next_index]
return generated
def build_inference_model(model, batch_size=1, seq_len=1):
"""
build inference model from model config
input shape modified to (1, 1)
"""
print("building inference model.", file=sys.stderr)
config = model.get_config()
# edit batch_size and seq_len
config[0]["config"]["batch_input_shape"] = (batch_size, seq_len)
inference_model = Sequential.from_config(config)
inference_model.trainable = False
return inference_model
def generate_seed(text, seq_lens=(2, 4, 8, 16, 32)):
"""
select subsequence randomly from input text
"""
# randomly choose sequence length
seq_len = random.choice(seq_lens)
# randomly choose start index
start_index = random.randint(0, len(text) - seq_len - 1)
seed = text[start_index: start_index + seq_len]
return seed
def sample_from_probs(probs, top_n=10):
"""
truncated weighted random choice.
"""
# need 64 floating point precision
probs = np.array(probs, dtype=np.float64)
# set probabilities after top_n to 0
probs[np.argsort(probs)[:-top_n]] = 0
# renormalise probabilities
probs /= np.sum(probs)
sampled_index = np.random.choice(len(probs), p=probs)
return sampled_index
if __name__ == '__main__':
main()