From be15448c7d5d925a4f755f1ed8541de6230c00b4 Mon Sep 17 00:00:00 2001 From: tazlin Date: Thu, 25 Jan 2024 14:21:07 -0500 Subject: [PATCH] fix: logic issue with `KNOWN_SAMPLERS` check --- horde_sdk/ai_horde_api/apimodels/base.py | 6 ++++-- tests/ai_horde_api/test_ai_horde_api_models.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/horde_sdk/ai_horde_api/apimodels/base.py b/horde_sdk/ai_horde_api/apimodels/base.py index e7cb89c..ae5f291 100644 --- a/horde_sdk/ai_horde_api/apimodels/base.py +++ b/horde_sdk/ai_horde_api/apimodels/base.py @@ -183,8 +183,10 @@ def width_divisible_by_64(cls, value: int) -> int: @field_validator("sampler_name") def sampler_name_must_be_known(cls, v: str | KNOWN_SAMPLERS) -> str | KNOWN_SAMPLERS: """Ensure that the sampler name is in this list of supported samplers.""" - if (isinstance(v, str) and v not in KNOWN_SAMPLERS.__members__) or (not isinstance(v, KNOWN_SAMPLERS)): - logger.warning(f"Unknown sampler name {v}. Is your SDK out of date or did the API change?") + if (isinstance(v, str) and v in KNOWN_SAMPLERS.__members__) or (isinstance(v, KNOWN_SAMPLERS)): + return v + + logger.warning(f"Unknown sampler name {v}. Is your SDK out of date or did the API change?") return v diff --git a/tests/ai_horde_api/test_ai_horde_api_models.py b/tests/ai_horde_api/test_ai_horde_api_models.py index d7a0c49..eb18f1e 100644 --- a/tests/ai_horde_api/test_ai_horde_api_models.py +++ b/tests/ai_horde_api/test_ai_horde_api_models.py @@ -47,7 +47,8 @@ def test_ImageGenerateAsyncRequest(ai_horde_api_key: str) -> None: models=["Deliberate"], prompt="test prompt", params=ImageGenerationInputPayload( - sampler_name=KNOWN_SAMPLERS.k_lms, + # sampler_name="DDIM", + sampler_name=KNOWN_SAMPLERS.DDIM, cfg_scale=7.5, denoising_strength=1, seed="123456789", @@ -86,7 +87,7 @@ def test_ImageGenerateAsyncRequest(ai_horde_api_key: str) -> None: assert test_async_request.models == ["Deliberate"] assert test_async_request.prompt == "test prompt" assert test_async_request.params is not None - assert test_async_request.params.sampler_name == "k_lms" + assert test_async_request.params.sampler_name == "DDIM" assert test_async_request.params.cfg_scale == 7.5 assert test_async_request.params.denoising_strength == 1 assert test_async_request.params.seed is not None