Skip to content

Commit

Permalink
Merge pull request #259 from Haidra-Org/main
Browse files Browse the repository at this point in the history
feat: remove large models from model meta commands
  • Loading branch information
tazlin authored Sep 23, 2024
2 parents 3e705dc + 217b760 commit d3efead
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 6 deletions.
6 changes: 6 additions & 0 deletions horde_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def _dev_env_var_warnings() -> None: # pragma: no cover
if len(dev_key) != 10 and len(dev_key) != 22:
raise ValueError("AI_HORDE_DEV_APIKEY must be the anon key or 22 characters long.")

AI_HORDE_MODEL_META_LARGE_MODELS = os.getenv("AI_HORDE_MODEL_META_LARGE_MODELS")
if AI_HORDE_MODEL_META_LARGE_MODELS:
logger.debug(
f"AI_HORDE_MODEL_META_LARGE_MODELS is {AI_HORDE_MODEL_META_LARGE_MODELS}.",
)


_dev_env_var_warnings()

Expand Down
28 changes: 22 additions & 6 deletions horde_sdk/ai_horde_worker/model_meta.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re

from horde_model_reference.meta_consts import MODEL_REFERENCE_CATEGORY
from horde_model_reference.meta_consts import MODEL_REFERENCE_CATEGORY, STABLE_DIFFUSION_BASELINE_CATEGORY
from horde_model_reference.model_reference_manager import ModelReferenceManager
from horde_model_reference.model_reference_records import StableDiffusion_ModelRecord
from loguru import logger
Expand Down Expand Up @@ -128,7 +129,7 @@ def resolve_meta_instructions(
return_list.extend(self.resolve_all_nsfw_model_names())

# If no valid meta instruction were found, return None
return set(return_list)
return self.remove_large_models(set(return_list))

@staticmethod
def meta_instruction_regex_match(instruction: str, target_string: str) -> re.Match[str] | None:
Expand All @@ -140,9 +141,22 @@ def meta_instruction_regex_match(instruction: str, target_string: str) -> re.Mat
Returns:
A Match object if the target string matches the regex pattern, otherwise None.
"""
return re.match(instruction, target_string, re.IGNORECASE)

def remove_large_models(self, models: set[str]) -> set[str]:
"""Remove large models from the input set of models."""
AI_HORDE_MODEL_META_LARGE_MODELS = os.getenv("AI_HORDE_MODEL_META_LARGE_MODELS")
if not AI_HORDE_MODEL_META_LARGE_MODELS:
cascade_models = self.resolve_all_models_of_baseline(STABLE_DIFFUSION_BASELINE_CATEGORY.stable_cascade)
flux_models = self.resolve_all_models_of_baseline(STABLE_DIFFUSION_BASELINE_CATEGORY.flux_1)

logger.debug(f"Removing cascade models: {cascade_models}")
logger.debug(f"Removing flux models: {flux_models}")
models = models - cascade_models - flux_models
return models

def resolve_all_model_names(self) -> set[str]:
"""Get the names of all models defined in the model reference.
Expand All @@ -153,11 +167,13 @@ def resolve_all_model_names(self) -> set[str]:

sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion]

if sd_model_references:
return set(sd_model_references.root.keys())
all_models = set(sd_model_references.root.keys()) if sd_model_references is not None else set()

logger.error("No stable diffusion models found in model reference.")
return set()
all_models = self.remove_large_models(all_models)

if not all_models:
logger.error("No stable diffusion models found in model reference.")
return all_models

def _resolve_sfw_nsfw_model_names(self, nsfw: bool) -> set[str]:
"""Get the names of all SFW or NSFW models defined in the model reference.
Expand Down
23 changes: 23 additions & 0 deletions tests/ai_horde_worker/test_model_meta_api_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ def test_image_model_load_resolver_all(image_model_load_resolver: ImageModelLoad

assert len(all_model_names) > 0

import os

os.environ["AI_HORDE_MODEL_META_LARGE_MODELS"] = "true"

all_model_names_with_large = image_model_load_resolver.resolve_all_model_names()

del os.environ["AI_HORDE_MODEL_META_LARGE_MODELS"]

assert len(all_model_names_with_large) > len(all_model_names)


def test_image_model_load_resolver_top_n(
image_model_load_resolver: ImageModelLoadResolver,
Expand Down Expand Up @@ -179,6 +189,19 @@ def test_image_models_unique_results_only(

assert len(resolved_model_names) >= (len(all_model_names) - 1) # FIXME: -1 is to account for SDXL beta

import os

os.environ["AI_HORDE_MODEL_META_LARGE_MODELS"] = "true"

resolved_models_names_with_large = image_model_load_resolver.resolve_meta_instructions(
["top 1000", "bottom 1000"],
AIHordeAPIManualClient(),
)

del os.environ["AI_HORDE_MODEL_META_LARGE_MODELS"]

assert len(resolved_models_names_with_large) >= len(resolved_model_names)


def test_resolve_all_models_of_baseline(
image_model_load_resolver: ImageModelLoadResolver,
Expand Down

0 comments on commit d3efead

Please sign in to comment.