From 0655d59bab3225050b1234f31ae8a1f8ee7ab0eb Mon Sep 17 00:00:00 2001 From: Divided by Zer0 Date: Sun, 15 Sep 2024 13:22:56 +0200 Subject: [PATCH] feat: support for seek by worker name (#243) * feat: support for seek by worker name * feat: include `worker_name` in simple client worker details lookup * chore: add worker name filter to existing worker details example --------- Co-authored-by: tazlin --- examples/ai_horde_client/workers.py | 31 +++++++++++++++++-- horde_sdk/ai_horde_api/ai_horde_clients.py | 9 ++++-- .../apimodels/workers/_workers.py | 9 ++++-- 3 files changed, 42 insertions(+), 7 deletions(-) diff --git a/examples/ai_horde_client/workers.py b/examples/ai_horde_client/workers.py index 5548dba..9ee8d01 100644 --- a/examples/ai_horde_client/workers.py +++ b/examples/ai_horde_client/workers.py @@ -12,10 +12,21 @@ ) -def all_workers(api_key: str, simple_client: AIHordeAPISimpleClient, filename: str) -> None: +def all_workers( + api_key: str, + simple_client: AIHordeAPISimpleClient, + filename: str, + *, + worker_name: str | None = None, +) -> None: all_workers_response: AllWorkersDetailsResponse - all_workers_response = simple_client.workers_all_details() + all_workers_response = simple_client.workers_all_details(worker_name=worker_name) + + if worker_name is None: + logger.info("Getting details for all workers.") + else: + logger.info(f"Getting details for worker with name: {worker_name}") if all_workers_response is None: raise ValueError("No workers returned in the response.") @@ -101,6 +112,13 @@ def set_maintenance_mode( help="The worker ID to get details for.", ) + group.add_argument( + "--worker_name", + "-n", + type=str, + help="The worker name to get details for.", + ) + group2 = parser.add_mutually_exclusive_group() group2.add_argument( "--maintenance-mode-on", @@ -123,7 +141,14 @@ def set_maintenance_mode( simple_client = AIHordeAPISimpleClient() - if args.all: + if args.worker_name: + all_workers( + api_key=args.apikey, + simple_client=simple_client, + filename=args.filename, + worker_name=args.worker_name, + ) + elif args.all: all_workers( api_key=args.apikey, simple_client=simple_client, diff --git a/horde_sdk/ai_horde_api/ai_horde_clients.py b/horde_sdk/ai_horde_api/ai_horde_clients.py index 4babb34..a9d0060 100644 --- a/horde_sdk/ai_horde_api/ai_horde_clients.py +++ b/horde_sdk/ai_horde_api/ai_horde_clients.py @@ -955,6 +955,7 @@ def text_generate_request_dry_run( def workers_all_details( self, + worker_name: str | None = None, ) -> AllWorkersDetailsResponse: """Get all the details for all workers. @@ -962,7 +963,10 @@ def workers_all_details( WorkersAllDetailsResponse: The response from the API. """ with AIHordeAPIClientSession() as horde_session: - response = horde_session.submit_request(AllWorkersDetailsRequest(), AllWorkersDetailsResponse) + response = horde_session.submit_request( + AllWorkersDetailsRequest(name=worker_name), + AllWorkersDetailsResponse, + ) if isinstance(response, RequestErrorResponse): raise AIHordeRequestError(response) @@ -1643,6 +1647,7 @@ async def text_generate_request_dry_run( async def workers_all_details( self, + worker_name: str | None = None, ) -> AllWorkersDetailsResponse: """Get all the details for all workers. @@ -1651,7 +1656,7 @@ async def workers_all_details( """ if self._horde_client_session is not None: response = await self._horde_client_session.submit_request( - AllWorkersDetailsRequest(), + AllWorkersDetailsRequest(name=worker_name), AllWorkersDetailsResponse, ) else: diff --git a/horde_sdk/ai_horde_api/apimodels/workers/_workers.py b/horde_sdk/ai_horde_api/apimodels/workers/_workers.py index 6d18cb1..4009257 100644 --- a/horde_sdk/ai_horde_api/apimodels/workers/_workers.py +++ b/horde_sdk/ai_horde_api/apimodels/workers/_workers.py @@ -189,10 +189,15 @@ def get_api_model_name(cls) -> str | None: class AllWorkersDetailsRequest(BaseAIHordeRequest, APIKeyAllowedInRequestMixin): - """Returns information on all works. If a moderator API key is specified, it will return additional information.""" + """Returns information on all workers. + + If a moderator API key is specified, it will return additional information. + """ type_: WORKER_TYPE = Field(WORKER_TYPE.all, alias="type") """Filter workers by type. Default is 'all' which returns all workers.""" + name: str | None = Field(None) + """Returns a worker matching the exact name provided. Case insensitive.""" @override @classmethod @@ -217,7 +222,7 @@ def get_default_success_response_type(cls) -> type[AllWorkersDetailsResponse]: @override @classmethod def get_query_fields(cls) -> list[str]: - return ["type_"] + return ["type_", "name"] @classmethod def is_api_key_required(cls) -> bool: