diff --git a/hordelib/comfy_horde.py b/hordelib/comfy_horde.py index c03f6d5a..c2163c50 100644 --- a/hordelib/comfy_horde.py +++ b/hordelib/comfy_horde.py @@ -24,7 +24,7 @@ from loguru import logger from hordelib.settings import UserSettings -from hordelib.utils.ioredirect import OutputCollector +from hordelib.utils.ioredirect import ComfyUIProgress, OutputCollector from hordelib.config_path import get_hordelib_path # Note It may not be abundantly clear with no context what is going on below, and I will attempt to clarify: @@ -669,7 +669,12 @@ def send_sync(self, label: str, data: dict, _id: str) -> None: # Execute the named pipeline and pass the pipeline the parameter provided. # For the horde we assume the pipeline returns an array of images. - def _run_pipeline(self, pipeline: dict, params: dict) -> list[dict] | None: + def _run_pipeline( + self, + pipeline: dict, + params: dict, + comfyui_progress_callback: typing.Callable[[ComfyUIProgress, str], None] | None = None, + ) -> list[dict] | None: if _comfy_current_loaded_models is None: raise RuntimeError("hordelib.initialise() must be called before using comfy_horde.") # Wipe any previous images, if they exist. @@ -692,7 +697,7 @@ def _run_pipeline(self, pipeline: dict, params: dict) -> list[dict] | None: # The client_id parameter here is just so we receive comfy callbacks for debugging. # We pretend we are a web client and want async callbacks. - stdio = OutputCollector() + stdio = OutputCollector(comfyui_progress_callback=comfyui_progress_callback) with contextlib.redirect_stdout(stdio), contextlib.redirect_stderr(stdio): # validate_prompt from comfy returns [bool, str, list] # Which gives us these nice hardcoded list indexes, which valid[2] is the output node list @@ -720,7 +725,12 @@ def _run_pipeline(self, pipeline: dict, params: dict) -> list[dict] | None: return self.images # Run a pipeline that returns an image in pixel space - def run_image_pipeline(self, pipeline, params: dict) -> list[dict[str, typing.Any]]: + def run_image_pipeline( + self, + pipeline, + params: dict, + comfyui_progress_callback: typing.Callable[[ComfyUIProgress, str], None] | None = None, + ) -> list[dict[str, typing.Any]]: # From the horde point of view, let us assume the output we are interested in # is always in a HordeImageOutput node named "output_image". This is an array of # dicts of the form: @@ -748,7 +758,7 @@ def run_image_pipeline(self, pipeline, params: dict) -> list[dict[str, typing.An if idle_time > 1 and UserSettings.enable_idle_time_warning.active: logger.warning(f"No job ran for {round(idle_time, 3)} seconds") - result = self._run_pipeline(pipeline_data, params) + result = self._run_pipeline(pipeline_data, params, comfyui_progress_callback) if result: return result diff --git a/hordelib/consts.py b/hordelib/consts.py index c6334082..8ba150a0 100644 --- a/hordelib/consts.py +++ b/hordelib/consts.py @@ -6,7 +6,7 @@ from hordelib.config_path import get_hordelib_path -COMFYUI_VERSION = "f81dbe26e2e363c28ad043db67b59c11bb33f446" +COMFYUI_VERSION = "2a813c3b09292c9aeab622ddf65d77e5d8171d0d" """The exact version of ComfyUI version to load.""" REMOTE_PROXY = "" diff --git a/hordelib/horde.py b/hordelib/horde.py index 98e7f14d..aa43f874 100644 --- a/hordelib/horde.py +++ b/hordelib/horde.py @@ -10,6 +10,7 @@ import typing from collections.abc import Callable from copy import deepcopy +from enum import Enum, auto from horde_sdk.ai_horde_api.apimodels import ImageGenerateJobPopResponse from horde_sdk.ai_horde_api.apimodels.base import ( @@ -18,12 +19,32 @@ from horde_sdk.ai_horde_api.consts import KNOWN_FACEFIXERS, KNOWN_UPSCALERS, METADATA_TYPE, METADATA_VALUE from loguru import logger from PIL import Image +from pydantic import BaseModel from hordelib.comfy_horde import Comfy_Horde from hordelib.consts import MODEL_CATEGORY_NAMES from hordelib.shared_model_manager import SharedModelManager from hordelib.utils.dynamicprompt import DynamicPromptParser from hordelib.utils.image_utils import ImageUtils +from hordelib.utils.ioredirect import ComfyUIProgress + + +class ProgressState(Enum): + """The state of the progress report""" + + started = auto() + progress = auto() + post_processing = auto() + finished = auto() + + +class ProgressReport(BaseModel): + """A progress message sent to a callback""" + + hordelib_progress_state: ProgressState + comfyui_progress: ComfyUIProgress | None = None + progress: float | None = None + hordelib_message: str | None = None class ResultingImageReturn: @@ -869,6 +890,7 @@ def _inference( payload: dict, *, single_image_expected: bool = True, + comfyui_progress_callback: Callable[[ComfyUIProgress, str], None] | None = None, ) -> list[ResultingImageReturn] | ResultingImageReturn: payload, pipeline_data, faults = self._get_validated_payload_and_pipeline_data(payload) @@ -901,7 +923,7 @@ def _inference( # Call the inference pipeline # logger.debug(payload) - images = self.generator.run_image_pipeline(pipeline_data, payload) + images = self.generator.run_image_pipeline(pipeline_data, payload, comfyui_progress_callback) results = self._process_results(images) ret_results = [ @@ -920,8 +942,15 @@ def _inference( return ret_results - def basic_inference(self, payload: dict | ImageGenerateJobPopResponse) -> list[ResultingImageReturn]: + def basic_inference( + self, + payload: dict | ImageGenerateJobPopResponse, + *, + progress_callback: Callable[[ProgressReport], None] | None = None, + ) -> list[ResultingImageReturn]: post_processing_requested: list[str] | None = None + if isinstance(payload, dict): + post_processing_requested = payload.get("post_processing") faults = [] if isinstance(payload, ImageGenerateJobPopResponse): # TODO move this to _inference() @@ -968,7 +997,37 @@ def basic_inference(self, payload: dict | ImageGenerateJobPopResponse) -> list[R sub_payload["model"] = payload.model payload = sub_payload - result = self._inference(payload, single_image_expected=False) + if progress_callback is not None: + try: + progress_callback( + ProgressReport( + hordelib_progress_state=ProgressState.started, + hordelib_message="Initiating inference...", + progress=0, + ), + ) + except Exception as e: + logger.error(f"Progress callback failed ({type(e)}): {e}") + + def _default_progress_callback(comfyui_progress: ComfyUIProgress, message: str) -> None: + nonlocal progress_callback + if progress_callback is not None: + try: + progress_callback( + ProgressReport( + hordelib_progress_state=ProgressState.progress, + hordelib_message=message, + comfyui_progress=comfyui_progress, + ), + ) + except Exception as e: + logger.error(f"Progress callback failed ({type(e)}): {e}") + + result = self._inference( + payload, + single_image_expected=False, + comfyui_progress_callback=_default_progress_callback, + ) if not isinstance(result, list): raise RuntimeError(f"Expected a list of PIL.Image.Image but got {type(result)}") @@ -981,11 +1040,29 @@ def basic_inference(self, payload: dict | ImageGenerateJobPopResponse) -> list[R post_processed: list[ResultingImageReturn] | None = None if post_processing_requested is not None: + if progress_callback is not None: + try: + progress_callback( + ProgressReport( + hordelib_progress_state=ProgressState.post_processing, + hordelib_message="Post Processing.", + ), + ) + except Exception as e: + logger.error(f"Progress callback failed ({type(e)}): {e}") + post_processed = [] for ret in return_list: single_image_faults = [] final_image = ret.image final_rawpng = ret.rawpng + + # Ensure facefixers always happen first + post_processing_requested = sorted( + post_processing_requested, + key=lambda x: 1 if x in KNOWN_FACEFIXERS.__members__ else 0, + ) + for post_processing in post_processing_requested: if ( post_processing in KNOWN_UPSCALERS.__members__ @@ -1025,6 +1102,18 @@ def basic_inference(self, payload: dict | ImageGenerateJobPopResponse) -> list[R ResultingImageReturn(image=final_image, rawpng=final_rawpng, faults=single_image_faults), ) + if progress_callback is not None: + try: + progress_callback( + ProgressReport( + hordelib_progress_state=ProgressState.finished, + hordelib_message="Inference complete.", + progress=100, + ), + ) + except Exception as e: + logger.error(f"Progress callback failed ({type(e)}): {e}") + if post_processed is not None: logger.debug(f"Post-processing complete. Returning {len(post_processed)} images.") return post_processed diff --git a/hordelib/initialisation.py b/hordelib/initialisation.py index d3359db5..0577637c 100644 --- a/hordelib/initialisation.py +++ b/hordelib/initialisation.py @@ -26,6 +26,7 @@ def initialise( process_id: int | None = None, force_normal_vram_mode: bool = True, extra_comfyui_args: list[str] | None = None, + disable_smart_memory: bool = False, ): """Initialise hordelib. This is required before using any other hordelib functions. @@ -75,6 +76,7 @@ def initialise( hordelib.comfy_horde.do_comfy_import( force_normal_vram_mode=force_normal_vram_mode, extra_comfyui_args=extra_comfyui_args, + disable_smart_memory=disable_smart_memory, ) vram_on_start_free = hordelib.comfy_horde.get_torch_free_vram_mb() diff --git a/hordelib/nodes/node_model_loader.py b/hordelib/nodes/node_model_loader.py index 1615ac6b..111f2cee 100644 --- a/hordelib/nodes/node_model_loader.py +++ b/hordelib/nodes/node_model_loader.py @@ -94,12 +94,13 @@ def load_checkpoint( SharedModelManager.manager._models_in_ram = {} ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) - result = comfy.sd.load_checkpoint_guess_config( - ckpt_path, - output_vae=True, - output_clip=True, - embedding_directory=folder_paths.get_folder_paths("embeddings"), - ) + with torch.no_grad(): + result = comfy.sd.load_checkpoint_guess_config( + ckpt_path, + output_vae=True, + output_clip=True, + embedding_directory=folder_paths.get_folder_paths("embeddings"), + ) SharedModelManager.manager._models_in_ram[horde_in_memory_name] = result, will_load_loras diff --git a/hordelib/nodes/node_upscale_model_loader.py b/hordelib/nodes/node_upscale_model_loader.py index c8e920dd..ff32ee80 100644 --- a/hordelib/nodes/node_upscale_model_loader.py +++ b/hordelib/nodes/node_upscale_model_loader.py @@ -20,6 +20,8 @@ def INPUT_TYPES(s): def load_model(self, model_name): model_path = folder_paths.get_full_path("upscale_models", model_name) sd = comfy.utils.load_torch_file(model_path, safe_load=True) + if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: + sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": ""}) out = model_loading.load_state_dict(sd).eval() return (out,) diff --git a/hordelib/utils/ioredirect.py b/hordelib/utils/ioredirect.py index 6c55cffe..9b4a66f9 100644 --- a/hordelib/utils/ioredirect.py +++ b/hordelib/utils/ioredirect.py @@ -1,9 +1,39 @@ import io from collections import deque +from collections.abc import Callable +from enum import Enum from time import perf_counter import regex from loguru import logger +from pydantic import BaseModel + + +class ComfyUIProgressUnit(Enum): + """An enum to represent the different types of progress bars that ComfyUI can output. + + This is used to determine how to parse the progress bar and log it. + """ + + ITERATIONS_PER_SECOND = 1 + SECONDS_PER_ITERATION = 2 + UNKNOWN = 3 + + +class ComfyUIProgress(BaseModel): + """A dataclass to represent the progress of a ComfyUI job. + + This is used to determine how to parse the progress bar and log it. + """ + + percent: int + current_step: int + total_steps: int + rate: float + rate_unit: ComfyUIProgressUnit + + def __str__(self): + return f"{self.percent}%: {self.current_step}/{self.total_steps} ({self.rate} {self.rate_unit})" class OutputCollector(io.TextIOWrapper): @@ -16,9 +46,17 @@ class OutputCollector(io.TextIOWrapper): start_time: float slow_message_count: int = 0 - def __init__(self): + capture_deque: deque + + comfyui_progress_callback: Callable[[ComfyUIProgress, str], None] | None = None + """A callback function that is called when a progress bar is detected in the output. The callback function should \ + accept two arguments: a ComfyUIProgress object and a string. The ComfyUIProgress object contains the parsed \ + progress bar information, and the string contains the original message that was captured.""" + + def __init__(self, *, comfyui_progress_callback: Callable[[ComfyUIProgress, str], None] | None = None): logger.disable("tqdm") # just.. no - self.deque = deque() + self.capture_deque = deque() + self.comfyui_progress_callback = comfyui_progress_callback self.start_time = perf_counter() def write(self, message: str): @@ -44,7 +82,7 @@ def write(self, message: str): if not matches: logger.debug(f"Unknown progress bar format?: {message}") - self.deque.append(message) + self.capture_deque.append(message) return # Remove everything in between '|' and '|' @@ -84,11 +122,27 @@ def write(self, message: str): ): logger.info(message) - self.deque.append(message) + if self.comfyui_progress_callback: + self.comfyui_progress_callback( + ComfyUIProgress( + percent=int(matches.group(1)), + current_step=found_current_step, + total_steps=found_total_steps, + rate=float(iteration_rate) if iteration_rate != "?" else -1.0, + rate_unit=( + ComfyUIProgressUnit.ITERATIONS_PER_SECOND + if is_iterations_per_second + else ComfyUIProgressUnit.SECONDS_PER_ITERATION + ), + ), + message, + ) + + self.capture_deque.append(message) def set_size(self, size): - while len(self.deque) > size: - self.deque.popleft() + while len(self.capture_deque) > size: + self.capture_deque.popleft() def flush(self): pass @@ -102,5 +156,5 @@ def close(self): def replay(self): logger.debug("Replaying output. Seconds in parentheses is the elapsed time spent in ComfyUI. ") - while len(self.deque): - logger.debug(self.deque.popleft()) + while len(self.capture_deque): + logger.debug(self.capture_deque.popleft()) diff --git a/images_expected/text_to_image_callback_0.png b/images_expected/text_to_image_callback_0.png new file mode 100644 index 00000000..f58294a8 Binary files /dev/null and b/images_expected/text_to_image_callback_0.png differ diff --git a/images_expected/text_to_image_callback_1.png b/images_expected/text_to_image_callback_1.png new file mode 100644 index 00000000..020b2ee3 Binary files /dev/null and b/images_expected/text_to_image_callback_1.png differ diff --git a/tests/conftest.py b/tests/conftest.py index 47bb70ca..f8962842 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,7 +35,7 @@ def init_horde(): import hordelib - hordelib.initialise(setup_logging=True, logging_verbosity=5) + hordelib.initialise(setup_logging=True, logging_verbosity=5, disable_smart_memory=True) from hordelib.settings import UserSettings UserSettings.set_ram_to_leave_free_mb("100%") diff --git a/tests/test_horde_inference.py b/tests/test_horde_inference.py index 7e0c5e7e..e4858382 100644 --- a/tests/test_horde_inference.py +++ b/tests/test_horde_inference.py @@ -385,3 +385,75 @@ def test_text_to_image_hires_fix_n_iter( img_pairs_to_check.append((f"images_expected/{img_filename}", image_result.image)) assert check_list_inference_images_similarity(img_pairs_to_check) + + def test_callback_with_post_processors( + self, + hordelib_instance: HordeLib, + stable_diffusion_model_name_for_testing: str, + ): + + data = { + "sampler_name": "k_euler", + "cfg_scale": 7.5, + "denoising_strength": 1.0, + "seed": 1312, + "height": 512, + "width": 512, + "karras": False, + "tiling": False, + "hires_fix": True, + "clip_skip": 1, + "control_type": None, + "image_is_control": False, + "return_control_map": False, + "prompt": "a portrait of an intense woman looking at the camera", + "ddim_steps": 25, + "n_iter": 2, + "post_processing": ["4x_AnimeSharp", "CodeFormers"], + "model": stable_diffusion_model_name_for_testing, + } + + from hordelib.horde import ProgressReport, ProgressState + + starting_messages = 0 + post_processing_messages = 0 + finished_messages = 0 + + def callback_function(progress_report: ProgressReport): + assert progress_report is not None + if progress_report.hordelib_progress_state == ProgressState.started: + nonlocal starting_messages + starting_messages += 1 + elif progress_report.hordelib_progress_state == ProgressState.post_processing: + nonlocal post_processing_messages + post_processing_messages += 1 + elif progress_report.hordelib_progress_state == ProgressState.finished: + nonlocal finished_messages + finished_messages += 1 + + if progress_report.comfyui_progress is not None: + assert progress_report.comfyui_progress.rate == -1 or progress_report.comfyui_progress.rate > 0 + assert progress_report.hordelib_progress_state == ProgressState.progress + assert progress_report.comfyui_progress.current_step >= 0 + assert progress_report.comfyui_progress.total_steps > 0 + + image_results = hordelib_instance.basic_inference(data, progress_callback=callback_function) + + assert len(image_results) == 2 + assert starting_messages == 1 + assert post_processing_messages == 1 + assert finished_messages == 1 + + image_filename_base = "text_to_image_callback_{0}.png" + + for i, image_result in enumerate(image_results): + assert image_result.image is not None + assert isinstance(image_result.image, Image.Image) + + image_filename = image_filename_base.format(i) + image_result.image.save(f"images/{image_filename}", quality=100) + + assert check_single_inference_image_similarity( + f"images_expected/{image_filename}", + image_result.image, + )