diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index f7099f34..feb11349 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -5,14 +5,6 @@ import PIL import torch -from app.pipelines.base import Pipeline -from app.pipelines.utils import ( - SafetyChecker, - get_model_dir, - get_torch_device, - is_lightning_model, - is_turbo_model, -) from diffusers import ( AutoPipelineForText2Image, EulerDiscreteScheduler, @@ -23,6 +15,15 @@ from huggingface_hub import file_download, hf_hub_download from safetensors.torch import load_file +from app.pipelines.base import Pipeline +from app.pipelines.utils import ( + SafetyChecker, + get_model_dir, + get_torch_device, + is_lightning_model, + is_turbo_model, +) + logger = logging.getLogger(__name__) @@ -183,6 +184,7 @@ def __call__( self, prompt: str, **kwargs ) -> Tuple[List[PIL.Image], List[Optional[bool]]]: seed = kwargs.pop("seed", None) + num_inference_steps = kwargs.get("num_inference_steps", None) safety_check = kwargs.pop("safety_check", True) if seed is not None: @@ -195,7 +197,7 @@ def __call__( torch.Generator(get_torch_device()).manual_seed(s) for s in seed ] - if "num_inference_steps" in kwargs and kwargs["num_inference_steps"] < 1: + if num_inference_steps is None or num_inference_steps < 1: del kwargs["num_inference_steps"] if ( @@ -217,8 +219,8 @@ def __call__( elif "8step" in self.model_id: kwargs["num_inference_steps"] = 8 else: - # Default to 2step - kwargs["num_inference_steps"] = 2 + # Default to 8step + kwargs["num_inference_steps"] = 8 output = self.ldm(prompt, **kwargs) diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py index 27fe63a9..b6d42747 100644 --- a/runner/app/routes/text_to_image.py +++ b/runner/app/routes/text_to_image.py @@ -26,7 +26,7 @@ class TextToImageParams(BaseModel): negative_prompt: str = "" safety_check: bool = True seed: int = None - num_inference_steps: int = 50 # TODO: Make optional. + num_inference_steps: int = 50 # NOTE: Hardcoded due to varying pipeline values. num_images_per_prompt: int = 1