-
Notifications
You must be signed in to change notification settings - Fork 2
/
models_text_skip.py
120 lines (103 loc) · 4.79 KB
/
models_text_skip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
class RNNVAE(nn.Module):
def __init__(self, vocab_size=10000,
enc_word_dim = 512,
enc_h_dim = 1024,
enc_num_layers = 1,
dec_word_dim = 512,
dec_h_dim = 1024,
dec_num_layers = 1,
dec_dropout = 0.5,
latent_dim=32,
mode='savae',
skip = 0):
super(RNNVAE, self).__init__()
self.skip = skip
self.enc_h_dim = enc_h_dim
self.enc_num_layers = enc_num_layers
self.dec_h_dim =dec_h_dim
self.dec_num_layers = dec_num_layers
if mode == 'savae' or mode == 'vae':
self.enc_word_vecs = nn.Embedding(vocab_size, enc_word_dim)
self.latent_linear_mean = nn.Linear(enc_h_dim, latent_dim)
self.latent_linear_logvar = nn.Linear(enc_h_dim, latent_dim)
self.enc_rnn = nn.LSTM(enc_word_dim, enc_h_dim, num_layers = enc_num_layers,
batch_first = True)
self.enc = nn.ModuleList([self.enc_word_vecs, self.enc_rnn,
self.latent_linear_mean, self.latent_linear_logvar])
elif mode == 'autoreg':
latent_dim = 0
self.dec_word_vecs = nn.Embedding(vocab_size, dec_word_dim)
dec_input_size = dec_word_dim
dec_input_size += latent_dim
if self.skip == 0:
self.dec_linear = nn.Linear(dec_h_dim, vocab_size)
else:
self.dec_linear = nn.Linear(dec_h_dim + latent_dim, vocab_size)
if self.skip == 0 or self.dec_num_layers == 1:
self.dec_rnn = nn.LSTM(dec_input_size, dec_h_dim, num_layers = dec_num_layers,
batch_first = True)
self.dec = nn.ModuleList([self.dec_word_vecs, self.dec_rnn, self.dec_linear])
else:
self.dec_rnn = nn.ModuleList([nn.LSTM(dec_input_size, dec_h_dim, batch_first = True)
if k == 0 else
nn.LSTM(dec_h_dim + latent_dim, dec_h_dim, batch_first=True)
for k in range(self.dec_num_layers)])
self.dec = nn.ModuleList([self.dec_word_vecs, self.dec_linear, *self.dec_rnn])
self.dropout = nn.Dropout(dec_dropout)
if latent_dim > 0:
self.latent_hidden_linear = nn.Linear(latent_dim, dec_h_dim)
self.dec.append(self.latent_hidden_linear)
def _enc_forward(self, sent):
word_vecs = self.enc_word_vecs(sent)
h0 = Variable(torch.zeros(self.enc_num_layers, word_vecs.size(0),
self.enc_h_dim).type_as(word_vecs.data))
c0 = Variable(torch.zeros(self.enc_num_layers, word_vecs.size(0),
self.enc_h_dim).type_as(word_vecs.data))
enc_h_states, _ = self.enc_rnn(word_vecs, (h0, c0))
enc_h_states_last = enc_h_states[:, -1]
mean = self.latent_linear_mean(enc_h_states_last)
logvar = self.latent_linear_logvar(enc_h_states_last)
return mean, logvar
def _reparameterize(self, mean, logvar, z = None):
std = logvar.mul(0.5).exp()
if z is None:
z = Variable(torch.cuda.FloatTensor(std.size()).normal_(0, 1))
return z.mul(std) + mean
def _dec_forward(self, sent, q_z, init_h = True):
self.word_vecs = self.dropout(self.dec_word_vecs(sent[:, :-1]))
if init_h:
self.h0 = Variable(torch.zeros(self.dec_num_layers, self.word_vecs.size(0),
self.dec_h_dim).type_as(self.word_vecs.data), requires_grad = False)
self.c0 = Variable(torch.zeros(self.dec_num_layers, self.word_vecs.size(0),
self.dec_h_dim).type_as(self.word_vecs.data), requires_grad = False)
else:
self.h0.data.zero_()
self.c0.data.zero_()
if q_z is not None:
q_z_expand = q_z.unsqueeze(1).expand(self.word_vecs.size(0),
self.word_vecs.size(1), q_z.size(1))
dec_input = torch.cat([self.word_vecs, q_z_expand], 2)
else:
dec_input = self.word_vecs
if q_z is not None:
self.h0[-1] = self.latent_hidden_linear(q_z)
if self.skip == 0 or self.dec_num_layers == 1:
memory, _ = self.dec_rnn(dec_input, (self.h0, self.c0))
else:
for k in range(self.dec_num_layers):
memory, _ = self.dec_rnn[k](dec_input, (self.h0[k].unsqueeze(0), self.c0[k].unsqueeze(0)))
dec_input = torch.cat([memory, q_z_expand], 2)
memory = self.dropout(memory)
if self.skip == 1:
dec_linear_input = torch.cat([memory.contiguous(), q_z_expand], 2)
else:
dec_linear_input = memory.contiguous()
preds = self.dec_linear(dec_linear_input)
preds = F.log_softmax(preds, 2)
return preds