Skip to content

Commit

Permalink
fix(runner): improve 'num_inference_steps' logic (#205)
Browse files Browse the repository at this point in the history
This commit prevents a Key Error from being thrown when the pipelines
are called directly.
  • Loading branch information
rickstaa authored Sep 18, 2024
1 parent af9bba3 commit d2a2545
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 8 deletions.
5 changes: 3 additions & 2 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def __call__(
self, prompt: str, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
num_inference_steps = kwargs.get("num_inference_steps", None)
safety_check = kwargs.pop("safety_check", True)

if seed is not None:
Expand All @@ -189,7 +188,9 @@ def __call__(
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

if num_inference_steps is None or num_inference_steps < 1:
if "num_inference_steps" in kwargs and (
kwargs["num_inference_steps"] is None or kwargs["num_inference_steps"] < 1
):
del kwargs["num_inference_steps"]

if (
Expand Down
5 changes: 3 additions & 2 deletions runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def __call__(
self, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
num_inference_steps = kwargs.get("num_inference_steps", None)
safety_check = kwargs.pop("safety_check", True)

if "decode_chunk_size" not in kwargs:
Expand All @@ -126,7 +125,9 @@ def __call__(
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

if num_inference_steps is None or num_inference_steps < 1:
if "num_inference_steps" in kwargs and (
kwargs["num_inference_steps"] is None or kwargs["num_inference_steps"] < 1
):
del kwargs["num_inference_steps"]

if safety_check:
Expand Down
5 changes: 3 additions & 2 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ def __call__(
self, prompt: str, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
num_inference_steps = kwargs.get("num_inference_steps", None)
safety_check = kwargs.pop("safety_check", True)

if seed is not None:
Expand All @@ -219,7 +218,9 @@ def __call__(
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

if num_inference_steps is None or num_inference_steps < 1:
if "num_inference_steps" in kwargs and (
kwargs["num_inference_steps"] is None or kwargs["num_inference_steps"] < 1
):
del kwargs["num_inference_steps"]

if (
Expand Down
5 changes: 3 additions & 2 deletions runner/app/pipelines/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def __call__(
self, prompt: str, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
num_inference_steps = kwargs.get("num_inference_steps", None)
safety_check = kwargs.pop("safety_check", True)

if seed is not None:
Expand All @@ -110,7 +109,9 @@ def __call__(
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

if num_inference_steps is None or num_inference_steps < 1:
if "num_inference_steps" in kwargs and (
kwargs["num_inference_steps"] is None or kwargs["num_inference_steps"] < 1
):
del kwargs["num_inference_steps"]

output = self.ldm(prompt, image=image, **kwargs)
Expand Down

0 comments on commit d2a2545

Please sign in to comment.