From 4092dcdcd8d7e32d1d69e8f2b1b4f3d8b57bcbec Mon Sep 17 00:00:00 2001
From: pm3310
Date: Sat, 17 Feb 2024 13:02:43 +0000
Subject: [PATCH] Fix linting
---
sagify/commands/llm.py | 20 ++++++++++++--------
sagify/llm_gateway/main.py | 2 ++
2 files changed, 14 insertions(+), 8 deletions(-)
diff --git a/sagify/commands/llm.py b/sagify/commands/llm.py
index 9016cee..e63b557 100644
--- a/sagify/commands/llm.py
+++ b/sagify/commands/llm.py
@@ -42,11 +42,11 @@
_MAPPING_IMAGE_CREATION_MODEL_ID_TO_MODEL_NAME = {
'stabilityai-stable-diffusion-v2': (
- 'model-txt2img-stabilityai-stable-diffusion-v2',
+ 'model-txt2img-stabilityai-stable-diffusion-v2',
'https://huggingface.co/stabilityai/stable-diffusion-2'
),
'stabilityai-stable-diffusion-v2-1-base': (
- 'model-txt2img-stabilityai-stable-diffusion-v2-1-base',
+ 'model-txt2img-stabilityai-stable-diffusion-v2-1-base',
'https://huggingface.co/stabilityai/stable-diffusion-2-1-base'
),
'stabilityai-stable-diffusion-v2-fp16': (
@@ -101,7 +101,7 @@
('ml.p3.8xlarge', 'https://instances.vantage.sh/aws/ec2/p3.8xlarge'),
('ml.p3.16xlarge', 'https://instances.vantage.sh/aws/ec2/p3.16xlarge'),
]
-
+
@click.group()
def llm():
@@ -110,6 +110,7 @@ def llm():
"""
pass
+
@llm.command()
def platforms():
"""
@@ -119,6 +120,7 @@ def platforms():
logger.info(" - OpenAI: https://platform.openai.com/docs/overview")
logger.info(" - AWS Sagemaker: https://aws.amazon.com/sagemaker")
+
@llm.command()
@click.option(
'--all',
@@ -184,7 +186,7 @@ def sagemaker_models(all, chat_completions, image_creations, embeddings):
logger.info(" - Instance Type: {}".format(instance_type))
logger.info(" Instance URL: {}".format(instance_url))
logger.info("\n")
-
+
if embeddings:
logger.info("\nEmbeddings:")
for model_id, (model_name, model_url) in _MAPPING_EMBEDDINGS_MODEL_ID_TO_MODEL_NAME.items():
@@ -315,7 +317,7 @@ def start(
list(_MAPPING_CHAT_COMPLETIONS_MODEL_ID_TO_MODEL_NAME.keys())
)
)
-
+
if default_config['chat_completions']['instance_type'] not in _VALID_INSTANCE_TYPES_PER_CHAT_COMPLETIONS_MODEL[
_MAPPING_CHAT_COMPLETIONS_MODEL_ID_TO_MODEL_NAME[default_config['chat_completions']['model']][0]
]:
@@ -349,7 +351,7 @@ def start(
list(_MAPPING_IMAGE_CREATION_MODEL_ID_TO_MODEL_NAME.keys())
)
)
-
+
if default_config['image_creations']['instance_type'] not in _VALID_INSTANCE_TYPES_PER_IMAGE_CREATIONS_MODEL[
_MAPPING_IMAGE_CREATION_MODEL_ID_TO_MODEL_NAME[default_config['image_creations']['model']][0]
]:
@@ -360,7 +362,7 @@ def start(
]
)
)
-
+
image_endpoint_name, _ = api_cloud.foundation_model_deploy(
model_id=_MAPPING_IMAGE_CREATION_MODEL_ID_TO_MODEL_NAME[default_config['image_creations']['model']][0],
model_version='1.*',
@@ -383,7 +385,7 @@ def start(
list(_MAPPING_EMBEDDINGS_MODEL_ID_TO_MODEL_NAME.keys())
)
)
-
+
if default_config['embeddings']['instance_type'] not in _VALID_EMBEDDINGS_INSTANCE_TYPES:
raise ValueError(
"Invalid instance type for embeddings model. Available instance types: {}".format(
@@ -509,6 +511,7 @@ def stop(
logger.info("{}".format(e))
sys.exit(-1)
+
@llm.command()
def start_local_gateway():
"""
@@ -519,6 +522,7 @@ def start_local_gateway():
from sagify.llm_gateway.main import start_server
start_server()
+
llm.add_command(platforms)
llm.add_command(sagemaker_models)
llm.add_command(start)
diff --git a/sagify/llm_gateway/main.py b/sagify/llm_gateway/main.py
index 19f993b..c854865 100644
--- a/sagify/llm_gateway/main.py
+++ b/sagify/llm_gateway/main.py
@@ -14,8 +14,10 @@
app.include_router(api_router)
app.add_exception_handler(InternalServerError, internal_server_error_handler)
+
def start_server():
uvicorn.run("sagify.llm_gateway.main:app", port=8080, host="0.0.0.0")
+
if __name__ == "__main__":
start_server()