Skip to content

Commit

Permalink
refactor(runner): improve pipeline NFSW return type
Browse files Browse the repository at this point in the history
This commit ensure that the NSFW flag is set as optional. This was done
since a None list can be returned.
  • Loading branch information
rickstaa committed May 31, 2024
1 parent a34622e commit f827d1c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
6 changes: 3 additions & 3 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from huggingface_hub import file_download, hf_hub_download
import torch
import PIL
from typing import List, Tuple
from typing import List, Tuple, Optional
import logging
import os

Expand Down Expand Up @@ -133,10 +133,10 @@ def __init__(self, model_id: str):

def __call__(
self, prompt: str, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[bool]]:
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
safety_check = kwargs.pop("safety_check", True)

seed = kwargs.pop("seed", None)
if seed is not None:
if isinstance(seed, int):
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(
Expand Down
15 changes: 9 additions & 6 deletions runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from huggingface_hub import file_download
import torch
import PIL
from typing import List, Tuple
from typing import List, Tuple, Optional
import logging
import os
import time
Expand Down Expand Up @@ -102,16 +102,20 @@ def __init__(self, model_id: str):
from app.pipelines.optim.deepcache import enable_deepcache

self.ldm = enable_deepcache(self.ldm)

safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)

def __call__(self, image: PIL.Image, **kwargs) -> Tuple[List[PIL.Image], List[bool]]:
def __call__(
self, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
safety_check = kwargs.pop("safety_check", True)

if "decode_chunk_size" not in kwargs:
# Decrease decode_chunk_size to reduce memory usage.
kwargs["decode_chunk_size"] = 4

seed = kwargs.pop("seed", None)
if seed is not None:
if isinstance(seed, int):
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(
Expand All @@ -121,8 +125,7 @@ def __call__(self, image: PIL.Image, **kwargs) -> Tuple[List[PIL.Image], List[bo
kwargs["generator"] = [
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

safety_check = kwargs.pop("safety_check", True)

if safety_check:
logger.info("checking input image for nsfw")
_, has_nsfw_concept = self._safety_checker.check_nsfw_images([image])
Expand Down

0 comments on commit f827d1c

Please sign in to comment.