forked from Fugtemypt123/mytransformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
120 lines (92 loc) · 2.88 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import spacy
def translate_sentence(model, sentence, zh_vocab, en_ivocab, device, max_len):
spacy_zh = spacy.load('zh_core_web_sm')
# spacy_en = spacy.load('en_core_web_sm')
zh_tokens = [tok.text for tok in spacy_zh(sentence)]
# en_tokens = [tok.text for tok in spacy_en(answer)]
zh_nums = []
# en_nums = []
for word in zh_tokens:
try:
zh_nums.append(zh_vocab[word])
except KeyError:
pass
zh_nums.insert(0, 1)
zh_nums.append(2)
# zh_res = [0] * max_len
# zh_res[:len(zh_nums)] = zh_nums
'''
for word in en_tokens:
try:
en_nums.append(en_vocab[word])
except KeyError:
pass
en_nums.insert(0, 1)
en_nums.append(2)
en_res = [0] * max_len
en_res[:len(en_nums)] = en_nums
'''
src = torch.tensor(zh_nums).unsqueeze(1).to(device)
src = torch.transpose(src, 0, 1)
'''
trg = torch.tensor(en_res).unsqueeze(1).to(device)
trg = torch.transpose(trg, 0, 1)
output = model(src, trg)
output = output.reshape(-1, output.shape[2])
return tensor2sentence(output, en_ivocab, max_len)
'''
# outputs = [0] * max_len
outputs = [1]
for i in range(max_len):
trg = torch.tensor(outputs).unsqueeze(1).to(device)
trg = torch.transpose(trg, 0, 1)
# print(src.shape)
# print(trg.shape)
with torch.no_grad():
output = model(src, trg)
# print(output.shape)
best_guess = output.argmax(2)[:, -1].item()
# print(best_guess)
outputs.append(best_guess)
# print(outputs[i+1])
if best_guess == 2:
break
translated_sentence = ""
for i in range(max_len):
translated_sentence += en_ivocab[outputs[i]]
if outputs[i] == 2 or outputs[i] == 0:
break
else:
translated_sentence += " "
return translated_sentence
def tensor2sentence(output, en_ivocab, max_len):
best_words = output.argmax(1)
sentence = ""
for i in range(max_len):
num = best_words[i].item()
word = en_ivocab[num]
sentence += word
if word != "<eos>":
sentence += " "
else:
break
return sentence
def list2sentence(best_words, zh_ivocab, max_len):
sentence = ""
for i in range(max_len):
num = best_words[i].item()
word = zh_ivocab[num]
sentence += word
if word != "<eos>":
sentence += " "
else:
break
return sentence
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
torch.save(state, filename)
def load_checkpoint(checkpoint, model, optimizer):
print("=> Loading checkpoint")
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])