Skip to content

Commit

Permalink
Merge pull request #159 from Haidra-Org/model-ref-changes
Browse files Browse the repository at this point in the history
fix: try and fallback to on-disk model ref when can't download
  • Loading branch information
tazlin authored Jan 9, 2024
2 parents 82b194d + cedd974 commit 9c6501f
Showing 1 changed file with 72 additions and 3 deletions.
75 changes: 72 additions & 3 deletions hordelib/shared_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,19 @@
from pathlib import Path

import torch
from horde_model_reference import get_model_reference_file_path
from horde_model_reference.legacy import LegacyReferenceDownloadManager
from loguru import logger
from typing_extensions import Self

from hordelib.consts import MODEL_CATEGORY_NAMES
from hordelib.model_manager.hyper import ALL_MODEL_MANAGER_TYPES, BaseModelManager, ModelManager
from hordelib.model_manager.base import _temp_reference_lookup
from hordelib.model_manager.hyper import (
ALL_MODEL_MANAGER_TYPES,
MODEL_MANAGERS_TYPE_LOOKUP,
BaseModelManager,
ModelManager,
)
from hordelib.preload import (
ANNOTATOR_MODEL_SHA_LOOKUP,
download_all_controlnet_annotators,
Expand Down Expand Up @@ -70,16 +77,78 @@ def load_model_managers(
managers_to_load: Iterable[str | MODEL_CATEGORY_NAMES | type[BaseModelManager]] = ALL_MODEL_MANAGER_TYPES,
*,
multiprocessing_lock: multiprocessing_lock | None = None,
download_legacy_references: bool = True,
overwrite_existing_references: bool = True,
):
"""Load the model managers specified.
Args:
managers_to_load (Iterable[str | MODEL_CATEGORY_NAMES | type[BaseModelManager]], optional): \
The model managers to load. \
Defaults to ALL_MODEL_MANAGER_TYPES.
multiprocessing_lock (multiprocessing_lock | None, optional): If you are using multiprocessing, \
you should pass a lock here. \
Defaults to None.
download_legacy_references (bool, optional): If True, this will download all legacy model references. \
Defaults to True.
overwrite_existing_references (bool, optional): If True, this will overwrite any existing legacy model \
references that might be already downloaded. \
Defaults to True.
"""
if cls.manager is None:
cls.manager = ModelManager()

args_passed = locals().copy() # XXX This is temporary
args_passed.pop("cls") # XXX This is temporary

lrdm = LegacyReferenceDownloadManager()
references = lrdm.download_all_legacy_model_references()
for reference in references:
if download_legacy_references:
try:
lrdm.download_all_legacy_model_references(overwrite_existing=overwrite_existing_references)
except Exception as e:
logger.error(f"Failed to download legacy model references: {e}")
logger.error(
"If this continues to happen, "
"github may be down or your internet connection may be having issues.",
)

references = {}
for reference in managers_to_load:
parsed_reference = None
if isinstance(reference, MODEL_CATEGORY_NAMES):
parsed_reference = reference
elif isinstance(reference, str):
try:
MODEL_CATEGORY_NAMES(reference)
except ValueError:
logger.warning(f"Invalid model category name: {reference}")
continue
parsed_reference = MODEL_CATEGORY_NAMES(reference)
elif isinstance(reference, type):
for k, v in MODEL_MANAGERS_TYPE_LOOKUP.items():
if v == reference:
try:
MODEL_CATEGORY_NAMES(k)
except ValueError:
logger.warning(f"Invalid model category name: {k}")
continue
parsed_reference = MODEL_CATEGORY_NAMES(k)
break

if parsed_reference is None:
logger.warning(f"Invalid model reference: {reference}")
continue

if parsed_reference not in _temp_reference_lookup:
logger.debug(f"Model reference doesn't require a legacy download: {reference}")
continue

references[parsed_reference] = get_model_reference_file_path(_temp_reference_lookup[parsed_reference])

for reference, path in references.items():
if path is None and not download_legacy_references:
logger.warning(f"Failed to download legacy reference: {reference}")
continue
logger.debug(f"Legacy reference downloaded: {reference}")

do_migrations()
Expand Down

0 comments on commit 9c6501f

Please sign in to comment.