Skip to content

Commit

Permalink
Merge pull request #196 from Haidra-Org/main
Browse files Browse the repository at this point in the history
Feat: Adds support for Stable Cascade (#195)
  • Loading branch information
tazlin authored Feb 22, 2024
2 parents 9f0f693 + 7a06553 commit 3d6f9ff
Show file tree
Hide file tree
Showing 23 changed files with 1,058 additions and 32 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,4 @@ tmp/
*.ckpt
*.pth
.gitignore
pipeline_debug.json
4 changes: 3 additions & 1 deletion hordelib/comfy_horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,9 +691,11 @@ def _run_pipeline(self, pipeline: dict, params: dict) -> list[dict] | None:
# This is useful for dumping the entire pipeline to the terminal when
# developing and debugging new pipelines. A badly structured pipeline
# file just results in a cryptic error from comfy
pretty_pipeline = pformat(pipeline)
if False: # This isn't here, Tazlin :)
pretty_pipeline = pformat(pipeline)
logger.warning(pretty_pipeline)
with open("pipeline_debug.json", "w") as outfile:
outfile.write(json.dumps(pipeline, indent=4))

# The client_id parameter here is just so we receive comfy callbacks for debugging.
# We pretend we are a web client and want async callbacks.
Expand Down
2 changes: 1 addition & 1 deletion hordelib/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from hordelib.config_path import get_hordelib_path

COMFYUI_VERSION = "c5a369a33ddb622827552716d9b0119035a2e666"
COMFYUI_VERSION = "18c151b3e3f6838fab4028e7a8ba526e30e610d3"
"""The exact version of ComfyUI version to load."""

REMOTE_PROXY = ""
Expand Down
54 changes: 50 additions & 4 deletions hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class HordeLib:
"scheduler": {"datatype": str, "values": SCHEDULERS, "default": "normal"},
"tiling": {"datatype": bool, "default": False},
"model_name": {"datatype": str, "default": "stable_diffusion"}, # Used internally by hordelib
"stable_cascade_stage_b": {"datatype": str, "default": None}, # Stable Cascade
"stable_cascade_stage_c": {"datatype": str, "default": None}, # Stable Cascade
}

LORA_SCHEMA = {
Expand Down Expand Up @@ -192,6 +194,22 @@ class HordeLib:
"upscale_sampler.sampler_name": "sampler_name",
"controlnet_apply.strength": "control_strength",
"controlnet_model_loader.control_net_name": "control_type",
# Stable Cascade
"stable_cascade_empty_latent_image.width": "width",
"stable_cascade_empty_latent_image.height": "height",
"stable_cascade_empty_latent_image.batch_size": "n_iter",
"sampler_stage_c.sampler_name": "sampler_name",
"sampler_stage_b.sampler_name": "sampler_name",
"sampler_stage_c.cfg": "cfg_scale",
"sampler_stage_c.denoise": "denoising_strength",
"sampler_stage_b.seed": "seed",
"sampler_stage_c.seed": "seed",
"model_loader_stage_c.ckpt_name": "stable_cascade_stage_c",
"model_loader_stage_c.model_name": "stable_cascade_stage_c",
"model_loader_stage_c.horde_model_name": "model_name",
"model_loader_stage_b.ckpt_name": "stable_cascade_stage_b",
"model_loader_stage_b.model_name": "stable_cascade_stage_b",
"model_loader_stage_b.horde_model_name": "model_name",
}

_comfyui_callback: Callable[[str, dict, str], None] | None = None
Expand Down Expand Up @@ -310,8 +328,19 @@ def _apply_aihorde_compatibility_hacks(self, payload):

if payload.get("model"):
payload["model_name"] = payload["model"]
# Comfy expects the "model" key to be the filename
# But we are also sending the "generic" model name along in key "model_name" in order to be able
# To look it up in the model manager.
if SharedModelManager.manager.compvis.is_model_available(payload["model"]):
payload["model"] = SharedModelManager.manager.compvis.get_model_filename(payload["model"])
model_files = SharedModelManager.manager.compvis.get_model_filenames(payload["model"])
payload["model"] = model_files[0]["file_path"]
for file_entry in model_files:
# If we have a file_type, we also add to the payload
# each file_path with the key being the file_type
# This is then defined in PAYLOAD_TO_PIPELINE_PARAMETER_MAPPING
# to be injected in the right part of the pipeline
if "file_type" in file_entry:
payload[file_entry["file_type"]] = file_entry["file_path"]
else:
post_processor_model_managers = SharedModelManager.manager.get_model_manager_instances(
[MODEL_CATEGORY_NAMES.codeformer, MODEL_CATEGORY_NAMES.esrgan, MODEL_CATEGORY_NAMES.gfpgan],
Expand All @@ -321,12 +350,12 @@ def _apply_aihorde_compatibility_hacks(self, payload):

for post_processor_model_manager in post_processor_model_managers:
if post_processor_model_manager.is_model_available(payload["model"]):
payload["model"] = post_processor_model_manager.get_model_filename(payload["model"])
model_files = post_processor_model_manager.get_model_filenames(payload["model"])
payload["model"] = model_files[0]["file_path"]
found_model = True

if not found_model:
raise RuntimeError(f"Model {payload['model']} not found! Is it in a Model Reference?")

# Rather than specify a scheduler, only karras or not karras is specified
if payload.get("karras", False):
payload["scheduler"] = "karras"
Expand Down Expand Up @@ -368,7 +397,6 @@ def _apply_aihorde_compatibility_hacks(self, payload):
# del payload["denoising_strength"]
# else:
# del payload["denoising_strength"]

return payload

def _final_pipeline_adjustments(self, payload, pipeline_data):
Expand Down Expand Up @@ -615,10 +643,24 @@ def _final_pipeline_adjustments(self, payload, pipeline_data):
pipeline_params[newkey] = payload.get(key)
else:
logger.error(f"Parameter {key} not found")
# We inject these parameters to ensure the HordeCheckpointLoader knows what file to load, if necessary
# We don't want to hardcode this into the pipeline.json as we export this directly from ComfyUI
# and don't want to have to rememebr to re-add those keys
if "model_loader_stage_c.ckpt_name" in pipeline_params:
pipeline_params["model_loader_stage_c.file_type"] = "stable_cascade_stage_c"
if "model_loader_stage_b.ckpt_name" in pipeline_params:
pipeline_params["model_loader_stage_b.file_type"] = "stable_cascade_stage_b"
pipeline_params["model_loader.file_type"] = None # To allow normal SD pipelines to keep working

# Inject our model manager
# pipeline_params["model_loader.model_manager"] = SharedModelManager
pipeline_params["model_loader.will_load_loras"] = bool(payload.get("loras"))
pipeline_params["model_loader_stage_c.will_load_loras"] = False # FIXME: Once we support loras
# Does this have to be required var in the modelloader?
pipeline_params["model_loader_stage_c.seamless_tiling_enabled"] = False
pipeline_params["model_loader_stage_b.will_load_loras"] = False # FIXME: Once we support loras
# Does this have to be required var in the modelloader?
pipeline_params["model_loader_stage_b.seamless_tiling_enabled"] = False

# For hires fix, change the image sizes as we create an intermediate image first
if payload.get("hires_fix", False):
Expand Down Expand Up @@ -685,6 +727,10 @@ def _get_appropriate_pipeline(self, params):
# image_upscale

# controlnet, controlnet_hires_fix controlnet_annotator
if params.get("model_name"):
model_details = SharedModelManager.manager.compvis.get_model_reference_info(params["model_name"])
if model_details.get("baseline") == "stable_cascade":
return "stable_cascade"
if params.get("control_type"):
if params.get("return_control_map", False):
return "controlnet_annotator"
Expand Down
36 changes: 23 additions & 13 deletions hordelib/model_manager/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,29 +165,35 @@ def get_free_ram_mb(self) -> int:
def get_model_reference_info(self, model_name: str) -> dict | None:
return self.model_reference.get(model_name, None)

def get_model_filename(self, model_name: str) -> Path:
"""Return the filename of the model for a given model name.
def get_model_filenames(self, model_name: str) -> list[dict]: # TODO: Convert dict into class
"""Return the filenames of the model for a given model name.
Args:
model_name (str): The name of the model to get the filename for.
Returns:
Path: The filename of the model.
list[dict]: Each has at least one value "file_path" with the Path to the filename
Optionally it also has a key "file_type" with the type of file this is for the model
Raises:
ValueError: If the model name is not in the model reference.
"""
if model_name not in self.model_reference:
raise ValueError(f"Model {model_name} not found in model reference")

model_file_entries = self.model_reference.get(model_name, {}).get("config", {}).get("files", [])

model_files = []
for model_file_entry in model_file_entries:
path_config_item = model_file_entry.get("path")
path_config_type = model_file_entry.get("file_type")
if path_config_item:
if path_config_item.endswith((".ckpt", ".safetensors", ".pt", ".pth", ".bin")):
return Path(path_config_item)

raise ValueError(f"Model {model_name} does not have a valid file entry")
path_entry = {"file_path": Path(path_config_item)}
if path_config_type:
path_entry["file_type"] = path_config_type
model_files.append(path_entry)
if len(model_files) == 0:
raise ValueError(f"Model {model_name} does not have a valid file entry")
return model_files

def get_model_config_files(self, model_name: str) -> list[dict]:
"""Return the config files for a given model name.
Expand Down Expand Up @@ -267,11 +273,11 @@ def validate_model(self, model_name: str, skip_checksum: bool = False) -> bool |
Returns:
bool | None: `True` if the model is valid, `False` if not, `None` if the model is not on disk.
"""
model_file = self.get_model_filename(model_name)
logger.debug(f"Validating {model_name}. File: {model_file}")

if not self.is_file_available(model_file):
return None
model_files = self.get_model_filenames(model_name)
logger.debug(f"Validating {model_name}. Files: {model_files}")
for file_entry in model_files:
if not self.is_file_available(file_entry["file_path"]):
return None

file_details = self.get_model_config_files(model_name)

Expand Down Expand Up @@ -725,7 +731,11 @@ def is_model_available(self, model_name: str) -> bool:
if model_name in self.tainted_models:
return False

return self.is_file_available(self.get_model_filename(model_name))
model_files = self.get_model_filenames(model_name)
for file_entry in model_files:
if not self.is_file_available(file_entry["file_path"]):
return False
return True

def is_model_url_from_civitai(self, url: str) -> bool:
return CIVITAI_API_PATH in url
29 changes: 25 additions & 4 deletions hordelib/nodes/node_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def INPUT_TYPES(s):
"seamless_tiling_enabled": ("<bool>",),
"horde_model_name": ("<horde model name>",),
"ckpt_name": ("<ckpt name>",),
"file_type": ("<file type>",), # TODO: Make this optional
},
}

Expand All @@ -33,11 +34,15 @@ def load_checkpoint(
seamless_tiling_enabled: bool,
horde_model_name: str,
ckpt_name: str | None = None,
file_type: str | None = None,
output_vae=True,
output_clip=True,
preloading=False,
):
logger.debug(f"Loading model {horde_model_name}")
if file_type is not None:
logger.debug(f"Loading model {horde_model_name}:{file_type}")
else:
logger.debug(f"Loading model {horde_model_name}")
logger.debug(f"Will load Loras: {will_load_loras}, seamless tiling: {seamless_tiling_enabled}")
if ckpt_name:
logger.debug(f"Checkpoint name: {ckpt_name}")
Expand All @@ -48,7 +53,11 @@ def load_checkpoint(
if SharedModelManager.manager.compvis is None:
raise ValueError("CompVisModelManager is not initialised.")

same_loaded_model = SharedModelManager.manager._models_in_ram.get(horde_model_name)
horde_in_memory_name = horde_model_name
if file_type is not None:
horde_in_memory_name = f"{horde_model_name}:{file_type}"
same_loaded_model = SharedModelManager.manager._models_in_ram.get(horde_in_memory_name)
logger.debug([horde_in_memory_name, file_type, same_loaded_model])

# Check if the model was previously loaded and if so, not loaded with Loras
if same_loaded_model and not same_loaded_model[1]:
Expand All @@ -67,7 +76,19 @@ def load_checkpoint(
if not SharedModelManager.manager.compvis.is_model_available(horde_model_name):
raise ValueError(f"Model {horde_model_name} is not available.")

ckpt_name = SharedModelManager.manager.compvis.get_model_filename(horde_model_name).name
file_entries = SharedModelManager.manager.compvis.get_model_filenames(horde_model_name)
for file_entry in file_entries:
if file_type is not None:
# if a file_type has been passed, we look at the available files for this model
# To find the appropriate type.
if file_entry.get("file_type") == file_type:
ckpt_name = file_entry["file_path"].name
break
else:
# If there's no file_type passed, we follow the previous approach and pick the first file
# (There should be only one)
ckpt_name = file_entry["file_path"].name
break

# Clear references so comfy can free memory as needed
SharedModelManager.manager._models_in_ram = {}
Expand All @@ -80,7 +101,7 @@ def load_checkpoint(
embedding_directory=folder_paths.get_folder_paths("embeddings"),
)

SharedModelManager.manager._models_in_ram[horde_model_name] = result, will_load_loras
SharedModelManager.manager._models_in_ram[horde_in_memory_name] = result, will_load_loras

if seamless_tiling_enabled:
result[0].model.apply(make_circular)
Expand Down
Loading

0 comments on commit 3d6f9ff

Please sign in to comment.