From 93aaa943a96073fd431ac55f12abe717841b26c4 Mon Sep 17 00:00:00 2001 From: Mingqing Xiao Date: Tue, 11 Oct 2022 13:29:44 +0800 Subject: [PATCH] IRN-IJCV --- README.md | 4 +- codes/README.md | 27 +- codes/data/LQGT_dataset.py | 132 ++++--- codes/models/IRN_color_model.py | 208 +++++++++++ codes/models/IRN_model.py | 9 + codes/models/IRN_model_CRM.py | 332 ++++++++++++++++++ codes/models/__init__.py | 4 + codes/models/modules/Apply_jpg.py | 37 ++ codes/models/modules/Inv_arch.py | 179 +++++++++- codes/models/modules/RRDB.py | 64 ++++ codes/models/modules/Replace.py | 19 + codes/models/modules/Subnet_constructor.py | 6 +- codes/models/networks.py | 41 ++- codes/options/options.py | 3 +- codes/options/test/test_IRN+_x4.yml | 2 +- .../test/test_IRN-Compression_x2_q90.yml | 62 ++++ .../test_IRN-Compression_x2_q90_kodak.yml | 42 +++ codes/options/test/test_IRN_color.yml | 49 +++ codes/options/test/test_IRN_x3.yml | 51 +++ codes/options/test/test_IRN_x4.yml | 2 +- codes/options/test/test_IRN_x8.yml | 50 +++ codes/options/train/train_IRN+_x4.yml | 2 +- .../train/train_IRN-Compression_x2_q90.yml | 107 ++++++ codes/options/train/train_IRN_color.yml | 86 +++++ codes/options/train/train_IRN_x2_finetune.yml | 91 +++++ .../train/train_IRN_x2_finetune_kodak.yml | 91 +++++ codes/options/train/train_IRN_x3.yml | 88 +++++ codes/options/train/train_IRN_x8.yml | 87 +++++ codes/test_IRN-Color.py | 127 +++++++ codes/test_IRN-Compression.py | 188 ++++++++++ codes/train_IRN-Color.py | 247 +++++++++++++ codes/train_IRN-Compression.py | 266 ++++++++++++++ 32 files changed, 2632 insertions(+), 71 deletions(-) create mode 100644 codes/models/IRN_color_model.py create mode 100644 codes/models/IRN_model_CRM.py create mode 100644 codes/models/modules/Apply_jpg.py create mode 100644 codes/models/modules/RRDB.py create mode 100644 codes/models/modules/Replace.py create mode 100644 codes/options/test/test_IRN-Compression_x2_q90.yml create mode 100644 codes/options/test/test_IRN-Compression_x2_q90_kodak.yml create mode 100644 codes/options/test/test_IRN_color.yml create mode 100644 codes/options/test/test_IRN_x3.yml create mode 100644 codes/options/test/test_IRN_x8.yml create mode 100644 codes/options/train/train_IRN-Compression_x2_q90.yml create mode 100644 codes/options/train/train_IRN_color.yml create mode 100644 codes/options/train/train_IRN_x2_finetune.yml create mode 100644 codes/options/train/train_IRN_x2_finetune_kodak.yml create mode 100644 codes/options/train/train_IRN_x3.yml create mode 100644 codes/options/train/train_IRN_x8.yml create mode 100644 codes/test_IRN-Color.py create mode 100644 codes/test_IRN-Compression.py create mode 100644 codes/train_IRN-Color.py create mode 100644 codes/train_IRN-Compression.py diff --git a/README.md b/README.md index 4933ba4..3a05de8 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # Invertible Image Rescaling -This is the PyTorch implementation of paper: Invertible Image Rescaling (ECCV 2020 Oral). [arxiv](https://arxiv.org/abs/2005.05650). +This is the PyTorch implementation of paper: Invertible Image Rescaling (ECCV 2020 Oral). \[[link](https://link.springer.com/chapter/10.1007/978-3-030-58452-8_8)\]\[[arxiv](https://arxiv.org/abs/2005.05650)\]. + +**2022/10 Update**: Our paper "Invertible Rescaling Network and Its Extensions" has been accepted by IJCV. \[[link](https://link.springer.com/article/10.1007/s11263-022-01688-4)\]\[[arxiv](https://arxiv.org/abs/2210.04188)\]. We update the repository for experiments in the paper. The previous version can be found in the ECCV branch. ## Dependencies and Installation - Python 3 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux)) diff --git a/codes/README.md b/codes/README.md index 72c756d..cabe5de 100644 --- a/codes/README.md +++ b/codes/README.md @@ -1,14 +1,35 @@ -# Training +# Training for image rescaling First set a config file in options/train/, then run as following: python train.py -opt options/train/train_IRN_x4.yml -# Test +# Testing for image rescaling First set a config file in options/test/, then run as following: python test.py -opt options/test/test_IRN_x4.yml -Pretrained models can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1-Rah2t-fk3uTcNagvTgTRlRTaK2dHktA?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1U38SjqVlqY5YVMsSFrkTsw) (extraction code: lukj). +# Training for image decolorization-colorization +First set a config file in options/train/, then run as following: + + python train.py -opt options/train/train_IRN_color.yml + +# Testing for image decolorization-colorization +First set a config file in options/test/, then run as following: + + python test.py -opt options/test/test_IRN_color.yml + +# Training for combination with image compression +First set a config file in options/train/, then run as following: + + python train.py -opt options/train/train_IRN-Compression_x2_q90.yml + +# Testing for combination with image compression +First set a config file in options/test/, then run as following: + + python test.py -opt options/test/test_IRN-Compression_x2_q90.yml + + +Pretrained models can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1ym6DvYNQegDrOy_4z733HxrULa1XIN92?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/14OvTiJNhFpHHN2yU-h7vDg) (extraction code: rx0z). # Code Framework The code framework follows [BasicSR](https://github.com/xinntao/BasicSR/tree/master/codes). It mainly consists of four parts - `Config`, `Data`, `Model` and `Network`. diff --git a/codes/data/LQGT_dataset.py b/codes/data/LQGT_dataset.py index 5101b21..2d74fcc 100644 --- a/codes/data/LQGT_dataset.py +++ b/codes/data/LQGT_dataset.py @@ -32,12 +32,17 @@ def __init__(self, opt): len(self.paths_LQ), len(self.paths_GT)) self.random_scale_list = [1] + self.use_grey = False + if self.opt['use_grey']: + self.use_grey = True + def _init_lmdb(self): # https://github.com/chainer/chainermn/issues/129 self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, meminit=False) - self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, - meminit=False) + if not self.use_grey: + self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, + meminit=False) def __getitem__(self, index): if self.data_type == 'lmdb': @@ -62,36 +67,37 @@ def __getitem__(self, index): img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] # get LQ image - if self.paths_LQ: - LQ_path = self.paths_LQ[index] - if self.data_type == 'lmdb': - resolution = [int(s) for s in self.sizes_LQ[index].split('_')] - else: - resolution = None - img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) - else: # down-sampling on-the-fly - # randomly scale during training - if self.opt['phase'] == 'train': - random_scale = random.choice(self.random_scale_list) - H_s, W_s, _ = img_GT.shape - - def _mod(n, random_scale, scale, thres): - rlt = int(n * random_scale) - rlt = (rlt // scale) * scale - return thres if rlt < thres else rlt - - H_s = _mod(H_s, random_scale, scale, GT_size) - W_s = _mod(W_s, random_scale, scale, GT_size) - img_GT = cv2.resize(np.copy(img_GT), (W_s, H_s), interpolation=cv2.INTER_LINEAR) - # force to 3 channels - if img_GT.ndim == 2: - img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) - - H, W, _ = img_GT.shape - # using matlab imresize - img_LQ = util.imresize_np(img_GT, 1 / scale, True) - if img_LQ.ndim == 2: - img_LQ = np.expand_dims(img_LQ, axis=2) + if not self.use_grey: + if self.paths_LQ: + LQ_path = self.paths_LQ[index] + if self.data_type == 'lmdb': + resolution = [int(s) for s in self.sizes_LQ[index].split('_')] + else: + resolution = None + img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) + else: # down-sampling on-the-fly + # randomly scale during training + if self.opt['phase'] == 'train': + random_scale = random.choice(self.random_scale_list) + H_s, W_s, _ = img_GT.shape + + def _mod(n, random_scale, scale, thres): + rlt = int(n * random_scale) + rlt = (rlt // scale) * scale + return thres if rlt < thres else rlt + + H_s = _mod(H_s, random_scale, scale, GT_size) + W_s = _mod(W_s, random_scale, scale, GT_size) + img_GT = cv2.resize(np.copy(img_GT), (W_s, H_s), interpolation=cv2.INTER_LINEAR) + # force to 3 channels + if img_GT.ndim == 2: + img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) + + H, W, _ = img_GT.shape + # using matlab imresize + img_LQ = util.imresize_np(img_GT, 1 / scale, True) + if img_LQ.ndim == 2: + img_LQ = np.expand_dims(img_LQ, axis=2) if self.opt['phase'] == 'train': # if the image size is too small @@ -100,39 +106,59 @@ def _mod(n, random_scale, scale, thres): img_GT = cv2.resize(np.copy(img_GT), (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) # using matlab imresize - img_LQ = util.imresize_np(img_GT, 1 / scale, True) - if img_LQ.ndim == 2: - img_LQ = np.expand_dims(img_LQ, axis=2) - - H, W, C = img_LQ.shape - LQ_size = GT_size // scale - - # randomly crop - rnd_h = random.randint(0, max(0, H - LQ_size)) - rnd_w = random.randint(0, max(0, W - LQ_size)) - img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] - rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale) - img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] + if not self.use_grey: + img_LQ = util.imresize_np(img_GT, 1 / scale, True) + if img_LQ.ndim == 2: + img_LQ = np.expand_dims(img_LQ, axis=2) + + if not self.use_grey: + H, W, C = img_LQ.shape + LQ_size = GT_size // scale + + # randomly crop + rnd_h = random.randint(0, max(0, H - LQ_size)) + rnd_w = random.randint(0, max(0, W - LQ_size)) + img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] + rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale) + img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] + else: + rnd_h_GT = random.randint(0, max(0, H - GT_size)) + rnd_w_GT = random.randint(0, max(0, W - GT_size)) + img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] # augmentation - flip, rotate - img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'], - self.opt['use_rot']) + if not self.use_grey: + img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'], + self.opt['use_rot']) + else: + img_GT = util.augment([img_GT], self.opt['use_flip'], self.opt['use_rot'])[0] # change color space if necessary - if self.opt['color']: - img_LQ = util.channel_convert(C, self.opt['color'], - [img_LQ])[0] # TODO during val no definition + if not self.use_grey: + if self.opt['color']: + img_LQ = util.channel_convert(C, self.opt['color'], + [img_LQ])[0] # TODO during val no definition + if self.use_grey: + img_Grey = cv2.cvtColor(img_GT, cv2.COLOR_BGR2GRAY) # BGR to RGB, HWC to CHW, numpy to tensor if img_GT.shape[2] == 3: img_GT = img_GT[:, :, [2, 1, 0]] - img_LQ = img_LQ[:, :, [2, 1, 0]] + if not self.use_grey: + img_LQ = img_LQ[:, :, [2, 1, 0]] img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() - img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() + if not self.use_grey: + img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() + if self.use_grey: + img_Grey = torch.from_numpy(np.ascontiguousarray(np.expand_dims(img_Grey, 0))).float() if LQ_path is None: LQ_path = GT_path - return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path} + + if not self.use_grey: + return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path} + else: + return {'Grey': img_Grey, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path} def __len__(self): return len(self.paths_GT) diff --git a/codes/models/IRN_color_model.py b/codes/models/IRN_color_model.py new file mode 100644 index 0000000..cc044c8 --- /dev/null +++ b/codes/models/IRN_color_model.py @@ -0,0 +1,208 @@ +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn +from torch.nn.parallel import DataParallel, DistributedDataParallel +import models.networks as networks +import models.lr_scheduler as lr_scheduler +from .base_model import BaseModel +from models.modules.loss import ReconstructionLoss +from models.modules.Quantization import Quantization + +logger = logging.getLogger('base') + +class IRNColorModel(BaseModel): + def __init__(self, opt): + super(IRNColorModel, self).__init__(opt) + + if opt['dist']: + self.rank = torch.distributed.get_rank() + else: + self.rank = -1 # non dist training + train_opt = opt['train'] + test_opt = opt['test'] + self.train_opt = train_opt + self.test_opt = test_opt + + self.netG = networks.define_grey(opt).to(self.device) + if opt['dist']: + self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) + else: + self.netG = DataParallel(self.netG) + # print network + self.print_network() + self.load() + + self.Quantization = Quantization() + + if self.is_train: + self.netG.train() + + # loss + self.Reconstruction_forw = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_forw']) + self.Reconstruction_back = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_back']) + + + # optimizers + wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 + optim_params = [] + for k, v in self.netG.named_parameters(): + if v.requires_grad: + optim_params.append(v) + else: + if self.rank <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) + self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], + weight_decay=wd_G, + betas=(train_opt['beta1'], train_opt['beta2'])) + self.optimizers.append(self.optimizer_G) + + # schedulers + if train_opt['lr_scheme'] == 'MultiStepLR': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], + restarts=train_opt['restarts'], + weights=train_opt['restart_weights'], + gamma=train_opt['lr_gamma'], + clear_state=train_opt['clear_state'])) + elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.CosineAnnealingLR_Restart( + optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], + restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) + else: + raise NotImplementedError('MultiStepLR learning rate scheme is enough.') + + self.log_dict = OrderedDict() + + def feed_data(self, data): + self.ref_Grey = data['Grey'].to(self.device) + self.real_H = data['GT'].to(self.device) # GT + + def gaussian_batch(self, dims): + return torch.randn(tuple(dims)).to(self.device) + + def loss_forward(self, out, y, z): + l_forw_fit = self.train_opt['lambda_fit_forw'] * self.Reconstruction_forw(out, y) + + z = z.reshape([out.shape[0], -1]) + l_forw_ce = self.train_opt['lambda_ce_forw'] * torch.sum(z**2) / z.shape[0] + + return l_forw_fit, l_forw_ce + + def loss_backward(self, x, y): + x_samples = self.netG(x=y, rev=True) + x_samples_image = x_samples[:, :3, :, :] + l_back_rec = self.train_opt['lambda_rec_back'] * self.Reconstruction_back(x, x_samples_image) + + return l_back_rec + + + def optimize_parameters(self, step): + self.optimizer_G.zero_grad() + + # forward decolorization + self.input = self.real_H + self.output = self.netG(x=self.input) + + zshape = self.output[:, 1:, :, :].shape + Grey_ref = self.ref_Grey.detach() + + l_forw_fit, l_forw_ce = self.loss_forward(self.output[:, :1, :, :], Grey_ref, self.output[:, 1:, :, :]) + + # backward upscaling + Grey = self.Quantization(self.output[:, :1, :, :]) + gaussian_scale = self.train_opt['gaussian_scale'] if self.train_opt['gaussian_scale'] != None else 0 + y_ = torch.cat((Grey, gaussian_scale * self.gaussian_batch(zshape)), dim=1) + + l_back_rec = self.loss_backward(self.real_H, y_) + + # total loss + loss = l_forw_fit + l_back_rec + l_forw_ce + loss.backward() + + # gradient clipping + if self.train_opt['gradient_clipping']: + nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping']) + + self.optimizer_G.step() + + # set log + self.log_dict['l_forw_fit'] = l_forw_fit.item() + self.log_dict['l_forw_ce'] = l_forw_ce.item() + self.log_dict['l_back_rec'] = l_back_rec.item() + + def test(self): + Lshape = self.ref_Grey.shape + + self.input = self.real_H + + zshape = [Lshape[0], 2, Lshape[2], Lshape[3]] + + gaussian_scale = 0 + if self.test_opt and self.test_opt['gaussian_scale'] != None: + gaussian_scale = self.test_opt['gaussian_scale'] + + self.netG.eval() + with torch.no_grad(): + self.forw_L = self.netG(x=self.input)[:, :1, :, :] + self.forw_L = self.Quantization(self.forw_L) + y_forw = torch.cat((self.forw_L, gaussian_scale * self.gaussian_batch(zshape)), dim=1) + self.fake_H = self.netG(x=y_forw, rev=True)[:, :3, :, :] + + self.netG.train() + + def decolorize(self, img): + self.netG.eval() + with torch.no_grad(): + Grey_img = self.netG(x=img)[:, :1, :, :] + Grey_img = self.Quantization(Grey_img) + self.netG.train() + + return Grey_img + + def colorize(self, Grey_img, gaussian_scale=0): + Lshape = Grey_img.shape + zshape = [Lshape[0], 2, Lshape[2], Lshape[3]] + y_ = torch.cat((Grey_img, gaussian_scale * self.gaussian_batch(zshape)), dim=1) + + self.netG.eval() + with torch.no_grad(): + img = self.netG(x=y_, rev=True)[:, :3, :, :] + self.netG.train() + + return img + + def get_current_log(self): + return self.log_dict + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['Grey_ref'] = self.ref_Grey.detach()[0].float().cpu() + out_dict['Color'] = self.fake_H.detach()[0].float().cpu() + out_dict['Grey'] = self.forw_L.detach()[0].float().cpu() + out_dict['GT'] = self.real_H.detach()[0].float().cpu() + return out_dict + + def print_network(self): + s, n = self.get_network_description(self.netG) + if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): + net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, + self.netG.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netG.__class__.__name__) + if self.rank <= 0: + logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) + logger.info(s) + + def load(self): + load_path_G = self.opt['path']['pretrain_model_G'] + if load_path_G is not None: + logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) + self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) + + def save(self, iter_label): + self.save_network(self.netG, 'G', iter_label) diff --git a/codes/models/IRN_model.py b/codes/models/IRN_model.py index f363888..31df839 100644 --- a/codes/models/IRN_model.py +++ b/codes/models/IRN_model.py @@ -9,6 +9,7 @@ from .base_model import BaseModel from models.modules.loss import ReconstructionLoss from models.modules.Quantization import Quantization +import numpy as np logger = logging.getLogger('base') @@ -115,6 +116,14 @@ def optimize_parameters(self, step): # backward upscaling LR = self.Quantization(self.output[:, :3, :, :]) + + if self.train_opt['add_noise_on_y']: + probability = self.train_opt['y_noise_prob'] + noise_scale = self.train_opt['y_noise_scale'] + prob = np.random.rand() + if prob < probability: + LR = LR + noise_scale * self.gaussian_batch(LR.shape) + gaussian_scale = self.train_opt['gaussian_scale'] if self.train_opt['gaussian_scale'] != None else 1 y_ = torch.cat((LR, gaussian_scale * self.gaussian_batch(zshape)), dim=1) diff --git a/codes/models/IRN_model_CRM.py b/codes/models/IRN_model_CRM.py new file mode 100644 index 0000000..0c257d1 --- /dev/null +++ b/codes/models/IRN_model_CRM.py @@ -0,0 +1,332 @@ +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn +from torch.nn.parallel import DataParallel, DistributedDataParallel +import models.networks as networks +import models.lr_scheduler as lr_scheduler +from .base_model import BaseModel +from models.modules.loss import ReconstructionLoss +from models.modules.Quantization import Quantization +from models.modules.Apply_jpg import apply_jpg +from models.modules.Replace import Replace + +logger = logging.getLogger('base') + +class IRNCRMModel(BaseModel): + def __init__(self, opt): + super(IRNCRMModel, self).__init__(opt) + + if opt['dist']: + self.rank = torch.distributed.get_rank() + else: + self.rank = -1 # non dist training + train_opt = opt['train'] + test_opt = opt['test'] + self.train_opt = train_opt + self.test_opt = test_opt + + self.netG = networks.define_G(opt).to(self.device) + if opt['dist']: + self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) + else: + self.netG = DataParallel(self.netG) + + self.netR = networks.define_R(opt).to(self.device) + if opt['dist']: + self.netR = DistributedDataParallel(self.netR, device_ids=[torch.cuda.current_device()]) + else: + self.netR = DataParallel(self.netR) + + # print network + self.print_network() + self.load() + + self.Quantization = Quantization() + self.apply_jpg = apply_jpg() + self.Replace = Replace() + + if self.is_train: + self.netG.train() + self.netR.train() + + # loss + self.Reconstruction_forw = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_forw']) + self.Reconstruction_back = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_back']) + + self.Reconstruction_jpeg = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_jpeg']) + + + # optimizers + wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 + wd_R = train_opt['weight_decay_R'] if train_opt['weight_decay_R'] else 0 + optim_params = [] + for k, v in self.netG.named_parameters(): + if v.requires_grad: + optim_params.append(v) + else: + if self.rank <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) + self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], + weight_decay=wd_G, + betas=(train_opt['beta1'], train_opt['beta2'])) + self.optimizers.append(self.optimizer_G) + + optim_params = [] + for k, v in self.netR.named_parameters(): + if v.requires_grad: + optim_params.append(v) + else: + if self.rank <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) + self.optimizer_R = torch.optim.Adam(optim_params, lr=train_opt['lr_R'], + weight_decay=wd_R, + betas=(train_opt['beta1_R'], train_opt['beta2_R'])) + self.optimizers.append(self.optimizer_R) + + # schedulers + if train_opt['lr_scheme'] == 'MultiStepLR': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], + restarts=train_opt['restarts'], + weights=train_opt['restart_weights'], + gamma=train_opt['lr_gamma'], + clear_state=train_opt['clear_state'])) + elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.CosineAnnealingLR_Restart( + optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], + restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) + else: + raise NotImplementedError('MultiStepLR learning rate scheme is enough.') + + self.log_dict = OrderedDict() + + def feed_data(self, data): + self.ref_L = data['LQ'].to(self.device) # LQ + self.real_H = data['GT'].to(self.device) # GT + + def gaussian_batch(self, dims): + return torch.randn(tuple(dims)).to(self.device) + + def loss_forward(self, out, y, z): + l_forw_fit = self.train_opt['lambda_fit_forw'] * self.Reconstruction_forw(out, y) + + z = z.reshape([out.shape[0], -1]) + l_forw_ce = self.train_opt['lambda_ce_forw'] * torch.sum(z**2) / z.shape[0] + + return l_forw_fit, l_forw_ce + + def loss_backward(self, x, y): + x_samples = self.netG(x=y, rev=True) + x_samples_image = x_samples[:, :3, :, :] + l_back_rec = self.train_opt['lambda_rec_back'] * self.Reconstruction_back(x, x_samples_image) + + return l_back_rec + + + def optimize_parameters(self, step): + if self.train_opt['only_jpeg_reconstruction']: + for p in self.netG.parameters(): + p.requires_grad = False + + self.optimizer_R.zero_grad() + + self.input = self.real_H + with torch.no_grad(): + self.output = self.netG(x=self.input) + LR = self.Quantization(self.output[:, :3, :, :]) + LR_ = LR.clone() + quality = self.train_opt['jpg_quality'] + self.output_jpeg = self.apply_jpg(LR_, quality).detach() + + self.output_restore = self.netR(self.output_jpeg) + l_jpeg_rec = self.train_opt['lambda_rec_jpeg'] * self.Reconstruction_jpeg(LR, self.output_restore) + loss = l_jpeg_rec + + if self.train_opt['add_joint_loss']: + start_iter = self.train_opt['joint_loss_iters'] if self.train_opt['joint_loss_iters'] != None else -1 + if step > start_iter: + self.output_restore = self.Quantization(self.output_restore) + zshape = self.output[:, 3:, :, :].shape + gaussian_scale = self.train_opt['gaussian_scale'] if self.train_opt['gaussian_scale'] != None else 1 + y_ = torch.cat((self.output_restore, gaussian_scale * self.gaussian_batch(zshape)), dim=1) + x_samples = self.netG(x=y_, rev=True)[:, :3, :, :] + l_back_rec = self.train_opt['lambda_joint_back'] * self.Reconstruction_back(self.real_H, x_samples) + loss += l_back_rec + loss.backward() + + # gradient clipping + if self.train_opt['gradient_clipping']: + nn.utils.clip_grad_norm_(self.netR.parameters(), self.train_opt['gradient_clipping']) + + self.optimizer_R.step() + + # set log + self.log_dict['l_jpeg_rec'] = l_jpeg_rec.item() + if self.train_opt['add_joint_loss'] and step > start_iter: + self.log_dict['l_back_rec'] = l_back_rec.item() + + for p in self.netG.parameters(): + p.requires_grad = True + + else: + self.optimizer_G.zero_grad() + self.optimizer_R.zero_grad() + + # forward downscaling + self.input = self.real_H + self.output = self.netG(x=self.input) + + zshape = self.output[:, 3:, :, :].shape + LR_ref = self.ref_L.detach() + + l_forw_fit, l_forw_ce = self.loss_forward(self.output[:, :3, :, :], LR_ref, self.output[:, 3:, :, :]) + + # backward upscaling + LR = self.Quantization(self.output[:, :3, :, :]) + LR_ = LR.clone() + quality = self.train_opt['jpg_quality'] + self.output_jpeg = self.apply_jpg(LR_, quality).detach() + self.output_restore = self.netR(x=self.output_jpeg) + l_jpeg_rec = self.train_opt['lambda_rec_jpeg'] * self.Reconstruction_jpeg(LR, self.output_restore) + + LR = self.Replace(LR, self.Quantization(self.output_restore)) + gaussian_scale = self.train_opt['gaussian_scale'] if self.train_opt['gaussian_scale'] != None else 1 + y_ = torch.cat((LR, gaussian_scale * self.gaussian_batch(zshape)), dim=1) + + l_back_rec = self.loss_backward(self.real_H, y_) + + # total loss + loss = l_jpeg_rec + l_forw_fit + l_back_rec + l_forw_ce + loss.backward() + + # gradient clipping + if self.train_opt['gradient_clipping']: + nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping']) + nn.utils.clip_grad_norm_(self.netR.parameters(), self.train_opt['gradient_clipping']) + + self.optimizer_G.step() + self.optimizer_R.step() + + # set log + self.log_dict['l_forw_fit'] = l_forw_fit.item() + self.log_dict['l_forw_ce'] = l_forw_ce.item() + self.log_dict['l_back_rec'] = l_back_rec.item() + self.log_dict['l_jpeg_rec'] = l_jpeg_rec.item() + + def test(self): + if self.test_opt and self.test_opt['bic_crm']: + self.netR.eval() + quality = self.test_opt['jpg_quality'] + with torch.no_grad(): + self.jpeg_L = self.apply_jpg(self.ref_L, quality) + self.restore_L = self.netR(x=self.jpeg_L) + self.restore_L = self.Quantization(self.restore_L) + self.forw_L = self.ref_L + self.fake_H = self.restore_L + self.netR.train() + return + Lshape = self.ref_L.shape + + input_dim = Lshape[1] + self.input = self.real_H + + zshape = [Lshape[0], input_dim * (self.opt['scale']**2) - Lshape[1], Lshape[2], Lshape[3]] + + gaussian_scale = 1 + if self.test_opt and self.test_opt['gaussian_scale'] != None: + gaussian_scale = self.test_opt['gaussian_scale'] + + self.netG.eval() + self.netR.eval() + with torch.no_grad(): + self.forw_L = self.netG(x=self.input)[:, :3, :, :] + self.forw_L = self.Quantization(self.forw_L) + forw_L_ = self.forw_L.clone() + quality = self.test_opt['jpg_quality'] + self.jpeg_L = self.apply_jpg(forw_L_, quality).detach() + if self.test_opt['ignore_restore']: + self.restore_L = self.jpeg_L + else: + self.restore_L = self.netR(x=self.jpeg_L) + self.restore_L = self.Quantization(self.restore_L) + + y_forw = torch.cat((self.restore_L, gaussian_scale * self.gaussian_batch(zshape)), dim=1) + self.fake_H = self.netG(x=y_forw, rev=True)[:, :3, :, :] + + + self.netG.train() + self.netR.train() + + def downscale(self, HR_img): + self.netG.eval() + with torch.no_grad(): + LR_img = self.netG(x=HR_img)[:, :3, :, :] + LR_img = self.Quantization(LR_img) + self.netG.train() + + return LR_img + + def upscale(self, LR_img, scale, gaussian_scale=1): + Lshape = LR_img.shape + zshape = [Lshape[0], Lshape[1] * (scale**2 - 1), Lshape[2], Lshape[3]] + y_ = torch.cat((LR_img, gaussian_scale * self.gaussian_batch(zshape)), dim=1) + + self.netG.eval() + with torch.no_grad(): + HR_img = self.netG(x=y_, rev=True)[:, :3, :, :] + self.netG.train() + + return HR_img + + def get_current_log(self): + return self.log_dict + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['LR_ref'] = self.ref_L.detach()[0].float().cpu() + out_dict['SR'] = self.fake_H.detach()[0].float().cpu() + out_dict['LR'] = self.forw_L.detach()[0].float().cpu() + out_dict['GT'] = self.real_H.detach()[0].float().cpu() + out_dict['RLR'] = self.restore_L.detach()[0].float().cpu() + return out_dict + + def print_network(self): + s, n = self.get_network_description(self.netG) + if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): + net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, + self.netG.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netG.__class__.__name__) + if self.rank <= 0: + logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) + logger.info(s) + + s, n = self.get_network_description(self.netR) + if isinstance(self.netR, nn.DataParallel) or isinstance(self.netR, DistributedDataParallel): + net_struc_str = '{} - {}'.format(self.netR.__class__.__name__, + self.netR.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netR.__class__.__name__) + if self.rank <= 0: + logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) + logger.info(s) + + def load(self): + load_path_G = self.opt['path']['pretrain_model_G'] + if load_path_G is not None: + logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) + self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) + + load_path_R = self.opt['path']['pretrain_model_R'] + if load_path_R is not None: + logger.info('Loading model for R [{:s}] ...'.format(load_path_R)) + self.load_network(load_path_R, self.netR, self.opt['path']['strict_load']) + + def save(self, iter_label): + self.save_network(self.netG, 'G', iter_label) + self.save_network(self.netR, 'R', iter_label) diff --git a/codes/models/__init__.py b/codes/models/__init__.py index ea71ee0..c5452ea 100644 --- a/codes/models/__init__.py +++ b/codes/models/__init__.py @@ -9,6 +9,10 @@ def create_model(opt): from .IRN_model import IRNModel as M elif model == 'IRN+': from .IRNp_model import IRNpModel as M + elif model == 'IRN-CRM': + from .IRN_model_CRM import IRNCRMModel as M + elif model == 'IRN-Color': + from .IRN_color_model import IRNColorModel as M else: raise NotImplementedError('Model [{:s}] not recognized.'.format(model)) m = M(opt) diff --git a/codes/models/modules/Apply_jpg.py b/codes/models/modules/Apply_jpg.py new file mode 100644 index 0000000..6fd46a0 --- /dev/null +++ b/codes/models/modules/Apply_jpg.py @@ -0,0 +1,37 @@ +import torch +import torch.nn as nn +import numpy as np +import cv2 + +class JPG(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, quality): + output = input + batch_size = input.shape[0] + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] + for i in range(batch_size): + tensor = input[i, :, :, :].squeeze().float().cpu().clamp_(0, 1) + img = tensor.numpy() + img = np.transpose(img[[2, 1, 0], :, :], (1, 2, 0)) + img = (img * 255.).round().astype(np.uint8) + + _, encimg = cv2.imencode('.jpg', img, encode_param) + decimg = cv2.imdecode(encimg, 1) + decimg = decimg * 1.0 / 255. + decimg = decimg[:, :, [2, 1, 0]] + dectensor = torch.from_numpy(np.ascontiguousarray(np.transpose(decimg, (2, 0, 1)))).float() + output[i, :, :, :] = dectensor.cuda() + + return output + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + +class apply_jpg(nn.Module): + def __init__(self): + super(apply_jpg, self).__init__() + + def forward(self, input, quality): + return JPG.apply(input, quality) diff --git a/codes/models/modules/Inv_arch.py b/codes/models/modules/Inv_arch.py index ae4f638..ee80e6f 100644 --- a/codes/models/modules/Inv_arch.py +++ b/codes/models/modules/Inv_arch.py @@ -84,20 +84,186 @@ def jacobian(self, x, rev=False): return self.last_jac +class ConvDownsampling(nn.Module): + def __init__(self, scale): + super(ConvDownsampling, self).__init__() + self.scale = scale + self.scale2 = self.scale ** 2 + + self.conv_weights = torch.eye(self.scale2) + + if self.scale == 2: # haar init + self.conv_weights[0] = torch.Tensor([1./4, 1./4, 1./4, 1./4]) + self.conv_weights[1] = torch.Tensor([1./4, -1./4, 1./4, -1./4]) + self.conv_weights[2] = torch.Tensor([1./4, 1./4, -1./4, -1./4]) + self.conv_weights[3] = torch.Tensor([1./4, -1./4, -1./4, 1./4]) + else: + self.conv_weights[0] = torch.Tensor([1./(self.scale2)] * (self.scale2)) + + self.conv_weights = nn.Parameter(self.conv_weights) + + def forward(self, x, rev=False): + if not rev: + # downsample + # may need improvement + h = x.shape[2] + w = x.shape[3] + wpad = 0 + hpad = 0 + if w % self.scale != 0: + wpad = self.scale - w % self.scale + if h % self.scale != 0: + hpad = self.scale - h % self.scale + if wpad != 0 or hpad != 0: + padding = (wpad // 2, wpad - wpad // 2, hpad // 2, hpad - hadp // 2) + pad = nn.ReplicationPad2d(padding) + x = pad(x) + + [B, C, H, W] = list(x.size()) + x = x.reshape(B, C, H // self.scale, self.scale, W // self.scale, self.scale) + x = x.permute(0, 1, 3, 5, 2, 4) + x = x.reshape(B, C * self.scale2, H // self.scale, W // self.scale) + + # conv + conv_weights = self.conv_weights.reshape(self.scale2, self.scale2, 1, 1) + conv_weights = conv_weights.repeat(C, 1, 1, 1) + + out = F.conv2d(x, conv_weights, bias=None, stride=1, groups=C) + + out = out.reshape(B, C, self.scale2, H // self.scale, W // self.scale) + out = torch.transpose(out, 1, 2) + out = out.reshape(B, C * self.scale2, H // self.scale, W // self.scale) + + return out + else: + inv_weights = torch.inverse(self.conv_weights) + inv_weights = inv_weights.reshape(self.scale2, self.scale2, 1, 1) + + [B, C_, H_, W_] = list(x.size()) + C = C_ // self.scale2 + H = H_ * self.scale + W = W_ * self.scale + + inv_weights = inv_weights.repeat(C, 1, 1, 1) + + x = x.reshape(B, self.scale2, C, H_, W_) + x = torch.transpose(x, 1, 2) + x = x.reshape(B, C_, H_, W_) + + out = F.conv2d(x, inv_weights, bias=None, stride=1, groups=C) + + out = out.reshape(B, C, self.scale, self.scale, H_, W_) + out = out.permute(0, 1, 4, 2, 5, 3) + out = out.reshape(B, C, H, W) + + return out + + class InvRescaleNet(nn.Module): - def __init__(self, channel_in=3, channel_out=3, subnet_constructor=None, block_num=[], down_num=2): + def __init__(self, channel_in=3, channel_out=3, subnet_constructor=None, block_num=[], down_num=2, down_first=False, use_ConvDownsampling=False, down_scale=4): super(InvRescaleNet, self).__init__() operations = [] + if use_ConvDownsampling: + down_num = 1 + down_first = True + current_channel = channel_in - for i in range(down_num): - b = HaarDownsampling(current_channel) - operations.append(b) - current_channel *= 4 - for j in range(block_num[i]): + if down_first: + for i in range(down_num): + if use_ConvDownsampling: + b = ConvDownsampling(down_scale) + current_channel *= down_scale**2 + else: + b = HaarDownsampling(current_channel) + current_channel *= 4 + operations.append(b) + for j in range(block_num[0]): b = InvBlockExp(subnet_constructor, current_channel, channel_out) operations.append(b) + else: + for i in range(down_num): + b = HaarDownsampling(current_channel) + operations.append(b) + current_channel *= 4 + for j in range(block_num[i]): + b = InvBlockExp(subnet_constructor, current_channel, channel_out) + operations.append(b) + + self.operations = nn.ModuleList(operations) + + def forward(self, x, rev=False, cal_jacobian=False): + out = x + jacobian = 0 + + if not rev: + for op in self.operations: + out = op.forward(out, rev) + if cal_jacobian: + jacobian += op.jacobian(out, rev) + else: + for op in reversed(self.operations): + out = op.forward(out, rev) + if cal_jacobian: + jacobian += op.jacobian(out, rev) + + if cal_jacobian: + return out, jacobian + else: + return out + + +class Conv1x1Grey(nn.Module): + def __init__(self, rgb_type, learnable=True): + super(Conv1x1Grey, self).__init__() + + self.channel_in = 3 + self.conv_weights = torch.eye(self.channel_in) + if rgb_type == 'RGB': + self.conv_weights[0] = torch.Tensor([0.299, 0.587, 0.114]) + self.conv_weights[1] = torch.Tensor([-0.147, -0.289, 0.436]) + self.conv_weights[2] = torch.Tensor([0.615, -0.515, -0.100]) + elif rgb_type == 'BGR': + self.conv_weights[0] = torch.Tensor([0.114, 0.587, 0.299]) + self.conv_weights[1] = torch.Tensor([0.436, -0.289, -0.147]) + self.conv_weights[2] = torch.Tensor([-0.100, -0.515, 0.615]) + else: + print("Error! Undefined RGB type!") + exit(1) + + self.conv_weights = nn.Parameter(self.conv_weights) + + if not learnable: + self.conv_weights.requires_grad = False + + def forward(self, x, rev=False): + if not rev: + conv_weights = self.conv_weights.reshape(self.channel_in, self.channel_in, 1, 1) + out = F.conv2d(x, conv_weights, bias=None, stride=1) + return out + else: + inv_weights = torch.inverse(self.conv_weights) + inv_weights = inv_weights.reshape(self.channel_in, self.channel_in, 1, 1) + out = F.conv2d(x, inv_weights, bias=None, stride=1) + return out + + +class InvGreyNet(nn.Module): + def __init__(self, rgb_type, subnet_constructor=None, block_num=[], Conv1x1Grey_learnable=True): + super(InvGreyNet, self).__init__() + + channel_in = 3 + channel_out = 1 + + operations = [] + + b = Conv1x1Grey(rgb_type, Conv1x1Grey_learnable) + operations.append(b) + + for j in range(block_num[0]): + b = InvBlockExp(subnet_constructor, channel_in, channel_out) + operations.append(b) self.operations = nn.ModuleList(operations) @@ -121,3 +287,4 @@ def forward(self, x, rev=False, cal_jacobian=False): else: return out + diff --git a/codes/models/modules/RRDB.py b/codes/models/modules/RRDB.py new file mode 100644 index 0000000..994587e --- /dev/null +++ b/codes/models/modules/RRDB.py @@ -0,0 +1,64 @@ +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +import models.modules.module_util as mutil + +class ResidualDenseBlock_5C(nn.Module): + def __init__(self, nf=64, gc=32, bias=True): + super(ResidualDenseBlock_5C, self).__init__() + self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + #mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4], 0.1) + mutil.initialize_weights(self.conv5, 0) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + + return x5 * 0.2 + x + + +class RRDB(nn.Module): + def __init__(self, nf, gc=32): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C(nf, gc) + self.RDB2 = ResidualDenseBlock_5C(nf, gc) + self.RDB3 = ResidualDenseBlock_5C(nf, gc) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + + return out * 0.2 + x + +class RRDBNet(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, gc=32): + super(RRDBNet, self).__init__() + RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) + + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + fea = self.conv_first(x) + trunk = self.trunk_conv(self.RRDB_trunk(fea)) + fea = fea + trunk + + out = self.conv_last(fea) + + return out diff --git a/codes/models/modules/Replace.py b/codes/models/modules/Replace.py new file mode 100644 index 0000000..be67f87 --- /dev/null +++ b/codes/models/modules/Replace.py @@ -0,0 +1,19 @@ +import torch +import torch.nn as nn + +class replace(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, replace_input): + return replace_input + + @staticmethod + def backward(ctx, grad_output): + return grad_output, grad_output + +class Replace(nn.Module): + def __init__(self): + super(Replace, self).__init__() + + def forward(self, input, replace_input): + return replace.apply(input, replace_input) diff --git a/codes/models/modules/Subnet_constructor.py b/codes/models/modules/Subnet_constructor.py index 5833d91..be77585 100644 --- a/codes/models/modules/Subnet_constructor.py +++ b/codes/models/modules/Subnet_constructor.py @@ -29,13 +29,13 @@ def forward(self, x): return x5 -def subnet(net_structure, init='xavier'): +def subnet(net_structure, init='xavier', gc=32): def constructor(channel_in, channel_out): if net_structure == 'DBNet': if init == 'xavier': - return DenseBlock(channel_in, channel_out, init) + return DenseBlock(channel_in, channel_out, init, gc=gc) else: - return DenseBlock(channel_in, channel_out) + return DenseBlock(channel_in, channel_out, gc=gc) else: return None diff --git a/codes/models/networks.py b/codes/models/networks.py index 0a9323c..8b71e0e 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -3,6 +3,7 @@ import models.modules.discriminator_vgg_arch as SRGAN_arch from models.modules.Inv_arch import * from models.modules.Subnet_constructor import subnet +from models.modules.RRDB import RRDBNet import math logger = logging.getLogger('base') @@ -18,14 +19,52 @@ def define_G(opt): init = opt_net['init'] else: init = 'xavier' + if opt_net['gc']: + gc = opt_net['gc'] + else: + gc = 32 + use_ConvDownsampling = False + down_first = False down_num = int(math.log(opt_net['scale'], 2)) - netG = InvRescaleNet(opt_net['in_nc'], opt_net['out_nc'], subnet(subnet_type, init), opt_net['block_num'], down_num) + if which_model['use_ConvDownsampling']: + use_ConvDownsampling = True + down_first = True + down_num = 1 + if which_model['down_first']: + down_first = True + + netG = InvRescaleNet(opt_net['in_nc'], opt_net['out_nc'], subnet(subnet_type, init, gc=gc), opt_net['block_num'], down_num, use_ConvDownsampling=use_ConvDownsampling, down_first=down_first, down_scale=opt_net['scale']) return netG +def define_R(opt): + opt_net = opt['network_R'] + return RRDBNet(opt_net['in_nc'], opt_net['out_nc'], opt_net['nf'], opt_net['nb'], opt_net['gc']) + + +# for Invertible decolorization-colorization +def define_grey(opt): + opt_net = opt['network_grey'] + which_model = opt_net['which_model'] + rgb_type = which_model['rgb_type'] + subnet_type = which_model['subnet_type'] + if opt_net['init']: + init = opt_net['init'] + else: + init = 'xavier' + + Conv1x1Grey_learnable = True + if which_model['Conv1x1Grey_learnable'] == False: + Conv1x1Grey_learnable = False + + net_grey = InvGreyNet(rgb_type, subnet(subnet_type, init), opt_net['block_num'], Conv1x1Grey_learnable) + + return net_grey + + #### Discriminator def define_D(opt): opt_net = opt['network_D'] diff --git a/codes/options/options.py b/codes/options/options.py index 9943269..c037931 100644 --- a/codes/options/options.py +++ b/codes/options/options.py @@ -65,7 +65,8 @@ def parse(opt_path, is_train=True): # network if opt['distortion'] == 'sr': - opt['network_G']['scale'] = scale + if 'network_G' in opt.keys(): + opt['network_G']['scale'] = scale return opt diff --git a/codes/options/test/test_IRN+_x4.yml b/codes/options/test/test_IRN+_x4.yml index 38b89d6..cab60da 100644 --- a/codes/options/test/test_IRN+_x4.yml +++ b/codes/options/test/test_IRN+_x4.yml @@ -4,7 +4,7 @@ model: IRN+ distortion: sr scale: 4 crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels -gpu_ids: [1] +gpu_ids: [0] datasets: test_1: # the 1st test dataset diff --git a/codes/options/test/test_IRN-Compression_x2_q90.yml b/codes/options/test/test_IRN-Compression_x2_q90.yml new file mode 100644 index 0000000..2206417 --- /dev/null +++ b/codes/options/test/test_IRN-Compression_x2_q90.yml @@ -0,0 +1,62 @@ +name: IRN-CRM_x2_q90 +suffix: ~ # add suffix to saved images +model: IRN-CRM +distortion: sr +scale: 2 +crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels +gpu_ids: [0] + +datasets: + test_1: # the 1st test dataset + name: set5 + mode: LQGT + dataroot_GT: ~ # path to test HR images + dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader + test_2: # the 2st test dataset + name: set14 + mode: LQGT + dataroot_GT: ~ # path to test HR images + dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader + test_3: # the 3st test dataset + name: B100 + mode: LQGT + dataroot_GT: ~ # path to test HR images + dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader + test_4: # the 3st test dataset + name: Urban100 + mode: LQGT + dataroot_GT: ~ # path to test HR images + dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader + test_5: + name: val_DIV2K + mode: LQGT + dataroot_GT: ~ # path to test HR images + dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader + + +#### network +network_G: + which_model_G: + subnet_type: DBNet + in_nc: 3 + out_nc: 3 + block_num: [8] + scale: 2 + init: xavier + +network_R: + in_nc: 3 + out_nc: 3 + nf: 64 + nb: 8 + gc: 32 + + +test: + jpg_quality: 90 + + +#### path +path: + pretrain_model_G: ../experiments/pretrained_models/IRN_x2_finetune.pth + pretrain_model_R: ../experiments/pretrained_models/CRM_x2_q90.pth diff --git a/codes/options/test/test_IRN-Compression_x2_q90_kodak.yml b/codes/options/test/test_IRN-Compression_x2_q90_kodak.yml new file mode 100644 index 0000000..5c1e480 --- /dev/null +++ b/codes/options/test/test_IRN-Compression_x2_q90_kodak.yml @@ -0,0 +1,42 @@ +name: IRN-CRM_x2_q90 +suffix: ~ # add suffix to saved images +model: IRN-CRM +distortion: sr +scale: 2 +crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels +gpu_ids: [0] + +datasets: + test_1: # the 1st test dataset + name: Kodak + mode: LQGT + dataroot_GT: ~ # path to test HR images + dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader + + +#### network +network_G: + which_model_G: + subnet_type: DBNet + in_nc: 3 + out_nc: 3 + block_num: [8] + scale: 2 + init: xavier + +network_R: + in_nc: 3 + out_nc: 3 + nf: 64 + nb: 8 + gc: 32 + + +test: + jpg_quality: 90 + + +#### path +path: + pretrain_model_G: ../experiments/pretrained_models/IRN_x2_finetune_Kodak.pth + pretrain_model_R: ../experiments/pretrained_models/CRM_x2_q90.pth diff --git a/codes/options/test/test_IRN_color.yml b/codes/options/test/test_IRN_color.yml new file mode 100644 index 0000000..0b96241 --- /dev/null +++ b/codes/options/test/test_IRN_color.yml @@ -0,0 +1,49 @@ +name: IRN_color +suffix: ~ # add suffix to saved images +model: IRN-Color +distortion: sr +scale: 2 +crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels +gpu_ids: [0] + +datasets: + test_1: # the 1st test dataset + name: set5 + mode: LQGT + dataroot_GT: ~ # path to test HR images + use_grey: True + test_2: # the 2st test dataset + name: set14 + mode: LQGT + dataroot_GT: ~ # path to test HR images + use_grey: True + test_3: # the 3st test dataset + name: B100 + mode: LQGT + dataroot_GT: ~ # path to test HR images + use_grey: True + test_4: # the 3st test dataset + name: Urban100 + mode: LQGT + dataroot_GT: ~ # path to test HR images + use_grey: True + test_5: + name: val_DIV2K + mode: LQGT + dataroot_GT: ~ # path to test HR images + use_grey: True + + +#### network +network_grey: + which_model: + rgb_type: RGB + subnet_type: DBNet + block_num: [8] + init: xavier + Conv1x1Grey_learnable: True + + +#### path +path: + pretrain_model_G: ../experiments/pretrained_models/IRN_color.pth diff --git a/codes/options/test/test_IRN_x3.yml b/codes/options/test/test_IRN_x3.yml new file mode 100644 index 0000000..17a0947 --- /dev/null +++ b/codes/options/test/test_IRN_x3.yml @@ -0,0 +1,51 @@ +name: IRN_x3 +suffix: ~ # add suffix to saved images +model: IRN +distortion: sr +scale: 3 +crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels +gpu_ids: [0] + +datasets: + test_1: # the 1st test dataset + name: set5 + mode: LQGT + dataroot_GT: ~ # path to test HR images + dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader + test_2: # the 2st test dataset + name: set14 + mode: LQGT + dataroot_GT: ~ # path to test HR images + dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader + test_3: # the 3st test dataset + name: B100 + mode: LQGT + dataroot_GT: ~ # path to test HR images + dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader + test_4: # the 3st test dataset + name: Urban100 + mode: LQGT + dataroot_GT: ~ # path to test HR images + dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader + test_5: + name: val_DIV2K + mode: LQGT + dataroot_GT: ~ # path to test HR images + dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader + + +#### network +network_G: + which_model_G: + subnet_type: DBNet + use_ConvDownsampling: True + in_nc: 3 + out_nc: 3 + block_num: [12] + scale: 3 + init: xavier + + +#### path +path: + pretrain_model_G: ../experiments/pretrained_models/IRN_x3.pth diff --git a/codes/options/test/test_IRN_x4.yml b/codes/options/test/test_IRN_x4.yml index b0f6830..6cb4a57 100644 --- a/codes/options/test/test_IRN_x4.yml +++ b/codes/options/test/test_IRN_x4.yml @@ -4,7 +4,7 @@ model: IRN distortion: sr scale: 4 crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels -gpu_ids: [1] +gpu_ids: [0] datasets: test_1: # the 1st test dataset diff --git a/codes/options/test/test_IRN_x8.yml b/codes/options/test/test_IRN_x8.yml new file mode 100644 index 0000000..7659b53 --- /dev/null +++ b/codes/options/test/test_IRN_x8.yml @@ -0,0 +1,50 @@ +name: IRN_x8 +suffix: ~ # add suffix to saved images +model: IRN +distortion: sr +scale: 8 +crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels +gpu_ids: [0] + +datasets: + test_1: # the 1st test dataset + name: set5 + mode: LQGT + dataroot_GT: ~ # path to test HR images + dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader + test_2: # the 2st test dataset + name: set14 + mode: LQGT + dataroot_GT: ~ # path to test HR images + dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader + test_3: # the 3st test dataset + name: B100 + mode: LQGT + dataroot_GT: ~ # path to test HR images + dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader + test_4: # the 3st test dataset + name: Urban100 + mode: LQGT + dataroot_GT: ~ # path to test HR images + dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader + test_5: + name: val_DIV2K + mode: LQGT + dataroot_GT: ~ # path to test HR images + dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader + + +#### network +network_G: + which_model_G: + subnet_type: DBNet + in_nc: 3 + out_nc: 3 + block_num: [8, 8, 8] + scale: 8 + init: xavier + + +#### path +path: + pretrain_model_G: ../experiments/pretrained_models/IRN_x8.pth diff --git a/codes/options/train/train_IRN+_x4.yml b/codes/options/train/train_IRN+_x4.yml index 71e38f2..6dc3eeb 100644 --- a/codes/options/train/train_IRN+_x4.yml +++ b/codes/options/train/train_IRN+_x4.yml @@ -6,7 +6,7 @@ use_tb_logger: true model: IRN+ distortion: sr scale: 4 -gpu_ids: [2] +gpu_ids: [0] #### datasets diff --git a/codes/options/train/train_IRN-Compression_x2_q90.yml b/codes/options/train/train_IRN-Compression_x2_q90.yml new file mode 100644 index 0000000..6c42bc3 --- /dev/null +++ b/codes/options/train/train_IRN-Compression_x2_q90.yml @@ -0,0 +1,107 @@ + +#### general settings + +name: 01_IRN-CRM_DB_x2_q90_scratch_DIV2K +use_tb_logger: true +model: IRN-CRM +distortion: sr +scale: 2 +gpu_ids: [0] + + +#### datasets + +datasets: + train: + name: DIV2K + mode: LQGT + dataroot_GT: ~ # path to training HR images + dataroot_LQ: ~ # path to training reference LR images, not necessary, if not provided, LR images will be generated in dataloader + + use_shuffle: true + n_workers: 6 # per GPU + batch_size: 16 + GT_size: 144 + use_flip: true + use_rot: true + color: RGB + + val: + name: val_DIV2K + mode: LQGT + dataroot_GT: ~ # path to validation HR images + dataroot_LQ: ~ # path to validation reference LR images, not necessary, if not provided, LR images will be generated in dataloader + + +#### network structures + +network_G: + which_model_G: + subnet_type: DBNet + in_nc: 3 + out_nc: 3 + block_num: [8] + scale: 2 + init: xavier + +network_R: + in_nc: 3 + out_nc: 3 + nf: 64 + nb: 8 + gc: 32 + + +#### path + +path: + pretrain_model_G: ../experiments/pretrained_models/IRN_x2_finetune.pth + strict_load: true + resume_state: ~ + + +#### training settings: learning rate scheme, loss + +train: + lr_G: !!float 2e-4 + lr_R: !!float 2e-4 + beta1: 0.9 + beta1_R: 0.9 + beta2: 0.999 + beta2_R: 0.999 + niter: 50000 + warmup_iter: -1 # no warm up + + lr_scheme: MultiStepLR + lr_steps: [10000, 20000, 30000, 40000] + lr_gamma: 0.5 + + pixel_criterion_forw: l2 + pixel_criterion_back: l1 + pixel_criterion_jpeg: l2 + + manual_seed: 10 + + val_freq: !!float 5e3 + + lambda_fit_forw: 4. + lambda_rec_back: 1 + lambda_ce_forw: 1 + weight_decay_G: !!float 1e-5 + gradient_clipping: 10 + + jpg_quality: 90 + lambda_rec_jpeg: 1 + lambda_joint_back: 0 + only_jpeg_reconstruction: True + add_joint_loss: False + +test: + jpg_quality: 90 + + +#### logger + +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 diff --git a/codes/options/train/train_IRN_color.yml b/codes/options/train/train_IRN_color.yml new file mode 100644 index 0000000..c532635 --- /dev/null +++ b/codes/options/train/train_IRN_color.yml @@ -0,0 +1,86 @@ + +#### general settings + +name: 01_IRNcolor_DB_scratch_DIV2K +use_tb_logger: true +model: IRN-Color +distortion: sr +scale: 2 +gpu_ids: [0] + + +#### datasets + +datasets: + train: + name: DIV2K + mode: LQGT + dataroot_GT: ~ # path to training HR images + + use_shuffle: true + n_workers: 6 # per GPU + batch_size: 8 + GT_size: 144 + use_flip: true + use_rot: true + color: RGB + use_grey: True + + val: + name: val_DIV2K + mode: LQGT + dataroot_GT: ~ # path to validation HR images + use_grey: True + + +#### network structures + +network_grey: + which_model: + rgb_type: RGB + subnet_type: DBNet + block_num: [8] + init: xavier + Conv1x1Grey_learnable: True + + +#### path + +path: + pretrain_model_G: ~ + strict_load: true + resume_state: ~ + + +#### training settings: learning rate scheme, loss + +train: + lr_G: !!float 2e-4 + beta1: 0.9 + beta2: 0.999 + niter: 500000 + warmup_iter: -1 # no warm up + + lr_scheme: MultiStepLR + lr_steps: [100000, 200000, 300000, 400000] + lr_gamma: 0.5 + + pixel_criterion_forw: l2 + pixel_criterion_back: l1 + + manual_seed: 10 + + val_freq: !!float 5e3 + + lambda_fit_forw: 3. + lambda_rec_back: 1 + lambda_ce_forw: 1 + weight_decay_G: !!float 1e-5 + gradient_clipping: 10 + + +#### logger + +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 diff --git a/codes/options/train/train_IRN_x2_finetune.yml b/codes/options/train/train_IRN_x2_finetune.yml new file mode 100644 index 0000000..eac1138 --- /dev/null +++ b/codes/options/train/train_IRN_x2_finetune.yml @@ -0,0 +1,91 @@ + +#### general settings + +name: 01_IRN_DB_x2_scratch_DIV2K +use_tb_logger: true +model: IRN +distortion: sr +scale: 2 +gpu_ids: [0] + + +#### datasets + +datasets: + train: + name: DIV2K + mode: LQGT + dataroot_GT: ~ # path to training HR images + dataroot_LQ: ~ # path to training reference LR images, not necessary, if not provided, LR images will be generated in dataloader + + use_shuffle: true + n_workers: 6 # per GPU + batch_size: 16 + GT_size: 144 + use_flip: true + use_rot: true + color: RGB + + val: + name: val_DIV2K + mode: LQGT + dataroot_GT: ~ # path to validation HR images + dataroot_LQ: ~ # path to validation reference LR images, not necessary, if not provided, LR images will be generated in dataloader + + +#### network structures + +network_G: + which_model_G: + subnet_type: DBNet + in_nc: 3 + out_nc: 3 + block_num: [8] + scale: 2 + init: xavier + + +#### path + +path: + pretrain_model_G: ../experiments/pretrained_models/IRN_x2.pth + strict_load: true + resume_state: ~ + + +#### training settings: learning rate scheme, loss + +train: + lr_G: !!float 1e-5 + beta1: 0.9 + beta2: 0.999 + niter: 100000 + warmup_iter: -1 # no warm up + + lr_scheme: MultiStepLR + lr_steps: [100000] + lr_gamma: 0.5 + + pixel_criterion_forw: l1 + pixel_criterion_back: l1 + + manual_seed: 10 + + val_freq: !!float 5e3 + + lambda_fit_forw: 4. + lambda_rec_back: 1 + lambda_ce_forw: 1 + weight_decay_G: !!float 1e-5 + gradient_clipping: 10 + + add_noise_on_y: True + y_noise_prob: 0.5 + y_noise_scale: 0.01 + + +#### logger + +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 diff --git a/codes/options/train/train_IRN_x2_finetune_kodak.yml b/codes/options/train/train_IRN_x2_finetune_kodak.yml new file mode 100644 index 0000000..2d73d71 --- /dev/null +++ b/codes/options/train/train_IRN_x2_finetune_kodak.yml @@ -0,0 +1,91 @@ + +#### general settings + +name: 01_IRN_DB_x2_finetune_kodak_DIV2K +use_tb_logger: true +model: IRN +distortion: sr +scale: 2 +gpu_ids: [0] + + +#### datasets + +datasets: + train: + name: Kodak + mode: LQGT + dataroot_GT: /home/mqxiao/data/Kodak/ # path to training HR images + dataroot_LQ: ~ # path to training reference LR images, not necessary, if not provided, LR images will be generated in dataloader + + use_shuffle: true + n_workers: 6 # per GPU + batch_size: 16 + GT_size: 144 + use_flip: true + use_rot: true + color: RGB + + val: + name: Kodak + mode: LQGT + dataroot_GT: /home/mqxiao/data/Kodak/ # path to validation HR images + dataroot_LQ: ~ # path to validation reference LR images, not necessary, if not provided, LR images will be generated in dataloader + + +#### network structures + +network_G: + which_model_G: + subnet_type: DBNet + in_nc: 3 + out_nc: 3 + block_num: [8] + scale: 2 + init: xavier + + +#### path + +path: + pretrain_model_G: /home/mqxiao/Invertible-Image-Rescaling-master/experiments/pretrained_models/IRN_x2_finetune.pth + strict_load: true + resume_state: ~ + + +#### training settings: learning rate scheme, loss + +train: + lr_G: !!float 1e-5 + beta1: 0.9 + beta2: 0.999 + niter: 5000 + warmup_iter: -1 # no warm up + + lr_scheme: MultiStepLR + lr_steps: [100000] + lr_gamma: 0.5 + + pixel_criterion_forw: l1 + pixel_criterion_back: l1 + + manual_seed: 10 + + val_freq: !!float 5e3 + + lambda_fit_forw: 4. + lambda_rec_back: 1 + lambda_ce_forw: 1 + weight_decay_G: !!float 1e-5 + gradient_clipping: 10 + + add_noise_on_y: True + y_noise_prob: 0.5 + y_noise_scale: 0.01 + + +#### logger + +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 diff --git a/codes/options/train/train_IRN_x3.yml b/codes/options/train/train_IRN_x3.yml new file mode 100644 index 0000000..366d48d --- /dev/null +++ b/codes/options/train/train_IRN_x3.yml @@ -0,0 +1,88 @@ + +#### general settings + +name: 01_IRN_DB_x3_scratch_DIV2K +use_tb_logger: true +model: IRN +distortion: sr +scale: 3 +gpu_ids: [0] + + +#### datasets + +datasets: + train: + name: DIV2K + mode: LQGT + dataroot_GT: ~ # path to training HR images + dataroot_LQ: ~ # path to training reference LR images, not necessary, if not provided, LR images will be generated in dataloader + + use_shuffle: true + n_workers: 6 # per GPU + batch_size: 16 + GT_size: 144 + use_flip: true + use_rot: true + color: RGB + + val: + name: val_DIV2K + mode: LQGT + dataroot_GT: ~ # path to validation HR images + dataroot_LQ: ~ # path to validation reference LR images, not necessary, if not provided, LR images will be generated in dataloader + + +#### network structures + +network_G: + which_model_G: + subnet_type: DBNet + use_ConvDownsampling: True + in_nc: 3 + out_nc: 3 + block_num: [12] + scale: 3 + init: xavier + + +#### path + +path: + pretrain_model_G: ~ + strict_load: true + resume_state: ~ + + +#### training settings: learning rate scheme, loss + +train: + lr_G: !!float 2e-4 + beta1: 0.9 + beta2: 0.999 + niter: 500000 + warmup_iter: -1 # no warm up + + lr_scheme: MultiStepLR + lr_steps: [100000, 200000, 300000, 400000] + lr_gamma: 0.5 + + pixel_criterion_forw: l2 + pixel_criterion_back: l1 + + manual_seed: 10 + + val_freq: !!float 5e3 + + lambda_fit_forw: 9. + lambda_rec_back: 1 + lambda_ce_forw: 1 + weight_decay_G: !!float 1e-5 + gradient_clipping: 10 + + +#### logger + +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 diff --git a/codes/options/train/train_IRN_x8.yml b/codes/options/train/train_IRN_x8.yml new file mode 100644 index 0000000..924c908 --- /dev/null +++ b/codes/options/train/train_IRN_x8.yml @@ -0,0 +1,87 @@ + +#### general settings + +name: 01_IRN_DB_x8_scratch_DIV2K +use_tb_logger: true +model: IRN +distortion: sr +scale: 8 +gpu_ids: [0] + + +#### datasets + +datasets: + train: + name: DIV2K + mode: LQGT + dataroot_GT: ~ # path to training HR images + dataroot_LQ: ~ # path to training reference LR images, not necessary, if not provided, LR images will be generated in dataloader + + use_shuffle: true + n_workers: 6 # per GPU + batch_size: 16 + GT_size: 144 + use_flip: true + use_rot: true + color: RGB + + val: + name: val_DIV2K + mode: LQGT + dataroot_GT: ~ # path to validation HR images + dataroot_LQ: ~ # path to validation reference LR images, not necessary, if not provided, LR images will be generated in dataloader + + +#### network structures + +network_G: + which_model_G: + subnet_type: DBNet + in_nc: 3 + out_nc: 3 + block_num: [8, 8, 8] + scale: 8 + init: xavier + + +#### path + +path: + pretrain_model_G: ~ + strict_load: true + resume_state: ~ + + +#### training settings: learning rate scheme, loss + +train: + lr_G: !!float 2e-4 + beta1: 0.9 + beta2: 0.999 + niter: 500000 + warmup_iter: -1 # no warm up + + lr_scheme: MultiStepLR + lr_steps: [100000, 200000, 300000, 400000] + lr_gamma: 0.5 + + pixel_criterion_forw: l2 + pixel_criterion_back: l1 + + manual_seed: 10 + + val_freq: !!float 5e3 + + lambda_fit_forw: 64. + lambda_rec_back: 1 + lambda_ce_forw: 1 + weight_decay_G: !!float 1e-5 + gradient_clipping: 10 + + +#### logger + +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 diff --git a/codes/test_IRN-Color.py b/codes/test_IRN-Color.py new file mode 100644 index 0000000..0f98112 --- /dev/null +++ b/codes/test_IRN-Color.py @@ -0,0 +1,127 @@ +import os.path as osp +import logging +import time +import argparse +from collections import OrderedDict + +import numpy as np +import options.options as option +import utils.util as util +from data.util import bgr2ycbcr +from data import create_dataset, create_dataloader +from models import create_model + +#### options +parser = argparse.ArgumentParser() +parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.') +opt = option.parse(parser.parse_args().opt, is_train=False) +opt = option.dict_to_nonedict(opt) + +util.mkdirs( + (path for key, path in opt['path'].items() + if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) +util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, + screen=True, tofile=True) +logger = logging.getLogger('base') +logger.info(option.dict2str(opt)) + +#### Create test dataset and dataloader +test_loaders = [] +for phase, dataset_opt in sorted(opt['datasets'].items()): + test_set = create_dataset(dataset_opt) + test_loader = create_dataloader(test_set, dataset_opt) + logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) + test_loaders.append(test_loader) + +model = create_model(opt) +for test_loader in test_loaders: + test_set_name = test_loader.dataset.opt['name'] + logger.info('\nTesting [{:s}]...'.format(test_set_name)) + test_start_time = time.time() + dataset_dir = osp.join(opt['path']['results_root'], test_set_name) + util.mkdir(dataset_dir) + + test_results = OrderedDict() + test_results['psnr'] = [] + test_results['ssim'] = [] + + test_results['psnr_grey'] = [] + test_results['ssim_grey'] = [] + + for data in test_loader: + model.feed_data(data) + img_path = data['GT_path'][0] + img_name = osp.splitext(osp.basename(img_path))[0] + + model.test() + visuals = model.get_current_visuals() + + color_img = util.tensor2img(visuals['Color']) # uint8 + gt_img = util.tensor2img(visuals['GT']) # uint8 + grey_img = util.tensor2img(visuals['Grey']) # uint8 + greygt_img = util.tensor2img(visuals['Grey_ref']) # uint8 + + # save images + suffix = opt['suffix'] + if suffix: + save_img_path = osp.join(dataset_dir, img_name + suffix + '.png') + else: + save_img_path = osp.join(dataset_dir, img_name + '.png') + util.save_img(color_img, save_img_path) + + if suffix: + save_img_path = osp.join(dataset_dir, img_name + suffix + '_GT.png') + else: + save_img_path = osp.join(dataset_dir, img_name + '_GT.png') + util.save_img(gt_img, save_img_path) + + if suffix: + save_img_path = osp.join(dataset_dir, img_name + suffix + '_Grey.png') + else: + save_img_path = osp.join(dataset_dir, img_name + '_Grey.png') + util.save_img(grey_img, save_img_path) + + if suffix: + save_img_path = osp.join(dataset_dir, img_name + suffix + '_Grey_ref.png') + else: + save_img_path = osp.join(dataset_dir, img_name + '_Grey_ref.png') + util.save_img(greygt_img, save_img_path) + + # calculate PSNR and SSIM + gt_img = gt_img / 255. + color_img = color_img / 255. + + grey_img = grey_img / 255. + greygt_img = greygt_img / 255. + + crop_border = opt['crop_border'] if opt['crop_border'] else opt['scale'] + if crop_border == 0: + cropped_color_img = color_img + cropped_gt_img = gt_img + else: + cropped_color_img = color_img[crop_border:-crop_border, crop_border:-crop_border, :] + cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :] + + psnr = util.calculate_psnr(cropped_color_img * 255, cropped_gt_img * 255) + ssim = util.calculate_ssim(cropped_color_img * 255, cropped_gt_img * 255) + test_results['psnr'].append(psnr) + test_results['ssim'].append(ssim) + + # PSNR and SSIM for grey + psnr_grey = util.calculate_psnr(grey_img * 255, greygt_img * 255) + ssim_grey = util.calculate_ssim(grey_img * 255, greygt_img * 255) + test_results['psnr_grey'].append(psnr_grey) + test_results['ssim_grey'].append(ssim_grey) + + logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}. Grey PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim, psnr_grey, ssim_grey)) + + # Average PSNR/SSIM results + ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) + ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) + + ave_psnr_grey = sum(test_results['psnr_grey']) / len(test_results['psnr_grey']) + ave_ssim_grey = sum(test_results['ssim_grey']) / len(test_results['ssim_grey']) + + logger.info( + '----Average PSNR/SSIM results for {}----\n\tpsnr: {:.6f} db; ssim: {:.6f}. Grey psnr: {:.6f} db; ssim: {:.6f}.\n'.format( + test_set_name, ave_psnr, ave_ssim, ave_psnr_grey, ave_ssim_grey)) diff --git a/codes/test_IRN-Compression.py b/codes/test_IRN-Compression.py new file mode 100644 index 0000000..1909b18 --- /dev/null +++ b/codes/test_IRN-Compression.py @@ -0,0 +1,188 @@ +import os.path as osp +import logging +import time +import argparse +from collections import OrderedDict + +import numpy as np +import options.options as option +import utils.util as util +from data.util import bgr2ycbcr +from data import create_dataset, create_dataloader +from models import create_model + +import cv2 + +#### options +parser = argparse.ArgumentParser() +parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.') +opt = option.parse(parser.parse_args().opt, is_train=False) +opt = option.dict_to_nonedict(opt) + +util.mkdirs( + (path for key, path in opt['path'].items() + if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) +util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, + screen=True, tofile=True) +logger = logging.getLogger('base') +logger.info(option.dict2str(opt)) + +#### Create test dataset and dataloader +test_loaders = [] +for phase, dataset_opt in sorted(opt['datasets'].items()): + test_set = create_dataset(dataset_opt) + test_loader = create_dataloader(test_set, dataset_opt) + logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) + test_loaders.append(test_loader) + +model = create_model(opt) +for test_loader in test_loaders: + test_set_name = test_loader.dataset.opt['name'] + logger.info('\nTesting [{:s}]...'.format(test_set_name)) + test_start_time = time.time() + dataset_dir = osp.join(opt['path']['results_root'], test_set_name) + util.mkdir(dataset_dir) + + test_results = OrderedDict() + test_results['psnr'] = [] + test_results['ssim'] = [] + test_results['psnr_y'] = [] + test_results['ssim_y'] = [] + + test_results['psnr_lr'] = [] + test_results['ssim_lr'] = [] + test_results['psnr_y_lr'] = [] + test_results['ssim_y_lr'] = [] + + test_results['bpp'] = [] + test_results['psnr_restore'] = [] + + for data in test_loader: + model.feed_data(data) + img_path = data['GT_path'][0] + img_name = osp.splitext(osp.basename(img_path))[0] + + model.test() + visuals = model.get_current_visuals() + + sr_img = util.tensor2img(visuals['SR']) # uint8 + srgt_img = util.tensor2img(visuals['GT']) # uint8 + lr_img = util.tensor2img(visuals['LR']) # uint8 + lrgt_img = util.tensor2img(visuals['LR_ref']) # uint8 + rlr_img = util.tensor2img(visuals['RLR']) # uint8 + + # save images + suffix = opt['suffix'] + if suffix: + save_img_path = osp.join(dataset_dir, img_name + suffix + '.png') + else: + save_img_path = osp.join(dataset_dir, img_name + '.png') + util.save_img(sr_img, save_img_path) + + if suffix: + save_img_path = osp.join(dataset_dir, img_name + suffix + '_GT.png') + else: + save_img_path = osp.join(dataset_dir, img_name + '_GT.png') + util.save_img(srgt_img, save_img_path) + + if suffix: + save_img_path = osp.join(dataset_dir, img_name + suffix + '_LR.jpg') + else: + save_img_path = osp.join(dataset_dir, img_name + '_LR.jpg') + quality = opt['test']['jpg_quality'] + cv2.imwrite(save_img_path, lr_img, [int(cv2.IMWRITE_JPEG_QUALITY),quality]) + + # get lr image size + lr_size = osp.getsize(save_img_path) + + if suffix: + save_img_path = osp.join(dataset_dir, img_name + suffix + '_LR_ref.png') + else: + save_img_path = osp.join(dataset_dir, img_name + '_LR_ref.png') + util.save_img(lrgt_img, save_img_path) + + # calculate PSNR, SSIM and bpp + h, w = srgt_img.shape[0], srgt_img.shape[1] + bpp = lr_size * 8. / (h * w) + test_results['bpp'].append(bpp) + + sr_img = sr_img / 255. + srgt_img = srgt_img / 255. + + lr_img = lr_img / 255. + lrgt_img = lrgt_img / 255. + + crop_border = opt['crop_border'] if opt['crop_border'] else opt['scale'] + if crop_border == 0: + cropped_sr_img = sr_img + cropped_srgt_img = srgt_img + else: + cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border, :] + cropped_srgt_img = srgt_img[crop_border:-crop_border, crop_border:-crop_border, :] + + psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_srgt_img * 255) + ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_srgt_img * 255) + test_results['psnr'].append(psnr) + test_results['ssim'].append(ssim) + + # PSNR and SSIM for LR + psnr_lr = util.calculate_psnr(lr_img * 255, lrgt_img * 255) + ssim_lr = util.calculate_ssim(lr_img * 255, lrgt_img * 255) + test_results['psnr_lr'].append(psnr_lr) + test_results['ssim_lr'].append(ssim_lr) + + psnr_restore = util.calculate_psnr(lr_img * 255, rlr_img) + test_results['psnr_restore'].append(psnr_restore) + + if srgt_img.shape[2] == 3: # RGB image + sr_img_y = bgr2ycbcr(sr_img, only_y=True) + srgt_img_y = bgr2ycbcr(srgt_img, only_y=True) + if crop_border == 0: + cropped_sr_img_y = sr_img_y + cropped_srgt_img_y = srgt_img_y + else: + cropped_sr_img_y = sr_img_y[crop_border:-crop_border, crop_border:-crop_border] + cropped_srgt_img_y = srgt_img_y[crop_border:-crop_border, crop_border:-crop_border] + psnr_y = util.calculate_psnr(cropped_sr_img_y * 255, cropped_srgt_img_y * 255) + ssim_y = util.calculate_ssim(cropped_sr_img_y * 255, cropped_srgt_img_y * 255) + test_results['psnr_y'].append(psnr_y) + test_results['ssim_y'].append(ssim_y) + + lr_img_y = bgr2ycbcr(lr_img, only_y=True) + lrgt_img_y = bgr2ycbcr(lrgt_img, only_y=True) + psnr_y_lr = util.calculate_psnr(lr_img_y * 255, lrgt_img_y * 255) + ssim_y_lr = util.calculate_ssim(lr_img_y * 255, lrgt_img_y * 255) + test_results['psnr_y_lr'].append(psnr_y_lr) + test_results['ssim_y_lr'].append(ssim_y_lr) + + logger.info( + '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}. LR PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}. PSNR Restore: {:.6f}, bpp: {:.6f}.'. + format(img_name, psnr, ssim, psnr_y, ssim_y, psnr_lr, ssim_lr, psnr_y_lr, ssim_y_lr, psnr_restore, bpp)) + else: + logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}. LR PSNR: {:.6f} dB; SSIM: {:.6f}. PSNR Restore: {:.6f}, bpp: {:.6f}.'.format(img_name, psnr, ssim, psnr_lr, ssim_lr, psnr_restore, bpp)) + + # Average PSNR/SSIM results + ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) + ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) + + ave_psnr_lr = sum(test_results['psnr_lr']) / len(test_results['psnr_lr']) + ave_ssim_lr = sum(test_results['ssim_lr']) / len(test_results['ssim_lr']) + + ave_psnr_restore = sum(test_results['psnr_restore']) / len(test_results['psnr_restore']) + ave_bpp = sum(test_results['bpp']) / len(test_results['bpp']) + + logger.info( + '----Average PSNR/SSIM results for {}----\n\tpsnr: {:.6f} db; ssim: {:.6f}. LR psnr: {:.6f} db; ssim: {:.6f}.\n'.format( + test_set_name, ave_psnr, ave_ssim, ave_psnr_lr, ave_ssim_lr)) + if test_results['psnr_y'] and test_results['ssim_y']: + ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y']) + ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y']) + + ave_psnr_y_lr = sum(test_results['psnr_y_lr']) / len(test_results['psnr_y_lr']) + ave_ssim_y_lr = sum(test_results['ssim_y_lr']) / len(test_results['ssim_y_lr']) + logger.info( + '----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}. LR PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.\n'. + format(ave_psnr_y, ave_ssim_y, ave_psnr_y_lr, ave_ssim_y_lr)) + logger.info( + '----Average PSNR Restore and bpp----\n\tpsnr_restore: {:.6f} db; bpp: {:.6f}.\n'.format( + ave_psnr_restore, ave_bpp)) diff --git a/codes/train_IRN-Color.py b/codes/train_IRN-Color.py new file mode 100644 index 0000000..af88b0e --- /dev/null +++ b/codes/train_IRN-Color.py @@ -0,0 +1,247 @@ +import os +import math +import argparse +import random +import logging + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from data.data_sampler import DistIterSampler + +import options.options as option +from utils import util +from data import create_dataloader, create_dataset +from models import create_model + + +def init_dist(backend='nccl', **kwargs): + ''' initialization for distributed training''' + # if mp.get_start_method(allow_none=True) is None: + if mp.get_start_method(allow_none=True) != 'spawn': + mp.set_start_method('spawn') + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def main(): + #### options + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, help='Path to option YMAL file.') + parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + opt = option.parse(args.opt, is_train=True) + + #### distributed training settings + if args.launcher == 'none': # disabled distributed training + opt['dist'] = False + rank = -1 + print('Disabled distributed training.') + else: + opt['dist'] = True + init_dist() + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + #### loading resume state if exists + if opt['path'].get('resume_state', None): + # distributed resuming: all load into default GPU + device_id = torch.cuda.current_device() + resume_state = torch.load(opt['path']['resume_state'], + map_location=lambda storage, loc: storage.cuda(device_id)) + option.check_resume(opt, resume_state['iter']) # check resume options + else: + resume_state = None + + #### mkdir and loggers + if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) + if resume_state is None: + util.mkdir_and_rename( + opt['path']['experiments_root']) # rename experiment folder if exists + util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' + and 'pretrain_model' not in key and 'resume' not in key)) + + # config loggers. Before it, the log will not work + util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, + screen=True, tofile=True) + util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, + screen=True, tofile=True) + logger = logging.getLogger('base') + logger.info(option.dict2str(opt)) + # tensorboard logger + if opt['use_tb_logger'] and 'debug' not in opt['name']: + version = float(torch.__version__[0:3]) + if version >= 1.1: # PyTorch 1.1 + from torch.utils.tensorboard import SummaryWriter + else: + logger.info( + 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) + from tensorboardX import SummaryWriter + tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) + else: + util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) + logger = logging.getLogger('base') + + # convert to NoneDict, which returns None for missing keys + opt = option.dict_to_nonedict(opt) + + #### random seed + seed = opt['train']['manual_seed'] + if seed is None: + seed = random.randint(1, 10000) + if rank <= 0: + logger.info('Random seed: {}'.format(seed)) + util.set_random_seed(seed) + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + #### create train and val dataloader + dataset_ratio = 200 # enlarge the size of each epoch + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + train_set = create_dataset(dataset_opt) + train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) + total_iters = int(opt['train']['niter']) + total_epochs = int(math.ceil(total_iters / train_size)) + if opt['dist']: + train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) + total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) + else: + train_sampler = None + train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) + if rank <= 0: + logger.info('Number of train images: {:,d}, iters: {:,d}'.format( + len(train_set), train_size)) + logger.info('Total epochs needed: {:d} for iters {:,d}'.format( + total_epochs, total_iters)) + elif phase == 'val': + val_set = create_dataset(dataset_opt) + val_loader = create_dataloader(val_set, dataset_opt, opt, None) + if rank <= 0: + logger.info('Number of val images in [{:s}]: {:d}'.format( + dataset_opt['name'], len(val_set))) + else: + raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) + assert train_loader is not None + + #### create model + model = create_model(opt) + + #### resume training + if resume_state: + logger.info('Resuming training from epoch: {}, iter: {}.'.format( + resume_state['epoch'], resume_state['iter'])) + + start_epoch = resume_state['epoch'] + current_step = resume_state['iter'] + model.resume_training(resume_state) # handle optimizers and schedulers + else: + current_step = 0 + start_epoch = 0 + + #### training + logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) + for epoch in range(start_epoch, total_epochs + 1): + if opt['dist']: + train_sampler.set_epoch(epoch) + for _, train_data in enumerate(train_loader): + current_step += 1 + if current_step > total_iters: + break + #### training + model.feed_data(train_data) + model.optimize_parameters(current_step) + + #### update learning rate + model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) + + #### log + if current_step % opt['logger']['print_freq'] == 0: + logs = model.get_current_log() + message = ' '.format( + epoch, current_step, model.get_current_learning_rate()) + for k, v in logs.items(): + message += '{:s}: {:.4e} '.format(k, v) + # tensorboard logger + if opt['use_tb_logger'] and 'debug' not in opt['name']: + if rank <= 0: + tb_logger.add_scalar(k, v, current_step) + if rank <= 0: + logger.info(message) + + # validation + if current_step % opt['train']['val_freq'] == 0 and rank <= 0: + avg_psnr = 0.0 + idx = 0 + for val_data in val_loader: + idx += 1 + img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0] + img_dir = os.path.join(opt['path']['val_images'], img_name) + util.mkdir(img_dir) + + model.feed_data(val_data) + model.test() + + visuals = model.get_current_visuals() + color_img = util.tensor2img(visuals['Color']) # uint8 + gt_img = util.tensor2img(visuals['GT']) # uint8 + + grey_img = util.tensor2img(visuals['Grey']) + + gtgrey_img = util.tensor2img(visuals['Grey_ref']) + + # Save Color images for reference + save_img_path = os.path.join(img_dir, + '{:s}_{:d}.png'.format(img_name, current_step)) + util.save_img(color_img, save_img_path) + + # Save Grey images + save_img_path_L = os.path.join(img_dir, '{:s}_forwGrey_{:d}.png'.format(img_name, current_step)) + util.save_img(grey_img, save_img_path_L) + + # Save ground truth + if current_step == opt['train']['val_freq']: + save_img_path_gt = os.path.join(img_dir, '{:s}_GT_{:d}.png'.format(img_name, current_step)) + util.save_img(gt_img, save_img_path_gt) + save_img_path_gtl = os.path.join(img_dir, '{:s}_Grey_ref_{:d}.png'.format(img_name, current_step)) + util.save_img(gtgrey_img, save_img_path_gtl) + + # calculate PSNR + crop_size = opt['scale'] + gt_img = gt_img / 255. + color_img = color_img / 255. + cropped_color_img = color_img[crop_size:-crop_size, crop_size:-crop_size, :] + cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] + avg_psnr += util.calculate_psnr(cropped_color_img * 255, cropped_gt_img * 255) + + avg_psnr = avg_psnr / idx + + # log + logger.info('# Validation # PSNR: {:.4e}.'.format(avg_psnr)) + logger_val = logging.getLogger('val') # validation logger + logger_val.info(' psnr: {:.4e}.'.format( + epoch, current_step, avg_psnr)) + # tensorboard logger + if opt['use_tb_logger'] and 'debug' not in opt['name']: + tb_logger.add_scalar('psnr', avg_psnr, current_step) + + #### save models and training states + if current_step % opt['logger']['save_checkpoint_freq'] == 0: + if rank <= 0: + logger.info('Saving models and training states.') + model.save(current_step) + model.save_training_state(epoch, current_step) + + if rank <= 0: + logger.info('Saving the final model.') + model.save('latest') + logger.info('End of training.') + + +if __name__ == '__main__': + main() diff --git a/codes/train_IRN-Compression.py b/codes/train_IRN-Compression.py new file mode 100644 index 0000000..6523e61 --- /dev/null +++ b/codes/train_IRN-Compression.py @@ -0,0 +1,266 @@ +import os +import math +import argparse +import random +import logging + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from data.data_sampler import DistIterSampler + +import options.options as option +from utils import util +from data import create_dataloader, create_dataset +from models import create_model +import cv2 + + +def init_dist(backend='nccl', **kwargs): + ''' initialization for distributed training''' + # if mp.get_start_method(allow_none=True) is None: + if mp.get_start_method(allow_none=True) != 'spawn': + mp.set_start_method('spawn') + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def main(): + #### options + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, help='Path to option YMAL file.') + parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + opt = option.parse(args.opt, is_train=True) + + #### distributed training settings + if args.launcher == 'none': # disabled distributed training + opt['dist'] = False + rank = -1 + print('Disabled distributed training.') + else: + opt['dist'] = True + init_dist() + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + #### loading resume state if exists + if opt['path'].get('resume_state', None): + # distributed resuming: all load into default GPU + device_id = torch.cuda.current_device() + resume_state = torch.load(opt['path']['resume_state'], + map_location=lambda storage, loc: storage.cuda(device_id)) + option.check_resume(opt, resume_state['iter']) # check resume options + else: + resume_state = None + + #### mkdir and loggers + if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) + if resume_state is None: + util.mkdir_and_rename( + opt['path']['experiments_root']) # rename experiment folder if exists + util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' + and 'pretrain_model' not in key and 'resume' not in key)) + + # config loggers. Before it, the log will not work + util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, + screen=True, tofile=True) + util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, + screen=True, tofile=True) + logger = logging.getLogger('base') + logger.info(option.dict2str(opt)) + # tensorboard logger + if opt['use_tb_logger'] and 'debug' not in opt['name']: + version = float(torch.__version__[0:3]) + if version >= 1.1: # PyTorch 1.1 + from torch.utils.tensorboard import SummaryWriter + else: + logger.info( + 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) + from tensorboardX import SummaryWriter + tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) + else: + util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) + logger = logging.getLogger('base') + + # convert to NoneDict, which returns None for missing keys + opt = option.dict_to_nonedict(opt) + + #### random seed + seed = opt['train']['manual_seed'] + if seed is None: + seed = random.randint(1, 10000) + if rank <= 0: + logger.info('Random seed: {}'.format(seed)) + util.set_random_seed(seed) + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + #### create train and val dataloader + dataset_ratio = 200 # enlarge the size of each epoch + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + train_set = create_dataset(dataset_opt) + train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) + total_iters = int(opt['train']['niter']) + total_epochs = int(math.ceil(total_iters / train_size)) + if opt['dist']: + train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) + total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) + else: + train_sampler = None + train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) + if rank <= 0: + logger.info('Number of train images: {:,d}, iters: {:,d}'.format( + len(train_set), train_size)) + logger.info('Total epochs needed: {:d} for iters {:,d}'.format( + total_epochs, total_iters)) + elif phase == 'val': + val_set = create_dataset(dataset_opt) + val_loader = create_dataloader(val_set, dataset_opt, opt, None) + if rank <= 0: + logger.info('Number of val images in [{:s}]: {:d}'.format( + dataset_opt['name'], len(val_set))) + else: + raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) + assert train_loader is not None + + #### create model + model = create_model(opt) + + #### resume training + if resume_state: + logger.info('Resuming training from epoch: {}, iter: {}.'.format( + resume_state['epoch'], resume_state['iter'])) + + start_epoch = resume_state['epoch'] + current_step = resume_state['iter'] + model.resume_training(resume_state) # handle optimizers and schedulers + else: + current_step = 0 + start_epoch = 0 + + #### training + logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) + for epoch in range(start_epoch, total_epochs + 1): + if opt['dist']: + train_sampler.set_epoch(epoch) + for _, train_data in enumerate(train_loader): + current_step += 1 + if current_step > total_iters: + break + #### training + model.feed_data(train_data) + model.optimize_parameters(current_step) + + #### update learning rate + model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) + + #### log + if current_step % opt['logger']['print_freq'] == 0: + logs = model.get_current_log() + message = ' '.format( + epoch, current_step, model.get_current_learning_rate()) + for k, v in logs.items(): + message += '{:s}: {:.4e} '.format(k, v) + # tensorboard logger + if opt['use_tb_logger'] and 'debug' not in opt['name']: + if rank <= 0: + tb_logger.add_scalar(k, v, current_step) + if rank <= 0: + logger.info(message) + + # validation + if current_step % opt['train']['val_freq'] == 0 and rank <= 0: + avg_psnr = 0.0 + avg_psnr_jpeg = 0.0 + avg_bpp = 0.0 + idx = 0 + for val_data in val_loader: + idx += 1 + img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0] + img_dir = os.path.join(opt['path']['val_images'], img_name) + util.mkdir(img_dir) + + model.feed_data(val_data) + model.test() + + visuals = model.get_current_visuals() + sr_img = util.tensor2img(visuals['SR']) # uint8 + gt_img = util.tensor2img(visuals['GT']) # uint8 + + lr_img = util.tensor2img(visuals['LR']) + rlr_img = util.tensor2img(visuals['RLR']) + + gtl_img = util.tensor2img(visuals['LR_ref']) + + # Save SR images for reference + save_img_path = os.path.join(img_dir, + '{:s}_{:d}.png'.format(img_name, current_step)) + util.save_img(sr_img, save_img_path) + + # Save LR images + save_img_path_L = os.path.join(img_dir, '{:s}_forwLR_{:d}.png'.format(img_name, current_step)) + util.save_img(lr_img, save_img_path_L) + + save_img_path_jL = os.path.join(img_dir, '{:s}_jpegLR_{:d}.jpg'.format(img_name, current_step)) + quality = opt['train']['jpg_quality'] + cv2.imwrite(save_img_path_jL, lr_img, [int(cv2.IMWRITE_JPEG_QUALITY),quality]) + lr_size = os.path.getsize(save_img_path_jL) + bpp = lr_size * 8. / (gt_img.shape[0] * gt_img.shape[1]) + avg_bpp += bpp + + # Save ground truth + if current_step == opt['train']['val_freq']: + save_img_path_gt = os.path.join(img_dir, '{:s}_GT_{:d}.png'.format(img_name, current_step)) + util.save_img(gt_img, save_img_path_gt) + save_img_path_gtl = os.path.join(img_dir, '{:s}_LR_ref_{:d}.png'.format(img_name, current_step)) + util.save_img(gtl_img, save_img_path_gtl) + + # calculate PSNR + crop_size = opt['scale'] + gt_img = gt_img / 255. + sr_img = sr_img / 255. + cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] + cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] + avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) + + lr_img = lr_img / 255. + rlr_img = rlr_img / 255. + cropped_lr_img = lr_img[crop_size:-crop_size, crop_size:-crop_size, :] + cropped_rlr_img = rlr_img[crop_size:-crop_size, crop_size:-crop_size, :] + avg_psnr_jpeg += util.calculate_psnr(cropped_lr_img * 255, cropped_rlr_img * 255) + + avg_psnr = avg_psnr / idx + avg_psnr_jpeg = avg_psnr_jpeg / idx + avg_bpp = avg_bpp / idx + + # log + logger.info('# Validation # PSNR: {:.4e}, JPEG Restore PSNR: {:.4e}, bpp: {:.4e}.'.format(avg_psnr, avg_psnr_jpeg, avg_bpp)) + logger_val = logging.getLogger('val') # validation logger + logger_val.info(' psnr: {:.4e}, jpeg restore psnr: {:.4e}, bpp: {:.4e}.'.format( + epoch, current_step, avg_psnr, avg_psnr_jpeg, avg_bpp)) + # tensorboard logger + if opt['use_tb_logger'] and 'debug' not in opt['name']: + tb_logger.add_scalar('psnr', avg_psnr, current_step) + + #### save models and training states + if current_step % opt['logger']['save_checkpoint_freq'] == 0: + if rank <= 0: + logger.info('Saving models and training states.') + model.save(current_step) + model.save_training_state(epoch, current_step) + + if rank <= 0: + logger.info('Saving the final model.') + model.save('latest') + logger.info('End of training.') + + +if __name__ == '__main__': + main()