diff --git a/hordelib/shared_model_manager.py b/hordelib/shared_model_manager.py index 5b0faabc..e5464237 100644 --- a/hordelib/shared_model_manager.py +++ b/hordelib/shared_model_manager.py @@ -5,12 +5,19 @@ from pathlib import Path import torch +from horde_model_reference import get_model_reference_file_path from horde_model_reference.legacy import LegacyReferenceDownloadManager from loguru import logger from typing_extensions import Self from hordelib.consts import MODEL_CATEGORY_NAMES -from hordelib.model_manager.hyper import ALL_MODEL_MANAGER_TYPES, BaseModelManager, ModelManager +from hordelib.model_manager.base import _temp_reference_lookup +from hordelib.model_manager.hyper import ( + ALL_MODEL_MANAGER_TYPES, + MODEL_MANAGERS_TYPE_LOOKUP, + BaseModelManager, + ModelManager, +) from hordelib.preload import ( ANNOTATOR_MODEL_SHA_LOOKUP, download_all_controlnet_annotators, @@ -70,7 +77,24 @@ def load_model_managers( managers_to_load: Iterable[str | MODEL_CATEGORY_NAMES | type[BaseModelManager]] = ALL_MODEL_MANAGER_TYPES, *, multiprocessing_lock: multiprocessing_lock | None = None, + download_legacy_references: bool = True, + overwrite_existing_references: bool = True, ): + """Load the model managers specified. + + Args: + managers_to_load (Iterable[str | MODEL_CATEGORY_NAMES | type[BaseModelManager]], optional): \ + The model managers to load. \ + Defaults to ALL_MODEL_MANAGER_TYPES. + multiprocessing_lock (multiprocessing_lock | None, optional): If you are using multiprocessing, \ + you should pass a lock here. \ + Defaults to None. + download_legacy_references (bool, optional): If True, this will download all legacy model references. \ + Defaults to True. + overwrite_existing_references (bool, optional): If True, this will overwrite any existing legacy model \ + references that might be already downloaded. \ + Defaults to True. + """ if cls.manager is None: cls.manager = ModelManager() @@ -78,8 +102,53 @@ def load_model_managers( args_passed.pop("cls") # XXX This is temporary lrdm = LegacyReferenceDownloadManager() - references = lrdm.download_all_legacy_model_references() - for reference in references: + if download_legacy_references: + try: + lrdm.download_all_legacy_model_references(overwrite_existing=overwrite_existing_references) + except Exception as e: + logger.error(f"Failed to download legacy model references: {e}") + logger.error( + "If this continues to happen, " + "github may be down or your internet connection may be having issues.", + ) + + references = {} + for reference in managers_to_load: + parsed_reference = None + if isinstance(reference, MODEL_CATEGORY_NAMES): + parsed_reference = reference + elif isinstance(reference, str): + try: + MODEL_CATEGORY_NAMES(reference) + except ValueError: + logger.warning(f"Invalid model category name: {reference}") + continue + parsed_reference = MODEL_CATEGORY_NAMES(reference) + elif isinstance(reference, type): + for k, v in MODEL_MANAGERS_TYPE_LOOKUP.items(): + if v == reference: + try: + MODEL_CATEGORY_NAMES(k) + except ValueError: + logger.warning(f"Invalid model category name: {k}") + continue + parsed_reference = MODEL_CATEGORY_NAMES(k) + break + + if parsed_reference is None: + logger.warning(f"Invalid model reference: {reference}") + continue + + if parsed_reference not in _temp_reference_lookup: + logger.debug(f"Model reference doesn't require a legacy download: {reference}") + continue + + references[parsed_reference] = get_model_reference_file_path(_temp_reference_lookup[parsed_reference]) + + for reference, path in references.items(): + if path is None and not download_legacy_references: + logger.warning(f"Failed to download legacy reference: {reference}") + continue logger.debug(f"Legacy reference downloaded: {reference}") do_migrations()