From d4c7bde6d737c791581391dd09cb23479cbb4d2b Mon Sep 17 00:00:00 2001 From: tazlin Date: Thu, 25 Jan 2024 10:14:46 -0500 Subject: [PATCH] fix: handle enum validation better - Check for expected types more rigorously with enum values - Fixes the corner case where the enum member name doesn't match the API name (as with `4x_AnimeSharp`) --- .../ai_horde_api/apimodels/alchemy/_async.py | 4 +++- .../ai_horde_api/apimodels/alchemy/_status.py | 4 +++- horde_sdk/ai_horde_api/apimodels/base.py | 18 ++++++++++-------- horde_sdk/ai_horde_api/consts.py | 12 ++++++++++++ .../ai_horde_api/test_ai_horde_api_models.py | 19 +++++++++++++++++++ 5 files changed, 47 insertions(+), 10 deletions(-) diff --git a/horde_sdk/ai_horde_api/apimodels/alchemy/_async.py b/horde_sdk/ai_horde_api/apimodels/alchemy/_async.py index dffaa25..a9a7c1f 100644 --- a/horde_sdk/ai_horde_api/apimodels/alchemy/_async.py +++ b/horde_sdk/ai_horde_api/apimodels/alchemy/_async.py @@ -68,7 +68,9 @@ class AlchemyAsyncRequestFormItem(BaseModel): @field_validator("name") def check_name(cls, v: KNOWN_ALCHEMY_TYPES | str) -> KNOWN_ALCHEMY_TYPES | str: - if isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__: + if (isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__) or ( + not isinstance(v, KNOWN_ALCHEMY_TYPES) + ): logger.warning(f"Unknown alchemy form name {v}. Is your SDK out of date or did the API change?") return v diff --git a/horde_sdk/ai_horde_api/apimodels/alchemy/_status.py b/horde_sdk/ai_horde_api/apimodels/alchemy/_status.py index 0f44c6d..11e6a26 100644 --- a/horde_sdk/ai_horde_api/apimodels/alchemy/_status.py +++ b/horde_sdk/ai_horde_api/apimodels/alchemy/_status.py @@ -68,7 +68,9 @@ class AlchemyFormStatus(BaseModel): @field_validator("form", mode="before") def validate_form(cls, v: str | KNOWN_ALCHEMY_TYPES) -> KNOWN_ALCHEMY_TYPES | str: - if isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__: + if (isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__) or ( + not isinstance(v, KNOWN_ALCHEMY_TYPES) + ): logger.warning(f"Unknown form type {v}. Is your SDK out of date or did the API change?") return v diff --git a/horde_sdk/ai_horde_api/apimodels/base.py b/horde_sdk/ai_horde_api/apimodels/base.py index 6aa13c1..e7cb89c 100644 --- a/horde_sdk/ai_horde_api/apimodels/base.py +++ b/horde_sdk/ai_horde_api/apimodels/base.py @@ -18,6 +18,7 @@ METADATA_TYPE, METADATA_VALUE, POST_PROCESSOR_ORDER_TYPE, + _all_valid_post_processors_names_and_values, ) from horde_sdk.ai_horde_api.endpoints import AI_HORDE_BASE_URL from horde_sdk.ai_horde_api.fields import JobID, WorkerID @@ -182,8 +183,9 @@ def width_divisible_by_64(cls, value: int) -> int: @field_validator("sampler_name") def sampler_name_must_be_known(cls, v: str | KNOWN_SAMPLERS) -> str | KNOWN_SAMPLERS: """Ensure that the sampler name is in this list of supported samplers.""" - if v not in KNOWN_SAMPLERS.__members__: + if (isinstance(v, str) and v not in KNOWN_SAMPLERS.__members__) or (not isinstance(v, KNOWN_SAMPLERS)): logger.warning(f"Unknown sampler name {v}. Is your SDK out of date or did the API change?") + return v # @model_validator(mode="after") @@ -208,11 +210,11 @@ def post_processors_must_be_known( v: list[str | KNOWN_UPSCALERS | KNOWN_FACEFIXERS | KNOWN_MISC_POST_PROCESSORS], ) -> list[str | KNOWN_UPSCALERS | KNOWN_FACEFIXERS | KNOWN_MISC_POST_PROCESSORS]: """Ensure that the post processors are in this list of supported post processors.""" + + _valid_types: list[type] = [str, KNOWN_UPSCALERS, KNOWN_FACEFIXERS, KNOWN_MISC_POST_PROCESSORS] for post_processor in v: - if ( - post_processor not in KNOWN_UPSCALERS.__members__ - and post_processor not in KNOWN_FACEFIXERS.__members__ - and post_processor not in KNOWN_MISC_POST_PROCESSORS.__members__ + if post_processor not in _all_valid_post_processors_names_and_values or ( + type(post_processor) not in _valid_types ): logger.warning( f"Unknown post processor {post_processor}. Is your SDK out of date or did the API change?", @@ -224,7 +226,7 @@ def control_type_must_be_known(cls, v: str | KNOWN_CONTROLNETS | None) -> str | """Ensure that the control type is in this list of supported control types.""" if v is None: return None - if v not in KNOWN_CONTROLNETS.__members__: + if (isinstance(v, str) and v not in KNOWN_CONTROLNETS.__members__) or (not isinstance(v, KNOWN_CONTROLNETS)): logger.warning(f"Unknown control type '{v}'. Is your SDK out of date or did the API change?") return v @@ -260,13 +262,13 @@ class GenMetadataEntry(BaseModel): @field_validator("type_") def validate_type(cls, v: str | METADATA_TYPE) -> str | METADATA_TYPE: """Ensure that the type is in this list of supported types.""" - if v not in METADATA_TYPE.__members__: + if (isinstance(v, str) and v not in METADATA_TYPE.__members__) or (not isinstance(v, METADATA_TYPE)): logger.warning(f"Unknown metadata type {v}. Is your SDK out of date or did the API change?") return v @field_validator("value") def validate_value(cls, v: str | METADATA_VALUE) -> str | METADATA_VALUE: """Ensure that the value is in this list of supported values.""" - if v not in METADATA_VALUE.__members__: + if (isinstance(v, str) and v not in METADATA_VALUE.__members__) or (not isinstance(v, METADATA_VALUE)): logger.warning(f"Unknown metadata value {v}. Is your SDK out of date or did the API change?") return v diff --git a/horde_sdk/ai_horde_api/consts.py b/horde_sdk/ai_horde_api/consts.py index 48cd25b..c425e4f 100644 --- a/horde_sdk/ai_horde_api/consts.py +++ b/horde_sdk/ai_horde_api/consts.py @@ -135,6 +135,18 @@ class KNOWN_MISC_POST_PROCESSORS(StrEnum): strip_background = auto() +_all_valid_post_processors_names_and_values = ( + list(KNOWN_UPSCALERS.__members__.keys()) + + list(KNOWN_UPSCALERS.__members__.values()) + + list(KNOWN_FACEFIXERS.__members__.keys()) + + list(KNOWN_FACEFIXERS.__members__.values()) + + list(KNOWN_MISC_POST_PROCESSORS.__members__.keys()) + + list(KNOWN_MISC_POST_PROCESSORS.__members__.values()) +) +"""Used to validate post processor names and values. \ + This is because some post processor names are not valid python variable names.""" + + class POST_PROCESSOR_ORDER_TYPE(StrEnum): """The post processor order types that are known to the API. diff --git a/tests/ai_horde_api/test_ai_horde_api_models.py b/tests/ai_horde_api/test_ai_horde_api_models.py index 8e113d3..d7a0c49 100644 --- a/tests/ai_horde_api/test_ai_horde_api_models.py +++ b/tests/ai_horde_api/test_ai_horde_api_models.py @@ -374,3 +374,22 @@ def test_ImageGenerateJobPopResponse() -> None: ), skipped=ImageGenerateJobPopSkippedStatus(), ) + test_response = ImageGenerateJobPopResponse( + id=None, + ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))], + payload=ImageGenerateJobPopPayload( + post_processing=["4x_AnimeSharp"], + prompt="A cat in a hat", + ), + skipped=ImageGenerateJobPopSkippedStatus(), + ) + + test_response = ImageGenerateJobPopResponse( + id=None, + ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))], + payload=ImageGenerateJobPopPayload( + post_processing=[KNOWN_UPSCALERS.four_4x_AnimeSharp], + prompt="A cat in a hat", + ), + skipped=ImageGenerateJobPopSkippedStatus(), + )