forked from spro/char-rnn.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
generate.py
executable file
·59 lines (46 loc) · 1.83 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#!/usr/bin/env python
# https://github.com/spro/char-rnn.pytorch
import torch
import os
import argparse
from helpers import *
from model import *
def generate(decoder, prime_str='A', predict_len=100, temperature=0.8, cuda=False):
hidden = decoder.init_hidden(1)
prime_input = Variable(char_tensor(prime_str).unsqueeze(0))
if cuda:
if isinstance(hidden, tuple):
hidden = (hidden[0].cuda(), hidden[1].cuda())
else:
hidden = hidden.cuda()
prime_input = prime_input.cuda()
predicted = prime_str
# Use priming string to "build up" hidden state
for p in range(len(prime_str) - 1):
_, hidden = decoder(prime_input[:,p], hidden)
inp = prime_input[:,-1]
for p in range(predict_len):
output, hidden = decoder(inp, hidden)
# Sample from the network as a multinomial distribution
output_dist = output.data.view(-1).div(temperature).exp()
top_i = torch.multinomial(output_dist, 1)[0]
# Add predicted character to string and use as next input
predicted_char = all_characters[top_i]
predicted += predicted_char
inp = Variable(char_tensor(predicted_char).unsqueeze(0))
if cuda:
inp = inp.cuda()
return predicted
# Run as standalone script
if __name__ == '__main__':
# Parse command line arguments
argparser = argparse.ArgumentParser()
argparser.add_argument('filename', type=str)
argparser.add_argument('-p', '--prime_str', type=str, default='A')
argparser.add_argument('-l', '--predict_len', type=int, default=100)
argparser.add_argument('-t', '--temperature', type=float, default=0.8)
argparser.add_argument('--cuda', action='store_true')
args = argparser.parse_args()
decoder = torch.load(args.filename)
del args.filename
print(generate(decoder, **vars(args)))