Skip to content

Commit

Permalink
dev(narugo): add upscale
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Oct 13, 2023
1 parent a488bf4 commit 6495a5c
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 4 deletions.
17 changes: 15 additions & 2 deletions imgutils/upscale/cdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def _open_cdc_upscaler_model(model: str) -> Tuple[Any, int]:
return ort, scale_h


_CDC_INPUT_UNIT = 16


def upscale_with_cdc(image: ImageTyping, model: str = 'HGSR-MHR-anime-aug_X4_320',
tile_size: int = 512, tile_overlap: int = 64, batch_size: int = 1,
silent: bool = False) -> Image.Image:
Expand All @@ -37,9 +40,19 @@ def upscale_with_cdc(image: ImageTyping, model: str = 'HGSR-MHR-anime-aug_X4_320
ort, scale = _open_cdc_upscaler_model(model)

def _method(ix):
ox, = ort.run(['output'], {'input': ix.astype(np.float32)})
ix = ix.astype(np.float32)
batch, channels, height, width = ix.shape
p_height = 0 if height % _CDC_INPUT_UNIT == 0 else _CDC_INPUT_UNIT - (height % _CDC_INPUT_UNIT)
p_width = 0 if width % _CDC_INPUT_UNIT == 0 else _CDC_INPUT_UNIT - (width % _CDC_INPUT_UNIT)
if p_height > 0 or p_width > 0: # align to 16
ix = np.pad(ix, ((0, 0), (0, 0), (0, p_height), (0, p_width)), mode='reflect')
actual_height, actual_width = height, width

ox, = ort.run(['output'], {'input': ix})
batch, channels, scale_, height, scale_, width = ox.shape
return ox.reshape((batch, channels, scale_ * height, scale_ * width))
ox = ox.reshape((batch, channels, scale_ * height, scale_ * width))
ox = ox[..., :scale_ * actual_height, :scale_ * actual_width] # crop back
return ox

output_ = area_batch_run(
input_, _method,
Expand Down
4 changes: 2 additions & 2 deletions imgutils/utils/area.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def area_batch_run(origin_input: np.ndarray, func, scale: int = 1,

tile = min(tile_size, height, width)
stride = tile - tile_overlap
h_idx_list = list(range(0, height - tile, stride)) + [height - tile]
w_idx_list = list(range(0, width - tile, stride)) + [width - tile]
h_idx_list = sorted(set(list(range(0, height - tile, stride)) + [height - tile]))
w_idx_list = sorted(set(list(range(0, width - tile, stride)) + [width - tile]))
sum_ = np.zeros((batch, output_channels, height * scale, width * scale), dtype=origin_input.dtype)
weight = np.zeros_like(sum_, dtype=origin_input.dtype)

Expand Down
Binary file added test/testfile/surtr_logo_2x.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/testfile/surtr_logo_4x.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/testfile/surtr_logo_small_2x.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/testfile/surtr_logo_small_4x.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file added test/upscale/__init__.py
Empty file.
14 changes: 14 additions & 0 deletions test/upscale/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest

from imgutils.data import load_image
from test.testings import get_testfile


@pytest.fixture()
def sample_image():
yield load_image(get_testfile('surtr_logo.png'), mode='RGB', force_background='white')


@pytest.fixture()
def sample_image_small(sample_image):
yield sample_image.resize((127, 126))
34 changes: 34 additions & 0 deletions test/upscale/test_cdc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest
from PIL import Image

from imgutils.metrics import psnr
from imgutils.upscale import upscale_with_cdc


@pytest.mark.unittest
class TestUpscaleCDC:
def test_upscale_with_cdc_4x(self, sample_image):
assert psnr(
upscale_with_cdc(sample_image),
sample_image.resize((sample_image.width * 4, sample_image.height * 4), Image.LANCZOS)
) >= 34.5

def test_upscale_with_cdc_2x(self, sample_image):
assert psnr(
upscale_with_cdc(sample_image, model='HGSR-MHR_X2_1680'),
sample_image.resize((sample_image.width * 2, sample_image.height * 2), Image.LANCZOS)
) >= 35.5

def test_upscale_with_cdc_small_4x(self, sample_image_small, sample_image):
assert psnr(
upscale_with_cdc(sample_image_small)
.resize(sample_image.size, Image.LANCZOS),
sample_image,
) >= 28.5

def test_upscale_with_cdc_small_2x(self, sample_image_small, sample_image):
assert psnr(
upscale_with_cdc(sample_image_small, model='HGSR-MHR_X2_1680')
.resize(sample_image.size, Image.LANCZOS),
sample_image,
) >= 28.0

0 comments on commit 6495a5c

Please sign in to comment.