diff --git a/horde_sdk/ai_horde_worker/model_meta.py b/horde_sdk/ai_horde_worker/model_meta.py index e636f37..737deab 100644 --- a/horde_sdk/ai_horde_worker/model_meta.py +++ b/horde_sdk/ai_horde_worker/model_meta.py @@ -2,6 +2,7 @@ from horde_model_reference.meta_consts import MODEL_REFERENCE_CATEGORY from horde_model_reference.model_reference_manager import ModelReferenceManager +from horde_model_reference.model_reference_records import StableDiffusion_ModelRecord from loguru import logger from horde_sdk.ai_horde_api.ai_horde_clients import AIHordeAPIManualClient @@ -112,6 +113,35 @@ def resolve_all_model_names(self) -> set[str]: logger.error("No stable diffusion models found in model reference.") return set() + def resolve_all_models_of_baseline(self, baseline: str) -> set[str]: + """Get the names of all models of a given baseline defined in the model reference. + + Args: + baseline: A string representing the baseline to get models for. + + Returns: + A set of strings representing the names of all models of the given baseline. + """ + all_model_references = self._model_reference_manager.get_all_model_references() + + sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion] + + found_models: set[str] = set() + + if sd_model_references is None: + logger.error("No stable diffusion models found in model reference.") + return found_models + + for model in sd_model_references.root.values(): + if not isinstance(model, StableDiffusion_ModelRecord): + logger.error(f"Model {model} is not a StableDiffusion_ModelRecord") + continue + + if model.baseline == baseline: + found_models.add(model.name) + + return found_models + @staticmethod def resolve_top_n_model_names( number_of_top_models: int, diff --git a/tests/ai_horde_worker/test_model_meta.py b/tests/ai_horde_worker/test_model_meta.py index 9bff8ab..11e0a14 100644 --- a/tests/ai_horde_worker/test_model_meta.py +++ b/tests/ai_horde_worker/test_model_meta.py @@ -103,3 +103,11 @@ def test_image_models_unique_results_only( all_model_names = image_model_load_resolver.resolve_all_model_names() assert len(resolved_model_names) >= (len(all_model_names) - 1) # FIXME: -1 is to account for SDXL beta + + +def test_resolve_all_models_of_baseline( + image_model_load_resolver: ImageModelLoadResolver, +) -> None: + resolved_model_names = image_model_load_resolver.resolve_all_models_of_baseline("stable_diffusion_xl") + + assert len(resolved_model_names) > 0