Skip to content

Commit

Permalink
IRN-IJCV
Browse files Browse the repository at this point in the history
  • Loading branch information
pkuxmq committed Oct 11, 2022
1 parent f424ffc commit 93aaa94
Show file tree
Hide file tree
Showing 32 changed files with 2,632 additions and 71 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Invertible Image Rescaling
This is the PyTorch implementation of paper: Invertible Image Rescaling (ECCV 2020 Oral). [arxiv](https://arxiv.org/abs/2005.05650).
This is the PyTorch implementation of paper: Invertible Image Rescaling (ECCV 2020 Oral). \[[link](https://link.springer.com/chapter/10.1007/978-3-030-58452-8_8)\]\[[arxiv](https://arxiv.org/abs/2005.05650)\].

**2022/10 Update**: Our paper "Invertible Rescaling Network and Its Extensions" has been accepted by IJCV. \[[link](https://link.springer.com/article/10.1007/s11263-022-01688-4)\]\[[arxiv](https://arxiv.org/abs/2210.04188)\]. We update the repository for experiments in the paper. The previous version can be found in the ECCV branch.

## Dependencies and Installation
- Python 3 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux))
Expand Down
27 changes: 24 additions & 3 deletions codes/README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,35 @@
# Training
# Training for image rescaling
First set a config file in options/train/, then run as following:

python train.py -opt options/train/train_IRN_x4.yml

# Test
# Testing for image rescaling
First set a config file in options/test/, then run as following:

python test.py -opt options/test/test_IRN_x4.yml

Pretrained models can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1-Rah2t-fk3uTcNagvTgTRlRTaK2dHktA?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1U38SjqVlqY5YVMsSFrkTsw) (extraction code: lukj).
# Training for image decolorization-colorization
First set a config file in options/train/, then run as following:

python train.py -opt options/train/train_IRN_color.yml

# Testing for image decolorization-colorization
First set a config file in options/test/, then run as following:

python test.py -opt options/test/test_IRN_color.yml

# Training for combination with image compression
First set a config file in options/train/, then run as following:

python train.py -opt options/train/train_IRN-Compression_x2_q90.yml

# Testing for combination with image compression
First set a config file in options/test/, then run as following:

python test.py -opt options/test/test_IRN-Compression_x2_q90.yml


Pretrained models can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1ym6DvYNQegDrOy_4z733HxrULa1XIN92?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/14OvTiJNhFpHHN2yU-h7vDg) (extraction code: rx0z).

# Code Framework
The code framework follows [BasicSR](https://github.com/xinntao/BasicSR/tree/master/codes). It mainly consists of four parts - `Config`, `Data`, `Model` and `Network`.
Expand Down
132 changes: 79 additions & 53 deletions codes/data/LQGT_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,17 @@ def __init__(self, opt):
len(self.paths_LQ), len(self.paths_GT))
self.random_scale_list = [1]

self.use_grey = False
if self.opt['use_grey']:
self.use_grey = True

def _init_lmdb(self):
# https://github.com/chainer/chainermn/issues/129
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
meminit=False)
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
meminit=False)
if not self.use_grey:
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
meminit=False)

def __getitem__(self, index):
if self.data_type == 'lmdb':
Expand All @@ -62,36 +67,37 @@ def __getitem__(self, index):
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]

# get LQ image
if self.paths_LQ:
LQ_path = self.paths_LQ[index]
if self.data_type == 'lmdb':
resolution = [int(s) for s in self.sizes_LQ[index].split('_')]
else:
resolution = None
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
else: # down-sampling on-the-fly
# randomly scale during training
if self.opt['phase'] == 'train':
random_scale = random.choice(self.random_scale_list)
H_s, W_s, _ = img_GT.shape

def _mod(n, random_scale, scale, thres):
rlt = int(n * random_scale)
rlt = (rlt // scale) * scale
return thres if rlt < thres else rlt

H_s = _mod(H_s, random_scale, scale, GT_size)
W_s = _mod(W_s, random_scale, scale, GT_size)
img_GT = cv2.resize(np.copy(img_GT), (W_s, H_s), interpolation=cv2.INTER_LINEAR)
# force to 3 channels
if img_GT.ndim == 2:
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)

H, W, _ = img_GT.shape
# using matlab imresize
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
if img_LQ.ndim == 2:
img_LQ = np.expand_dims(img_LQ, axis=2)
if not self.use_grey:
if self.paths_LQ:
LQ_path = self.paths_LQ[index]
if self.data_type == 'lmdb':
resolution = [int(s) for s in self.sizes_LQ[index].split('_')]
else:
resolution = None
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
else: # down-sampling on-the-fly
# randomly scale during training
if self.opt['phase'] == 'train':
random_scale = random.choice(self.random_scale_list)
H_s, W_s, _ = img_GT.shape

def _mod(n, random_scale, scale, thres):
rlt = int(n * random_scale)
rlt = (rlt // scale) * scale
return thres if rlt < thres else rlt

H_s = _mod(H_s, random_scale, scale, GT_size)
W_s = _mod(W_s, random_scale, scale, GT_size)
img_GT = cv2.resize(np.copy(img_GT), (W_s, H_s), interpolation=cv2.INTER_LINEAR)
# force to 3 channels
if img_GT.ndim == 2:
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)

H, W, _ = img_GT.shape
# using matlab imresize
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
if img_LQ.ndim == 2:
img_LQ = np.expand_dims(img_LQ, axis=2)

if self.opt['phase'] == 'train':
# if the image size is too small
Expand All @@ -100,39 +106,59 @@ def _mod(n, random_scale, scale, thres):
img_GT = cv2.resize(np.copy(img_GT), (GT_size, GT_size),
interpolation=cv2.INTER_LINEAR)
# using matlab imresize
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
if img_LQ.ndim == 2:
img_LQ = np.expand_dims(img_LQ, axis=2)

H, W, C = img_LQ.shape
LQ_size = GT_size // scale

# randomly crop
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
if not self.use_grey:
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
if img_LQ.ndim == 2:
img_LQ = np.expand_dims(img_LQ, axis=2)

if not self.use_grey:
H, W, C = img_LQ.shape
LQ_size = GT_size // scale

# randomly crop
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
else:
rnd_h_GT = random.randint(0, max(0, H - GT_size))
rnd_w_GT = random.randint(0, max(0, W - GT_size))
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]

# augmentation - flip, rotate
img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
self.opt['use_rot'])
if not self.use_grey:
img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
self.opt['use_rot'])
else:
img_GT = util.augment([img_GT], self.opt['use_flip'], self.opt['use_rot'])[0]

# change color space if necessary
if self.opt['color']:
img_LQ = util.channel_convert(C, self.opt['color'],
[img_LQ])[0] # TODO during val no definition
if not self.use_grey:
if self.opt['color']:
img_LQ = util.channel_convert(C, self.opt['color'],
[img_LQ])[0] # TODO during val no definition
if self.use_grey:
img_Grey = cv2.cvtColor(img_GT, cv2.COLOR_BGR2GRAY)

# BGR to RGB, HWC to CHW, numpy to tensor
if img_GT.shape[2] == 3:
img_GT = img_GT[:, :, [2, 1, 0]]
img_LQ = img_LQ[:, :, [2, 1, 0]]
if not self.use_grey:
img_LQ = img_LQ[:, :, [2, 1, 0]]
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
if not self.use_grey:
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
if self.use_grey:
img_Grey = torch.from_numpy(np.ascontiguousarray(np.expand_dims(img_Grey, 0))).float()

if LQ_path is None:
LQ_path = GT_path
return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path}

if not self.use_grey:
return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path}
else:
return {'Grey': img_Grey, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path}

def __len__(self):
return len(self.paths_GT)
Loading

0 comments on commit 93aaa94

Please sign in to comment.