Skip to content

Commit

Permalink
Bump to 0.0.10
Browse files Browse the repository at this point in the history
  • Loading branch information
wwakabobik committed Nov 24, 2023
1 parent 6375c5f commit dae3611
Show file tree
Hide file tree
Showing 3 changed files with 347 additions and 178 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
- Fixed headers update methods (possible async bug fix)

## [0.0.10] - 2023-11-24

### Fixed
- Fixed image upload methods (headers should be purged before poking s3)
- Fixed session headers update to much more generic
101 changes: 63 additions & 38 deletions src/leonardo_api/leonardo_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import os
from typing import Optional

import asyncio
import aiofiles
import aiohttp
import asyncio

from .logger_config import setup_logger

Expand All @@ -33,7 +33,7 @@ class Leonardo:
logger (logging.Logger, optional): default logger. Default will be initialized if not provided.
"""

def __init__(self, auth_token: str, logger: Optional[logging.Logger] = None):
def __init__(self, auth_token: str, logger: Optional[logging.Logger] = None) -> None:
"""
Constructs all the necessary attributes for the Leonardo object.
Expand All @@ -46,7 +46,7 @@ def __init__(self, auth_token: str, logger: Optional[logging.Logger] = None):
self.___logger = logger if logger else setup_logger("Leonardo", "leonardo_async.log")
self.___logger.debug("Leonardo init complete")

async def ___get_client_session(self, request_type: str = "get", empty: bool = False):
async def ___get_client_session(self, request_type: str = "get", empty: bool = False) -> aiohttp.ClientSession:
"""
This method returns aiohttp.ClientSession with headers.
Expand All @@ -66,7 +66,7 @@ async def ___get_client_session(self, request_type: str = "get", empty: bool = F
headers.update({"accept": "application/json", "content-type": "application/json"})
return aiohttp.ClientSession(headers=headers)

async def get_user_info(self):
async def get_user_info(self) -> dict:
"""
This endpoint will return your user information, including your user ID.
Expand Down Expand Up @@ -114,7 +114,7 @@ async def post_generations(
prompt_magic: bool = True,
control_net: bool = False,
control_net_type: Optional[str] = None,
):
) -> str:
"""
This endpoint will generate images.
Expand Down Expand Up @@ -201,7 +201,7 @@ async def post_generations(
await session.close()
raise error

async def get_single_generation(self, generation_id: str):
async def get_single_generation(self, generation_id: str) -> dict:
"""
This endpoint will provide information about a specific generation.
Expand Down Expand Up @@ -229,14 +229,14 @@ async def get_single_generation(self, generation_id: str):
await session.close()
raise

async def delete_single_generation(self, generation_id: str):
async def delete_single_generation(self, generation_id: str) -> aiohttp.ClientResponse:
"""
This endpoint deletes a specific generation.
:param generation_id: The ID of the generation to delete.
:type generation_id: str
:return: generation info
:rtype: dict
:rtype: aiohttp.ClientResponse
Raises:
Exception: if error occurred while delete single generation
Expand All @@ -256,13 +256,16 @@ async def delete_single_generation(self, generation_id: str):
await session.close()
raise error

async def get_generations_by_user(self, user_id: str, offset: int = 0, limit: int = 10):
async def get_generations_by_user(self, user_id: str, offset: int = 0, limit: int = 10) -> dict:
"""
This endpoint returns all generations by a specific user.
:param user_id: The ID of the user.
:type user_id: str
:param offset: The offset for pagination.
:type offset: int
:param limit: The limit for pagination.
:type limit: int
:return: generations
:rtype: dict
Expand All @@ -286,7 +289,7 @@ async def get_generations_by_user(self, user_id: str, offset: int = 0, limit: in
await session.close()
raise error

async def upload_init_image(self, file_path: str): # pylint: disable=too-many-locals
async def upload_init_image(self, file_path: str) -> str: # pylint: disable=too-many-locals
"""
This endpoint returns pre-signed details to upload an init image to S3.
Expand Down Expand Up @@ -337,7 +340,7 @@ async def upload_init_image(self, file_path: str): # pylint: disable=too-many-l
await session.close()
raise error

async def get_single_init_image(self, image_id: str):
async def get_single_init_image(self, image_id: str) -> dict:
"""
This endpoint will return a single init image.
Expand Down Expand Up @@ -365,12 +368,14 @@ async def get_single_init_image(self, image_id: str):
await session.close()
raise

async def delete_init_image(self, image_id: str):
async def delete_init_image(self, image_id: str) -> aiohttp.ClientResponse:
"""
This endpoint deletes an init image.
:param image_id: The ID of the init image to delete.
:type image_id: str
:return: init image
:rtype: aiohttp.ClientResponse
Raises:
Exception: if error occurred while delete init image
Expand All @@ -390,7 +395,7 @@ async def delete_init_image(self, image_id: str):
await session.close()
raise error

async def create_upscale(self, image_id: str):
async def create_upscale(self, image_id: str) -> dict:
"""
This endpoint will create an upscale for the provided image ID.
Expand Down Expand Up @@ -419,7 +424,7 @@ async def create_upscale(self, image_id: str):
await session.close()
raise error

async def get_variation_by_id(self, generation_id: str):
async def get_variation_by_id(self, generation_id: str) -> dict:
"""
This endpoint will get the variation by ID.
Expand Down Expand Up @@ -447,14 +452,16 @@ async def get_variation_by_id(self, generation_id: str):
await session.close()
raise error

async def create_dataset(self, name: str, description: Optional[str] = None):
async def create_dataset(self, name: str, description: Optional[str] = None) -> aiohttp.ClientResponse:
"""
This endpoint creates a new dataset.
:param name: The name of the dataset.
:type name: str
:param description: A description for the dataset.
:type description: str, optional
:return: dataset info
:rtype: dict
:rtype: aiohttp.ClientResponse
Raises:
Exception: if error occurred while create dataset
Expand All @@ -475,12 +482,14 @@ async def create_dataset(self, name: str, description: Optional[str] = None):
await session.close()
raise error

async def get_dataset_by_id(self, dataset_id: str):
async def get_dataset_by_id(self, dataset_id: str) -> dict:
"""
This endpoint gets the specific dataset.
:param dataset_id: The ID of the dataset to return.
:type dataset_id: str
:return: dataset info
:rtype: dict
Raises:
Exception: if error occurred while get dataset by id
Expand All @@ -501,14 +510,14 @@ async def get_dataset_by_id(self, dataset_id: str):
await session.close()
raise

async def delete_dataset_by_id(self, dataset_id: str):
async def delete_dataset_by_id(self, dataset_id: str) -> aiohttp.ClientResponse:
"""
This endpoint deletes the specific dataset.
:param dataset_id: The ID of the dataset to delete.
:type dataset_id: str
:return: dataset info
:rtype: dict
:return: response
:rtype: aiohttp.ClientResponse
Raises:
Exception: if error occurred while delete dataset by id
Expand All @@ -528,14 +537,16 @@ async def delete_dataset_by_id(self, dataset_id: str):
await session.close()
raise error

async def upload_dataset_image(self, dataset_id: str, file_path: str):
async def upload_dataset_image(self, dataset_id: str, file_path: str) -> aiohttp.ClientResponse:
"""
This endpoint returns pre-signed details to upload a dataset image to S3.
:param dataset_id: The ID of the dataset to which the image will be uploaded.
:type dataset_id: str
:param file_path: The path to the image file.
:type file_path: str
:return: dataset info
:rtype: dict
:rtype: aiohttp.ClientResponse
Raises:
ValueError: if invalid file extension
Expand Down Expand Up @@ -573,26 +584,27 @@ async def upload_dataset_image(self, dataset_id: str, file_path: str):
session = await self.___get_client_session("post", empty=True)
async with session.post(upload_url, data=fields) as response:
response.raise_for_status()
response_text = await response.text()
self.___logger.debug(
f"Dataset with dataset_id={dataset_id} uploaded using {file_path}:" f" {response_text}"
)
self.___logger.debug(f"Dataset with dataset_id={dataset_id} uploaded using {file_path}: {response}")
await session.close()
return response_text
return response
except Exception as error:
self.___logger.error(f"Error occurred uploading dataset: {str(error)}")
if not session.closed:
await session.close()
raise

async def upload_generated_image_to_dataset(self, dataset_id: str, generated_image_id: str):
async def upload_generated_image_to_dataset(
self, dataset_id: str, generated_image_id: str
) -> aiohttp.ClientResponse:
"""
This endpoint will upload a previously generated image to the dataset.
:param dataset_id: The ID of the dataset to upload the image to.
:type dataset_id: str
:param generated_image_id: The ID of the image to upload to the dataset.
:type generated_image_id: str
:return: dataset info
:rtype: dict
:rtype: aiohttp.ClientResponse
Raises:
Exception: if error occurred while upload generated image to dataset
Expand Down Expand Up @@ -629,21 +641,30 @@ async def train_custom_model(
resolution: int = 512,
sd_version: Optional[str] = None,
strength: str = "MEDIUM",
):
) -> aiohttp.ClientResponse:
"""
This endpoint will train a new custom model.
:param name: The name of the model.
:type name: str
:param description: The description of the model.
:type description: str, optional
:param dataset_id: The ID of the dataset to train the model on.
:type dataset_id: str
:param instance_prompt: The instance prompt to use during training.
:type instance_prompt: str
:param model_type: The category the most accurately reflects the model.
:type model_type: str, optional
:param nsfw: mark for NSFW model. Default is False.
:type nsfw: bool, optional
:param resolution: The resolution for training. Must be 512 or 768.
:type resolution: int, optional
:param sd_version: The base version of stable diffusion to use if not using a custom model.
:type sd_version: str, optional
:param strength: When training using the PIXEL_ART model type, this influences the training strength.
:type strength: str, optional
:return: dataset info
:rtype: dict
:rtype: aiohttp.ClientResponse
Raises:
Exception: if error occurred while train custom model
Expand Down Expand Up @@ -675,12 +696,14 @@ async def train_custom_model(
await session.close()
raise error

async def get_custom_model_by_id(self, model_id: str):
async def get_custom_model_by_id(self, model_id: str) -> dict:
"""
This endpoint gets the specific custom model.
:param model_id: The ID of the custom model to return.
:type model_id: str
:return: custom model info
:rtype: dict
Raises:
Exception: if error occurred while get custom model by id
Expand All @@ -691,22 +714,24 @@ async def get_custom_model_by_id(self, model_id: str):
try:
async with session.get(url) as response:
response.raise_for_status()
response_text = await response.text()
self.___logger.debug(f"Custom modal has been trained: {response_text}")
response = await response.json()
self.___logger.debug(f"Custom modal has been trained: {response}")
await session.close()
return response_text
return response
except Exception as error:
self.___logger.error(f"Error obtaining custom model: {str(error)}")
if not session.closed:
await session.close()
raise

async def delete_custom_model_by_id(self, model_id: str):
async def delete_custom_model_by_id(self, model_id: str) -> aiohttp.ClientResponse:
"""
This endpoint will delete a specific custom model.
:param model_id: The ID of the model to delete.
:type model_id: str
:type model_id: aiohttp.ClientResponse
:return: custom model info
:rtype: aiohttp.ClientResponse
Raises:
Exception: if error occurred while delete custom model by id
Expand All @@ -728,7 +753,7 @@ async def delete_custom_model_by_id(self, model_id: str):

async def wait_for_image_generation(
self, generation_id: str, image_index: int = 0, poll_interval: int = 5, timeout: int = 120
):
) -> dict:
"""
This method waits for the completion of image generation.
Expand Down
Loading

0 comments on commit dae3611

Please sign in to comment.