Skip to content

Commit

Permalink
fix: check by value for KNOWN_UPSCALERS enum membership
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Mar 9, 2024
1 parent 39653a7 commit cc6ebbf
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
5 changes: 4 additions & 1 deletion horde_sdk/ai_horde_api/apimodels/generate/_pop.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,10 @@ def has_upscaler(self) -> bool:
if len(self.payload.post_processing) == 0:
return False

return any(post_processing in KNOWN_UPSCALERS.__members__ for post_processing in self.payload.post_processing)
return any(
post_processing in KNOWN_UPSCALERS.__members__ or post_processing in KNOWN_UPSCALERS._value2member_map_
for post_processing in self.payload.post_processing
)

@property
def has_facefixer(self) -> bool:
Expand Down
7 changes: 7 additions & 0 deletions tests/ai_horde_api/test_ai_horde_api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,11 @@ def test_ImageGenerateJobPopResponse() -> None:
skipped=ImageGenerateJobPopSkippedStatus(),
)

assert all(
post_processor in KNOWN_UPSCALERS._value2member_map_
for post_processor in test_response.payload.post_processing
)

test_response = ImageGenerateJobPopResponse(
id=None,
ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))],
Expand All @@ -447,6 +452,8 @@ def test_ImageGenerateJobPopResponse() -> None:
skipped=ImageGenerateJobPopSkippedStatus(),
)

assert all(post_processor in KNOWN_UPSCALERS for post_processor in test_response.payload.post_processing)


def test_ImageGenerateJobPopResponse_hashability() -> None:
test_response_ids = ImageGenerateJobPopResponse(
Expand Down

0 comments on commit cc6ebbf

Please sign in to comment.