Skip to content

Commit

Permalink
feat(image-to-video): integrate NSFW safety checker (livepeer#90)
Browse files Browse the repository at this point in the history
* add saftey checker to image to video

* refactor(runner): apply black formatter

This commit applies the black formatter onto the codebase.

---------

Co-authored-by: Rick Staa <[email protected]>
  • Loading branch information
2 people authored and eliteprox committed Jun 9, 2024
1 parent 802af8e commit 7a0aed7
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 29 deletions.
20 changes: 15 additions & 5 deletions runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from app.pipelines.base import Pipeline
from app.pipelines.util import get_torch_device, get_model_dir
from app.pipelines.util import get_torch_device, get_model_dir, SafetyChecker

from diffusers import StableVideoDiffusionPipeline
from huggingface_hub import file_download
import torch
import PIL
from typing import List
from typing import List, Tuple
import logging
import os
import time
Expand Down Expand Up @@ -102,8 +102,11 @@ def __init__(self, model_id: str):
from app.pipelines.optim.deepcache import enable_deepcache

self.ldm = enable_deepcache(self.ldm)

safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)

def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]:
def __call__(self, image: PIL.Image, **kwargs) -> Tuple[List[PIL.Image], List[bool]]:
if "decode_chunk_size" not in kwargs:
# Decrease decode_chunk_size to reduce memory usage.
kwargs["decode_chunk_size"] = 4
Expand All @@ -118,8 +121,15 @@ def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]:
kwargs["generator"] = [
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

return self.ldm(image, **kwargs).frames

safety_check = kwargs.pop("safety_check", True)
if safety_check:
logger.info("checking input image for nsfw")
_, has_nsfw_concept = self._safety_checker.check_nsfw_images([image])
else:
has_nsfw_concept = [None]

return self.ldm(image, **kwargs).frames, has_nsfw_concept

def __str__(self) -> str:
return f"ImageToVideoPipeline model_id={self.model_id}"
8 changes: 5 additions & 3 deletions runner/app/routes/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ async def image_to_video(
motion_bucket_id: Annotated[int, Form()] = 127,
noise_aug_strength: Annotated[float, Form()] = 0.02,
seed: Annotated[int, Form()] = None,
safety_check: Annotated[bool, Form()] = True,
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
Expand Down Expand Up @@ -73,14 +74,15 @@ async def image_to_video(
seed = random.randint(0, 2**32 - 1)

try:
batch_frames = pipeline(
batch_frames, has_nsfw_concept = pipeline(
image=Image.open(image.file).convert("RGB"),
height=height,
width=width,
fps=fps,
motion_bucket_id=motion_bucket_id,
noise_aug_strength=noise_aug_strength,
seed=seed,
safety_check=safety_check,
seed=seed
)
except Exception as e:
logger.error(f"ImageToVideoPipeline error: {e}")
Expand All @@ -93,7 +95,7 @@ async def image_to_video(
for frames in batch_frames:
output_frames.append(
[
{"url": image_to_data_url(frame), "seed": seed, "nsfw": False}
{"url": image_to_data_url(frame), "seed": seed, "nsfw": has_nsfw_concept[0]}
for frame in frames
]
)
Expand Down
5 changes: 5 additions & 0 deletions runner/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,11 @@
"seed": {
"type": "integer",
"title": "Seed"
},
"safety_check": {
"type": "boolean",
"title": "Safety Check",
"default": true
}
},
"type": "object",
Expand Down
4 changes: 4 additions & 0 deletions runner/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ components:
seed:
type: integer
title: Seed
safety_check:
type: boolean
title: Safety Check
default: true
type: object
required:
- image
Expand Down
9 changes: 9 additions & 0 deletions runner/test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"images": [
{
"url": "",
"seed": 0,
"nsfw": true
}
]
}
43 changes: 22 additions & 21 deletions worker/runner.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 7a0aed7

Please sign in to comment.