Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests: api side worker test; fix: misc fixes #221

Merged
merged 11 commits into from
Jul 20, 2024
3 changes: 2 additions & 1 deletion .github/workflows/maintests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ jobs:
build:
env:
AIWORKER_CACHE_HOME: ${{ github.workspace }}/.cache
HORDE_MODEL_REFERENCE_MAKE_FOLDERS: 1
TESTS_ONGOING: 1
HORDE_SDK_TESTING: 1
HORDE_MODEL_REFERENCE_MAKE_FOLDERS: 1
runs-on: ubuntu-latest
strategy:
matrix:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/prtests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ jobs:
env:
AIWORKER_CACHE_HOME: ${{ github.workspace }}/.cache
TESTS_ONGOING: 1
HORDE_SDK_TESTING: 1
HORDE_MODEL_REFERENCE_MAKE_FOLDERS: 1
runs-on: ubuntu-latest
strategy:
Expand Down
5 changes: 3 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ repos:
pass_filenames: false
additional_dependencies: [
pytest,
pydantic,
pydantic==2.7.4,
types-Pillow,
types-requests,
types-pytz,
types-setuptools,
types-urllib3,
types-aiofiles,
StrEnum
StrEnum,
horde_model_reference==0.8.1,
]
10 changes: 5 additions & 5 deletions horde_sdk/ai_horde_api/apimodels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class ImageGenerateParamMixin(HordeAPIDataObject):
karras: bool = True
"""Set to True if you want to use the Karras scheduling."""
tiling: bool = False
"""Deprecated."""
"""Set to True if you want to use seamless tiling."""
hires_fix: bool = False
"""Set to True if you want to use the hires fix."""
hires_fix_denoising_strength: float | None = Field(default=None, ge=0, le=1)
Expand All @@ -234,17 +234,17 @@ class ImageGenerateParamMixin(HordeAPIDataObject):
"""Set to True if you want the ControlNet map returned instead of a generated image."""
facefixer_strength: float | None = Field(default=None, ge=0, le=1)
"""The strength of the facefixer model."""
loras: list[LorasPayloadEntry] = Field(default_factory=list)
loras: list[LorasPayloadEntry] | None = None
"""A list of lora parameters to use."""
tis: list[TIPayloadEntry] = Field(default_factory=list)
tis: list[TIPayloadEntry] | None = None
"""A list of textual inversion (embedding) parameters to use."""
extra_texts: list[ExtraTextEntry] = Field(default_factory=list)
extra_texts: list[ExtraTextEntry] | None = None
"""A list of extra texts and prompts to use in the comfyUI workflow."""
workflow: str | KNOWN_WORKFLOWS | None = None
"""The specific comfyUI workflow to use."""
transparent: bool | None = None
"""When true, will generate an image with a transparent background"""
special: dict[Any, Any] = Field(default_factory=dict)
special: dict[Any, Any] | None = None
"""Reserved for future use."""
use_nsfw_censor: bool = False
"""If the request is SFW, and the worker accidentally generates NSFW, it will send back a censored image."""
Expand Down
17 changes: 12 additions & 5 deletions horde_sdk/ai_horde_api/apimodels/generate/_pop.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,15 @@ def validate_id(cls, v: str | JobID) -> JobID | str:

return v

_ids_present: bool = False

@property
def ids_present(self) -> bool:
"""Whether or not the IDs are present."""
return self._ids_present

@model_validator(mode="after")
def ids_present(self) -> ImageGenerateJobPopResponse:
def validate_ids_present(self) -> ImageGenerateJobPopResponse:
"""Ensure that either id_ or ids is present."""
if self.model is None:
if self.skipped.is_empty():
Expand All @@ -270,6 +277,8 @@ def ids_present(self) -> ImageGenerateJobPopResponse:
logger.debug("Sorting IDs")
self.ids.sort()

self._ids_present = True

return self

@override
Expand Down Expand Up @@ -418,11 +427,9 @@ class PopInput(HordeAPIObject):
max_length=1000,
)
"""The worker name, version and website."""
models: list[str] | None = None
models: list[str]
"""The models this worker can generate."""
name: str | None = Field(
None,
)
name: str
"""The Name of the Worker."""
nsfw: bool | None = Field(
False,
Expand Down
9 changes: 7 additions & 2 deletions horde_sdk/generic_api/apimodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,14 @@ def get_extra_fields_to_exclude_from_log(self) -> set[str]:
"""Return an additional set of fields to exclude from the log_safe_model_dump method."""
return set()

def log_safe_model_dump(self) -> dict[Any, Any]:
def log_safe_model_dump(self, extra_exclude: set[str] | None = None) -> dict[Any, Any]:
"""Return a dict of the model's fields, with any sensitive fields redacted."""
return self.model_dump(exclude=self.get_sensitive_fields() | self.get_extra_fields_to_exclude_from_log())
if extra_exclude is None:
extra_exclude = set()

return self.model_dump(
exclude=self.get_sensitive_fields() | self.get_extra_fields_to_exclude_from_log() | extra_exclude,
)


class HordeResponse(HordeAPIMessage):
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ exclude = [
concurrency = ["gevent"]

[tool.pytest.ini_options]
# You can use `and`, `or`, `not` and parentheses to filter with the `-m` option
markers = [
# "slow: marks tests as slow (deselect with '-m \"not slow\"')",
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"object_verify: marks tests that verify the API object structure and layout",
"api_side_ci: indicates that the test is intended to run during CI for the API",
]
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
horde_model_reference~=0.7.0
horde_model_reference~=0.8.1

pydantic
pydantic==2.7.4
requests
StrEnum
loguru
Expand Down
174 changes: 174 additions & 0 deletions tests/ai_horde_api/test_ai_worker_roundtrip_api_calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import asyncio

import aiohttp
import PIL.Image
import pytest
import yarl
from loguru import logger

from horde_sdk.ai_horde_api.ai_horde_clients import (
AIHordeAPIAsyncClientSession,
AIHordeAPIAsyncSimpleClient,
)
from horde_sdk.ai_horde_api.apimodels import (
ImageGenerateAsyncRequest,
ImageGenerateJobPopRequest,
ImageGenerateJobPopResponse,
ImageGenerateStatusResponse,
ImageGenerationJobSubmitRequest,
JobSubmitResponse,
)
from horde_sdk.ai_horde_api.consts import (
GENERATION_STATE,
)
from horde_sdk.ai_horde_api.fields import JobID


class TestImageWorkerRoundtrip:
async def fake_worker_checkin(
self,
aiohttp_session: aiohttp.ClientSession,
horde_client_session: AIHordeAPIAsyncClientSession,
image_gen_request: ImageGenerateAsyncRequest,
) -> None:
assert image_gen_request.params is not None

effective_resolution = (image_gen_request.params.width * image_gen_request.params.height) * 2

job_pop_request = ImageGenerateJobPopRequest(
name="fake CI worker",
bridge_agent="AI Horde Worker reGen:8.0.1-citests:https://github.com/Haidra-Org/horde-worker-reGen",
max_pixels=effective_resolution,
models=image_gen_request.models,
)

job_pop_response = await horde_client_session.submit_request(
job_pop_request,
job_pop_request.get_default_success_response_type(),
)

assert isinstance(job_pop_response, ImageGenerateJobPopResponse)
logger.info(f"{job_pop_response.log_safe_model_dump({'skipped'})}")

assert not job_pop_response.ids_present
assert job_pop_response.skipped is not None

logger.info(f"Checked in as fake worker ({effective_resolution}): {job_pop_response.skipped}")

async def fake_worker(
self,
aiohttp_session: aiohttp.ClientSession,
horde_client_session: AIHordeAPIAsyncClientSession,
image_gen_request: ImageGenerateAsyncRequest,
) -> None:
assert image_gen_request.params is not None

effective_resolution = (image_gen_request.params.width * image_gen_request.params.height) * 2

job_pop_request = ImageGenerateJobPopRequest(
name="fake CI worker",
bridge_agent="AI Horde Worker reGen:8.0.1-citests:https://github.com/Haidra-Org/horde-worker-reGen",
max_pixels=effective_resolution,
models=image_gen_request.models,
)

max_tries = 5
try_wait = 5
current_try = 0

while True:
job_pop_response = await horde_client_session.submit_request(
job_pop_request,
job_pop_request.get_default_success_response_type(),
)

assert isinstance(job_pop_response, ImageGenerateJobPopResponse)
logger.info(f"{job_pop_response.log_safe_model_dump({'skipped'})}")
logger.info(f"Checked in as fake worker ({effective_resolution}): {job_pop_response.skipped}")

if not job_pop_response.ids_present:
if current_try >= max_tries:
raise RuntimeError("Max tries exceeded")

logger.info(f"Waiting {try_wait} seconds before retrying")
await asyncio.sleep(try_wait)
current_try += 1
continue

# We're going to send a blank image base64 encoded
fake_image = PIL.Image.new(
"RGB",
(image_gen_request.params.width, image_gen_request.params.height),
(255, 255, 255),
)

fake_image_bytes = fake_image.tobytes()

r2_url = job_pop_response.r2_upload

assert r2_url is not None

async with aiohttp_session.put(
yarl.URL(r2_url, encoded=True),
data=fake_image_bytes,
skip_auto_headers=["content-type"],
timeout=aiohttp.ClientTimeout(total=10),
) as response:
assert response.status == 200

assert job_pop_response.ids is not None
assert len(job_pop_response.ids) == 1

job_submit_request = ImageGenerationJobSubmitRequest(
id=job_pop_response.ids[0],
state=GENERATION_STATE.ok,
generation="R2",
seed="1312",
)

job_submit_response = await horde_client_session.submit_request(
job_submit_request,
job_submit_request.get_default_success_response_type(),
)

assert isinstance(job_submit_response, JobSubmitResponse)
assert job_submit_response.reward is not None and job_submit_response.reward > 0

break

@pytest.mark.api_side_ci
@pytest.mark.asyncio
async def test_basic_image_roundtrip(self, simple_image_gen_request: ImageGenerateAsyncRequest) -> None:
aiohttp_session = aiohttp.ClientSession()
horde_client_session = AIHordeAPIAsyncClientSession(aiohttp_session)

async with aiohttp_session, horde_client_session:
simple_client = AIHordeAPIAsyncSimpleClient(horde_client_session=horde_client_session)

await self.fake_worker_checkin(aiohttp_session, horde_client_session, simple_image_gen_request)

image_gen_task = asyncio.create_task(simple_client.image_generate_request(simple_image_gen_request))

fake_worker_task = asyncio.create_task(
self.fake_worker(
aiohttp_session,
horde_client_session,
simple_image_gen_request,
),
)

await asyncio.gather(image_gen_task, fake_worker_task)

image_gen_response, job_id = image_gen_task.result()

assert isinstance(image_gen_response, ImageGenerateStatusResponse)
assert isinstance(job_id, JobID)

assert len(image_gen_response.generations) == 1

generation = image_gen_response.generations[0]
assert generation.seed == "1312"
assert generation.img is not None
assert not generation.gen_metadata

assert generation.censored is False
18 changes: 17 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pathlib

import pytest
from loguru import logger

os.environ["TESTS_ONGOING"] = "1"

Expand All @@ -15,6 +16,21 @@ def check_tests_ongoing_env_var() -> None:
"""Checks that the TESTS_ONGOING environment variable is set."""
assert os.getenv("TESTS_ONGOING", None) is not None, "TESTS_ONGOING environment variable not set"

AI_HORDE_TESTING = os.getenv("AI_HORDE_TESTING", None)
HORDE_SDK_TESTING = os.getenv("HORDE_SDK_TESTING", None)
if AI_HORDE_TESTING is None and HORDE_SDK_TESTING is None:
logger.warning(
"Neither AI_HORDE_TESTING nor HORDE_SDK_TESTING environment variables are set. "
"Is this a local development test run? If so, set AI_HORDE_TESTING=1 or HORDE_SDK_TESTING=1 to suppress "
"this warning",
)

if AI_HORDE_TESTING is not None:
logger.info("AI_HORDE_TESTING environment variable set.")

if HORDE_SDK_TESTING is not None:
logger.info("HORDE_SDK_TESTING environment variable set.")


@pytest.fixture(scope="session")
def ai_horde_api_key() -> str:
Expand All @@ -30,7 +46,7 @@ def simple_image_gen_request(ai_horde_api_key: str) -> ImageGenerateAsyncRequest
prompt="a cat in a hat",
models=["Deliberate"],
params=ImageGenerationInputPayload(
steps=1,
steps=5,
n=1,
),
)
Expand Down
6 changes: 4 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ skip_empty = True
description = base evironment
passenv =
AIWORKER_CACHE_HOME
HORDE_SDK_TESTING
AI_HORDE_TESTING
TESTS_ONGOING

[testenv:pre-commit]
Expand All @@ -33,7 +35,7 @@ deps =
requests
-r requirements.txt
commands =
pytest tests {posargs} --cov
pytest tests {posargs} --cov -m "not api_side_ci"


[testenv:tests-no-api-calls]
Expand All @@ -48,4 +50,4 @@ deps =
requests
-r requirements.txt
commands =
pytest tests {posargs} --ignore-glob=*api_calls.py --cov
pytest tests {posargs} --ignore-glob=*api_calls.py -m "not api_side_ci"
Loading