From 32d8bce19fa361cdb6a5ffceba70c00d9c71403d Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Wed, 17 Jul 2024 13:14:23 +0200 Subject: [PATCH] feat: improve I2I pipeline num_inference_steps behavoir This commit ensures that the `strength` is set correctly when the sdxl-turbo model is used and cleans up the `num_inference_steps` implementation. --- runner/app/pipelines/image_to_image.py | 35 ++++++++++++-------------- runner/app/routes/image_to_image.py | 11 +++++--- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/runner/app/pipelines/image_to_image.py b/runner/app/pipelines/image_to_image.py index d5c05bc6..4080c919 100644 --- a/runner/app/pipelines/image_to_image.py +++ b/runner/app/pipelines/image_to_image.py @@ -5,14 +5,6 @@ import PIL import torch -from app.pipelines.base import Pipeline -from app.pipelines.utils import ( - SafetyChecker, - get_model_dir, - get_torch_device, - is_lightning_model, - is_turbo_model, -) from diffusers import ( AutoPipelineForImage2Image, EulerAncestralDiscreteScheduler, @@ -25,6 +17,15 @@ from PIL import ImageFile from safetensors.torch import load_file +from app.pipelines.base import Pipeline +from app.pipelines.utils import ( + SafetyChecker, + get_model_dir, + get_torch_device, + is_lightning_model, + is_turbo_model, +) + ImageFile.LOAD_TRUNCATED_IMAGES = True logger = logging.getLogger(__name__) @@ -171,6 +172,7 @@ def __call__( self, prompt: str, image: PIL.Image, **kwargs ) -> Tuple[List[PIL.Image], List[Optional[bool]]]: seed = kwargs.pop("seed", None) + num_inference_steps = kwargs.get("num_inference_steps", None) safety_check = kwargs.pop("safety_check", True) if seed is not None: @@ -183,7 +185,7 @@ def __call__( torch.Generator(get_torch_device()).manual_seed(s) for s in seed ] - if "num_inference_steps" in kwargs and kwargs["num_inference_steps"] < 1: + if num_inference_steps is None or num_inference_steps < 1: del kwargs["num_inference_steps"] if ( @@ -194,15 +196,12 @@ def __call__( # it should be set to 0 kwargs["guidance_scale"] = 0.0 - # num_inference_steps * strength should be >= 1 because - # the pipeline will be run for int(num_inference_steps * strength) steps - kwargs["strength"] = kwargs.get("strength", 0.5) - if ( - kwargs.get("num_inference_steps") - and kwargs["strength"] * kwargs["num_inference_steps"] < 1 - ): + # Ensure num_inference_steps * strength >= 1 for minimum pipeline + # execution steps. + if "num_inference_steps" in kwargs: kwargs["strength"] = max( - 1.0 / kwargs["num_inference_steps"], kwargs["strength"] + 1.0 / kwargs.get("num_inference_steps", 1), + kwargs.get("strength", 0.5), ) elif ModelName.SDXL_LIGHTNING.value in self.model_id: # SDXL-Lightning models should have guidance_scale = 0 and use @@ -218,8 +217,6 @@ def __call__( else: # Default to 2step kwargs["num_inference_steps"] = 2 - elif ModelName.INSTRUCT_PIX2PIX.value in self.model_id: - kwargs["num_inference_steps"] = kwargs.get("num_inference_steps", 10) output = self.ldm(prompt, image=image, **kwargs) diff --git a/runner/app/routes/image_to_image.py b/runner/app/routes/image_to_image.py index 8c11f253..6a397628 100644 --- a/runner/app/routes/image_to_image.py +++ b/runner/app/routes/image_to_image.py @@ -3,14 +3,15 @@ import random from typing import Annotated -from app.dependencies import get_pipeline -from app.pipelines.base import Pipeline -from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from PIL import Image, ImageFile +from app.dependencies import get_pipeline +from app.pipelines.base import Pipeline +from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url + ImageFile.LOAD_TRUNCATED_IMAGES = True router = APIRouter() @@ -44,7 +45,9 @@ async def image_to_image( negative_prompt: Annotated[str, Form()] = "", safety_check: Annotated[bool, Form()] = True, seed: Annotated[int, Form()] = None, - num_inference_steps: Annotated[int, Form()] = 25, # TODO: Make optional. + num_inference_steps: Annotated[ + int, Form() + ] = 100, # NOTE: Hardcoded due to varying pipeline values. num_images_per_prompt: Annotated[int, Form()] = 1, pipeline: Pipeline = Depends(get_pipeline), token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),