Skip to content

Commit

Permalink
feat: improve T2I num_inference_steps behavoir
Browse files Browse the repository at this point in the history
This commit improves the implementation and defaults for the
num_inference_steps parameter.
  • Loading branch information
rickstaa committed Jul 17, 2024
1 parent 4d2a6af commit 97fca3a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
24 changes: 13 additions & 11 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,6 @@

import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import (
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
)
from diffusers import (
AutoPipelineForText2Image,
EulerDiscreteScheduler,
Expand All @@ -23,6 +15,15 @@
from huggingface_hub import file_download, hf_hub_download
from safetensors.torch import load_file

from app.pipelines.base import Pipeline
from app.pipelines.utils import (
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
)

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -183,6 +184,7 @@ def __call__(
self, prompt: str, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
num_inference_steps = kwargs.get("num_inference_steps", None)
safety_check = kwargs.pop("safety_check", True)

if seed is not None:
Expand All @@ -195,7 +197,7 @@ def __call__(
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

if "num_inference_steps" in kwargs and kwargs["num_inference_steps"] < 1:
if num_inference_steps is None or num_inference_steps < 1:
del kwargs["num_inference_steps"]

if (
Expand All @@ -217,8 +219,8 @@ def __call__(
elif "8step" in self.model_id:
kwargs["num_inference_steps"] = 8
else:
# Default to 2step
kwargs["num_inference_steps"] = 2
# Default to 8step
kwargs["num_inference_steps"] = 8

output = self.ldm(prompt, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion runner/app/routes/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TextToImageParams(BaseModel):
negative_prompt: str = ""
safety_check: bool = True
seed: int = None
num_inference_steps: int = 50 # TODO: Make optional.
num_inference_steps: int = 50 # NOTE: Hardcoded due to varying pipeline values.
num_images_per_prompt: int = 1


Expand Down

0 comments on commit 97fca3a

Please sign in to comment.