Skip to content

Commit

Permalink
fix(runner): improve 'num_inference_steps' logic (livepeer#205)
Browse files Browse the repository at this point in the history
This commit prevents a Key Error from being thrown when the pipelines
are called directly.
  • Loading branch information
rickstaa authored and jjassonn committed Sep 18, 2024
1 parent 1db0a23 commit 0a66525
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 9 deletions.
5 changes: 3 additions & 2 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ def __call__(
self, prompt: str, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
num_inference_steps = kwargs.get("num_inference_steps", None)
safety_check = kwargs.pop("safety_check", True)

if seed is not None:
Expand All @@ -206,7 +205,9 @@ def __call__(
elif isinstance(seed, list):
kwargs["generator"] = [torch.Generator(get_torch_device()).manual_seed(s) for s in seed]

if num_inference_steps is None or num_inference_steps < 1:
if "num_inference_steps" in kwargs and (
kwargs["num_inference_steps"] is None or kwargs["num_inference_steps"] < 1
):
del kwargs["num_inference_steps"]

if self.model_id in ["stabilityai/sdxl-turbo", "stabilityai/sd-turbo"]:
Expand Down
5 changes: 3 additions & 2 deletions runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def __call__(
self, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
num_inference_steps = kwargs.get("num_inference_steps", None)
safety_check = kwargs.pop("safety_check", True)

if "decode_chunk_size" not in kwargs:
Expand All @@ -128,7 +127,9 @@ def __call__(
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

if num_inference_steps is None or num_inference_steps < 1:
if "num_inference_steps" in kwargs and (
kwargs["num_inference_steps"] is None or kwargs["num_inference_steps"] < 1
):
del kwargs["num_inference_steps"]

if safety_check:
Expand Down
5 changes: 3 additions & 2 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ def __call__(
self, prompt: str, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
num_inference_steps = kwargs.get("num_inference_steps", None)
safety_check = kwargs.pop("safety_check", True)

if seed is not None:
Expand All @@ -239,7 +238,9 @@ def __call__(
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

if num_inference_steps is None or num_inference_steps < 1:
if "num_inference_steps" in kwargs and (
kwargs["num_inference_steps"] is None or kwargs["num_inference_steps"] < 1
):
del kwargs["num_inference_steps"]

if (
Expand Down
7 changes: 4 additions & 3 deletions runner/app/pipelines/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def __call__(
self, prompt: str, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
num_inference_steps = kwargs.get("num_inference_steps", None)
safety_check = kwargs.pop("safety_check", True)

if seed is not None:
Expand All @@ -131,8 +130,10 @@ def __call__(
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

if num_inference_steps is None or num_inference_steps < 1:
kwargs.pop("num_inference_steps", None)
if "num_inference_steps" in kwargs and (
kwargs["num_inference_steps"] is None or kwargs["num_inference_steps"] < 1
):
del kwargs["num_inference_steps"]

# trying differnt configs of promp_embed for different models
try:
Expand Down

0 comments on commit 0a66525

Please sign in to comment.