Skip to content

Commit

Permalink
Merge pull request #371 from Haidra-Org/main
Browse files Browse the repository at this point in the history
fix: better lora cache clearing; feat: use comfyui `7a7efe8`
  • Loading branch information
tazlin authored Dec 12, 2024
2 parents a4116a4 + 5de5bc7 commit 1e554a6
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 17 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,6 @@ tmp/
*.pth
.gitignore
pipeline_debug.json
inference-time-data*
kudos_models/
optuna_stud*.db
6 changes: 3 additions & 3 deletions hordelib/comfy_horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@
_comfy_model_loading: types.ModuleType
_comfy_free_memory: Callable[[float, torch.device, list], None]
"""Will aggressively unload models from memory"""
_comfy_cleanup_models: Callable[[bool], None]
_comfy_cleanup_models: Callable
"""Will unload unused models from memory"""
_comfy_soft_empty_cache: Callable[[bool], None]
_comfy_soft_empty_cache: Callable
"""Triggers comfyui and torch to empty their caches"""

_comfy_is_changed_cache_get: Callable
Expand Down Expand Up @@ -944,7 +944,7 @@ def _run_pipeline(
if self.aggressive_unloading:
global _comfy_cleanup_models
logger.debug("Cleaning up models")
_comfy_cleanup_models(False)
_comfy_cleanup_models()
_comfy_soft_empty_cache()

stdio.replay()
Expand Down
2 changes: 1 addition & 1 deletion hordelib/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from hordelib.config_path import get_hordelib_path

COMFYUI_VERSION = "839ed3368efd0f61a2b986f57fe9e0698fd08e9f"
COMFYUI_VERSION = "7a7efe8424d960a95be393a85ca4d94e5892edea"
"""The exact version of ComfyUI version to load."""

REMOTE_PROXY = ""
Expand Down
32 changes: 21 additions & 11 deletions hordelib/model_manager/lora.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import glob
import hashlib
import json
Expand Down Expand Up @@ -97,7 +96,6 @@ def __init__(
self._thread = None
self.stop_downloading = True
# Not yet handled, as we need a global reference to search through.
self._previous_model_reference = {} # type: ignore # FIXME: add type
self._adhoc_loras = set() # type: ignore # FIXME: add type
self._download_wait = download_wait
# If false, this MM will only download SFW loras
Expand Down Expand Up @@ -205,7 +203,7 @@ def load_model_database(self) -> None:
for version_id in lora["versions"]:
self._index_version_ids[version_id] = lora["name"]
self.model_reference = new_model_reference
logger.info("Loaded model reference from disk.")
logger.info(f"Loaded model reference from disk with {len(self.model_reference)} lora entries.")
except json.JSONDecodeError:
logger.error(f"Could not load {self.models_db_name} model reference from disk! Bad JSON?")
self.model_reference = {}
Expand Down Expand Up @@ -544,7 +542,8 @@ def _download_thread(self, thread_number):
lora["last_checked"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self._add_lora_to_reference(lora)
if self.is_adhoc_cache_full():
self.delete_oldest_lora()
for _lora_iter in range(self.amount_of_adhoc_loras_to_delete()):
self.delete_oldest_lora()
self.save_cached_reference_to_disk()
break

Expand Down Expand Up @@ -680,7 +679,6 @@ def download_default_loras(self, nsfw=True, timeout=None):
if not self.are_downloads_complete():
return
self.nsfw = nsfw
self._previous_model_reference = copy.deepcopy(self.model_reference)
# TODO: Avoid clearing this out, until we know CivitAI is not dead.
self.clear_all_references()
os.makedirs(self.model_folder_path, exist_ok=True)
Expand Down Expand Up @@ -891,6 +889,12 @@ def is_default_cache_full(self):
def is_adhoc_cache_full(self):
return self.calculate_adhoc_loras_cache() >= self.max_adhoc_disk

def amount_of_adhoc_loras_to_delete(self):
if not self.is_adhoc_cache_full():
return 0
# If we have exceeded our cache, we delete 1 lora + 1 extra lora per 4G over our cache.
return 1 + int((self.calculate_adhoc_loras_cache() - self.max_adhoc_disk) / 4096)

def calculate_download_queue(self):
total_queue = 0
for lora in self._download_queue:
Expand Down Expand Up @@ -955,6 +959,10 @@ def delete_unused_loras(self, timeout=0):
logger.warning(f"Expected to delete lora file {lora_filename} but it was not found.")
return loras_to_delete

def delete_adhoc_loras_over_limit(self):
while self.is_adhoc_cache_full():
self.delete_oldest_lora()

def delete_lora_files(self, lora_filename: str):
filename = os.path.join(self.model_folder_path, lora_filename)
if not os.path.exists(filename):
Expand Down Expand Up @@ -1010,7 +1018,7 @@ def reset_adhoc_loras(self):
self._adhoc_loras = set()
unsorted_items = []
sorted_items = []
for plora_key, plora in self._previous_model_reference.items():
for plora_key, plora in self.model_reference.items():
for version in plora.get("versions", {}).values():
unsorted_items.append((plora_key, version))
try:
Expand All @@ -1021,24 +1029,26 @@ def reset_adhoc_loras(self):
)
except Exception as err:
logger.error(err)
while not self.is_adhoc_cache_full() and len(sorted_items) > 0:
while len(sorted_items) > 0:
prevlora_key, prevversion = sorted_items.pop()
if prevlora_key in self.model_reference:
if prevversion.get("adhoc", True) is False:
continue
# If True, it will initiates a redownload and call _add_lora_to_reference() later
if not self._check_for_refresh(prevlora_key):
if "last_used" not in prevversion:
prevversion["last_used"] = now
# We create a temp lora dict holding the just one version (the one we want to keep)
# The _add_lora_to_reference() will anyway merge versions if we keep more than 1
temp_lora = self._previous_model_reference[prevlora_key].copy()
temp_lora = self.model_reference[prevlora_key].copy()
temp_lora["versions"] = {}
temp_lora["versions"][prevversion["version_id"]] = prevversion
self._add_lora_to_reference(temp_lora)
self._adhoc_loras.add(prevlora_key)
self._previous_model_reference = {}
self.save_cached_reference_to_disk()
logger.debug("Finished lora reset")
logger.debug(
f"Finished lora reset. Added {len(self._adhoc_loras)} adhoc loras "
f"with a total size of {self.calculate_adhoc_loras_cache()}",
)

def get_lora_metadata(self, url: str) -> dict:
"""Returns parsed Lora details from civitAI
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def __init__(self, layer_list):
class AttentionSharingPatcher(torch.nn.Module):
def __init__(self, unet, frames=2, use_control=True, rank=256):
super().__init__()
model_management.unload_model_clones(unet)
# model_management.unload_model_clones(unet) # this is now handled implicitly in comfyui

units = []
for i in range(32):
Expand Down
5 changes: 4 additions & 1 deletion hordelib/pipeline_designs/pipeline_stable_cascade_remix.json
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,10 @@
"title": "clip_vision_encode_0",
"properties": {
"Node name for S&R": "CLIPVisionEncode"
}
},
"widgets_values": [
"none"
]
},
{
"id": 51,
Expand Down
1 change: 1 addition & 0 deletions hordelib/pipelines/pipeline_stable_cascade_remix.json
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@
},
"50": {
"inputs": {
"crop": "none",
"clip_vision": [
"49",
3
Expand Down

0 comments on commit 1e554a6

Please sign in to comment.