-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathmodel.py
224 lines (184 loc) · 8.17 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math
class ConvInputModel(nn.Module):
def __init__(self):
super(ConvInputModel, self).__init__()
self.conv1 = nn.Conv2d(3, 24, 3, stride=2, padding=1)
self.batchNorm1 = nn.BatchNorm2d(24)
self.conv2 = nn.Conv2d(24, 24, 3, stride=2, padding=1)
self.batchNorm2 = nn.BatchNorm2d(24)
self.conv3 = nn.Conv2d(24, 24, 3, stride=2, padding=1)
self.batchNorm3 = nn.BatchNorm2d(24)
self.conv4 = nn.Conv2d(24, 24, 3, stride=2, padding=1)
self.batchNorm4 = nn.BatchNorm2d(24)
def forward(self, img):
"""convolution"""
x = self.conv1(img)
x = self.batchNorm1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.batchNorm2(x)
x = F.relu(x)
x = self.conv3(x)
x = self.batchNorm3(x)
x = F.relu(x)
x = self.conv4(x)
x = self.batchNorm4(x)
x = F.relu(x)
return x
class QuestionEmbedModel(nn.Module):
def __init__(self, in_size, embed=32, hidden=128):
super(QuestionEmbedModel, self).__init__()
self.wembedding = nn.Embedding(in_size + 1, embed) #word embeddings have size 32
self.lstm = nn.LSTM(embed, hidden, batch_first=True) # Input dim is 32, output dim is the question embedding
self.hidden = hidden
def forward(self, question):
#calculate question embeddings
wembed = self.wembedding(question)
# wembed = wembed.permute(1,0,2) # in lstm minibatches are in the 2-nd dimension
self.lstm.flatten_parameters()
_, hidden = self.lstm(wembed) # initial state is set to zeros by default
qst_emb = hidden[0] # hidden state of the lstm. qst = (B x 128)
#qst_emb = qst_emb.permute(1,0,2).contiguous()
#qst_emb = qst_emb.view(-1, self.hidden*2)
qst_emb = qst_emb[0]
return qst_emb
class RelationalLayerBase(nn.Module):
def __init__(self, in_size, out_size, qst_size, hyp):
super().__init__()
self.f_fc1 = nn.Linear(hyp["g_layers"][-1], hyp["f_fc1"])
self.f_fc2 = nn.Linear(hyp["f_fc1"], hyp["f_fc2"])
self.f_fc3 = nn.Linear(hyp["f_fc2"], out_size)
self.dropout = nn.Dropout(p=hyp["dropout"])
self.on_gpu = False
self.hyp = hyp
self.qst_size = qst_size
self.in_size = in_size
self.out_size = out_size
def cuda(self):
self.on_gpu = True
super().cuda()
class RelationalLayer(RelationalLayerBase):
def __init__(self, in_size, out_size, qst_size, hyp, extraction=False):
super().__init__(in_size, out_size, qst_size, hyp)
self.quest_inject_position = hyp["question_injection_position"]
self.in_size = in_size
#create all g layers
self.g_layers = []
self.g_layers_size = hyp["g_layers"]
for idx,g_layer_size in enumerate(hyp["g_layers"]):
in_s = in_size if idx==0 else hyp["g_layers"][idx-1]
out_s = g_layer_size
if idx==self.quest_inject_position:
#create the h layer. Now, for better code organization, it is part of the g layers pool.
l = nn.Linear(in_s+qst_size, out_s)
else:
#create a standard g layer.
l = nn.Linear(in_s, out_s)
self.g_layers.append(l)
self.g_layers = nn.ModuleList(self.g_layers)
self.extraction = extraction
def forward(self, x, qst):
# x = (B x 8*8 x 24)
# qst = (B x 128)
"""g"""
b, d, k = x.size()
qst_size = qst.size()[1]
# add question everywhere
qst = torch.unsqueeze(qst, 1) # (B x 1 x 128)
qst = qst.repeat(1, d, 1) # (B x 64 x 128)
qst = torch.unsqueeze(qst, 2) # (B x 64 x 1 x 128)
# cast all pairs against each other
x_i = torch.unsqueeze(x, 1) # (B x 1 x 64 x 26)
x_i = x_i.repeat(1, d, 1, 1) # (B x 64 x 64 x 26)
x_j = torch.unsqueeze(x, 2) # (B x 64 x 1 x 26)
#x_j = torch.cat([x_j, qst], 3)
x_j = x_j.repeat(1, 1, d, 1) # (B x 64 x 64 x 26)
# concatenate all together
x_full = torch.cat([x_i, x_j], 3) # (B x 64 x 64 x 2*26)
# reshape for passing through network
x_ = x_full.view(b * d**2, self.in_size)
#create g and inject the question at the position pointed by quest_inject_position.
for idx, (g_layer, g_layer_size) in enumerate(zip(self.g_layers, self.g_layers_size)):
if idx==self.quest_inject_position:
in_size = self.in_size if idx==0 else self.g_layers_size[idx-1]
# questions inserted
x_img = x_.view(b,d,d,in_size)
qst = qst.repeat(1,1,d,1)
x_concat = torch.cat([x_img,qst],3) #(B x 64 x 64 x 128+256)
# h layer
x_ = x_concat.view(b*(d**2),in_size+self.qst_size)
x_ = g_layer(x_)
x_ = F.relu(x_)
else:
x_ = g_layer(x_)
x_ = F.relu(x_)
if self.extraction:
return None
# reshape again and sum
x_g = x_.view(b, d**2, self.g_layers_size[-1])
x_g = x_g.sum(1).squeeze(1)
"""f"""
x_f = self.f_fc1(x_g)
x_f = F.relu(x_f)
x_f = self.f_fc2(x_f)
x_f = self.dropout(x_f)
x_f = F.relu(x_f)
x_f = self.f_fc3(x_f)
return F.log_softmax(x_f, dim=1)
class RN(nn.Module):
def __init__(self, args, hyp, extraction=False):
super(RN, self).__init__()
self.coord_tensor = None
self.on_gpu = False
# CNN
self.conv = ConvInputModel()
self.state_desc = hyp['state_description']
# LSTM
hidden_size = hyp["lstm_hidden"]
self.text = QuestionEmbedModel(args.qdict_size, embed=hyp["lstm_word_emb"], hidden=hidden_size)
# RELATIONAL LAYER
self.rl_in_size = hyp["rl_in_size"]
self.rl_out_size = args.adict_size
self.rl = RelationalLayer(self.rl_in_size, self.rl_out_size, hidden_size, hyp, extraction)
if hyp["question_injection_position"] != 0:
print('Supposing IR model')
else:
print('Supposing original DeepMind model')
def forward(self, img, qst_idxs):
if self.state_desc:
x = img # (B x 12 x 8)
else:
x = self.conv(img) # (B x 24 x 8 x 8)
b, k, d, _ = x.size()
x = x.view(b,k,d*d) # (B x 24 x 8*8)
# add coordinates
if self.coord_tensor is None or torch.cuda.device_count() == 1:
self.build_coord_tensor(b, d) # (B x 2 x 8 x 8)
self.coord_tensor = self.coord_tensor.view(b,2,d*d) # (B x 2 x 8*8)
x = torch.cat([x, self.coord_tensor], 1) # (B x 24+2 x 8*8)
x = x.permute(0, 2, 1) # (B x 64 x 24+2)
qst = self.text(qst_idxs)
y = self.rl(x, qst)
return y
# prepare coord tensor
def build_coord_tensor(self, b, d):
coords = torch.linspace(-d/2., d/2., d)
x = coords.unsqueeze(0).repeat(d, 1)
y = coords.unsqueeze(1).repeat(1, d)
ct = torch.stack((x,y))
# broadcast to all batches
# TODO: upgrade pytorch and use broadcasting
ct = ct.unsqueeze(0).repeat(b, 1, 1, 1)
self.coord_tensor = Variable(ct, requires_grad=False)
if self.on_gpu:
self.coord_tensor = self.coord_tensor.cuda()
def cuda(self):
self.on_gpu = True
self.rl.cuda()
super(RN, self).cuda()