diff --git a/docs/source/api_doc/upscale/cdc.rst b/docs/source/api_doc/upscale/cdc.rst new file mode 100644 index 00000000000..9d76f491f55 --- /dev/null +++ b/docs/source/api_doc/upscale/cdc.rst @@ -0,0 +1,15 @@ +imgutils.upscale.cdc +==================================== + +.. currentmodule:: imgutils.upscale.cdc + +.. automodule:: imgutils.upscale.cdc + + +upscale_with_cdc +--------------------------- + +.. autofunction:: upscale_with_cdc + + + diff --git a/docs/source/api_doc/upscale/cdc_benchmark.plot.py b/docs/source/api_doc/upscale/cdc_benchmark.plot.py new file mode 100644 index 00000000000..5190e374adb --- /dev/null +++ b/docs/source/api_doc/upscale/cdc_benchmark.plot.py @@ -0,0 +1,44 @@ +import os.path +import random + +from huggingface_hub import HfFileSystem + +from benchmark import BaseBenchmark, create_plot_cli +from imgutils.upscale.cdc import upscale_with_cdc + +hf_fs = HfFileSystem() +repository = 'deepghs/cdc_anime_onnx' +_CDC_MODELS = [ + os.path.splitext(os.path.relpath(file, repository))[0] + for file in hf_fs.glob(f'{repository}/*.onnx') +] + + +class CDCUpscalerBenchmark(BaseBenchmark): + def __init__(self, model: str): + BaseBenchmark.__init__(self) + self.model = model + + def load(self): + from imgutils.upscale.cdc import _open_cdc_upscaler_model + _open_cdc_upscaler_model(self.model) + + def unload(self): + from imgutils.upscale.cdc import _open_cdc_upscaler_model + _open_cdc_upscaler_model.cache_clear() + + def run(self): + image_file = random.choice(self.all_images) + _ = upscale_with_cdc(image_file, model=self.model) + + +if __name__ == '__main__': + create_plot_cli( + [ + (model, CDCUpscalerBenchmark(model)) + for model in _CDC_MODELS + ], + title='Benchmark for CDCUpscaler Models', + run_times=5, + try_times=10, + )() diff --git a/docs/source/api_doc/upscale/cdc_demo.plot.py b/docs/source/api_doc/upscale/cdc_demo.plot.py new file mode 100644 index 00000000000..1a3d8bdaa0e --- /dev/null +++ b/docs/source/api_doc/upscale/cdc_demo.plot.py @@ -0,0 +1,35 @@ +import os + +from huggingface_hub import HfFileSystem + +from imgutils.upscale import upscale_with_cdc +from imgutils.upscale.cdc import _open_cdc_upscaler_model +from plot import image_plot + +hf_fs = HfFileSystem() +repository = 'deepghs/cdc_anime_onnx' +_CDC_MODELS = [ + os.path.splitext(os.path.relpath(file, repository))[0] + for file in hf_fs.glob(f'{repository}/*.onnx') +] + +if __name__ == '__main__': + demo_images = [ + ('sample/original.png', 'Small Logo'), + ('sample/skadi.jpg', 'Illustration'), + ('sample/hutao.png', 'Large Illustration'), + ('sample/xx.jpg', 'Illustration #2'), + ] + + items = [] + for file, title in demo_images: + items.append((file, title)) + for model in _CDC_MODELS: + _, scale = _open_cdc_upscaler_model(model) + items.append((upscale_with_cdc(file, model=model), f'{title}\n({scale}X By {model})')) + + image_plot( + *items, + columns=len(_CDC_MODELS) + 1, + figsize=(2 * (len(_CDC_MODELS) + 1), 3 * len(demo_images)), + ) diff --git a/docs/source/api_doc/upscale/index.rst b/docs/source/api_doc/upscale/index.rst new file mode 100644 index 00000000000..7182d0cf806 --- /dev/null +++ b/docs/source/api_doc/upscale/index.rst @@ -0,0 +1,13 @@ +imgutils.upscale +======================== + +.. currentmodule:: imgutils.upscale + +.. automodule:: imgutils.upscale + + +.. toctree:: + :maxdepth: 3 + + cdc + diff --git a/docs/source/api_doc/upscale/sample/hutao.png b/docs/source/api_doc/upscale/sample/hutao.png new file mode 100644 index 00000000000..a8f42dce778 Binary files /dev/null and b/docs/source/api_doc/upscale/sample/hutao.png differ diff --git a/docs/source/api_doc/upscale/sample/original.png b/docs/source/api_doc/upscale/sample/original.png new file mode 100644 index 00000000000..4757b83a42f Binary files /dev/null and b/docs/source/api_doc/upscale/sample/original.png differ diff --git a/docs/source/api_doc/upscale/sample/skadi.jpg b/docs/source/api_doc/upscale/sample/skadi.jpg new file mode 100644 index 00000000000..a585ecb7c85 Binary files /dev/null and b/docs/source/api_doc/upscale/sample/skadi.jpg differ diff --git a/docs/source/api_doc/upscale/sample/xx.jpg b/docs/source/api_doc/upscale/sample/xx.jpg new file mode 100644 index 00000000000..dfa950eeab0 Binary files /dev/null and b/docs/source/api_doc/upscale/sample/xx.jpg differ diff --git a/docs/source/index.rst b/docs/source/index.rst index a72917758fd..ca7da07631e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -38,6 +38,7 @@ configuration file's structure and their versions. api_doc/restore/index api_doc/segment/index api_doc/tagging/index + api_doc/upscale/index api_doc/utils/index api_doc/validate/index diff --git a/imgutils/restore/nafnet.py b/imgutils/restore/nafnet.py index 1720629eee4..d565b0da904 100644 --- a/imgutils/restore/nafnet.py +++ b/imgutils/restore/nafnet.py @@ -48,7 +48,8 @@ def _open_nafnet_model(model: NafNetModelTyping): def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS', - tile_size: int = 256, tile_overlap: int = 16, silent: bool = False) -> Image.Image: + tile_size: int = 256, tile_overlap: int = 16, batch_size: int = 4, + silent: bool = False) -> Image.Image: """ Restore an image using the NAFNet model. @@ -60,6 +61,8 @@ def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS', :type tile_size: int :param tile_overlap: The overlap between tiles. Default is 16. :type tile_overlap: int + :param batch_size: The batch size of inference. Default is 4. + :type batch_size: int :param silent: If True, the progress will not be displayed. Default is False. :type silent: bool :return: The restored image. @@ -75,8 +78,8 @@ def _method(ix): output_ = area_batch_run( input_, _method, - tile_size=tile_size, tile_overlap=tile_overlap, silent=silent, - process_title='NafNet Restore', + tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size, + silent=silent, process_title='NafNet Restore', ) output_ = np.clip(output_, a_min=0.0, a_max=1.0) return Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB') diff --git a/imgutils/restore/scunet.py b/imgutils/restore/scunet.py index 069400b5fb6..863861020a6 100644 --- a/imgutils/restore/scunet.py +++ b/imgutils/restore/scunet.py @@ -44,7 +44,8 @@ def _open_scunet_model(model: SCUNetModelTyping): def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN', - tile_size: int = 128, tile_overlap: int = 16, silent: bool = False) -> Image.Image: + tile_size: int = 128, tile_overlap: int = 16, batch_size: int = 4, + silent: bool = False) -> Image.Image: """ Restore an image using the SCUNet model. @@ -56,6 +57,8 @@ def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN', :type tile_size: int :param tile_overlap: The overlap between tiles. Default is 16. :type tile_overlap: int + :param batch_size: The batch size of inference. Default is 4. + :type batch_size: int :param silent: If True, the progress will not be displayed. Default is False. :type silent: bool :return: The restored image. @@ -71,8 +74,8 @@ def _method(ix): output_ = area_batch_run( input_, _method, - tile_size=tile_size, tile_overlap=tile_overlap, silent=silent, - process_title='SCUNet Restore', + tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size, + silent=silent, process_title='SCUNet Restore', ) output_ = np.clip(output_, a_min=0.0, a_max=1.0) return Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB') diff --git a/imgutils/upscale/__init__.py b/imgutils/upscale/__init__.py new file mode 100644 index 00000000000..45ce3d4f775 --- /dev/null +++ b/imgutils/upscale/__init__.py @@ -0,0 +1 @@ +from .cdc import upscale_with_cdc diff --git a/imgutils/upscale/cdc.py b/imgutils/upscale/cdc.py new file mode 100644 index 00000000000..186e2903d3a --- /dev/null +++ b/imgutils/upscale/cdc.py @@ -0,0 +1,51 @@ +from functools import lru_cache +from typing import Tuple + +import numpy as np +from PIL import Image +from huggingface_hub import hf_hub_download +from onnxruntime import InferenceSession + +from ..data import ImageTyping, load_image +from ..utils import open_onnx_model, area_batch_run + + +@lru_cache() +def _open_cdc_upscaler_model(model: str) -> Tuple[InferenceSession, int]: + ort = open_onnx_model(hf_hub_download( + f'deepghs/cdc_anime_onnx', + f'{model}.onnx' + )) + + input_ = np.random.randn(1, 3, 16, 16).astype(np.float32) + output_, = ort.run(['output'], {'input': input_}) + + batch, channels, scale_h, height, scale_w, width = output_.shape + assert batch == 1 and channels == 3 and height == 16 and width == 16, \ + f'Unexpected output size found {output_.shape!r}.' + assert scale_h == scale_w, f'Scale of height and width not match - {output_.shape!r}.' + + return ort, scale_h + + +def upscale_with_cdc(image: ImageTyping, model: str = 'HGSR-MHR-anime-aug_X4_320', + tile_size: int = 512, tile_overlap: int = 64, batch_size: int = 1, + silent: bool = False) -> Image.Image: + image = load_image(image, mode='RGB', force_background='white') + input_ = np.array(image).astype(np.float32) / 255.0 + input_ = input_.transpose((2, 0, 1))[None, ...] + + ort, scale = _open_cdc_upscaler_model(model) + + def _method(ix): + ox, = ort.run(['output'], {'input': ix.astype(np.float32)}) + batch, channels, scale_, height, scale_, width = ox.shape + return ox.reshape((batch, channels, scale_ * height, scale_ * width)) + + output_ = area_batch_run( + input_, _method, + tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size, + scale=scale, silent=silent, process_title='CDC Upscale', + ) + output_ = np.clip(output_, a_min=0.0, a_max=1.0) + return Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')