-
Notifications
You must be signed in to change notification settings - Fork 1
/
uientry.py
108 lines (93 loc) · 3.87 KB
/
uientry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
from dataloader import poet_dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Mode_LSTM:
def __init__(self, name, path):
print(f'Loading model for 『{name}』...')
self.name = name
self.model = torch.load(path, map_location=device)
self.hidden = None
self.n_sents, self.n_words = self.model.info()
self.dataset = poet_dataset(self.model.data_path)
self.sep = self.dataset.head2vec(self.dataset.sep).to(device)
self.model.eval()
print(self.model.pre_word)
with torch.no_grad():
for word in self.model.pre_word:
ipt = self.dataset.head2vec(word).to(device)
_, self.hidden = self.model(ipt, self.hidden)
def entry(self, heads):
if len(heads) != self.n_sents:
return "Invalid input"
self.model.eval()
poet = []
cur_hidden = self.hidden
with torch.no_grad():
for i, head in enumerate(heads):
sent = []
ipt = self.dataset.head2vec(head).to(device)
sent.append(head)
for _ in range(self.n_words - 1):
opt, cur_hidden = self.model(ipt, cur_hidden)
word_idx = torch.argmax(opt.squeeze()).item()
word = self.dataset.num2word(word_idx)
sent.append(word)
ipt = self.dataset.head2vec(word).to(device)
opt, cur_hidden = self.model(ipt, cur_hidden)
sent.append(',' if i % 2 == 0 else '。')
opt, cur_hidden = self.model(self.sep, cur_hidden)
poet.append(''.join(sent))
return '\n'.join(poet)
class Mode_Transformer():
def __init__(self, name, path):
print(f'Loading model for 『{name}』...')
self.name = name
self.model = torch.load(path, map_location = device)
self.hidden = None
self.n_sents, self.n_words = self.model.info()
self.dataset = poet_dataset(self.model.data_path)
self.sep = self.dataset.head2vec(self.dataset.sep).to(device)
self.model.eval()
def entry(self, heads):
if len(heads) != self.n_sents:
return "Invalid input"
self.model.eval()
poet = []
with torch.no_grad():
input = None
for i, head in enumerate(heads):
sent = []
ipt = self.dataset.head2vec(head).to(device)
input = ipt if input is None else torch.cat([input, ipt], dim = 0)
sent.append(head)
for _ in range(self.n_words - 1):
input_mask = self.model.generate_square_subsequent_mask(input.shape[0]).to(device)
output = self.model(input, input_mask)
word_idx = torch.argmax(output[-1]).item()
word = self.dataset.num2word(word_idx)
sent.append(word)
ipt = self.dataset.head2vec(word).to(device)
input = torch.cat([input, ipt], dim = 0)
sent.append(',' if i % 2 == 0 else '。')
input = torch.cat([input, self.sep], dim = 0)
poet.append(''.join(sent))
return '\n'.join(poet)
Mode = Mode_LSTM
modes = [Mode('五言绝句', 'wuyanjueju_final_model.pt'),
Mode('七言绝句', 'qiyanjueju_final_model.pt'),
Mode('五言律诗', 'wuyanlvshi_final_model.pt'),
Mode('七言律诗', 'qiyanlvshi_final_model.pt')]
curmode = 0
def getAllModes():
return modes
def getCurMode():
return modes[curmode]
def setCurModeIndex(index):
global curmode
curmode = index
if __name__ == "__main__":
model = Mode('五言绝句', 'wuyanjueju_final_model.pt')
while True:
heads = input('Head: ')
output = model.entry(heads)
print(output)