-
Notifications
You must be signed in to change notification settings - Fork 86
/
Copy pathIRNp_model.py
294 lines (232 loc) · 11.6 KB
/
IRNp_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import logging
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
import models.networks as networks
import models.lr_scheduler as lr_scheduler
from .base_model import BaseModel
from models.modules.loss import GANLoss, ReconstructionLoss
from models.modules.Quantization import Quantization
logger = logging.getLogger('base')
class IRNpModel(BaseModel):
def __init__(self, opt):
super(IRNpModel, self).__init__(opt)
if opt['dist']:
self.rank = torch.distributed.get_rank()
else:
self.rank = -1 # non dist training
train_opt = opt['train']
test_opt = opt['test']
self.train_opt = train_opt
self.test_opt = test_opt
self.netG = networks.define_G(opt).to(self.device)
if opt['dist']:
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
else:
self.netG = DataParallel(self.netG)
# print network
self.print_network()
self.load()
self.Quantization = Quantization()
if self.is_train:
self.netD = networks.define_D(opt).to(self.device)
if opt['dist']:
self.netD = DistributedDataParallel(self.netD, device_ids=[torch.cuda.current_device()])
else:
self.netD = DataParallel(self.netD)
self.netG.train()
self.netD.train()
# loss
self.Reconstruction_forw = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_forw'])
self.Reconstruction_back = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_back'])
# feature loss
if train_opt['feature_weight'] > 0:
self.Reconstructionf = ReconstructionLoss(losstype=self.train_opt['feature_criterion'])
self.l_fea_w = train_opt['feature_weight']
self.netF = networks.define_F(opt, use_bn=False).to(self.device)
if opt['dist']:
self.netF = DistributedDataParallel(self.netF, device_ids=[torch.cuda.current_device()])
else:
self.netF = DataParallel(self.netF)
else:
self.l_fea_w = 0
# GD gan loss
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
self.l_gan_w = train_opt['gan_weight']
# D_update_ratio and D_init_iters
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
# optimizers
# G
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
optim_params = []
for k, v in self.netG.named_parameters():
if v.requires_grad:
optim_params.append(v)
else:
if self.rank <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
weight_decay=wd_G,
betas=(train_opt['beta1'], train_opt['beta2']))
self.optimizers.append(self.optimizer_G)
# D
wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D']))
self.optimizers.append(self.optimizer_D)
# schedulers
if train_opt['lr_scheme'] == 'MultiStepLR':
for optimizer in self.optimizers:
self.schedulers.append(
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
restarts=train_opt['restarts'],
weights=train_opt['restart_weights'],
gamma=train_opt['lr_gamma'],
clear_state=train_opt['clear_state']))
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
for optimizer in self.optimizers:
self.schedulers.append(
lr_scheduler.CosineAnnealingLR_Restart(
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
else:
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
self.log_dict = OrderedDict()
def feed_data(self, data):
self.ref_L = data['LQ'].to(self.device) # LQ
self.real_H = data['GT'].to(self.device) # GT
def gaussian_batch(self, dims):
return torch.randn(tuple(dims)).to(self.device)
def loss_forward(self, out, y):
l_forw_fit = self.train_opt['lambda_fit_forw'] * self.Reconstruction_forw(out[:, :3, :, :], y)
return l_forw_fit
def loss_backward(self, x, x_samples):
x_samples_image = x_samples[:, :3, :, :]
l_back_rec = self.train_opt['lambda_rec_back'] * self.Reconstruction_back(x, x_samples_image)
# feature loss
if self.l_fea_w > 0:
l_back_fea = self.feature_loss(x, x_samples_image)
else:
l_back_fea = torch.tensor(0)
# GAN loss
pred_g_fake = self.netD(x_samples_image)
if self.opt['train']['gan_type'] == 'gan':
l_back_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
elif self.opt['train']['gan_type'] == 'ragan':
pred_d_real = self.netD(x).detach()
l_back_gan = self.l_gan_w * (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
return l_back_rec, l_back_fea, l_back_gan
def feature_loss(self, real, fake):
real_fea = self.netF(real).detach()
fake_fea = self.netF(fake)
l_g_fea = self.l_fea_w * self.Reconstructionf(real_fea, fake_fea)
return l_g_fea
def optimize_parameters(self, step):
# G
for p in self.netD.parameters():
p.requires_grad = False
self.optimizer_G.zero_grad()
self.input = self.real_H
self.output = self.netG(x=self.input)
loss = 0
zshape = self.output[:, 3:, :, :].shape
LR = self.Quantization(self.output[:, :3, :, :])
gaussian_scale = self.train_opt['gaussian_scale'] if self.train_opt['gaussian_scale'] != None else 1
y_ = torch.cat((LR, gaussian_scale * self.gaussian_batch(zshape)), dim=1)
self.fake_H = self.netG(x=y_, rev=True)
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
l_forw_fit = self.loss_forward(self.output, self.ref_L)
l_back_rec, l_back_fea, l_back_gan = self.loss_backward(self.real_H, self.fake_H)
loss += l_forw_fit + l_back_rec + l_back_fea + l_back_gan
loss.backward()
# gradient clipping
if self.train_opt['gradient_clipping']:
nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping'])
self.optimizer_G.step()
# D
for p in self.netD.parameters():
p.requires_grad = True
self.optimizer_D.zero_grad()
l_d_total = 0
pred_d_real = self.netD(self.real_H)
pred_d_fake = self.netD(self.fake_H.detach())
if self.opt['train']['gan_type'] == 'gan':
l_d_real = self.cri_gan(pred_d_real, True)
l_d_fake = self.cri_gan(pred_d_fake, False)
l_d_total = l_d_real + l_d_fake
elif self.opt['train']['gan_type'] == 'ragan':
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
l_d_total = (l_d_real + l_d_fake) / 2
l_d_total.backward()
self.optimizer_D.step()
# set log
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
self.log_dict['l_forw_fit'] = l_forw_fit.item()
self.log_dict['l_back_rec'] = l_back_rec.item()
self.log_dict['l_back_fea'] = l_back_fea.item()
self.log_dict['l_back_gan'] = l_back_gan.item()
self.log_dict['l_d'] = l_d_total.item()
def test(self):
Lshape = self.ref_L.shape
input_dim = Lshape[1]
self.input = self.real_H
zshape = [Lshape[0], input_dim * (self.opt['scale']**2) - Lshape[1], Lshape[2], Lshape[3]]
gaussian_scale = 1
if self.test_opt and self.test_opt['gaussian_scale'] != None:
gaussian_scale = self.test_opt['gaussian_scale']
self.netG.eval()
with torch.no_grad():
self.forw_L = self.netG(x=self.input)[:, :3, :, :]
self.forw_L = self.Quantization(self.forw_L)
y_forw = torch.cat((self.forw_L, gaussian_scale * self.gaussian_batch(zshape)), dim=1)
self.fake_H = self.netG(x=y_forw, rev=True)[:, :3, :, :]
self.netG.train()
def downscale(self, HR_img):
self.netG.eval()
with torch.no_grad():
LR_img = self.netG(x=HR_img)[:, :3, :, :]
LR_img = self.Quantization(LR_img)
self.netG.train()
return LR_img
def upscale(self, LR_img, scale, gaussian_scale=1):
Lshape = LR_img.shape
zshape = [Lshape[0], Lshape[1] * (scale**2 - 1), Lshape[2], Lshape[3]]
y_ = torch.cat((LR_img, gaussian_scale * self.gaussian_batch(zshape)), dim=1)
self.netG.eval()
with torch.no_grad():
HR_img = self.netG(x=y_, rev=True)[:, :3, :, :]
self.netG.train()
return HR_img
def get_current_log(self):
return self.log_dict
def get_current_visuals(self):
out_dict = OrderedDict()
out_dict['LR_ref'] = self.ref_L.detach()[0].float().cpu()
out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
out_dict['LR'] = self.forw_L.detach()[0].float().cpu()
out_dict['GT'] = self.real_H.detach()[0].float().cpu()
return out_dict
def print_network(self):
s, n = self.get_network_description(self.netG)
if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
self.netG.module.__class__.__name__)
else:
net_struc_str = '{}'.format(self.netG.__class__.__name__)
if self.rank <= 0:
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
logger.info(s)
def load(self):
load_path_G = self.opt['path']['pretrain_model_G']
if load_path_G is not None:
logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
load_path_D = self.opt['path']['pretrain_model_D']
if load_path_D is not None:
logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
self.load_network(load_path_D, self.netD, self.opt['path']['strict_load'])
def save(self, iter_label):
self.save_network(self.netG, 'G', iter_label)
self.save_network(self.netD, 'D', iter_label)