From 4092dcdcd8d7e32d1d69e8f2b1b4f3d8b57bcbec Mon Sep 17 00:00:00 2001 From: pm3310 Date: Sat, 17 Feb 2024 13:02:43 +0000 Subject: [PATCH] Fix linting --- sagify/commands/llm.py | 20 ++++++++++++-------- sagify/llm_gateway/main.py | 2 ++ 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/sagify/commands/llm.py b/sagify/commands/llm.py index 9016cee..e63b557 100644 --- a/sagify/commands/llm.py +++ b/sagify/commands/llm.py @@ -42,11 +42,11 @@ _MAPPING_IMAGE_CREATION_MODEL_ID_TO_MODEL_NAME = { 'stabilityai-stable-diffusion-v2': ( - 'model-txt2img-stabilityai-stable-diffusion-v2', + 'model-txt2img-stabilityai-stable-diffusion-v2', 'https://huggingface.co/stabilityai/stable-diffusion-2' ), 'stabilityai-stable-diffusion-v2-1-base': ( - 'model-txt2img-stabilityai-stable-diffusion-v2-1-base', + 'model-txt2img-stabilityai-stable-diffusion-v2-1-base', 'https://huggingface.co/stabilityai/stable-diffusion-2-1-base' ), 'stabilityai-stable-diffusion-v2-fp16': ( @@ -101,7 +101,7 @@ ('ml.p3.8xlarge', 'https://instances.vantage.sh/aws/ec2/p3.8xlarge'), ('ml.p3.16xlarge', 'https://instances.vantage.sh/aws/ec2/p3.16xlarge'), ] - + @click.group() def llm(): @@ -110,6 +110,7 @@ def llm(): """ pass + @llm.command() def platforms(): """ @@ -119,6 +120,7 @@ def platforms(): logger.info(" - OpenAI: https://platform.openai.com/docs/overview") logger.info(" - AWS Sagemaker: https://aws.amazon.com/sagemaker") + @llm.command() @click.option( '--all', @@ -184,7 +186,7 @@ def sagemaker_models(all, chat_completions, image_creations, embeddings): logger.info(" - Instance Type: {}".format(instance_type)) logger.info(" Instance URL: {}".format(instance_url)) logger.info("\n") - + if embeddings: logger.info("\nEmbeddings:") for model_id, (model_name, model_url) in _MAPPING_EMBEDDINGS_MODEL_ID_TO_MODEL_NAME.items(): @@ -315,7 +317,7 @@ def start( list(_MAPPING_CHAT_COMPLETIONS_MODEL_ID_TO_MODEL_NAME.keys()) ) ) - + if default_config['chat_completions']['instance_type'] not in _VALID_INSTANCE_TYPES_PER_CHAT_COMPLETIONS_MODEL[ _MAPPING_CHAT_COMPLETIONS_MODEL_ID_TO_MODEL_NAME[default_config['chat_completions']['model']][0] ]: @@ -349,7 +351,7 @@ def start( list(_MAPPING_IMAGE_CREATION_MODEL_ID_TO_MODEL_NAME.keys()) ) ) - + if default_config['image_creations']['instance_type'] not in _VALID_INSTANCE_TYPES_PER_IMAGE_CREATIONS_MODEL[ _MAPPING_IMAGE_CREATION_MODEL_ID_TO_MODEL_NAME[default_config['image_creations']['model']][0] ]: @@ -360,7 +362,7 @@ def start( ] ) ) - + image_endpoint_name, _ = api_cloud.foundation_model_deploy( model_id=_MAPPING_IMAGE_CREATION_MODEL_ID_TO_MODEL_NAME[default_config['image_creations']['model']][0], model_version='1.*', @@ -383,7 +385,7 @@ def start( list(_MAPPING_EMBEDDINGS_MODEL_ID_TO_MODEL_NAME.keys()) ) ) - + if default_config['embeddings']['instance_type'] not in _VALID_EMBEDDINGS_INSTANCE_TYPES: raise ValueError( "Invalid instance type for embeddings model. Available instance types: {}".format( @@ -509,6 +511,7 @@ def stop( logger.info("{}".format(e)) sys.exit(-1) + @llm.command() def start_local_gateway(): """ @@ -519,6 +522,7 @@ def start_local_gateway(): from sagify.llm_gateway.main import start_server start_server() + llm.add_command(platforms) llm.add_command(sagemaker_models) llm.add_command(start) diff --git a/sagify/llm_gateway/main.py b/sagify/llm_gateway/main.py index 19f993b..c854865 100644 --- a/sagify/llm_gateway/main.py +++ b/sagify/llm_gateway/main.py @@ -14,8 +14,10 @@ app.include_router(api_router) app.add_exception_handler(InternalServerError, internal_server_error_handler) + def start_server(): uvicorn.run("sagify.llm_gateway.main:app", port=8080, host="0.0.0.0") + if __name__ == "__main__": start_server()