From 13c8617a096b3426d7ef1c6cf0a1483d3258ac50 Mon Sep 17 00:00:00 2001 From: tazlin Date: Fri, 23 Aug 2024 15:36:01 -0400 Subject: [PATCH] fix: support changes to 'prompt' caching mechanism In comfyui internals, 'prompt` actually refers to an entire set of nodes and settings. Previously, we were hijacking the `recursive_output_delete_if_changed` call, but a recent change in comfyui has switched this behavior to the class `IsChangedCache`. The hook in place before had more to do with detecting bugs to do with caching than anything else, so I've implemented a similar sort of hijack for `IsChangedCache`. --- hordelib/comfy_horde.py | 84 +++++++++++++---------------- hordelib/nodes/node_model_loader.py | 13 +++-- 2 files changed, 47 insertions(+), 50 deletions(-) diff --git a/hordelib/comfy_horde.py b/hordelib/comfy_horde.py index 72109f48..e30c275e 100644 --- a/hordelib/comfy_horde.py +++ b/hordelib/comfy_horde.py @@ -84,7 +84,7 @@ _comfy_cleanup_models: types.FunctionType _comfy_soft_empty_cache: types.FunctionType -_comfy_recursive_output_delete_if_changed: types.FunctionType +_comfy_is_changed_cache_get: types.FunctionType _canny: types.ModuleType _hed: types.ModuleType @@ -138,7 +138,6 @@ def do_comfy_import( global _comfy_current_loaded_models global _comfy_load_models_gpu global _comfy_nodes, _comfy_PromptExecutor, _comfy_validate_prompt - global _comfy_recursive_output_delete_if_changed global _comfy_folder_names_and_paths, _comfy_supported_pt_extensions global _comfy_load_checkpoint_guess_config global _comfy_get_torch_device, _comfy_get_free_memory, _comfy_get_total_memory @@ -169,10 +168,15 @@ def do_comfy_import( from execution import nodes as _comfy_nodes from execution import PromptExecutor as _comfy_PromptExecutor from execution import validate_prompt as _comfy_validate_prompt - from execution import recursive_output_delete_if_changed - _comfy_recursive_output_delete_if_changed = recursive_output_delete_if_changed # type: ignore - execution.recursive_output_delete_if_changed = recursive_output_delete_if_changed_hijack + # from execution import recursive_output_delete_if_changed + from execution import IsChangedCache + + global _comfy_is_changed_cache_get + _comfy_is_changed_cache_get = IsChangedCache.get # type: ignore + + IsChangedCache.get = IsChangedCache_get_hijack # type: ignore + from folder_paths import folder_names_and_paths as _comfy_folder_names_and_paths # type: ignore from folder_paths import supported_pt_extensions as _comfy_supported_pt_extensions # type: ignore from comfy.sd import load_checkpoint_guess_config as _comfy_load_checkpoint_guess_config @@ -197,22 +201,8 @@ def do_comfy_import( uniformer as _uniformer, ) - import comfy.model_management - - # comfy.model_management.vram_state = comfy.model_management.VRAMState.HIGH_VRAM - # comfy.model_management.set_vram_to = comfy.model_management.VRAMState.HIGH_VRAM - logger.info("Comfy_Horde initialised") - # def always_cpu(parameters, dtype): - # return torch.device("cpu") - - # comfy.model_management.unet_inital_load_device = always_cpu - # comfy.model_management.DISABLE_SMART_MEMORY = True - # comfy.model_management.lowvram_available = True - - # comfy.model_management.unet_offload_device = _unet_offload_device_hijack - log_free_ram() output_collector.replay() @@ -221,39 +211,39 @@ def do_comfy_import( _last_pipeline_settings_hash = "" +import PIL.Image + + +def default_json_serializer_pil_image(obj): + if isinstance(obj, PIL.Image.Image): + return str(hash(obj.__str__())) + return obj + + +def IsChangedCache_get_hijack(self, *args, **kwargs): + global _comfy_is_changed_cache_get + result = _comfy_is_changed_cache_get(self, *args, **kwargs) -def recursive_output_delete_if_changed_hijack(prompt: dict, old_prompt, outputs, current_item): global _last_pipeline_settings_hash - if current_item == "prompt": - try: - pipeline_settings_hash = hashlib.md5(json.dumps(prompt).encode("utf-8")).hexdigest() - logger.debug(f"pipeline_settings_hash: {pipeline_settings_hash}") - - if pipeline_settings_hash != _last_pipeline_settings_hash: - _last_pipeline_settings_hash = pipeline_settings_hash - logger.debug("Pipeline settings changed") - - if old_prompt: - old_pipeline_settings_hash = hashlib.md5(json.dumps(old_prompt).encode("utf-8")).hexdigest() - logger.debug(f"old_pipeline_settings_hash: {old_pipeline_settings_hash}") - if pipeline_settings_hash != old_pipeline_settings_hash: - logger.debug("Pipeline settings changed from old_prompt") - except TypeError: - logger.debug("could not print hash due to source image in payload") - if current_item == "prompt" or current_item == "negative_prompt": - try: - prompt_text = prompt[current_item]["inputs"]["text"] - prompt_hash = hashlib.md5(prompt_text.encode("utf-8")).hexdigest() - logger.debug(f"{current_item} hash: {prompt_hash}") - except KeyError: - pass - global _comfy_recursive_output_delete_if_changed - return _comfy_recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item) + prompt = self.dynprompt.original_prompt + + pipeline_settings_hash = hashlib.md5( + json.dumps(prompt, default=default_json_serializer_pil_image).encode(), + ).hexdigest() + + if pipeline_settings_hash != _last_pipeline_settings_hash: + _last_pipeline_settings_hash = pipeline_settings_hash + logger.debug(f"Pipeline settings changed: {pipeline_settings_hash}") + logger.debug(f"Cache length: {len(self.outputs_cache.cache)}") + logger.debug(f"Subcache length: {len(self.outputs_cache.subcaches)}") + + logger.debug(f"IsChangedCache.dynprompt.all_node_ids: {self.dynprompt.all_node_ids()}") + if result: + logger.debug(f"IsChangedCache.get: {result}") -# def cleanup(): -# _comfy_soft_empty_cache() + return result def unload_all_models_vram(): diff --git a/hordelib/nodes/node_model_loader.py b/hordelib/nodes/node_model_loader.py index 0105d019..f9e9dbc4 100644 --- a/hordelib/nodes/node_model_loader.py +++ b/hordelib/nodes/node_model_loader.py @@ -41,8 +41,8 @@ def load_checkpoint( horde_model_name: str, ckpt_name: str | None = None, file_type: str | None = None, - output_vae=True, - output_clip=True, + output_vae=True, # this arg is required by comfyui internals + output_clip=True, # this arg is required by comfyui internals preloading=False, ): log_free_ram() @@ -115,8 +115,15 @@ def load_checkpoint( if ckpt_name is not None and Path(ckpt_name).is_absolute(): ckpt_path = ckpt_name + elif ckpt_name is not None: + full_path = folder_paths.get_full_path("checkpoints", ckpt_name) + + if full_path is None: + raise ValueError(f"Checkpoint {ckpt_name} not found.") + + ckpt_path = full_path else: - ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + raise ValueError("No checkpoint name provided.") with torch.no_grad(): result = comfy.sd.load_checkpoint_guess_config(