-
Notifications
You must be signed in to change notification settings - Fork 2
/
S2SModel.py
executable file
·86 lines (64 loc) · 3.06 KB
/
S2SModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
from torch.autograd import Variable
import torch.nn as nn
from Statistics import Statistics
from UtilClass import bottle
from CopyGenerator import CopyGenerator, ProdGenerator
from RegularDecoder import RegularDecoder
from ProdDecoder import ProdDecoder
from RegularEncoder import RegularEncoder
from ConcodeEncoder import ConcodeEncoder
from ConcodeDecoder import ConcodeDecoder
from decoders import DecoderState
class S2SModel(nn.Module):
def __init__(self, opt, vocabs):
super(S2SModel, self).__init__()
self.opt = opt
self.vocabs = vocabs
if self.opt.encoder_type == "regular":
self.encoderClass = RegularEncoder
elif self.opt.encoder_type == "concode":
self.encoderClass = ConcodeEncoder
self.encoder = self.encoderClass(vocabs, opt)
if self.opt.decoder_type == "prod":
self.decoderClass = ProdDecoder
elif self.opt.decoder_type == "concode":
self.decoderClass = ConcodeDecoder
else:
self.decoderClass = RegularDecoder
self.decoder = self.decoderClass(vocabs, opt)
if self.opt.decoder_type in ["prod", "concode"]:
generator = ProdGenerator
else:
generator = CopyGenerator
self.generator = generator(self.opt.decoder_rnn_size, vocabs, self.opt)
self.cuda()
def forward(self, batch):
# initial parent states for Prod Decoder
batch_size = batch['seq2seq'].size(0)
if self.opt.decoder_type == "concode":
batch['parent_states'] = {}
for j in range(0, batch_size):
batch['parent_states'][j] = {}
if self.opt.decoder_type in ["prod", "concode"]:
batch['parent_states'][j][0] = Variable(torch.zeros(1, 1, self.opt.decoder_rnn_size).cuda(), requires_grad=False)
context, context_lengths, enc_hidden = self.encoder(batch)
decInitState = DecoderState(enc_hidden, Variable(torch.zeros(batch_size, 1, self.opt.decoder_rnn_size).cuda(), requires_grad=False))
output, attn, copy_attn = self.decoder(batch, context, context_lengths, decInitState)
if self.opt.decoder_type == "concode":
del batch['parent_states']
# Other generators will not use the extra parameters
# Let the generator put the src_map in cuda if it uses it
# TODO: Make sec_map variable again in generator
src_map = torch.zeros(0, 0)
if self.opt.decoder_type == "concode":
src_map = torch.cat((batch['concode_src_map_vars'], batch['concode_src_map_methods']), 1)
scores = self.generator(bottle(output), bottle(copy_attn), src_map if self.opt.encoder_type in ["concode"] else batch['src_map'], batch)
loss, total, correct = self.generator.computeLoss(scores, batch)
return loss, Statistics(loss.data.item(), total.item(), correct.item(), self.encoder.n_src_words)
# This only works for a batch size of 1
def predict(self, batch, opt):
curr_batch_size = batch['seq2seq'].size(0)
assert(curr_batch_size == 1)
context, context_lengths, enc_hidden = self.encoder(batch)
return self.decoder.predict(enc_hidden, context, context_lengths, batch, opt.beam_size, opt.max_sent_length, self.generator, opt.replace_unk)