Skip to content

Commit

Permalink
feat: handle extra image retries
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Mar 25, 2024
1 parent 30e208b commit 4589cf2
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 27 deletions.
3 changes: 3 additions & 0 deletions horde_sdk/ai_horde_api/apimodels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ class ExtraSourceImageEntry(HordeAPIDataObject):
v2 API Model: `ExtraSourceImage`
"""

original_url: str | None = None
"""The URL of the original image after it was downloaded."""

image: str = Field(min_length=1)
"""The URL of the image to download, or the base64 string once downloaded."""
strength: float = Field(default=1, ge=-5, le=5)
Expand Down
95 changes: 68 additions & 27 deletions horde_sdk/ai_horde_api/apimodels/generate/_pop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import uuid
from urllib.parse import urlparse

import aiohttp
from loguru import logger
Expand Down Expand Up @@ -245,7 +246,7 @@ def async_download_source_image(self, client_session: aiohttp.ClientSession) ->
return asyncio.create_task(asyncio.sleep(0))

# If the source image is not a URL, it is already a base64 string.
if not self.source_image.startswith("http"):
if urlparse(self.source_image).scheme not in ["http", "https"]:
self._downloaded_source_image = self.source_image
return asyncio.create_task(asyncio.sleep(0))

Expand All @@ -272,6 +273,8 @@ def async_download_source_mask(self, client_session: aiohttp.ClientSession) -> a
async def async_download_extra_source_images(
self,
client_session: aiohttp.ClientSession,
*,
max_retries: int = 5,
) -> list[ExtraSourceImageEntry] | None:
"""Download all extra source images concurrently."""

Expand All @@ -285,32 +288,70 @@ async def async_download_extra_source_images(
logger.warning("Extra source images already downloaded.")
return self._downloaded_extra_source_images

tasks: list[asyncio.Task] = []

for extra_source_image in self.extra_source_images:
if extra_source_image.image is None:
continue

if not extra_source_image.image.startswith("http"):
self._downloaded_extra_source_images.append(extra_source_image)
continue

tasks.append(
asyncio.create_task(
self.download_file_as_base64(client_session, extra_source_image.image),
),
)

results = await asyncio.gather(*tasks, return_exceptions=True)

for result, extra_source_image in zip(results, self.extra_source_images, strict=True):
if isinstance(result, Exception) or not isinstance(result, str):
logger.error(f"Error downloading extra source image {extra_source_image.image}: {result}")
continue

self._downloaded_extra_source_images.append(
ExtraSourceImageEntry(image=result, strength=extra_source_image.strength),
)
attempts = 0
while attempts < max_retries:
tasks: list[asyncio.Task] = []

for extra_source_image in self.extra_source_images:
if extra_source_image.image is None:
continue

if urlparse(extra_source_image.image).scheme not in ["http", "https"]:
self._downloaded_extra_source_images.append(extra_source_image)
tasks.append(asyncio.create_task(asyncio.sleep(0)))
continue

if any(
extra_source_image.image == downloaded_extra_source_image.original_url
for downloaded_extra_source_image in self._downloaded_extra_source_images
):
logger.debug(f"Extra source image {extra_source_image.image} already downloaded.")
tasks.append(asyncio.create_task(asyncio.sleep(0)))
continue

tasks.append(
asyncio.create_task(
self.download_file_as_base64(client_session, extra_source_image.image),
),
)

results = await asyncio.gather(*tasks, return_exceptions=True)

for result, extra_source_image in zip(results, self.extra_source_images, strict=True):
if isinstance(result, Exception) or not isinstance(result, str):
logger.error(f"Error downloading extra source image {extra_source_image.image}: {result}")
continue

self._downloaded_extra_source_images.append(
ExtraSourceImageEntry(
image=result,
strength=extra_source_image.strength,
original_url=extra_source_image.image,
),
)

if len(self._downloaded_extra_source_images) == len(self.extra_source_images):
break

attempts += 1

# If there are any entries in _downloaded_extra_source_images,
# make sure the order matches the order of the original list.
if (
self.extra_source_images is not None
and self._downloaded_extra_source_images is not None
and len(self._downloaded_extra_source_images) > 0
):

def _sort_key(x: ExtraSourceImageEntry) -> int:
if self.extra_source_images is not None:
for i, extra_source_image in enumerate(self.extra_source_images):
if extra_source_image.image == x.original_url:
return i

return 0

self._downloaded_extra_source_images.sort(key=_sort_key)

return self._downloaded_extra_source_images.copy()

Expand Down
4 changes: 4 additions & 0 deletions tests/ai_horde_api/test_ai_horde_api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,10 @@ async def test_ImageGenerateJobPop_download_addtl_data() -> None:
assert len(downloaded_extra_source_images) == 2
for extra_source_image in downloaded_extra_source_images:
assert extra_source_image is not None
assert extra_source_image.original_url is not None
assert extra_source_image.original_url.startswith(
"https://raw.githubusercontent.com/db0/Stable-Horde/main/img_stable/",
)
assert PIL.Image.open(io.BytesIO(base64.b64decode(extra_source_image.image)))

assert downloaded_extra_source_images[0].strength == 1.0
Expand Down

0 comments on commit 4589cf2

Please sign in to comment.