From df9887225e1b09b86b846e5fd530551d145117bb Mon Sep 17 00:00:00 2001 From: tazlin Date: Tue, 9 Jan 2024 09:30:38 -0500 Subject: [PATCH 1/3] fix: try and fallback to on-disk model ref when can't download --- hordelib/shared_model_manager.py | 60 ++++++++++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 3 deletions(-) diff --git a/hordelib/shared_model_manager.py b/hordelib/shared_model_manager.py index 5b0faabc..b83f6425 100644 --- a/hordelib/shared_model_manager.py +++ b/hordelib/shared_model_manager.py @@ -5,12 +5,19 @@ from pathlib import Path import torch +from horde_model_reference import get_model_reference_file_path from horde_model_reference.legacy import LegacyReferenceDownloadManager from loguru import logger from typing_extensions import Self from hordelib.consts import MODEL_CATEGORY_NAMES -from hordelib.model_manager.hyper import ALL_MODEL_MANAGER_TYPES, BaseModelManager, ModelManager +from hordelib.model_manager.base import _temp_reference_lookup +from hordelib.model_manager.hyper import ( + ALL_MODEL_MANAGER_TYPES, + MODEL_MANAGERS_TYPE_LOOKUP, + BaseModelManager, + ModelManager, +) from hordelib.preload import ( ANNOTATOR_MODEL_SHA_LOOKUP, download_all_controlnet_annotators, @@ -70,6 +77,8 @@ def load_model_managers( managers_to_load: Iterable[str | MODEL_CATEGORY_NAMES | type[BaseModelManager]] = ALL_MODEL_MANAGER_TYPES, *, multiprocessing_lock: multiprocessing_lock | None = None, + download_legacy_references: bool = True, + overwrite_existing_references: bool = True, ): if cls.manager is None: cls.manager = ModelManager() @@ -78,8 +87,53 @@ def load_model_managers( args_passed.pop("cls") # XXX This is temporary lrdm = LegacyReferenceDownloadManager() - references = lrdm.download_all_legacy_model_references() - for reference in references: + if download_legacy_references: + try: + lrdm.download_all_legacy_model_references(overwrite_existing=overwrite_existing_references) + except Exception as e: + logger.error(f"Failed to download legacy model references: {e}") + logger.error( + "If this continues to happen, " + "github may be down or your internet connection may be having issues.", + ) + + references = {} + for reference in managers_to_load: + parsed_reference = None + if isinstance(reference, MODEL_CATEGORY_NAMES): + parsed_reference = reference + elif isinstance(reference, str): + try: + MODEL_CATEGORY_NAMES(reference) + except ValueError: + logger.warning(f"Invalid model category name: {reference}") + continue + parsed_reference = MODEL_CATEGORY_NAMES(reference) + elif isinstance(reference, type): + for k, v in MODEL_MANAGERS_TYPE_LOOKUP.items(): + if v == reference: + try: + MODEL_CATEGORY_NAMES(k) + except ValueError: + logger.warning(f"Invalid model category name: {k}") + continue + parsed_reference = MODEL_CATEGORY_NAMES(k) + break + + if parsed_reference is None: + logger.warning(f"Invalid model reference: {reference}") + continue + + if parsed_reference not in _temp_reference_lookup: + logger.warning(f"Model reference doesn't require a legacy download: {reference}") + continue + + references[parsed_reference] = get_model_reference_file_path(_temp_reference_lookup[parsed_reference]) + + for reference, path in references.items(): + if path is None and not download_legacy_references: + logger.warning(f"Failed to download legacy reference: {reference}") + continue logger.debug(f"Legacy reference downloaded: {reference}") do_migrations() From 5662e1092eb22d27fec3075232024be1b1a0a065 Mon Sep 17 00:00:00 2001 From: tazlin Date: Tue, 9 Jan 2024 09:36:56 -0500 Subject: [PATCH 2/3] docs: add docstring for `load_model_managers` --- hordelib/shared_model_manager.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/hordelib/shared_model_manager.py b/hordelib/shared_model_manager.py index b83f6425..9817d8c9 100644 --- a/hordelib/shared_model_manager.py +++ b/hordelib/shared_model_manager.py @@ -80,6 +80,21 @@ def load_model_managers( download_legacy_references: bool = True, overwrite_existing_references: bool = True, ): + """Load the model managers specified. + + Args: + managers_to_load (Iterable[str | MODEL_CATEGORY_NAMES | type[BaseModelManager]], optional): \ + The model managers to load. \ + Defaults to ALL_MODEL_MANAGER_TYPES. + multiprocessing_lock (multiprocessing_lock | None, optional): If you are using multiprocessing, \ + you should pass a lock here. \ + Defaults to None. + download_legacy_references (bool, optional): If True, this will download all legacy model references. \ + Defaults to True. + overwrite_existing_references (bool, optional): If True, this will overwrite any existing legacy model \ + references that might be already downloaded. \ + Defaults to True. + """ if cls.manager is None: cls.manager = ModelManager() From cedd9744127aeabf846721799d7d00764bc4b19c Mon Sep 17 00:00:00 2001 From: tazlin Date: Tue, 9 Jan 2024 09:37:11 -0500 Subject: [PATCH 3/3] fix: demote log message level for no-download-required refs --- hordelib/shared_model_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hordelib/shared_model_manager.py b/hordelib/shared_model_manager.py index 9817d8c9..e5464237 100644 --- a/hordelib/shared_model_manager.py +++ b/hordelib/shared_model_manager.py @@ -140,7 +140,7 @@ def load_model_managers( continue if parsed_reference not in _temp_reference_lookup: - logger.warning(f"Model reference doesn't require a legacy download: {reference}") + logger.debug(f"Model reference doesn't require a legacy download: {reference}") continue references[parsed_reference] = get_model_reference_file_path(_temp_reference_lookup[parsed_reference])