From f827d1c1af6ab4f26b04fd6a8b43c5070ee85597 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Fri, 31 May 2024 12:08:44 +0200 Subject: [PATCH] refactor(runner): improve pipeline NFSW return type This commit ensure that the NSFW flag is set as optional. This was done since a None list can be returned. --- runner/app/pipelines/image_to_image.py | 6 +++--- runner/app/pipelines/image_to_video.py | 15 +++++++++------ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/runner/app/pipelines/image_to_image.py b/runner/app/pipelines/image_to_image.py index 2409ccf8..9ddf687a 100644 --- a/runner/app/pipelines/image_to_image.py +++ b/runner/app/pipelines/image_to_image.py @@ -11,7 +11,7 @@ from huggingface_hub import file_download, hf_hub_download import torch import PIL -from typing import List, Tuple +from typing import List, Tuple, Optional import logging import os @@ -133,10 +133,10 @@ def __init__(self, model_id: str): def __call__( self, prompt: str, image: PIL.Image, **kwargs - ) -> Tuple[List[PIL.Image], List[bool]]: + ) -> Tuple[List[PIL.Image], List[Optional[bool]]]: + seed = kwargs.pop("seed", None) safety_check = kwargs.pop("safety_check", True) - seed = kwargs.pop("seed", None) if seed is not None: if isinstance(seed, int): kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed( diff --git a/runner/app/pipelines/image_to_video.py b/runner/app/pipelines/image_to_video.py index 4f0088c5..4cfaa4a7 100644 --- a/runner/app/pipelines/image_to_video.py +++ b/runner/app/pipelines/image_to_video.py @@ -5,7 +5,7 @@ from huggingface_hub import file_download import torch import PIL -from typing import List, Tuple +from typing import List, Tuple, Optional import logging import os import time @@ -102,16 +102,20 @@ def __init__(self, model_id: str): from app.pipelines.optim.deepcache import enable_deepcache self.ldm = enable_deepcache(self.ldm) - + safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower() self._safety_checker = SafetyChecker(device=safety_checker_device) - def __call__(self, image: PIL.Image, **kwargs) -> Tuple[List[PIL.Image], List[bool]]: + def __call__( + self, image: PIL.Image, **kwargs + ) -> Tuple[List[PIL.Image], List[Optional[bool]]]: + seed = kwargs.pop("seed", None) + safety_check = kwargs.pop("safety_check", True) + if "decode_chunk_size" not in kwargs: # Decrease decode_chunk_size to reduce memory usage. kwargs["decode_chunk_size"] = 4 - seed = kwargs.pop("seed", None) if seed is not None: if isinstance(seed, int): kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed( @@ -121,8 +125,7 @@ def __call__(self, image: PIL.Image, **kwargs) -> Tuple[List[PIL.Image], List[bo kwargs["generator"] = [ torch.Generator(get_torch_device()).manual_seed(s) for s in seed ] - - safety_check = kwargs.pop("safety_check", True) + if safety_check: logger.info("checking input image for nsfw") _, has_nsfw_concept = self._safety_checker.check_nsfw_images([image])