diff --git a/docs/api_to_sdk_map.md b/docs/api_to_sdk_map.md index af23beb..19b60a0 100644 --- a/docs/api_to_sdk_map.md +++ b/docs/api_to_sdk_map.md @@ -26,7 +26,7 @@ This is a mapping of the AI-Horde API models (defined at [https://stablehorde.ne | /v2/generate/async | 200 | [ImageGenerateAsyncDryRunResponse][horde_sdk.ai_horde_api.apimodels.generate._async.ImageGenerateAsyncDryRunResponse] | | /v2/generate/async | 202 | [ImageGenerateAsyncResponse][horde_sdk.ai_horde_api.apimodels.generate._async.ImageGenerateAsyncResponse] | | /v2/generate/check/{id} | 200 | [ImageGenerateCheckResponse][horde_sdk.ai_horde_api.apimodels.generate._check.ImageGenerateCheckResponse] | -| /v2/generate/pop | 200 | [ImageGenerateJobResponse][horde_sdk.ai_horde_api.apimodels.generate._pop.ImageGenerateJobResponse] | +| /v2/generate/pop | 200 | [ImageGenerateJobPopResponse][horde_sdk.ai_horde_api.apimodels.generate._pop.ImageGenerateJobPopResponse] | | /v2/generate/status/{id} | 200 | [ImageGenerateStatusResponse][horde_sdk.ai_horde_api.apimodels.generate._status.ImageGenerateStatusResponse] | | /v2/generate/submit | 200 | [JobSubmitResponse][horde_sdk.ai_horde_api.apimodels.base.JobSubmitResponse] | | /v2/interrogate/async | 202 | [AlchemyAsyncResponse][horde_sdk.ai_horde_api.apimodels.alchemy._async.AlchemyAsyncResponse] | diff --git a/docs/api_to_sdk_response_map.json b/docs/api_to_sdk_response_map.json index 2a80be7..8be1605 100644 --- a/docs/api_to_sdk_response_map.json +++ b/docs/api_to_sdk_response_map.json @@ -25,7 +25,7 @@ "200": "horde_sdk.ai_horde_api.apimodels.generate._check.ImageGenerateCheckResponse" }, "/v2/generate/pop": { - "200": "horde_sdk.ai_horde_api.apimodels.generate._pop.ImageGenerateJobResponse" + "200": "horde_sdk.ai_horde_api.apimodels.generate._pop.ImageGenerateJobPopResponse" }, "/v2/generate/submit": { "200": "horde_sdk.ai_horde_api.apimodels.base.JobSubmitResponse" diff --git a/docs/request_field_names_and_descriptions.json b/docs/request_field_names_and_descriptions.json new file mode 100644 index 0000000..5704fc1 --- /dev/null +++ b/docs/request_field_names_and_descriptions.json @@ -0,0 +1,360 @@ +{ + "AlchemyAsyncRequest": [ + [ + "apikey", + null + ], + [ + "accept", + null + ], + [ + "client_agent", + null + ], + [ + "forms", + null + ], + [ + "source_image", + null + ], + [ + "slow_workers", + null + ] + ], + "AlchemyDeleteRequest": [ + [ + "id_", + null + ], + [ + "accept", + null + ], + [ + "client_agent", + null + ] + ], + "AlchemyPopRequest": [ + [ + "apikey", + null + ], + [ + "accept", + null + ], + [ + "client_agent", + null + ], + [ + "name", + null + ], + [ + "priority_usernames", + null + ], + [ + "forms", + null + ] + ], + "AlchemyStatusRequest": [ + [ + "apikey", + null + ], + [ + "id_", + null + ], + [ + "accept", + null + ], + [ + "client_agent", + null + ] + ], + "AllWorkersDetailsRequest": [ + [ + "apikey", + null + ], + [ + "accept", + null + ], + [ + "client_agent", + null + ], + [ + "type_", + null + ] + ], + "DeleteImageGenerateRequest": [ + [ + "id_", + null + ], + [ + "accept", + null + ], + [ + "client_agent", + null + ] + ], + "FindUserRequest": [ + [ + "apikey", + null + ], + [ + "accept", + null + ], + [ + "client_agent", + null + ] + ], + "ImageGenerateAsyncRequest": [ + [ + "trusted_workers", + null + ], + [ + "slow_workers", + null + ], + [ + "workers", + null + ], + [ + "worker_blacklist", + null + ], + [ + "models", + null + ], + [ + "dry_run", + null + ], + [ + "apikey", + null + ], + [ + "accept", + null + ], + [ + "client_agent", + null + ], + [ + "prompt", + null + ], + [ + "params", + null + ], + [ + "nsfw", + null + ], + [ + "censor_nsfw", + null + ], + [ + "r2", + null + ], + [ + "shared", + null + ], + [ + "replacement_filter", + null + ], + [ + "source_image", + null + ], + [ + "source_processing", + null + ], + [ + "source_mask", + null + ] + ], + "ImageGenerateCheckRequest": [ + [ + "id_", + null + ], + [ + "accept", + null + ], + [ + "client_agent", + null + ] + ], + "ImageGenerateJobPopRequest": [ + [ + "apikey", + null + ], + [ + "accept", + null + ], + [ + "client_agent", + null + ], + [ + "name", + null + ], + [ + "priority_usernames", + null + ], + [ + "nsfw", + null + ], + [ + "models", + null + ], + [ + "bridge_version", + null + ], + [ + "bridge_agent", + null + ], + [ + "threads", + null + ], + [ + "require_upfront_kudos", + null + ], + [ + "max_pixels", + null + ], + [ + "blacklist", + null + ], + [ + "allow_img2img", + null + ], + [ + "allow_painting", + null + ], + [ + "allow_unsafe_ipaddr", + null + ], + [ + "allow_post_processing", + null + ], + [ + "allow_controlnet", + null + ], + [ + "allow_lora", + null + ] + ], + "ImageGenerateStatusRequest": [ + [ + "id_", + null + ], + [ + "accept", + null + ], + [ + "client_agent", + null + ] + ], + "ImageGenerationJobSubmitRequest": [ + [ + "apikey", + null + ], + [ + "id_", + null + ], + [ + "accept", + null + ], + [ + "client_agent", + null + ], + [ + "generation", + null + ], + [ + "state", + null + ], + [ + "seed", + null + ], + [ + "censored", + null + ] + ], + "StatsImageModelsRequest": [ + [ + "accept", + null + ], + [ + "client_agent", + null + ] + ] +} diff --git a/docs/response_field_names_and_descriptions.json b/docs/response_field_names_and_descriptions.json new file mode 100644 index 0000000..c9831c3 --- /dev/null +++ b/docs/response_field_names_and_descriptions.json @@ -0,0 +1,354 @@ +{ + "AlchemyAsyncResponse": [ + [ + "message", + null + ], + [ + "id_", + null + ] + ], + "AlchemyStatusResponse": [ + [ + "state", + null + ], + [ + "forms", + null + ], + [ + "state", + null + ], + [ + "forms", + null + ] + ], + "AlchemyPopResponse": [ + [ + "forms", + null + ], + [ + "skipped", + null + ] + ], + "AllWorkersDetailsResponse": [ + [ + "root", + null + ] + ], + "ImageGenerateStatusResponse": [ + [ + "finished", + null + ], + [ + "processing", + null + ], + [ + "restarted", + null + ], + [ + "waiting", + null + ], + [ + "done", + null + ], + [ + "faulted", + null + ], + [ + "wait_time", + null + ], + [ + "queue_position", + null + ], + [ + "kudos", + null + ], + [ + "is_possible", + null + ], + [ + "generations", + null + ], + [ + "shared", + null + ], + [ + "finished", + null + ], + [ + "processing", + null + ], + [ + "restarted", + null + ], + [ + "waiting", + null + ], + [ + "done", + null + ], + [ + "faulted", + null + ], + [ + "wait_time", + null + ], + [ + "queue_position", + null + ], + [ + "kudos", + null + ], + [ + "is_possible", + null + ], + [ + "generations", + null + ], + [ + "shared", + null + ] + ], + "FindUserResponse": [ + [ + "account_age", + "How many seconds since this account was created." + ], + [ + "concurrency", + "How many concurrent generations this user may request." + ], + [ + "contact", + "(Privileged) Contact details for the horde admins to reach the user in case of emergency." + ], + [ + "contributions", + null + ], + [ + "evaluating_kudos", + "(Privileged) The amount of Evaluating Kudos this untrusted user has from generations and uptime. When this number reaches a pre-specified threshold, they automatically become trusted." + ], + [ + "flagged", + "This user has been flagged for suspicious activity." + ], + [ + "id_", + "The user unique ID. It is always an integer." + ], + [ + "kudos", + "The amount of Kudos this user has. The amount of Kudos determines the priority when requesting image generations." + ], + [ + "kudos_details", + null + ], + [ + "moderator", + "This user is a Horde moderator." + ], + [ + "monthly_kudos", + null + ], + [ + "pseudonymous", + "If true, this user has not registered using an oauth service." + ], + [ + "records", + null + ], + [ + "sharedkey_ids", + null + ], + [ + "special", + "(Privileged) This user has been given the Special role." + ], + [ + "suspicious", + "(Privileged) How much suspicion this user has accumulated." + ], + [ + "trusted", + "This user is a trusted member of the Horde." + ], + [ + "usage", + null + ], + [ + "username", + "The user's unique Username. It is a combination of their chosen alias plus their ID." + ], + [ + "vpn", + "(Privileged) This user has been given the VPN role." + ], + [ + "worker_count", + "How many workers this user has created (active or inactive)." + ], + [ + "worker_ids", + null + ], + [ + "worker_invited", + "Whether this user has been invited to join a worker to the horde and how many of them. When 0, this user cannot add (new) workers to the horde." + ] + ], + "ImageGenerateAsyncDryRunResponse": [ + [ + "kudos", + null + ] + ], + "ImageGenerateAsyncResponse": [ + [ + "message", + null + ], + [ + "id_", + null + ], + [ + "kudos", + null + ] + ], + "ImageGenerateCheckResponse": [ + [ + "finished", + null + ], + [ + "processing", + null + ], + [ + "restarted", + null + ], + [ + "waiting", + null + ], + [ + "done", + null + ], + [ + "faulted", + null + ], + [ + "wait_time", + null + ], + [ + "queue_position", + null + ], + [ + "kudos", + null + ], + [ + "is_possible", + null + ] + ], + "ImageGenerateJobPopResponse": [ + [ + "id_", + null + ], + [ + "payload", + null + ], + [ + "skipped", + null + ], + [ + "model", + null + ], + [ + "source_image", + null + ], + [ + "source_processing", + null + ], + [ + "source_mask", + null + ], + [ + "r2_upload", + null + ] + ], + "JobSubmitResponse": [ + [ + "reward", + null + ] + ], + "StatsModelsResponse": [ + [ + "day", + null + ], + [ + "month", + null + ], + [ + "total", + null + ] + ] +} diff --git a/examples/ai_horde_client/aihorde_simple_client_example.py b/examples/ai_horde_client/aihorde_simple_client_example.py index 4780906..905ee41 100644 --- a/examples/ai_horde_client/aihorde_simple_client_example.py +++ b/examples/ai_horde_client/aihorde_simple_client_example.py @@ -15,19 +15,20 @@ def simple_generate_example(api_key: str = ANON_API_KEY) -> None: status_response, job_id = simple_client.image_generate_request( ImageGenerateAsyncRequest( apikey=api_key, + workers=["facf2d67-9e83-4a9e-a7ae-8e555d55af08"], params=ImageGenerationInputPayload( sampler_name=KNOWN_SAMPLERS.k_euler, cfg_scale=4, - width=512, + width=768, height=512, karras=False, - hires_fix=True, + hires_fix=False, clip_skip=1, steps=30, loras=[ LorasPayloadEntry( name="GlowingRunesAI", - model=-1, + model=1, clip=1, inject_trigger="any", # Get a random color trigger ), @@ -67,4 +68,5 @@ def simple_generate_example(api_key: str = ANON_API_KEY) -> None: api_key = args.api_key - simple_generate_example(api_key) + while True: + simple_generate_example(api_key) diff --git a/examples/ai_horde_client/async_aihorde_simple_client_example.py b/examples/ai_horde_client/async_aihorde_simple_client_example.py index cb6faaa..a08cb79 100644 --- a/examples/ai_horde_client/async_aihorde_simple_client_example.py +++ b/examples/ai_horde_client/async_aihorde_simple_client_example.py @@ -8,7 +8,12 @@ from horde_sdk import ANON_API_KEY, RequestErrorResponse from horde_sdk.ai_horde_api.ai_horde_clients import AIHordeAPIAsyncSimpleClient -from horde_sdk.ai_horde_api.apimodels import ImageGenerateAsyncRequest, ImageGenerateStatusResponse +from horde_sdk.ai_horde_api.apimodels import ( + ImageGenerateAsyncRequest, + ImageGenerateStatusResponse, + ImageGenerationInputPayload, + TIPayloadEntry, +) from horde_sdk.ai_horde_api.fields import JobID @@ -24,6 +29,17 @@ async def async_one_image_generate_example( apikey=apikey, prompt="A cat in a hat", models=["Deliberate"], + params=ImageGenerationInputPayload( + height=512, + width=512, + tis=[ + TIPayloadEntry( + name="72437", + inject_ti="negprompt", + strength=1, + ), + ], + ), ), ) @@ -51,14 +67,16 @@ async def async_multi_image_generate_example( ImageGenerateAsyncRequest( apikey=apikey, prompt="A cat in a blue hat", - models=["Deliberate"], + models=["SDXL 1.0"], + params=ImageGenerationInputPayload(height=1024, width=1024), ), ), simple_client.image_generate_request( ImageGenerateAsyncRequest( apikey=apikey, prompt="A cat in a red hat", - models=["Deliberate"], + models=["SDXL 1.0"], + params=ImageGenerationInputPayload(height=1024, width=1024), ), ), ) @@ -84,7 +102,7 @@ async def async_simple_generate_example(apikey: str = ANON_API_KEY) -> None: simple_client = AIHordeAPIAsyncSimpleClient(aiohttp_session) await async_one_image_generate_example(simple_client, apikey) - await async_multi_image_generate_example(simple_client, apikey) + # await async_multi_image_generate_example(simple_client, apikey) if __name__ == "__main__": diff --git a/horde_sdk/ai_horde_api/apimodels/__init__.py b/horde_sdk/ai_horde_api/apimodels/__init__.py index 86d020c..b934403 100644 --- a/horde_sdk/ai_horde_api/apimodels/__init__.py +++ b/horde_sdk/ai_horde_api/apimodels/__init__.py @@ -35,7 +35,7 @@ AlchemyStatusResponse, AlchemyUpscaleResult, ) -from horde_sdk.ai_horde_api.apimodels.base import LorasPayloadEntry +from horde_sdk.ai_horde_api.apimodels.base import LorasPayloadEntry, TIPayloadEntry from horde_sdk.ai_horde_api.apimodels.generate._async import ( ImageGenerateAsyncDryRunResponse, ImageGenerateAsyncRequest, @@ -46,8 +46,8 @@ from horde_sdk.ai_horde_api.apimodels.generate._pop import ( ImageGenerateJobPopPayload, ImageGenerateJobPopRequest, + ImageGenerateJobPopResponse, ImageGenerateJobPopSkippedStatus, - ImageGenerateJobResponse, ) from horde_sdk.ai_horde_api.apimodels.generate._status import ( DeleteImageGenerateRequest, @@ -94,7 +94,7 @@ "ImageGenerateJobPopRequest", "ImageGenerateJobPopPayload", "ImageGenerateJobPopSkippedStatus", - "ImageGenerateJobResponse", + "ImageGenerateJobPopResponse", "ImageGenerateStatusRequest", "ImageGenerateStatusResponse", "ImageGenerationJobSubmitRequest", @@ -106,6 +106,7 @@ "StatsImageModelsRequest", "StatsModelsResponse", "StatsModelsTimeframe", + "TIPayloadEntry", "UsageDetails", "UserAmountRecords", "UserKudosDetails", diff --git a/horde_sdk/ai_horde_api/apimodels/alchemy/_async.py b/horde_sdk/ai_horde_api/apimodels/alchemy/_async.py index a951b4a..3fb7354 100644 --- a/horde_sdk/ai_horde_api/apimodels/alchemy/_async.py +++ b/horde_sdk/ai_horde_api/apimodels/alchemy/_async.py @@ -39,7 +39,9 @@ def get_api_model_name(cls) -> str | None: return "RequestInterrogationResponse" @override - def get_follow_up_returned_params(self) -> list[dict[str, object]]: + def get_follow_up_returned_params(self, *, as_python_field_name: bool = False) -> list[dict[str, object]]: + if as_python_field_name: + return [{"id_": self.id_}] return [{AIHordePathData.id_: self.id_}] @override diff --git a/horde_sdk/ai_horde_api/apimodels/alchemy/_pop.py b/horde_sdk/ai_horde_api/apimodels/alchemy/_pop.py index 38f68fa..fe09af9 100644 --- a/horde_sdk/ai_horde_api/apimodels/alchemy/_pop.py +++ b/horde_sdk/ai_horde_api/apimodels/alchemy/_pop.py @@ -107,7 +107,7 @@ def get_follow_up_failure_cleanup_request_type(cls) -> type[AlchemyJobSubmitRequ return AlchemyJobSubmitRequest @override - def get_follow_up_returned_params(self) -> list[dict[str, object]]: + def get_follow_up_returned_params(self, *, as_python_field_name: bool = False) -> list[dict[str, object]]: if not self.forms: return [] all_ids: list[dict[str, object]] = [] @@ -116,7 +116,10 @@ def get_follow_up_returned_params(self) -> list[dict[str, object]]: logger.warning(f"Skipping form {form} as it is not an AlchemyPopFormPayload") continue if form.id_: - all_ids.append({"id": form.id_}) + if as_python_field_name: + all_ids.append({"id_": form.id_}) + else: + all_ids.append({"id": form.id_}) return all_ids diff --git a/horde_sdk/ai_horde_api/apimodels/base.py b/horde_sdk/ai_horde_api/apimodels/base.py index b8cb583..32f92df 100644 --- a/horde_sdk/ai_horde_api/apimodels/base.py +++ b/horde_sdk/ai_horde_api/apimodels/base.py @@ -1,11 +1,14 @@ """The base classes for all AI Horde API requests/responses.""" from __future__ import annotations +import random +import uuid + from loguru import logger -from pydantic import AliasChoices, BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from typing_extensions import override -from horde_sdk.ai_horde_api.consts import KNOWN_SAMPLERS +from horde_sdk.ai_horde_api.consts import KNOWN_SAMPLERS, POST_PROCESSOR_ORDER_TYPE from horde_sdk.ai_horde_api.endpoints import AI_HORDE_BASE_URL from horde_sdk.ai_horde_api.fields import JobID, WorkerID from horde_sdk.generic_api.apimodels import HordeRequest, HordeResponseBaseModel @@ -26,6 +29,14 @@ class JobRequestMixin(BaseModel): id_: JobID = Field(alias="id") """The UUID for this job. Use this to post the results in the future.""" + @field_validator("id_", mode="before") + def validate_id(cls, v: str | JobID) -> JobID | str: + if isinstance(v, str) and v == "": + logger.warning("Job ID is empty") + return JobID(root=uuid.uuid4()) + + return v + class JobResponseMixin(BaseModel): # TODO: this model may not actually exist as such in the API """Mix-in class for data relating to any generation jobs.""" @@ -33,6 +44,14 @@ class JobResponseMixin(BaseModel): # TODO: this model may not actually exist as id_: JobID = Field(alias="id") """The UUID for this job.""" + @field_validator("id_", mode="before") + def validate_id(cls, v: str | JobID) -> JobID | str: + if isinstance(v, str) and v == "": + logger.warning("Job ID is empty") + return JobID(root=uuid.uuid4()) + + return v + class WorkerRequestMixin(BaseModel): """Mix-in class for data relating to worker requests.""" @@ -109,6 +128,9 @@ class ImageGenerateParamMixin(BaseModel): """Deprecated.""" post_processing: list[str] = Field(default_factory=list) """A list of post-processing models to use.""" + post_processing_order: POST_PROCESSOR_ORDER_TYPE = POST_PROCESSOR_ORDER_TYPE.facefixers_first + """The order in which to apply post-processing models. + Applying upscalers or removing backgrounds before facefixers costs less kudos.""" karras: bool = True """Set to True if you want to use the Karras scheduling.""" tiling: bool = False @@ -131,11 +153,6 @@ class ImageGenerateParamMixin(BaseModel): """A list of textual inversion (embedding) parameters to use.""" special: dict = Field(default_factory=dict) """Reserved for future use.""" - steps: int = Field(default=25, ge=1, validation_alias=AliasChoices("steps", "ddim_steps")) - """The number of image generation steps to perform.""" - - n: int = Field(default=1, ge=1, le=20, validation_alias=AliasChoices("n", "n_iter")) - """The number of images to generate. Defaults to 1, maximum is 20.""" @field_validator("width", "height", mode="before") def width_divisible_by_64(cls, value: int) -> int: @@ -144,14 +161,6 @@ def width_divisible_by_64(cls, value: int) -> int: raise ValueError("width must be divisible by 64") return value - @field_validator("n", mode="before") - def validate_n(cls, value: int) -> int: - if value == 0: - logger.debug("n (number of images to generate) is not set; defaulting to 1") - return 1 - - return value - use_nsfw_censor: bool = False @field_validator("sampler_name") @@ -161,11 +170,19 @@ def sampler_name_must_be_known(cls, v: str | KNOWN_SAMPLERS) -> str | KNOWN_SAMP raise ValueError(f"Unknown sampler name {v}") return v - @model_validator(mode="after") - def validate_hires_fix(self) -> ImageGenerateParamMixin: - if self.hires_fix and (self.width < 512 or self.height < 512): - raise ValueError("hires_fix is only valid when width and height are both >= 512") - return self + # @model_validator(mode="after") + # def validate_hires_fix(self) -> ImageGenerateParamMixin: + # if self.hires_fix and (self.width < 512 or self.height < 512): + # raise ValueError("hires_fix is only valid when width and height are both >= 512") + # return self + + @field_validator("seed") + def random_seed_if_none(cls, v: str | None) -> str | None: + """If the seed is None, generate a random seed.""" + if v is None: + logger.debug("Generating random seed") + return str(random.randint(1, 1000000000)) + return v class JobSubmitResponse(HordeResponseBaseModel): diff --git a/horde_sdk/ai_horde_api/apimodels/generate/_async.py b/horde_sdk/ai_horde_api/apimodels/generate/_async.py index f69838e..730140a 100644 --- a/horde_sdk/ai_horde_api/apimodels/generate/_async.py +++ b/horde_sdk/ai_horde_api/apimodels/generate/_async.py @@ -1,4 +1,5 @@ -from pydantic import AliasChoices, Field, model_validator +from loguru import logger +from pydantic import AliasChoices, Field, field_validator, model_validator from typing_extensions import override from horde_sdk.ai_horde_api.apimodels.base import ( @@ -37,7 +38,9 @@ class ImageGenerateAsyncResponse( kudos: float @override - def get_follow_up_returned_params(self) -> list[dict[str, object]]: + def get_follow_up_returned_params(self, *, as_python_field_name: bool = False) -> list[dict[str, object]]: + if as_python_field_name: + return [{"id_": self.id_}] return [{"id": self.id_}] @classmethod @@ -88,6 +91,14 @@ class ImageGenerationInputPayload(HordeAPIObject, ImageGenerateParamMixin): def get_api_model_name(cls) -> str | None: return "ModelGenerationInputStable" + @field_validator("n", mode="before") + def validate_n(cls, value: int) -> int: + if value == 0: + logger.debug("n (number of images to generate) is not set; defaulting to 1") + return 1 + + return value + class ImageGenerateAsyncRequest( BaseAIHordeRequest, diff --git a/horde_sdk/ai_horde_api/apimodels/generate/_pop.py b/horde_sdk/ai_horde_api/apimodels/generate/_pop.py index fcfde51..b168266 100644 --- a/horde_sdk/ai_horde_api/apimodels/generate/_pop.py +++ b/horde_sdk/ai_horde_api/apimodels/generate/_pop.py @@ -1,15 +1,18 @@ +import uuid + import pydantic +from loguru import logger from pydantic import AliasChoices, Field, field_validator from typing_extensions import override from horde_sdk.ai_horde_api.apimodels.base import ( BaseAIHordeRequest, ImageGenerateParamMixin, - JobResponseMixin, ) from horde_sdk.ai_horde_api.apimodels.generate._submit import ImageGenerationJobSubmitRequest from horde_sdk.ai_horde_api.consts import GENERATION_STATE, KNOWN_SOURCE_PROCESSING from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_ENDPOINT_SUBPATH +from horde_sdk.ai_horde_api.fields import JobID from horde_sdk.consts import HTTPMethod from horde_sdk.generic_api.apimodels import ( APIKeyAllowedInRequestMixin, @@ -62,7 +65,7 @@ class ImageGenerateJobPopSkippedStatus(pydantic.BaseModel): class ImageGenerateJobPopPayload(ImageGenerateParamMixin): - prompt: str + prompt: str | None = None ddim_steps: int = Field(default=25, ge=1, validation_alias=AliasChoices("steps", "ddim_steps")) """The number of image generation steps to perform.""" @@ -70,20 +73,23 @@ class ImageGenerateJobPopPayload(ImageGenerateParamMixin): """The number of images to generate. Defaults to 1, maximum is 20.""" -class ImageGenerateJobResponse(HordeResponseBaseModel, JobResponseMixin, ResponseRequiringFollowUpMixin): +class ImageGenerateJobPopResponse(HordeResponseBaseModel, ResponseRequiringFollowUpMixin): """Represents the data returned from the `/v2/generate/pop` endpoint. v2 API Model: `GenerationPayloadStable` """ + id_: JobID | None = Field(None, alias="id") + """The UUID for this image generation.""" + payload: ImageGenerateJobPopPayload """The parameters used to generate this image.""" skipped: ImageGenerateJobPopSkippedStatus """The reasons this worker was not issued certain jobs, and the number of jobs for each reason.""" - model: str + model: str | None = None """Which of the available models to use for this request.""" source_image: str | None = None - """The Base64-encoded webp to use for img2img.""" + """The URL or Base64-encoded webp to use for img2img.""" source_processing: str | KNOWN_SOURCE_PROCESSING = KNOWN_SOURCE_PROCESSING.txt2img """If source_image is provided, specifies how to process it.""" source_mask: str | None = None @@ -100,6 +106,14 @@ def source_processing_must_be_known(cls, v: str | KNOWN_SOURCE_PROCESSING) -> st raise ValueError(f"Unknown source processing {v}") return v + @field_validator("id_", mode="before") + def validate_id(cls, v: str | JobID) -> JobID | str: + if isinstance(v, str) and v == "": + logger.warning("Job ID is empty") + return JobID(root=uuid.uuid4()) + + return v + @override @classmethod def get_api_model_name(cls) -> str | None: @@ -116,17 +130,30 @@ def get_follow_up_failure_cleanup_request_type(cls) -> type[ImageGenerationJobSu return ImageGenerationJobSubmitRequest @override - def get_follow_up_returned_params(self) -> list[dict[str, object]]: + def get_follow_up_returned_params(self, *, as_python_field_name: bool = False) -> list[dict[str, object]]: + if as_python_field_name: + return [{"id_": self.id_}] return [{"id": self.id_}] @override def get_follow_up_failure_cleanup_params(self) -> dict[str, object]: - return {"state": GENERATION_STATE.faulted} # TODO: One day, could I do away with the magic string? + return { + "state": GENERATION_STATE.faulted, + "seed": self.payload.seed, + "generation": "Faulted", + } # TODO: One day, could I do away with the magic string? @override def get_extra_fields_to_exclude_from_log(self) -> set[str]: return {"source_image"} + @override + def ignore_failure(self) -> bool: + if self.id_ is None: + return True + + return super().ignore_failure() + class ImageGenerateJobPopRequest(BaseAIHordeRequest, APIKeyAllowedInRequestMixin): """Represents the data needed to make a job request from a worker to the /v2/generate/pop endpoint. @@ -168,5 +195,5 @@ def get_api_endpoint_subpath(cls) -> AI_HORDE_API_ENDPOINT_SUBPATH: @override @classmethod - def get_default_success_response_type(cls) -> type[ImageGenerateJobResponse]: - return ImageGenerateJobResponse + def get_default_success_response_type(cls) -> type[ImageGenerateJobPopResponse]: + return ImageGenerateJobPopResponse diff --git a/horde_sdk/ai_horde_api/apimodels/generate/_status.py b/horde_sdk/ai_horde_api/apimodels/generate/_status.py index e57c512..bfb5088 100644 --- a/horde_sdk/ai_horde_api/apimodels/generate/_status.py +++ b/horde_sdk/ai_horde_api/apimodels/generate/_status.py @@ -1,4 +1,7 @@ -from pydantic import BaseModel, Field +import uuid + +from loguru import logger +from pydantic import BaseModel, Field, field_validator from typing_extensions import override from horde_sdk.ai_horde_api.apimodels.base import BaseAIHordeRequest, JobRequestMixin @@ -34,6 +37,14 @@ class ImageGeneration(BaseModel): censored: bool """When true this image has been censored by the worker's safety filter.""" + @field_validator("id_", mode="before") + def validate_id(cls, v: str | JobID) -> JobID | str: + if isinstance(v, str) and v == "": + logger.warning("Job ID is empty") + return JobID(root=uuid.uuid4()) + + return v + class ImageGenerateStatusResponse( HordeResponseBaseModel, diff --git a/horde_sdk/ai_horde_api/apimodels/generate/_submit.py b/horde_sdk/ai_horde_api/apimodels/generate/_submit.py index 3384611..0fe9157 100644 --- a/horde_sdk/ai_horde_api/apimodels/generate/_submit.py +++ b/horde_sdk/ai_horde_api/apimodels/generate/_submit.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from loguru import logger +from pydantic import model_validator from typing_extensions import override from horde_sdk.ai_horde_api.apimodels.base import BaseAIHordeRequest, JobRequestMixin, JobSubmitResponse @@ -13,15 +17,27 @@ class ImageGenerationJobSubmitRequest(BaseAIHordeRequest, JobRequestMixin, APIKe v2 API Model: `SubmitInputStable` """ - generation: str + generation: str = "" """R2 result was uploaded to R2, else the string of the result.""" state: GENERATION_STATE """The state of this generation.""" - seed: str + seed: int = 0 """The seed for this generation.""" censored: bool = False """If True, this resulting image has been censored.""" + @model_validator(mode="after") + def validate_generation(self) -> ImageGenerationJobSubmitRequest: + if self.generation == "": + logger.error("Generation cannot be an empty string.") + logger.error(self) + + if self.seed == 0: + logger.error("Seed cannot be 0.") + logger.error(self) + + return self + @override @classmethod def get_api_model_name(cls) -> str | None: diff --git a/horde_sdk/ai_horde_api/consts.py b/horde_sdk/ai_horde_api/consts.py index 443405b..972ff60 100644 --- a/horde_sdk/ai_horde_api/consts.py +++ b/horde_sdk/ai_horde_api/consts.py @@ -134,6 +134,20 @@ class KNOWN_MISC_POST_PROCESSORS(StrEnum): strip_background = auto() +class POST_PROCESSOR_ORDER_TYPE(StrEnum): + """The post processor order types that are known to the API. + + (facefixers_first, upscalers_first, custom, etc) + """ + + facefixers_first = auto() + upscalers_first = auto() + custom = auto() + + +DEFAULT_POST_PROCESSOR_ORDER = POST_PROCESSOR_ORDER_TYPE.facefixers_first + + class KNOWN_CLIP_BLIP_TYPES(StrEnum): caption = auto() interrogation = auto() diff --git a/horde_sdk/ai_horde_api/fields.py b/horde_sdk/ai_horde_api/fields.py index 1f58086..5aacbcc 100644 --- a/horde_sdk/ai_horde_api/fields.py +++ b/horde_sdk/ai_horde_api/fields.py @@ -6,7 +6,7 @@ import uuid from typing import Any -from pydantic import RootModel, field_validator +from pydantic import RootModel, field_validator, model_serializer from typing_extensions import override @@ -17,7 +17,11 @@ class UUID_Identifier(RootModel[uuid.UUID]): root: uuid.UUID - @field_validator("root", mode="before") + @model_serializer + def ser_model(self) -> str: + return str(self.root) + + @field_validator("root", mode="after") def id_must_be_uuid(cls, v: str | uuid.UUID) -> str | uuid.UUID: """Ensure that the ID is a valid UUID.""" if isinstance(v, uuid.UUID): @@ -45,12 +49,12 @@ def __eq__(self, other: Any) -> bool: return self.root == other.root if isinstance(other, str): - return self.root.__str__() == other + return str(self.root) == other if isinstance(other, uuid.UUID): return self.root == other - return False + raise NotImplementedError(f"Cannot compare {self.__class__.__name__} with {other.__class__.__name__}") @override def __hash__(self) -> int: diff --git a/horde_sdk/ai_horde_worker/bridge_data.py b/horde_sdk/ai_horde_worker/bridge_data.py index ec72a91..d1fb954 100644 --- a/horde_sdk/ai_horde_worker/bridge_data.py +++ b/horde_sdk/ai_horde_worker/bridge_data.py @@ -32,7 +32,7 @@ class BaseHordeBridgeData(BaseModel): model_config = ConfigDict(extra="allow") - @model_validator(mode="after") # type: ignore # FIXME: https://github.com/python/mypy/issues/15620 + @model_validator(mode="after") def validate_extra_params_warning(self) -> BaseHordeBridgeData: """Warn on extra parameters being passed.""" if not self.model_extra: @@ -128,9 +128,8 @@ def validate_is_dir(cls, v: str) -> str: class ImageWorkerBridgeData(SharedHordeBridgeData): """The bridge data file for a Dreamer or Alchemist worker.""" - # - # Dreamer - # + extra_stable_diffusion_models_folders: list[str] = Field(default_factory=list) + """A list of extra folders to search for stable diffusion models.""" allow_controlnet: bool = False """Whether to allow the use of ControlNet. This requires img2img to be enabled.""" @@ -263,9 +262,9 @@ class ImageWorkerBridgeData(SharedHordeBridgeData): forms: list[str] = ["caption", "nsfw", "interrogation", "post-process"] """The type of services or processing an alchemist worker will provide.""" - @model_validator(mode="after") # type: ignore # FIXME: https://github.com/python/mypy/issues/15620 - def validate_param_conflict(self) -> ImageWorkerBridgeData: - """Validate that the parameters are not conflicting.""" + @model_validator(mode="after") + def validate_model(self) -> ImageWorkerBridgeData: + """Validate that the parameters are not conflicting and make any fixed adjustments.""" if not self.allow_img2img and self.allow_controlnet: logger.warning( ( @@ -274,6 +273,9 @@ def validate_param_conflict(self) -> ImageWorkerBridgeData: ), ) self.allow_controlnet = False + + self.image_models_to_skip.append("SDXL_beta::stability.ai#6901") # FIXME: no magic strings + return self _meta_load_instructions: list[str] | None = None diff --git a/horde_sdk/ai_horde_worker/locale_info/bridge_data_fields.py b/horde_sdk/ai_horde_worker/locale_info/bridge_data_fields.py index a7c5e68..54fb181 100644 --- a/horde_sdk/ai_horde_worker/locale_info/bridge_data_fields.py +++ b/horde_sdk/ai_horde_worker/locale_info/bridge_data_fields.py @@ -192,5 +192,12 @@ " other worker can pretend to serve it" ), ), + "extra_stable_diffusion_models_folders": _L( + ( + "A list of folders to search for stable diffusion models. " + "This is useful if you want to load models from a folder other than the default " + "or if you want to load models from multiple folders." + ), + ), "test": _L("If set to true, the worker will not actually accept jobs, but will instead just print them out."), } diff --git a/horde_sdk/generic_api/apimodels.py b/horde_sdk/generic_api/apimodels.py index e74ed70..0e6498b 100644 --- a/horde_sdk/generic_api/apimodels.py +++ b/horde_sdk/generic_api/apimodels.py @@ -53,7 +53,7 @@ class ResponseRequiringFollowUpMixin(abc.ABC): """Represents any response from any Horde API which requires a follow up request of some kind.""" @abc.abstractmethod - def get_follow_up_returned_params(self) -> list[dict[str, object]]: + def get_follow_up_returned_params(self, *, as_python_field_name: bool = False) -> list[dict[str, object]]: """Return the information required from this response to submit a follow up request. Note that this dict uses the alias field names (as seen on the API), not the python field names. @@ -114,6 +114,9 @@ def get_follow_up_failure_cleanup_params(self) -> dict[str, object]: def get_follow_up_failure_cleanup_request(self) -> list[HordeRequest]: """Return the request for this response to clean up after a failed follow up request.""" + if self.ignore_failure(): + return [] + if self._cleanup_requests is not None: return self._cleanup_requests @@ -135,6 +138,34 @@ def get_follow_up_request_types(cls) -> list[type[HordeRequest]]: """Return a list of all the possible follow up request types for this response.""" return [cls.get_follow_up_default_request_type()] + def ignore_failure(self) -> bool: + """Return if the object is in a state which doesn't require failure follow up.""" + # ImageGenerateJobPopResponse was the use case at the time of writing + return False + + def does_target_request_follow_up(self, target_request: HordeRequest) -> bool: + """Return whether the `target_request` would follow up on this request. + + Args: + target_request (HordeRequest): The request to check if it would follow up on this request. + + Returns: + bool: Whether the `target_request` would follow up on this request. + """ + + follow_up_returned_params = self.get_follow_up_returned_params(as_python_field_name=True) + + if len(follow_up_returned_params) == 0: + logger.warning("No follow up returned params defined for this request") + return False + all_match = True + for param_set in follow_up_returned_params: + for key, value in param_set.items(): + if hasattr(target_request, key) and getattr(target_request, key) != value: + all_match = False + break + return all_match + class ResponseWithProgressMixin(BaseModel): """Represents any response from any Horde API which contains progress information.""" @@ -205,7 +236,7 @@ def get_http_method(cls) -> HTTPMethod: # X_Fields # TODO client_agent: str = Field( - default="horde_sdk:0.7.1:https://githib.com/haidra-org/horde-sdk", + default="horde_sdk:0.7.10:https://githib.com/haidra-org/horde-sdk", # FIXME alias="Client-Agent", ) @@ -267,27 +298,6 @@ def get_requires_follow_up(self) -> bool: return True return False - def does_target_request_follow_up(self, target_request: HordeRequest) -> bool: - """Return whether the `target_request` would follow up on this request. - - Args: - target_request (HordeRequest): The request to check if it would follow up on this request. - - Returns: - bool: Whether the `target_request` would follow up on this request. - """ - if not self.get_requires_follow_up(): - return False - - defined_response_types = self.get_success_status_response_pairs().values() - - for response_type in defined_response_types: - if issubclass(response_type, ResponseRequiringFollowUpMixin): # noqa: SIM102 - if type(target_request) in response_type.get_follow_up_request_types(): - return True - - return False - @override @classmethod def get_sensitive_fields(self) -> set[str]: diff --git a/horde_sdk/generic_api/generic_clients.py b/horde_sdk/generic_api/generic_clients.py index 45f5985..811ab11 100644 --- a/horde_sdk/generic_api/generic_clients.py +++ b/horde_sdk/generic_api/generic_clients.py @@ -167,10 +167,13 @@ def get_specified_data_keys(data_keys: type[StrEnum], api_request: HordeRequest) # Get the endpoint URL from the request and replace any path keys with their corresponding values endpoint_url: str = api_request.get_api_endpoint_url() - for py_field_name, api_field_name in specified_paths.items(): + for py_field_name, api_field_name in list(specified_paths.items()): # Replace the path key with the value from the request # IE: /v2/ratings/{id} -> /v2/ratings/123 + _endpoint_url = endpoint_url endpoint_url = endpoint_url.format_map({api_field_name: str(getattr(api_request, py_field_name))}) + if _endpoint_url == endpoint_url: + specified_paths.pop(py_field_name) # Extract any extra header fields and the request body data from the request extra_header_keys: list[str] = api_request.get_header_fields() @@ -210,6 +213,7 @@ def get_specified_data_keys(data_keys: type[StrEnum], api_request: HordeRequest) # Convert the request body data to a dictionary request_body_data_dict: dict | None = api_request.model_dump( + by_alias=True, exclude_none=True, exclude_unset=True, exclude=all_fields_to_exclude_from_body, @@ -479,9 +483,8 @@ def submit_request( ) else: # TODO: This whole else is duplicated in the asyncio version of this class. Refactor it out. # Check if this request is a cleanup or follow up request for a prior request - # Loop through each item in self._pending_follow_ups list - for index, (prior_request, _prior_response, cleanup_request) in enumerate(self._pending_follow_ups): + for index, (prior_request, prior_response, cleanup_request) in enumerate(self._pending_follow_ups): if api_request is cleanup_request: if not isinstance(response, RequestErrorResponse): self._pending_follow_ups.pop(index) @@ -494,23 +497,32 @@ def submit_request( logger.error(f"Response: {response}") break - # If the response isn't a final follow-up, we don't need to do anything else. - if not isinstance(response, ResponseWithProgressMixin): - continue - if not response.is_final_follow_up(): - continue - if not prior_request.get_requires_follow_up(): + if not isinstance(prior_response, ResponseRequiringFollowUpMixin): continue - # See if the current api_request is a follow-up to the prior_request - if not prior_request.does_target_request_follow_up(api_request): - continue + # If the response isn't a final follow-up, we don't need to do anything else. + if isinstance(response, ResponseWithProgressMixin): + if not response.is_final_follow_up(): + continue + if not prior_request.get_requires_follow_up(): + continue + + # See if the current api_request is a follow-up to the prior_request + if not prior_response.does_target_request_follow_up(api_request): + continue + + # Check if the current response indicates that the job is complete + if response.is_job_complete(prior_request.get_number_of_results_expected()): + # Remove the current item from the _pending_follow_ups list + # This is for the benefit of the __exit__ method (context management) + self._pending_follow_ups.pop(index) + break + else: + if not prior_response.does_target_request_follow_up(api_request): + continue - # Check if the current response indicates that the job is complete - if response.is_job_complete(prior_request.get_number_of_results_expected()): - # Remove the current item from the _pending_follow_ups list - # This is for the benefit of the __exit__ method (context management) self._pending_follow_ups.pop(index) + break return response @@ -563,6 +575,9 @@ def _handle_exit( if not isinstance(response_to_follow_up, ResponseRequiringFollowUpMixin): return True + if response_to_follow_up.ignore_failure(): + return True + # The message to log if an exception occurs. message = ( "An exception occurred while trying to create a recovery request! " @@ -652,44 +667,55 @@ async def submit_request( # Check if the response requires a follow-up request. if isinstance(response, ResponseRequiringFollowUpMixin): # Add the follow-up request to the list of pending follow-ups. - self._pending_follow_ups.append( - (api_request, response, response.get_follow_up_failure_cleanup_request()), - ) + if not response.ignore_failure(): + self._pending_follow_ups.append( + (api_request, response, response.get_follow_up_failure_cleanup_request()), + ) else: # Check if this request is a cleanup or follow up request for a prior request # Loop through each item in self._pending_follow_ups list - for index, (prior_request, _prior_response, cleanup_request) in enumerate(self._pending_follow_ups): + for index, (prior_request, prior_response, cleanup_request) in enumerate(self._pending_follow_ups): if api_request is cleanup_request: if not isinstance(response, RequestErrorResponse): self._pending_follow_ups.pop(index) - else: - logger.error( - "This api request would have followed up on an operation which requires it, but it " - "failed!", - ) - logger.error(f"Request: {api_request}") - logger.error(f"Response: {response}") + break + + logger.error( + "This api request would have followed up on an operation which requires it, but it " + "failed!", + ) + logger.error(f"Request: {api_request.log_safe_model_dump()}") + logger.error(f"Response: {response.log_safe_model_dump()}") break - # If the response isn't a final follow-up, we don't need to do anything else. - if not isinstance(response, ResponseWithProgressMixin): - continue - if not response.is_final_follow_up(): - continue - if not prior_request.get_requires_follow_up(): + if not isinstance(prior_response, ResponseRequiringFollowUpMixin): continue - # See if the current api_request is a follow-up to the prior_request - if not prior_request.does_target_request_follow_up(api_request): - continue + # If the response isn't a final follow-up, we don't need to do anything else. + if isinstance(response, ResponseWithProgressMixin): + if not response.is_final_follow_up(): + continue + if not prior_request.get_requires_follow_up(): + continue + + # See if the current api_request is a follow-up to the prior_request + if not prior_response.does_target_request_follow_up(api_request): + continue + + # Check if the current response indicates that the job is complete + if response.is_job_complete(prior_request.get_number_of_results_expected()): + # Remove the current item from the _pending_follow_ups list + # This is for the benefit of the __exit__ method (context management) + self._pending_follow_ups.pop(index) + break + else: + if not prior_response.does_target_request_follow_up(api_request): + continue - # Check if the current response indicates that the job is complete - if response.is_job_complete(prior_request.get_number_of_results_expected()): - # Remove the current item from the _pending_follow_ups list - # This is for the benefit of the __exit__ method (context management) self._pending_follow_ups.pop(index) + break # Return the response from the API. return response @@ -755,6 +781,9 @@ async def _handle_exit_async( if not isinstance(response_to_follow_up, ResponseRequiringFollowUpMixin): return True + if response_to_follow_up.ignore_failure(): + return True + # If we get here, we need to create a follow-up request to clean up after the premature ending. message = ( "An exception occurred while trying to create a recovery request! " @@ -779,10 +808,14 @@ async def _handle_exit_async( ) for cleanup_request in cleanup_requests ], + return_exceptions=True, ) # Log the results of each cleanup request. for i, cleanup_response in enumerate(cleanup_responses): + if isinstance(cleanup_response, Exception): + logger.error(f"Recovery request {i+1} failed!") + logger.info(f"Recovery request {i+1} submitted!") logger.debug(f"Recovery request {i+1}: {cleanup_requests[i].log_safe_model_dump()}") logger.debug(f"Recovery response {i+1}: {cleanup_response}") diff --git a/horde_sdk/localize.py b/horde_sdk/localize.py index 514b972..b6abfe1 100644 --- a/horde_sdk/localize.py +++ b/horde_sdk/localize.py @@ -3,4 +3,4 @@ def _L(s: str) -> str: """Indicate that the string is displayed to the user and should be localized.""" - return s + return str(s) diff --git a/horde_sdk/logging.py b/horde_sdk/logging.py index db745a0..71e0bc1 100644 --- a/horde_sdk/logging.py +++ b/horde_sdk/logging.py @@ -66,6 +66,7 @@ def is_trace_log(record: dict) -> bool: "diagnose": True, }, ] + PROGRESS_LOGGER_LABEL = "PROGRESS" """The label for request progress log messages. Less severity than INFO.""" COMPLETE_LOGGER_LABEL = "COMPLETE" @@ -88,4 +89,7 @@ def is_trace_log(record: dict) -> bool: if parsed_verbosity is not None: verbosity = parsed_verbosity -logger.configure(handlers=handler_config) +set_logger_handlers = os.getenv("HORDE_SDK_SET_DEFAULT_LOG_HANDLERS") + +if set_logger_handlers: + logger.configure(handlers=handler_config) diff --git a/pyproject.toml b/pyproject.toml index c000f91..e58b6f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "horde_sdk" -version = "0.7.9" +version = "0.7.11" description = "A python toolkit for interacting with the horde APIs, services, and ecosystem." authors = [ {name = "tazlin", email = "tazlin.on.github@gmail.com"}, diff --git a/tests/ai_horde_api/test_ai_horde_generate_api_calls.py b/tests/ai_horde_api/test_ai_horde_generate_api_calls.py index b3885ca..acc7edf 100644 --- a/tests/ai_horde_api/test_ai_horde_generate_api_calls.py +++ b/tests/ai_horde_api/test_ai_horde_generate_api_calls.py @@ -12,6 +12,7 @@ KNOWN_ALCHEMY_TYPES, AlchemyAsyncRequest, AlchemyAsyncRequestFormItem, + AlchemyStatusResponse, ImageGenerateAsyncRequest, ImageGenerateAsyncResponse, ImageGenerateStatusResponse, @@ -19,6 +20,12 @@ ImageGenerationInputPayload, LorasPayloadEntry, ) +from horde_sdk.ai_horde_api.consts import ( + KNOWN_FACEFIXERS, + KNOWN_MISC_POST_PROCESSORS, + KNOWN_UPSCALERS, + POST_PROCESSOR_ORDER_TYPE, +) from horde_sdk.ai_horde_api.fields import JobID from horde_sdk.generic_api.apimodels import RequestErrorResponse @@ -73,6 +80,103 @@ def test_simple_client_image_generate( assert image is not None + def test_simple_client_image_generate_with_post_process( + self, + ai_horde_api_key: str, + ) -> None: + """Test that a simple image generation request can be submitted and cancelled.""" + simple_client = AIHordeAPISimpleClient() + + pp_image_gen_request = ImageGenerateAsyncRequest( + apikey=ai_horde_api_key, + prompt="a cat in a hat", + params=ImageGenerationInputPayload( + seed="1234", + n=1, + post_processing=[KNOWN_UPSCALERS.RealESRGAN_x2plus], + ), + models=["Deliberate"], + ) + + image_generate_status_respons, job_id = simple_client.image_generate_request(pp_image_gen_request) + + if isinstance(image_generate_status_respons.generations, RequestErrorResponse): + raise AssertionError(image_generate_status_respons.generations.message) + + assert len(image_generate_status_respons.generations) == 1 + + image = simple_client.download_image_from_generation(image_generate_status_respons.generations[0]) + + assert image is not None + + def test_simple_client_image_generate_with_post_process_costly_order( + self, + ai_horde_api_key: str, + ) -> None: + """Test that a simple image generation request can be submitted and cancelled.""" + simple_client = AIHordeAPISimpleClient() + + pp_image_gen_request = ImageGenerateAsyncRequest( + apikey=ai_horde_api_key, + prompt="a cat in a hat", + params=ImageGenerationInputPayload( + seed="1234", + n=1, + post_processing=[ + KNOWN_UPSCALERS.RealESRGAN_x2plus, + KNOWN_FACEFIXERS.CodeFormers, + KNOWN_MISC_POST_PROCESSORS.strip_background, + ], + post_processing_order=POST_PROCESSOR_ORDER_TYPE.custom, + ), + models=["Deliberate"], + ) + + image_generate_status_respons, job_id = simple_client.image_generate_request(pp_image_gen_request) + + if isinstance(image_generate_status_respons.generations, RequestErrorResponse): + raise AssertionError(image_generate_status_respons.generations.message) + + assert len(image_generate_status_respons.generations) == 1 + + image = simple_client.download_image_from_generation(image_generate_status_respons.generations[0]) + + assert image is not None + + def test_simple_client_image_generate_with_post_process_fix_costly_order( + self, + ai_horde_api_key: str, + ) -> None: + """Test that a simple image generation request can be submitted and cancelled.""" + simple_client = AIHordeAPISimpleClient() + + pp_image_gen_request = ImageGenerateAsyncRequest( + apikey=ai_horde_api_key, + prompt="a cat in a hat", + params=ImageGenerationInputPayload( + seed="1234", + n=1, + post_processing=[ + KNOWN_UPSCALERS.RealESRGAN_x2plus, + KNOWN_FACEFIXERS.CodeFormers, + KNOWN_MISC_POST_PROCESSORS.strip_background, + ], + post_processing_order=POST_PROCESSOR_ORDER_TYPE.facefixers_first, + ), + models=["Deliberate"], + ) + + image_generate_status_respons, job_id = simple_client.image_generate_request(pp_image_gen_request) + + if isinstance(image_generate_status_respons.generations, RequestErrorResponse): + raise AssertionError(image_generate_status_respons.generations.message) + + assert len(image_generate_status_respons.generations) == 1 + + image = simple_client.download_image_from_generation(image_generate_status_respons.generations[0]) + + assert image is not None + def test_simple_client_image_generate_no_apikey_specified( self, ) -> None: @@ -108,7 +212,7 @@ def test_simple_client_image_generate_loras( params=ImageGenerationInputPayload( seed="1234", n=1, - loras=[LorasPayloadEntry(name="48139", model=1, clip=1)], + loras=[LorasPayloadEntry(name="76693", model=1, clip=1)], ), models=["Deliberate"], ) @@ -185,17 +289,58 @@ def test_simple_client_alchemy_basic( ) -> None: simple_client = AIHordeAPISimpleClient() - simple_client.alchemy_request( + result, jobid = simple_client.alchemy_request( alchemy_request=AlchemyAsyncRequest( forms=[ AlchemyAsyncRequestFormItem( name=KNOWN_ALCHEMY_TYPES.caption, ), + AlchemyAsyncRequestFormItem( + name=KNOWN_ALCHEMY_TYPES.RealESRGAN_x4plus, + ), ], source_image=default_testing_image_base64, ), ) + assert result is not None + + @pytest.mark.asyncio + async def test_simple_client_async_alchemy_basic_flood( + self, + default_testing_image_base64: str, + ) -> None: + # Perform 15 requests in parallel + async with aiohttp.ClientSession() as aiohttp_session: + simple_client = AIHordeAPIAsyncSimpleClient(aiohttp_session) + + async def submit_request() -> AlchemyStatusResponse: + result, jobid = await simple_client.alchemy_request( + alchemy_request=AlchemyAsyncRequest( + forms=[ + AlchemyAsyncRequestFormItem( + name=KNOWN_ALCHEMY_TYPES.caption, + ), + AlchemyAsyncRequestFormItem( + name=KNOWN_ALCHEMY_TYPES.RealESRGAN_x4plus, + ), + ], + source_image=default_testing_image_base64, + ), + ) + + if isinstance(result, RequestErrorResponse): + raise AssertionError(result.message) + + return result + + # Run 15 concurrent requests using asyncio + tasks = [asyncio.create_task(submit_request()) for _ in range(15)] + all_responses: list[AlchemyStatusResponse | None] = await asyncio.gather(*tasks) + + # Check that all requests were successful + assert len([response for response in all_responses if response]) == 15 + @pytest.mark.asyncio async def test_simple_client_async_image_generate_multiple( self, diff --git a/tests/ai_horde_api/test_dynamically_validate_against_swagger.py b/tests/ai_horde_api/test_dynamically_validate_against_swagger.py index ab93391..7b8d282 100644 --- a/tests/ai_horde_api/test_dynamically_validate_against_swagger.py +++ b/tests/ai_horde_api/test_dynamically_validate_against_swagger.py @@ -31,6 +31,9 @@ def all_ai_horde_model_defs_in_swagger(swagger_doc: SwaggerDoc) -> None: api_to_sdk_payload_model_map: dict[str, dict[HTTPMethod, type[HordeRequest]]] = {} api_to_sdk_response_model_map: dict[str, dict[HTTPStatusCode, type[HordeResponse]]] = {} + request_field_names_and_descriptions: dict[str, list[tuple[str, str | None]]] = {} + response_field_names_and_descriptions: dict[str, list[tuple[str, str | None]]] = {} + for request_type in all_request_types: endpoint_subpath: GENERIC_API_ENDPOINT_SUBPATH = request_type.get_api_endpoint_subpath() assert endpoint_subpath, f"Failed to get endpoint subpath for {request_type.__name__}" @@ -78,10 +81,11 @@ def all_ai_horde_model_defs_in_swagger(swagger_doc: SwaggerDoc) -> None: api_to_sdk_payload_model_map[endpoint_subpath][request_type.get_http_method()] = request_type - field_names_and_descriptions: list[tuple[str, str]] = [] for field_name, field_info in request_type.model_fields.items(): - if field_info.description: - field_names_and_descriptions.append((field_name, field_info.description)) + if request_type.__name__ not in request_field_names_and_descriptions: + request_field_names_and_descriptions[request_type.__name__] = [] + + request_field_names_and_descriptions[request_type.__name__].append((field_name, field_info.description)) endpoint_success_http_status_codes: list[HTTPStatusCode] = [ success_code @@ -99,6 +103,18 @@ def all_ai_horde_model_defs_in_swagger(swagger_doc: SwaggerDoc) -> None: api_to_sdk_response_model_map[endpoint_subpath] = request_type.get_success_status_response_pairs() + for response_type in request_type.get_success_status_response_pairs().values(): + for field_name, field_info in response_type.model_fields.items(): + if response_type.__name__ not in response_field_names_and_descriptions: + response_field_names_and_descriptions[response_type.__name__] = [] + + if field_info.description is not None: + response_field_names_and_descriptions[response_type.__name__].append( + (field_name, field_info.description), + ) + else: + response_field_names_and_descriptions[response_type.__name__].append((field_name, None)) + def json_serializer(obj: object) -> object: if isinstance(obj, str): return obj @@ -115,6 +131,12 @@ def json_serializer(obj: object) -> object: f.write(json.dumps(api_to_sdk_response_model_map, indent=4, default=json_serializer)) f.write("\n") + with open("docs/request_field_names_and_descriptions.json", "w") as f: + f.write(json.dumps(request_field_names_and_descriptions, indent=4, default=json_serializer)) + + with open("docs/response_field_names_and_descriptions.json", "w") as f: + f.write(json.dumps(response_field_names_and_descriptions, indent=4, default=json_serializer)) + def test_all_ai_horde_model_defs_in_swagger_from_prod_swagger() -> None: swagger_doc: SwaggerDoc | None = None diff --git a/tests/ai_horde_worker/test_model_meta.py b/tests/ai_horde_worker/test_model_meta.py index b2e90af..0c769bd 100644 --- a/tests/ai_horde_worker/test_model_meta.py +++ b/tests/ai_horde_worker/test_model_meta.py @@ -102,4 +102,4 @@ def test_image_models_unique_results_only( ) all_model_names = image_model_load_resolver.resolve_all_model_names() - assert len(resolved_model_names) == len(all_model_names) + assert len(resolved_model_names) >= len(all_model_names) diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_generate_async_post.json b/tests/test_data/ai_horde_api/example_payloads/_v2_generate_async_post.json index fbe195c..36b5297 100644 --- a/tests/test_data/ai_horde_api/example_payloads/_v2_generate_async_post.json +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_generate_async_post.json @@ -11,6 +11,7 @@ "post_processing": [ "GFPGAN" ], + "post_processing_order": "facefixers_first", "karras": false, "tiling": false, "hires_fix": false, diff --git a/tests/test_data/ai_horde_api/example_responses/_v2_generate_pop_post_200.json b/tests/test_data/ai_horde_api/example_responses/_v2_generate_pop_post_200.json index 139498a..9bf18ce 100644 --- a/tests/test_data/ai_horde_api/example_responses/_v2_generate_pop_post_200.json +++ b/tests/test_data/ai_horde_api/example_responses/_v2_generate_pop_post_200.json @@ -10,6 +10,7 @@ "post_processing": [ "GFPGAN" ], + "post_processing_order": "facefixers_first", "karras": false, "tiling": false, "hires_fix": false,