Skip to content

Commit

Permalink
feat: improve I2I pipeline num_inference_steps behavoir
Browse files Browse the repository at this point in the history
This commit ensures that the `strength` is set correctly when the
sdxl-turbo model is used and cleans up the `num_inference_steps`
implementation.
  • Loading branch information
rickstaa committed Jul 17, 2024
1 parent 22f43f1 commit 32d8bce
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 23 deletions.
35 changes: 16 additions & 19 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,6 @@

import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import (
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
)
from diffusers import (
AutoPipelineForImage2Image,
EulerAncestralDiscreteScheduler,
Expand All @@ -25,6 +17,15 @@
from PIL import ImageFile
from safetensors.torch import load_file

from app.pipelines.base import Pipeline
from app.pipelines.utils import (
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
)

ImageFile.LOAD_TRUNCATED_IMAGES = True

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -171,6 +172,7 @@ def __call__(
self, prompt: str, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
num_inference_steps = kwargs.get("num_inference_steps", None)
safety_check = kwargs.pop("safety_check", True)

if seed is not None:
Expand All @@ -183,7 +185,7 @@ def __call__(
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

if "num_inference_steps" in kwargs and kwargs["num_inference_steps"] < 1:
if num_inference_steps is None or num_inference_steps < 1:
del kwargs["num_inference_steps"]

if (
Expand All @@ -194,15 +196,12 @@ def __call__(
# it should be set to 0
kwargs["guidance_scale"] = 0.0

# num_inference_steps * strength should be >= 1 because
# the pipeline will be run for int(num_inference_steps * strength) steps
kwargs["strength"] = kwargs.get("strength", 0.5)
if (
kwargs.get("num_inference_steps")
and kwargs["strength"] * kwargs["num_inference_steps"] < 1
):
# Ensure num_inference_steps * strength >= 1 for minimum pipeline
# execution steps.
if "num_inference_steps" in kwargs:
kwargs["strength"] = max(
1.0 / kwargs["num_inference_steps"], kwargs["strength"]
1.0 / kwargs.get("num_inference_steps", 1),
kwargs.get("strength", 0.5),
)
elif ModelName.SDXL_LIGHTNING.value in self.model_id:
# SDXL-Lightning models should have guidance_scale = 0 and use
Expand All @@ -218,8 +217,6 @@ def __call__(
else:
# Default to 2step
kwargs["num_inference_steps"] = 2
elif ModelName.INSTRUCT_PIX2PIX.value in self.model_id:
kwargs["num_inference_steps"] = kwargs.get("num_inference_steps", 10)

output = self.ldm(prompt, image=image, **kwargs)

Expand Down
11 changes: 7 additions & 4 deletions runner/app/routes/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import random
from typing import Annotated

from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from PIL import Image, ImageFile

from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url

ImageFile.LOAD_TRUNCATED_IMAGES = True

router = APIRouter()
Expand Down Expand Up @@ -44,7 +45,9 @@ async def image_to_image(
negative_prompt: Annotated[str, Form()] = "",
safety_check: Annotated[bool, Form()] = True,
seed: Annotated[int, Form()] = None,
num_inference_steps: Annotated[int, Form()] = 25, # TODO: Make optional.
num_inference_steps: Annotated[
int, Form()
] = 100, # NOTE: Hardcoded due to varying pipeline values.
num_images_per_prompt: Annotated[int, Form()] = 1,
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
Expand Down

0 comments on commit 32d8bce

Please sign in to comment.