From 40c713fc9c7f01a0a6542379f9a0028a8dca24cd Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Sun, 5 May 2024 17:10:41 +0200 Subject: [PATCH] feat(pipelines): add optional NSFW safety check to T2I and I2I pipelines This commit incorporates the CompVis/stable-diffusion-safety-checker into the text-to-image and image-to-image pipelines. By enabling the `safety_check` input variable, users get notified of the generation of NSFW images. --- runner/app/pipelines/image_to_image.py | 22 +++++-- runner/app/pipelines/text_to_image.py | 37 +++++++---- runner/app/pipelines/util.py | 88 ++++++++++++++++++++++++++ runner/app/routes/image_to_image.py | 10 ++- runner/app/routes/text_to_image.py | 9 ++- runner/app/routes/util.py | 6 +- runner/bench.py | 13 ++-- runner/openapi.json | 30 +++++++-- 8 files changed, 183 insertions(+), 32 deletions(-) diff --git a/runner/app/pipelines/image_to_image.py b/runner/app/pipelines/image_to_image.py index 68d96892..e72fd16e 100644 --- a/runner/app/pipelines/image_to_image.py +++ b/runner/app/pipelines/image_to_image.py @@ -1,5 +1,5 @@ from app.pipelines.base import Pipeline -from app.pipelines.util import get_torch_device, get_model_dir +from app.pipelines.util import get_torch_device, get_model_dir, SafetyChecker from diffusers import ( AutoPipelineForImage2Image, @@ -11,7 +11,7 @@ from huggingface_hub import file_download, hf_hub_download import torch import PIL -from typing import List +from typing import List, Tuple import logging import os @@ -111,7 +111,14 @@ def __init__(self, model_id: str): "call may be slow if 'SFAST' is enabled." ) - def __call__(self, prompt: str, image: PIL.Image, **kwargs) -> List[PIL.Image]: + safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower() + self._safety_checker = SafetyChecker(device=safety_checker_device) + + def __call__( + self, prompt: str, image: PIL.Image, **kwargs + ) -> Tuple[List[PIL.Image], List[bool]]: + safety_check = kwargs.pop("safety_check", False) + seed = kwargs.pop("seed", None) if seed is not None: if isinstance(seed, int): @@ -153,7 +160,14 @@ def __call__(self, prompt: str, image: PIL.Image, **kwargs) -> List[PIL.Image]: # Default to 2step kwargs["num_inference_steps"] = 2 - return self.ldm(prompt, image=image, **kwargs).images + output = self.ldm(prompt, image=image, **kwargs) + + if safety_check: + _, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images) + else: + has_nsfw_concept = [None] * len(output.images) + + return output.images, has_nsfw_concept def __str__(self) -> str: return f"ImageToImagePipeline model_id={self.model_id}" diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index 9c58ad9b..9c6e03cd 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -1,19 +1,20 @@ -from app.pipelines.base import Pipeline -from app.pipelines.util import get_torch_device, get_model_dir +import logging +import os +from typing import List, Tuple, Optional +import PIL +import torch from diffusers import ( AutoPipelineForText2Image, + EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel, - EulerDiscreteScheduler, ) -from safetensors.torch import load_file from huggingface_hub import file_download, hf_hub_download -import torch -import PIL -from typing import List -import logging -import os +from safetensors.torch import load_file + +from app.pipelines.base import Pipeline +from app.pipelines.util import get_model_dir, get_torch_device, SafetyChecker logger = logging.getLogger(__name__) @@ -130,7 +131,14 @@ def __init__(self, model_id: str): "call may be slow if 'SFAST' is enabled." ) - def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]: + safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower() + self._safety_checker = SafetyChecker(device=safety_checker_device) + + def __call__( + self, prompt: str, **kwargs + ) -> Tuple[List[PIL.Image], List[Optional[bool]]]: + safety_check = kwargs.pop("safety_check", False) + seed = kwargs.pop("seed", None) if seed is not None: if isinstance(seed, int): @@ -167,7 +175,14 @@ def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]: # Default to 2step kwargs["num_inference_steps"] = 2 - return self.ldm(prompt, **kwargs).images + output = self.ldm(prompt, **kwargs) + + if safety_check: + _, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images) + else: + has_nsfw_concept = [None] * len(output.images) + + return output.images, has_nsfw_concept def __str__(self) -> str: return f"TextToImagePipeline model_id={self.model_id}" diff --git a/runner/app/pipelines/util.py b/runner/app/pipelines/util.py index 5ce8d80c..2a73b798 100644 --- a/runner/app/pipelines/util.py +++ b/runner/app/pipelines/util.py @@ -1,6 +1,15 @@ import torch import os +import numpy as np +from torch import dtype as TorchDtype from pathlib import Path +from PIL import Image +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker +from transformers import CLIPFeatureExtractor +from typing import Optional +import logging + +logger = logging.getLogger(__name__) def get_model_dir() -> Path: @@ -18,3 +27,82 @@ def get_torch_device(): return torch.device("mps") else: return torch.device("cpu") + + +def validate_torch_device(device_name: str) -> bool: + """Checks if the given PyTorch device name is valid and available. + + Args: + device_name: Name of the device ('cuda:0', 'cuda', 'cpu'). + + Returns: + True if valid and available, False otherwise. + """ + try: + device = torch.device(device_name) + if device.type == "cuda": + # Check if CUDA is available and the specified index is within range + if device.index is None: + return torch.cuda.is_available() + else: + return device.index < torch.cuda.device_count() + return True + except RuntimeError: + return False + + +class SafetyChecker: + """Checks images for unsafe or inappropriate content using a pretrained model. + + Attributes: + device (str): Device for inference. + """ + + def __init__( + self, + device: Optional[str] = "cuda", + dtype: Optional[TorchDtype] = torch.float16, + ): + """Initializes the SafetyChecker. + + Args: + device: Device for inference. Defaults to "cuda". + dtype: Data type for inference. Defaults to `torch.float16`. + """ + device = device.lower() if device else device + if not validate_torch_device(device): + default_device = get_torch_device() + logger.warning( + f"Device '{device}' not found. Defaulting to '{default_device}'." + ) + device = default_device + + self.device = device + self._dtype = dtype + self._safety_checker = StableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker" + ).to(self.device) + self._feature_extractor = CLIPFeatureExtractor.from_pretrained( + "openai/clip-vit-base-patch32" + ) + + def check_nsfw_images( + self, images: list[Image.Image] + ) -> tuple[list[Image.Image], list[bool]]: + """Checks images for unsafe content. + + Args: + images: Images to check. + + Returns: + Tuple of images and corresponding NSFW flags. + """ + safety_checker_input = self._feature_extractor(images, return_tensors="pt").to( + self.device + ) + images_np = [np.array(img) for img in images] + _, has_nsfw_concept = self._safety_checker( + images=images_np, + clip_input=safety_checker_input.pixel_values.to(self._dtype), + ) + return images, has_nsfw_concept diff --git a/runner/app/routes/image_to_image.py b/runner/app/routes/image_to_image.py index 8a7b005d..1ab8e154 100644 --- a/runner/app/routes/image_to_image.py +++ b/runner/app/routes/image_to_image.py @@ -37,6 +37,7 @@ async def image_to_image( strength: Annotated[float, Form()] = 0.8, guidance_scale: Annotated[float, Form()] = 7.5, negative_prompt: Annotated[str, Form()] = "", + safety_check: Annotated[bool, Form()] = False, seed: Annotated[int, Form()] = None, num_images_per_prompt: Annotated[int, Form()] = 1, pipeline: Pipeline = Depends(get_pipeline), @@ -76,12 +77,13 @@ async def image_to_image( image = img try: - images = pipeline( + images, has_nsfw_concept = pipeline( prompt=prompt, image=image, strength=strength, guidance_scale=guidance_scale, negative_prompt=negative_prompt, + safety_check=safety_check, seed=seed, num_images_per_prompt=num_images_per_prompt, ) @@ -97,7 +99,9 @@ async def image_to_image( seeds = [seeds] output_images = [] - for img, s in zip(images, seeds): - output_images.append({"url": image_to_data_url(img), "seed": s}) + for img, sd, is_nsfw in zip(images, seeds, has_nsfw_concept): + output_images.append( + {"url": image_to_data_url(img), "seed": sd, "nsfw": is_nsfw} + ) return {"images": output_images} diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py index 9e7542ae..114da2b6 100644 --- a/runner/app/routes/text_to_image.py +++ b/runner/app/routes/text_to_image.py @@ -23,6 +23,7 @@ class TextToImageParams(BaseModel): width: int = None guidance_scale: float = 7.5 negative_prompt: str = "" + safety_check: bool = False seed: int = None num_images_per_prompt: int = 1 @@ -63,7 +64,7 @@ async def text_to_image( ] try: - images = pipeline(**params.model_dump()) + images, has_nsfw_concept = pipeline(**params.model_dump()) except Exception as e: logger.error(f"TextToImagePipeline error: {e}") logger.exception(e) @@ -76,7 +77,9 @@ async def text_to_image( seeds = [seeds] output_images = [] - for img, sd in zip(images, seeds): - output_images.append({"url": image_to_data_url(img), "seed": sd}) + for img, sd, is_nsfw in zip(images, seeds, has_nsfw_concept): + output_images.append( + {"url": image_to_data_url(img), "seed": sd, "nsfw": is_nsfw} + ) return {"images": output_images} diff --git a/runner/app/routes/util.py b/runner/app/routes/util.py index c24353de..1f93f1ab 100644 --- a/runner/app/routes/util.py +++ b/runner/app/routes/util.py @@ -1,13 +1,15 @@ +import base64 import io +from typing import List, Optional + from PIL import Image -import base64 from pydantic import BaseModel -from typing import List class Media(BaseModel): url: str seed: int + nsfw: Optional[bool] class ImageResponse(BaseModel): diff --git a/runner/bench.py b/runner/bench.py index 024e9a5d..8a6f789f 100644 --- a/runner/bench.py +++ b/runner/bench.py @@ -1,16 +1,17 @@ import argparse +import os from time import time from typing import List + +import numpy as np import torch -from PIL import Image from app.main import load_pipeline from app.pipelines.base import Pipeline -from app.pipelines.text_to_image import TextToImagePipeline from app.pipelines.image_to_image import ImageToImagePipeline from app.pipelines.image_to_video import ImageToVideoPipeline +from app.pipelines.text_to_image import TextToImagePipeline +from PIL import Image from pydantic import BaseModel -import os -import numpy as np PROMPT = "a mountain lion" IMAGE = "images/test.png" @@ -47,6 +48,8 @@ def bench_pipeline(pipeline: Pipeline, batch_size=1, runs=1) -> BenchMetrics: for i in range(runs): start = time() output = call_pipeline(pipeline, batch_size) + if isinstance(output, tuple): + output = output[0] assert len(output) == batch_size inference_time[i] = time() - start @@ -124,7 +127,7 @@ def bench_pipeline(pipeline: Pipeline, batch_size=1, runs=1) -> BenchMetrics: print(f"pipeline load max GPU memory reserved: {load_max_mem_reserved:.3f}GiB") if os.getenv("SFAST", "").strip().lower() == "true": - + print(f"avg warmup inference time: {warmup_metrics.inference_time:.3f}s") print( f"avg warmup inference time per output: {warmup_metrics.inference_time_per_output:.3f}s" diff --git a/runner/openapi.json b/runner/openapi.json index 8debda86..6c2eef5a 100644 --- a/runner/openapi.json +++ b/runner/openapi.json @@ -7,8 +7,8 @@ }, "servers": [ { - "url": "http://gateway-endpoint.ai/", - "description": "Example Gateway" + "url": "https://dream-gateway.livepeer.cloud", + "description": "Livepeer Cloud Community Gateway" } ], "paths": { @@ -266,6 +266,11 @@ "title": "Negative Prompt", "default": "" }, + "safety_check": { + "type": "boolean", + "title": "Safety Check", + "default": false + }, "seed": { "type": "integer", "title": "Seed" @@ -392,12 +397,24 @@ "seed": { "type": "integer", "title": "Seed" + }, + "nsfw": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Nsfw" } }, "type": "object", "required": [ "url", - "seed" + "seed", + "nsfw" ], "title": "Media" }, @@ -430,6 +447,11 @@ "title": "Negative Prompt", "default": "" }, + "safety_check": { + "type": "boolean", + "title": "Safety Check", + "default": false + }, "seed": { "type": "integer", "title": "Seed" @@ -506,4 +528,4 @@ } } } -} +} \ No newline at end of file