From cec9b2b4c1cfa84ae04e638b16452d856a1606c4 Mon Sep 17 00:00:00 2001 From: tazlin Date: Fri, 23 Feb 2024 16:53:47 -0500 Subject: [PATCH] fix: correct logic for missing source_image --- hordelib/horde.py | 47 +++++++++++++++++++++++------------------------ 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/hordelib/horde.py b/hordelib/horde.py index 5b0e9614..4bf79175 100644 --- a/hordelib/horde.py +++ b/hordelib/horde.py @@ -351,32 +351,15 @@ def _apply_aihorde_compatibility_hacks(self, payload: dict) -> tuple[dict, list[ found_model_on_disk = True if SharedModelManager.manager.compvis.model_reference[model].get("inpainting") is True: - if "source_mask" not in payload or payload["source_mask"] is None: - source_image = payload.get("source_image") - if not isinstance(source_image, Image.Image): - logger.warning( - "Inpainting model detected, but no source image provided. Using a noise image.", - ) - - faults.append( - GenMetadataEntry( - type=METADATA_TYPE.source_image, - value=METADATA_VALUE.parse_failed, - ), - ) - - payload["source_image"] = ImageUtils.create_noise_image( - payload.get("width"), - payload.get("height"), - ) - - if not isinstance(source_image, Image.Image): - raise RuntimeError("source_image is not a valid PIL image") - + source_image = payload.get("source_image") + if ( + ("source_mask" not in payload or payload["source_mask"] is None) + and "source_image" in payload + and isinstance(source_image, Image.Image) + ): if not ImageUtils.has_alpha_channel(source_image): - # set mask to an all alpha image logger.warning( - "Inpainting model detected, but no source mask provided. Using an all alpha image.", + "Inpainting model detected, but no source mask provided. Using an all white image.", ) faults.append( @@ -390,6 +373,22 @@ def _apply_aihorde_compatibility_hacks(self, payload: dict) -> tuple[dict, list[ source_image.height, ) + if source_image is None: + logger.warning( + "Inpainting model detected, but no source image provided. Using a noise image.", + ) + + faults.append( + GenMetadataEntry( + type=METADATA_TYPE.source_image, + value=METADATA_VALUE.parse_failed, + ), + ) + payload["source_image"] = ImageUtils.create_noise_image( + payload["width"], + payload["height"], + ) + else: # The node may be a post processor, so we check the other model managers post_processor_model_managers = SharedModelManager.manager.get_model_manager_instances(