diff --git a/src/leonardo_api/leonardo_async.py b/src/leonardo_api/leonardo_async.py index 25f0e9e..de7cb7f 100644 --- a/src/leonardo_api/leonardo_async.py +++ b/src/leonardo_api/leonardo_async.py @@ -5,20 +5,21 @@ Copyright (c) 2023. All rights reserved. Created: 28.08.2023 -Last Modified: 30.09.2023 +Last Modified: 24.11.2023 Description: This file contains asynchronous implementation for Leonardo.ai API """ -import asyncio import json import logging +import mimetypes import os from typing import Optional import aiofiles import aiohttp +import asyncio from .logger_config import setup_logger @@ -29,7 +30,7 @@ class Leonardo: Parameters: auth_token (str): Auth Bearer token. Required. - logger (logging.Logger, optional): default logger. . Default will be initialized if not provided. + logger (logging.Logger, optional): default logger. Default will be initialized if not provided. """ def __init__(self, auth_token: str, logger: Optional[logging.Logger] = None): @@ -37,32 +38,60 @@ def __init__(self, auth_token: str, logger: Optional[logging.Logger] = None): Constructs all the necessary attributes for the Leonardo object. :param auth_token: Auth Bearer token. Required. - :param logger: default logger. . Default will be initialized if not provided. + :type auth_token: str + :param logger: default logger. Default will be initialized if not provided. + :type logger: logging.Logger, optional """ - self.___session = aiohttp.ClientSession(headers={"Authorization": f"Bearer {auth_token}"}) + self.___auth_token = auth_token self.___logger = logger if logger else setup_logger("Leonardo", "leonardo_async.log") - self.___get_headers = {"content-type": "application/json"} - self.___post_headers = {"accept": "application/json", "content-type": "application/json"} self.___logger.debug("Leonardo init complete") + async def ___get_client_session(self, request_type: str = "get", empty: bool = False): + """ + This method returns aiohttp.ClientSession with headers. + + :param request_type: type of request: "get" or "post" + :type request_type: str + :param empty: is True if headers will be empty + :type empty: bool + :return: client session with headers + :rtype: aiohttp.ClientSession + """ + headers = {} + if not empty: + headers = {"Authorization": f"Bearer {self.___auth_token}"} + if request_type.lower() == "get" or request_type.lower() == "delete": + headers.update({"content-type": "application/json"}) + if request_type.lower() == "post": + headers.update({"accept": "application/json", "content-type": "application/json"}) + return aiohttp.ClientSession(headers=headers) + async def get_user_info(self): """ This endpoint will return your user information, including your user ID. + + :return: user info + :rtype: dict + + Raises: + Exception: if error occurred while getting user info """ url = "https://cloud.leonardo.ai/api/rest/v1/me" self.___logger.debug(f"Requesting user info: GET {url}") + session = await self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - async with self.___session.get(url, headers=headers_copy) as response: + async with session.get(url) as response: response.raise_for_status() response = await response.json() self.___logger.debug(f"User info: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while getting user info: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def post_generations( self, @@ -90,25 +119,49 @@ async def post_generations( This endpoint will generate images. :param prompt: The prompt used to generate images. + :type prompt: str :param negative_prompt: The negative prompt used for the image generation. + :type negative_prompt: str, optional :param model_id: The model ID used for the image generation. + :type model_id: str, optional :param sd_version: The base version of stable diffusion to use if not using a custom model. + :type sd_version: str, optional :param num_images: The number of images to generate. Default is 4. + :type num_images: int, optional :param width: The width of the images. Default is 512px. + :type width: int, optional :param height: The height of the images. Default is 512px. + :type height: int, optional :param num_inference_steps: The number of inference steps for generation. Must be from 40 to 60. Default is 40. + :type num_inference_steps: int, optional :param guidance_scale: How strongly the generation should reflect the prompt. Number from 1 to 20. Default is 7. + :type guidance_scale: int, optional :param init_generation_image_id: The ID of an existing image to use in image2image. + :type init_generation_image_id: str, optional :param init_image_id: The ID of an Init Image to use in image2image. + :type init_image_id: str, optional :param init_strength: How strongly the generated images should reflect the original image in image2image. + :type init_strength: float, optional :param scheduler: The scheduler to generate images with. + :type scheduler: str, optional :param preset_style: The style to generate images with. + :type preset_style: str, optional :param tiling: Whether the generated images should tile on all axis. Default is False. + :type tiling: bool, optional :param public: Whether the generated images should show in the community feed. Default is False. + :type public: bool, optional :param prompt_magic: Enable to use Prompt Magic. Default is True. + :type prompt_magic: bool, optional :param control_net: Enable to use ControlNet. Requires an init image to be provided. Requires a model based on SD v1.5. Default is False. + :type control_net: bool, optional :param control_net_type: The type of ControlNet to use. + :type control_net_type: str, optional + :return: generation response + :rtype: str + + Raises: + Exception: if error occurred while post generations """ # pylint: disable=too-many-locals url = "https://cloud.leonardo.ai/api/rest/v1/generations" @@ -134,36 +187,46 @@ async def post_generations( "controlNetType": control_net_type, } self.___logger.debug(f"Requesting post generations: POST {url} with payload: {payload}") + session = await self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - async with self.___session.post(url, json=payload, headers=headers_copy) as response: + async with session.post(url, json=payload) as response: response.raise_for_status() response = await response.json() self.___logger.debug(f"Post generations: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while post generations: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def get_single_generation(self, generation_id: str): """ This endpoint will provide information about a specific generation. :param generation_id: The ID of the generation to return. + :type generation_id: str + :return: generation info + :rtype: dict + + Raises: + Exception: if error occurred while get single generation """ url = f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}" self.___logger.debug(f"Requested single generations: GET {url} with generation_id={generation_id}") + session = await self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - async with self.___session.get(url, headers=headers_copy) as response: + async with session.get(url) as response: response.raise_for_status() response = await response.json() self.___logger.debug(f"Single generations: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while get single generations: {str(error)}") + if not session.closed: + await session.close() raise async def delete_single_generation(self, generation_id: str): @@ -171,20 +234,27 @@ async def delete_single_generation(self, generation_id: str): This endpoint deletes a specific generation. :param generation_id: The ID of the generation to delete. + :type generation_id: str + :return: generation info + :rtype: dict + + Raises: + Exception: if error occurred while delete single generation """ url = f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}" self.___logger.debug(f"Delete generations with generation_id={generation_id}: DELETE {url}") + session = await self.___get_client_session("delete") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - async with self.___session.delete(url, headers=headers_copy) as response: + async with session.delete(url) as response: response.raise_for_status() - response = await response.json() self.___logger.debug(f"Generations {generation_id} has been deleted: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while delete generation: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def get_generations_by_user(self, user_id: str, offset: int = 0, limit: int = 10): """ @@ -193,27 +263,41 @@ async def get_generations_by_user(self, user_id: str, offset: int = 0, limit: in :param user_id: The ID of the user. :param offset: The offset for pagination. :param limit: The limit for pagination. + :return: generations + :rtype: dict + + Raises: + Exception: if error occurred while get generations by user """ url = f"https://cloud.leonardo.ai/api/rest/v1/generations/user/{user_id}" params = {"offset": offset, "limit": limit} self.___logger.debug(f"Requested generations for {user_id} with params {params}: GET {url}") + session = await self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - async with self.___session.get(url, params=params, headers=headers_copy) as response: + async with session.get(url, params=params) as response: response.raise_for_status() response = await response.json() self.___logger.debug(f"Generations for user {user_id} are: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while obtaining user's generations: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def upload_init_image(self, file_path: str): """ This endpoint returns pre-signed details to upload an init image to S3. :param file_path: The path to the image file. + :type file_path: str + :return: generation_id + :rtype: str + + Raises: + ValueError: if invalid file extension + Exception: if error occurred while upload init image """ valid_extensions = ["png", "jpg", "jpeg", "webp"] extension = os.path.splitext(file_path)[1].strip(".") @@ -223,50 +307,62 @@ async def upload_init_image(self, file_path: str): url = "https://cloud.leonardo.ai/api/rest/v1/init-image" payload = {"extension": extension} self.___logger.debug(f"Init image {file_path} upload requested with payload = {payload}: POST {url}") + session = await self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - async with self.___session.post(url, json=payload, headers=headers_copy) as response: + async with session.post(url, json=payload) as response: response.raise_for_status() data = await response.json() - self.___logger.debug(f"Init image {file_path} initiated: {data}") + await session.close() + self.___logger.debug(f"Init image {file_path} initiated as: {data['uploadInitImage']['url']}") + generation_id = data["uploadInitImage"]["id"] upload_url = data["uploadInitImage"]["url"] fields = json.loads(data["uploadInitImage"]["fields"]) + self.___logger.debug(f"Init image {file_path} uploading with as binary: POST {upload_url}") async with aiofiles.open(file_path, "rb") as file: file_data = await file.read() - - fields.update({"file": file_data}) - self.___logger.debug(f"Init image {file_path} uploading with as binary: POST {upload_url}") - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - async with self.___session.post(upload_url, data=fields, headers=headers_copy) as response: - response.raise_for_status() - response_text = await response.text() - self.___logger.debug(f"Init image {file_path} has been uploaded: {response_text}") - return response_text + data = aiohttp.FormData() + for key, value in fields.items(): + data.add_field(key, value) + data.add_field("file", file_data, filename=file_path, content_type=mimetypes.guess_type(file_path)[0]) + session = await self.___get_client_session("post", empty=True) + async with session.post(upload_url, data=data) as response: + response.raise_for_status() + self.___logger.debug(f"Init image {file_path} has been uploaded, generation_id is: {generation_id}") + await session.close() + return generation_id except Exception as error: self.___logger.error(f"Error occurred while upload init image: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def get_single_init_image(self, image_id: str): """ This endpoint will return a single init image. :param image_id: The ID of the init image to return. + :type image_id: str + :return: init image + :rtype: dict + + Raises: + Exception: if error occurred while get single init image """ url = f"https://cloud.leonardo.ai/api/rest/v1/init-image/{image_id}" self.___logger.debug(f"Requested single image with image_id={image_id}: GET {url}") + session = await self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - async with self.___session.get(url, headers=headers_copy) as response: + async with session.get(url) as response: response.raise_for_status() response = await response.json() self.___logger.debug(f"Single image provided: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while obtain single init image: {str(error)}") + if not session.closed: + await session.close() raise async def delete_init_image(self, image_id: str): @@ -274,61 +370,82 @@ async def delete_init_image(self, image_id: str): This endpoint deletes an init image. :param image_id: The ID of the init image to delete. + :type image_id: str + + Raises: + Exception: if error occurred while delete init image """ url = f"https://cloud.leonardo.ai/api/rest/v1/init-image/{image_id}" self.___logger.debug(f"Requested to delete single image with image_id={image_id}: DELETE {url}") + session = await self.___get_client_session("delete") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - async with self.___session.delete(url, headers=headers_copy) as response: + async with session.delete(url) as response: response.raise_for_status() - response = await response.json() self.___logger.debug(f"Single image deleted: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while deleting init image: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def create_upscale(self, image_id: str): """ This endpoint will create an upscale for the provided image ID. :param image_id: The ID of the image to upscale. + :type image_id: str + :return: upscale info + :rtype: dict + + Raises: + Exception: if error occurred while create upscale """ url = "https://cloud.leonardo.ai/api/rest/v1/variations/upscale" payload = {"id": image_id} self.___logger.debug(f"Requested to upscale image with payload {payload}: POST {url}") + session = await self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - async with self.___session.post(url, json=payload, headers=headers_copy) as response: + async with session.post(url, json=payload) as response: response.raise_for_status() response = await response.json() self.___logger.debug(f"Upscale created: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while up-scaling image: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def get_variation_by_id(self, generation_id: str): """ This endpoint will get the variation by ID. :param generation_id: The ID of the variation to get. + :type generation_id: str + :return: variation info + :rtype: dict + + Raises: + Exception: if error occurred while get variation by id """ url = f"https://cloud.leonardo.ai/api/rest/v1/variations/{generation_id}" self.___logger.debug(f"Requested to obtain variation by id {generation_id}: GET {url}") + session = await self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - async with self.___session.get(url, headers=headers_copy) as response: + async with session.get(url) as response: response.raise_for_status() response = await response.json() self.___logger.debug(f"Get variation by ID: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while get variation by id: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def create_dataset(self, name: str, description: Optional[str] = None): """ @@ -336,39 +453,52 @@ async def create_dataset(self, name: str, description: Optional[str] = None): :param name: The name of the dataset. :param description: A description for the dataset. + :return: dataset info + :rtype: dict + + Raises: + Exception: if error occurred while create dataset """ url = "https://cloud.leonardo.ai/api/rest/v1/datasets" payload = {"name": name, "description": description} self.___logger.debug(f"Requested to create dataset with payload {payload}: POST {url}") + session = await self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - async with self.___session.post(url, json=payload, headers=headers_copy) as response: + async with session.post(url, json=payload) as response: response.raise_for_status() - response = await response.json() self.___logger.debug(f"Dataset has been created: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while create dataset: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def get_dataset_by_id(self, dataset_id: str): """ This endpoint gets the specific dataset. :param dataset_id: The ID of the dataset to return. + :return: dataset info + + Raises: + Exception: if error occurred while get dataset by id """ url = f"https://cloud.leonardo.ai/api/rest/v1/datasets/{dataset_id}" self.___logger.debug(f"Requested to obtain dataset dataset_id={dataset_id}: GET {url}") + session = await self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - async with self.___session.get(url, headers=headers_copy.update(self.___get_headers)) as response: + async with session.get(url) as response: response.raise_for_status() response = await response.json() self.___logger.debug(f"Dataset with dataset_id={dataset_id} provided: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while get dataset: {str(error)}") + if not session.closed: + await session.close() raise async def delete_dataset_by_id(self, dataset_id: str): @@ -376,20 +506,27 @@ async def delete_dataset_by_id(self, dataset_id: str): This endpoint deletes the specific dataset. :param dataset_id: The ID of the dataset to delete. + :type dataset_id: str + :return: dataset info + :rtype: dict + + Raises: + Exception: if error occurred while delete dataset by id """ url = f"https://cloud.leonardo.ai/api/rest/v1/datasets/{dataset_id}" self.___logger.debug(f"Requested to delete dataset dataset_id={dataset_id}: DELETE {url}") + session = await self.___get_client_session("delete") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - async with self.___session.delete(url, headers=headers_copy) as response: + async with session.delete(url) as response: response.raise_for_status() - response = await response.json() self.___logger.debug(f"Dataset with dataset_id={dataset_id} has been deleted: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while delete dataset: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def upload_dataset_image(self, dataset_id: str, file_path: str): """ @@ -397,6 +534,11 @@ async def upload_dataset_image(self, dataset_id: str, file_path: str): :param dataset_id: The ID of the dataset to which the image will be uploaded. :param file_path: The path to the image file. + :return: dataset info + :rtype: dict + + Raises: + ValueError: if invalid file extension """ # pylint: disable=too-many-locals valid_extensions = ["png", "jpg", "jpeg", "webp"] @@ -408,34 +550,39 @@ async def upload_dataset_image(self, dataset_id: str, file_path: str): payload = {"extension": extension} self.___logger.debug(f"Requested to upload dataset_id={dataset_id} from {file_path}: POST {url}") + session = await self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - async with self.___session.post(url, json=payload, headers=headers_copy) as response: + async with session.post(url, json=payload) as response: response.raise_for_status() data = await response.json() + await session.close() self.___logger.debug( - f"Dataset with dataset_id={dataset_id} started to upload from {file_path}:" f" {response}" + f"Dataset with dataset_id={dataset_id} started to upload from {file_path} as " + f"{data['uploadDatasetImage']['url']}" ) upload_url = data["uploadDatasetImage"]["url"] fields = json.loads(data["uploadDatasetImage"]["fields"]) + self.___logger.debug(f"Uploading dataset_id={dataset_id} from {file_path}: POST {url}") async with aiofiles.open(file_path, "rb") as file: file_data = await file.read() - - fields.update({"file": file_data}) - self.___logger.debug(f"Uploading dataset_id={dataset_id} from {file_path}: POST {url}") - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - async with self.___session.post(upload_url, data=fields, headers=headers_copy) as response: - response.raise_for_status() - response_text = await response.text() - self.___logger.debug( - f"Dataset with dataset_id={dataset_id} uploaded using {file_path}:" f" {response_text}" - ) - return response_text + data = aiohttp.FormData() + for key, value in fields.items(): + data.add_field(key, value) + data.add_field("file", file_data, filename=file_path, content_type=mimetypes.guess_type(file_path)[0]) + session = await self.___get_client_session("post", empty=True) + async with session.post(upload_url, data=fields) as response: + response.raise_for_status() + response_text = await response.text() + self.___logger.debug( + f"Dataset with dataset_id={dataset_id} uploaded using {file_path}:" f" {response_text}" + ) + await session.close() + return response_text except Exception as error: self.___logger.error(f"Error occurred uploading dataset: {str(error)}") + if not session.closed: + await session.close() raise async def upload_generated_image_to_dataset(self, dataset_id: str, generated_image_id: str): @@ -444,25 +591,31 @@ async def upload_generated_image_to_dataset(self, dataset_id: str, generated_ima :param dataset_id: The ID of the dataset to upload the image to. :param generated_image_id: The ID of the image to upload to the dataset. + :return: dataset info + :rtype: dict + + Raises: + Exception: if error occurred while upload generated image to dataset """ url = f"https://cloud.leonardo.ai/api/rest/v1/datasets/{dataset_id}/upload/gen" payload = {"generatedImageId": generated_image_id} self.___logger.debug( f"Requested to upload generated_image_id={generated_image_id} " f"to dataset_id={dataset_id}: POST {url}" ) + session = await self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - async with self.___session.post(url, json=payload, headers=headers_copy) as response: + async with session.post(url, json=payload) as response: response.raise_for_status() - response = await response.json() self.___logger.debug( f"Image with image_id={generated_image_id} has been uploaded to " f"dataset_id={dataset_id}: {response}" ) + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while upload generated image to dataset: {str(error)}") + if not session.closed: + await session.close() raise async def train_custom_model( @@ -489,6 +642,11 @@ async def train_custom_model( :param resolution: The resolution for training. Must be 512 or 768. :param sd_version: The base version of stable diffusion to use if not using a custom model. :param strength: When training using the PIXEL_ART model type, this influences the training strength. + :return: dataset info + :rtype: dict + + Raises: + Exception: if error occurred while train custom model """ # pylint: disable=too-many-locals url = "https://cloud.leonardo.ai/api/rest/v1/models" @@ -504,36 +662,43 @@ async def train_custom_model( "strength": strength, } self.___logger.debug(f"Requested to train custom model with payload {payload}: POST {url}") + session = await self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - async with self.___session.post(url, json=payload, headers=headers_copy) as response: + async with session.post(url, json=payload) as response: response.raise_for_status() - post_response = await response.text() - self.___logger.debug(f"Custom modal has been trained: {post_response}") - return post_response + self.___logger.debug(f"Custom modal has been trained: {response}") + await session.close() + return response except Exception as error: self.___logger.error(f"Error training custom model: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def get_custom_model_by_id(self, model_id: str): """ This endpoint gets the specific custom model. :param model_id: The ID of the custom model to return. + :return: custom model info + + Raises: + Exception: if error occurred while get custom model by id """ url = f"https://cloud.leonardo.ai/api/rest/v1/models/{model_id}" self.___logger.debug(f"Requested to obtain custom model by model_id={model_id}: GET {url}") + session = await self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - async with self.___session.get(url, headers=headers_copy) as response: + async with session.get(url) as response: response.raise_for_status() response_text = await response.text() self.___logger.debug(f"Custom modal has been trained: {response_text}") + await session.close() return response_text except Exception as error: self.___logger.error(f"Error obtaining custom model: {str(error)}") + if not session.closed: + await session.close() raise async def delete_custom_model_by_id(self, model_id: str): @@ -541,34 +706,48 @@ async def delete_custom_model_by_id(self, model_id: str): This endpoint will delete a specific custom model. :param model_id: The ID of the model to delete. + :type model_id: str + + Raises: + Exception: if error occurred while delete custom model by id """ url = f"https://cloud.leonardo.ai/api/rest/v1/models/{model_id}" self.___logger.debug(f"Requested to delete custom model by model_id={model_id}: GET {url}") + session = await self.___get_client_session("delete") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - async with self.___session.delete(url, headers=headers_copy) as response: + async with session.delete(url) as response: response.raise_for_status() - response_text = await response.text() - - self.___logger.debug(f"Custom modal has been deleted: {response_text}") - return response_text + self.___logger.debug(f"Custom modal has been deleted: {response}") + await session.close() + return response except Exception as error: self.___logger.error(f"Error delete custom model: {str(error)}") - raise + if not session.closed: + await session.close() + raise error - async def wait_for_image_generation(self, generation_id, image_index=0, poll_interval=5, timeout=120): + async def wait_for_image_generation( + self, generation_id: str, image_index: int = 0, poll_interval: int = 5, timeout: int = 120 + ): """ This method waits for the completion of image generation. :param generation_id: The ID of the generation to check. + :type generation_id: str :param image_index: (Optional) The index of the specific image to wait for. Default is 0. + :type image_index: int, optional :param poll_interval: (Optional) The time interval in seconds between each check. Default is 5 seconds. + :type poll_interval: int, optional :param timeout: (Optional) Waiting timeout. Default is 120 seconds. - + :type timeout: int, optional + :return: The completed image(s) once generation is complete. + :rtype: dict :raises IndexError: If an invalid image_index is provided. + :raises TimeoutError: If the image has not been generated in timeout seconds. - :return: The completed image(s) once generation is complete. + Raises: + TimeoutError: if image has not been generated in timeout seconds + IndexError: if incorrect image index """ timeout_counter = 0 while True: