Skip to content

Commit

Permalink
feat: initial baseline model meta load command support
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Mar 5, 2024
1 parent 3d5729c commit 22d0f09
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
30 changes: 30 additions & 0 deletions horde_sdk/ai_horde_worker/model_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from horde_model_reference.meta_consts import MODEL_REFERENCE_CATEGORY
from horde_model_reference.model_reference_manager import ModelReferenceManager
from horde_model_reference.model_reference_records import StableDiffusion_ModelRecord
from loguru import logger

from horde_sdk.ai_horde_api.ai_horde_clients import AIHordeAPIManualClient
Expand Down Expand Up @@ -112,6 +113,35 @@ def resolve_all_model_names(self) -> set[str]:
logger.error("No stable diffusion models found in model reference.")
return set()

def resolve_all_models_of_baseline(self, baseline: str) -> set[str]:
"""Get the names of all models of a given baseline defined in the model reference.
Args:
baseline: A string representing the baseline to get models for.
Returns:
A set of strings representing the names of all models of the given baseline.
"""
all_model_references = self._model_reference_manager.get_all_model_references()

sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion]

found_models: set[str] = set()

if sd_model_references is None:
logger.error("No stable diffusion models found in model reference.")
return found_models

for model in sd_model_references.root.values():
if not isinstance(model, StableDiffusion_ModelRecord):
logger.error(f"Model {model} is not a StableDiffusion_ModelRecord")
continue

if model.baseline == baseline:
found_models.add(model.name)

return found_models

@staticmethod
def resolve_top_n_model_names(
number_of_top_models: int,
Expand Down
8 changes: 8 additions & 0 deletions tests/ai_horde_worker/test_model_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,11 @@ def test_image_models_unique_results_only(
all_model_names = image_model_load_resolver.resolve_all_model_names()

assert len(resolved_model_names) >= (len(all_model_names) - 1) # FIXME: -1 is to account for SDXL beta


def test_resolve_all_models_of_baseline(
image_model_load_resolver: ImageModelLoadResolver,
) -> None:
resolved_model_names = image_model_load_resolver.resolve_all_models_of_baseline("stable_diffusion_xl")

assert len(resolved_model_names) > 0

0 comments on commit 22d0f09

Please sign in to comment.