Skip to content

Commit

Permalink
fix: ensure adhoc loras can be rotated
Browse files Browse the repository at this point in the history
  • Loading branch information
db0 authored and tazlin committed Dec 11, 2024
1 parent dec43ce commit 772fb9f
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions hordelib/model_manager/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def __init__(
self._thread = None
self.stop_downloading = True
# Not yet handled, as we need a global reference to search through.
self._previous_model_reference = {} # type: ignore # FIXME: add type
self._adhoc_loras = set() # type: ignore # FIXME: add type
self._download_wait = download_wait
# If false, this MM will only download SFW loras
Expand Down Expand Up @@ -205,7 +204,7 @@ def load_model_database(self) -> None:
for version_id in lora["versions"]:
self._index_version_ids[version_id] = lora["name"]
self.model_reference = new_model_reference
logger.info("Loaded model reference from disk.")
logger.info(f"Loaded model reference from disk with {len(self.model_reference)} lora entries.")
except json.JSONDecodeError:
logger.error(f"Could not load {self.models_db_name} model reference from disk! Bad JSON?")
self.model_reference = {}
Expand Down Expand Up @@ -544,7 +543,8 @@ def _download_thread(self, thread_number):
lora["last_checked"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self._add_lora_to_reference(lora)
if self.is_adhoc_cache_full():
self.delete_oldest_lora()
for _lora_iter in range(self.amount_of_adhoc_loras_to_delete()):
self.delete_oldest_lora()
self.save_cached_reference_to_disk()
break

Expand Down Expand Up @@ -680,7 +680,6 @@ def download_default_loras(self, nsfw=True, timeout=None):
if not self.are_downloads_complete():
return
self.nsfw = nsfw
self._previous_model_reference = copy.deepcopy(self.model_reference)
# TODO: Avoid clearing this out, until we know CivitAI is not dead.
self.clear_all_references()
os.makedirs(self.model_folder_path, exist_ok=True)
Expand Down Expand Up @@ -891,6 +890,12 @@ def is_default_cache_full(self):
def is_adhoc_cache_full(self):
return self.calculate_adhoc_loras_cache() >= self.max_adhoc_disk

def amount_of_adhoc_loras_to_delete(self):
if not self.is_adhoc_cache_full():
return 0
# If we have exceeded our cache, we delete 1 lora + 1 extra lora per 4G over our cache.
return 1 + ((self.calculate_adhoc_loras_cache() - self._max_top_disk) / 4096)

def calculate_download_queue(self):
total_queue = 0
for lora in self._download_queue:
Expand Down Expand Up @@ -1010,7 +1015,7 @@ def reset_adhoc_loras(self):
self._adhoc_loras = set()
unsorted_items = []
sorted_items = []
for plora_key, plora in self._previous_model_reference.items():
for plora_key, plora in self.model_reference.items():
for version in plora.get("versions", {}).values():
unsorted_items.append((plora_key, version))
try:
Expand All @@ -1021,24 +1026,26 @@ def reset_adhoc_loras(self):
)
except Exception as err:
logger.error(err)
while not self.is_adhoc_cache_full() and len(sorted_items) > 0:
while len(sorted_items) > 0:
prevlora_key, prevversion = sorted_items.pop()
if prevlora_key in self.model_reference:
if prevversion.get("adhoc", True) is False:
continue
# If True, it will initiates a redownload and call _add_lora_to_reference() later
if not self._check_for_refresh(prevlora_key):
if "last_used" not in prevversion:
prevversion["last_used"] = now
# We create a temp lora dict holding the just one version (the one we want to keep)
# The _add_lora_to_reference() will anyway merge versions if we keep more than 1
temp_lora = self._previous_model_reference[prevlora_key].copy()
temp_lora = self.model_reference[prevlora_key].copy()
temp_lora["versions"] = {}
temp_lora["versions"][prevversion["version_id"]] = prevversion
self._add_lora_to_reference(temp_lora)
self._adhoc_loras.add(prevlora_key)
self._previous_model_reference = {}
self.save_cached_reference_to_disk()
logger.debug("Finished lora reset")
logger.debug(
f"Finished lora reset. Added {len(self._adhoc_loras)} adhoc loras "
f"with a total size of {self.calculate_adhoc_loras_cache()}"
)

def get_lora_metadata(self, url: str) -> dict:
"""Returns parsed Lora details from civitAI
Expand Down

0 comments on commit 772fb9f

Please sign in to comment.