Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: rc field support for RequestErrorResponse; feat: better __eq__ and __hash__ implementations where appropriate #143

Merged
merged 11 commits into from
Feb 17, 2024
Merged
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 23.12.1
rev: 24.2.0
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.14
rev: v0.2.1
hooks:
- id: ruff
- repo: https://github.com/pre-commit/mirrors-mypy
Expand Down
136 changes: 1 addition & 135 deletions docs/examples.md
Original file line number Diff line number Diff line change
@@ -1,137 +1,3 @@
# Example Clients

See `examples/` for a complete list. These examples are all made in mind with your current working directory as `horde_sdk` (e.g., `cd horde_sdk`).

## Simple Client (sync) Example
From `examples/ai_horde_client/aihorde_simple_client_example.py`:

``` python
from horde_sdk.ai_horde_api.ai_horde_clients import AIHordeAPISimpleClient
from horde_sdk.ai_horde_api.apimodels import ImageGenerateAsyncRequest, ImageGeneration


def simple_generate_example() -> None:
simple_client = AIHordeAPISimpleClient()

generations: list[ImageGeneration] = simple_client.image_generate_request(
ImageGenerateAsyncRequest(
apikey=ANON_API_KEY,
prompt="A cat in a hat",
models=["Deliberate"],
),
)

image = simple_client.generation_to_image(generations[0])

image.save("cat_in_hat.webp")

if __name__ == "__main__":
simple_generate_example()
```



```python
import argparse
import asyncio
from collections.abc import Coroutine
from pathlib import Path

import aiohttp
from PIL.Image import Image

from horde_sdk import ANON_API_KEY, RequestErrorResponse
from horde_sdk.ai_horde_api.ai_horde_clients import AIHordeAPIAsyncSimpleClient
from horde_sdk.ai_horde_api.apimodels import ImageGenerateAsyncRequest, ImageGenerateStatusResponse
from horde_sdk.ai_horde_api.fields import JobID


async def async_one_image_generate_example(
simple_client: AIHordeAPIAsyncSimpleClient,
apikey: str = ANON_API_KEY,
) -> None:
single_generation_response: ImageGenerateStatusResponse
job_id: JobID

single_generation_response, job_id = await simple_client.image_generate_request(
ImageGenerateAsyncRequest(
apikey=apikey,
prompt="A cat in a hat",
models=["Deliberate"],
),
)

if isinstance(single_generation_response, RequestErrorResponse):
print(f"Error: {single_generation_response.message}")
else:
single_image, _ = await simple_client.download_image_from_generation(single_generation_response.generations[0])

example_path = Path("examples/requested_images")
example_path.mkdir(exist_ok=True, parents=True)

single_image.save(example_path / f"{job_id}_simple_async_example.webp")


async def async_multi_image_generate_example(
simple_client: AIHordeAPIAsyncSimpleClient,
apikey: str = ANON_API_KEY,
) -> None:
multi_generation_responses: tuple[
tuple[ImageGenerateStatusResponse, JobID],
tuple[ImageGenerateStatusResponse, JobID],
]
multi_generation_responses = await asyncio.gather(
simple_client.image_generate_request(
ImageGenerateAsyncRequest(
apikey=apikey,
prompt="A cat in a blue hat",
models=["Deliberate"],
),
),
simple_client.image_generate_request(
ImageGenerateAsyncRequest(
apikey=apikey,
prompt="A cat in a red hat",
models=["Deliberate"],
),
),
)

download_image_from_generation_calls: list[Coroutine[None, None, tuple[Image, JobID]]] = []

for status_response, _ in multi_generation_responses:
download_image_from_generation_calls.append(
simple_client.download_image_from_generation(status_response.generations[0]),
)

downloaded_images: list[tuple[Image, JobID]] = await asyncio.gather(*download_image_from_generation_calls)

example_path = Path("examples/requested_images")
example_path.mkdir(exist_ok=True, parents=True)

for image, job_id in downloaded_images:
image.save(example_path / f"{job_id}_simple_async_example.webp")


async def async_simple_generate_example(apikey: str = ANON_API_KEY) -> None:
async with aiohttp.ClientSession() as aiohttp_session:
simple_client = AIHordeAPIAsyncSimpleClient(aiohttp_session)

await async_one_image_generate_example(simple_client, apikey)
await async_multi_image_generate_example(simple_client, apikey)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="AI Horde API Manual Client Example")
parser.add_argument(
"--apikey",
type=str,
default=ANON_API_KEY,
help="The API key to use. Defaults to the anon key.",
)
args = parser.parse_args()

# Run the example.
asyncio.run(async_simple_generate_example(args.apikey))

```
See `examples/` (https://github.com/Haidra-Org/horde-sdk/tree/main/examples) for a complete list. These examples are all made in mind with your current working directory as `horde_sdk` (e.g., `cd horde_sdk`).
1 change: 1 addition & 0 deletions horde_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Any model or helper useful for creating or interacting with a horde API."""

# isort: off
# We import dotenv first so that we can use it to load environment variables before importing anything else.
import dotenv
Expand Down
1 change: 1 addition & 0 deletions horde_sdk/ai_horde_api/ai_horde_clients.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Definitions to help interact with the AI-Horde API."""

from __future__ import annotations

import asyncio
Expand Down
1 change: 1 addition & 0 deletions horde_sdk/ai_horde_api/apimodels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""All requests, responses and API models defined for the AI Horde API."""

from horde_sdk.ai_horde_api.apimodels._find_user import (
ContributionsDetails,
FindUserRequest,
Expand Down
24 changes: 15 additions & 9 deletions horde_sdk/ai_horde_api/apimodels/_find_user.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from datetime import datetime

from pydantic import BaseModel, Field
from pydantic import Field
from typing_extensions import override

from horde_sdk.ai_horde_api.apimodels.base import BaseAIHordeRequest
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_ENDPOINT_SUBPATH
from horde_sdk.consts import HTTPMethod
from horde_sdk.generic_api.apimodels import APIKeyAllowedInRequestMixin, HordeResponseBaseModel
from horde_sdk.generic_api.apimodels import APIKeyAllowedInRequestMixin, HordeAPIDataObject, HordeResponseBaseModel


class ContributionsDetails(BaseModel):
class ContributionsDetails(HordeAPIDataObject):
fulfillments: int | None = Field(default=None, description="How many images this user has generated.")
megapixelsteps: float | None = Field(default=None, description="How many megapixelsteps this user has generated.")


class UserKudosDetails(BaseModel):
class UserKudosDetails(HordeAPIDataObject):
accumulated: float | None = Field(0, description="The amount of Kudos accumulated or used for generating images.")
admin: float | None = Field(0, description="The amount of Kudos this user has been given by the Horde admins.")
awarded: float | None = Field(
Expand All @@ -29,33 +29,33 @@ class UserKudosDetails(BaseModel):
)


class MonthlyKudos(BaseModel):
class MonthlyKudos(HordeAPIDataObject):
amount: int | None = Field(default=None, description="How much recurring Kudos this user receives monthly.")
last_received: datetime | None = Field(default=None, description="Last date this user received monthly Kudos.")


class UserThingRecords(BaseModel):
class UserThingRecords(HordeAPIDataObject):
megapixelsteps: float | None = Field(
0,
description="How many megapixelsteps this user has generated or requested.",
)
tokens: int | None = Field(0, description="How many token this user has generated or requested.")


class UserAmountRecords(BaseModel):
class UserAmountRecords(HordeAPIDataObject):
image: int | None = Field(0, description="How many images this user has generated or requested.")
interrogation: int | None = Field(0, description="How many texts this user has generated or requested.")
text: int | None = Field(0, description="How many texts this user has generated or requested.")


class UserRecords(BaseModel):
class UserRecords(HordeAPIDataObject):
contribution: UserThingRecords | None = None
fulfillment: UserAmountRecords | None = None
request: UserAmountRecords | None = None
usage: UserThingRecords | None = None


class UsageDetails(BaseModel):
class UsageDetails(HordeAPIDataObject):
megapixelsteps: float | None = Field(default=None, description="How many megapixelsteps this user has requested.")
requests: int | None = Field(default=None, description="How many images this user has requested.")

Expand Down Expand Up @@ -183,6 +183,12 @@ def get_api_model_name(cls) -> str | None:
"""Whether this user has been invited to join a worker to the horde and how many of them.
When 0, this user cannot add (new) workers to the horde."""

def __eq__(self, other: object) -> bool:
raise NotImplementedError("TODO")

def __hash__(self) -> int:
raise NotImplementedError("TODO")


class FindUserRequest(BaseAIHordeRequest, APIKeyAllowedInRequestMixin):
@override
Expand Down
6 changes: 6 additions & 0 deletions horde_sdk/ai_horde_api/apimodels/_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def get_timeframe(self, timeframe: StatsModelsTimeframe) -> dict[str, int]:

raise ValueError(f"Invalid timeframe: {timeframe}")

def __eq__(self, other: object) -> bool:
raise NotImplementedError("Cannot compare StatsModelsResponse objects")

def __hash__(self) -> int:
raise NotImplementedError("Cannot hash StatsModelsResponse objects")


class StatsImageModelsRequest(BaseAIHordeRequest):
"""Represents the data needed to make a request to the `/v2/stats/img/models` endpoint."""
Expand Down
11 changes: 6 additions & 5 deletions horde_sdk/ai_horde_api/apimodels/alchemy/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import urllib.parse

from loguru import logger
from pydantic import BaseModel, field_validator
from pydantic import field_validator
from typing_extensions import override

from horde_sdk.ai_horde_api.apimodels.alchemy._status import AlchemyDeleteRequest, AlchemyStatusRequest
Expand All @@ -17,6 +17,7 @@
from horde_sdk.generic_api.apimodels import (
APIKeyAllowedInRequestMixin,
ContainsMessageResponseMixin,
HordeAPIDataObject,
HordeResponse,
HordeResponseBaseModel,
ResponseRequiringFollowUpMixin,
Expand Down Expand Up @@ -63,14 +64,14 @@ def get_follow_up_failure_cleanup_request_type(cls) -> type[AlchemyDeleteRequest
return AlchemyDeleteRequest


class AlchemyAsyncRequestFormItem(BaseModel):
class AlchemyAsyncRequestFormItem(HordeAPIDataObject):
name: KNOWN_ALCHEMY_TYPES | str

@field_validator("name")
def check_name(cls, v: KNOWN_ALCHEMY_TYPES | str) -> KNOWN_ALCHEMY_TYPES | str:
if (isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__) or (
not isinstance(v, KNOWN_ALCHEMY_TYPES)
):
if isinstance(v, KNOWN_ALCHEMY_TYPES):
return v
if isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__:
logger.warning(f"Unknown alchemy form name {v}. Is your SDK out of date or did the API change?")
return v

Expand Down
35 changes: 34 additions & 1 deletion horde_sdk/ai_horde_api/apimodels/alchemy/_pop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from loguru import logger
from pydantic import Field, field_validator
from pydantic import Field, field_validator, model_validator
from typing_extensions import override

from horde_sdk.ai_horde_api.apimodels.alchemy._submit import AlchemyJobSubmitRequest
Expand Down Expand Up @@ -51,6 +53,8 @@ def get_api_model_name(cls) -> str | None:

@field_validator("form", mode="before")
def validate_form(cls, v: str | KNOWN_ALCHEMY_TYPES) -> KNOWN_ALCHEMY_TYPES | str:
if isinstance(v, KNOWN_ALCHEMY_TYPES):
return v
if isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__:
logger.warning(f"Unknown form type {v}")
return v
Expand Down Expand Up @@ -130,12 +134,41 @@ def get_follow_up_returned_params(self, *, as_python_field_name: bool = False) -

return all_ids

@model_validator(mode="after")
def coerce_list_order(self) -> AlchemyPopResponse:
if self.forms is not None:
logger.debug("Sorting forms by id")
self.forms.sort(key=lambda form: form.id_)

return self

@override
@classmethod
def get_follow_up_request_types(cls) -> list[type[AlchemyJobSubmitRequest]]: # type: ignore[override]
"""Return a list of all the possible follow up request types for this response."""
return [AlchemyJobSubmitRequest]

def __eq__(self, other: object) -> bool:
if not isinstance(other, AlchemyPopResponse):
return False

forms_match = True
skipped_match = True

if self.forms is not None and other.forms is not None:
forms_match = all(form in other.forms for form in self.forms)

if self.skipped is not None:
skipped_match = self.skipped == other.skipped

return forms_match and skipped_match

def __hash__(self) -> int:
if self.forms is None:
return hash(self.skipped)

return hash((tuple([form.id_ for form in self.forms]), self.skipped))


class AlchemyPopRequest(BaseAIHordeRequest, APIKeyAllowedInRequestMixin):
"""Represents the data needed to make a request to the `/v2/interrogate/pop` endpoint.
Expand Down
Loading
Loading