Skip to content

Commit

Permalink
feat(pipelines): add optional NSFW safety check to T2I and I2I pipelines
Browse files Browse the repository at this point in the history
This commit incorporates the CompVis/stable-diffusion-safety-checker
into the text-to-image and image-to-image pipelines. By enabling the
`safety_check` input variable, users get notified of the generation
of NSFW images.
  • Loading branch information
rickstaa committed May 6, 2024
1 parent 2f8ab8a commit 40c713f
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 32 deletions.
22 changes: 18 additions & 4 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from app.pipelines.base import Pipeline
from app.pipelines.util import get_torch_device, get_model_dir
from app.pipelines.util import get_torch_device, get_model_dir, SafetyChecker

from diffusers import (
AutoPipelineForImage2Image,
Expand All @@ -11,7 +11,7 @@
from huggingface_hub import file_download, hf_hub_download
import torch
import PIL
from typing import List
from typing import List, Tuple
import logging
import os

Expand Down Expand Up @@ -111,7 +111,14 @@ def __init__(self, model_id: str):
"call may be slow if 'SFAST' is enabled."
)

def __call__(self, prompt: str, image: PIL.Image, **kwargs) -> List[PIL.Image]:
safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)

def __call__(
self, prompt: str, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[bool]]:
safety_check = kwargs.pop("safety_check", False)

seed = kwargs.pop("seed", None)
if seed is not None:
if isinstance(seed, int):
Expand Down Expand Up @@ -153,7 +160,14 @@ def __call__(self, prompt: str, image: PIL.Image, **kwargs) -> List[PIL.Image]:
# Default to 2step
kwargs["num_inference_steps"] = 2

return self.ldm(prompt, image=image, **kwargs).images
output = self.ldm(prompt, image=image, **kwargs)

if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images)
else:
has_nsfw_concept = [None] * len(output.images)

return output.images, has_nsfw_concept

def __str__(self) -> str:
return f"ImageToImagePipeline model_id={self.model_id}"
37 changes: 26 additions & 11 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from app.pipelines.base import Pipeline
from app.pipelines.util import get_torch_device, get_model_dir
import logging
import os
from typing import List, Tuple, Optional

import PIL
import torch
from diffusers import (
AutoPipelineForText2Image,
EulerDiscreteScheduler,
StableDiffusionXLPipeline,
UNet2DConditionModel,
EulerDiscreteScheduler,
)
from safetensors.torch import load_file
from huggingface_hub import file_download, hf_hub_download
import torch
import PIL
from typing import List
import logging
import os
from safetensors.torch import load_file

from app.pipelines.base import Pipeline
from app.pipelines.util import get_model_dir, get_torch_device, SafetyChecker

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -130,7 +131,14 @@ def __init__(self, model_id: str):
"call may be slow if 'SFAST' is enabled."
)

def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]:
safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)

def __call__(
self, prompt: str, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
safety_check = kwargs.pop("safety_check", False)

seed = kwargs.pop("seed", None)
if seed is not None:
if isinstance(seed, int):
Expand Down Expand Up @@ -167,7 +175,14 @@ def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]:
# Default to 2step
kwargs["num_inference_steps"] = 2

return self.ldm(prompt, **kwargs).images
output = self.ldm(prompt, **kwargs)

if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images)
else:
has_nsfw_concept = [None] * len(output.images)

return output.images, has_nsfw_concept

def __str__(self) -> str:
return f"TextToImagePipeline model_id={self.model_id}"
88 changes: 88 additions & 0 deletions runner/app/pipelines/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import torch
import os
import numpy as np
from torch import dtype as TorchDtype
from pathlib import Path
from PIL import Image
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor
from typing import Optional
import logging

logger = logging.getLogger(__name__)


def get_model_dir() -> Path:
Expand All @@ -18,3 +27,82 @@ def get_torch_device():
return torch.device("mps")
else:
return torch.device("cpu")


def validate_torch_device(device_name: str) -> bool:
"""Checks if the given PyTorch device name is valid and available.
Args:
device_name: Name of the device ('cuda:0', 'cuda', 'cpu').
Returns:
True if valid and available, False otherwise.
"""
try:
device = torch.device(device_name)
if device.type == "cuda":
# Check if CUDA is available and the specified index is within range
if device.index is None:
return torch.cuda.is_available()
else:
return device.index < torch.cuda.device_count()
return True
except RuntimeError:
return False


class SafetyChecker:
"""Checks images for unsafe or inappropriate content using a pretrained model.
Attributes:
device (str): Device for inference.
"""

def __init__(
self,
device: Optional[str] = "cuda",
dtype: Optional[TorchDtype] = torch.float16,
):
"""Initializes the SafetyChecker.
Args:
device: Device for inference. Defaults to "cuda".
dtype: Data type for inference. Defaults to `torch.float16`.
"""
device = device.lower() if device else device
if not validate_torch_device(device):
default_device = get_torch_device()
logger.warning(
f"Device '{device}' not found. Defaulting to '{default_device}'."
)
device = default_device

self.device = device
self._dtype = dtype
self._safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
).to(self.device)
self._feature_extractor = CLIPFeatureExtractor.from_pretrained(
"openai/clip-vit-base-patch32"
)

def check_nsfw_images(
self, images: list[Image.Image]
) -> tuple[list[Image.Image], list[bool]]:
"""Checks images for unsafe content.
Args:
images: Images to check.
Returns:
Tuple of images and corresponding NSFW flags.
"""
safety_checker_input = self._feature_extractor(images, return_tensors="pt").to(
self.device
)
images_np = [np.array(img) for img in images]
_, has_nsfw_concept = self._safety_checker(
images=images_np,
clip_input=safety_checker_input.pixel_values.to(self._dtype),
)
return images, has_nsfw_concept
10 changes: 7 additions & 3 deletions runner/app/routes/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ async def image_to_image(
strength: Annotated[float, Form()] = 0.8,
guidance_scale: Annotated[float, Form()] = 7.5,
negative_prompt: Annotated[str, Form()] = "",
safety_check: Annotated[bool, Form()] = False,
seed: Annotated[int, Form()] = None,
num_images_per_prompt: Annotated[int, Form()] = 1,
pipeline: Pipeline = Depends(get_pipeline),
Expand Down Expand Up @@ -76,12 +77,13 @@ async def image_to_image(
image = img

try:
images = pipeline(
images, has_nsfw_concept = pipeline(
prompt=prompt,
image=image,
strength=strength,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
safety_check=safety_check,
seed=seed,
num_images_per_prompt=num_images_per_prompt,
)
Expand All @@ -97,7 +99,9 @@ async def image_to_image(
seeds = [seeds]

output_images = []
for img, s in zip(images, seeds):
output_images.append({"url": image_to_data_url(img), "seed": s})
for img, sd, is_nsfw in zip(images, seeds, has_nsfw_concept):
output_images.append(
{"url": image_to_data_url(img), "seed": sd, "nsfw": is_nsfw}
)

return {"images": output_images}
9 changes: 6 additions & 3 deletions runner/app/routes/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class TextToImageParams(BaseModel):
width: int = None
guidance_scale: float = 7.5
negative_prompt: str = ""
safety_check: bool = False
seed: int = None
num_images_per_prompt: int = 1

Expand Down Expand Up @@ -63,7 +64,7 @@ async def text_to_image(
]

try:
images = pipeline(**params.model_dump())
images, has_nsfw_concept = pipeline(**params.model_dump())
except Exception as e:
logger.error(f"TextToImagePipeline error: {e}")
logger.exception(e)
Expand All @@ -76,7 +77,9 @@ async def text_to_image(
seeds = [seeds]

output_images = []
for img, sd in zip(images, seeds):
output_images.append({"url": image_to_data_url(img), "seed": sd})
for img, sd, is_nsfw in zip(images, seeds, has_nsfw_concept):
output_images.append(
{"url": image_to_data_url(img), "seed": sd, "nsfw": is_nsfw}
)

return {"images": output_images}
6 changes: 4 additions & 2 deletions runner/app/routes/util.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import base64
import io
from typing import List, Optional

from PIL import Image
import base64
from pydantic import BaseModel
from typing import List


class Media(BaseModel):
url: str
seed: int
nsfw: Optional[bool]


class ImageResponse(BaseModel):
Expand Down
13 changes: 8 additions & 5 deletions runner/bench.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import argparse
import os
from time import time
from typing import List

import numpy as np
import torch
from PIL import Image
from app.main import load_pipeline
from app.pipelines.base import Pipeline
from app.pipelines.text_to_image import TextToImagePipeline
from app.pipelines.image_to_image import ImageToImagePipeline
from app.pipelines.image_to_video import ImageToVideoPipeline
from app.pipelines.text_to_image import TextToImagePipeline
from PIL import Image
from pydantic import BaseModel
import os
import numpy as np

PROMPT = "a mountain lion"
IMAGE = "images/test.png"
Expand Down Expand Up @@ -47,6 +48,8 @@ def bench_pipeline(pipeline: Pipeline, batch_size=1, runs=1) -> BenchMetrics:
for i in range(runs):
start = time()
output = call_pipeline(pipeline, batch_size)
if isinstance(output, tuple):
output = output[0]
assert len(output) == batch_size

inference_time[i] = time() - start
Expand Down Expand Up @@ -124,7 +127,7 @@ def bench_pipeline(pipeline: Pipeline, batch_size=1, runs=1) -> BenchMetrics:
print(f"pipeline load max GPU memory reserved: {load_max_mem_reserved:.3f}GiB")

if os.getenv("SFAST", "").strip().lower() == "true":

print(f"avg warmup inference time: {warmup_metrics.inference_time:.3f}s")
print(
f"avg warmup inference time per output: {warmup_metrics.inference_time_per_output:.3f}s"
Expand Down
30 changes: 26 additions & 4 deletions runner/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
},
"servers": [
{
"url": "http://gateway-endpoint.ai/",
"description": "Example Gateway"
"url": "https://dream-gateway.livepeer.cloud",
"description": "Livepeer Cloud Community Gateway"
}
],
"paths": {
Expand Down Expand Up @@ -266,6 +266,11 @@
"title": "Negative Prompt",
"default": ""
},
"safety_check": {
"type": "boolean",
"title": "Safety Check",
"default": false
},
"seed": {
"type": "integer",
"title": "Seed"
Expand Down Expand Up @@ -392,12 +397,24 @@
"seed": {
"type": "integer",
"title": "Seed"
},
"nsfw": {
"anyOf": [
{
"type": "boolean"
},
{
"type": "null"
}
],
"title": "Nsfw"
}
},
"type": "object",
"required": [
"url",
"seed"
"seed",
"nsfw"
],
"title": "Media"
},
Expand Down Expand Up @@ -430,6 +447,11 @@
"title": "Negative Prompt",
"default": ""
},
"safety_check": {
"type": "boolean",
"title": "Safety Check",
"default": false
},
"seed": {
"type": "integer",
"title": "Seed"
Expand Down Expand Up @@ -506,4 +528,4 @@
}
}
}
}
}

0 comments on commit 40c713f

Please sign in to comment.