From 4589cf2e5f50ba64d8286d0ab6770ca08281fdc1 Mon Sep 17 00:00:00 2001 From: tazlin Date: Sun, 24 Mar 2024 22:04:06 -0400 Subject: [PATCH] feat: handle extra image retries --- horde_sdk/ai_horde_api/apimodels/base.py | 3 + .../ai_horde_api/apimodels/generate/_pop.py | 95 +++++++++++++------ .../ai_horde_api/test_ai_horde_api_models.py | 4 + 3 files changed, 75 insertions(+), 27 deletions(-) diff --git a/horde_sdk/ai_horde_api/apimodels/base.py b/horde_sdk/ai_horde_api/apimodels/base.py index bf6c2c7..986ae71 100644 --- a/horde_sdk/ai_horde_api/apimodels/base.py +++ b/horde_sdk/ai_horde_api/apimodels/base.py @@ -131,6 +131,9 @@ class ExtraSourceImageEntry(HordeAPIDataObject): v2 API Model: `ExtraSourceImage` """ + original_url: str | None = None + """The URL of the original image after it was downloaded.""" + image: str = Field(min_length=1) """The URL of the image to download, or the base64 string once downloaded.""" strength: float = Field(default=1, ge=-5, le=5) diff --git a/horde_sdk/ai_horde_api/apimodels/generate/_pop.py b/horde_sdk/ai_horde_api/apimodels/generate/_pop.py index 2a0a97c..9c620ca 100644 --- a/horde_sdk/ai_horde_api/apimodels/generate/_pop.py +++ b/horde_sdk/ai_horde_api/apimodels/generate/_pop.py @@ -2,6 +2,7 @@ import asyncio import uuid +from urllib.parse import urlparse import aiohttp from loguru import logger @@ -245,7 +246,7 @@ def async_download_source_image(self, client_session: aiohttp.ClientSession) -> return asyncio.create_task(asyncio.sleep(0)) # If the source image is not a URL, it is already a base64 string. - if not self.source_image.startswith("http"): + if urlparse(self.source_image).scheme not in ["http", "https"]: self._downloaded_source_image = self.source_image return asyncio.create_task(asyncio.sleep(0)) @@ -272,6 +273,8 @@ def async_download_source_mask(self, client_session: aiohttp.ClientSession) -> a async def async_download_extra_source_images( self, client_session: aiohttp.ClientSession, + *, + max_retries: int = 5, ) -> list[ExtraSourceImageEntry] | None: """Download all extra source images concurrently.""" @@ -285,32 +288,70 @@ async def async_download_extra_source_images( logger.warning("Extra source images already downloaded.") return self._downloaded_extra_source_images - tasks: list[asyncio.Task] = [] - - for extra_source_image in self.extra_source_images: - if extra_source_image.image is None: - continue - - if not extra_source_image.image.startswith("http"): - self._downloaded_extra_source_images.append(extra_source_image) - continue - - tasks.append( - asyncio.create_task( - self.download_file_as_base64(client_session, extra_source_image.image), - ), - ) - - results = await asyncio.gather(*tasks, return_exceptions=True) - - for result, extra_source_image in zip(results, self.extra_source_images, strict=True): - if isinstance(result, Exception) or not isinstance(result, str): - logger.error(f"Error downloading extra source image {extra_source_image.image}: {result}") - continue - - self._downloaded_extra_source_images.append( - ExtraSourceImageEntry(image=result, strength=extra_source_image.strength), - ) + attempts = 0 + while attempts < max_retries: + tasks: list[asyncio.Task] = [] + + for extra_source_image in self.extra_source_images: + if extra_source_image.image is None: + continue + + if urlparse(extra_source_image.image).scheme not in ["http", "https"]: + self._downloaded_extra_source_images.append(extra_source_image) + tasks.append(asyncio.create_task(asyncio.sleep(0))) + continue + + if any( + extra_source_image.image == downloaded_extra_source_image.original_url + for downloaded_extra_source_image in self._downloaded_extra_source_images + ): + logger.debug(f"Extra source image {extra_source_image.image} already downloaded.") + tasks.append(asyncio.create_task(asyncio.sleep(0))) + continue + + tasks.append( + asyncio.create_task( + self.download_file_as_base64(client_session, extra_source_image.image), + ), + ) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + for result, extra_source_image in zip(results, self.extra_source_images, strict=True): + if isinstance(result, Exception) or not isinstance(result, str): + logger.error(f"Error downloading extra source image {extra_source_image.image}: {result}") + continue + + self._downloaded_extra_source_images.append( + ExtraSourceImageEntry( + image=result, + strength=extra_source_image.strength, + original_url=extra_source_image.image, + ), + ) + + if len(self._downloaded_extra_source_images) == len(self.extra_source_images): + break + + attempts += 1 + + # If there are any entries in _downloaded_extra_source_images, + # make sure the order matches the order of the original list. + if ( + self.extra_source_images is not None + and self._downloaded_extra_source_images is not None + and len(self._downloaded_extra_source_images) > 0 + ): + + def _sort_key(x: ExtraSourceImageEntry) -> int: + if self.extra_source_images is not None: + for i, extra_source_image in enumerate(self.extra_source_images): + if extra_source_image.image == x.original_url: + return i + + return 0 + + self._downloaded_extra_source_images.sort(key=_sort_key) return self._downloaded_extra_source_images.copy() diff --git a/tests/ai_horde_api/test_ai_horde_api_models.py b/tests/ai_horde_api/test_ai_horde_api_models.py index 4e85aed..9d9b671 100644 --- a/tests/ai_horde_api/test_ai_horde_api_models.py +++ b/tests/ai_horde_api/test_ai_horde_api_models.py @@ -648,6 +648,10 @@ async def test_ImageGenerateJobPop_download_addtl_data() -> None: assert len(downloaded_extra_source_images) == 2 for extra_source_image in downloaded_extra_source_images: assert extra_source_image is not None + assert extra_source_image.original_url is not None + assert extra_source_image.original_url.startswith( + "https://raw.githubusercontent.com/db0/Stable-Horde/main/img_stable/", + ) assert PIL.Image.open(io.BytesIO(base64.b64decode(extra_source_image.image))) assert downloaded_extra_source_images[0].strength == 1.0