Skip to content

Commit

Permalink
feat: support for seek by worker name (#243)
Browse files Browse the repository at this point in the history
* feat: support for seek by worker name

* feat: include `worker_name` in simple client worker details lookup

* chore: add worker name filter to existing worker details example

---------

Co-authored-by: tazlin <[email protected]>
  • Loading branch information
db0 and tazlin authored Sep 15, 2024
1 parent 372400b commit 0655d59
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 7 deletions.
31 changes: 28 additions & 3 deletions examples/ai_horde_client/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,21 @@
)


def all_workers(api_key: str, simple_client: AIHordeAPISimpleClient, filename: str) -> None:
def all_workers(
api_key: str,
simple_client: AIHordeAPISimpleClient,
filename: str,
*,
worker_name: str | None = None,
) -> None:
all_workers_response: AllWorkersDetailsResponse

all_workers_response = simple_client.workers_all_details()
all_workers_response = simple_client.workers_all_details(worker_name=worker_name)

if worker_name is None:
logger.info("Getting details for all workers.")
else:
logger.info(f"Getting details for worker with name: {worker_name}")

if all_workers_response is None:
raise ValueError("No workers returned in the response.")
Expand Down Expand Up @@ -101,6 +112,13 @@ def set_maintenance_mode(
help="The worker ID to get details for.",
)

group.add_argument(
"--worker_name",
"-n",
type=str,
help="The worker name to get details for.",
)

group2 = parser.add_mutually_exclusive_group()
group2.add_argument(
"--maintenance-mode-on",
Expand All @@ -123,7 +141,14 @@ def set_maintenance_mode(

simple_client = AIHordeAPISimpleClient()

if args.all:
if args.worker_name:
all_workers(
api_key=args.apikey,
simple_client=simple_client,
filename=args.filename,
worker_name=args.worker_name,
)
elif args.all:
all_workers(
api_key=args.apikey,
simple_client=simple_client,
Expand Down
9 changes: 7 additions & 2 deletions horde_sdk/ai_horde_api/ai_horde_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,14 +955,18 @@ def text_generate_request_dry_run(

def workers_all_details(
self,
worker_name: str | None = None,
) -> AllWorkersDetailsResponse:
"""Get all the details for all workers.
Returns:
WorkersAllDetailsResponse: The response from the API.
"""
with AIHordeAPIClientSession() as horde_session:
response = horde_session.submit_request(AllWorkersDetailsRequest(), AllWorkersDetailsResponse)
response = horde_session.submit_request(
AllWorkersDetailsRequest(name=worker_name),
AllWorkersDetailsResponse,
)

if isinstance(response, RequestErrorResponse):
raise AIHordeRequestError(response)
Expand Down Expand Up @@ -1643,6 +1647,7 @@ async def text_generate_request_dry_run(

async def workers_all_details(
self,
worker_name: str | None = None,
) -> AllWorkersDetailsResponse:
"""Get all the details for all workers.
Expand All @@ -1651,7 +1656,7 @@ async def workers_all_details(
"""
if self._horde_client_session is not None:
response = await self._horde_client_session.submit_request(
AllWorkersDetailsRequest(),
AllWorkersDetailsRequest(name=worker_name),
AllWorkersDetailsResponse,
)
else:
Expand Down
9 changes: 7 additions & 2 deletions horde_sdk/ai_horde_api/apimodels/workers/_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,15 @@ def get_api_model_name(cls) -> str | None:


class AllWorkersDetailsRequest(BaseAIHordeRequest, APIKeyAllowedInRequestMixin):
"""Returns information on all works. If a moderator API key is specified, it will return additional information."""
"""Returns information on all workers.
If a moderator API key is specified, it will return additional information.
"""

type_: WORKER_TYPE = Field(WORKER_TYPE.all, alias="type")
"""Filter workers by type. Default is 'all' which returns all workers."""
name: str | None = Field(None)
"""Returns a worker matching the exact name provided. Case insensitive."""

@override
@classmethod
Expand All @@ -217,7 +222,7 @@ def get_default_success_response_type(cls) -> type[AllWorkersDetailsResponse]:
@override
@classmethod
def get_query_fields(cls) -> list[str]:
return ["type_"]
return ["type_", "name"]

@classmethod
def is_api_key_required(cls) -> bool:
Expand Down

0 comments on commit 0655d59

Please sign in to comment.