-
Notifications
You must be signed in to change notification settings - Fork 47
/
models.py
105 lines (85 loc) · 3.09 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn as nn
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.c0 = nn.Conv2d(3, 64, kernel_size=4, stride=1)
self.c1 = nn.Conv2d(64, 128, kernel_size=4, stride=1)
self.c2 = nn.Conv2d(128, 256, kernel_size=4, stride=1)
self.c3 = nn.Conv2d(256, 512, kernel_size=4, stride=1)
self.l1 = nn.Linear(512*20*20, 64)
self.b1 = nn.BatchNorm2d(128)
self.b2 = nn.BatchNorm2d(256)
self.b3 = nn.BatchNorm2d(512)
def forward(self, x):
h = F.relu(self.c0(x))
features = F.relu(self.b1(self.c1(h)))
h = F.relu(self.b2(self.c2(features)))
h = F.relu(self.b3(self.c3(h)))
encoded = self.l1(h.view(x.shape[0], -1))
return encoded, features
class GlobalDiscriminator(nn.Module):
def __init__(self):
super().__init__()
self.c0 = nn.Conv2d(128, 64, kernel_size=3)
self.c1 = nn.Conv2d(64, 32, kernel_size=3)
self.l0 = nn.Linear(32 * 22 * 22 + 64, 512)
self.l1 = nn.Linear(512, 512)
self.l2 = nn.Linear(512, 1)
def forward(self, y, M):
h = F.relu(self.c0(M))
h = self.c1(h)
h = h.view(y.shape[0], -1)
h = torch.cat((y, h), dim=1)
h = F.relu(self.l0(h))
h = F.relu(self.l1(h))
return self.l2(h)
class LocalDiscriminator(nn.Module):
def __init__(self):
super().__init__()
self.c0 = nn.Conv2d(192, 512, kernel_size=1)
self.c1 = nn.Conv2d(512, 512, kernel_size=1)
self.c2 = nn.Conv2d(512, 1, kernel_size=1)
def forward(self, x):
h = F.relu(self.c0(x))
h = F.relu(self.c1(h))
return self.c2(h)
class PriorDiscriminator(nn.Module):
def __init__(self):
super().__init__()
self.l0 = nn.Linear(64, 1000)
self.l1 = nn.Linear(1000, 200)
self.l2 = nn.Linear(200, 1)
def forward(self, x):
h = F.relu(self.l0(x))
h = F.relu(self.l1(h))
return torch.sigmoid(self.l2(h))
class Classifier(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(64, 15)
self.bn1 = nn.BatchNorm1d(15)
self.l2 = nn.Linear(15, 10)
self.bn2 = nn.BatchNorm1d(10)
self.l3 = nn.Linear(10, 10)
self.bn3 = nn.BatchNorm1d(10)
def forward(self, x):
encoded, _ = x[0], x[1]
clazz = F.relu(self.bn1(self.l1(encoded)))
clazz = F.relu(self.bn2(self.l2(clazz)))
clazz = F.softmax(self.bn3(self.l3(clazz)), dim=1)
return clazz
class DeepInfoAsLatent(nn.Module):
def __init__(self, run, epoch):
super().__init__()
model_path = Path(r'c:/data/deepinfomax/models') / Path(str(run)) / Path('encoder' + str(epoch) + '.wgt')
self.encoder = Encoder()
self.encoder.load_state_dict(torch.load(str(model_path)))
self.classifier = Classifier()
def forward(self, x):
z, features = self.encoder(x)
z = z.detach()
return self.classifier((z, features))