diff --git a/horde_sdk/ai_horde_worker/bridge_data.py b/horde_sdk/ai_horde_worker/bridge_data.py index d4be0ee..15a1abf 100644 --- a/horde_sdk/ai_horde_worker/bridge_data.py +++ b/horde_sdk/ai_horde_worker/bridge_data.py @@ -19,7 +19,16 @@ class MetaInstruction(StrEnum): - ALL_REGEX = r"all$|all models+$" + ALL_REGEX = r"all$|all models?$" + + ALL_SDXL_REGEX = r"all sdxl$|all sdxl models?$" + ALL_SD15_REGEX = r"all sd15$|all sd15 models?$" + ALL_SD21_REGEX = r"all sd21$|all sd21 models?$" + + ALL_SFW_REGEX = r"all sfw$|all sfw models?$" + ALL_NSFW_REGEX = r"all nsfw$|all nsfw models?$" + + ALL_INPAINTING_REGEX = r"all inpainting$|all inpainting models?$" TOP_N_REGEX = r"TOP (\d+)" """The regex to use to match the top N models. The number is in a capture group on its own.""" @@ -296,6 +305,7 @@ def validate_model(self) -> ImageWorkerBridgeData: return self _meta_load_instructions: list[str] | None = None + _meta_skip_instructions: list[str] | None = None @property def meta_load_instructions(self) -> list[str] | None: @@ -315,6 +325,24 @@ def handle_meta_instructions(self) -> ImageWorkerBridgeData: return self + @property + def meta_skip_instructions(self) -> list[str] | None: + """The meta skip instructions.""" + return self._meta_skip_instructions + + @model_validator(mode="after") + def handle_meta_skip_instructions(self) -> ImageWorkerBridgeData: + # See if any entries are meta instructions, and if so, remove them and place them in _meta_skip_instructions + for instruction_regex in MetaInstruction.__members__.values(): + for i, model in enumerate(self.image_models_to_skip): + if re.match(instruction_regex, model, re.IGNORECASE): + if self._meta_skip_instructions is None: + self._meta_skip_instructions = [] + self._meta_skip_instructions.append(model) + self.image_models_to_skip.pop(i) + + return self + @field_validator("image_models_to_load") def validate_models_to_load(cls, v: list) -> list: """Validate and parse the models to load.""" diff --git a/horde_sdk/ai_horde_worker/model_meta.py b/horde_sdk/ai_horde_worker/model_meta.py index 737deab..a2757fe 100644 --- a/horde_sdk/ai_horde_worker/model_meta.py +++ b/horde_sdk/ai_horde_worker/model_meta.py @@ -81,6 +81,43 @@ def resolve_meta_instructions( found_bottom_n = True continue + if ImageModelLoadResolver.meta_instruction_regex_match( + MetaInstruction.ALL_SDXL_REGEX, + possible_instruction, + ): + return_list.extend(self.resolve_all_models_of_baseline("stable_diffusion_xl")) + + if ImageModelLoadResolver.meta_instruction_regex_match( + MetaInstruction.ALL_SD15_REGEX, + possible_instruction, + ): + return_list.extend(self.resolve_all_models_of_baseline("stable_diffusion_1")) + + if ImageModelLoadResolver.meta_instruction_regex_match( + MetaInstruction.ALL_SD21_REGEX, + possible_instruction, + ): + return_list.extend(self.resolve_all_models_of_baseline("stable_diffusion_2_512")) + return_list.extend(self.resolve_all_models_of_baseline("stable_diffusion_2_768")) + + if ImageModelLoadResolver.meta_instruction_regex_match( + MetaInstruction.ALL_INPAINTING_REGEX, + possible_instruction, + ): + return_list.extend(self.resolve_all_inpainting_models()) + + if ImageModelLoadResolver.meta_instruction_regex_match( + MetaInstruction.ALL_SFW_REGEX, + possible_instruction, + ): + return_list.extend(self.resolve_all_sfw_model_names()) + + if ImageModelLoadResolver.meta_instruction_regex_match( + MetaInstruction.ALL_NSFW_REGEX, + possible_instruction, + ): + return_list.extend(self.resolve_all_nsfw_model_names()) + # If no valid meta instruction were found, return None return set(return_list) @@ -113,6 +150,77 @@ def resolve_all_model_names(self) -> set[str]: logger.error("No stable diffusion models found in model reference.") return set() + def _resolve_sfw_nsfw_model_names(self, nsfw: bool) -> set[str]: + """Get the names of all SFW or NSFW models defined in the model reference. + + Args: + nsfw: A boolean representing whether to get SFW or NSFW models. + + Returns: + A set of strings representing the names of all SFW or NSFW models. + """ + all_model_references = self._model_reference_manager.get_all_model_references() + + sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion] + + found_models: set[str] = set() + + if sd_model_references is None: + logger.error("No stable diffusion models found in model reference.") + return found_models + + for model in sd_model_references.root.values(): + if not isinstance(model, StableDiffusion_ModelRecord): + logger.error(f"Model {model} is not a StableDiffusion_ModelRecord") + continue + + if model.nsfw == nsfw: + found_models.add(model.name) + + return found_models + + def resolve_all_sfw_model_names(self) -> set[str]: + """Get the names of all SFW models defined in the model reference. + + Returns: + A set of strings representing the names of all SFW models. + """ + return self._resolve_sfw_nsfw_model_names(nsfw=False) + + def resolve_all_nsfw_model_names(self) -> set[str]: + """Get the names of all NSFW models defined in the model reference. + + Returns: + A set of strings representing the names of all NSFW models. + """ + return self._resolve_sfw_nsfw_model_names(nsfw=True) + + def resolve_all_inpainting_models(self) -> set[str]: + """Get the names of all inpainting models defined in the model reference. + + Returns: + A set of strings representing the names of all inpainting models. + """ + all_model_references = self._model_reference_manager.get_all_model_references() + + sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion] + + found_models: set[str] = set() + + if sd_model_references is None: + logger.error("No stable diffusion models found in model reference.") + return found_models + + for model in sd_model_references.root.values(): + if not isinstance(model, StableDiffusion_ModelRecord): + logger.error(f"Model {model} is not a StableDiffusion_ModelRecord") + continue + + if model.inpainting: + found_models.add(model.name) + + return found_models + def resolve_all_models_of_baseline(self, baseline: str) -> set[str]: """Get the names of all models of a given baseline defined in the model reference. diff --git a/requirements.txt b/requirements.txt index da57b0c..20fd6bc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -horde_model_reference~=0.6.1 +horde_model_reference~=0.6.3 pydantic requests diff --git a/tests/ai_horde_worker/test_model_meta.py b/tests/ai_horde_worker/test_model_meta.py index 11e0a14..d3deee3 100644 --- a/tests/ai_horde_worker/test_model_meta.py +++ b/tests/ai_horde_worker/test_model_meta.py @@ -93,6 +93,81 @@ def test_image_model_load_resolver_multiple_instructions( assert len(resolved_model_names) == 2 +def test_image_model_load_resolved_all_sd15( + image_model_load_resolver: ImageModelLoadResolver, +) -> None: + resolved_model_names = image_model_load_resolver.resolve_meta_instructions( + ["all sd15"], + AIHordeAPIManualClient(), + ) + + assert len(resolved_model_names) > 0 + + for model_name in resolved_model_names: + assert "SDXL" not in model_name + + assert "Deliberate" in resolved_model_names + + +def test_image_model_load_resolved_all_sd21( + image_model_load_resolver: ImageModelLoadResolver, +) -> None: + resolved_model_names = image_model_load_resolver.resolve_meta_instructions( + ["all sd21"], + AIHordeAPIManualClient(), + ) + + assert len(resolved_model_names) > 0 + + for model_name in resolved_model_names: + assert "SDXL" not in model_name + assert model_name != "Deliberate" + + +def test_image_model_load_resolved_all_sdxl( + image_model_load_resolver: ImageModelLoadResolver, +) -> None: + resolved_model_names = image_model_load_resolver.resolve_meta_instructions( + ["all sdxl"], + AIHordeAPIManualClient(), + ) + + assert len(resolved_model_names) > 0 + assert "AlbedoBase XL (SDXL)" in resolved_model_names + + +def test_image_model_load_resolved_all_inpainting( + image_model_load_resolver: ImageModelLoadResolver, +) -> None: + resolved_model_names = image_model_load_resolver.resolve_meta_instructions( + ["all inpainting"], + AIHordeAPIManualClient(), + ) + + assert len(resolved_model_names) > 0 + assert any("inpainting" in model_name.lower() for model_name in resolved_model_names) + + +def test_image_model_load_resolved_sfw_nsfw( + image_model_load_resolver: ImageModelLoadResolver, +) -> None: + resolved_model_names = image_model_load_resolver.resolve_meta_instructions( + ["all sfw"], + AIHordeAPIManualClient(), + ) + + assert len(resolved_model_names) > 0 + assert not any("urpm" in model_name.lower() for model_name in resolved_model_names) + + resolved_model_names = image_model_load_resolver.resolve_meta_instructions( + ["all nsfw"], + AIHordeAPIManualClient(), + ) + + assert len(resolved_model_names) > 0 + assert any("urpm" in model_name.lower() for model_name in resolved_model_names) + + def test_image_models_unique_results_only( image_model_load_resolver: ImageModelLoadResolver, ) -> None: