Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: handle extra image retries #171

Merged
merged 1 commit into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions horde_sdk/ai_horde_api/apimodels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ class ExtraSourceImageEntry(HordeAPIDataObject):
v2 API Model: `ExtraSourceImage`
"""

original_url: str | None = None
"""The URL of the original image after it was downloaded."""

image: str = Field(min_length=1)
"""The URL of the image to download, or the base64 string once downloaded."""
strength: float = Field(default=1, ge=-5, le=5)
Expand Down
95 changes: 68 additions & 27 deletions horde_sdk/ai_horde_api/apimodels/generate/_pop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import uuid
from urllib.parse import urlparse

import aiohttp
from loguru import logger
Expand Down Expand Up @@ -245,7 +246,7 @@ def async_download_source_image(self, client_session: aiohttp.ClientSession) ->
return asyncio.create_task(asyncio.sleep(0))

# If the source image is not a URL, it is already a base64 string.
if not self.source_image.startswith("http"):
if urlparse(self.source_image).scheme not in ["http", "https"]:
self._downloaded_source_image = self.source_image
return asyncio.create_task(asyncio.sleep(0))

Expand All @@ -272,6 +273,8 @@ def async_download_source_mask(self, client_session: aiohttp.ClientSession) -> a
async def async_download_extra_source_images(
self,
client_session: aiohttp.ClientSession,
*,
max_retries: int = 5,
) -> list[ExtraSourceImageEntry] | None:
"""Download all extra source images concurrently."""

Expand All @@ -285,32 +288,70 @@ async def async_download_extra_source_images(
logger.warning("Extra source images already downloaded.")
return self._downloaded_extra_source_images

tasks: list[asyncio.Task] = []

for extra_source_image in self.extra_source_images:
if extra_source_image.image is None:
continue

if not extra_source_image.image.startswith("http"):
self._downloaded_extra_source_images.append(extra_source_image)
continue

tasks.append(
asyncio.create_task(
self.download_file_as_base64(client_session, extra_source_image.image),
),
)

results = await asyncio.gather(*tasks, return_exceptions=True)

for result, extra_source_image in zip(results, self.extra_source_images, strict=True):
if isinstance(result, Exception) or not isinstance(result, str):
logger.error(f"Error downloading extra source image {extra_source_image.image}: {result}")
continue

self._downloaded_extra_source_images.append(
ExtraSourceImageEntry(image=result, strength=extra_source_image.strength),
)
attempts = 0
while attempts < max_retries:
tasks: list[asyncio.Task] = []

for extra_source_image in self.extra_source_images:
if extra_source_image.image is None:
continue

if urlparse(extra_source_image.image).scheme not in ["http", "https"]:
self._downloaded_extra_source_images.append(extra_source_image)
tasks.append(asyncio.create_task(asyncio.sleep(0)))
continue

if any(
extra_source_image.image == downloaded_extra_source_image.original_url
for downloaded_extra_source_image in self._downloaded_extra_source_images
):
logger.debug(f"Extra source image {extra_source_image.image} already downloaded.")
tasks.append(asyncio.create_task(asyncio.sleep(0)))
continue

tasks.append(
asyncio.create_task(
self.download_file_as_base64(client_session, extra_source_image.image),
),
)

results = await asyncio.gather(*tasks, return_exceptions=True)

for result, extra_source_image in zip(results, self.extra_source_images, strict=True):
if isinstance(result, Exception) or not isinstance(result, str):
logger.error(f"Error downloading extra source image {extra_source_image.image}: {result}")
continue

self._downloaded_extra_source_images.append(
ExtraSourceImageEntry(
image=result,
strength=extra_source_image.strength,
original_url=extra_source_image.image,
),
)

if len(self._downloaded_extra_source_images) == len(self.extra_source_images):
break

attempts += 1

# If there are any entries in _downloaded_extra_source_images,
# make sure the order matches the order of the original list.
if (
self.extra_source_images is not None
and self._downloaded_extra_source_images is not None
and len(self._downloaded_extra_source_images) > 0
):

def _sort_key(x: ExtraSourceImageEntry) -> int:
if self.extra_source_images is not None:
for i, extra_source_image in enumerate(self.extra_source_images):
if extra_source_image.image == x.original_url:
return i

return 0

self._downloaded_extra_source_images.sort(key=_sort_key)

return self._downloaded_extra_source_images.copy()

Expand Down
4 changes: 4 additions & 0 deletions tests/ai_horde_api/test_ai_horde_api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,10 @@ async def test_ImageGenerateJobPop_download_addtl_data() -> None:
assert len(downloaded_extra_source_images) == 2
for extra_source_image in downloaded_extra_source_images:
assert extra_source_image is not None
assert extra_source_image.original_url is not None
assert extra_source_image.original_url.startswith(
"https://raw.githubusercontent.com/db0/Stable-Horde/main/img_stable/",
)
assert PIL.Image.open(io.BytesIO(base64.b64decode(extra_source_image.image)))

assert downloaded_extra_source_images[0].strength == 1.0
Expand Down
Loading