diff --git a/docs/response_field_names_and_descriptions.json b/docs/response_field_names_and_descriptions.json index 8a98680..fd7258e 100644 --- a/docs/response_field_names_and_descriptions.json +++ b/docs/response_field_names_and_descriptions.json @@ -233,7 +233,7 @@ } }, "DeleteWorkerResponse": { - "deleted_id_": { + "deleted_id": { "description": "The ID of the deleted worker.", "types": [ "str", @@ -1238,6 +1238,20 @@ "int", "None" ] + }, + "controlnet": { + "description": "If True, this worker supports and allows controlnet requests.", + "types": [ + "bool", + "None" + ] + }, + "sdxl_controlnet": { + "description": "If True, this worker supports and allows sdxl controlnet requests.", + "types": [ + "bool", + "None" + ] } }, "TextGenerateAsyncDryRunResponse": { diff --git a/horde_sdk/__init__.py b/horde_sdk/__init__.py index 2c56a33..b0bb4a2 100644 --- a/horde_sdk/__init__.py +++ b/horde_sdk/__init__.py @@ -76,7 +76,7 @@ def _dev_env_var_warnings() -> None: # pragma: no cover from horde_sdk.generic_api.apimodels import ( APIKeyAllowedInRequestMixin, ContainsMessageResponseMixin, - HordeAPIDataObject, + HordeAPIData, HordeAPIMessage, HordeAPIObject, HordeRequest, @@ -99,7 +99,7 @@ def _dev_env_var_warnings() -> None: # pragma: no cover "APIKeyAllowedInRequestMixin", "HordeRequest", "ContainsMessageResponseMixin", - "HordeAPIDataObject", + "HordeAPIData", "HordeAPIMessage", "HordeAPIObject", "RequestErrorResponse", diff --git a/horde_sdk/ai_horde_api/apimodels/_documents.py b/horde_sdk/ai_horde_api/apimodels/_documents.py index 54c7858..0e3b1f0 100644 --- a/horde_sdk/ai_horde_api/apimodels/_documents.py +++ b/horde_sdk/ai_horde_api/apimodels/_documents.py @@ -7,7 +7,7 @@ from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_ENDPOINT_SUBPATH from horde_sdk.consts import HTTPMethod from horde_sdk.generic_api.apimodels import ( - HordeAPIObject, + HordeAPIObjectBaseModel, HordeResponseBaseModel, ) @@ -29,7 +29,7 @@ def get_api_model_name(cls) -> str: return "HordeDocument" -class AIHordeDocumentRequestMixin(HordeAPIObject): +class AIHordeDocumentRequestMixin(HordeAPIObjectBaseModel): format: DocumentFormat | str = DocumentFormat.html """The format of the document to return. Default is markdown.""" @@ -108,7 +108,7 @@ def get_http_method(cls) -> HTTPMethod: @override @classmethod def get_api_endpoint_subpath(cls) -> AI_HORDE_API_ENDPOINT_SUBPATH: - return AI_HORDE_API_ENDPOINT_SUBPATH.vs_documents_terms + return AI_HORDE_API_ENDPOINT_SUBPATH.v2_documents_terms @override @classmethod diff --git a/horde_sdk/ai_horde_api/apimodels/_kudos.py b/horde_sdk/ai_horde_api/apimodels/_kudos.py index 386f05b..68fea6a 100644 --- a/horde_sdk/ai_horde_api/apimodels/_kudos.py +++ b/horde_sdk/ai_horde_api/apimodels/_kudos.py @@ -3,10 +3,10 @@ from horde_sdk.ai_horde_api.apimodels.base import BaseAIHordeRequest from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_ENDPOINT_SUBPATH from horde_sdk.consts import _ANONYMOUS_MODEL, HTTPMethod -from horde_sdk.generic_api.apimodels import APIKeyAllowedInRequestMixin, HordeResponse +from horde_sdk.generic_api.apimodels import APIKeyAllowedInRequestMixin, HordeResponseBaseModel -class KudosTransferResponse(HordeResponse): +class KudosTransferResponse(HordeResponseBaseModel): transferred: float | None = None """The amount of Kudos transferred.""" diff --git a/horde_sdk/ai_horde_api/apimodels/_stats.py b/horde_sdk/ai_horde_api/apimodels/_stats.py index 83a0f69..0a8876f 100644 --- a/horde_sdk/ai_horde_api/apimodels/_stats.py +++ b/horde_sdk/ai_horde_api/apimodels/_stats.py @@ -9,7 +9,7 @@ from horde_sdk.ai_horde_api.consts import MODEL_STATE from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_ENDPOINT_SUBPATH from horde_sdk.consts import HTTPMethod -from horde_sdk.generic_api.apimodels import HordeAPIDataObject, HordeResponseBaseModel +from horde_sdk.generic_api.apimodels import HordeAPIObjectBaseModel, HordeResponseBaseModel from horde_sdk.generic_api.decoration import Unequatable, Unhashable @@ -122,7 +122,12 @@ def get_default_success_response_type(cls) -> type[ImageStatsModelsResponse]: return ImageStatsModelsResponse -class SinglePeriodImgStat(HordeAPIDataObject): +class SinglePeriodImgStat(HordeAPIObjectBaseModel): + """Represents the stats for a single period of image generation. + + v2 API Model: `SinglePeriodImgStat` + """ + images: int | None = Field( default=None, ) @@ -140,6 +145,11 @@ def mps(self) -> int | None: return self.ps // 1_000_000 + @override + @classmethod + def get_api_model_name(cls) -> str | None: + return "SinglePeriodImgStat" + class ImageStatsModelsTotalResponse(HordeResponseBaseModel): """Represents the data returned from the `/v2/stats/img/totals` endpoint.""" @@ -253,7 +263,12 @@ def get_default_success_response_type(cls) -> type[TextStatsModelResponse]: return TextStatsModelResponse -class SinglePeriodTxtStat(HordeAPIDataObject): +class SinglePeriodTxtStat(HordeAPIObjectBaseModel): + """Represents the stats for a single period. + + v2 API Model: `SinglePeriodTxtStat` + """ + requests: int | None = Field( default=None, ) @@ -263,6 +278,11 @@ class SinglePeriodTxtStat(HordeAPIDataObject): ) """The number of tokens generated during this period.""" + @override + @classmethod + def get_api_model_name(cls) -> str | None: + return "SinglePeriodTxtStat" + @Unhashable class TextStatsModelsTotalResponse(HordeResponseBaseModel): diff --git a/horde_sdk/ai_horde_api/apimodels/_status.py b/horde_sdk/ai_horde_api/apimodels/_status.py index 51f983b..f5790b0 100644 --- a/horde_sdk/ai_horde_api/apimodels/_status.py +++ b/horde_sdk/ai_horde_api/apimodels/_status.py @@ -1,6 +1,6 @@ from collections.abc import Iterator -from pydantic import ConfigDict, Field, RootModel +from pydantic import ConfigDict, Field from typing_extensions import override from horde_sdk.ai_horde_api.apimodels.base import BaseAIHordeRequest @@ -9,9 +9,9 @@ from horde_sdk.consts import HTTPMethod from horde_sdk.generic_api.apimodels import ( ContainsMessageResponseMixin, - HordeAPIObject, - HordeResponse, + HordeAPIObjectBaseModel, HordeResponseBaseModel, + HordeResponseRootModel, ) from horde_sdk.generic_api.decoration import Unhashable @@ -148,7 +148,7 @@ def get_default_success_response_type(cls) -> type[HordePerformanceResponse]: return HordePerformanceResponse -class Newspiece(HordeAPIObject): +class Newspiece(HordeAPIObjectBaseModel): date_published: str | None = Field( default=None, ) @@ -159,6 +159,18 @@ class Newspiece(HordeAPIObject): default=None, ) """The actual piece of news.""" + tags: list[str] | None = Field( + default=None, + ) + """The tags associated with this newspiece.""" + title: str | None = Field( + default=None, + ) + """The title of this newspiece.""" + more_info_urls: list[str] | None = Field( + default=None, + ) + """The URLs for more information about this newspiece.""" @override @classmethod @@ -167,7 +179,7 @@ def get_api_model_name(cls) -> str | None: @Unhashable -class NewsResponse(HordeResponse, RootModel[list[Newspiece]]): +class NewsResponse(HordeResponseRootModel[list[Newspiece]]): root: list[Newspiece] """The underlying list of newspieces.""" @@ -210,7 +222,7 @@ def get_default_success_response_type(cls) -> type[NewsResponse]: return NewsResponse -class ActiveModelLite(HordeAPIObject): +class ActiveModelLite(HordeAPIObjectBaseModel): count: int | None = Field( default=None, ) @@ -256,7 +268,7 @@ def get_api_model_name(cls) -> str | None: @Unhashable -class HordeStatusModelsAllResponse(HordeResponse, RootModel[list[ActiveModel]]): +class HordeStatusModelsAllResponse(HordeResponseRootModel[list[ActiveModel]]): root: list[ActiveModel] """The underlying list of models.""" @@ -325,7 +337,7 @@ def get_query_fields(cls) -> list[str]: @Unhashable -class HordeStatusModelsSingleResponse(HordeResponse, RootModel[list[ActiveModel]]): +class HordeStatusModelsSingleResponse(HordeResponseRootModel[list[ActiveModel]]): # This is a list because of an oversight in the structure of the API response. # FIXME root: list[ActiveModel] @@ -377,7 +389,7 @@ def get_default_success_response_type(cls) -> type[HordeStatusModelsSingleRespon return HordeStatusModelsSingleResponse -class HordeModes(HordeAPIObject): +class HordeModes(HordeAPIObjectBaseModel): maintenance_mode: bool = Field( default=False, ) diff --git a/horde_sdk/ai_horde_api/apimodels/_styles.py b/horde_sdk/ai_horde_api/apimodels/_styles.py index 2349cac..ae79b4a 100644 --- a/horde_sdk/ai_horde_api/apimodels/_styles.py +++ b/horde_sdk/ai_horde_api/apimodels/_styles.py @@ -3,7 +3,7 @@ from pydantic import Field from strenum import StrEnum -from horde_sdk.generic_api.apimodels import HordeAPIDataObject +from horde_sdk.generic_api.apimodels import HordeAPIData class StyleType(StrEnum): @@ -13,7 +13,10 @@ class StyleType(StrEnum): text = auto() -class ResponseModelStylesUser(HordeAPIDataObject): +class ResponseModelStylesUser(HordeAPIData): name: str + """The name of the style.""" id_: str = Field(alias="id") + """The ID of the style.""" type_: StyleType = Field(alias="type") + """The type of the style.""" diff --git a/horde_sdk/ai_horde_api/apimodels/_users.py b/horde_sdk/ai_horde_api/apimodels/_users.py index 8ebb75a..c1d894c 100644 --- a/horde_sdk/ai_horde_api/apimodels/_users.py +++ b/horde_sdk/ai_horde_api/apimodels/_users.py @@ -1,6 +1,6 @@ from datetime import datetime -from pydantic import Field, RootModel +from pydantic import Field from typing_extensions import override from horde_sdk.ai_horde_api.apimodels._styles import ResponseModelStylesUser @@ -10,27 +10,40 @@ from horde_sdk.consts import _ANONYMOUS_MODEL, HTTPMethod from horde_sdk.generic_api.apimodels import ( APIKeyAllowedInRequestMixin, - HordeAPIDataObject, - HordeResponse, + HordeAPIObjectBaseModel, HordeResponseBaseModel, + HordeResponseRootModel, RequestSpecifiesUserIDMixin, ) from horde_sdk.generic_api.decoration import Unequatable, Unhashable -class ContributionsDetails(HordeAPIDataObject): - """How many images and megapixelsteps this user has generated.""" +class ContributionsDetails(HordeAPIObjectBaseModel): + """How many images and megapixelsteps this user has generated. + + v2 API Model: ContributionsDetails + """ fulfillments: int | None = Field( default=None, ) + """How many images this user has generated.""" megapixelsteps: float | None = Field( default=None, ) + """How many megapixelsteps this user has generated.""" + + @override + @classmethod + def get_api_model_name(cls) -> str | None: + return "ContributionsDetails" + +class UserKudosDetails(HordeAPIObjectBaseModel): + """The details of the kudos this user has accumulated, used, sent and received. -class UserKudosDetails(HordeAPIDataObject): - """The details of the kudos this user has accumulated, used, sent and received.""" + v2 API Model: UserKudosDetails + """ accumulated: float | None = Field(0) """The amount of Kudos accumulated or used for generating images.""" @@ -56,24 +69,54 @@ class UserKudosDetails(HordeAPIDataObject): styled: float | None = Field(0) """The amount of Kudos this user has received from styling images.""" + @override + @classmethod + def get_api_model_name(cls) -> str | None: + return "UserKudosDetails" + + +class MonthlyKudos(HordeAPIObjectBaseModel): + """The details of the monthly kudos this user receives. + + v2 API Model: MonthlyKudos + """ -class MonthlyKudos(HordeAPIDataObject): amount: int | None = Field(default=None) """How much recurring Kudos this user receives monthly.""" last_received: datetime | None = Field(default=None) """Last date this user received monthly Kudos.""" + @override + @classmethod + def get_api_model_name(cls) -> str | None: + return "MonthlyKudos" + + +class UserThingRecords(HordeAPIObjectBaseModel): + """How many images, texts, megapixelsteps and tokens this user has generated or requested. + + v2 API Model: UserThingRecords + """ -class UserThingRecords(HordeAPIDataObject): megapixelsteps: float | None = Field(0) """How many megapixelsteps this user has generated or requested.""" tokens: int | None = Field(0) """How many token this user has generated or requested.""" + @override + @classmethod + def get_api_model_name(cls) -> str | None: + return "UserThingRecords" + + +class UserAmountRecords(HordeAPIObjectBaseModel): + """How many images, texts, megapixelsteps and tokens this user has generated or requested. + + v2 API Model: UserAmountRecords + """ -class UserAmountRecords(HordeAPIDataObject): image: int | None = Field(0) """How many images this user has generated or requested.""" @@ -83,26 +126,61 @@ class UserAmountRecords(HordeAPIDataObject): text: int | None = Field(0) """How many texts this user has generated or requested.""" + @override + @classmethod + def get_api_model_name(cls) -> str | None: + return "UserAmountRecords" + + +class UserRecords(HordeAPIObjectBaseModel): + """How many images, texts, megapixelsteps, tokens and styles this user has generated, requested or has had used. + + v2 API Model: UserRecords + """ -class UserRecords(HordeAPIDataObject): contribution: UserThingRecords | None = None + """How much this user has contributed.""" fulfillment: UserAmountRecords | None = None + """How much this user has fulfilled.""" request: UserAmountRecords | None = None + """How much this user has requested.""" usage: UserThingRecords | None = None + """How much this user has used.""" style: UserAmountRecords | None = None + """How much this user's styles have been used.""" + @override + @classmethod + def get_api_model_name(cls) -> str | None: + return "UserRecords" + + +class UsageDetails(HordeAPIObjectBaseModel): + """How many images and megapixelsteps this user has requested. + + v2 API Model: UsageDetails + """ -class UsageDetails(HordeAPIDataObject): megapixelsteps: float | None = Field(default=None) """How many megapixelsteps this user has requested.""" requests: int | None = Field(default=None) """How many images this user has requested.""" + @override + @classmethod + def get_api_model_name(cls) -> str | None: + return "UsageDetails" + @Unhashable @Unequatable -class ActiveGenerations(HordeAPIDataObject): +class ActiveGenerations(HordeAPIObjectBaseModel): + """A list of generations that are currently active for this user. + + v2 API Model: ActiveGenerations + """ + """A list of generations that are currently active for this user.""" text: list[UUID_Identifier] | None = None @@ -114,10 +192,20 @@ class ActiveGenerations(HordeAPIDataObject): alchemy: list[UUID_Identifier] | None = None """The IDs of the alchemy generations that are currently active for this user.""" + @override + @classmethod + def get_api_model_name(cls) -> str | None: + return "ActiveGenerations" + @Unhashable @Unequatable class UserDetailsResponse(HordeResponseBaseModel): + """The details of a user. + + v2 API Model: UserDetails + """ + @override @classmethod def get_api_model_name(cls) -> str | None: @@ -275,7 +363,12 @@ def get_api_model_name(cls) -> str | None: @Unhashable @Unequatable -class ListUsersDetailsResponse(HordeResponse, RootModel[list[UserDetailsResponse]]): +class ListUsersDetailsResponse(HordeResponseRootModel[list[UserDetailsResponse]]): + """The response for a list of user details. + + v2 API Model: _ANONYMOUS_MODEL + """ + root: list[UserDetailsResponse] """The underlying list of user details.""" @@ -286,6 +379,8 @@ def get_api_model_name(cls) -> str: class ListUsersDetailsRequest(BaseAIHordeRequest): + """Represents a request to list all users.""" + page: int """The page number to request. There are up to 25 users per page.""" @@ -341,7 +436,7 @@ def get_default_success_response_type(cls) -> type[UserDetailsResponse]: return UserDetailsResponse -class _ModifyUserBase(HordeAPIDataObject): +class _ModifyUserBase(HordeAPIObjectBaseModel): admin_comment: str | None = Field( default=None, max_length=500, @@ -457,7 +552,7 @@ class ModifyUserReply(_ModifyUserBase): """The new amount of suspicion this user has.""" -class ModifyUserResponse(HordeResponse, ModifyUserReply): +class ModifyUserResponse(HordeResponseBaseModel, ModifyUserReply): @override @classmethod def get_api_model_name(cls) -> str: diff --git a/horde_sdk/ai_horde_api/apimodels/alchemy/_async.py b/horde_sdk/ai_horde_api/apimodels/alchemy/_async.py index cab6329..eb98b75 100644 --- a/horde_sdk/ai_horde_api/apimodels/alchemy/_async.py +++ b/horde_sdk/ai_horde_api/apimodels/alchemy/_async.py @@ -17,9 +17,9 @@ from horde_sdk.generic_api.apimodels import ( APIKeyAllowedInRequestMixin, ContainsMessageResponseMixin, - HordeAPIDataObject, - HordeResponse, + HordeAPIData, HordeResponseBaseModel, + HordeResponseTypes, ResponseRequiringFollowUpMixin, ) @@ -64,8 +64,9 @@ def get_follow_up_failure_cleanup_request_type(cls) -> type[AlchemyDeleteRequest return AlchemyDeleteRequest -class AlchemyAsyncRequestFormItem(HordeAPIDataObject): +class AlchemyAsyncRequestFormItem(HordeAPIData): name: KNOWN_ALCHEMY_TYPES | str + """The name of the form to request.""" @field_validator("name") def check_name(cls, v: KNOWN_ALCHEMY_TYPES | str) -> KNOWN_ALCHEMY_TYPES | str: @@ -131,7 +132,7 @@ def get_default_success_response_type(cls) -> type[AlchemyAsyncResponse]: @override @classmethod - def get_success_status_response_pairs(cls) -> dict[HTTPStatusCode, type[HordeResponse]]: + def get_success_status_response_pairs(cls) -> dict[HTTPStatusCode, type[HordeResponseTypes]]: return { HTTPStatusCode.ACCEPTED: cls.get_default_success_response_type(), } diff --git a/horde_sdk/ai_horde_api/apimodels/alchemy/_pop.py b/horde_sdk/ai_horde_api/apimodels/alchemy/_pop.py index 76f6f0a..28adacf 100644 --- a/horde_sdk/ai_horde_api/apimodels/alchemy/_pop.py +++ b/horde_sdk/ai_horde_api/apimodels/alchemy/_pop.py @@ -14,14 +14,14 @@ from horde_sdk.consts import HTTPMethod from horde_sdk.generic_api.apimodels import ( APIKeyAllowedInRequestMixin, - HordeAPIObject, + HordeAPIObjectBaseModel, HordeResponseBaseModel, ResponseRequiringFollowUpMixin, ) # FIXME -class AlchemyFormPayloadStable(HordeAPIObject): +class AlchemyFormPayloadStable(HordeAPIObjectBaseModel): """Currently unsupported. v2 API Model: `ModelInterrogationFormPayloadStable` @@ -46,7 +46,7 @@ def get_api_model_name(cls) -> str | None: """Currently unsupported.""" -class AlchemyPopFormPayload(HordeAPIObject, JobRequestMixin): +class AlchemyPopFormPayload(HordeAPIObjectBaseModel, JobRequestMixin): """v2 API Model: `InterrogationPopFormPayload`.""" @override @@ -79,7 +79,7 @@ def validate_form(cls, v: str | KNOWN_ALCHEMY_TYPES) -> KNOWN_ALCHEMY_TYPES | st """The URL From which the source image can be downloaded.""" -class NoValidAlchemyFound(HordeAPIObject): +class NoValidAlchemyFound(HordeAPIObjectBaseModel): """v2 API Model: `NoValidInterrogationsFoundStable`.""" @override @@ -112,6 +112,19 @@ def get_api_model_name(cls) -> str | None: ) """How many waiting requests were skipped because they demanded a specific worker.""" + def __eq__(self, other: object) -> bool: + if not isinstance(other, NoValidAlchemyFound): + return False + + return ( + self.bridge_version == other.bridge_version + and self.untrusted == other.untrusted + and self.worker_id == other.worker_id + ) + + def __hash__(self) -> int: + return hash((self.bridge_version, self.untrusted, self.worker_id)) + class AlchemyPopResponse(HordeResponseBaseModel, ResponseRequiringFollowUpMixin): """v2 API Model: `InterrogationPopPayload`.""" diff --git a/horde_sdk/ai_horde_api/apimodels/alchemy/_status.py b/horde_sdk/ai_horde_api/apimodels/alchemy/_status.py index 5b2a6e7..b927f3e 100644 --- a/horde_sdk/ai_horde_api/apimodels/alchemy/_status.py +++ b/horde_sdk/ai_horde_api/apimodels/alchemy/_status.py @@ -8,7 +8,7 @@ from horde_sdk.consts import HTTPMethod from horde_sdk.generic_api.apimodels import ( APIKeyAllowedInRequestMixin, - HordeAPIDataObject, + HordeAPIData, HordeResponseBaseModel, ResponseWithProgressMixin, ) @@ -16,56 +16,73 @@ # FIXME: All vs API models defs? (override get_api_model_name and add to docstrings) -class AlchemyUpscaleResult(HordeAPIDataObject): +class AlchemyUpscaleResult(HordeAPIData): """Represents the result of an upscale job.""" upscaler_used: KNOWN_UPSCALERS | str + """The upscaler used.""" url: str + """The URL of the upscaled image.""" -class AlchemyCaptionResult(HordeAPIDataObject): +class AlchemyCaptionResult(HordeAPIData): """Represents the result of a caption job.""" caption: str + """The resulting caption of the image.""" -class AlchemyNSFWResult(HordeAPIDataObject): +class AlchemyNSFWResult(HordeAPIData): """Represents the result of an NSFW evaluation.""" nsfw: bool + """Whether the image is likely to be NSFW.""" -class AlchemyInterrogationResultItem(HordeAPIDataObject): +class AlchemyInterrogationResultItem(HordeAPIData): """Represents an item in the result of an interrogation job.""" text: str + """The text of the item.""" confidence: float + """The confidence of the item.""" -class AlchemyInterrogationDetails(HordeAPIDataObject): +class AlchemyInterrogationDetails(HordeAPIData): """The details of an interrogation job.""" tags: list[AlchemyInterrogationResultItem] + """The resulting similar tags of the image.""" sites: list[AlchemyInterrogationResultItem] + """The resulting similar sites of the image.""" artists: list[AlchemyInterrogationResultItem] + """The resulting similar artists of the image.""" flavors: list[AlchemyInterrogationResultItem] + """The resulting similar flavors of the image.""" mediums: list[AlchemyInterrogationResultItem] + """The resulting similar mediums of the image.""" movements: list[AlchemyInterrogationResultItem] + """The resulting similar movements of the image.""" techniques: list[AlchemyInterrogationResultItem] + """The resulting similar techniques of the image.""" -class AlchemyInterrogationResult(HordeAPIDataObject): +class AlchemyInterrogationResult(HordeAPIData): """Represents the result of an interrogation job. Use the `interrogation` field for the details.""" interrogation: AlchemyInterrogationDetails + """The details of the interrogation.""" -class AlchemyFormStatus(HordeAPIDataObject): +class AlchemyFormStatus(HordeAPIData): """Represents the status of a form in an interrogation job.""" form: KNOWN_ALCHEMY_TYPES | str + """The form type.""" state: GENERATION_STATE + """The state of the form.""" result: AlchemyInterrogationDetails | AlchemyNSFWResult | AlchemyCaptionResult | AlchemyUpscaleResult | None = None + """The result of the form.""" @field_validator("form", mode="before") def validate_form(cls, v: str | KNOWN_ALCHEMY_TYPES) -> KNOWN_ALCHEMY_TYPES | str: diff --git a/horde_sdk/ai_horde_api/apimodels/base.py b/horde_sdk/ai_horde_api/apimodels/base.py index 050b32c..18997a1 100644 --- a/horde_sdk/ai_horde_api/apimodels/base.py +++ b/horde_sdk/ai_horde_api/apimodels/base.py @@ -26,7 +26,12 @@ ) from horde_sdk.ai_horde_api.endpoints import AI_HORDE_BASE_URL from horde_sdk.ai_horde_api.fields import JobID, WorkerID -from horde_sdk.generic_api.apimodels import HordeAPIDataObject, HordeRequest, HordeResponseBaseModel +from horde_sdk.generic_api.apimodels import ( + HordeAPIData, + HordeAPIObjectBaseModel, + HordeRequest, + HordeResponseBaseModel, +) class BaseAIHordeRequest(HordeRequest): @@ -38,7 +43,7 @@ def get_api_url(cls) -> str: return AI_HORDE_BASE_URL -class JobRequestMixin(HordeAPIDataObject): +class JobRequestMixin(HordeAPIData): """Mix-in class for data relating to any generation jobs.""" id_: JobID = Field(alias="id") @@ -62,7 +67,7 @@ def __hash__(self) -> int: return hash(JobRequestMixin.__name__) + hash(self.id_) -class JobResponseMixin(HordeAPIDataObject): +class JobResponseMixin(HordeAPIData): """Mix-in class for data relating to any generation jobs.""" id_: JobID = Field(alias="id") @@ -78,21 +83,21 @@ def validate_id(cls, v: str | JobID) -> JobID | str: return v -class WorkerRequestMixin(HordeAPIDataObject): +class WorkerRequestMixin(HordeAPIData): """Mix-in class for data relating to worker requests.""" worker_id: str | WorkerID """The UUID of the worker in question for this request.""" -class WorkerRequestNameMixin(HordeAPIDataObject): +class WorkerRequestNameMixin(HordeAPIData): """Mix-in class for data relating to worker requests.""" worker_name: str """The name of the worker in question for this request.""" -class LorasPayloadEntry(HordeAPIDataObject): +class LorasPayloadEntry(HordeAPIObjectBaseModel): """Represents a single lora parameter. v2 API Model: `ModelPayloadLorasStable` @@ -109,13 +114,24 @@ class LorasPayloadEntry(HordeAPIDataObject): is_version: bool = Field(default=False) """If true, will treat the lora name as a version ID.""" + @override + @classmethod + def get_api_model_name(cls) -> str | None: + return "ModelPayloadLorasStable" + + +class TIPayloadEntry(HordeAPIObjectBaseModel): + """Represents a single textual inversion (embedding) parameter. -class TIPayloadEntry(HordeAPIDataObject): - """Represents a single textual inversion (embedding) parameter.""" + v2 API Model: `ModelPayloadTextualInversionsStable` + """ name: str = Field(min_length=1, max_length=255) + """The name or ID of the textual inversion model to use.""" inject_ti: str | None = None + """Whether to automatically insert the TI into the prompt or negprompt.""" strength: float = Field(default=1, ge=-5, le=5) + """The strength to apply the textual inversion model.""" @field_validator("inject_ti") def validate_inject_ti(cls, v: str | None) -> str | None: @@ -141,8 +157,13 @@ def strength_only_if_inject_ti(self) -> TIPayloadEntry: logger.debug("strength is only valid when inject_ti is set") return self + @override + @classmethod + def get_api_model_name(cls) -> str | None: + return "ModelPayloadTextualInversionsStable" + -class ExtraSourceImageEntry(HordeAPIDataObject): +class ExtraSourceImageEntry(HordeAPIObjectBaseModel): """Represents a single extra source image. v2 API Model: `ExtraSourceImage` @@ -156,8 +177,13 @@ class ExtraSourceImageEntry(HordeAPIDataObject): strength: float = Field(default=1, ge=-5, le=5) """The strength to apply to this image on various operations.""" + @override + @classmethod + def get_api_model_name(cls) -> str | None: + return "ExtraSourceImage" + -class ExtraTextEntry(HordeAPIDataObject): +class ExtraTextEntry(HordeAPIObjectBaseModel): """Represents a single extra text. v2 API Model: `ExtraText` @@ -168,8 +194,13 @@ class ExtraTextEntry(HordeAPIDataObject): reference: str = Field(min_length=3) """Reference pointing to how this text is to be used.""" + @override + @classmethod + def get_api_model_name(cls) -> str | None: + return "ExtraText" -class SingleWarningEntry(HordeAPIDataObject): + +class SingleWarningEntry(HordeAPIObjectBaseModel): """Represents a single warning. v2 API Model: `RequestSingleWarning` @@ -190,12 +221,17 @@ def code_must_be_known(cls, v: str | WarningCode) -> str | WarningCode: return v + @override + @classmethod + def get_api_model_name(cls) -> str | None: + return "RequestSingleWarning" + -class ImageGenerateParamMixin(HordeAPIDataObject): +class ImageGenerateParamMixin(HordeAPIObjectBaseModel): """Mix-in class of some of the data included in a request to the `/v2/generate/async` endpoint. Also is the corresponding information returned on a job pop to the `/v2/generate/pop` endpoint. - v2 API Model: `ModelPayloadStable` + v2 API Model: `ModelPayloadRootStable` """ model_config = ( @@ -325,6 +361,11 @@ def control_type_must_be_known(cls, v: str | KNOWN_CONTROLNETS | None) -> str | return v + @override + @classmethod + def get_api_model_name(cls) -> str | None: + return "ModelPayloadRootStable" + class JobSubmitResponse(HordeResponseBaseModel): """The response to a job submission request, indicating the number of kudos gained. @@ -341,7 +382,7 @@ def get_api_model_name(cls) -> str | None: return "GenerationSubmitted" -class GenMetadataEntry(HordeAPIDataObject): +class GenMetadataEntry(HordeAPIObjectBaseModel): """Represents a single generation metadata entry. v2 API Model: `GenerationMetadataStable` @@ -371,3 +412,8 @@ def validate_value(cls, v: str | METADATA_VALUE) -> str | METADATA_VALUE: if isinstance(v, str) and v not in METADATA_VALUE.__members__: logger.warning(f"Unknown metadata value {v}. Is your SDK out of date or did the API change?") return v + + @override + @classmethod + def get_api_model_name(cls) -> str | None: + return "GenerationMetadataStable" diff --git a/horde_sdk/ai_horde_api/apimodels/generate/_async.py b/horde_sdk/ai_horde_api/apimodels/generate/_async.py index 014e9e7..63d2617 100644 --- a/horde_sdk/ai_horde_api/apimodels/generate/_async.py +++ b/horde_sdk/ai_horde_api/apimodels/generate/_async.py @@ -19,9 +19,8 @@ from horde_sdk.generic_api.apimodels import ( APIKeyAllowedInRequestMixin, ContainsMessageResponseMixin, - HordeAPIObject, - HordeResponse, HordeResponseBaseModel, + HordeResponseTypes, RequestUsesWorkerMixin, ResponseRequiringFollowUpMixin, ) @@ -98,7 +97,7 @@ def get_api_model_name(cls) -> str | None: return _ANONYMOUS_MODEL -class ImageGenerationInputPayload(HordeAPIObject, ImageGenerateParamMixin): +class ImageGenerationInputPayload(ImageGenerateParamMixin): """Represents the 'params' field in the `/v2/generate/async` endpoint. v2 API Model: `ModelGenerationInputStable` @@ -194,7 +193,7 @@ def get_default_success_response_type(cls) -> type[ImageGenerateAsyncResponse]: @override @classmethod - def get_success_status_response_pairs(cls) -> dict[HTTPStatusCode, type[HordeResponse]]: + def get_success_status_response_pairs(cls) -> dict[HTTPStatusCode, type[HordeResponseTypes]]: return { HTTPStatusCode.OK: ImageGenerateAsyncDryRunResponse, HTTPStatusCode.ACCEPTED: cls.get_default_success_response_type(), diff --git a/horde_sdk/ai_horde_api/apimodels/generate/_pop.py b/horde_sdk/ai_horde_api/apimodels/generate/_pop.py index 39b97f3..f85b9e8 100644 --- a/horde_sdk/ai_horde_api/apimodels/generate/_pop.py +++ b/horde_sdk/ai_horde_api/apimodels/generate/_pop.py @@ -26,14 +26,14 @@ from horde_sdk.consts import HTTPMethod from horde_sdk.generic_api.apimodels import ( APIKeyAllowedInRequestMixin, - HordeAPIObject, + HordeAPIObjectBaseModel, HordeResponseBaseModel, ResponseRequiringDownloadMixin, ResponseRequiringFollowUpMixin, ) -class NoValidRequestFound(HordeAPIObject): +class NoValidRequestFound(HordeAPIObjectBaseModel): blacklist: int | None = Field(default=None, ge=0) """How many waiting requests were skipped because they demanded a generation with a word that this worker does not accept.""" @@ -86,6 +86,8 @@ class ImageGenerateJobPopSkippedStatus(NoValidRequestFound): """How many waiting requests were skipped because they requested loras.""" controlnet: int = Field(default=0, ge=0) """How many waiting requests were skipped because they requested a controlnet.""" + step_count: int = Field(default=0, ge=0) + """How many waiting requests were skipped because they requested more steps than this worker provides.""" @override @classmethod @@ -95,6 +97,8 @@ def get_api_model_name(cls) -> str | None: class ImageGenerateJobPopPayload(ImageGenerateParamMixin): prompt: str | None = None + """The prompt to use for this image generation.""" + ddim_steps: int = Field(default=25, ge=1, validation_alias=AliasChoices("steps", "ddim_steps")) """The number of image generation steps to perform.""" @@ -426,7 +430,7 @@ def __hash__(self) -> int: return hash(0) -class PopInput(HordeAPIObject): +class PopInput(HordeAPIObjectBaseModel): amount: int | None = Field(1, ge=1, le=20) """The number of jobs to pop at the same time.""" bridge_agent: str | None = Field( diff --git a/horde_sdk/ai_horde_api/apimodels/generate/_progress.py b/horde_sdk/ai_horde_api/apimodels/generate/_progress.py index 333fb74..edbf7d2 100644 --- a/horde_sdk/ai_horde_api/apimodels/generate/_progress.py +++ b/horde_sdk/ai_horde_api/apimodels/generate/_progress.py @@ -1,7 +1,7 @@ -from horde_sdk.generic_api.apimodels import HordeAPIObject, ResponseWithProgressMixin +from horde_sdk.generic_api.apimodels import HordeAPIObjectBaseModel, ResponseWithProgressMixin -class ResponseGenerationProgressInfoMixin(HordeAPIObject): +class ResponseGenerationProgressInfoMixin(HordeAPIObjectBaseModel): finished: int """The amount of finished jobs in this request.""" processing: int diff --git a/horde_sdk/ai_horde_api/apimodels/generate/_status.py b/horde_sdk/ai_horde_api/apimodels/generate/_status.py index f168efe..70a1174 100644 --- a/horde_sdk/ai_horde_api/apimodels/generate/_status.py +++ b/horde_sdk/ai_horde_api/apimodels/generate/_status.py @@ -10,10 +10,10 @@ from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_ENDPOINT_SUBPATH from horde_sdk.ai_horde_api.fields import JobID, WorkerID from horde_sdk.consts import HTTPMethod -from horde_sdk.generic_api.apimodels import HordeAPIObject, HordeResponseBaseModel, ResponseWithProgressMixin +from horde_sdk.generic_api.apimodels import HordeAPIObjectBaseModel, HordeResponseBaseModel, ResponseWithProgressMixin -class Generation(HordeAPIObject): +class Generation(HordeAPIObjectBaseModel): model: str = Field(title="Generation Model") """The model which generated this image.""" state: GENERATION_STATE = Field( diff --git a/horde_sdk/ai_horde_api/apimodels/generate/text/_async.py b/horde_sdk/ai_horde_api/apimodels/generate/text/_async.py index 354dd6a..019aec9 100644 --- a/horde_sdk/ai_horde_api/apimodels/generate/text/_async.py +++ b/horde_sdk/ai_horde_api/apimodels/generate/text/_async.py @@ -16,9 +16,9 @@ from horde_sdk.generic_api.apimodels import ( APIKeyAllowedInRequestMixin, ContainsMessageResponseMixin, - HordeAPIDataObject, - HordeResponse, + HordeAPIData, HordeResponseBaseModel, + HordeResponseTypes, RequestUsesWorkerMixin, ResponseRequiringFollowUpMixin, ) @@ -83,7 +83,7 @@ def __eq__(self, __value: object) -> bool: @Unhashable -class ModelPayloadRootKobold(HordeAPIDataObject): +class ModelPayloadRootKobold(HordeAPIData): dynatemp_exponent: float | None = Field(1, ge=0.0, le=5.0) """Dynamic temperature exponent value.""" dynatemp_range: float | None = Field(0, ge=0.0, le=5.0) @@ -258,7 +258,7 @@ def get_default_success_response_type(cls) -> type[TextGenerateAsyncResponse]: @override @classmethod - def get_success_status_response_pairs(cls) -> dict[HTTPStatusCode, type[HordeResponse]]: + def get_success_status_response_pairs(cls) -> dict[HTTPStatusCode, type[HordeResponseTypes]]: return { HTTPStatusCode.OK: TextGenerateAsyncDryRunResponse, HTTPStatusCode.ACCEPTED: cls.get_default_success_response_type(), diff --git a/horde_sdk/ai_horde_api/apimodels/workers/_workers.py b/horde_sdk/ai_horde_api/apimodels/workers/_workers.py index 98240c0..2fdee30 100644 --- a/horde_sdk/ai_horde_api/apimodels/workers/_workers.py +++ b/horde_sdk/ai_horde_api/apimodels/workers/_workers.py @@ -1,6 +1,6 @@ from collections.abc import Iterator -from pydantic import AliasChoices, Field, RootModel +from pydantic import AliasChoices, Field from typing_extensions import override from horde_sdk.ai_horde_api.apimodels.base import BaseAIHordeRequest, WorkerRequestMixin, WorkerRequestNameMixin @@ -10,13 +10,14 @@ from horde_sdk.consts import HTTPMethod from horde_sdk.generic_api.apimodels import ( APIKeyAllowedInRequestMixin, - HordeAPIObject, - HordeResponse, + HordeAPIObjectBaseModel, + HordeResponseBaseModel, + HordeResponseRootModel, ) from horde_sdk.generic_api.decoration import Unequatable, Unhashable -class TeamDetailsLite(HordeAPIObject): +class TeamDetailsLite(HordeAPIObjectBaseModel): name: str | None = None """The Name given to this team.""" id_: str | TeamID | None = Field(default=None, alias="id") @@ -28,7 +29,7 @@ def get_api_model_name(cls) -> str | None: return "TeamDetailsLite" -class WorkerKudosDetails(HordeAPIObject): +class WorkerKudosDetails(HordeAPIObjectBaseModel): generated: float | None = None """How much Kudos this worker has received for generating images.""" uptime: int | None = None @@ -41,7 +42,7 @@ def get_api_model_name(cls) -> str | None: @Unhashable -class WorkerDetailItem(HordeAPIObject): +class WorkerDetailItem(HordeAPIObjectBaseModel): type_: WORKER_TYPE = Field(alias="type") """The type of worker.""" name: str @@ -115,6 +116,10 @@ class WorkerDetailItem(HordeAPIObject): """The maximum tokens this worker can read.""" tokens_generated: int | None = Field(default=None, examples=[0]) """How many tokens this worker has generated until now. """ + controlnet: bool | None = Field(default=None, examples=[False]) + """If True, this worker supports and allows controlnet requests.""" + sdxl_controlnet: bool | None = Field(default=None, examples=[False]) + """If True, this worker supports and allows sdxl controlnet requests.""" @override @classmethod @@ -164,7 +169,7 @@ def __eq__(self, other: object) -> bool: @Unhashable @Unequatable -class AllWorkersDetailsResponse(HordeResponse, RootModel[list[WorkerDetailItem]]): +class AllWorkersDetailsResponse(HordeResponseRootModel[list[WorkerDetailItem]]): # @tazlin: The typing of __iter__ in BaseModel seems to assume that RootModel wouldn't also be a parent class. # without a `type: ignore``, mypy feels that this is a bad override. This is probably a sub-optimal solution # on my part with me hoping to come up with a more elegant path in the future. @@ -232,7 +237,7 @@ def is_api_key_required(cls) -> bool: @Unhashable @Unequatable -class SingleWorkerDetailsResponse(HordeResponse, WorkerDetailItem): +class SingleWorkerDetailsResponse(HordeResponseBaseModel, WorkerDetailItem): @override @classmethod def get_api_model_name(cls) -> str | None: @@ -303,7 +308,7 @@ def is_api_key_required(cls) -> bool: return False -class ModifyWorkerResponse(HordeResponse): +class ModifyWorkerResponse(HordeResponseBaseModel): info: str | None = Field(default=None) """The new state of the 'info' var for this worker.""" maintenance: bool | None = Field(default=None) @@ -362,8 +367,8 @@ def get_default_success_response_type(cls) -> type[ModifyWorkerResponse]: return ModifyWorkerResponse -class DeleteWorkerResponse(HordeResponse): - deleted_id_: str | None = None +class DeleteWorkerResponse(HordeResponseBaseModel): + deleted_id: str | None = None """The ID of the deleted worker.""" deleted_name: str | None = None """The Name of the deleted worker.""" diff --git a/horde_sdk/ai_horde_api/endpoints.py b/horde_sdk/ai_horde_api/endpoints.py index 493c41d..ab42a90 100644 --- a/horde_sdk/ai_horde_api/endpoints.py +++ b/horde_sdk/ai_horde_api/endpoints.py @@ -95,7 +95,19 @@ class AI_HORDE_API_ENDPOINT_SUBPATH(GENERIC_API_ENDPOINT_SUBPATH): v2_documents_privacy = "/v2/documents/privacy" v2_documents_sponsors = "/v2/documents/sponsors" - vs_documents_terms = "/v2/documents/terms" + v2_documents_terms = "/v2/documents/terms" + + v2_styles_image_by_name = "/v2/styles/image_by_name/{style_name}" + v2_styles_image_by_id = "/v2/styles/image/{style_id}" + v2_collections_by_name = "/v2/collection_by_name/{collection_name}" + v2_styles_image_example_by_id = "/v2/styles/image/{style_id}/example" + v2_styles_text_by_id = "/v2/styles/text/{style_id}" + v2_styles_image = "/v2/styles/image" + v2_collections_by_id = "/v2/collections/{collection_id}" + v2_styles_image_example_by_id_example = "/v2/styles/image/{style_id}/example/{example_id}" + v2_styles_text = "/v2/styles/text" + v2_styles_text_by_name = "/v2/styles/text_by_name/{style_name}" + v2_collections = "/v2/collections" def get_ai_horde_swagger_url() -> str: @@ -104,3 +116,25 @@ def get_ai_horde_swagger_url() -> str: base_url=AI_HORDE_BASE_URL, path=AI_HORDE_API_ENDPOINT_SUBPATH.swagger, ) + + +def get_admin_only_endpoints() -> set[str]: + """Return all of the endpoints that are admin-only.""" + return { + AI_HORDE_API_ENDPOINT_SUBPATH.v2_status_modes, + AI_HORDE_API_ENDPOINT_SUBPATH.v2_filters, + AI_HORDE_API_ENDPOINT_SUBPATH.v2_filters_regex, + AI_HORDE_API_ENDPOINT_SUBPATH.v2_filters_regex_single, + AI_HORDE_API_ENDPOINT_SUBPATH.v2_operations_block_worker_ipaddr_single, + AI_HORDE_API_ENDPOINT_SUBPATH.v2_operations_ipaddr, + AI_HORDE_API_ENDPOINT_SUBPATH.v2_operations_ipaddr_single, + AI_HORDE_API_ENDPOINT_SUBPATH.v2_kudos_award, + } + + +def get_deprecated_endpoints() -> set[str]: + """Return all of the endpoints that are deprecated.""" + return { + AI_HORDE_API_ENDPOINT_SUBPATH.v2_generate_pop_multi, + AI_HORDE_API_ENDPOINT_SUBPATH.v2_generate_rate_id, + } diff --git a/horde_sdk/generic_api/apimodels.py b/horde_sdk/generic_api/apimodels.py index 9aea272..d66c513 100644 --- a/horde_sdk/generic_api/apimodels.py +++ b/horde_sdk/generic_api/apimodels.py @@ -6,11 +6,11 @@ import base64 import os import uuid -from typing import Any +from typing import Any, TypeVar import aiohttp from loguru import logger -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator from typing_extensions import override from horde_sdk import _default_sslcontext @@ -29,8 +29,17 @@ ) -class HordeAPIObject(BaseModel, abc.ABC): - """Base class for all Horde API data models, requests, or responses.""" +class HordeAPIObject(abc.ABC): + """Base class for all Horde API data models, requests, or responses. + + This is an abstract class that you probably shouldn't inherit from directly. Instead, inherit from one of the + subclasses defined in this module. + + Requests generally would inherit from `HordeRequest`, responses from `HordeResponse`, and data models from + `HordeAPIObjectBaseModel` (if it appears on the API as a published model) or `HordeAPIDataObject` (if it is a + data object that is not specifically defined by the API docs, such as an intermediate class or an anonymous model). + + """ @classmethod @abc.abstractmethod @@ -40,13 +49,57 @@ def get_api_model_name(cls) -> str | None: If none, there is no payload, such as for a GET request. """ - model_config = ConfigDict( - frozen=True, - use_attribute_docstrings=True, + @classmethod + def get_sensitive_fields(cls) -> set[str]: + """Return a set of fields which should be redacted from logs.""" + return {"apikey"} + + def get_extra_fields_to_exclude_from_log(self) -> set[str]: + """Return an additional set of fields to exclude from the log_safe_model_dump method.""" + return set() + + def log_safe_model_dump(self, extra_exclude: set[str] | None = None) -> dict[Any, Any]: + """Return a dict of the model's fields, with any sensitive fields redacted.""" + if extra_exclude is None: + extra_exclude = set() + + if hasattr(self, "model_dump"): + return self.model_dump( # type: ignore + exclude=self.get_sensitive_fields() | self.get_extra_fields_to_exclude_from_log() | extra_exclude, + ) + + logger.warning("Model does not have a model_dump method. Using python native class compatible method.") + logger.debug( + "Generally this should not be relied upon. If you're seeing this and you're a developer for the SDK, " + "consider using pydantic models instead.", + ) + # Its not a pydantic model, use python native class compatible method + return { + key: getattr(self, key) + for key in self.__dict__ + if key not in self.get_sensitive_fields() | self.get_extra_fields_to_exclude_from_log() | extra_exclude + } + + +class HordeAPIObjectBaseModel(HordeAPIObject, BaseModel): + """Base class for all Horde API data models (leveraging pydantic).""" + + model_config = ( + ConfigDict( + frozen=True, + use_attribute_docstrings=True, + extra="allow", + ) + if not os.getenv("TESTS_ONGOING") + else ConfigDict( + frozen=True, + use_attribute_docstrings=True, + extra="forbid", + ) ) -class HordeAPIDataObject(BaseModel): +class HordeAPIData(BaseModel): """Base class for all Horde API data models which appear as objects within other data models. These are objects which are not specifically defined by the API docs, but (logically or otherwise) are @@ -72,39 +125,63 @@ class HordeAPIDataObject(BaseModel): class HordeAPIMessage(HordeAPIObject): """Represents any request or response from any Horde API.""" - @classmethod - def get_sensitive_fields(cls) -> set[str]: - """Return a set of fields which should be redacted from logs.""" - return {"apikey"} - def get_extra_fields_to_exclude_from_log(self) -> set[str]: - """Return an additional set of fields to exclude from the log_safe_model_dump method.""" - return set() +class HordeResponse(HordeAPIMessage): + """Represents any response from any Horde API.""" - def log_safe_model_dump(self, extra_exclude: set[str] | None = None) -> dict[Any, Any]: - """Return a dict of the model's fields, with any sensitive fields redacted.""" - if extra_exclude is None: - extra_exclude = set() - return self.model_dump( - exclude=self.get_sensitive_fields() | self.get_extra_fields_to_exclude_from_log() | extra_exclude, - ) +T = TypeVar("T") -class HordeResponse(HordeAPIMessage): - """Represents any response from any Horde API.""" +class HordeResponseRootModel(RootModel[T], HordeResponse): + """Base class for all Horde API response data models which model another data type (leveraging pydantic). + + A typical example for using this class would be responses which *are* a list of another data type. + Define subclasses of this class with the type of the data model as the type argument. + + For example: + ```python + class MyDataModel(HordeResponseRootModel[MyData]): + pass + ``` + + """ + + model_config = ( + ConfigDict( + frozen=True, + use_attribute_docstrings=True, + ) + if not os.getenv("TESTS_ONGOING") + else ConfigDict( + frozen=True, + use_attribute_docstrings=True, + ) + ) class HordeResponseBaseModel(HordeResponse, BaseModel): """Base class for all Horde API response data models (leveraging pydantic).""" model_config = ( - ConfigDict(frozen=True, extra="allow") + ConfigDict( + frozen=True, + use_attribute_docstrings=True, + extra="allow", + ) if not os.getenv("TESTS_ONGOING") - else ConfigDict(frozen=True, extra="forbid") + else ConfigDict( + frozen=True, + use_attribute_docstrings=True, + extra="forbid", + ) ) +HordeResponseTypes = HordeResponseRootModel[Any] | HordeResponseBaseModel +"""A type hint for any type of the valid horde response models.""" + + class ResponseRequiringFollowUpMixin(abc.ABC): """Represents any response from any Horde API which requires a follow up request of some kind.""" @@ -157,8 +234,6 @@ def get_follow_up_failure_cleanup_request_type(cls) -> type[HordeRequest]: Defaults to `None`, meaning no cleanup request is needed. """ - _cleanup_requests: list[HordeRequest] | None = None - def get_follow_up_failure_cleanup_params(self) -> dict[str, object]: """Return any extra information required from this response to clean up after a failed follow up request. @@ -168,6 +243,8 @@ def get_follow_up_failure_cleanup_params(self) -> dict[str, object]: """ return {} + _cleanup_requests: list[HordeRequest] | None + def get_follow_up_failure_cleanup_request(self) -> list[HordeRequest]: """Return the request for this response to clean up after a failed follow up request.""" if self.ignore_failure(): @@ -222,7 +299,7 @@ def does_target_request_follow_up(self, target_request: HordeRequest) -> bool: return all_match -class ResponseWithProgressMixin(HordeAPIDataObject): +class ResponseWithProgressMixin(HordeAPIData): """Represents any response from any Horde API which contains progress information.""" @abc.abstractmethod @@ -255,7 +332,7 @@ def get_finalize_success_request_type(cls) -> type[HordeRequest] | None: """Return the request type for this response to finalize the job on success, or `None` if not needed.""" -class ResponseRequiringDownloadMixin(HordeAPIDataObject): +class ResponseRequiringDownloadMixin(HordeAPIData): """Represents any response which may require downloading additional data.""" async def download_file_as_base64(self, client_session: aiohttp.ClientSession, url: str) -> str: @@ -290,7 +367,7 @@ def download_additional_data(self) -> None: """Download any additional data required for this response.""" -class ContainsMessageResponseMixin(HordeAPIDataObject): +class ContainsMessageResponseMixin(HordeAPIData): """Represents any response from any Horde API which contains a message.""" message: str = "" @@ -307,6 +384,7 @@ class RequestErrorResponse(HordeResponseBaseModel, ContainsMessageResponseMixin) """This is a catch all for any additional data that may be returned by the API relevant to the error.""" rc: str = "RC_MISSING" + """The return code from the API which maps to a reason for the error.""" @override @classmethod @@ -317,7 +395,19 @@ def get_api_model_name(cls) -> str | None: class HordeRequest(HordeAPIMessage, BaseModel): """Represents any request to any Horde API.""" - model_config = ConfigDict(frozen=True, extra="forbid") + model_config = ( + ConfigDict( + frozen=True, + use_attribute_docstrings=True, + extra="allow", + ) + if not os.getenv("TESTS_ONGOING") + else ConfigDict( + frozen=True, + use_attribute_docstrings=True, + extra="forbid", + ) + ) @classmethod @abc.abstractmethod @@ -352,11 +442,13 @@ def get_api_endpoint_subpath(cls) -> GENERIC_API_ENDPOINT_SUBPATH: @classmethod @abc.abstractmethod - def get_default_success_response_type(cls) -> type[HordeResponse]: + def get_default_success_response_type(cls) -> type[HordeResponseTypes]: """Return the `type` of the response expected in the ordinary case of success.""" @classmethod - def get_success_status_response_pairs(cls) -> dict[HTTPStatusCode, type[HordeResponse]]: + def get_success_status_response_pairs( + cls, + ) -> dict[HTTPStatusCode, type[HordeResponseTypes]]: """Return a dict of HTTP status codes and the expected `HordeResponse`. Defaults to `{HTTPStatusCode.OK: cls.get_expected_response_type()}`, but may be overridden to support other @@ -409,7 +501,7 @@ def get_sensitive_fields(cls) -> set[str]: return {"apikey"} -class APIKeyAllowedInRequestMixin(HordeAPIDataObject): +class APIKeyAllowedInRequestMixin(HordeAPIObjectBaseModel): """Mix-in class to describe an endpoint which may require authentication.""" apikey: str | None = None @@ -439,7 +531,7 @@ def validate_api_key_length(cls, v: str) -> str: return v -class RequestSpecifiesUserIDMixin(HordeAPIDataObject): +class RequestSpecifiesUserIDMixin(HordeAPIData): """Mix-in class to describe an endpoint for which you can specify a user.""" user_id: str @@ -456,7 +548,7 @@ def user_id_is_numeric(cls, value: str) -> str: return value -class RequestUsesWorkerMixin(HordeAPIDataObject): +class RequestUsesWorkerMixin(HordeAPIData): """Mix-in class to describe an endpoint for which you can specify workers.""" trusted_workers: bool = False @@ -483,6 +575,7 @@ class RequestUsesWorkerMixin(HordeAPIDataObject): "HordeRequest", "HordeResponse", "HordeResponseBaseModel", + "HordeResponseRootModel", "ContainsMessageResponseMixin", "HordeAPIObject", "HordeAPIMessage", diff --git a/horde_sdk/generic_api/generic_clients.py b/horde_sdk/generic_api/generic_clients.py index 76fa66b..7c01f55 100644 --- a/horde_sdk/generic_api/generic_clients.py +++ b/horde_sdk/generic_api/generic_clients.py @@ -22,6 +22,8 @@ APIKeyAllowedInRequestMixin, HordeRequest, HordeResponse, + HordeResponseBaseModel, + HordeResponseRootModel, RequestErrorResponse, ResponseRequiringFollowUpMixin, ResponseWithProgressMixin, @@ -52,9 +54,13 @@ class ParsedRawRequest(BaseModel): """The body to be sent with the request, or `None` if no body should be sent.""" +# Can be a BaseModel or RootModel HordeRequestTypeVar = TypeVar("HordeRequestTypeVar", bound=HordeRequest) """TypeVar for the horde request type.""" -HordeResponseTypeVar = TypeVar("HordeResponseTypeVar", bound=HordeResponse) +HordeResponseTypeVar = TypeVar( + "HordeResponseTypeVar", + bound=HordeResponseBaseModel | HordeResponseRootModel[Any], +) """TypeVar for the horde response type.""" @@ -483,7 +489,7 @@ class GenericHordeAPISession(GenericHordeAPIManualClient): or anything labeled as `async` on the API. """ - _pending_follow_ups: list[tuple[HordeRequest, HordeResponse, list[HordeRequest] | None]] + _pending_follow_ups: list[tuple[HordeRequest, HordeResponseBaseModel, list[HordeRequest] | None]] """A `list` of 3-tuples containing the request, response, and a clean-up request for any requests which might need it.""" @@ -596,14 +602,14 @@ def __exit__(self, exc_type: type[BaseException], exc_val: Exception, exc_tb: ob def _handle_exit( self, request_to_follow_up: HordeRequest, # The request that is ending prematurely. - response_to_follow_up: HordeResponse, # The response to the request that is ending prematurely. + response_to_follow_up: HordeResponseBaseModel, # The response to the request that is ending prematurely. cleanup_requests: list[HordeRequest] | None, # The request to clean up after the premature ending, if any. ) -> bool: """Send any follow up requests needed to clean up after a request which is ending prematurely. Args: request_to_follow_up (HordeRequest): The request which is ending prematurely. - response_to_follow_up (HordeResponse): The response to the request which is ending prematurely. + response_to_follow_up (HordeResponseTypeVar): The response to the request which is ending prematurely. cleanup_requests (HordeRequest | None): The request to clean up after the premature ending, if any. Returns: diff --git a/horde_sdk/meta.py b/horde_sdk/meta.py index d25cb46..3163f2f 100644 --- a/horde_sdk/meta.py +++ b/horde_sdk/meta.py @@ -8,10 +8,13 @@ import horde_sdk.ai_horde_api import horde_sdk.ai_horde_api.apimodels import horde_sdk.ai_horde_worker +import horde_sdk.generic_api +import horde_sdk.generic_api.apimodels import horde_sdk.ratings_api import horde_sdk.ratings_api.apimodels from horde_sdk import HordeAPIObject, HordeRequest from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_ENDPOINT_SUBPATH, get_ai_horde_swagger_url +from horde_sdk.generic_api.apimodels import HordeAPIData from horde_sdk.generic_api.utils.swagger import SwaggerParser @@ -71,7 +74,35 @@ def any_unimported_classes(module: types.ModuleType, super_type: type) -> tuple[ return bool(missing_classes), missing_classes -def all_undefined_classes(module: types.ModuleType) -> dict[str, str]: +def all_undefined_classes(module: types.ModuleType) -> list[str]: + """Return all of the models defined on the API but not in the SDK.""" + module_found_classes = find_subclasses(module, HordeAPIObject) + + defined_api_object_names: set[str] = set() + + for class_type in module_found_classes: + if not issubclass(class_type, HordeAPIObject): + raise TypeError(f"Expected {class_type} to be a HordeAPIObject") + + api_model_name = class_type.get_api_model_name() + if api_model_name is not None: + defined_api_object_names.add(api_model_name) + + undefined_classes: list[str] = [] + + parser = SwaggerParser(swagger_doc_url=get_ai_horde_swagger_url()) + swagger_doc = parser.get_swagger_doc() + + all_api_objects = set(swagger_doc.definitions.keys()) + missing_object_names = all_api_objects - defined_api_object_names + + for object_name in missing_object_names: + undefined_classes.append(object_name) + + return undefined_classes + + +def all_undefined_classes_for_endpoints(module: types.ModuleType) -> dict[str, str]: """Return all of the models defined on the API but not in the SDK.""" module_found_classes = find_subclasses(module, HordeAPIObject) @@ -123,13 +154,12 @@ def all_unaddressed_endpoints_ai_horde() -> set[AI_HORDE_API_ENDPOINT_SUBPATH]: known_paths.remove(AI_HORDE_API_ENDPOINT_SUBPATH.swagger) unaddressed_paths = set() - all_classes = find_subclasses(horde_sdk.ai_horde_api.apimodels, HordeAPIObject) + all_classes = find_subclasses(horde_sdk.ai_horde_api.apimodels, HordeRequest) all_classes_paths = {cls.get_api_endpoint_subpath() for cls in all_classes if issubclass(cls, HordeRequest)} for path in known_paths: if path not in all_classes_paths: - print(f"Unaddressed path: {path}") unaddressed_paths.add(path) return unaddressed_paths @@ -138,6 +168,7 @@ def all_unaddressed_endpoints_ai_horde() -> set[AI_HORDE_API_ENDPOINT_SUBPATH]: def all_models_missing_docstrings() -> set[type]: """Return all of the models that do not have docstrings.""" all_classes = find_subclasses(horde_sdk.ai_horde_api.apimodels, HordeAPIObject) + all_classes += find_subclasses(horde_sdk.ai_horde_api.apimodels, HordeAPIData) missing_docstrings = set() @@ -150,7 +181,10 @@ def all_models_missing_docstrings() -> set[type]: def all_model_and_fields_missing_docstrings() -> dict[type, set[str]]: """Return all of the models' fields that do not have docstrings.""" - all_classes = find_subclasses(horde_sdk.ai_horde_api.apimodels, HordeAPIObject) + all_classes = find_subclasses(horde_sdk.generic_api.apimodels, HordeAPIObject) + all_classes += find_subclasses(horde_sdk.generic_api.apimodels, HordeAPIData) + all_classes += find_subclasses(horde_sdk.ai_horde_api.apimodels, HordeAPIObject) + all_classes += find_subclasses(horde_sdk.ai_horde_api.apimodels, HordeAPIData) missing_docstrings: dict[type, set[str]] = {} diff --git a/tests/ai_horde_api/test_dynamically_validate_against_swagger.py b/tests/ai_horde_api/test_dynamically_validate_against_swagger.py index d6d84d0..ba9547f 100644 --- a/tests/ai_horde_api/test_dynamically_validate_against_swagger.py +++ b/tests/ai_horde_api/test_dynamically_validate_against_swagger.py @@ -9,7 +9,7 @@ from horde_sdk.ai_horde_api.endpoints import get_ai_horde_swagger_url from horde_sdk.consts import _ANONYMOUS_MODEL, HTTPMethod, HTTPStatusCode, get_all_success_status_codes from horde_sdk.generic_api._reflection import get_all_request_types -from horde_sdk.generic_api.apimodels import HordeRequest, HordeResponse +from horde_sdk.generic_api.apimodels import HordeRequest, HordeResponseTypes from horde_sdk.generic_api.endpoints import GENERIC_API_ENDPOINT_SUBPATH from horde_sdk.generic_api.utils.swagger import ( SwaggerDoc, @@ -72,7 +72,7 @@ def all_ai_horde_model_defs_in_swagger(swagger_doc: SwaggerDoc) -> None: swagger_defined_response_examples = swagger_doc.get_all_response_examples() api_to_sdk_payload_model_map: dict[str, dict[HTTPMethod, type[HordeRequest]]] = {} - api_to_sdk_response_model_map: dict[str, dict[HTTPStatusCode, type[HordeResponse]]] = {} + api_to_sdk_response_model_map: dict[str, dict[HTTPStatusCode, type[HordeResponseTypes]]] = {} request_field_names_and_descriptions: dict[str, dict[str, dict[str, str | list[str] | None]]] = {} response_field_names_and_descriptions: dict[str, dict[str, dict[str, str | list[str] | None]]] = {} diff --git a/tests/test_data/ai_horde_api/example_responses/_v2_collection_by_name_collection_name_get_200.json b/tests/test_data/ai_horde_api/example_responses/_v2_collection_by_name_collection_name_get_200.json index 2f91d08..14c40bf 100644 --- a/tests/test_data/ai_horde_api/example_responses/_v2_collection_by_name_collection_name_get_200.json +++ b/tests/test_data/ai_horde_api/example_responses/_v2_collection_by_name_collection_name_get_200.json @@ -9,5 +9,6 @@ "name": "db0#1::style::my awesome style", "id": "00000000-0000-0000-0000-000000000000" } - ] + ], + "use_count": 0 } diff --git a/tests/test_data/ai_horde_api/example_responses/_v2_collections_collection_id_get_200.json b/tests/test_data/ai_horde_api/example_responses/_v2_collections_collection_id_get_200.json index 2f91d08..14c40bf 100644 --- a/tests/test_data/ai_horde_api/example_responses/_v2_collections_collection_id_get_200.json +++ b/tests/test_data/ai_horde_api/example_responses/_v2_collections_collection_id_get_200.json @@ -9,5 +9,6 @@ "name": "db0#1::style::my awesome style", "id": "00000000-0000-0000-0000-000000000000" } - ] + ], + "use_count": 0 } diff --git a/tests/test_data/ai_horde_api/example_responses/_v2_collections_get_200.json b/tests/test_data/ai_horde_api/example_responses/_v2_collections_get_200.json index 2d2a387..7bfacfd 100644 --- a/tests/test_data/ai_horde_api/example_responses/_v2_collections_get_200.json +++ b/tests/test_data/ai_horde_api/example_responses/_v2_collections_get_200.json @@ -10,6 +10,7 @@ "name": "db0#1::style::my awesome style", "id": "00000000-0000-0000-0000-000000000000" } - ] + ], + "use_count": 0 } ] diff --git a/tests/test_data/ai_horde_api/example_responses/_v2_styles_image_by_name_style_name_get_200.json b/tests/test_data/ai_horde_api/example_responses/_v2_styles_image_by_name_style_name_get_200.json index 64d22ee..db63b5e 100644 --- a/tests/test_data/ai_horde_api/example_responses/_v2_styles_image_by_name_style_name_get_200.json +++ b/tests/test_data/ai_horde_api/example_responses/_v2_styles_image_by_name_style_name_get_200.json @@ -51,6 +51,7 @@ "stable_diffusion" ], "id": "00000000-0000-0000-0000-000000000000", + "use_count": 0, "creator": "db0#1", "examples": [ { diff --git a/tests/test_data/ai_horde_api/example_responses/_v2_styles_image_get_200.json b/tests/test_data/ai_horde_api/example_responses/_v2_styles_image_get_200.json index e3b00f1..d6b0dd4 100644 --- a/tests/test_data/ai_horde_api/example_responses/_v2_styles_image_get_200.json +++ b/tests/test_data/ai_horde_api/example_responses/_v2_styles_image_get_200.json @@ -52,6 +52,7 @@ "stable_diffusion" ], "id": "00000000-0000-0000-0000-000000000000", + "use_count": 0, "creator": "db0#1", "examples": [ { diff --git a/tests/test_data/ai_horde_api/example_responses/_v2_styles_image_style_id_get_200.json b/tests/test_data/ai_horde_api/example_responses/_v2_styles_image_style_id_get_200.json index 64d22ee..db63b5e 100644 --- a/tests/test_data/ai_horde_api/example_responses/_v2_styles_image_style_id_get_200.json +++ b/tests/test_data/ai_horde_api/example_responses/_v2_styles_image_style_id_get_200.json @@ -51,6 +51,7 @@ "stable_diffusion" ], "id": "00000000-0000-0000-0000-000000000000", + "use_count": 0, "creator": "db0#1", "examples": [ { diff --git a/tests/test_data/ai_horde_api/example_responses/_v2_styles_text_by_name_style_name_get_200.json b/tests/test_data/ai_horde_api/example_responses/_v2_styles_text_by_name_style_name_get_200.json index 0637d5a..fba0714 100644 --- a/tests/test_data/ai_horde_api/example_responses/_v2_styles_text_by_name_style_name_get_200.json +++ b/tests/test_data/ai_horde_api/example_responses/_v2_styles_text_by_name_style_name_get_200.json @@ -38,5 +38,6 @@ "llama3" ], "id": "00000000-0000-0000-0000-000000000000", + "use_count": 0, "creator": "db0#1" } diff --git a/tests/test_data/ai_horde_api/example_responses/_v2_styles_text_get_200.json b/tests/test_data/ai_horde_api/example_responses/_v2_styles_text_get_200.json index 0dcc4e4..5bbf6a9 100644 --- a/tests/test_data/ai_horde_api/example_responses/_v2_styles_text_get_200.json +++ b/tests/test_data/ai_horde_api/example_responses/_v2_styles_text_get_200.json @@ -39,6 +39,7 @@ "llama3" ], "id": "00000000-0000-0000-0000-000000000000", + "use_count": 0, "creator": "db0#1" } ] diff --git a/tests/test_data/ai_horde_api/example_responses/_v2_styles_text_style_id_get_200.json b/tests/test_data/ai_horde_api/example_responses/_v2_styles_text_style_id_get_200.json index 0637d5a..fba0714 100644 --- a/tests/test_data/ai_horde_api/example_responses/_v2_styles_text_style_id_get_200.json +++ b/tests/test_data/ai_horde_api/example_responses/_v2_styles_text_style_id_get_200.json @@ -38,5 +38,6 @@ "llama3" ], "id": "00000000-0000-0000-0000-000000000000", + "use_count": 0, "creator": "db0#1" } diff --git a/tests/test_verify_api_surface.py b/tests/test_verify_api_surface.py index 63ce7f6..1757145 100644 --- a/tests/test_verify_api_surface.py +++ b/tests/test_verify_api_surface.py @@ -27,7 +27,7 @@ def test_all_ai_horde_api_data_objects_imported() -> None: unimported_classes, missing_imports = horde_sdk.meta.any_unimported_classes( horde_sdk.ai_horde_api.apimodels, - horde_sdk.generic_api.apimodels.HordeAPIDataObject, + horde_sdk.generic_api.apimodels.HordeAPIData, ) missing_import_names = {cls.__name__ for cls in missing_imports} @@ -39,14 +39,52 @@ def test_all_ai_horde_api_data_objects_imported() -> None: ) -@pytest.mark.skip(reason="This test is not yet enforced.") +# @pytest.mark.skip(reason="This test is not yet enforced.") @pytest.mark.object_verify def test_all_ai_horde_api_models_defined() -> None: import horde_sdk.ai_horde_api.apimodels - from horde_sdk.meta import all_undefined_classes + from horde_sdk.meta import all_undefined_classes, all_undefined_classes_for_endpoints undefined_classes = all_undefined_classes(horde_sdk.ai_horde_api.apimodels) + # all_undefined_classes_for_endpoints handles the ones directly referenced by endpoints, so we remove them + undefined_classes_for_endpoints = all_undefined_classes_for_endpoints(horde_sdk.ai_horde_api.apimodels) + for key in undefined_classes_for_endpoints: + if key in undefined_classes: + undefined_classes.remove(key) + + assert ( + "GenerationInputStable" not in undefined_classes + ), "A model which is known to be defined in the SDK was not found. Something critically bad has happened." + + # Pretty print the undefined classes sorted by dict values, NOT by keys + import json + + error_responses = { + "RequestError", + "RequestValidationError", + } + + for error_response in error_responses: + if error_response in undefined_classes: + print(f"Warning: {error_response} is an error response which may not be handled.") + undefined_classes.remove(error_response) + + undefined_classes_sorted = sorted(undefined_classes) + print(json.dumps(undefined_classes_sorted, indent=4)) + + assert not undefined_classes, ( + "The following models are defined in the API but not in the SDK: " f"{undefined_classes}" + ) + + +@pytest.mark.object_verify +def test_all_ai_horde_api_models_defined_for_endpoints() -> None: + import horde_sdk.ai_horde_api.apimodels + from horde_sdk.meta import all_undefined_classes_for_endpoints + + undefined_classes = all_undefined_classes_for_endpoints(horde_sdk.ai_horde_api.apimodels) + assert ( "GenerationInputStable" not in undefined_classes ), "A model which is known to be defined in the SDK was not found. Something critically bad has happened." @@ -83,13 +121,22 @@ def test_all_ai_horde_endpoints_known() -> None: ) -@pytest.mark.skip(reason="This test is not yet enforced.") +# @pytest.mark.skip(reason="This test is not yet enforced.") @pytest.mark.object_verify def test_all_ai_horde_endpoints_addressed() -> None: + from horde_sdk.ai_horde_api.endpoints import get_admin_only_endpoints, get_deprecated_endpoints from horde_sdk.meta import all_unaddressed_endpoints_ai_horde unaddressed_endpoints = all_unaddressed_endpoints_ai_horde() + all_ignored_endpoints = get_admin_only_endpoints() | get_deprecated_endpoints() + + unaddressed_endpoints -= all_ignored_endpoints + + print() + for unaddressed_endpoint in unaddressed_endpoints: + print(f"Unaddressed path: {unaddressed_endpoint}.") + assert not unaddressed_endpoints, ( "The following endpoints are defined in the API but not in the SDK: " f"{unaddressed_endpoints}" )