From 2a5a2fcf8759f11f342d1ca77a27c89dbf28cfad Mon Sep 17 00:00:00 2001 From: wwakabobik Date: Wed, 22 Nov 2023 17:23:47 +0100 Subject: [PATCH] v0.4 --- examples/image_generation/gpt_functions.py | 1 + examples/image_generation/image_test.py | 185 ++++++++++++++++++ examples/image_generation/leonardo_test.py | 2 +- .../llm_api_comparison/wrapped_llm_test.py | 11 +- requirements.txt | 5 +- utils/discord_interactions.py | 4 +- utils/discord_watcher.py | 1 + 7 files changed, 195 insertions(+), 14 deletions(-) create mode 100644 examples/image_generation/image_test.py diff --git a/examples/image_generation/gpt_functions.py b/examples/image_generation/gpt_functions.py index 7e0bcb1..1325714 100644 --- a/examples/image_generation/gpt_functions.py +++ b/examples/image_generation/gpt_functions.py @@ -10,6 +10,7 @@ Description: This file contains testing functions for ChatGPT function calling using DALLE and Leonardo experiments """ + import json from io import BytesIO diff --git a/examples/image_generation/image_test.py b/examples/image_generation/image_test.py new file mode 100644 index 0000000..f9ab715 --- /dev/null +++ b/examples/image_generation/image_test.py @@ -0,0 +1,185 @@ +# -*- coding: utf-8 -*- +""" +Filename: image_test.py +Author: Iliya Vereshchagin +Copyright (c) 2023. All rights reserved. + +Created: 21.11.2023 +Last Modified: 21.11.2023 + +Description: +This file contains testing functions for image generation. +""" + +import json +import re +import ssl +import time +from pprint import pprint + +import aiofiles +import aiohttp +import asyncio +from ablt_python_api import ABLTApi_async as ABLTApi +from leonardo_api import LeonardoAsync as Leonardo +from openai_python_api.dalle import DALLE + +from examples.creds import oai_token, oai_organization, leonardo_token, ablt_token, discord_midjourney_payload +from utils.discord_interactions import DiscordInteractions + +# Initialize the APIs +ssl_context = ssl.create_default_context() +ssl_context.check_hostname = False +ssl_context.verify_mode = ssl.CERT_NONE +dalle = DALLE(auth_token=oai_token, organization=oai_organization) +leonardo = Leonardo(auth_token=leonardo_token) +ablt = ABLTApi(bearer_token=ablt_token, ssl_context=ssl_context) + + +async def midjourney_wrapper(prompt): + """ + Wrapper for midjourney testing. + + :param prompt: The prompt to use for the function. + """ + discord = DiscordInteractions( + token=discord_midjourney_payload["auth_token"], + application_id=discord_midjourney_payload["application_id"], + guild_id=discord_midjourney_payload["guild_id"], + channel_id=discord_midjourney_payload["channel_id"], + session_id=discord_midjourney_payload["session_id"], + version=discord_midjourney_payload["version"], + interaction_id=discord_midjourney_payload["interaction_id"], + ) + await discord.post_interaction(my_text_prompt=prompt) + return find_and_clear(log_file="discord_watcher.log") + + +async def leonardo_wrapper(prompt): + response = await leonardo.post_generations( + prompt=prompt, + num_images=1, + model_id="1e60896f-3c26-4296-8ecc-53e2afecc132", + width=1024, + height=1024, + prompt_magic=True, + ) + response = await leonardo.wait_for_image_generation(generation_id=response["sdGenerationJob"]["generationId"]) + return json.dumps(response["url"]) + + +def find_and_clear(log_file): + """ + Find and clear the log file. + + :param log_file: The log file to use for the function. + :type log_file: str + :return: The attachment found in the log file. + :rtype: str + """ + for _ in range(12): + with open(log_file, "r+") as file: + lines = file.readlines() + for line in reversed(lines): + match = re.search(r"Found an attachment: (.*)", line) + if match: + file.truncate(0) + return match.group(1) + time.sleep(5) + return None + + +async def save_image_from_url(url, file_path): + """ + Save image from url to file. + + :param url: The url to use for the function. + :type url: str + :param file_path: The file path to use for the function. + :type file_path: str + """ + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + f = await aiofiles.open(file_path, mode="wb") + await f.write(await response.read()) + await f.close() + print(f"Image successfully saved to {file_path}") + return file_path + print(f"Unable to save image. HTTP response code: {response.status}") + return None + + +async def generate_image(): + prompts = ( + "beautiful and scary necromancer girl riding white unicorn", + "draw a character that is a toast-mascot in cartoon style", + "ai robots are fighting against humans in style of Pieter Bruegel", + ) + image_list = [] + for index, prompt in enumerate(prompts): + midjourney_prompt = await ablt.chat( + bot_slug="maina", + prompt=f"Please write a midjourney prompt with aspect ratio 1:1, realistic style: '{prompt}'. " + f"Give me the prompt only, without any comments and descriptions. " + f"Just prompt output for midjourney.", + stream=False, + ).__anext__() + dalle_prompt = await ablt.chat( + bot_slug="maina", + prompt=f"Please write a dalle3 prompt: '{prompt}'. " + f"Give me the prompt only, without any comments and descriptions. Just prompt output.", + stream=False, + ).__anext__() + midjourney_prompt = midjourney_prompt.replace("`", "").replace("n", "") + leonardo_image_url_coro = leonardo_wrapper(dalle_prompt) + dalle3_image_url_coro = dalle.create_image_url(dalle_prompt) + midjourney_image_url_coro = midjourney_wrapper(midjourney_prompt) + leonardo_image_url, dalle3_image_url, midjourney_image_url = await asyncio.gather( + leonardo_image_url_coro, dalle3_image_url_coro, midjourney_image_url_coro + ) + leonardo_image_coro = save_image_from_url(leonardo_image_url[0], f"leonardo_image_{index}.png") + dalle3_image_coro = save_image_from_url(dalle3_image_url[0], f"dalle3_image_{index}.png") + midjourney_image_coro = save_image_from_url(midjourney_image_url, f"midjourney_image_{index}.png") + leonardo_image, dalle3_image, midjourney_image = await asyncio.gather( + leonardo_image_coro, dalle3_image_coro, midjourney_image_coro + ) + image_list.append( + { + "images": {"leonardo": leonardo_image, "dalle3": dalle3_image, "midjourney": midjourney_image}, + "url": {"leonardo": leonardo_image_url, "dalle3": dalle3_image_url, "midjourney": midjourney_image_url}, + } + ) + return image_list + + +async def get_dalle_variations(image_list): + """ + Get variations from dalle3 images. + + :param image_list: The image list to use for the function. + :type image_list: list + :return: The variations from dalle3 images. + :rtype: list + """ + variations = [] + dalle.default_model = None # disable dall-e-3 because isn't supported for variations yet + for index, images in enumerate(image_list): + file_path = images["images"]["dalle3"] + # you may also use dalle.create_variation_from_url_and_get_url(url), but it's won't work for dalle3 urls + with open(file_path, "rb") as file: + url = await dalle.create_variation_from_file_and_get_url(file) + image = await save_image_from_url(url, f"dalle3_variation_{index}.png") + variations.append({"url": url, "image": image}) + return variations + + +async def main(): + """Main function.""" + image_list = await generate_image() + pprint(image_list) + dalle_variations = await get_dalle_variations(image_list) + pprint(dalle_variations) + + +asyncio.run(main()) diff --git a/examples/image_generation/leonardo_test.py b/examples/image_generation/leonardo_test.py index 53c2b0f..fb73ac9 100644 --- a/examples/image_generation/leonardo_test.py +++ b/examples/image_generation/leonardo_test.py @@ -27,7 +27,7 @@ async def main(): prompt = "a beautiful necromancer witch resurrects skeletons against the backdrop of a burning ruined castle" response = await leonardo.post_generations( prompt=prompt, - num_images=1, + num_images=2, negative_prompt="bright colors, good characters, positive", model_id="e316348f-7773-490e-adcd-46757c738eb7", width=1024, diff --git a/examples/llm_api_comparison/wrapped_llm_test.py b/examples/llm_api_comparison/wrapped_llm_test.py index ef1c570..f135cd8 100644 --- a/examples/llm_api_comparison/wrapped_llm_test.py +++ b/examples/llm_api_comparison/wrapped_llm_test.py @@ -14,11 +14,10 @@ from ablt_python_api import ABLTApi from examples.creds import ablt_token -from utils.llm_timer_wrapper import TimeMetricsWrapperSync +from examples.llm_api_comparison.ablt_models import unique_models from examples.llm_api_comparison.csv_saver import save_to_csv from examples.llm_api_comparison.llm_questions import llm_questions -from examples.llm_api_comparison.ablt_models import unique_models - +from utils.llm_timer_wrapper import TimeMetricsWrapperSync # Initialize LLM with tokens ablt = ABLTApi(ablt_token, ssl_verify=False) @@ -47,11 +46,7 @@ def main(): while True: try: response = check_chat_ablt_response(prompt, model) - save_to_csv( - file_name="llm_wrapped.csv", - model_name=model, - question=prompt, - metrics=response) + save_to_csv(file_name="llm_wrapped.csv", model_name=model, question=prompt, metrics=response) error_counter = 5 break except Exception as error: diff --git a/requirements.txt b/requirements.txt index 059f872..c6303dd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,8 +27,9 @@ pytest-xdist==3.5.0 cohere==4.36 llamaapi==0.1.36 # My AI APIs -leonardo-api==0.0.7 -openai-python-api==0.0.5 +leonardo-api==0.0.8 +openai-python-api==0.0.6 ablt-python-api==0.0.2 # Discord py-cord==2.4.1 +openai==1.3.4 diff --git a/utils/discord_interactions.py b/utils/discord_interactions.py index 11f5bac..a492bb5 100644 --- a/utils/discord_interactions.py +++ b/utils/discord_interactions.py @@ -36,7 +36,6 @@ async def post_interaction(self, my_text_prompt, **kwargs): :type my_text_prompt: str :param kwargs: The parameters for the interaction. :return: The response from the interaction. - :rtype: dict """ params = {**self.default_params, **kwargs} @@ -63,6 +62,5 @@ async def post_interaction(self, my_text_prompt, **kwargs): async with aiohttp.ClientSession() as session: async with session.post(self.url, json=payload_data, headers=self.headers) as resp: - if resp.status != 200: + if resp.status != 200 and resp.status != 204: raise ValueError(f"Request failed with status code {resp.status}") - return await resp.json() diff --git a/utils/discord_watcher.py b/utils/discord_watcher.py index 46b60c1..c5985d9 100644 --- a/utils/discord_watcher.py +++ b/utils/discord_watcher.py @@ -59,3 +59,4 @@ async def on_message(self, message): return embed.to_dict() else: self.___logger.debug('Found a message from the target user, but content is not ready yet...') +