-
Notifications
You must be signed in to change notification settings - Fork 5
/
models.py
104 lines (93 loc) · 3.34 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
Definition of the FFDNet model and its custom layers
Copyright (C) 2018, Matias Tassano <[email protected]>
This program is free software: you can use, modify and/or
redistribute it under the terms of the GNU General Public
License as published by the Free Software Foundation, either
version 3 of the License, or (at your option) any later
version. You should have received a copy of this license along
this program. If not, see <http://www.gnu.org/licenses/>.
"""
import torch.nn as nn
from torch.autograd import Variable
import functions
class UpSampleFeatures(nn.Module):
r"""Implements the last layer of FFDNet
"""
def __init__(self):
super(UpSampleFeatures, self).__init__()
def forward(self, x):
return functions.upsamplefeatures(x)
class IntermediateDnCNN(nn.Module):
r"""Implements the middel part of the FFDNet architecture, which
is basically a DnCNN net
"""
def __init__(self, input_features, middle_features, num_conv_layers):
super(IntermediateDnCNN, self).__init__()
self.kernel_size = 3
self.padding = 1
self.input_features = input_features
self.num_conv_layers = num_conv_layers
self.middle_features = middle_features
if self.input_features == 5:
self.output_features = 4 #Grayscale image
elif self.input_features == 15:
self.output_features = 12 #RGB image
else:
raise Exception('Invalid number of input features')
layers = []
layers.append(nn.Conv2d(in_channels=self.input_features,\
out_channels=self.middle_features,\
kernel_size=self.kernel_size,\
padding=self.padding,\
bias=False))
layers.append(nn.ReLU(inplace=True))
for _ in range(self.num_conv_layers-2):
layers.append(nn.Conv2d(in_channels=self.middle_features,\
out_channels=self.middle_features,\
kernel_size=self.kernel_size,\
padding=self.padding,\
bias=False))
layers.append(nn.BatchNorm2d(self.middle_features))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(in_channels=self.middle_features,\
out_channels=self.output_features,\
kernel_size=self.kernel_size,\
padding=self.padding,\
bias=False))
self.itermediate_dncnn = nn.Sequential(*layers)
def forward(self, x):
out = self.itermediate_dncnn(x)
return out
class FFDNet(nn.Module):
r"""Implements the FFDNet architecture
"""
def __init__(self, num_input_channels):
super(FFDNet, self).__init__()
self.num_input_channels = num_input_channels
if self.num_input_channels == 1:
# Grayscale image
self.num_feature_maps = 64
self.num_conv_layers = 15
self.downsampled_channels = 5
self.output_features = 4
elif self.num_input_channels == 3:
# RGB image
self.num_feature_maps = 96
self.num_conv_layers = 12
self.downsampled_channels = 15
self.output_features = 12
else:
raise Exception('Invalid number of input features')
self.intermediate_dncnn = IntermediateDnCNN(\
input_features=self.downsampled_channels,\
middle_features=self.num_feature_maps,\
num_conv_layers=self.num_conv_layers)
self.upsamplefeatures = UpSampleFeatures()
def forward(self, x, noise_sigma):
concat_noise_x = functions.concatenate_input_noise_map(\
x.data, noise_sigma.data)
concat_noise_x = Variable(concat_noise_x)
h_dncnn = self.intermediate_dncnn(concat_noise_x)
pred_noise = self.upsamplefeatures(h_dncnn)
return pred_noise