Skip to content

Commit

Permalink
feat: more model meta instructions; allow meta in skip models
Browse files Browse the repository at this point in the history
`ALL SDXL`, `ALL SD15`, `ALL SD21`, `ALL SFW`, `ALL NSFW`
  • Loading branch information
tazlin committed Mar 5, 2024
1 parent bb8c072 commit 39653a7
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 2 deletions.
30 changes: 29 additions & 1 deletion horde_sdk/ai_horde_worker/bridge_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,16 @@


class MetaInstruction(StrEnum):
ALL_REGEX = r"all$|all models+$"
ALL_REGEX = r"all$|all models?$"

ALL_SDXL_REGEX = r"all sdxl$|all sdxl models?$"
ALL_SD15_REGEX = r"all sd15$|all sd15 models?$"
ALL_SD21_REGEX = r"all sd21$|all sd21 models?$"

ALL_SFW_REGEX = r"all sfw$|all sfw models?$"
ALL_NSFW_REGEX = r"all nsfw$|all nsfw models?$"

ALL_INPAINTING_REGEX = r"all inpainting$|all inpainting models?$"

TOP_N_REGEX = r"TOP (\d+)"
"""The regex to use to match the top N models. The number is in a capture group on its own."""
Expand Down Expand Up @@ -296,6 +305,7 @@ def validate_model(self) -> ImageWorkerBridgeData:
return self

_meta_load_instructions: list[str] | None = None
_meta_skip_instructions: list[str] | None = None

@property
def meta_load_instructions(self) -> list[str] | None:
Expand All @@ -315,6 +325,24 @@ def handle_meta_instructions(self) -> ImageWorkerBridgeData:

return self

@property
def meta_skip_instructions(self) -> list[str] | None:
"""The meta skip instructions."""
return self._meta_skip_instructions

@model_validator(mode="after")
def handle_meta_skip_instructions(self) -> ImageWorkerBridgeData:
# See if any entries are meta instructions, and if so, remove them and place them in _meta_skip_instructions
for instruction_regex in MetaInstruction.__members__.values():
for i, model in enumerate(self.image_models_to_skip):
if re.match(instruction_regex, model, re.IGNORECASE):
if self._meta_skip_instructions is None:
self._meta_skip_instructions = []
self._meta_skip_instructions.append(model)
self.image_models_to_skip.pop(i)

return self

@field_validator("image_models_to_load")
def validate_models_to_load(cls, v: list) -> list:
"""Validate and parse the models to load."""
Expand Down
108 changes: 108 additions & 0 deletions horde_sdk/ai_horde_worker/model_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,43 @@ def resolve_meta_instructions(
found_bottom_n = True
continue

if ImageModelLoadResolver.meta_instruction_regex_match(
MetaInstruction.ALL_SDXL_REGEX,
possible_instruction,
):
return_list.extend(self.resolve_all_models_of_baseline("stable_diffusion_xl"))

if ImageModelLoadResolver.meta_instruction_regex_match(
MetaInstruction.ALL_SD15_REGEX,
possible_instruction,
):
return_list.extend(self.resolve_all_models_of_baseline("stable_diffusion_1"))

if ImageModelLoadResolver.meta_instruction_regex_match(
MetaInstruction.ALL_SD21_REGEX,
possible_instruction,
):
return_list.extend(self.resolve_all_models_of_baseline("stable_diffusion_2_512"))
return_list.extend(self.resolve_all_models_of_baseline("stable_diffusion_2_768"))

if ImageModelLoadResolver.meta_instruction_regex_match(
MetaInstruction.ALL_INPAINTING_REGEX,
possible_instruction,
):
return_list.extend(self.resolve_all_inpainting_models())

if ImageModelLoadResolver.meta_instruction_regex_match(
MetaInstruction.ALL_SFW_REGEX,
possible_instruction,
):
return_list.extend(self.resolve_all_sfw_model_names())

if ImageModelLoadResolver.meta_instruction_regex_match(
MetaInstruction.ALL_NSFW_REGEX,
possible_instruction,
):
return_list.extend(self.resolve_all_nsfw_model_names())

# If no valid meta instruction were found, return None
return set(return_list)

Expand Down Expand Up @@ -113,6 +150,77 @@ def resolve_all_model_names(self) -> set[str]:
logger.error("No stable diffusion models found in model reference.")
return set()

def _resolve_sfw_nsfw_model_names(self, nsfw: bool) -> set[str]:
"""Get the names of all SFW or NSFW models defined in the model reference.
Args:
nsfw: A boolean representing whether to get SFW or NSFW models.
Returns:
A set of strings representing the names of all SFW or NSFW models.
"""
all_model_references = self._model_reference_manager.get_all_model_references()

sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion]

found_models: set[str] = set()

if sd_model_references is None:
logger.error("No stable diffusion models found in model reference.")
return found_models

for model in sd_model_references.root.values():
if not isinstance(model, StableDiffusion_ModelRecord):
logger.error(f"Model {model} is not a StableDiffusion_ModelRecord")
continue

if model.nsfw == nsfw:
found_models.add(model.name)

return found_models

def resolve_all_sfw_model_names(self) -> set[str]:
"""Get the names of all SFW models defined in the model reference.
Returns:
A set of strings representing the names of all SFW models.
"""
return self._resolve_sfw_nsfw_model_names(nsfw=False)

def resolve_all_nsfw_model_names(self) -> set[str]:
"""Get the names of all NSFW models defined in the model reference.
Returns:
A set of strings representing the names of all NSFW models.
"""
return self._resolve_sfw_nsfw_model_names(nsfw=True)

def resolve_all_inpainting_models(self) -> set[str]:
"""Get the names of all inpainting models defined in the model reference.
Returns:
A set of strings representing the names of all inpainting models.
"""
all_model_references = self._model_reference_manager.get_all_model_references()

sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion]

found_models: set[str] = set()

if sd_model_references is None:
logger.error("No stable diffusion models found in model reference.")
return found_models

for model in sd_model_references.root.values():
if not isinstance(model, StableDiffusion_ModelRecord):
logger.error(f"Model {model} is not a StableDiffusion_ModelRecord")
continue

if model.inpainting:
found_models.add(model.name)

return found_models

def resolve_all_models_of_baseline(self, baseline: str) -> set[str]:
"""Get the names of all models of a given baseline defined in the model reference.
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
horde_model_reference~=0.6.1
horde_model_reference~=0.6.3

pydantic
requests
Expand Down
75 changes: 75 additions & 0 deletions tests/ai_horde_worker/test_model_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,81 @@ def test_image_model_load_resolver_multiple_instructions(
assert len(resolved_model_names) == 2


def test_image_model_load_resolved_all_sd15(
image_model_load_resolver: ImageModelLoadResolver,
) -> None:
resolved_model_names = image_model_load_resolver.resolve_meta_instructions(
["all sd15"],
AIHordeAPIManualClient(),
)

assert len(resolved_model_names) > 0

for model_name in resolved_model_names:
assert "SDXL" not in model_name

assert "Deliberate" in resolved_model_names


def test_image_model_load_resolved_all_sd21(
image_model_load_resolver: ImageModelLoadResolver,
) -> None:
resolved_model_names = image_model_load_resolver.resolve_meta_instructions(
["all sd21"],
AIHordeAPIManualClient(),
)

assert len(resolved_model_names) > 0

for model_name in resolved_model_names:
assert "SDXL" not in model_name
assert model_name != "Deliberate"


def test_image_model_load_resolved_all_sdxl(
image_model_load_resolver: ImageModelLoadResolver,
) -> None:
resolved_model_names = image_model_load_resolver.resolve_meta_instructions(
["all sdxl"],
AIHordeAPIManualClient(),
)

assert len(resolved_model_names) > 0
assert "AlbedoBase XL (SDXL)" in resolved_model_names


def test_image_model_load_resolved_all_inpainting(
image_model_load_resolver: ImageModelLoadResolver,
) -> None:
resolved_model_names = image_model_load_resolver.resolve_meta_instructions(
["all inpainting"],
AIHordeAPIManualClient(),
)

assert len(resolved_model_names) > 0
assert any("inpainting" in model_name.lower() for model_name in resolved_model_names)


def test_image_model_load_resolved_sfw_nsfw(
image_model_load_resolver: ImageModelLoadResolver,
) -> None:
resolved_model_names = image_model_load_resolver.resolve_meta_instructions(
["all sfw"],
AIHordeAPIManualClient(),
)

assert len(resolved_model_names) > 0
assert not any("urpm" in model_name.lower() for model_name in resolved_model_names)

resolved_model_names = image_model_load_resolver.resolve_meta_instructions(
["all nsfw"],
AIHordeAPIManualClient(),
)

assert len(resolved_model_names) > 0
assert any("urpm" in model_name.lower() for model_name in resolved_model_names)


def test_image_models_unique_results_only(
image_model_load_resolver: ImageModelLoadResolver,
) -> None:
Expand Down

0 comments on commit 39653a7

Please sign in to comment.