From 2fbdfc7bb050e45976355b68b59564a965797195 Mon Sep 17 00:00:00 2001 From: wwakabobik Date: Fri, 24 Nov 2023 14:34:15 +0100 Subject: [PATCH 1/6] Bump to 0.0.9 --- src/leonardo_api/leonardo_async.py | 401 +++++++++++++++++++++-------- 1 file changed, 290 insertions(+), 111 deletions(-) diff --git a/src/leonardo_api/leonardo_async.py b/src/leonardo_api/leonardo_async.py index 25f0e9e..de7cb7f 100644 --- a/src/leonardo_api/leonardo_async.py +++ b/src/leonardo_api/leonardo_async.py @@ -5,20 +5,21 @@ Copyright (c) 2023. All rights reserved. Created: 28.08.2023 -Last Modified: 30.09.2023 +Last Modified: 24.11.2023 Description: This file contains asynchronous implementation for Leonardo.ai API """ -import asyncio import json import logging +import mimetypes import os from typing import Optional import aiofiles import aiohttp +import asyncio from .logger_config import setup_logger @@ -29,7 +30,7 @@ class Leonardo: Parameters: auth_token (str): Auth Bearer token. Required. - logger (logging.Logger, optional): default logger. . Default will be initialized if not provided. + logger (logging.Logger, optional): default logger. Default will be initialized if not provided. """ def __init__(self, auth_token: str, logger: Optional[logging.Logger] = None): @@ -37,32 +38,60 @@ def __init__(self, auth_token: str, logger: Optional[logging.Logger] = None): Constructs all the necessary attributes for the Leonardo object. :param auth_token: Auth Bearer token. Required. - :param logger: default logger. . Default will be initialized if not provided. + :type auth_token: str + :param logger: default logger. Default will be initialized if not provided. + :type logger: logging.Logger, optional """ - self.___session = aiohttp.ClientSession(headers={"Authorization": f"Bearer {auth_token}"}) + self.___auth_token = auth_token self.___logger = logger if logger else setup_logger("Leonardo", "leonardo_async.log") - self.___get_headers = {"content-type": "application/json"} - self.___post_headers = {"accept": "application/json", "content-type": "application/json"} self.___logger.debug("Leonardo init complete") + async def ___get_client_session(self, request_type: str = "get", empty: bool = False): + """ + This method returns aiohttp.ClientSession with headers. + + :param request_type: type of request: "get" or "post" + :type request_type: str + :param empty: is True if headers will be empty + :type empty: bool + :return: client session with headers + :rtype: aiohttp.ClientSession + """ + headers = {} + if not empty: + headers = {"Authorization": f"Bearer {self.___auth_token}"} + if request_type.lower() == "get" or request_type.lower() == "delete": + headers.update({"content-type": "application/json"}) + if request_type.lower() == "post": + headers.update({"accept": "application/json", "content-type": "application/json"}) + return aiohttp.ClientSession(headers=headers) + async def get_user_info(self): """ This endpoint will return your user information, including your user ID. + + :return: user info + :rtype: dict + + Raises: + Exception: if error occurred while getting user info """ url = "https://cloud.leonardo.ai/api/rest/v1/me" self.___logger.debug(f"Requesting user info: GET {url}") + session = await self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - async with self.___session.get(url, headers=headers_copy) as response: + async with session.get(url) as response: response.raise_for_status() response = await response.json() self.___logger.debug(f"User info: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while getting user info: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def post_generations( self, @@ -90,25 +119,49 @@ async def post_generations( This endpoint will generate images. :param prompt: The prompt used to generate images. + :type prompt: str :param negative_prompt: The negative prompt used for the image generation. + :type negative_prompt: str, optional :param model_id: The model ID used for the image generation. + :type model_id: str, optional :param sd_version: The base version of stable diffusion to use if not using a custom model. + :type sd_version: str, optional :param num_images: The number of images to generate. Default is 4. + :type num_images: int, optional :param width: The width of the images. Default is 512px. + :type width: int, optional :param height: The height of the images. Default is 512px. + :type height: int, optional :param num_inference_steps: The number of inference steps for generation. Must be from 40 to 60. Default is 40. + :type num_inference_steps: int, optional :param guidance_scale: How strongly the generation should reflect the prompt. Number from 1 to 20. Default is 7. + :type guidance_scale: int, optional :param init_generation_image_id: The ID of an existing image to use in image2image. + :type init_generation_image_id: str, optional :param init_image_id: The ID of an Init Image to use in image2image. + :type init_image_id: str, optional :param init_strength: How strongly the generated images should reflect the original image in image2image. + :type init_strength: float, optional :param scheduler: The scheduler to generate images with. + :type scheduler: str, optional :param preset_style: The style to generate images with. + :type preset_style: str, optional :param tiling: Whether the generated images should tile on all axis. Default is False. + :type tiling: bool, optional :param public: Whether the generated images should show in the community feed. Default is False. + :type public: bool, optional :param prompt_magic: Enable to use Prompt Magic. Default is True. + :type prompt_magic: bool, optional :param control_net: Enable to use ControlNet. Requires an init image to be provided. Requires a model based on SD v1.5. Default is False. + :type control_net: bool, optional :param control_net_type: The type of ControlNet to use. + :type control_net_type: str, optional + :return: generation response + :rtype: str + + Raises: + Exception: if error occurred while post generations """ # pylint: disable=too-many-locals url = "https://cloud.leonardo.ai/api/rest/v1/generations" @@ -134,36 +187,46 @@ async def post_generations( "controlNetType": control_net_type, } self.___logger.debug(f"Requesting post generations: POST {url} with payload: {payload}") + session = await self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - async with self.___session.post(url, json=payload, headers=headers_copy) as response: + async with session.post(url, json=payload) as response: response.raise_for_status() response = await response.json() self.___logger.debug(f"Post generations: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while post generations: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def get_single_generation(self, generation_id: str): """ This endpoint will provide information about a specific generation. :param generation_id: The ID of the generation to return. + :type generation_id: str + :return: generation info + :rtype: dict + + Raises: + Exception: if error occurred while get single generation """ url = f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}" self.___logger.debug(f"Requested single generations: GET {url} with generation_id={generation_id}") + session = await self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - async with self.___session.get(url, headers=headers_copy) as response: + async with session.get(url) as response: response.raise_for_status() response = await response.json() self.___logger.debug(f"Single generations: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while get single generations: {str(error)}") + if not session.closed: + await session.close() raise async def delete_single_generation(self, generation_id: str): @@ -171,20 +234,27 @@ async def delete_single_generation(self, generation_id: str): This endpoint deletes a specific generation. :param generation_id: The ID of the generation to delete. + :type generation_id: str + :return: generation info + :rtype: dict + + Raises: + Exception: if error occurred while delete single generation """ url = f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}" self.___logger.debug(f"Delete generations with generation_id={generation_id}: DELETE {url}") + session = await self.___get_client_session("delete") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - async with self.___session.delete(url, headers=headers_copy) as response: + async with session.delete(url) as response: response.raise_for_status() - response = await response.json() self.___logger.debug(f"Generations {generation_id} has been deleted: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while delete generation: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def get_generations_by_user(self, user_id: str, offset: int = 0, limit: int = 10): """ @@ -193,27 +263,41 @@ async def get_generations_by_user(self, user_id: str, offset: int = 0, limit: in :param user_id: The ID of the user. :param offset: The offset for pagination. :param limit: The limit for pagination. + :return: generations + :rtype: dict + + Raises: + Exception: if error occurred while get generations by user """ url = f"https://cloud.leonardo.ai/api/rest/v1/generations/user/{user_id}" params = {"offset": offset, "limit": limit} self.___logger.debug(f"Requested generations for {user_id} with params {params}: GET {url}") + session = await self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - async with self.___session.get(url, params=params, headers=headers_copy) as response: + async with session.get(url, params=params) as response: response.raise_for_status() response = await response.json() self.___logger.debug(f"Generations for user {user_id} are: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while obtaining user's generations: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def upload_init_image(self, file_path: str): """ This endpoint returns pre-signed details to upload an init image to S3. :param file_path: The path to the image file. + :type file_path: str + :return: generation_id + :rtype: str + + Raises: + ValueError: if invalid file extension + Exception: if error occurred while upload init image """ valid_extensions = ["png", "jpg", "jpeg", "webp"] extension = os.path.splitext(file_path)[1].strip(".") @@ -223,50 +307,62 @@ async def upload_init_image(self, file_path: str): url = "https://cloud.leonardo.ai/api/rest/v1/init-image" payload = {"extension": extension} self.___logger.debug(f"Init image {file_path} upload requested with payload = {payload}: POST {url}") + session = await self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - async with self.___session.post(url, json=payload, headers=headers_copy) as response: + async with session.post(url, json=payload) as response: response.raise_for_status() data = await response.json() - self.___logger.debug(f"Init image {file_path} initiated: {data}") + await session.close() + self.___logger.debug(f"Init image {file_path} initiated as: {data['uploadInitImage']['url']}") + generation_id = data["uploadInitImage"]["id"] upload_url = data["uploadInitImage"]["url"] fields = json.loads(data["uploadInitImage"]["fields"]) + self.___logger.debug(f"Init image {file_path} uploading with as binary: POST {upload_url}") async with aiofiles.open(file_path, "rb") as file: file_data = await file.read() - - fields.update({"file": file_data}) - self.___logger.debug(f"Init image {file_path} uploading with as binary: POST {upload_url}") - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - async with self.___session.post(upload_url, data=fields, headers=headers_copy) as response: - response.raise_for_status() - response_text = await response.text() - self.___logger.debug(f"Init image {file_path} has been uploaded: {response_text}") - return response_text + data = aiohttp.FormData() + for key, value in fields.items(): + data.add_field(key, value) + data.add_field("file", file_data, filename=file_path, content_type=mimetypes.guess_type(file_path)[0]) + session = await self.___get_client_session("post", empty=True) + async with session.post(upload_url, data=data) as response: + response.raise_for_status() + self.___logger.debug(f"Init image {file_path} has been uploaded, generation_id is: {generation_id}") + await session.close() + return generation_id except Exception as error: self.___logger.error(f"Error occurred while upload init image: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def get_single_init_image(self, image_id: str): """ This endpoint will return a single init image. :param image_id: The ID of the init image to return. + :type image_id: str + :return: init image + :rtype: dict + + Raises: + Exception: if error occurred while get single init image """ url = f"https://cloud.leonardo.ai/api/rest/v1/init-image/{image_id}" self.___logger.debug(f"Requested single image with image_id={image_id}: GET {url}") + session = await self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - async with self.___session.get(url, headers=headers_copy) as response: + async with session.get(url) as response: response.raise_for_status() response = await response.json() self.___logger.debug(f"Single image provided: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while obtain single init image: {str(error)}") + if not session.closed: + await session.close() raise async def delete_init_image(self, image_id: str): @@ -274,61 +370,82 @@ async def delete_init_image(self, image_id: str): This endpoint deletes an init image. :param image_id: The ID of the init image to delete. + :type image_id: str + + Raises: + Exception: if error occurred while delete init image """ url = f"https://cloud.leonardo.ai/api/rest/v1/init-image/{image_id}" self.___logger.debug(f"Requested to delete single image with image_id={image_id}: DELETE {url}") + session = await self.___get_client_session("delete") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - async with self.___session.delete(url, headers=headers_copy) as response: + async with session.delete(url) as response: response.raise_for_status() - response = await response.json() self.___logger.debug(f"Single image deleted: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while deleting init image: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def create_upscale(self, image_id: str): """ This endpoint will create an upscale for the provided image ID. :param image_id: The ID of the image to upscale. + :type image_id: str + :return: upscale info + :rtype: dict + + Raises: + Exception: if error occurred while create upscale """ url = "https://cloud.leonardo.ai/api/rest/v1/variations/upscale" payload = {"id": image_id} self.___logger.debug(f"Requested to upscale image with payload {payload}: POST {url}") + session = await self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - async with self.___session.post(url, json=payload, headers=headers_copy) as response: + async with session.post(url, json=payload) as response: response.raise_for_status() response = await response.json() self.___logger.debug(f"Upscale created: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while up-scaling image: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def get_variation_by_id(self, generation_id: str): """ This endpoint will get the variation by ID. :param generation_id: The ID of the variation to get. + :type generation_id: str + :return: variation info + :rtype: dict + + Raises: + Exception: if error occurred while get variation by id """ url = f"https://cloud.leonardo.ai/api/rest/v1/variations/{generation_id}" self.___logger.debug(f"Requested to obtain variation by id {generation_id}: GET {url}") + session = await self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - async with self.___session.get(url, headers=headers_copy) as response: + async with session.get(url) as response: response.raise_for_status() response = await response.json() self.___logger.debug(f"Get variation by ID: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while get variation by id: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def create_dataset(self, name: str, description: Optional[str] = None): """ @@ -336,39 +453,52 @@ async def create_dataset(self, name: str, description: Optional[str] = None): :param name: The name of the dataset. :param description: A description for the dataset. + :return: dataset info + :rtype: dict + + Raises: + Exception: if error occurred while create dataset """ url = "https://cloud.leonardo.ai/api/rest/v1/datasets" payload = {"name": name, "description": description} self.___logger.debug(f"Requested to create dataset with payload {payload}: POST {url}") + session = await self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - async with self.___session.post(url, json=payload, headers=headers_copy) as response: + async with session.post(url, json=payload) as response: response.raise_for_status() - response = await response.json() self.___logger.debug(f"Dataset has been created: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while create dataset: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def get_dataset_by_id(self, dataset_id: str): """ This endpoint gets the specific dataset. :param dataset_id: The ID of the dataset to return. + :return: dataset info + + Raises: + Exception: if error occurred while get dataset by id """ url = f"https://cloud.leonardo.ai/api/rest/v1/datasets/{dataset_id}" self.___logger.debug(f"Requested to obtain dataset dataset_id={dataset_id}: GET {url}") + session = await self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - async with self.___session.get(url, headers=headers_copy.update(self.___get_headers)) as response: + async with session.get(url) as response: response.raise_for_status() response = await response.json() self.___logger.debug(f"Dataset with dataset_id={dataset_id} provided: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while get dataset: {str(error)}") + if not session.closed: + await session.close() raise async def delete_dataset_by_id(self, dataset_id: str): @@ -376,20 +506,27 @@ async def delete_dataset_by_id(self, dataset_id: str): This endpoint deletes the specific dataset. :param dataset_id: The ID of the dataset to delete. + :type dataset_id: str + :return: dataset info + :rtype: dict + + Raises: + Exception: if error occurred while delete dataset by id """ url = f"https://cloud.leonardo.ai/api/rest/v1/datasets/{dataset_id}" self.___logger.debug(f"Requested to delete dataset dataset_id={dataset_id}: DELETE {url}") + session = await self.___get_client_session("delete") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - async with self.___session.delete(url, headers=headers_copy) as response: + async with session.delete(url) as response: response.raise_for_status() - response = await response.json() self.___logger.debug(f"Dataset with dataset_id={dataset_id} has been deleted: {response}") + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while delete dataset: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def upload_dataset_image(self, dataset_id: str, file_path: str): """ @@ -397,6 +534,11 @@ async def upload_dataset_image(self, dataset_id: str, file_path: str): :param dataset_id: The ID of the dataset to which the image will be uploaded. :param file_path: The path to the image file. + :return: dataset info + :rtype: dict + + Raises: + ValueError: if invalid file extension """ # pylint: disable=too-many-locals valid_extensions = ["png", "jpg", "jpeg", "webp"] @@ -408,34 +550,39 @@ async def upload_dataset_image(self, dataset_id: str, file_path: str): payload = {"extension": extension} self.___logger.debug(f"Requested to upload dataset_id={dataset_id} from {file_path}: POST {url}") + session = await self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - async with self.___session.post(url, json=payload, headers=headers_copy) as response: + async with session.post(url, json=payload) as response: response.raise_for_status() data = await response.json() + await session.close() self.___logger.debug( - f"Dataset with dataset_id={dataset_id} started to upload from {file_path}:" f" {response}" + f"Dataset with dataset_id={dataset_id} started to upload from {file_path} as " + f"{data['uploadDatasetImage']['url']}" ) upload_url = data["uploadDatasetImage"]["url"] fields = json.loads(data["uploadDatasetImage"]["fields"]) + self.___logger.debug(f"Uploading dataset_id={dataset_id} from {file_path}: POST {url}") async with aiofiles.open(file_path, "rb") as file: file_data = await file.read() - - fields.update({"file": file_data}) - self.___logger.debug(f"Uploading dataset_id={dataset_id} from {file_path}: POST {url}") - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - async with self.___session.post(upload_url, data=fields, headers=headers_copy) as response: - response.raise_for_status() - response_text = await response.text() - self.___logger.debug( - f"Dataset with dataset_id={dataset_id} uploaded using {file_path}:" f" {response_text}" - ) - return response_text + data = aiohttp.FormData() + for key, value in fields.items(): + data.add_field(key, value) + data.add_field("file", file_data, filename=file_path, content_type=mimetypes.guess_type(file_path)[0]) + session = await self.___get_client_session("post", empty=True) + async with session.post(upload_url, data=fields) as response: + response.raise_for_status() + response_text = await response.text() + self.___logger.debug( + f"Dataset with dataset_id={dataset_id} uploaded using {file_path}:" f" {response_text}" + ) + await session.close() + return response_text except Exception as error: self.___logger.error(f"Error occurred uploading dataset: {str(error)}") + if not session.closed: + await session.close() raise async def upload_generated_image_to_dataset(self, dataset_id: str, generated_image_id: str): @@ -444,25 +591,31 @@ async def upload_generated_image_to_dataset(self, dataset_id: str, generated_ima :param dataset_id: The ID of the dataset to upload the image to. :param generated_image_id: The ID of the image to upload to the dataset. + :return: dataset info + :rtype: dict + + Raises: + Exception: if error occurred while upload generated image to dataset """ url = f"https://cloud.leonardo.ai/api/rest/v1/datasets/{dataset_id}/upload/gen" payload = {"generatedImageId": generated_image_id} self.___logger.debug( f"Requested to upload generated_image_id={generated_image_id} " f"to dataset_id={dataset_id}: POST {url}" ) + session = await self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - async with self.___session.post(url, json=payload, headers=headers_copy) as response: + async with session.post(url, json=payload) as response: response.raise_for_status() - response = await response.json() self.___logger.debug( f"Image with image_id={generated_image_id} has been uploaded to " f"dataset_id={dataset_id}: {response}" ) + await session.close() return response except Exception as error: self.___logger.error(f"Error occurred while upload generated image to dataset: {str(error)}") + if not session.closed: + await session.close() raise async def train_custom_model( @@ -489,6 +642,11 @@ async def train_custom_model( :param resolution: The resolution for training. Must be 512 or 768. :param sd_version: The base version of stable diffusion to use if not using a custom model. :param strength: When training using the PIXEL_ART model type, this influences the training strength. + :return: dataset info + :rtype: dict + + Raises: + Exception: if error occurred while train custom model """ # pylint: disable=too-many-locals url = "https://cloud.leonardo.ai/api/rest/v1/models" @@ -504,36 +662,43 @@ async def train_custom_model( "strength": strength, } self.___logger.debug(f"Requested to train custom model with payload {payload}: POST {url}") + session = await self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - async with self.___session.post(url, json=payload, headers=headers_copy) as response: + async with session.post(url, json=payload) as response: response.raise_for_status() - post_response = await response.text() - self.___logger.debug(f"Custom modal has been trained: {post_response}") - return post_response + self.___logger.debug(f"Custom modal has been trained: {response}") + await session.close() + return response except Exception as error: self.___logger.error(f"Error training custom model: {str(error)}") - raise + if not session.closed: + await session.close() + raise error async def get_custom_model_by_id(self, model_id: str): """ This endpoint gets the specific custom model. :param model_id: The ID of the custom model to return. + :return: custom model info + + Raises: + Exception: if error occurred while get custom model by id """ url = f"https://cloud.leonardo.ai/api/rest/v1/models/{model_id}" self.___logger.debug(f"Requested to obtain custom model by model_id={model_id}: GET {url}") + session = await self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - async with self.___session.get(url, headers=headers_copy) as response: + async with session.get(url) as response: response.raise_for_status() response_text = await response.text() self.___logger.debug(f"Custom modal has been trained: {response_text}") + await session.close() return response_text except Exception as error: self.___logger.error(f"Error obtaining custom model: {str(error)}") + if not session.closed: + await session.close() raise async def delete_custom_model_by_id(self, model_id: str): @@ -541,34 +706,48 @@ async def delete_custom_model_by_id(self, model_id: str): This endpoint will delete a specific custom model. :param model_id: The ID of the model to delete. + :type model_id: str + + Raises: + Exception: if error occurred while delete custom model by id """ url = f"https://cloud.leonardo.ai/api/rest/v1/models/{model_id}" self.___logger.debug(f"Requested to delete custom model by model_id={model_id}: GET {url}") + session = await self.___get_client_session("delete") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - async with self.___session.delete(url, headers=headers_copy) as response: + async with session.delete(url) as response: response.raise_for_status() - response_text = await response.text() - - self.___logger.debug(f"Custom modal has been deleted: {response_text}") - return response_text + self.___logger.debug(f"Custom modal has been deleted: {response}") + await session.close() + return response except Exception as error: self.___logger.error(f"Error delete custom model: {str(error)}") - raise + if not session.closed: + await session.close() + raise error - async def wait_for_image_generation(self, generation_id, image_index=0, poll_interval=5, timeout=120): + async def wait_for_image_generation( + self, generation_id: str, image_index: int = 0, poll_interval: int = 5, timeout: int = 120 + ): """ This method waits for the completion of image generation. :param generation_id: The ID of the generation to check. + :type generation_id: str :param image_index: (Optional) The index of the specific image to wait for. Default is 0. + :type image_index: int, optional :param poll_interval: (Optional) The time interval in seconds between each check. Default is 5 seconds. + :type poll_interval: int, optional :param timeout: (Optional) Waiting timeout. Default is 120 seconds. - + :type timeout: int, optional + :return: The completed image(s) once generation is complete. + :rtype: dict :raises IndexError: If an invalid image_index is provided. + :raises TimeoutError: If the image has not been generated in timeout seconds. - :return: The completed image(s) once generation is complete. + Raises: + TimeoutError: if image has not been generated in timeout seconds + IndexError: if incorrect image index """ timeout_counter = 0 while True: From 6375c5f6105f58d2e16ff2794efdc046be03959f Mon Sep 17 00:00:00 2001 From: wwakabobik Date: Fri, 24 Nov 2023 14:37:05 +0100 Subject: [PATCH 2/6] Bump to 0.0.9 --- src/leonardo_api/leonardo_async.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/leonardo_api/leonardo_async.py b/src/leonardo_api/leonardo_async.py index de7cb7f..875e9ac 100644 --- a/src/leonardo_api/leonardo_async.py +++ b/src/leonardo_api/leonardo_async.py @@ -17,9 +17,9 @@ import os from typing import Optional +import asyncio import aiofiles import aiohttp -import asyncio from .logger_config import setup_logger @@ -286,7 +286,7 @@ async def get_generations_by_user(self, user_id: str, offset: int = 0, limit: in await session.close() raise error - async def upload_init_image(self, file_path: str): + async def upload_init_image(self, file_path: str): # pylint: disable=too-many-locals """ This endpoint returns pre-signed details to upload an init image to S3. From dae36114f8f893c7d969996b5777a76ed874a12b Mon Sep 17 00:00:00 2001 From: wwakabobik Date: Fri, 24 Nov 2023 16:00:26 +0100 Subject: [PATCH 3/6] Bump to 0.0.10 --- CHANGELOG.md | 6 + src/leonardo_api/leonardo_async.py | 101 ++++--- src/leonardo_api/leonardo_sync.py | 418 +++++++++++++++++++---------- 3 files changed, 347 insertions(+), 178 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f7fa1ef..cb7d956 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,3 +57,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Fixed headers update methods (possible async bug fix) + +## [0.0.10] - 2023-11-24 + +### Fixed +- Fixed image upload methods (headers should be purged before poking s3) +- Fixed session headers update to much more generic diff --git a/src/leonardo_api/leonardo_async.py b/src/leonardo_api/leonardo_async.py index 875e9ac..5ae4ae3 100644 --- a/src/leonardo_api/leonardo_async.py +++ b/src/leonardo_api/leonardo_async.py @@ -17,9 +17,9 @@ import os from typing import Optional -import asyncio import aiofiles import aiohttp +import asyncio from .logger_config import setup_logger @@ -33,7 +33,7 @@ class Leonardo: logger (logging.Logger, optional): default logger. Default will be initialized if not provided. """ - def __init__(self, auth_token: str, logger: Optional[logging.Logger] = None): + def __init__(self, auth_token: str, logger: Optional[logging.Logger] = None) -> None: """ Constructs all the necessary attributes for the Leonardo object. @@ -46,7 +46,7 @@ def __init__(self, auth_token: str, logger: Optional[logging.Logger] = None): self.___logger = logger if logger else setup_logger("Leonardo", "leonardo_async.log") self.___logger.debug("Leonardo init complete") - async def ___get_client_session(self, request_type: str = "get", empty: bool = False): + async def ___get_client_session(self, request_type: str = "get", empty: bool = False) -> aiohttp.ClientSession: """ This method returns aiohttp.ClientSession with headers. @@ -66,7 +66,7 @@ async def ___get_client_session(self, request_type: str = "get", empty: bool = F headers.update({"accept": "application/json", "content-type": "application/json"}) return aiohttp.ClientSession(headers=headers) - async def get_user_info(self): + async def get_user_info(self) -> dict: """ This endpoint will return your user information, including your user ID. @@ -114,7 +114,7 @@ async def post_generations( prompt_magic: bool = True, control_net: bool = False, control_net_type: Optional[str] = None, - ): + ) -> str: """ This endpoint will generate images. @@ -201,7 +201,7 @@ async def post_generations( await session.close() raise error - async def get_single_generation(self, generation_id: str): + async def get_single_generation(self, generation_id: str) -> dict: """ This endpoint will provide information about a specific generation. @@ -229,14 +229,14 @@ async def get_single_generation(self, generation_id: str): await session.close() raise - async def delete_single_generation(self, generation_id: str): + async def delete_single_generation(self, generation_id: str) -> aiohttp.ClientResponse: """ This endpoint deletes a specific generation. :param generation_id: The ID of the generation to delete. :type generation_id: str :return: generation info - :rtype: dict + :rtype: aiohttp.ClientResponse Raises: Exception: if error occurred while delete single generation @@ -256,13 +256,16 @@ async def delete_single_generation(self, generation_id: str): await session.close() raise error - async def get_generations_by_user(self, user_id: str, offset: int = 0, limit: int = 10): + async def get_generations_by_user(self, user_id: str, offset: int = 0, limit: int = 10) -> dict: """ This endpoint returns all generations by a specific user. :param user_id: The ID of the user. + :type user_id: str :param offset: The offset for pagination. + :type offset: int :param limit: The limit for pagination. + :type limit: int :return: generations :rtype: dict @@ -286,7 +289,7 @@ async def get_generations_by_user(self, user_id: str, offset: int = 0, limit: in await session.close() raise error - async def upload_init_image(self, file_path: str): # pylint: disable=too-many-locals + async def upload_init_image(self, file_path: str) -> str: # pylint: disable=too-many-locals """ This endpoint returns pre-signed details to upload an init image to S3. @@ -337,7 +340,7 @@ async def upload_init_image(self, file_path: str): # pylint: disable=too-many-l await session.close() raise error - async def get_single_init_image(self, image_id: str): + async def get_single_init_image(self, image_id: str) -> dict: """ This endpoint will return a single init image. @@ -365,12 +368,14 @@ async def get_single_init_image(self, image_id: str): await session.close() raise - async def delete_init_image(self, image_id: str): + async def delete_init_image(self, image_id: str) -> aiohttp.ClientResponse: """ This endpoint deletes an init image. :param image_id: The ID of the init image to delete. :type image_id: str + :return: init image + :rtype: aiohttp.ClientResponse Raises: Exception: if error occurred while delete init image @@ -390,7 +395,7 @@ async def delete_init_image(self, image_id: str): await session.close() raise error - async def create_upscale(self, image_id: str): + async def create_upscale(self, image_id: str) -> dict: """ This endpoint will create an upscale for the provided image ID. @@ -419,7 +424,7 @@ async def create_upscale(self, image_id: str): await session.close() raise error - async def get_variation_by_id(self, generation_id: str): + async def get_variation_by_id(self, generation_id: str) -> dict: """ This endpoint will get the variation by ID. @@ -447,14 +452,16 @@ async def get_variation_by_id(self, generation_id: str): await session.close() raise error - async def create_dataset(self, name: str, description: Optional[str] = None): + async def create_dataset(self, name: str, description: Optional[str] = None) -> aiohttp.ClientResponse: """ This endpoint creates a new dataset. :param name: The name of the dataset. + :type name: str :param description: A description for the dataset. + :type description: str, optional :return: dataset info - :rtype: dict + :rtype: aiohttp.ClientResponse Raises: Exception: if error occurred while create dataset @@ -475,12 +482,14 @@ async def create_dataset(self, name: str, description: Optional[str] = None): await session.close() raise error - async def get_dataset_by_id(self, dataset_id: str): + async def get_dataset_by_id(self, dataset_id: str) -> dict: """ This endpoint gets the specific dataset. :param dataset_id: The ID of the dataset to return. + :type dataset_id: str :return: dataset info + :rtype: dict Raises: Exception: if error occurred while get dataset by id @@ -501,14 +510,14 @@ async def get_dataset_by_id(self, dataset_id: str): await session.close() raise - async def delete_dataset_by_id(self, dataset_id: str): + async def delete_dataset_by_id(self, dataset_id: str) -> aiohttp.ClientResponse: """ This endpoint deletes the specific dataset. :param dataset_id: The ID of the dataset to delete. :type dataset_id: str - :return: dataset info - :rtype: dict + :return: response + :rtype: aiohttp.ClientResponse Raises: Exception: if error occurred while delete dataset by id @@ -528,14 +537,16 @@ async def delete_dataset_by_id(self, dataset_id: str): await session.close() raise error - async def upload_dataset_image(self, dataset_id: str, file_path: str): + async def upload_dataset_image(self, dataset_id: str, file_path: str) -> aiohttp.ClientResponse: """ This endpoint returns pre-signed details to upload a dataset image to S3. :param dataset_id: The ID of the dataset to which the image will be uploaded. + :type dataset_id: str :param file_path: The path to the image file. + :type file_path: str :return: dataset info - :rtype: dict + :rtype: aiohttp.ClientResponse Raises: ValueError: if invalid file extension @@ -573,26 +584,27 @@ async def upload_dataset_image(self, dataset_id: str, file_path: str): session = await self.___get_client_session("post", empty=True) async with session.post(upload_url, data=fields) as response: response.raise_for_status() - response_text = await response.text() - self.___logger.debug( - f"Dataset with dataset_id={dataset_id} uploaded using {file_path}:" f" {response_text}" - ) + self.___logger.debug(f"Dataset with dataset_id={dataset_id} uploaded using {file_path}: {response}") await session.close() - return response_text + return response except Exception as error: self.___logger.error(f"Error occurred uploading dataset: {str(error)}") if not session.closed: await session.close() raise - async def upload_generated_image_to_dataset(self, dataset_id: str, generated_image_id: str): + async def upload_generated_image_to_dataset( + self, dataset_id: str, generated_image_id: str + ) -> aiohttp.ClientResponse: """ This endpoint will upload a previously generated image to the dataset. :param dataset_id: The ID of the dataset to upload the image to. + :type dataset_id: str :param generated_image_id: The ID of the image to upload to the dataset. + :type generated_image_id: str :return: dataset info - :rtype: dict + :rtype: aiohttp.ClientResponse Raises: Exception: if error occurred while upload generated image to dataset @@ -629,21 +641,30 @@ async def train_custom_model( resolution: int = 512, sd_version: Optional[str] = None, strength: str = "MEDIUM", - ): + ) -> aiohttp.ClientResponse: """ This endpoint will train a new custom model. :param name: The name of the model. + :type name: str :param description: The description of the model. + :type description: str, optional :param dataset_id: The ID of the dataset to train the model on. + :type dataset_id: str :param instance_prompt: The instance prompt to use during training. + :type instance_prompt: str :param model_type: The category the most accurately reflects the model. + :type model_type: str, optional :param nsfw: mark for NSFW model. Default is False. + :type nsfw: bool, optional :param resolution: The resolution for training. Must be 512 or 768. + :type resolution: int, optional :param sd_version: The base version of stable diffusion to use if not using a custom model. + :type sd_version: str, optional :param strength: When training using the PIXEL_ART model type, this influences the training strength. + :type strength: str, optional :return: dataset info - :rtype: dict + :rtype: aiohttp.ClientResponse Raises: Exception: if error occurred while train custom model @@ -675,12 +696,14 @@ async def train_custom_model( await session.close() raise error - async def get_custom_model_by_id(self, model_id: str): + async def get_custom_model_by_id(self, model_id: str) -> dict: """ This endpoint gets the specific custom model. :param model_id: The ID of the custom model to return. + :type model_id: str :return: custom model info + :rtype: dict Raises: Exception: if error occurred while get custom model by id @@ -691,22 +714,24 @@ async def get_custom_model_by_id(self, model_id: str): try: async with session.get(url) as response: response.raise_for_status() - response_text = await response.text() - self.___logger.debug(f"Custom modal has been trained: {response_text}") + response = await response.json() + self.___logger.debug(f"Custom modal has been trained: {response}") await session.close() - return response_text + return response except Exception as error: self.___logger.error(f"Error obtaining custom model: {str(error)}") if not session.closed: await session.close() raise - async def delete_custom_model_by_id(self, model_id: str): + async def delete_custom_model_by_id(self, model_id: str) -> aiohttp.ClientResponse: """ This endpoint will delete a specific custom model. :param model_id: The ID of the model to delete. - :type model_id: str + :type model_id: aiohttp.ClientResponse + :return: custom model info + :rtype: aiohttp.ClientResponse Raises: Exception: if error occurred while delete custom model by id @@ -728,7 +753,7 @@ async def delete_custom_model_by_id(self, model_id: str): async def wait_for_image_generation( self, generation_id: str, image_index: int = 0, poll_interval: int = 5, timeout: int = 120 - ): + ) -> dict: """ This method waits for the completion of image generation. diff --git a/src/leonardo_api/leonardo_sync.py b/src/leonardo_api/leonardo_sync.py index 466765f..c928962 100644 --- a/src/leonardo_api/leonardo_sync.py +++ b/src/leonardo_api/leonardo_sync.py @@ -5,7 +5,7 @@ Copyright (c) 2023. All rights reserved. Created: 29.08.2023 -Last Modified: 30.08.2023 +Last Modified: 24.11.2023 Description: This file contains synchronous implementation for Leonardo.ai API @@ -31,37 +31,64 @@ class Leonardo: logger (logging.Logger, optional): default logger. Default will be initialized if not provided. """ - def __init__(self, auth_token: str, logger: Optional[logging.Logger] = None): + def __init__(self, auth_token: str, logger: Optional[logging.Logger] = None) -> None: """ Constructs all the necessary attributes for the Leonardo object. :param auth_token: Auth Bearer token. Required. + :type auth_token: str :param logger: default logger. Default will be initialized if not provided. + :type logger: logging.Logger, optional """ - self.___session = requests.Session() - self.___session.headers.update({"Authorization": f"Bearer {auth_token}"}) + self.___auth_token = auth_token self.___logger = logger if logger else setup_logger("Leonardo", "leonardo_async.log") - self.___get_headers = {"content-type": "application/json"} - self.___post_headers = {"accept": "application/json", "content-type": "application/json"} self.___logger.debug("Leonardo init complete") - def get_user_info(self): + def ___get_client_session(self, request_type: str = "get", empty: bool = False) -> requests.Session: + """ + This method returns requests.Session() with headers. + + :param request_type: type of request: "get" or "post" + :type request_type: str + :param empty: is True if headers will be empty + :type empty: bool + :return: client session with headers + :rtype: requests.Session + """ + headers = {} + if not empty: + headers = {"Authorization": f"Bearer {self.___auth_token}"} + if request_type.lower() == "get" or request_type.lower() == "delete": + headers.update({"content-type": "application/json"}) + if request_type.lower() == "post": + headers.update({"accept": "application/json", "content-type": "application/json"}) + session = requests.Session() + session.headers.update(headers) + return session + + def get_user_info(self) -> dict: """ This endpoint will return your user information, including your user ID. + + :return: The user information. + :rtype: dict + + Raises: + Exception: If an error occurs while getting user info. """ url = "https://cloud.leonardo.ai/api/rest/v1/me" self.___logger.debug(f"Requesting user info: GET {url}") + session = self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - response = self.___session.get(url, headers=headers_copy) + response = session.get(url) response.raise_for_status() response = response.json() self.___logger.debug(f"User info: {response}") + session.close() return response except Exception as error: self.___logger.error(f"Error occurred while getting user info: {str(error)}") - raise + raise error def post_generations( self, @@ -84,30 +111,51 @@ def post_generations( prompt_magic: bool = True, control_net: bool = False, control_net_type: Optional[str] = None, - ): + ) -> dict: """ This endpoint will generate images. :param prompt: The prompt used to generate images. + :type prompt: str :param negative_prompt: The negative prompt used for the image generation. + :type negative_prompt: str, optional :param model_id: The model ID used for the image generation. + :type model_id: str, optional :param sd_version: The base version of stable diffusion to use if not using a custom model. + :type sd_version: str, optional :param num_images: The number of images to generate. Default is 4. + :type num_images: int, optional :param width: The width of the images. Default is 512px. + :type width: int, optional :param height: The height of the images. Default is 512px. + :type height: int, optional :param num_inference_steps: The number of inference steps for generation. Must be from 40 to 60. Default is 40. + :type num_inference_steps: int, optional :param guidance_scale: How strongly the generation should reflect the prompt. Number from 1 to 20. Default is 7. + :type guidance_scale: int, optional :param init_generation_image_id: The ID of an existing image to use in image2image. + :type init_generation_image_id: str, optional :param init_image_id: The ID of an Init Image to use in image2image. + :type init_image_id: str, optional :param init_strength: How strongly the generated images should reflect the original image in image2image. + :type init_strength: float, optional :param scheduler: The scheduler to generate images with. + :type scheduler: str, optional :param preset_style: The style to generate images with. + :type preset_style: str, optional :param tiling: Whether the generated images should tile on all axis. Default is False. + :type tiling: bool, optional :param public: Whether the generated images should show in the community feed. Default is False. + :type public: bool, optional :param prompt_magic: Enable to use Prompt Magic. Default is True. + :type prompt_magic: bool, optional :param control_net: Enable to use ControlNet. Requires an init image to be provided. Requires a model based on SD v1.5. Default is False. + :type control_net: bool, optional :param control_net_type: The type of ControlNet to use. + :type control_net_type: str, optional + :return: The generation response. + :rtype: dict """ # pylint: disable=too-many-locals url = "https://cloud.leonardo.ai/api/rest/v1/generations" @@ -133,86 +181,109 @@ def post_generations( "controlNetType": control_net_type, } self.___logger.debug(f"Requesting post generations: POST {url} with payload: {payload}") + session = self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - response = self.___session.post(url, json=payload, headers=headers_copy) + response = session.post(url, json=payload) response.raise_for_status() response = response.json() self.___logger.debug(f"Post generations: {response}") return response except Exception as error: self.___logger.error(f"Error occurred while post generations: {str(error)}") - raise + raise error - def get_single_generation(self, generation_id: str): + def get_single_generation(self, generation_id: str) -> dict: """ This endpoint will provide information about a specific generation. :param generation_id: The ID of the generation to return. + :type generation_id: str + :return: The generation information. + :rtype: dict + + Raises: + Exception: If an error occurs while getting generation info. """ url = f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}" self.___logger.debug(f"Requested single generations: GET {url} with generation_id={generation_id}") + session = self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - response = self.___session.get(url, headers=headers_copy) + response = session.get(url) response.raise_for_status() response = response.json() self.___logger.debug(f"Single generations: {response}") return response except Exception as error: self.___logger.error(f"Error occurred while get single generations: {str(error)}") - raise + raise error - def delete_single_generation(self, generation_id: str): + def delete_single_generation(self, generation_id: str) -> requests.Response: """ This endpoint deletes a specific generation. :param generation_id: The ID of the generation to delete. + :type generation_id: str + :return: The response from the delete request. + :rtype: requests.Response + + Raises: + Exception: If an error occurs while deleting generation. """ url = f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}" self.___logger.debug(f"Delete generations with generation_id={generation_id}: DELETE {url}") + session = self.___get_client_session("delete") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - response = self.___session.delete(url, headers=headers_copy) + response = session.delete(url) response.raise_for_status() - response = response.json() self.___logger.debug(f"Generations {generation_id} has been deleted: {response}") return response except Exception as error: self.___logger.error(f"Error occurred while delete generation: {str(error)}") - raise + raise error - def get_generations_by_user(self, user_id: str, offset: int = 0, limit: int = 10): + def get_generations_by_user(self, user_id: str, offset: int = 0, limit: int = 10) -> dict: """ This endpoint returns all generations by a specific user. :param user_id: The ID of the user. + :type user_id: str :param offset: The offset for pagination. + :type offset: int :param limit: The limit for pagination. + :type limit: int + :return: The generations for the user. + :rtype: dict + + Raises: + Exception: If an error occurs while obtaining user's generations. """ url = f"https://cloud.leonardo.ai/api/rest/v1/generations/user/{user_id}" params = {"offset": offset, "limit": limit} self.___logger.debug(f"Requested generations for {user_id} with params {params}: GET {url}") + session = self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - response = self.___session.get(url, params=params, headers=headers_copy) + response = session.get(url, params=params) response.raise_for_status() response = response.json() self.___logger.debug(f"Generations for user {user_id} are: {response}") return response except Exception as error: self.___logger.error(f"Error occurred while obtaining user's generations: {str(error)}") - raise + raise error - def upload_init_image(self, file_path: str): + def upload_init_image(self, file_path: str) -> str: """ This endpoint returns pre-signed details to upload an init image to S3. :param file_path: The path to the image file. + :type: str + :return: The generation ID of the uploaded image. + :rtype: str + :raises ValueError: If an invalid file extension is provided. + + Raises: + ValueError: If an invalid file extension is provided. + Exception: If an error occurs while uploading the image. """ valid_extensions = ["png", "jpg", "jpeg", "webp"] extension = os.path.splitext(file_path)[1].strip(".") @@ -222,183 +293,218 @@ def upload_init_image(self, file_path: str): url = "https://cloud.leonardo.ai/api/rest/v1/init-image" payload = {"extension": extension} self.___logger.debug(f"Init image {file_path} upload requested with payload = {payload}: POST {url}") + session = self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - response = self.___session.post(url, json=payload, headers=headers_copy) + response = session.post(url, json=payload) response.raise_for_status() data = response.json() self.___logger.debug(f"Init image {file_path} initiated: {data}") - upload_url = data["uploadInitImage"]["url"] fields = json.loads(data["uploadInitImage"]["fields"]) + generation_id = data["uploadInitImage"]["id"] + session.close() + self.___logger.debug(f"Init image {file_path} uploading as binary: POST {upload_url}") + session = self.___get_client_session("post", empty=True) with open(file_path, "rb") as file: file_data = file.read() - - fields.update({"file": file_data}) - - self.___logger.debug(f"Init image {file_path} uploading as binary: POST {upload_url}") - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - response = self.___session.post(upload_url, data=fields, headers=headers_copy) - response.raise_for_status() - response_text = response.text - self.___logger.debug(f"Init image {file_path} has been uploaded: {response_text}") - return response_text + fields.update({"file": file_data}) + response = session.post(upload_url, data=fields) + response.raise_for_status() + self.___logger.debug(f"Init image {file_path} has been uploaded with generation_id={generation_id}") + return generation_id except Exception as error: self.___logger.error(f"Error occurred while upload init image: {str(error)}") - raise + raise error - def get_single_init_image(self, image_id: str): + def get_single_init_image(self, image_id: str) -> dict: """ This endpoint will return a single init image. :param image_id: The ID of the init image to return. + :type image_id: str + :return: The init image. + :rtype: dict """ url = f"https://cloud.leonardo.ai/api/rest/v1/init-image/{image_id}" self.___logger.debug(f"Requested single image with image_id={image_id}: GET {url}") + session = self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - response = self.___session.get(url, headers=headers_copy) + response = session.get(url) response.raise_for_status() response = response.json() self.___logger.debug(f"Single image provided: {response}") return response except Exception as error: self.___logger.error(f"Error occurred while obtain single init image: {str(error)}") - raise + raise error - def delete_init_image(self, image_id: str): + def delete_init_image(self, image_id: str) -> requests.Response: """ This endpoint deletes an init image. :param image_id: The ID of the init image to delete. + :type image_id: str + :return: The response from the delete request. + :rtype: requests.Response + + Raises: + Exception: If an error occurs while deleting init image. """ url = f"https://cloud.leonardo.ai/api/rest/v1/init-image/{image_id}" self.___logger.debug(f"Requested to delete single image with image_id={image_id}: DELETE {url}") + session = self.___get_client_session("delete") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - response = self.___session.delete(url, headers=headers_copy) + response = session.delete(url) response.raise_for_status() - response = response.json() self.___logger.debug(f"Single image deleted: {response}") + session.close() return response except Exception as error: self.___logger.error(f"Error occurred while deleting init image: {str(error)}") - raise + raise error - def create_upscale(self, image_id: str): + def create_upscale(self, image_id: str) -> requests.Response: """ This endpoint will create an upscale for the provided image ID. :param image_id: The ID of the image to upscale. + :type image_id: str + :return: The response from the create request. + :rtype: requests.Response + + Raises: + Exception: If an error occurs while creating upscale. """ url = "https://cloud.leonardo.ai/api/rest/v1/variations/upscale" payload = {"id": image_id} self.___logger.debug(f"Requested to upscale image with payload {payload}: POST {url}") + session = self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - response = self.___session.post(url, json=payload, headers=headers_copy) + response = session.post(url, json=payload) response.raise_for_status() - response = response.json() self.___logger.debug(f"Upscale created: {response}") + session.close() return response except Exception as error: self.___logger.error(f"Error occurred while up-scaling image: {str(error)}") - raise + raise error - def get_variation_by_id(self, generation_id: str): + def get_variation_by_id(self, generation_id: str) -> dict: """ This endpoint will get the variation by ID. :param generation_id: The ID of the variation to get. + :type generation_id: str + :return: The variation. + :rtype: dict + + Raises: + Exception: If an error occurs while getting variation. """ url = f"https://cloud.leonardo.ai/api/rest/v1/variations/{generation_id}" self.___logger.debug(f"Requested to obtain variation by id {generation_id}: GET {url}") + session = self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - response = self.___session.get(url, headers=headers_copy) + response = session.get(url) response.raise_for_status() response = response.json() self.___logger.debug(f"Get variation by ID: {response}") return response except Exception as error: self.___logger.error(f"Error occurred while get variation by id: {str(error)}") - raise + raise error - def create_dataset(self, name: str, description: Optional[str] = None): + def create_dataset(self, name: str, description: Optional[str] = None) -> requests.Response: """ This endpoint creates a new dataset. :param name: The name of the dataset. + :type name: str :param description: A description for the dataset. + :type description: str, optional + :return: The response from the create request. + :rtype: requests.Response + + Raises: + Exception: If an error occurs while creating dataset. """ url = "https://cloud.leonardo.ai/api/rest/v1/datasets" payload = {"name": name, "description": description} self.___logger.debug(f"Requested to create dataset with payload {payload}: POST {url}") + session = self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - response = self.___session.post(url, json=payload, headers=headers_copy) + response = session.post(url, json=payload) response.raise_for_status() - response = response.json() self.___logger.debug(f"Dataset has been created: {response}") + session.close() return response except Exception as error: self.___logger.error(f"Error occurred while create dataset: {str(error)}") - raise + raise error - def get_dataset_by_id(self, dataset_id: str): + def get_dataset_by_id(self, dataset_id: str) -> dict: """ This endpoint gets the specific dataset. :param dataset_id: The ID of the dataset to return. + :type dataset_id: str + :return: The dataset. + :rtype: dict + + Raises: + Exception: If an error occurs while getting dataset. """ url = f"https://cloud.leonardo.ai/api/rest/v1/datasets/{dataset_id}" self.___logger.debug(f"Requested to obtain dataset dataset_id={dataset_id}: GET {url}") + session = self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - response = self.___session.get(url, headers=headers_copy) + response = session.get(url) response.raise_for_status() response = response.json() self.___logger.debug(f"Dataset with dataset_id={dataset_id} provided: {response}") return response except Exception as error: self.___logger.error(f"Error occurred while get dataset: {str(error)}") - raise + raise error - def delete_dataset_by_id(self, dataset_id: str): + def delete_dataset_by_id(self, dataset_id: str) -> requests.Response: """ This endpoint deletes the specific dataset. :param dataset_id: The ID of the dataset to delete. + :type dataset_id: str + :return: The response from the delete request. + :rtype: requests.Response + + Raises: + Exception: If an error occurs while deleting dataset. """ url = f"https://cloud.leonardo.ai/api/rest/v1/datasets/{dataset_id}" self.___logger.debug(f"Requested to delete dataset dataset_id={dataset_id}: DELETE {url}") + session = self.___get_client_session("delete") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - response = self.___session.delete(url, headers=headers_copy) + response = session.delete(url) response.raise_for_status() - response = response.json() self.___logger.debug(f"Dataset with dataset_id={dataset_id} has been deleted: {response}") return response except Exception as error: self.___logger.error(f"Error occurred while delete dataset: {str(error)}") - raise + raise error - def upload_dataset_image(self, dataset_id: str, file_path: str): + def upload_dataset_image(self, dataset_id: str, file_path: str) -> requests.Response: """ This endpoint returns pre-signed details to upload a dataset image to S3. :param dataset_id: The ID of the dataset to which the image will be uploaded. + :type dataset_id: str :param file_path: The path to the image file. + :type file_path: str + :return: The response from the upload request. + :rtype: requests.Response + + Raises: + ValueError: If an invalid file extension is provided. """ # pylint: disable=too-many-locals valid_extensions = ["png", "jpg", "jpeg", "webp"] @@ -410,63 +516,65 @@ def upload_dataset_image(self, dataset_id: str, file_path: str): payload = {"extension": extension} self.___logger.debug(f"Requested to upload dataset_id={dataset_id} from {file_path}: POST {url}") + session = self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - response = self.___session.post(url, json=payload, headers=headers_copy) + response = session.post(url, json=payload) response.raise_for_status() data = response.json() self.___logger.debug( f"Dataset with dataset_id={dataset_id} started to upload from {file_path}:" f" {response}" ) - upload_url = data["uploadDatasetImage"]["url"] fields = json.loads(data["uploadDatasetImage"]["fields"]) + dataset_id = data["uploadDatasetImage"]["datasetId"] + self.___logger.debug(f"Uploading dataset_id={dataset_id} from {file_path}: POST {url}") + session = self.___get_client_session("post", empty=True) with open(file_path, "rb") as file: file_data = file.read() - - fields.update({"file": file_data}) - - self.___logger.debug(f"Uploading dataset_id={dataset_id} from {file_path}: POST {url}") - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - response = self.___session.post(upload_url, data=fields, headers=headers_copy) - response.raise_for_status() - response_text = response.text - self.___logger.debug( - f"Dataset with dataset_id={dataset_id} uploaded using {file_path}: " f"{response_text}" - ) - return response_text + fields.update({"file": file_data}) + response = session.post(upload_url, data=fields) + response.raise_for_status() + self.___logger.debug( + f"Dataset with dataset_id={dataset_id} uploaded using {file_path}" + ) + session.close() + return response except Exception as error: self.___logger.error(f"Error occurred uploading dataset: {str(error)}") - raise + raise error - def upload_generated_image_to_dataset(self, dataset_id: str, generated_image_id: str): + def upload_generated_image_to_dataset(self, dataset_id: str, generated_image_id: str) -> requests.Response: """ This endpoint will upload a previously generated image to the dataset. :param dataset_id: The ID of the dataset to upload the image to. + :type dataset_id: str :param generated_image_id: The ID of the image to upload to the dataset. + :type generated_image_id: str + :return: The response from the upload request. + :rtype: requests.Response + + Raises: + Exception: If an error occurs while uploading image to dataset. """ url = f"https://cloud.leonardo.ai/api/rest/v1/datasets/{dataset_id}/upload/gen" payload = {"generatedImageId": generated_image_id} self.___logger.debug( f"Requested to upload generated_image_id={generated_image_id} " f"to dataset_id={dataset_id}: POST {url}" ) + session = self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - response = self.___session.post(url, json=payload, headers=headers_copy) + response = session.post(url, json=payload) response.raise_for_status() - response = response.json() self.___logger.debug( f"Image with image_id={generated_image_id} has been uploaded to " f"dataset_id={dataset_id}: {response}" ) + session.close() return response except Exception as error: self.___logger.error(f"Error occurred while upload generated image to dataset: {str(error)}") - raise + raise error def train_custom_model( self, @@ -479,19 +587,30 @@ def train_custom_model( resolution: int = 512, sd_version: Optional[str] = None, strength: str = "MEDIUM", - ): + ) -> requests.Response: """ This endpoint will train a new custom model. :param name: The name of the model. + :type name: str :param description: The description of the model. + :type description: str, optional :param dataset_id: The ID of the dataset to train the model on. + :type dataset_id: str :param instance_prompt: The instance prompt to use during training. + :type instance_prompt: str :param model_type: The category the most accurately reflects the model. + :type model_type: str :param nsfw: mark for NSFW model. Default is False. + :type nsfw: bool, optional :param resolution: The resolution for training. Must be 512 or 768. + :type resolution: int, optional :param sd_version: The base version of stable diffusion to use if not using a custom model. + :type sd_version: str, optional :param strength: When training using the PIXEL_ART model type, this influences the training strength. + :type strength: str, optional + :return: The response from the create request. + :rtype: requests.Response """ # pylint: disable=too-many-locals url = "https://cloud.leonardo.ai/api/rest/v1/models" @@ -507,71 +626,90 @@ def train_custom_model( "strength": strength, } self.___logger.debug(f"Requested to train custom model with payload {payload}: POST {url}") + session = self.___get_client_session("post") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___post_headers) - response = self.___session.post(url, json=payload, headers=headers_copy) + response = session.post(url, json=payload) response.raise_for_status() - response_text = response.text - self.___logger.debug(f"Custom modal has been trained: {response_text}") - return response_text + self.___logger.debug(f"Custom modal has been trained: {response}") + session.close() + return response except Exception as error: self.___logger.error(f"Error training custom model: {str(error)}") - raise + raise error - def get_custom_model_by_id(self, model_id: str): + def get_custom_model_by_id(self, model_id: str) -> dict: """ This endpoint gets the specific custom model. :param model_id: The ID of the custom model to return. + :type model_id: str + :return: The custom model. + :rtype: dict + + Raises: + Exception: If an error occurs while getting custom model. """ url = f"https://cloud.leonardo.ai/api/rest/v1/models/{model_id}" self.___logger.debug(f"Requested to obtain custom model by model_id={model_id}: GET {url}") + session = self.___get_client_session("get") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - response = self.___session.get(url, headers=headers_copy) + response = session.get(url) response.raise_for_status() - response_text = response.text - self.___logger.debug(f"Custom modal has been trained: {response_text}") - return response_text + response = response.json() + self.___logger.debug(f"Custom modal has been trained: {response}") + session.close() + return response except Exception as error: self.___logger.error(f"Error obtaining custom model: {str(error)}") - raise + raise error - def delete_custom_model_by_id(self, model_id: str): + def delete_custom_model_by_id(self, model_id: str) -> requests.Response: """ This endpoint will delete a specific custom model. :param model_id: The ID of the model to delete. + :type model_id: str + :return: The response from the delete request. + :rtype: requests.Response + + Raises: + Exception: If an error occurs while deleting custom model. """ url = f"https://cloud.leonardo.ai/api/rest/v1/models/{model_id}" self.___logger.debug(f"Requested to delete custom model by model_id={model_id}: GET {url}") + session = self.___get_client_session("delete") try: - headers_copy = dict(self.___session.headers) - headers_copy.update(self.___get_headers) - response = self.___session.delete(url, headers=headers_copy) + response = session.delete(url) response.raise_for_status() - response_text = response.text - self.___logger.debug(f"Custom modal has been deleted: {response_text}") - return response_text + self.___logger.debug(f"Custom modal has been deleted: {response}") + return response except Exception as error: self.___logger.error(f"Error delete custom model: {str(error)}") - raise + raise error - def wait_for_image_generation(self, generation_id, image_index=0, poll_interval=5, timeout=120): + def wait_for_image_generation( + self, generation_id: str, image_index: int = 0, poll_interval: int = 5, timeout: int = 120 + ) -> dict: """ This method waits for the completion of image generation. :param generation_id: The ID of the generation to check. + :type generation_id: str :param image_index: (Optional) The index of the specific image to wait for. If None, waits for all images to complete. + :type image_index: int, optional :param poll_interval: (Optional) The time interval in seconds between each check. Default is 5 seconds. + :type poll_interval: int, optional :param timeout: (Optional) Waiting timeout. Default is 120 seconds. - + :type timeout: int, optional + :raises TimeoutError: If the image(s) have not been generated within the timeout. :raises IndexError: If an invalid image_index is provided. - :return: The completed image(s) once generation is complete. + :rtype: dict + + Raises: + TimeoutError: If the image(s) have not been generated within the timeout. + IndexError: If an invalid image_index is provided. """ timeout_counter = 0 while True: From 98ddb1017deb3511ae436be4e15f38c165a83ad5 Mon Sep 17 00:00:00 2001 From: wwakabobik Date: Fri, 24 Nov 2023 16:00:48 +0100 Subject: [PATCH 4/6] Bump to 0.0.10 --- pyproject.toml | 2 +- setup.cfg | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5eca574..fdd543e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "leonardo_api" -version = "0.0.9" +version = "0.0.10" authors = [ { name="Iliya Vereshchagin", email="i.vereshchagin@gmail.com" }, ] diff --git a/setup.cfg b/setup.cfg index 2b0828e..a801492 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = leonardo_api -version = attr: leonardo_api.0.0.9 +version = attr: leonardo_api.0.0.10 author = Iliya Vereshchagin author_email = i.vereshchagin@gmail.com maintainer = Iliya Vereshchagin From 39dd3e07a78309bd14168094cca15522c535128d Mon Sep 17 00:00:00 2001 From: wwakabobik Date: Fri, 24 Nov 2023 16:52:04 +0100 Subject: [PATCH 5/6] Bump to 0.0.10 --- src/leonardo_api/leonardo_async.py | 62 +++++++++++++++--------------- src/leonardo_api/leonardo_sync.py | 58 ++++++++++++++-------------- 2 files changed, 60 insertions(+), 60 deletions(-) diff --git a/src/leonardo_api/leonardo_async.py b/src/leonardo_api/leonardo_async.py index 5ae4ae3..e190b3f 100644 --- a/src/leonardo_api/leonardo_async.py +++ b/src/leonardo_api/leonardo_async.py @@ -83,10 +83,10 @@ async def get_user_info(self) -> dict: try: async with session.get(url) as response: response.raise_for_status() - response = await response.json() - self.___logger.debug(f"User info: {response}") + response_dict = await response.json() + self.___logger.debug(f"User info: {response_dict}") await session.close() - return response + return response_dict except Exception as error: self.___logger.error(f"Error occurred while getting user info: {str(error)}") if not session.closed: @@ -114,7 +114,7 @@ async def post_generations( prompt_magic: bool = True, control_net: bool = False, control_net_type: Optional[str] = None, - ) -> str: + ) -> dict: """ This endpoint will generate images. @@ -158,7 +158,7 @@ async def post_generations( :param control_net_type: The type of ControlNet to use. :type control_net_type: str, optional :return: generation response - :rtype: str + :rtype: dict Raises: Exception: if error occurred while post generations @@ -191,10 +191,10 @@ async def post_generations( try: async with session.post(url, json=payload) as response: response.raise_for_status() - response = await response.json() - self.___logger.debug(f"Post generations: {response}") + response_dict = await response.json() + self.___logger.debug(f"Post generations: {response_dict}") await session.close() - return response + return response_dict except Exception as error: self.___logger.error(f"Error occurred while post generations: {str(error)}") if not session.closed: @@ -219,10 +219,10 @@ async def get_single_generation(self, generation_id: str) -> dict: try: async with session.get(url) as response: response.raise_for_status() - response = await response.json() + response_dict = await response.json() self.___logger.debug(f"Single generations: {response}") await session.close() - return response + return response_dict except Exception as error: self.___logger.error(f"Error occurred while get single generations: {str(error)}") if not session.closed: @@ -279,10 +279,10 @@ async def get_generations_by_user(self, user_id: str, offset: int = 0, limit: in try: async with session.get(url, params=params) as response: response.raise_for_status() - response = await response.json() - self.___logger.debug(f"Generations for user {user_id} are: {response}") + response_dict = await response.json() + self.___logger.debug(f"Generations for user {user_id} are: {response_dict}") await session.close() - return response + return response_dict except Exception as error: self.___logger.error(f"Error occurred while obtaining user's generations: {str(error)}") if not session.closed: @@ -358,10 +358,10 @@ async def get_single_init_image(self, image_id: str) -> dict: try: async with session.get(url) as response: response.raise_for_status() - response = await response.json() - self.___logger.debug(f"Single image provided: {response}") + response_dict = await response.json() + self.___logger.debug(f"Single image provided: {response_dict}") await session.close() - return response + return response_dict except Exception as error: self.___logger.error(f"Error occurred while obtain single init image: {str(error)}") if not session.closed: @@ -395,7 +395,7 @@ async def delete_init_image(self, image_id: str) -> aiohttp.ClientResponse: await session.close() raise error - async def create_upscale(self, image_id: str) -> dict: + async def create_upscale(self, image_id: str) -> aiohttp.ClientResponse: """ This endpoint will create an upscale for the provided image ID. @@ -414,10 +414,10 @@ async def create_upscale(self, image_id: str) -> dict: try: async with session.post(url, json=payload) as response: response.raise_for_status() - response = await response.json() - self.___logger.debug(f"Upscale created: {response}") + response_dict = await response.json() + self.___logger.debug(f"Upscale created: {response_dict}") await session.close() - return response + return response_dict except Exception as error: self.___logger.error(f"Error occurred while up-scaling image: {str(error)}") if not session.closed: @@ -442,10 +442,10 @@ async def get_variation_by_id(self, generation_id: str) -> dict: try: async with session.get(url) as response: response.raise_for_status() - response = await response.json() - self.___logger.debug(f"Get variation by ID: {response}") + response_dict = await response.json() + self.___logger.debug(f"Get variation by ID: {response_dict}") await session.close() - return response + return response_dict except Exception as error: self.___logger.error(f"Error occurred while get variation by id: {str(error)}") if not session.closed: @@ -500,10 +500,10 @@ async def get_dataset_by_id(self, dataset_id: str) -> dict: try: async with session.get(url) as response: response.raise_for_status() - response = await response.json() - self.___logger.debug(f"Dataset with dataset_id={dataset_id} provided: {response}") + response_dict = await response.json() + self.___logger.debug(f"Dataset with dataset_id={dataset_id} provided: {response_dict}") await session.close() - return response + return response_dict except Exception as error: self.___logger.error(f"Error occurred while get dataset: {str(error)}") if not session.closed: @@ -714,10 +714,10 @@ async def get_custom_model_by_id(self, model_id: str) -> dict: try: async with session.get(url) as response: response.raise_for_status() - response = await response.json() - self.___logger.debug(f"Custom modal has been trained: {response}") + response_dict = await response.json() + self.___logger.debug(f"Custom modal has been trained: {response_dict}") await session.close() - return response + return response_dict except Exception as error: self.___logger.error(f"Error obtaining custom model: {str(error)}") if not session.closed: @@ -753,7 +753,7 @@ async def delete_custom_model_by_id(self, model_id: str) -> aiohttp.ClientRespon async def wait_for_image_generation( self, generation_id: str, image_index: int = 0, poll_interval: int = 5, timeout: int = 120 - ) -> dict: + ) -> aiohttp.ClientResponse: """ This method waits for the completion of image generation. @@ -766,7 +766,7 @@ async def wait_for_image_generation( :param timeout: (Optional) Waiting timeout. Default is 120 seconds. :type timeout: int, optional :return: The completed image(s) once generation is complete. - :rtype: dict + :rtype: aiohttp.ClientResponse :raises IndexError: If an invalid image_index is provided. :raises TimeoutError: If the image has not been generated in timeout seconds. diff --git a/src/leonardo_api/leonardo_sync.py b/src/leonardo_api/leonardo_sync.py index c928962..b5f9b6b 100644 --- a/src/leonardo_api/leonardo_sync.py +++ b/src/leonardo_api/leonardo_sync.py @@ -82,10 +82,10 @@ def get_user_info(self) -> dict: try: response = session.get(url) response.raise_for_status() - response = response.json() - self.___logger.debug(f"User info: {response}") + response_dict = response.json() + self.___logger.debug(f"User info: {response_dict}") session.close() - return response + return response_dict except Exception as error: self.___logger.error(f"Error occurred while getting user info: {str(error)}") raise error @@ -111,7 +111,7 @@ def post_generations( prompt_magic: bool = True, control_net: bool = False, control_net_type: Optional[str] = None, - ) -> dict: + ) -> requests.Response: """ This endpoint will generate images. @@ -155,7 +155,7 @@ def post_generations( :param control_net_type: The type of ControlNet to use. :type control_net_type: str, optional :return: The generation response. - :rtype: dict + :rtype: requests.Response """ # pylint: disable=too-many-locals url = "https://cloud.leonardo.ai/api/rest/v1/generations" @@ -185,9 +185,9 @@ def post_generations( try: response = session.post(url, json=payload) response.raise_for_status() - response = response.json() - self.___logger.debug(f"Post generations: {response}") - return response + response_dict = response.json() + self.___logger.debug(f"Post generations: {response_dict}") + return response_dict except Exception as error: self.___logger.error(f"Error occurred while post generations: {str(error)}") raise error @@ -210,9 +210,9 @@ def get_single_generation(self, generation_id: str) -> dict: try: response = session.get(url) response.raise_for_status() - response = response.json() - self.___logger.debug(f"Single generations: {response}") - return response + response_dict = response.json() + self.___logger.debug(f"Single generations: {response_dict}") + return response_dict except Exception as error: self.___logger.error(f"Error occurred while get single generations: {str(error)}") raise error @@ -264,9 +264,9 @@ def get_generations_by_user(self, user_id: str, offset: int = 0, limit: int = 10 try: response = session.get(url, params=params) response.raise_for_status() - response = response.json() + response_dict = response.json() self.___logger.debug(f"Generations for user {user_id} are: {response}") - return response + return response_dict except Exception as error: self.___logger.error(f"Error occurred while obtaining user's generations: {str(error)}") raise error @@ -332,9 +332,9 @@ def get_single_init_image(self, image_id: str) -> dict: try: response = session.get(url) response.raise_for_status() - response = response.json() - self.___logger.debug(f"Single image provided: {response}") - return response + response_dict = response.json() + self.___logger.debug(f"Single image provided: {response_dict}") + return response_dict except Exception as error: self.___logger.error(f"Error occurred while obtain single init image: {str(error)}") raise error @@ -390,14 +390,14 @@ def create_upscale(self, image_id: str) -> requests.Response: self.___logger.error(f"Error occurred while up-scaling image: {str(error)}") raise error - def get_variation_by_id(self, generation_id: str) -> dict: + def get_variation_by_id(self, generation_id: str) -> requests.Response: """ This endpoint will get the variation by ID. :param generation_id: The ID of the variation to get. :type generation_id: str :return: The variation. - :rtype: dict + :rtype: requests.Response Raises: Exception: If an error occurs while getting variation. @@ -408,9 +408,9 @@ def get_variation_by_id(self, generation_id: str) -> dict: try: response = session.get(url) response.raise_for_status() - response = response.json() - self.___logger.debug(f"Get variation by ID: {response}") - return response + response_dict = response.json() + self.___logger.debug(f"Get variation by ID: {response_dict}") + return response_dict except Exception as error: self.___logger.error(f"Error occurred while get variation by id: {str(error)}") raise error @@ -461,9 +461,9 @@ def get_dataset_by_id(self, dataset_id: str) -> dict: try: response = session.get(url) response.raise_for_status() - response = response.json() - self.___logger.debug(f"Dataset with dataset_id={dataset_id} provided: {response}") - return response + response_dict = response.json() + self.___logger.debug(f"Dataset with dataset_id={dataset_id} provided: {response_dict}") + return response_dict except Exception as error: self.___logger.error(f"Error occurred while get dataset: {str(error)}") raise error @@ -655,10 +655,10 @@ def get_custom_model_by_id(self, model_id: str) -> dict: try: response = session.get(url) response.raise_for_status() - response = response.json() - self.___logger.debug(f"Custom modal has been trained: {response}") + response_dict = response.json() + self.___logger.debug(f"Custom modal has been trained: {response_dict}") session.close() - return response + return response_dict except Exception as error: self.___logger.error(f"Error obtaining custom model: {str(error)}") raise error @@ -689,7 +689,7 @@ def delete_custom_model_by_id(self, model_id: str) -> requests.Response: def wait_for_image_generation( self, generation_id: str, image_index: int = 0, poll_interval: int = 5, timeout: int = 120 - ) -> dict: + ) -> requests.Response: """ This method waits for the completion of image generation. @@ -705,7 +705,7 @@ def wait_for_image_generation( :raises TimeoutError: If the image(s) have not been generated within the timeout. :raises IndexError: If an invalid image_index is provided. :return: The completed image(s) once generation is complete. - :rtype: dict + :rtype: requests.Response Raises: TimeoutError: If the image(s) have not been generated within the timeout. From 6f91e809e3e1c7ef9f737f2a1b723f97c2622c7c Mon Sep 17 00:00:00 2001 From: wwakabobik Date: Fri, 24 Nov 2023 16:55:09 +0100 Subject: [PATCH 6/6] Bump to 0.0.10 --- src/leonardo_api/leonardo_async.py | 3 ++- src/leonardo_api/leonardo_sync.py | 6 ++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/leonardo_api/leonardo_async.py b/src/leonardo_api/leonardo_async.py index e190b3f..1cca270 100644 --- a/src/leonardo_api/leonardo_async.py +++ b/src/leonardo_api/leonardo_async.py @@ -17,9 +17,10 @@ import os from typing import Optional +import asyncio import aiofiles import aiohttp -import asyncio + from .logger_config import setup_logger diff --git a/src/leonardo_api/leonardo_sync.py b/src/leonardo_api/leonardo_sync.py index b5f9b6b..246d6d1 100644 --- a/src/leonardo_api/leonardo_sync.py +++ b/src/leonardo_api/leonardo_sync.py @@ -535,9 +535,7 @@ def upload_dataset_image(self, dataset_id: str, file_path: str) -> requests.Resp fields.update({"file": file_data}) response = session.post(upload_url, data=fields) response.raise_for_status() - self.___logger.debug( - f"Dataset with dataset_id={dataset_id} uploaded using {file_path}" - ) + self.___logger.debug(f"Dataset with dataset_id={dataset_id} uploaded using {file_path}") session.close() return response except Exception as error: @@ -688,7 +686,7 @@ def delete_custom_model_by_id(self, model_id: str) -> requests.Response: raise error def wait_for_image_generation( - self, generation_id: str, image_index: int = 0, poll_interval: int = 5, timeout: int = 120 + self, generation_id: str, image_index: int = 0, poll_interval: int = 5, timeout: int = 120 ) -> requests.Response: """ This method waits for the completion of image generation.