Skip to content

Commit

Permalink
dev(narugo): add cdc upscaler
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Oct 13, 2023
1 parent 44f3a41 commit ed12b29
Show file tree
Hide file tree
Showing 13 changed files with 172 additions and 6 deletions.
15 changes: 15 additions & 0 deletions docs/source/api_doc/upscale/cdc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
imgutils.upscale.cdc
====================================

.. currentmodule:: imgutils.upscale.cdc

.. automodule:: imgutils.upscale.cdc


upscale_with_cdc
---------------------------

.. autofunction:: upscale_with_cdc



44 changes: 44 additions & 0 deletions docs/source/api_doc/upscale/cdc_benchmark.plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os.path
import random

from huggingface_hub import HfFileSystem

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.upscale.cdc import upscale_with_cdc

hf_fs = HfFileSystem()
repository = 'deepghs/cdc_anime_onnx'
_CDC_MODELS = [
os.path.splitext(os.path.relpath(file, repository))[0]
for file in hf_fs.glob(f'{repository}/*.onnx')
]


class CDCUpscalerBenchmark(BaseBenchmark):
def __init__(self, model: str):
BaseBenchmark.__init__(self)
self.model = model

def load(self):
from imgutils.upscale.cdc import _open_cdc_upscaler_model
_open_cdc_upscaler_model(self.model)

def unload(self):
from imgutils.upscale.cdc import _open_cdc_upscaler_model
_open_cdc_upscaler_model.cache_clear()

def run(self):
image_file = random.choice(self.all_images)
_ = upscale_with_cdc(image_file, model=self.model)


if __name__ == '__main__':
create_plot_cli(
[
(model, CDCUpscalerBenchmark(model))
for model in _CDC_MODELS
],
title='Benchmark for CDCUpscaler Models',
run_times=5,
try_times=10,
)()
35 changes: 35 additions & 0 deletions docs/source/api_doc/upscale/cdc_demo.plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os

from huggingface_hub import HfFileSystem

from imgutils.upscale import upscale_with_cdc
from imgutils.upscale.cdc import _open_cdc_upscaler_model
from plot import image_plot

hf_fs = HfFileSystem()
repository = 'deepghs/cdc_anime_onnx'
_CDC_MODELS = [
os.path.splitext(os.path.relpath(file, repository))[0]
for file in hf_fs.glob(f'{repository}/*.onnx')
]

if __name__ == '__main__':
demo_images = [
('sample/original.png', 'Small Logo'),
('sample/skadi.jpg', 'Illustration'),
('sample/hutao.png', 'Large Illustration'),
('sample/xx.jpg', 'Illustration #2'),
]

items = []
for file, title in demo_images:
items.append((file, title))
for model in _CDC_MODELS:
_, scale = _open_cdc_upscaler_model(model)
items.append((upscale_with_cdc(file, model=model), f'{title}\n({scale}X By {model})'))

image_plot(
*items,
columns=len(_CDC_MODELS) + 1,
figsize=(2 * (len(_CDC_MODELS) + 1), 3 * len(demo_images)),
)
13 changes: 13 additions & 0 deletions docs/source/api_doc/upscale/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
imgutils.upscale
========================

.. currentmodule:: imgutils.upscale

.. automodule:: imgutils.upscale


.. toctree::
:maxdepth: 3

cdc

Binary file added docs/source/api_doc/upscale/sample/hutao.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/api_doc/upscale/sample/original.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/api_doc/upscale/sample/skadi.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/api_doc/upscale/sample/xx.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ configuration file's structure and their versions.
api_doc/restore/index
api_doc/segment/index
api_doc/tagging/index
api_doc/upscale/index
api_doc/utils/index
api_doc/validate/index

Expand Down
9 changes: 6 additions & 3 deletions imgutils/restore/nafnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def _open_nafnet_model(model: NafNetModelTyping):


def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS',
tile_size: int = 256, tile_overlap: int = 16, silent: bool = False) -> Image.Image:
tile_size: int = 256, tile_overlap: int = 16, batch_size: int = 4,
silent: bool = False) -> Image.Image:
"""
Restore an image using the NAFNet model.
Expand All @@ -60,6 +61,8 @@ def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS',
:type tile_size: int
:param tile_overlap: The overlap between tiles. Default is 16.
:type tile_overlap: int
:param batch_size: The batch size of inference. Default is 4.
:type batch_size: int
:param silent: If True, the progress will not be displayed. Default is False.
:type silent: bool
:return: The restored image.
Expand All @@ -75,8 +78,8 @@ def _method(ix):

output_ = area_batch_run(
input_, _method,
tile_size=tile_size, tile_overlap=tile_overlap, silent=silent,
process_title='NafNet Restore',
tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size,
silent=silent, process_title='NafNet Restore',
)
output_ = np.clip(output_, a_min=0.0, a_max=1.0)
return Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')
9 changes: 6 additions & 3 deletions imgutils/restore/scunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def _open_scunet_model(model: SCUNetModelTyping):


def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN',
tile_size: int = 128, tile_overlap: int = 16, silent: bool = False) -> Image.Image:
tile_size: int = 128, tile_overlap: int = 16, batch_size: int = 4,
silent: bool = False) -> Image.Image:
"""
Restore an image using the SCUNet model.
Expand All @@ -56,6 +57,8 @@ def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN',
:type tile_size: int
:param tile_overlap: The overlap between tiles. Default is 16.
:type tile_overlap: int
:param batch_size: The batch size of inference. Default is 4.
:type batch_size: int
:param silent: If True, the progress will not be displayed. Default is False.
:type silent: bool
:return: The restored image.
Expand All @@ -71,8 +74,8 @@ def _method(ix):

output_ = area_batch_run(
input_, _method,
tile_size=tile_size, tile_overlap=tile_overlap, silent=silent,
process_title='SCUNet Restore',
tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size,
silent=silent, process_title='SCUNet Restore',
)
output_ = np.clip(output_, a_min=0.0, a_max=1.0)
return Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')
1 change: 1 addition & 0 deletions imgutils/upscale/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .cdc import upscale_with_cdc

Check warning on line 1 in imgutils/upscale/__init__.py

View check run for this annotation

Codecov / codecov/patch

imgutils/upscale/__init__.py#L1

Added line #L1 was not covered by tests
51 changes: 51 additions & 0 deletions imgutils/upscale/cdc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from functools import lru_cache
from typing import Tuple

Check warning on line 2 in imgutils/upscale/cdc.py

View check run for this annotation

Codecov / codecov/patch

imgutils/upscale/cdc.py#L1-L2

Added lines #L1 - L2 were not covered by tests

import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
from onnxruntime import InferenceSession

Check warning on line 7 in imgutils/upscale/cdc.py

View check run for this annotation

Codecov / codecov/patch

imgutils/upscale/cdc.py#L4-L7

Added lines #L4 - L7 were not covered by tests

from ..data import ImageTyping, load_image
from ..utils import open_onnx_model, area_batch_run

Check warning on line 10 in imgutils/upscale/cdc.py

View check run for this annotation

Codecov / codecov/patch

imgutils/upscale/cdc.py#L9-L10

Added lines #L9 - L10 were not covered by tests


@lru_cache()
def _open_cdc_upscaler_model(model: str) -> Tuple[InferenceSession, int]:
ort = open_onnx_model(hf_hub_download(

Check warning on line 15 in imgutils/upscale/cdc.py

View check run for this annotation

Codecov / codecov/patch

imgutils/upscale/cdc.py#L13-L15

Added lines #L13 - L15 were not covered by tests
f'deepghs/cdc_anime_onnx',
f'{model}.onnx'
))

input_ = np.random.randn(1, 3, 16, 16).astype(np.float32)
output_, = ort.run(['output'], {'input': input_})

Check warning on line 21 in imgutils/upscale/cdc.py

View check run for this annotation

Codecov / codecov/patch

imgutils/upscale/cdc.py#L20-L21

Added lines #L20 - L21 were not covered by tests

batch, channels, scale_h, height, scale_w, width = output_.shape
assert batch == 1 and channels == 3 and height == 16 and width == 16, \

Check warning on line 24 in imgutils/upscale/cdc.py

View check run for this annotation

Codecov / codecov/patch

imgutils/upscale/cdc.py#L23-L24

Added lines #L23 - L24 were not covered by tests
f'Unexpected output size found {output_.shape!r}.'
assert scale_h == scale_w, f'Scale of height and width not match - {output_.shape!r}.'

Check warning on line 26 in imgutils/upscale/cdc.py

View check run for this annotation

Codecov / codecov/patch

imgutils/upscale/cdc.py#L26

Added line #L26 was not covered by tests

return ort, scale_h

Check warning on line 28 in imgutils/upscale/cdc.py

View check run for this annotation

Codecov / codecov/patch

imgutils/upscale/cdc.py#L28

Added line #L28 was not covered by tests


def upscale_with_cdc(image: ImageTyping, model: str = 'HGSR-MHR-anime-aug_X4_320',

Check warning on line 31 in imgutils/upscale/cdc.py

View check run for this annotation

Codecov / codecov/patch

imgutils/upscale/cdc.py#L31

Added line #L31 was not covered by tests
tile_size: int = 512, tile_overlap: int = 64, batch_size: int = 1,
silent: bool = False) -> Image.Image:
image = load_image(image, mode='RGB', force_background='white')
input_ = np.array(image).astype(np.float32) / 255.0
input_ = input_.transpose((2, 0, 1))[None, ...]

Check warning on line 36 in imgutils/upscale/cdc.py

View check run for this annotation

Codecov / codecov/patch

imgutils/upscale/cdc.py#L34-L36

Added lines #L34 - L36 were not covered by tests

ort, scale = _open_cdc_upscaler_model(model)

Check warning on line 38 in imgutils/upscale/cdc.py

View check run for this annotation

Codecov / codecov/patch

imgutils/upscale/cdc.py#L38

Added line #L38 was not covered by tests

def _method(ix):
ox, = ort.run(['output'], {'input': ix.astype(np.float32)})
batch, channels, scale_, height, scale_, width = ox.shape
return ox.reshape((batch, channels, scale_ * height, scale_ * width))

Check warning on line 43 in imgutils/upscale/cdc.py

View check run for this annotation

Codecov / codecov/patch

imgutils/upscale/cdc.py#L40-L43

Added lines #L40 - L43 were not covered by tests

output_ = area_batch_run(

Check warning on line 45 in imgutils/upscale/cdc.py

View check run for this annotation

Codecov / codecov/patch

imgutils/upscale/cdc.py#L45

Added line #L45 was not covered by tests
input_, _method,
tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size,
scale=scale, silent=silent, process_title='CDC Upscale',
)
output_ = np.clip(output_, a_min=0.0, a_max=1.0)
return Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')

Check warning on line 51 in imgutils/upscale/cdc.py

View check run for this annotation

Codecov / codecov/patch

imgutils/upscale/cdc.py#L50-L51

Added lines #L50 - L51 were not covered by tests

0 comments on commit ed12b29

Please sign in to comment.