Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor/fix: resolve the RootModel problem #302

Merged
merged 4 commits into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion docs/response_field_names_and_descriptions.json
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@
}
},
"DeleteWorkerResponse": {
"deleted_id_": {
"deleted_id": {
"description": "The ID of the deleted worker.",
"types": [
"str",
Expand Down Expand Up @@ -1238,6 +1238,20 @@
"int",
"None"
]
},
"controlnet": {
"description": "If True, this worker supports and allows controlnet requests.",
"types": [
"bool",
"None"
]
},
"sdxl_controlnet": {
"description": "If True, this worker supports and allows sdxl controlnet requests.",
"types": [
"bool",
"None"
]
}
},
"TextGenerateAsyncDryRunResponse": {
Expand Down
4 changes: 2 additions & 2 deletions horde_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _dev_env_var_warnings() -> None: # pragma: no cover
from horde_sdk.generic_api.apimodels import (
APIKeyAllowedInRequestMixin,
ContainsMessageResponseMixin,
HordeAPIDataObject,
HordeAPIData,
HordeAPIMessage,
HordeAPIObject,
HordeRequest,
Expand All @@ -99,7 +99,7 @@ def _dev_env_var_warnings() -> None: # pragma: no cover
"APIKeyAllowedInRequestMixin",
"HordeRequest",
"ContainsMessageResponseMixin",
"HordeAPIDataObject",
"HordeAPIData",
"HordeAPIMessage",
"HordeAPIObject",
"RequestErrorResponse",
Expand Down
6 changes: 3 additions & 3 deletions horde_sdk/ai_horde_api/apimodels/_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_ENDPOINT_SUBPATH
from horde_sdk.consts import HTTPMethod
from horde_sdk.generic_api.apimodels import (
HordeAPIObject,
HordeAPIObjectBaseModel,
HordeResponseBaseModel,
)

Expand All @@ -29,7 +29,7 @@ def get_api_model_name(cls) -> str:
return "HordeDocument"


class AIHordeDocumentRequestMixin(HordeAPIObject):
class AIHordeDocumentRequestMixin(HordeAPIObjectBaseModel):
format: DocumentFormat | str = DocumentFormat.html

"""The format of the document to return. Default is markdown."""
Expand Down Expand Up @@ -108,7 +108,7 @@ def get_http_method(cls) -> HTTPMethod:
@override
@classmethod
def get_api_endpoint_subpath(cls) -> AI_HORDE_API_ENDPOINT_SUBPATH:
return AI_HORDE_API_ENDPOINT_SUBPATH.vs_documents_terms
return AI_HORDE_API_ENDPOINT_SUBPATH.v2_documents_terms

@override
@classmethod
Expand Down
4 changes: 2 additions & 2 deletions horde_sdk/ai_horde_api/apimodels/_kudos.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from horde_sdk.ai_horde_api.apimodels.base import BaseAIHordeRequest
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_ENDPOINT_SUBPATH
from horde_sdk.consts import _ANONYMOUS_MODEL, HTTPMethod
from horde_sdk.generic_api.apimodels import APIKeyAllowedInRequestMixin, HordeResponse
from horde_sdk.generic_api.apimodels import APIKeyAllowedInRequestMixin, HordeResponseBaseModel


class KudosTransferResponse(HordeResponse):
class KudosTransferResponse(HordeResponseBaseModel):
transferred: float | None = None
"""The amount of Kudos transferred."""

Expand Down
26 changes: 23 additions & 3 deletions horde_sdk/ai_horde_api/apimodels/_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from horde_sdk.ai_horde_api.consts import MODEL_STATE
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_ENDPOINT_SUBPATH
from horde_sdk.consts import HTTPMethod
from horde_sdk.generic_api.apimodels import HordeAPIDataObject, HordeResponseBaseModel
from horde_sdk.generic_api.apimodels import HordeAPIObjectBaseModel, HordeResponseBaseModel
from horde_sdk.generic_api.decoration import Unequatable, Unhashable


Expand Down Expand Up @@ -122,7 +122,12 @@ def get_default_success_response_type(cls) -> type[ImageStatsModelsResponse]:
return ImageStatsModelsResponse


class SinglePeriodImgStat(HordeAPIDataObject):
class SinglePeriodImgStat(HordeAPIObjectBaseModel):
"""Represents the stats for a single period of image generation.

v2 API Model: `SinglePeriodImgStat`
"""

images: int | None = Field(
default=None,
)
Expand All @@ -140,6 +145,11 @@ def mps(self) -> int | None:

return self.ps // 1_000_000

@override
@classmethod
def get_api_model_name(cls) -> str | None:
return "SinglePeriodImgStat"


class ImageStatsModelsTotalResponse(HordeResponseBaseModel):
"""Represents the data returned from the `/v2/stats/img/totals` endpoint."""
Expand Down Expand Up @@ -253,7 +263,12 @@ def get_default_success_response_type(cls) -> type[TextStatsModelResponse]:
return TextStatsModelResponse


class SinglePeriodTxtStat(HordeAPIDataObject):
class SinglePeriodTxtStat(HordeAPIObjectBaseModel):
"""Represents the stats for a single period.

v2 API Model: `SinglePeriodTxtStat`
"""

requests: int | None = Field(
default=None,
)
Expand All @@ -263,6 +278,11 @@ class SinglePeriodTxtStat(HordeAPIDataObject):
)
"""The number of tokens generated during this period."""

@override
@classmethod
def get_api_model_name(cls) -> str | None:
return "SinglePeriodTxtStat"


@Unhashable
class TextStatsModelsTotalResponse(HordeResponseBaseModel):
Expand Down
30 changes: 21 additions & 9 deletions horde_sdk/ai_horde_api/apimodels/_status.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Iterator

from pydantic import ConfigDict, Field, RootModel
from pydantic import ConfigDict, Field
from typing_extensions import override

from horde_sdk.ai_horde_api.apimodels.base import BaseAIHordeRequest
Expand All @@ -9,9 +9,9 @@
from horde_sdk.consts import HTTPMethod
from horde_sdk.generic_api.apimodels import (
ContainsMessageResponseMixin,
HordeAPIObject,
HordeResponse,
HordeAPIObjectBaseModel,
HordeResponseBaseModel,
HordeResponseRootModel,
)
from horde_sdk.generic_api.decoration import Unhashable

Expand Down Expand Up @@ -148,7 +148,7 @@ def get_default_success_response_type(cls) -> type[HordePerformanceResponse]:
return HordePerformanceResponse


class Newspiece(HordeAPIObject):
class Newspiece(HordeAPIObjectBaseModel):
date_published: str | None = Field(
default=None,
)
Expand All @@ -159,6 +159,18 @@ class Newspiece(HordeAPIObject):
default=None,
)
"""The actual piece of news."""
tags: list[str] | None = Field(
default=None,
)
"""The tags associated with this newspiece."""
title: str | None = Field(
default=None,
)
"""The title of this newspiece."""
more_info_urls: list[str] | None = Field(
default=None,
)
"""The URLs for more information about this newspiece."""

@override
@classmethod
Expand All @@ -167,7 +179,7 @@ def get_api_model_name(cls) -> str | None:


@Unhashable
class NewsResponse(HordeResponse, RootModel[list[Newspiece]]):
class NewsResponse(HordeResponseRootModel[list[Newspiece]]):
root: list[Newspiece]
"""The underlying list of newspieces."""

Expand Down Expand Up @@ -210,7 +222,7 @@ def get_default_success_response_type(cls) -> type[NewsResponse]:
return NewsResponse


class ActiveModelLite(HordeAPIObject):
class ActiveModelLite(HordeAPIObjectBaseModel):
count: int | None = Field(
default=None,
)
Expand Down Expand Up @@ -256,7 +268,7 @@ def get_api_model_name(cls) -> str | None:


@Unhashable
class HordeStatusModelsAllResponse(HordeResponse, RootModel[list[ActiveModel]]):
class HordeStatusModelsAllResponse(HordeResponseRootModel[list[ActiveModel]]):
root: list[ActiveModel]
"""The underlying list of models."""

Expand Down Expand Up @@ -325,7 +337,7 @@ def get_query_fields(cls) -> list[str]:


@Unhashable
class HordeStatusModelsSingleResponse(HordeResponse, RootModel[list[ActiveModel]]):
class HordeStatusModelsSingleResponse(HordeResponseRootModel[list[ActiveModel]]):
# This is a list because of an oversight in the structure of the API response. # FIXME

root: list[ActiveModel]
Expand Down Expand Up @@ -377,7 +389,7 @@ def get_default_success_response_type(cls) -> type[HordeStatusModelsSingleRespon
return HordeStatusModelsSingleResponse


class HordeModes(HordeAPIObject):
class HordeModes(HordeAPIObjectBaseModel):
maintenance_mode: bool = Field(
default=False,
)
Expand Down
7 changes: 5 additions & 2 deletions horde_sdk/ai_horde_api/apimodels/_styles.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import Field
from strenum import StrEnum

from horde_sdk.generic_api.apimodels import HordeAPIDataObject
from horde_sdk.generic_api.apimodels import HordeAPIData


class StyleType(StrEnum):
Expand All @@ -13,7 +13,10 @@ class StyleType(StrEnum):
text = auto()


class ResponseModelStylesUser(HordeAPIDataObject):
class ResponseModelStylesUser(HordeAPIData):
name: str
"""The name of the style."""
id_: str = Field(alias="id")
"""The ID of the style."""
type_: StyleType = Field(alias="type")
"""The type of the style."""
Loading
Loading