-
Notifications
You must be signed in to change notification settings - Fork 0
/
arch2.py
123 lines (110 loc) · 4.45 KB
/
arch2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
class InformedSender(nn.Module):
def __init__(self, game_size, feat_size, embedding_size, hidden_size,
vocab_size=100, temp=1.):
super(InformedSender, self).__init__()
self.game_size = game_size
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.temp = temp
self.lin1 = nn.Linear(feat_size, embedding_size, bias=False)
self.conv2 = nn.Conv2d(1, hidden_size,
kernel_size=(game_size, 1),
stride=(game_size, 1), bias=False)
self.conv3 = nn.Conv2d(1, 1,
kernel_size=(hidden_size, 1),
stride=(hidden_size, 1), bias=False)
self.lin4 = nn.Linear(embedding_size, vocab_size, bias=False)
def forward(self, x, return_embeddings=False):
emb = self.return_embeddings(x)
# in: h of size (batch_size, 1, game_size, embedding_size)
# out: h of size (batch_size, hidden_size, 1, embedding_size)
h = self.conv2(emb)
h = torch.sigmoid(h)
# in: h of size (batch_size, hidden_size, 1, embedding_size)
# out: h of size (batch_size, 1, hidden_size, embedding_size)
h = h.transpose(1, 2)
h = self.conv3(h)
# h of size (batch_size, 1, 1, embedding_size)
h = torch.sigmoid(h)
h = h.squeeze(dim=1)
h = h.squeeze(dim=1)
# h of size (batch_size, embedding_size)
h = self.lin4(h)
h = h.mul(1./self.temp)
# h of size (batch_size, vocab_size)
logits = F.log_softmax(h, dim=1)
print('log', logits.shape)
return logits
def return_embeddings(self, x):
# embed each image (left or right)
embs = []
for i in range(self.game_size):
h = x[i]
if len(h.size()) == 3:
h = h.squeeze(dim=-1)
h_i = self.lin1(h)
# h_i are batch_size x embedding_size
h_i = h_i.unsqueeze(dim=1)
h_i = h_i.unsqueeze(dim=1)
# h_i are now batch_size x 1 x 1 x embedding_size
embs.append(h_i)
# concatenate the embeddings
h = torch.cat(embs, dim=2)
return h
class Receiver(nn.Module):
def __init__(self, game_size, feat_size, embedding_size,
vocab_size, reinforce):
super(Receiver, self).__init__()
self.game_size = game_size
self.embedding_size = embedding_size
print(vocab_size, embedding_size, "args")
self.lin1 = nn.Linear(feat_size, embedding_size, bias=False)
# if reinforce:
# self.lin2 = nn.Embedding(vocab_size, embedding_size)
# else:
self.lin2 = nn.Linear(vocab_size, embedding_size, bias=False)
def forward(self, signal, x):
# embed each image (left or right)
emb = self.return_embeddings(x)
# embed the signal
if len(signal.size()) == 3:
signal = signal.squeeze(dim=-1)
# print(signal)
print(signal.shape)
h_s = self.lin2(signal)
# h_s is of size batch_size x embedding_size
h_s = h_s.unsqueeze(dim=1)
# h_s is of size batch_size x 1 x embedding_size
h_s = h_s.transpose(1, 2)
# h_s is of size batch_size x embedding_size x 1
out = torch.bmm(emb, h_s)
# out is of size batch_size x game_size x 1
out = out.squeeze(dim=-1)
# out is of size batch_size x game_size
log_probs = F.log_softmax(out, dim=1)
print("log probs", log_probs, log_probs.shape)
print(torch.argmax(log_probs, dim=1))
return torch.argmax(log_probs, dim=1), log_probs, log_probs
def return_embeddings(self, x):
# embed each image (left or right)
embs = []
for i in range(self.game_size):
h = x[i]
if len(h.size()) == 3:
h = h.squeeze(dim=-1)
h_i = self.lin1(h)
# print(h_i)
# h_i are batch_size x embedding_size
h_i = h_i.unsqueeze(dim=1)
# h_i are now batch_size x 1 x embedding_size
embs.append(h_i)
# print(embs)
h = torch.cat(embs, dim=1)
return h