Skip to content

Commit

Permalink
Merge pull request #24 from Haidra-Org/0.7.9
Browse files Browse the repository at this point in the history
feat: allow flexible `AIHordeAPIAsyncClientSession` context
  • Loading branch information
tazlin authored Sep 17, 2023
2 parents 8f22ee5 + 7733111 commit e28bd71
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 5 deletions.
26 changes: 24 additions & 2 deletions horde_sdk/ai_horde_api/ai_horde_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import asyncio
import base64
import contextlib
import io
import time
import urllib.parse
Expand Down Expand Up @@ -779,9 +780,19 @@ def alchemy_request(
class AIHordeAPIAsyncSimpleClient(BaseAIHordeSimpleClient):
"""An asyncio based simple client for the AI-Horde API. Start with this class if you want asyncio capabilities.."""

def __init__(self, aiohttp_session: aiohttp.ClientSession) -> None:
_horde_client_session: AIHordeAPIAsyncClientSession | None

def __init__(
self,
aiohttp_session: aiohttp.ClientSession | None,
horde_client_session: AIHordeAPIAsyncClientSession | None = None,
) -> None:
"""Create a new instance of the AIHordeAPISimpleClient."""
if aiohttp_session is not None and horde_client_session is not None:
raise ValueError("Only one of aiohttp_session or horde_client_session can be provided")

self._aiohttp_session = aiohttp_session
self._horde_client_session = horde_client_session

async def download_image_from_generation(self, generation: ImageGeneration) -> tuple[PIL.Image.Image, JobID]:
"""Asynchronously convert from base64 or download an image from a response.
Expand Down Expand Up @@ -876,8 +887,19 @@ async def _do_request_with_check(
AIHordeRequestError: If the request failed. The error response is included in the exception.
"""

context: contextlib.AbstractContextManager | AIHordeAPIAsyncClientSession
ai_horde_session: AIHordeAPIAsyncClientSession

if self._horde_client_session is not None:
# Use a dummy context manager to keep the type checker happy
context = contextlib.nullcontext()
ai_horde_session = self._horde_client_session
elif self._aiohttp_session is not None:
ai_horde_session = AIHordeAPIAsyncClientSession(self._aiohttp_session)
context = ai_horde_session

# This session class will cleanup incomplete requests in the event of an exception
async with AIHordeAPIAsyncClientSession(aiohttp_session=self._aiohttp_session) as ai_horde_session:
async with context: # type: ignore
# Submit the initial request
logger.debug(
f"Submitting request: {api_request.log_safe_model_dump()} with timeout {timeout}",
Expand Down
5 changes: 4 additions & 1 deletion horde_sdk/ai_horde_api/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ def __str__(self) -> str:

@override
def __eq__(self, other: Any) -> bool:
if isinstance(other, UUID_Identifier):
return self.root == other.root

if isinstance(other, str):
return self.root.__str__() == other

if isinstance(other, uuid.UUID):
return str(self.root) == str(other)
return self.root == other

return False

Expand Down
4 changes: 3 additions & 1 deletion horde_sdk/generic_api/generic_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,8 @@ async def submit_request(
) -> HordeResponseTypeVar | RequestErrorResponse:
# Add the request to the list of awaiting requests.

self._awaiting_requests.append(api_request)
async with self._awaiting_requests_lock:
self._awaiting_requests.append(api_request)

# Submit the request to the API and get the response.
response = await super().submit_request(api_request, expected_response_type)
Expand Down Expand Up @@ -783,6 +784,7 @@ async def _handle_exit_async(
# Log the results of each cleanup request.
for i, cleanup_response in enumerate(cleanup_responses):
logger.info(f"Recovery request {i+1} submitted!")
logger.debug(f"Recovery request {i+1}: {cleanup_requests[i].log_safe_model_dump()}")
logger.debug(f"Recovery response {i+1}: {cleanup_response}")

# Return True to indicate that all requests were handled successfully.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "horde_sdk"
version = "0.7.7"
version = "0.7.9"
description = "A python toolkit for interacting with the horde APIs, services, and ecosystem."
authors = [
{name = "tazlin", email = "[email protected]"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,11 @@ def json_serializer(obj: object) -> object:

with open("docs/api_to_sdk_payload_map.json", "w") as f:
f.write(json.dumps(api_to_sdk_payload_model_map, indent=4, default=json_serializer))
f.write("\n")

with open("docs/api_to_sdk_response_map.json", "w") as f:
f.write(json.dumps(api_to_sdk_response_model_map, indent=4, default=json_serializer))
f.write("\n")


def test_all_ai_horde_model_defs_in_swagger_from_prod_swagger() -> None:
Expand Down

0 comments on commit e28bd71

Please sign in to comment.