-
Notifications
You must be signed in to change notification settings - Fork 48
/
net.py
26 lines (22 loc) · 851 Bytes
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os
import torch.nn as nn
from torchvision import models
os.environ['TORCH_HOME'] = 'models'
alexnet_model = models.alexnet(pretrained=True)
class AlexNetPlusLatent(nn.Module):
def __init__(self, bits):
super(AlexNetPlusLatent, self).__init__()
self.bits = bits
self.features = nn.Sequential(*list(alexnet_model.features.children()))
self.remain = nn.Sequential(*list(alexnet_model.classifier.children())[:-1])
self.Linear1 = nn.Linear(4096, self.bits)
self.sigmoid = nn.Sigmoid()
self.Linear2 = nn.Linear(self.bits, 10)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), 256 * 6 * 6)
x = self.remain(x)
x = self.Linear1(x)
features = self.sigmoid(x)
result = self.Linear2(features)
return features, result