Skip to content

Commit

Permalink
wip: new mm
Browse files Browse the repository at this point in the history
  • Loading branch information
db0 committed Jun 23, 2024
1 parent c07ee1d commit 756f1c7
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 1 deletion.
4 changes: 4 additions & 0 deletions hordelib/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class MODEL_CATEGORY_NAMES(StrEnum):
ti = auto()
blip = auto()
clip = auto()
miscellaneous = auto()


# Default model managers to load
Expand All @@ -55,6 +56,7 @@ class MODEL_CATEGORY_NAMES(StrEnum):
# MODEL_CATEGORY_NAMES.safety_checker: True,
MODEL_CATEGORY_NAMES.lora: True,
MODEL_CATEGORY_NAMES.ti: True,
MODEL_CATEGORY_NAMES.miscellaneous: True,
}
"""The default model managers to load.""" # XXX Clarify

Expand All @@ -68,6 +70,7 @@ class MODEL_CATEGORY_NAMES(StrEnum):
MODEL_CATEGORY_NAMES.safety_checker: MODEL_CATEGORY_NAMES.safety_checker,
MODEL_CATEGORY_NAMES.lora: MODEL_CATEGORY_NAMES.lora,
MODEL_CATEGORY_NAMES.ti: MODEL_CATEGORY_NAMES.ti,
MODEL_CATEGORY_NAMES.miscellaneous: MODEL_CATEGORY_NAMES.miscellaneous,
}
"""The name of the json file (without the extension) of the corresponding model database."""

Expand All @@ -81,5 +84,6 @@ class MODEL_CATEGORY_NAMES(StrEnum):
MODEL_CATEGORY_NAMES.safety_checker: MODEL_CATEGORY_NAMES.safety_checker,
MODEL_CATEGORY_NAMES.lora: MODEL_CATEGORY_NAMES.lora,
MODEL_CATEGORY_NAMES.ti: MODEL_CATEGORY_NAMES.ti,
MODEL_CATEGORY_NAMES.miscellaneous: MODEL_CATEGORY_NAMES.miscellaneous,
}
"""The folder name on disk where the models are stored in AIWORKER_CACHE_HOME."""
8 changes: 8 additions & 0 deletions hordelib/model_manager/hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from hordelib.model_manager.esrgan import EsrganModelManager
from hordelib.model_manager.gfpgan import GfpganModelManager
from hordelib.model_manager.lora import LoraModelManager
from hordelib.model_manager.miscellaneous import MiscellaneousModelManager
from hordelib.model_manager.safety_checker import SafetyCheckerModelManager
from hordelib.model_manager.ti import TextualInversionModelManager

Expand All @@ -31,6 +32,7 @@
MODEL_CATEGORY_NAMES.safety_checker: SafetyCheckerModelManager,
MODEL_CATEGORY_NAMES.lora: LoraModelManager,
MODEL_CATEGORY_NAMES.ti: TextualInversionModelManager,
MODEL_CATEGORY_NAMES.miscellaneous: MiscellaneousModelManager,
}
"""A lookup table for the `BaseModelManager` types."""

Expand Down Expand Up @@ -91,6 +93,12 @@ def ti(self) -> TextualInversionModelManager | None:
found_mm = self.get_model_manager_instance(TextualInversionModelManager)
return found_mm if isinstance(found_mm, TextualInversionModelManager) else None

@property
def miscellaneous(self) -> MiscellaneousModelManager | None:
"""The miscellaneous model manager instance. Returns `None` if not loaded."""
found_mm = self.get_model_manager_instance(MiscellaneousModelManager)
return found_mm if isinstance(found_mm, MiscellaneousModelManager) else None

@property
def models(self) -> dict:
"""All model manager's internal dictionaries of models, loaded from model database JSON files."""
Expand Down
15 changes: 15 additions & 0 deletions hordelib/model_manager/miscellaneous.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from hordelib.consts import MODEL_CATEGORY_NAMES
from hordelib.model_manager.base import BaseModelManager


class MiscellaneousModelManager(BaseModelManager): # FIXME # TODO?
def __init__(
self,
download_reference=False,
**kwargs,
):
super().__init__(
model_category_name=MODEL_CATEGORY_NAMES.miscellaneous,
download_reference=download_reference,
**kwargs,
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from .enums import ResizeMode
from hordelib.shared_model_manager import SharedModelManager


def rgba2rgbfp32(x):
Expand Down Expand Up @@ -94,7 +95,10 @@ def numpy_to_pytorch(x):
y = torch.from_numpy(y).float()
return y


# IMPORTANT: Edited by hordelib devs to use model loader!
# To keep things more maintainable in case we need to update the code
# We let the same incoming variables, but ignore them in favour of our own
# model manager paths
def load_file_from_url(
url: str,
*,
Expand All @@ -106,6 +110,8 @@ def load_file_from_url(
Returns the path to the downloaded file.
"""
from loguru import logger # Debug
logger.error(SharedModelManager.manager.miscellaneous.is_file_available(file_name))
os.makedirs(model_dir, exist_ok=True)
if not file_name:
parts = urlparse(url)
Expand Down
3 changes: 3 additions & 0 deletions hordelib/nodes/node_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from hordelib.shared_model_manager import SharedModelManager


# Don't let the name fool you, this class is trying to load all the files that will be necessary
# for a given comfyUI workflow. That includes loras, etc.
# TODO: Rename to HordeWorkflowModelsLoader ;)
class HordeCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
Expand Down

0 comments on commit 756f1c7

Please sign in to comment.