Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: free resources less often; fix: handle non-standard inpainting payloads #199

Merged
merged 5 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/run_memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

hordelib.initialise(setup_logging=False)

from hordelib.comfy_horde import cleanup
from hordelib.horde import HordeLib
from hordelib.settings import UserSettings
from hordelib.shared_model_manager import SharedModelManager
Expand Down Expand Up @@ -155,7 +154,7 @@ def main():
model_index += 1
how_far -= 1
# That would have pushed something to disk, force a memory cleanup
cleanup()
# cleanup()
report_ram()

logger.warning("Loaded all models")
Expand Down
37 changes: 15 additions & 22 deletions hordelib/comfy_horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@


# isort: off
def do_comfy_import(force_normal_vram_mode: bool = False, extra_comfyui_args: list[str] | None = None) -> None:
def do_comfy_import(
force_normal_vram_mode: bool = False,
extra_comfyui_args: list[str] | None = None,
disable_smart_memory: bool = False,
) -> None:
global _comfy_current_loaded_models
global _comfy_load_models_gpu
global _comfy_nodes, _comfy_PromptExecutor, _comfy_validate_prompt
Expand All @@ -109,9 +113,9 @@ def do_comfy_import(force_normal_vram_mode: bool = False, extra_comfyui_args: li
global _comfy_free_memory, _comfy_cleanup_models, _comfy_soft_empty_cache
global _canny, _hed, _leres, _midas, _mlsd, _openpose, _pidinet, _uniformer

logger.info("Disabling smart memory")

sys.argv.append("--disable-smart-memory")
if disable_smart_memory:
logger.info("Disabling smart memory")
sys.argv.append("--disable-smart-memory")

if force_normal_vram_mode:
logger.info("Forcing normal vram mode")
Expand Down Expand Up @@ -222,8 +226,8 @@ def recursive_output_delete_if_changed_hijack(prompt: dict, old_prompt, outputs,
return _comfy_recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item)


def cleanup():
_comfy_soft_empty_cache()
# def cleanup():
# _comfy_soft_empty_cache()


def unload_all_models_vram():
Expand Down Expand Up @@ -271,17 +275,6 @@ def get_torch_free_vram_mb():
return round(_comfy_get_free_memory() / (1024 * 1024))


def garbage_collect():
logger.debug("Comfy_Horde garbage_collect called")
gc.collect()
if not torch.cuda.is_available():
logger.debug("CUDA not available, skipping cuda empty cache")
return
if torch.version.cuda:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()


class Comfy_Horde:
"""Handles horde-specific behavior against ComfyUI."""

Expand Down Expand Up @@ -718,11 +711,11 @@ def _run_pipeline(self, pipeline: dict, params: dict) -> list[dict] | None:

stdio.replay()

# Check if there are any resource to clean up
cleanup()
if time.time() - self._gc_timer > Comfy_Horde.GC_TIME:
self._gc_timer = time.time()
garbage_collect()
# # Check if there are any resource to clean up
# cleanup()
# if time.time() - self._gc_timer > Comfy_Horde.GC_TIME:
# self._gc_timer = time.time()
# garbage_collect()

return self.images

Expand Down
158 changes: 115 additions & 43 deletions hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,55 +319,127 @@ def _validate_data_structure(self, data, schema_definition=PAYLOAD_SCHEMA):

return data

def _apply_aihorde_compatibility_hacks(self, payload):
def _apply_aihorde_compatibility_hacks(self, payload: dict) -> tuple[dict, list[GenMetadataEntry]]:
"""For use by the AI Horde worker we require various counterintuitive hacks to the payload data.

We encapsulate all of this implicit witchcraft in one function, here.
"""
faults: list[GenMetadataEntry] = []

if SharedModelManager.manager.compvis is None:
raise RuntimeError("Cannot use AI Horde compatibility hacks without compvis loaded!")

payload = deepcopy(payload)

if payload.get("model"):
payload["model_name"] = payload["model"]
# Comfy expects the "model" key to be the filename
# But we are also sending the "generic" model name along in key "model_name" in order to be able
# To look it up in the model manager.
if SharedModelManager.manager.compvis.is_model_available(payload["model"]):
model_files = SharedModelManager.manager.compvis.get_model_filenames(payload["model"])
payload["model"] = model_files[0]["file_path"]
for file_entry in model_files:
# If we have a file_type, we also add to the payload
# each file_path with the key being the file_type
# This is then defined in PAYLOAD_TO_PIPELINE_PARAMETER_MAPPING
# to be injected in the right part of the pipeline
if "file_type" in file_entry:
payload[file_entry["file_type"]] = file_entry["file_path"]
else:
post_processor_model_managers = SharedModelManager.manager.get_model_manager_instances(
[MODEL_CATEGORY_NAMES.codeformer, MODEL_CATEGORY_NAMES.esrgan, MODEL_CATEGORY_NAMES.gfpgan],
)
model = payload.get("model")

if model is None:
raise RuntimeError("No model specified in payload")

# This is translated to "horde_model_name" later for compvis models and used as is for post processors
payload["model_name"] = model

found_model_in_ref = False
found_model_on_disk = False
model_files: list[dict] = [{}]

if model in SharedModelManager.manager.compvis.model_reference:
found_model_in_ref = True

if SharedModelManager.manager.compvis.is_model_available(model):
model_files = SharedModelManager.manager.compvis.get_model_filenames(model)
found_model_on_disk = True

if SharedModelManager.manager.compvis.model_reference[model].get("inpainting") is True:
if payload.get("source_processing") not in ["inpainting", "outpainting"]:
logger.warning(
"Inpainting model detected, but source processing not set to inpainting or outpainting.",
)

payload["source_processing"] = "inpainting"

source_image = payload.get("source_image")
source_mask = payload.get("source_mask")

if source_image is None or not isinstance(source_image, Image.Image):
logger.warning(
"Inpainting model detected, but source image is not a valid image. Using a noise image.",
)
faults.append(
GenMetadataEntry(
type=METADATA_TYPE.source_image,
value=METADATA_VALUE.parse_failed,
),
)
payload["source_image"] = ImageUtils.create_noise_image(
payload["width"],
payload["height"],
)

found_model = False
source_image = payload.get("source_image")

for post_processor_model_manager in post_processor_model_managers:
if post_processor_model_manager.is_model_available(payload["model"]):
model_files = post_processor_model_manager.get_model_filenames(payload["model"])
payload["model"] = model_files[0]["file_path"]
found_model = True
if source_mask is None and (
source_image is None
or (isinstance(source_image, Image.Image) and not ImageUtils.has_alpha_channel(source_image))
):
logger.warning(
"Inpainting model detected, but no source mask provided. Using an all white mask.",
)
faults.append(
GenMetadataEntry(
type=METADATA_TYPE.source_mask,
value=METADATA_VALUE.parse_failed,
),
)
payload["source_mask"] = ImageUtils.create_white_image(
source_image.width if source_image else payload["width"],
source_image.height if source_image else payload["height"],
)

else:
# The node may be a post processor, so we check the other model managers
post_processor_model_managers = SharedModelManager.manager.get_model_manager_instances(
[MODEL_CATEGORY_NAMES.codeformer, MODEL_CATEGORY_NAMES.esrgan, MODEL_CATEGORY_NAMES.gfpgan],
)

for post_processor_model_manager in post_processor_model_managers:
if model in post_processor_model_manager.model_reference:
found_model_in_ref = True
if post_processor_model_manager.is_model_available(model):
model_files = post_processor_model_manager.get_model_filenames(model)
found_model_on_disk = True
break

if not found_model_in_ref:
raise RuntimeError(f"Model {model} not found in model reference!")

if not found_model_on_disk:
raise RuntimeError(f"Model {model} not found on disk!")

if len(model_files) == 0 or (not isinstance(model_files[0], dict)) or "file_path" not in model_files[0]:
raise RuntimeError(f"Model {model} has no files in its reference entry!")

payload["model"] = model_files[0]["file_path"]
for file_entry in model_files:
if "file_type" in file_entry:
payload[file_entry["file_type"]] = file_entry["file_path"]

if not found_model:
raise RuntimeError(f"Model {payload['model']} not found! Is it in a Model Reference?")
# Rather than specify a scheduler, only karras or not karras is specified
if payload.get("karras", False):
payload["scheduler"] = "karras"
else:
payload["scheduler"] = "normal"

prompt = payload.get("prompt")

# Negative and positive prompts are merged together
if payload.get("prompt"):
if "###" in payload.get("prompt"):
split_prompts = payload.get("prompt").split("###")
if prompt is not None:
if "###" in prompt:
split_prompts = prompt.split("###")
payload["prompt"] = split_prompts[0]
payload["negative_prompt"] = split_prompts[1]
elif prompt == "":
logger.warning("Empty prompt detected, this is likely to produce poor results")

# Turn off hires fix if we're not generating a hires image, or if the params are just confused
try:
Expand Down Expand Up @@ -397,9 +469,9 @@ def _apply_aihorde_compatibility_hacks(self, payload):
# del payload["denoising_strength"]
# else:
# del payload["denoising_strength"]
return payload
return payload, faults

def _final_pipeline_adjustments(self, payload, pipeline_data):
def _final_pipeline_adjustments(self, payload, pipeline_data) -> tuple[dict, list[GenMetadataEntry]]:
payload = deepcopy(payload)
faults: list[GenMetadataEntry] = []

Expand Down Expand Up @@ -780,7 +852,7 @@ def _process_results(

def _get_validated_payload_and_pipeline_data(self, payload: dict) -> tuple[dict, dict, list[GenMetadataEntry]]:
# AIHorde hacks to payload
payload = self._apply_aihorde_compatibility_hacks(payload)
payload, compatibility_faults = self._apply_aihorde_compatibility_hacks(payload)
# Check payload types/values and normalise it's format
payload = self._validate_data_structure(payload)
# Resize the source image and mask to actual final width/height requested
Expand All @@ -789,8 +861,8 @@ def _get_validated_payload_and_pipeline_data(self, payload: dict) -> tuple[dict,
pipeline = self._get_appropriate_pipeline(payload)
# Final adjustments to the pipeline
pipeline_data = self.generator.get_pipeline_data(pipeline)
payload, faults = self._final_pipeline_adjustments(payload, pipeline_data)
return payload, pipeline_data, faults
payload, finale_adjustment_faults = self._final_pipeline_adjustments(payload, pipeline_data)
return payload, pipeline_data, compatibility_faults + finale_adjustment_faults

def _inference(
self,
Expand Down Expand Up @@ -985,7 +1057,7 @@ def basic_inference_rawpng(self, payload: dict) -> list[io.BytesIO]:
def image_upscale(self, payload) -> ResultingImageReturn:
logger.debug("image_upscale called")
# AIHorde hacks to payload
payload = self._apply_aihorde_compatibility_hacks(payload)
payload, compatibility_faults = self._apply_aihorde_compatibility_hacks(payload)
# Remember if we were passed width and height, we wouldn't normally be passed width and height
# because the upscale models upscale to a fixed multiple of image size. However, if we *are*
# passed a width and height we rescale the upscale output image to this size.
Expand All @@ -996,7 +1068,7 @@ def image_upscale(self, payload) -> ResultingImageReturn:
# Final adjustments to the pipeline
pipeline_name = "image_upscale"
pipeline_data = self.generator.get_pipeline_data(pipeline_name)
payload, faults = self._final_pipeline_adjustments(payload, pipeline_data)
payload, final_adjustment_faults = self._final_pipeline_adjustments(payload, pipeline_data)

# Run the pipeline

Expand All @@ -1007,7 +1079,7 @@ def image_upscale(self, payload) -> ResultingImageReturn:
return ResultingImageReturn(
ImageUtils.shrink_image(Image.open(images[0]["imagedata"]), width, height),
rawpng=None,
faults=faults,
faults=final_adjustment_faults,
)
result = self._process_results(images)
if len(result) != 1:
Expand All @@ -1017,18 +1089,18 @@ def image_upscale(self, payload) -> ResultingImageReturn:
if not isinstance(image, Image.Image):
raise RuntimeError(f"Expected a PIL.Image.Image but got {type(image)}")

return ResultingImageReturn(image=image, rawpng=rawpng, faults=faults)
return ResultingImageReturn(image=image, rawpng=rawpng, faults=compatibility_faults + final_adjustment_faults)

def image_facefix(self, payload) -> ResultingImageReturn:
logger.debug("image_facefix called")
# AIHorde hacks to payload
payload = self._apply_aihorde_compatibility_hacks(payload)
payload, compatibility_faults = self._apply_aihorde_compatibility_hacks(payload)
# Check payload types/values and normalise it's format
payload = self._validate_data_structure(payload)
# Final adjustments to the pipeline
pipeline_name = "image_facefix"
pipeline_data = self.generator.get_pipeline_data(pipeline_name)
payload, faults = self._final_pipeline_adjustments(payload, pipeline_data)
payload, final_adjustment_faults = self._final_pipeline_adjustments(payload, pipeline_data)

# Run the pipeline

Expand All @@ -1042,4 +1114,4 @@ def image_facefix(self, payload) -> ResultingImageReturn:
if not isinstance(image, Image.Image):
raise RuntimeError(f"Expected a PIL.Image.Image but got {type(image)}")

return ResultingImageReturn(image=image, rawpng=rawpng, faults=faults)
return ResultingImageReturn(image=image, rawpng=rawpng, faults=compatibility_faults + final_adjustment_faults)
10 changes: 10 additions & 0 deletions hordelib/nodes/node_image_loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# horde_image_loader.py
# Load images into the pipeline from PIL, not disk
import numpy as np
import PIL.Image
import torch
from loguru import logger


class HordeImageLoader:
Expand All @@ -17,6 +19,14 @@ def INPUT_TYPES(s):
FUNCTION = "load_image"

def load_image(self, image):
if image is None:
logger.error("Input image is None in HordeImageLoader - this is a bug, please report it!")
raise ValueError("Input image is None in HordeImageLoader")

if not isinstance(image, PIL.Image.Image):
logger.error(f"Input image is not a PIL Image, it is a {type(image)}")
raise ValueError(f"Input image is not a PIL Image, it is a {type(image)}")

new_image = image.convert("RGB")
new_image = np.array(new_image).astype(np.float32) / 255.0
new_image = torch.from_numpy(new_image)[None,]
Expand Down
Loading
Loading