From 69926b7af89d9271cf14447aca1b80875b97c6b2 Mon Sep 17 00:00:00 2001 From: tazlin Date: Tue, 23 Jan 2024 10:57:17 -0500 Subject: [PATCH] fix: show warnings on unknown control_type or post_processors - also prints the randomly generated seed if used --- horde_sdk/ai_horde_api/apimodels/base.py | 49 +++++++++++++++++-- .../ai_horde_api/test_ai_horde_api_models.py | 12 +++++ 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/horde_sdk/ai_horde_api/apimodels/base.py b/horde_sdk/ai_horde_api/apimodels/base.py index 069fb0c..6aa13c1 100644 --- a/horde_sdk/ai_horde_api/apimodels/base.py +++ b/horde_sdk/ai_horde_api/apimodels/base.py @@ -9,7 +9,16 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from typing_extensions import override -from horde_sdk.ai_horde_api.consts import KNOWN_SAMPLERS, METADATA_TYPE, METADATA_VALUE, POST_PROCESSOR_ORDER_TYPE +from horde_sdk.ai_horde_api.consts import ( + KNOWN_CONTROLNETS, + KNOWN_FACEFIXERS, + KNOWN_MISC_POST_PROCESSORS, + KNOWN_SAMPLERS, + KNOWN_UPSCALERS, + METADATA_TYPE, + METADATA_VALUE, + POST_PROCESSOR_ORDER_TYPE, +) from horde_sdk.ai_horde_api.endpoints import AI_HORDE_BASE_URL from horde_sdk.ai_horde_api.fields import JobID, WorkerID from horde_sdk.generic_api.apimodels import HordeRequest, HordeResponseBaseModel @@ -131,7 +140,9 @@ class ImageGenerateParamMixin(BaseModel): """The desired output image width.""" seed_variation: int | None = Field(default=None, ge=1, le=1000) """Deprecated.""" - post_processing: list[str] = Field(default_factory=list) + post_processing: list[str | KNOWN_UPSCALERS | KNOWN_FACEFIXERS | KNOWN_MISC_POST_PROCESSORS] = Field( + default_factory=list, + ) """A list of post-processing models to use.""" post_processing_order: POST_PROCESSOR_ORDER_TYPE = POST_PROCESSOR_ORDER_TYPE.facefixers_first """The order in which to apply post-processing models. @@ -144,7 +155,7 @@ class ImageGenerateParamMixin(BaseModel): """Set to True if you want to use the hires fix.""" clip_skip: int = Field(default=1, ge=1, le=12) """The number of clip layers to skip.""" - control_type: str | None = None + control_type: str | KNOWN_CONTROLNETS | None = None """The type of control net type to use.""" image_is_control: bool | None = None """Set to True if the image is a control image.""" @@ -185,8 +196,36 @@ def sampler_name_must_be_known(cls, v: str | KNOWN_SAMPLERS) -> str | KNOWN_SAMP def random_seed_if_none(cls, v: str | None) -> str | None: """If the seed is None, generate a random seed.""" if v is None: - logger.debug("Generating random seed") - return str(random.randint(1, 1000000000)) + random_seed = str(random.randint(1, 1000000000)) + logger.debug(f"Using random seed ({random_seed})") + return random_seed + + return v + + @field_validator("post_processing") + def post_processors_must_be_known( + cls, + v: list[str | KNOWN_UPSCALERS | KNOWN_FACEFIXERS | KNOWN_MISC_POST_PROCESSORS], + ) -> list[str | KNOWN_UPSCALERS | KNOWN_FACEFIXERS | KNOWN_MISC_POST_PROCESSORS]: + """Ensure that the post processors are in this list of supported post processors.""" + for post_processor in v: + if ( + post_processor not in KNOWN_UPSCALERS.__members__ + and post_processor not in KNOWN_FACEFIXERS.__members__ + and post_processor not in KNOWN_MISC_POST_PROCESSORS.__members__ + ): + logger.warning( + f"Unknown post processor {post_processor}. Is your SDK out of date or did the API change?", + ) + return v + + @field_validator("control_type") + def control_type_must_be_known(cls, v: str | KNOWN_CONTROLNETS | None) -> str | KNOWN_CONTROLNETS | None: + """Ensure that the control type is in this list of supported control types.""" + if v is None: + return None + if v not in KNOWN_CONTROLNETS.__members__: + logger.warning(f"Unknown control type '{v}'. Is your SDK out of date or did the API change?") return v diff --git a/tests/ai_horde_api/test_ai_horde_api_models.py b/tests/ai_horde_api/test_ai_horde_api_models.py index b6b21a1..8e113d3 100644 --- a/tests/ai_horde_api/test_ai_horde_api_models.py +++ b/tests/ai_horde_api/test_ai_horde_api_models.py @@ -362,3 +362,15 @@ def test_ImageGenerateJobPopResponse() -> None: assert test_response.has_upscaler is True assert test_response.has_facefixer is True + + test_response = ImageGenerateJobPopResponse( + id=None, + ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))], + payload=ImageGenerateJobPopPayload( + post_processing=["unknown post processor"], + control_type="unknown control type", + sampler_name="unknown sampler", + prompt="A cat in a hat", + ), + skipped=ImageGenerateJobPopSkippedStatus(), + )