diff --git a/runner/app/pipelines/image_to_video.py b/runner/app/pipelines/image_to_video.py index b013eb56..4f0088c5 100644 --- a/runner/app/pipelines/image_to_video.py +++ b/runner/app/pipelines/image_to_video.py @@ -1,11 +1,11 @@ from app.pipelines.base import Pipeline -from app.pipelines.util import get_torch_device, get_model_dir +from app.pipelines.util import get_torch_device, get_model_dir, SafetyChecker from diffusers import StableVideoDiffusionPipeline from huggingface_hub import file_download import torch import PIL -from typing import List +from typing import List, Tuple import logging import os import time @@ -102,8 +102,11 @@ def __init__(self, model_id: str): from app.pipelines.optim.deepcache import enable_deepcache self.ldm = enable_deepcache(self.ldm) + + safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower() + self._safety_checker = SafetyChecker(device=safety_checker_device) - def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]: + def __call__(self, image: PIL.Image, **kwargs) -> Tuple[List[PIL.Image], List[bool]]: if "decode_chunk_size" not in kwargs: # Decrease decode_chunk_size to reduce memory usage. kwargs["decode_chunk_size"] = 4 @@ -118,8 +121,15 @@ def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]: kwargs["generator"] = [ torch.Generator(get_torch_device()).manual_seed(s) for s in seed ] - - return self.ldm(image, **kwargs).frames + + safety_check = kwargs.pop("safety_check", True) + if safety_check: + logger.info("checking input image for nsfw") + _, has_nsfw_concept = self._safety_checker.check_nsfw_images([image]) + else: + has_nsfw_concept = [None] + + return self.ldm(image, **kwargs).frames, has_nsfw_concept def __str__(self) -> str: return f"ImageToVideoPipeline model_id={self.model_id}" diff --git a/runner/app/routes/image_to_video.py b/runner/app/routes/image_to_video.py index d150dc9e..5d66aa74 100644 --- a/runner/app/routes/image_to_video.py +++ b/runner/app/routes/image_to_video.py @@ -39,6 +39,7 @@ async def image_to_video( motion_bucket_id: Annotated[int, Form()] = 127, noise_aug_strength: Annotated[float, Form()] = 0.02, seed: Annotated[int, Form()] = None, + safety_check: Annotated[bool, Form()] = True, pipeline: Pipeline = Depends(get_pipeline), token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), ): @@ -73,14 +74,15 @@ async def image_to_video( seed = random.randint(0, 2**32 - 1) try: - batch_frames = pipeline( + batch_frames, has_nsfw_concept = pipeline( image=Image.open(image.file).convert("RGB"), height=height, width=width, fps=fps, motion_bucket_id=motion_bucket_id, noise_aug_strength=noise_aug_strength, - seed=seed, + safety_check=safety_check, + seed=seed ) except Exception as e: logger.error(f"ImageToVideoPipeline error: {e}") @@ -93,7 +95,7 @@ async def image_to_video( for frames in batch_frames: output_frames.append( [ - {"url": image_to_data_url(frame), "seed": seed, "nsfw": False} + {"url": image_to_data_url(frame), "seed": seed, "nsfw": has_nsfw_concept[0]} for frame in frames ] ) diff --git a/runner/openapi.json b/runner/openapi.json index 6eed3a96..dbafb47f 100644 --- a/runner/openapi.json +++ b/runner/openapi.json @@ -328,6 +328,11 @@ "seed": { "type": "integer", "title": "Seed" + }, + "safety_check": { + "type": "boolean", + "title": "Safety Check", + "default": true } }, "type": "object", diff --git a/runner/openapi.yaml b/runner/openapi.yaml index f22de353..2f9b27ea 100644 --- a/runner/openapi.yaml +++ b/runner/openapi.yaml @@ -214,6 +214,10 @@ components: seed: type: integer title: Seed + safety_check: + type: boolean + title: Safety Check + default: true type: object required: - image diff --git a/runner/test.txt b/runner/test.txt new file mode 100644 index 00000000..f727069e --- /dev/null +++ b/runner/test.txt @@ -0,0 +1,9 @@ +{ + "images": [ + { + "url": "", + "seed": 0, + "nsfw": true + } + ] +} diff --git a/worker/runner.gen.go b/worker/runner.gen.go index 3ab72ff4..155bf270 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -52,6 +52,7 @@ type BodyImageToVideoImageToVideoPost struct { ModelId *string `json:"model_id,omitempty"` MotionBucketId *int `json:"motion_bucket_id,omitempty"` NoiseAugStrength *float32 `json:"noise_aug_strength,omitempty"` + SafetyCheck *bool `json:"safety_check,omitempty"` Seed *int `json:"seed,omitempty"` Width *int `json:"width,omitempty"` } @@ -1081,27 +1082,27 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xY224bNxN+FYL/fylLshrXhe5st02E1olhqemFYQj07mjFZJdkebAjGHr3gkNpl3tQ", - "JbeOiwa+snZ3Dt8MvznQjzSRhZIChDV0/EhNsoSC4c+zq8lPWkvtfystFWjLAb8UJvN/LLc50DG9NBnt", - "UbtS/sFYzUVG1+se1fCH4xpSOr5BldteqVLaLvXk3SdILF336LlMV3NesAzmVm5+NB6VNLYNK3M8ZSKB", - "uUmY9/JIU1gwl1s6Pu2fVM7fbuTIFOVKCMIVd6A9BPTiDSykLpilY3rHBdMrWhmZoEgr7B4tZAr5nKc1", - "/zTSvPQCZJJ2KQvImOX3MFdaFsrutPF+I0euglyXKVeEbJm5At1l8Diy5wqCERlyBbpllQsLWUhNZWer", - "uxuCYQuwq3myhORzzbPVDirnUxQjFyhWmrmTMgcm0A5AGnuc+ucucMZqEJld1pwN+z9EvrYSrVNvEFZt", - "owpciLh7KD/3UvuepyCbj93UXihTi+n7Cs7PynTmYgk8W9YP/OQ00nsXvnep/mv0L6TlUszvXPIZbNPI", - "8eg0tuIlyTlK1qxFcQjJDcyZy+Y7iDEcRSXghcmZy8hujjyBig88bbg7Ho7eVO5+x+9tzQYN97BvN4U6", - "2PduNrva0dFTsIzn/tf/NSzomP5vUM2FwWYoDMqu3US5UY9gVr52APnIcp4yf4h7IXELhdmHrWlvXWH5", - "MVgqgTCt2QpjiNE2DXThBpbb5cW2ndXxGsusq1cp/fALjVsPCnRNyqomKwcd/rHorsEoKQy0EYRuf3DG", - "LiHlLM5TGABdeWox0sRnXYfVgTt4auEVZvEQ19J7//yP+r/TeSz3m873LiYOZUywiIiiyALwjohm8MXO", - "JAZ+xTQLyf5a20jVyQ/o3d/4+oFmxQI0YGotNAbjybBhdStLpij7n1tpyjnyxMGxCSoic5uzHcTe25Zz", - "mdQ6DBOrDws6vnls5eqxBfE2aja/ygTdtNpNr3W9AGN2LBvhRSWKmMnMv91X9z6O4GojGWXqgFHw0U/a", - "3a14oVnRaMVP7MmNnJTbXjC8p0dv3Mch1fC2AkJGJk5zu5p6KAG7H4vnwDTo8mqINA6vSiNLaxVdextc", - "LGSoCpNorvB8x/RMEKZUzsOBEyuJdoKcTYjiCnIuQjxbXvB7UADaf792QqCje9Am2Br2j/tDnxCpQDDF", - "6Zh+h696VDG7RNiDJY5RbMKAde2PBp1P0nLKUp+ykA/UGg2H/k8ihQWBWhHowSfj3W/vx/uOMZ7jmJh6", - "QqYuScCYhctJeSR4BK4o/JpdQvQvB9hFj6w8Ktfy7R2hHhZW9qbAaeADGOv3xUZchcstV0zbgd/vj1Jm", - "2eGhHXr7Wdc56dvj+itmvL6DHJrzHn3znKde7rwd/s9ZSq7DkaDf0ehZ/bbW3zaCSoSUK/LJS4U/ERa0", - "YDmZgr4HTap7xLbv4AyJO87N7fo2rgk8YjKTYVNo1AbefPbWBnbBl6qN3XezF66Neu9/rY1vuTYCw7E2", - "LHyxB4yNaC38y8r4+8G3F8/X4fBaAM9bAJ5j8WxAXW/MoGrdX7ljXuTSpeRCFoUT3K7IW2bhga3o5j8J", - "uNma8WCQamDFURa+9vONej/x6v5a82cAAAD//zzGF0pGGQAA", + "H4sIAAAAAAAC/+xYW2/bNhT+KwS3R8d2vGYZ/JZkW2tsaYPY6x6CwGCkY5mtRHK8JDUC//eBh7ZEXTy7", + "W5piRZ5iSefyncPvXJhHmshCSQHCGjp+pCZZQsHw59nV5Betpfa/lZYKtOWAXwqT+T+W2xzomF6ajPao", + "XSn/YKzmIqPrdY9q+MtxDSkd36DKba9UKW2XevLuAySWrnv0XKarOS9YBnMrNz8aj0oa24aVOZ4ykcDc", + "JMx7eaQpLJjLLR2f9k8q5683cmSKciUE4Yo70B4CevEGFlIXzNIxveOC6RWtjExQpBV2jxYyhXzO05p/", + "GmleegEySbuUBWTM8nuYKy0LZXfaeLuRI1dBrsuUK0K2zFyB7jJ4HNlzBcGIDLkC3bLKhYUspKays9Xd", + "DcGwBdjVPFlC8rHm2WoHlfMpipELFCvN3EmZAxNoByCNPU79cxc4YzWIzC5rzob9nyJfW4nWqTcIq7ZR", + "BS5E3D2Un3upfc9TkM3HbmovlKnF9GMF51dlOnOxBJ4t6wd+chrpvQnfu1S/Gv0LabkU8zuXfATbNHI8", + "Oo2teElyjpI1a1EcQnIDc+ay+Q5iDEdRCXhhcuYyspsjX4HSDzxtwD4ejl5Vnv7E723NBp33sHg3FTtY", + "/GY2u9oxGVKwjOf+1/caFnRMvxtU82WwGS6Dsvs3UW7UI5iVrx1A3rOcp8yTYS8kbqEw+7A17a0rLD8H", + "SyUQpjVbYQwx2qaBLtzAcru82HKojtdYZl292um732jcwlCga+JWtV056PCPxXsNRklhoI0gTI2DM3YJ", + "KWdxnsIg6cpTi5EmPus6rA7cwVMLrzCLh7iW3vrn/1R0Tuex3B8637vgOJQxwSIiiiILwDsimsEnO5MY", + "+BXTLCT7S2011UQ4YAZ842sMmhUL0ICptdAYsCfDhtWtLJmi7P9uNSrnyGcOjk1QEZnbnO0g9t62nMuk", + "1mGYWL1b0PHNYytXjy2It1Gz+V0m6KbVbnqtawoYs2NpCS8qUcRMZv7tvrr3cQRXG8koUweMgvd+0u5u", + "xQvNikYr/sye3MhJuTUGw3t69MZ9HFINbysgZGTiNLerqYcSsPuxeA5Mgy6vmEjj8Ko0srRW0bW3wcVC", + "hqowieYKz3dMzwRhSuU8HDixkmgnyNmEKK4g5yLEs+UFvwcFoP33aycEOroHbYKtYf+4P/QJkQoEU5yO", + "6Q/4qkcVs0uEPVjiGMUmDFjX/mjQ+SQtpyz1KQv5QK3RcOj/JFJYEKgVgR58MN799p697xjjOY6JqSdk", + "6pIEjFm4nJRHgkfgisKv6yVE/3KAXfTIyqNyvd/eNephYWVvCpwGPoCxfl9sxFW43HLFtB34e8JRyiw7", + "PLRDb1HrOid9e1x/wYzXd5BDc96jr57y1Mudt8P/OUvJdTgS9DsaPanf1vrbRlCJkHJFPnmu8CfCghYs", + "J1PQ96BJdY/Y9h2cIXHHubld38Y1gUdMZjJsCo3awJvP3trALvhctbH7bvbMtVHv/S+18S3XRmA41oaF", + "T/aAsRGthf9YGf8++Pbi+TIcXgrgaQvAcyyeDajrjRlUrfsrd8yLXLqUXMiicILbFXnNLDywFd38JwE3", + "WzMeDFINrDjKwtd+vlHvJ17dX2v+DgAA//+bzF7djhkAAA==", } // GetSwagger returns the content of the embedded swagger specification file