diff --git a/runner/app/pipelines/image_to_image.py b/runner/app/pipelines/image_to_image.py index d00e5da9..2409ccf8 100644 --- a/runner/app/pipelines/image_to_image.py +++ b/runner/app/pipelines/image_to_image.py @@ -1,5 +1,5 @@ from app.pipelines.base import Pipeline -from app.pipelines.util import get_torch_device, get_model_dir +from app.pipelines.util import get_torch_device, get_model_dir, SafetyChecker from diffusers import ( AutoPipelineForImage2Image, @@ -11,7 +11,7 @@ from huggingface_hub import file_download, hf_hub_download import torch import PIL -from typing import List +from typing import List, Tuple import logging import os @@ -128,7 +128,14 @@ def __init__(self, model_id: str): self.ldm = enable_deepcache(self.ldm) - def __call__(self, prompt: str, image: PIL.Image, **kwargs) -> List[PIL.Image]: + safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower() + self._safety_checker = SafetyChecker(device=safety_checker_device) + + def __call__( + self, prompt: str, image: PIL.Image, **kwargs + ) -> Tuple[List[PIL.Image], List[bool]]: + safety_check = kwargs.pop("safety_check", True) + seed = kwargs.pop("seed", None) if seed is not None: if isinstance(seed, int): @@ -170,7 +177,14 @@ def __call__(self, prompt: str, image: PIL.Image, **kwargs) -> List[PIL.Image]: # Default to 2step kwargs["num_inference_steps"] = 2 - return self.ldm(prompt, image=image, **kwargs).images + output = self.ldm(prompt, image=image, **kwargs) + + if safety_check: + _, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images) + else: + has_nsfw_concept = [None] * len(output.images) + + return output.images, has_nsfw_concept def __str__(self) -> str: return f"ImageToImagePipeline model_id={self.model_id}" diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index d96e68dc..411c2d28 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -1,19 +1,20 @@ -from app.pipelines.base import Pipeline -from app.pipelines.util import get_torch_device, get_model_dir +import logging +import os +from typing import List, Tuple, Optional +import PIL +import torch from diffusers import ( AutoPipelineForText2Image, + EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel, - EulerDiscreteScheduler, ) -from safetensors.torch import load_file from huggingface_hub import file_download, hf_hub_download -import torch -import PIL -from typing import List -import logging -import os +from safetensors.torch import load_file + +from app.pipelines.base import Pipeline +from app.pipelines.util import get_model_dir, get_torch_device, SafetyChecker logger = logging.getLogger(__name__) @@ -147,7 +148,14 @@ def __init__(self, model_id: str): self.ldm = enable_deepcache(self.ldm) - def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]: + safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower() + self._safety_checker = SafetyChecker(device=safety_checker_device) + + def __call__( + self, prompt: str, **kwargs + ) -> Tuple[List[PIL.Image], List[Optional[bool]]]: + safety_check = kwargs.pop("safety_check", True) + seed = kwargs.pop("seed", None) if seed is not None: if isinstance(seed, int): @@ -184,7 +192,14 @@ def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]: # Default to 2step kwargs["num_inference_steps"] = 2 - return self.ldm(prompt, **kwargs).images + output = self.ldm(prompt, **kwargs) + + if safety_check: + _, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images) + else: + has_nsfw_concept = [None] * len(output.images) + + return output.images, has_nsfw_concept def __str__(self) -> str: return f"TextToImagePipeline model_id={self.model_id}" diff --git a/runner/app/pipelines/util.py b/runner/app/pipelines/util.py index 5ce8d80c..2a73b798 100644 --- a/runner/app/pipelines/util.py +++ b/runner/app/pipelines/util.py @@ -1,6 +1,15 @@ import torch import os +import numpy as np +from torch import dtype as TorchDtype from pathlib import Path +from PIL import Image +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker +from transformers import CLIPFeatureExtractor +from typing import Optional +import logging + +logger = logging.getLogger(__name__) def get_model_dir() -> Path: @@ -18,3 +27,82 @@ def get_torch_device(): return torch.device("mps") else: return torch.device("cpu") + + +def validate_torch_device(device_name: str) -> bool: + """Checks if the given PyTorch device name is valid and available. + + Args: + device_name: Name of the device ('cuda:0', 'cuda', 'cpu'). + + Returns: + True if valid and available, False otherwise. + """ + try: + device = torch.device(device_name) + if device.type == "cuda": + # Check if CUDA is available and the specified index is within range + if device.index is None: + return torch.cuda.is_available() + else: + return device.index < torch.cuda.device_count() + return True + except RuntimeError: + return False + + +class SafetyChecker: + """Checks images for unsafe or inappropriate content using a pretrained model. + + Attributes: + device (str): Device for inference. + """ + + def __init__( + self, + device: Optional[str] = "cuda", + dtype: Optional[TorchDtype] = torch.float16, + ): + """Initializes the SafetyChecker. + + Args: + device: Device for inference. Defaults to "cuda". + dtype: Data type for inference. Defaults to `torch.float16`. + """ + device = device.lower() if device else device + if not validate_torch_device(device): + default_device = get_torch_device() + logger.warning( + f"Device '{device}' not found. Defaulting to '{default_device}'." + ) + device = default_device + + self.device = device + self._dtype = dtype + self._safety_checker = StableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker" + ).to(self.device) + self._feature_extractor = CLIPFeatureExtractor.from_pretrained( + "openai/clip-vit-base-patch32" + ) + + def check_nsfw_images( + self, images: list[Image.Image] + ) -> tuple[list[Image.Image], list[bool]]: + """Checks images for unsafe content. + + Args: + images: Images to check. + + Returns: + Tuple of images and corresponding NSFW flags. + """ + safety_checker_input = self._feature_extractor(images, return_tensors="pt").to( + self.device + ) + images_np = [np.array(img) for img in images] + _, has_nsfw_concept = self._safety_checker( + images=images_np, + clip_input=safety_checker_input.pixel_values.to(self._dtype), + ) + return images, has_nsfw_concept diff --git a/runner/app/routes/image_to_image.py b/runner/app/routes/image_to_image.py index 8a7b005d..6bd212f6 100644 --- a/runner/app/routes/image_to_image.py +++ b/runner/app/routes/image_to_image.py @@ -21,7 +21,7 @@ responses = {400: {"model": HTTPError}, 500: {"model": HTTPError}} -# TODO: Make model_id and other properties optional once Go codegen tool supports +# TODO: Make model_id and other None properties optional once Go codegen tool supports # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 @router.post("/image-to-image", response_model=ImageResponse, responses=responses) @router.post( @@ -37,6 +37,7 @@ async def image_to_image( strength: Annotated[float, Form()] = 0.8, guidance_scale: Annotated[float, Form()] = 7.5, negative_prompt: Annotated[str, Form()] = "", + safety_check: Annotated[bool, Form()] = True, seed: Annotated[int, Form()] = None, num_images_per_prompt: Annotated[int, Form()] = 1, pipeline: Pipeline = Depends(get_pipeline), @@ -76,12 +77,13 @@ async def image_to_image( image = img try: - images = pipeline( + images, has_nsfw_concept = pipeline( prompt=prompt, image=image, strength=strength, guidance_scale=guidance_scale, negative_prompt=negative_prompt, + safety_check=safety_check, seed=seed, num_images_per_prompt=num_images_per_prompt, ) @@ -97,7 +99,12 @@ async def image_to_image( seeds = [seeds] output_images = [] - for img, s in zip(images, seeds): - output_images.append({"url": image_to_data_url(img), "seed": s}) + for img, sd, is_nsfw in zip(images, seeds, has_nsfw_concept): + # TODO: Return None once Go codegen tool supports optional properties + # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 + is_nsfw = is_nsfw or False + output_images.append( + {"url": image_to_data_url(img), "seed": sd, "nsfw": is_nsfw} + ) return {"images": output_images} diff --git a/runner/app/routes/image_to_video.py b/runner/app/routes/image_to_video.py index c12e35cd..63f47021 100644 --- a/runner/app/routes/image_to_video.py +++ b/runner/app/routes/image_to_video.py @@ -21,7 +21,7 @@ responses = {400: {"model": HTTPError}, 500: {"model": HTTPError}} -# TODO: Make model_id and other properties optional once Go codegen tool supports +# TODO: Make model_id and other None properties optional once Go codegen tool supports # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 @router.post("/image-to-video", response_model=VideoResponse, responses=responses) @router.post( diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py index 061f5603..3a747ae1 100644 --- a/runner/app/routes/text_to_image.py +++ b/runner/app/routes/text_to_image.py @@ -15,14 +15,15 @@ class TextToImageParams(BaseModel): - # TODO: Make model_id and other properties optional once Go codegen tool supports - # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 + # TODO: Make model_id and other None properties optional once Go codegen tool + # supports OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 model_id: str = "" prompt: str height: int = None width: int = None guidance_scale: float = 7.5 negative_prompt: str = "" + safety_check: bool = True seed: int = None num_inference_steps: int = 50 # TODO: Make optional. num_images_per_prompt: int = 1 @@ -64,7 +65,7 @@ async def text_to_image( ] try: - images = pipeline(**params.model_dump()) + images, has_nsfw_concept = pipeline(**params.model_dump()) except Exception as e: logger.error(f"TextToImagePipeline error: {e}") logger.exception(e) @@ -77,7 +78,12 @@ async def text_to_image( seeds = [seeds] output_images = [] - for img, sd in zip(images, seeds): - output_images.append({"url": image_to_data_url(img), "seed": sd}) + for img, sd, is_nsfw in zip(images, seeds, has_nsfw_concept): + # TODO: Return None once Go codegen tool supports optional properties + # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 + is_nsfw = is_nsfw or False + output_images.append( + {"url": image_to_data_url(img), "seed": sd, "nsfw": is_nsfw} + ) return {"images": output_images} diff --git a/runner/app/routes/util.py b/runner/app/routes/util.py index c24353de..97f09db8 100644 --- a/runner/app/routes/util.py +++ b/runner/app/routes/util.py @@ -1,13 +1,17 @@ +import base64 import io +from typing import List + from PIL import Image -import base64 from pydantic import BaseModel -from typing import List class Media(BaseModel): url: str seed: int + # TODO: Make nsfw property optional once Go codegen tool supports + # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 + nsfw: bool class ImageResponse(BaseModel): diff --git a/runner/bench.py b/runner/bench.py index 96995f36..c286a9a0 100644 --- a/runner/bench.py +++ b/runner/bench.py @@ -1,16 +1,17 @@ import argparse +import os from time import time from typing import List + +import numpy as np import torch -from PIL import Image from app.main import load_pipeline from app.pipelines.base import Pipeline -from app.pipelines.text_to_image import TextToImagePipeline from app.pipelines.image_to_image import ImageToImagePipeline from app.pipelines.image_to_video import ImageToVideoPipeline +from app.pipelines.text_to_image import TextToImagePipeline +from PIL import Image from pydantic import BaseModel -import os -import numpy as np PROMPT = "a mountain lion" IMAGE = "images/test.png" @@ -47,6 +48,8 @@ def bench_pipeline(pipeline: Pipeline, batch_size=1, runs=1) -> BenchMetrics: for i in range(runs): start = time() output = call_pipeline(pipeline, batch_size) + if isinstance(output, tuple): + output = output[0] assert len(output) == batch_size inference_time[i] = time() - start diff --git a/runner/openapi.json b/runner/openapi.json index 300f5d93..6eed3a96 100644 --- a/runner/openapi.json +++ b/runner/openapi.json @@ -266,6 +266,11 @@ "title": "Negative Prompt", "default": "" }, + "safety_check": { + "type": "boolean", + "title": "Safety Check", + "default": true + }, "seed": { "type": "integer", "title": "Seed" @@ -392,12 +397,17 @@ "seed": { "type": "integer", "title": "Seed" + }, + "nsfw": { + "type": "boolean", + "title": "Nsfw" } }, "type": "object", "required": [ "url", - "seed" + "seed", + "nsfw" ], "title": "Media" }, @@ -430,6 +440,11 @@ "title": "Negative Prompt", "default": "" }, + "safety_check": { + "type": "boolean", + "title": "Safety Check", + "default": true + }, "seed": { "type": "integer", "title": "Seed" diff --git a/runner/openapi.yaml b/runner/openapi.yaml index 406a3b8e..f22de353 100644 --- a/runner/openapi.yaml +++ b/runner/openapi.yaml @@ -165,6 +165,10 @@ components: type: string title: Negative Prompt default: '' + safety_check: + type: boolean + title: Safety Check + default: true seed: type: integer title: Seed @@ -258,10 +262,14 @@ components: seed: type: integer title: Seed + nsfw: + type: boolean + title: Nsfw type: object required: - url - seed + - nsfw title: Media TextToImageParams: properties: @@ -286,6 +294,10 @@ components: type: string title: Negative Prompt default: '' + safety_check: + type: boolean + title: Safety Check + default: true seed: type: integer title: Seed diff --git a/worker/runner.gen.go b/worker/runner.gen.go index 54904307..3ab72ff4 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -39,6 +39,7 @@ type BodyImageToImageImageToImagePost struct { NegativePrompt *string `json:"negative_prompt,omitempty"` NumImagesPerPrompt *int `json:"num_images_per_prompt,omitempty"` Prompt string `json:"prompt"` + SafetyCheck *bool `json:"safety_check,omitempty"` Seed *int `json:"seed,omitempty"` Strength *float32 `json:"strength,omitempty"` } @@ -77,6 +78,7 @@ type ImageResponse struct { // Media defines model for Media. type Media struct { + Nsfw bool `json:"nsfw"` Seed int `json:"seed"` Url string `json:"url"` } @@ -90,6 +92,7 @@ type TextToImageParams struct { NumImagesPerPrompt *int `json:"num_images_per_prompt,omitempty"` NumInferenceSteps *int `json:"num_inference_steps,omitempty"` Prompt string `json:"prompt"` + SafetyCheck *bool `json:"safety_check,omitempty"` Seed *int `json:"seed,omitempty"` Width *int `json:"width,omitempty"` } @@ -1078,26 +1081,27 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xX3W/bNhD/Vwhuj47teM06+C3pttbYsgax1z0EgcFIZ5mtRHL8SGsE/t8HHmWJ+vDs", - "DGkGFHmyJd3H7+5+dzw+0EQWSgoQ1tDpAzXJGgqGf8+vZr9oLbX/r7RUoC0H/FKYzP9YbnOgU3ppMjqg", - "dqP8g7Gai4xutwOq4W/HNaR0eoMqt4NKpbJd6cm7j5BYuh3QC5lulrxgGSytLP+0HpU0tgsrczxlIoGl", - "SZj38kBTWDGXWzp9PTyrnb8t5cgc5SoIwhV3oD0E9OINrKQumKVTescF0xtaG5mhSCfsAS1kCvmSpw3/", - "NNK89AJklvYpC8iY5fewVFoWyu618UcpR66CXJ8pV4RsmaUC3WfwNLLnCoIRGXIFumOVCwtZSE1tZ6e7", - "H4IBSGPJuX/uM2qsBpHZdQPeePhTDXC+k+hUq0U0tUMTahhx7lheHaTkPU9Bth/7KblSphHTjzWcX5Xp", - "zcUaeLZuFursdaT3LnzvU/3faFtIy6VY3rnkE9i2kdPJ69iKlyQXKNmwFsUhJDewZC5b7iHGeBJR1wuT", - "c5eR/Rx5BBU/87Tl7nQ8eVW7+wu/dzVbNDzAvv0U6mHfu8Xias8kTsEynvt/32tY0Sn9blTP81E5zEfV", - "tG2jLNUjmLWvPUA+sJynzBfxICRuoTCHsLXtbWssPwdLFRCmNdtgDDHatoE+3MByu36zhuRTF6+xzLpm", - "l9L3v9F49KBA3wlX92TtoMc/Nt01GCWFgS6CMKWPztglpJzFeQqDuy9PHUaauNZNWD24g6duxo7tJafz", - "WO5PnR/cExzKoIcIaQDSg3ABX+xCYiBXTLOQvK+1FdST+YhZ/I2vAWhWrEADptZC66A7G7es7mTJHGWf", - "bbWo5vkjB3gJJiJhl2s9hDw4HnOZNDqdic37FZ3ePHRifOhAvI2a/neZoJtO2w866zkYs+fQDy9qUcRM", - "Fv7toUb1cQRXpWSUqSNG8gd/4u0fiSvNitZIfORsbOWk2rqC4QOzsnQfh9TA2wkIGZk4ze1m7qEE7P54", - "ugCmQVdXK690F15VRtbWKrr1NrhYydBGJtFcYX2n9FwQplTOQ8GJlUQ7Qc5nRHEFORchnh0v+D0oAO2/", - "Xzsh0NE9aBNsjYenw7FPiFQgmOJ0Sn/AVwOqmF0j7NEajzMcnoD96EuDzmdpddpRn7KQD9SajMf+J5HC", - "gkCtCPToo/Hud/fLQ2WMz1NMTDMhc5ckYMzK5aQqCZbAFYVfdyuI/uUIp9+JlSfVerzb1ZthYWeXDU4D", - "H8BYv7e14ipcbrli2o78nn2SMsuOD+3YW8i2yUmrHWy/Ysabu8CxOR/QV09Z9Wr37PF/wVJyHUqCfieT", - "J/XbWUO7CGoRUq2qZ88V/kxY0ILlZA76HjSp9/nd3MEzJJ44N7fb27gnsMRkIcMJ3+oNvIEc7A2cgs/V", - "G/vvSM/cG83Z/9Ib33JvBIZjb1j4Yo84NqK18F87478H3108Xw6HlwZ42gbwHIvPBtT1xgyqNv1VO+ab", - "XLqUvJFF4QS3G/KWWfjMNrS8+uNma6ajUaqBFSdZ+DrMS/Vh4tX9teafAAAA///iPqr4hhgAAA==", + "H4sIAAAAAAAC/+xY224bNxN+FYL/fylLshrXhe5st02E1olhqemFYQj07mjFZJdkebAjGHr3gkNpl3tQ", + "JbeOiwa+snZ3Dt8MvznQjzSRhZIChDV0/EhNsoSC4c+zq8lPWkvtfystFWjLAb8UJvN/LLc50DG9NBnt", + "UbtS/sFYzUVG1+se1fCH4xpSOr5BldteqVLaLvXk3SdILF336LlMV3NesAzmVm5+NB6VNLYNK3M8ZSKB", + "uUmY9/JIU1gwl1s6Pu2fVM7fbuTIFOVKCMIVd6A9BPTiDSykLpilY3rHBdMrWhmZoEgr7B4tZAr5nKc1", + "/zTSvPQCZJJ2KQvImOX3MFdaFsrutPF+I0euglyXKVeEbJm5At1l8Diy5wqCERlyBbpllQsLWUhNZWer", + "uxuCYQuwq3myhORzzbPVDirnUxQjFyhWmrmTMgcm0A5AGnuc+ucucMZqEJld1pwN+z9EvrYSrVNvEFZt", + "owpciLh7KD/3UvuepyCbj93UXihTi+n7Cs7PynTmYgk8W9YP/OQ00nsXvnep/mv0L6TlUszvXPIZbNPI", + "8eg0tuIlyTlK1qxFcQjJDcyZy+Y7iDEcRSXghcmZy8hujjyBig88bbg7Ho7eVO5+x+9tzQYN97BvN4U6", + "2PduNrva0dFTsIzn/tf/NSzomP5vUM2FwWYoDMqu3US5UY9gVr52APnIcp4yf4h7IXELhdmHrWlvXWH5", + "MVgqgTCt2QpjiNE2DXThBpbb5cW2ndXxGsusq1cp/fALjVsPCnRNyqomKwcd/rHorsEoKQy0EYRuf3DG", + "LiHlLM5TGABdeWox0sRnXYfVgTt4auEVZvEQ19J7//yP+r/TeSz3m873LiYOZUywiIiiyALwjohm8MXO", + "JAZ+xTQLyf5a20jVyQ/o3d/4+oFmxQI0YGotNAbjybBhdStLpij7n1tpyjnyxMGxCSoic5uzHcTe25Zz", + "mdQ6DBOrDws6vnls5eqxBfE2aja/ygTdtNpNr3W9AGN2LBvhRSWKmMnMv91X9z6O4GojGWXqgFHw0U/a", + "3a14oVnRaMVP7MmNnJTbXjC8p0dv3Mch1fC2AkJGJk5zu5p6KAG7H4vnwDTo8mqINA6vSiNLaxVdextc", + "LGSoCpNorvB8x/RMEKZUzsOBEyuJdoKcTYjiCnIuQjxbXvB7UADaf792QqCje9Am2Br2j/tDnxCpQDDF", + "6Zh+h696VDG7RNiDJY5RbMKAde2PBp1P0nLKUp+ykA/UGg2H/k8ihQWBWhHowSfj3W/vx/uOMZ7jmJh6", + "QqYuScCYhctJeSR4BK4o/JpdQvQvB9hFj6w8Ktfy7R2hHhZW9qbAaeADGOv3xUZchcstV0zbgd/vj1Jm", + "2eGhHXr7Wdc56dvj+itmvL6DHJrzHn3znKde7rwd/s9ZSq7DkaDf0ehZ/bbW3zaCSoSUK/LJS4U/ERa0", + "YDmZgr4HTap7xLbv4AyJO87N7fo2rgk8YjKTYVNo1AbefPbWBnbBl6qN3XezF66Neu9/rY1vuTYCw7E2", + "LHyxB4yNaC38y8r4+8G3F8/X4fBaAM9bAJ5j8WxAXW/MoGrdX7ljXuTSpeRCFoUT3K7IW2bhga3o5j8J", + "uNma8WCQamDFURa+9vONej/x6v5a82cAAAD//zzGF0pGGQAA", } // GetSwagger returns the content of the embedded swagger specification file