Skip to content

Commit

Permalink
feat: support downloading source_image/masks/extra_images
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Mar 24, 2024
1 parent fb87565 commit 30e208b
Show file tree
Hide file tree
Showing 8 changed files with 309 additions and 5 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,6 @@ out.json

examples/requested_images/*.*
_version.py

tests/testing_result_images/*
!tests/testing_result_images/.results_go_here
2 changes: 1 addition & 1 deletion horde_sdk/ai_horde_api/apimodels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class ExtraSourceImageEntry(HordeAPIDataObject):
"""

image: str = Field(min_length=1)
"""The URL of the image to download."""
"""The URL of the image to download, or the base64 string once downloaded."""
strength: float = Field(default=1, ge=-5, le=5)
"""The strength to apply to this image on various operations."""

Expand Down
2 changes: 1 addition & 1 deletion horde_sdk/ai_horde_api/apimodels/generate/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,4 @@ def get_number_of_results_expected(self) -> int:

@override
def get_extra_fields_to_exclude_from_log(self) -> set[str]:
return {"source_image"}
return {"source_image", "source_mask", "extra_source_images"}
123 changes: 120 additions & 3 deletions horde_sdk/ai_horde_api/apimodels/generate/_pop.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import asyncio
import uuid

import aiohttp
from loguru import logger
from pydantic import AliasChoices, Field, field_validator, model_validator
from typing_extensions import override
Expand All @@ -25,6 +27,7 @@
APIKeyAllowedInRequestMixin,
HordeAPIDataObject,
HordeResponseBaseModel,
ResponseRequiringDownloadMixin,
ResponseRequiringFollowUpMixin,
)

Expand Down Expand Up @@ -85,7 +88,11 @@ class ImageGenerateJobPopPayload(ImageGenerateParamMixin):
"""The number of images to generate. Defaults to 1, maximum is 20."""


class ImageGenerateJobPopResponse(HordeResponseBaseModel, ResponseRequiringFollowUpMixin):
class ImageGenerateJobPopResponse(
HordeResponseBaseModel,
ResponseRequiringFollowUpMixin,
ResponseRequiringDownloadMixin,
):
"""Represents the data returned from the `/v2/generate/pop` endpoint.
v2 API Model: `GenerationPayloadStable`
Expand All @@ -104,14 +111,20 @@ class ImageGenerateJobPopResponse(HordeResponseBaseModel, ResponseRequiringFollo
"""Which of the available models to use for this request."""
source_image: str | None = None
"""The URL or Base64-encoded webp to use for img2img."""
_downloaded_source_image: str | None = None
"""The downloaded source image (as base64), if any. This is not part of the API response."""
source_processing: str | KNOWN_SOURCE_PROCESSING = KNOWN_SOURCE_PROCESSING.txt2img
"""If source_image is provided, specifies how to process it."""
source_mask: str | None = None
"""If img_processing is set to 'inpainting' or 'outpainting', this parameter can be optionally provided as the
mask of the areas to inpaint. If this arg is not passed, the inpainting/outpainting mask has to be embedded as
alpha channel."""
_downloaded_source_mask: str | None = None
"""The downloaded source mask (as base64), if any. This is not part of the API response."""
extra_source_images: list[ExtraSourceImageEntry] | None = None
"""Additional uploaded images which can be used for further operations."""
"""Additional uploaded images (as base64) which can be used for further operations."""
_downloaded_extra_source_images: list[ExtraSourceImageEntry] | None = None
"""The downloaded extra source images, if any. This is not part of the API response."""
r2_upload: str | None = None
"""(Obsolete) The r2 upload link to use to upload this image."""
r2_uploads: list[str] | None = None
Expand Down Expand Up @@ -182,7 +195,7 @@ def get_follow_up_failure_cleanup_params(self) -> dict[str, object]:

@override
def get_extra_fields_to_exclude_from_log(self) -> set[str]:
return {"source_image"}
return {"source_image", "source_mask", "extra_source_images"}

@override
def ignore_failure(self) -> bool:
Expand Down Expand Up @@ -210,6 +223,110 @@ def has_facefixer(self) -> bool:

return any(post_processing in KNOWN_FACEFIXERS.__members__ for post_processing in self.payload.post_processing)

def get_downloaded_source_image(self) -> str | None:
"""Get the downloaded source image."""
return self._downloaded_source_image

def get_downloaded_source_mask(self) -> str | None:
"""Get the downloaded source mask."""
return self._downloaded_source_mask

def get_downloaded_extra_source_images(self) -> list[ExtraSourceImageEntry] | None:
"""Get the downloaded extra source images."""
return (
self._downloaded_extra_source_images.copy() if self._downloaded_extra_source_images is not None else None
)

def async_download_source_image(self, client_session: aiohttp.ClientSession) -> asyncio.Task:
"""Download the source image concurrently."""

# If the source image is not set, there is nothing to download.
if self.source_image is None:
return asyncio.create_task(asyncio.sleep(0))

# If the source image is not a URL, it is already a base64 string.
if not self.source_image.startswith("http"):
self._downloaded_source_image = self.source_image
return asyncio.create_task(asyncio.sleep(0))

return asyncio.create_task(
self.download_file_to_field_as_base64(client_session, self.source_image, "_downloaded_source_image"),
)

def async_download_source_mask(self, client_session: aiohttp.ClientSession) -> asyncio.Task:
"""Download the source mask concurrently."""

# If the source mask is not set, there is nothing to download.
if self.source_mask is None:
return asyncio.create_task(asyncio.sleep(0))

# If the source mask is not a URL, it is already a base64 string.
if not self.source_mask.startswith("http"):
self._downloaded_source_mask = self.source_mask
return asyncio.create_task(asyncio.sleep(0))

return asyncio.create_task(
self.download_file_to_field_as_base64(client_session, self.source_mask, "_downloaded_source_mask"),
)

async def async_download_extra_source_images(
self,
client_session: aiohttp.ClientSession,
) -> list[ExtraSourceImageEntry] | None:
"""Download all extra source images concurrently."""

if self.extra_source_images is None or len(self.extra_source_images) == 0:
logger.info("No extra source images to download.")
return None

if self._downloaded_extra_source_images is None:
self._downloaded_extra_source_images = []
else:
logger.warning("Extra source images already downloaded.")
return self._downloaded_extra_source_images

tasks: list[asyncio.Task] = []

for extra_source_image in self.extra_source_images:
if extra_source_image.image is None:
continue

if not extra_source_image.image.startswith("http"):
self._downloaded_extra_source_images.append(extra_source_image)
continue

tasks.append(
asyncio.create_task(
self.download_file_as_base64(client_session, extra_source_image.image),
),
)

results = await asyncio.gather(*tasks, return_exceptions=True)

for result, extra_source_image in zip(results, self.extra_source_images, strict=True):
if isinstance(result, Exception) or not isinstance(result, str):
logger.error(f"Error downloading extra source image {extra_source_image.image}: {result}")
continue

self._downloaded_extra_source_images.append(
ExtraSourceImageEntry(image=result, strength=extra_source_image.strength),
)

return self._downloaded_extra_source_images.copy()

@override
async def async_download_additional_data(self, client_session: aiohttp.ClientSession) -> None:
"""Download all additional images concurrently."""
await asyncio.gather(
self.async_download_source_image(client_session),
self.async_download_source_mask(client_session),
self.async_download_extra_source_images(client_session),
)

@override
def download_additional_data(self) -> None:
raise NotImplementedError("This method is not yet implemented. Use async_download_additional_data instead.")

def __eq__(self, other: object) -> bool:
if not isinstance(other, ImageGenerateJobPopResponse):
return False
Expand Down
37 changes: 37 additions & 0 deletions horde_sdk/generic_api/apimodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from __future__ import annotations

import abc
import base64
import os
import uuid

import aiohttp
from loguru import logger
from pydantic import BaseModel, ConfigDict, Field, field_validator
from typing_extensions import override
Expand Down Expand Up @@ -218,6 +220,41 @@ def get_finalize_success_request_type(cls) -> type[HordeRequest] | None:
"""Return the request type for this response to finalize the job on success, or `None` if not needed."""


class ResponseRequiringDownloadMixin(BaseModel):
"""Represents any response which may require downloading additional data."""

async def download_file_as_base64(self, client_session: aiohttp.ClientSession, url: str) -> str:
"""Download a file and return the value as a base64 string."""
async with client_session.get(url) as response:
response.raise_for_status()
return base64.b64encode(await response.read()).decode("utf-8")

async def download_file_to_field_as_base64(
self,
client_session: aiohttp.ClientSession,
url: str,
field_name: str,
) -> None:
"""Download a file from a URL and save it to the field.
Args:
client_session (aiohttp.ClientSession): The aiohttp client session to use for the download.
url (str): The URL to download the file from.
field_name (str): The name of the field to save the file to.
"""
async with client_session.get(url) as response:
response.raise_for_status()
setattr(self, field_name, base64.b64encode(await response.read()).decode("utf-8"))

@abc.abstractmethod
async def async_download_additional_data(self, client_session: aiohttp.ClientSession) -> None:
"""Asynchronously download any additional data required for this response."""

@abc.abstractmethod
def download_additional_data(self) -> None:
"""Download any additional data required for this response."""


class ContainsMessageResponseMixin(BaseModel):
"""Represents any response from any Horde API which contains a message."""

Expand Down
73 changes: 73 additions & 0 deletions tests/ai_horde_api/test_ai_horde_api_models.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
"""Unit tests for AI-Horde API models."""

import base64
import io
import json
from uuid import UUID

import aiohttp
import PIL.Image
import pytest

from horde_sdk.ai_horde_api.apimodels import (
KNOWN_ALCHEMY_TYPES,
AlchemyPopFormPayload,
AlchemyPopResponse,
ImageGenerateAsyncResponse,
)
from horde_sdk.ai_horde_api.apimodels._find_user import (
ContributionsDetails,
Expand Down Expand Up @@ -594,6 +599,61 @@ def test_ImageGenerateJobPopResponse_hashability() -> None:
assert test_response_multiple_ids_2 in combined_container_multiple_ids


@pytest.mark.asyncio
async def test_ImageGenerateJobPop_download_addtl_data() -> None:
from horde_sdk.ai_horde_api.apimodels import ExtraSourceImageEntry

test_response = ImageGenerateJobPopResponse(
id=None,
ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))],
payload=ImageGenerateJobPopPayload(
post_processing=[KNOWN_UPSCALERS.RealESRGAN_x2plus],
prompt="A cat in a hat",
),
model="Deliberate",
source_image="https://raw.githubusercontent.com/db0/Stable-Horde/main/img_stable/0.jpg",
source_mask="https://raw.githubusercontent.com/db0/Stable-Horde/main/img_stable/1.jpg",
extra_source_images=[
ExtraSourceImageEntry(
image="https://raw.githubusercontent.com/db0/Stable-Horde/main/img_stable/2.jpg",
strength=1.0,
),
ExtraSourceImageEntry(
image="https://raw.githubusercontent.com/db0/Stable-Horde/main/img_stable/3.jpg",
strength=2.0,
),
],
skipped=ImageGenerateJobPopSkippedStatus(),
)

client_session = aiohttp.ClientSession()

await test_response.async_download_additional_data(client_session)

assert test_response._downloaded_source_image is not None
assert test_response._downloaded_source_mask is not None
assert test_response._downloaded_extra_source_images is not None
assert len(test_response._downloaded_extra_source_images) == 2

downloaded_source_image = test_response.get_downloaded_source_image()
assert downloaded_source_image is not None
assert PIL.Image.open(io.BytesIO(base64.b64decode(downloaded_source_image)))

downloaded_source_mask = test_response.get_downloaded_source_mask()
assert downloaded_source_mask is not None
assert PIL.Image.open(io.BytesIO(base64.b64decode(downloaded_source_mask)))

downloaded_extra_source_images = test_response.get_downloaded_extra_source_images()
assert downloaded_extra_source_images is not None
assert len(downloaded_extra_source_images) == 2
for extra_source_image in downloaded_extra_source_images:
assert extra_source_image is not None
assert PIL.Image.open(io.BytesIO(base64.b64decode(extra_source_image.image)))

assert downloaded_extra_source_images[0].strength == 1.0
assert downloaded_extra_source_images[1].strength == 2.0


def test_AlchemyPopResponse() -> None:
test_alchemy_pop_response = AlchemyPopResponse(
forms=[
Expand Down Expand Up @@ -721,3 +781,16 @@ def test_problem_payload() -> None:
problem_payload = json.loads(json_from_api)

ImageGenerateJobPopResponse.model_validate(problem_payload)


def test_problem_gen_request_response() -> None:

example_json = """
{
"id": "00000000-0000-0000-0000-000000000000",
"kudos": 8.0
}"""

json_from_api = json.loads(example_json)

ImageGenerateAsyncResponse.model_validate(json_from_api)
Loading

0 comments on commit 30e208b

Please sign in to comment.