forked from SJTU-Lucy/CS238-VirtualReality
-
Notifications
You must be signed in to change notification settings - Fork 0
/
vgg.py
88 lines (73 loc) · 3.08 KB
/
vgg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch as T
import torch.nn as nn
from neural_monitor import logger
import h5py
import numpy as np
import torch.nn.functional as F
import os
import utils
class Normalization(T.nn.Module):
def __init__(self):
super(Normalization, self).__init__()
self.register_buffer('kern', T.from_numpy(np.array([[0, 0, 255], [0, 255, 0], [255, 0, 0]], 'float32')[:, :, None, None]))
self.register_buffer('bias', T.from_numpy(np.array([-103.939, -116.779, -123.68], 'float32')))
def forward(self, input):
return F.conv2d(input, self.kern, bias=self.bias, padding=0)
class ConvRelu(T.nn.Sequential):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1,
bias: bool = True,
padding_mode: str = 'zeros'
):
super(ConvRelu, self).__init__()
self.pad = nn.ReflectionPad2d(padding)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=0, dilation=dilation,
groups=groups, bias=bias, padding_mode=padding_mode)
self.relu = nn.ReLU()
class VGG19(T.nn.Sequential):
def __init__(self, weight_file):
super(VGG19, self).__init__()
self.norm = Normalization()
self.conv1_1 = ConvRelu(3, 64)
self.conv1_2 = ConvRelu(64, 64)
self.pool1 = nn.MaxPool2d(2, stride=2)
self.conv2_1 = ConvRelu(64, 128)
self.conv2_2 = ConvRelu(128, 128)
self.pool2 = nn.MaxPool2d(2, stride=2)
self.conv3_1 = ConvRelu(128, 256)
self.conv3_2 = ConvRelu(256, 256)
self.conv3_3 = ConvRelu(256, 256)
self.conv3_4 = ConvRelu(256, 256)
self.pool3 = nn.MaxPool2d(2, stride=2)
self.conv4_1 = ConvRelu(256, 512)
self.conv4_2 = ConvRelu(512, 512)
self.conv4_3 = ConvRelu(512, 512)
self.conv4_4 = ConvRelu(512, 512)
self.pool4 = nn.MaxPool2d(2, stride=2)
self.conv5_1 = ConvRelu(512, 512)
self.conv5_2 = ConvRelu(512, 512)
self.conv5_3 = ConvRelu(512, 512)
self.conv5_4 = ConvRelu(512, 512)
self.pool5 = nn.MaxPool2d(2, stride=2)
for p in self.parameters():
p.requires_grad_(False)
self.load_params(weight_file)
def load_params(self, param_file):
if not os.path.exists(param_file):
utils.download_file('https://github.com/ftokarev/tf-vgg-weights/raw/master/vgg19_weights_normalized.h5',
param_file)
f = h5py.File(param_file, mode='r')
trained = [np.array(layer[1], 'float32') for layer in list(f.items())]
weight_value_tuples = []
for p, tp in zip(self.parameters(), trained):
if len(tp.shape) == 4:
tp = np.transpose(tp, (3, 2, 0, 1))
weight_value_tuples.append((p, tp))
utils.batch_set_value(*zip(*(weight_value_tuples)))
logger.info('Pretrained weights loaded successfully!')