From b9ae95bb4ee682988381ba645f222961b98a692b Mon Sep 17 00:00:00 2001
From: pm3310
Date: Fri, 19 Jan 2024 21:36:47 +0000
Subject: [PATCH 1/2] Simplify deployment of foundation model
---
sagify/sagemaker/sagemaker.py | 39 ++++++-----------------------------
1 file changed, 6 insertions(+), 33 deletions(-)
diff --git a/sagify/sagemaker/sagemaker.py b/sagify/sagemaker/sagemaker.py
index cf961f3..3f051ca 100644
--- a/sagify/sagemaker/sagemaker.py
+++ b/sagify/sagemaker/sagemaker.py
@@ -8,8 +8,8 @@
import sagemaker.huggingface
import sagemaker.xgboost
import sagemaker.sklearn.model
-from sagemaker import image_uris, model_uris, payloads
-from sagemaker.predictor import Predictor
+from sagemaker import payloads
+from sagemaker.jumpstart.model import JumpStartModel
from six.moves.urllib.parse import urlparse
import boto3
@@ -604,44 +604,17 @@ def deploy_foundation_model(
:return: [str], endpoint name
"""
- deploy_image_uri = image_uris.retrieve(
- region=self.aws_region,
- framework=None, # automatically inferred from model_id
- image_scope="inference",
- model_id=model_id,
- model_version=model_version,
- instance_type=instance_type,
- sagemaker_session=self.sagemaker_session
- )
-
- model_uri = model_uris.retrieve(
+ model = JumpStartModel(
model_id=model_id,
model_version=model_version,
- model_scope="inference",
region=self.aws_region,
- sagemaker_session=self.sagemaker_session
- )
-
- # Increase the maximum response size from the endpoint
- env = {
- "MMS_MAX_RESPONSE_SIZE": "20000000",
- }
-
- model = sage.Model(
- image_uri=deploy_image_uri,
- model_data=model_uri,
- role=self.role,
- predictor_cls=Predictor,
- name=endpoint_name,
- env=env,
- sagemaker_session=self.sagemaker_session
+ sagemaker_session=self.sagemaker_session,
+ tolerate_deprecated_model=True,
+ tolerate_vulnerable_model=True
)
-
model_predictor = model.deploy(
initial_instance_count=instance_count,
instance_type=instance_type,
- predictor_cls=Predictor,
- endpoint_name=endpoint_name,
tags=tags,
accept_eula=True
)
From 844e61c8663e53e8acf84d7952ebcd12d500720c Mon Sep 17 00:00:00 2001
From: pm3310
Date: Fri, 19 Jan 2024 21:37:46 +0000
Subject: [PATCH 2/2] LLM infra commands
---
README.md | 78 +++++++
docs/index.md | 78 +++++++
sagify/__main__.py | 2 +
sagify/commands/llm.py | 294 ++++++++++++++++++++++++
sagify/sagemaker/sagemaker.py | 7 +
tests/commands/test_llm.py | 409 ++++++++++++++++++++++++++++++++++
6 files changed, 868 insertions(+)
create mode 100644 sagify/commands/llm.py
create mode 100644 tests/commands/test_llm.py
diff --git a/README.md b/README.md
index 1b7c272..721680f 100644
--- a/README.md
+++ b/README.md
@@ -967,3 +967,81 @@ This command deploys a Foundation model without code.
`--external-id EXTERNAL_ID` or `-x EXTERNAL_ID`: Optional external id used when using an IAM role
`--endpoint-name ENDPOINT_NAME`: Optional name for the SageMaker endpoint
+
+
+### LLM Start Infrastructure
+
+#### Name
+
+Command to start LLM infrastructure
+
+#### Synopsis
+```sh
+sagify llm start --all --chat-completions --image-creations --embeddings [--config EC2_CONFIG_FILE] --aws-profile AWS_PROFILE --aws-region AWS_REGION [--aws-tags TAGS] [--iam-role-arn IAM_ROLE] [--external-id EXTERNAL_ID]
+```
+
+#### Description
+
+It spins up the endpoints for chat completions, image creation and embeddings.
+
+#### Required Flags
+
+`--all`: Start infrastructure for all services. If this flag is used the flags `--chat-completions`, `--image-creations`, `--embeddings` are ignored.
+
+`--chat-completions`: Start infrastructure for chat completions.
+
+`--image-creations`: Start infrastructure for image creations.
+
+`--embeddings`: Start infrastructure for embeddings.
+
+`--config EC2_CONFIG_FILE`: Path to config file to override foundation models, ec2 instance types and/or number of instances.
+
+`--aws-profile AWS_PROFILE`: The AWS profile to use for the lightning deploy command
+
+`--aws-region AWS_REGION`: The AWS region to use for the lightning deploy command
+
+#### Optional Flags
+
+`--aws-tags TAGS` or `-a TAGS`: Tags for labeling a training job of the form `tag1=value1;tag2=value2`. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
+
+`--iam-role-arn IAM_ROLE` or `-r IAM_ROLE`: AWS IAM role to use for deploying with *SageMaker*
+
+`--external-id EXTERNAL_ID` or `-x EXTERNAL_ID`: Optional external id used when using an IAM role
+
+
+### LLM Stop Infrastructure
+
+#### Name
+
+Command to stop LLM infrastructure
+
+#### Synopsis
+```sh
+sagify llm stop --all --chat-completions --image-creations --embeddings --aws-profile AWS_PROFILE --aws-region AWS_REGION [--aws-tags TAGS] [--iam-role-arn IAM_ROLE] [--external-id EXTERNAL_ID]
+```
+
+#### Description
+
+It stop all or some of the services that are running.
+
+#### Required Flags
+
+`--all`: Start infrastructure for all services. If this flag is used the flags `--chat-completions`, `--image-creations`, `--embeddings` are ignored.
+
+`--chat-completions`: Start infrastructure for chat completions.
+
+`--image-creations`: Start infrastructure for image creations.
+
+`--embeddings`: Start infrastructure for embeddings.
+
+`--aws-profile AWS_PROFILE`: The AWS profile to use for the lightning deploy command
+
+`--aws-region AWS_REGION`: The AWS region to use for the lightning deploy command
+
+#### Optional Flags
+
+`--aws-tags TAGS` or `-a TAGS`: Tags for labeling a training job of the form `tag1=value1;tag2=value2`. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
+
+`--iam-role-arn IAM_ROLE` or `-r IAM_ROLE`: AWS IAM role to use for deploying with *SageMaker*
+
+`--external-id EXTERNAL_ID` or `-x EXTERNAL_ID`: Optional external id used when using an IAM role
diff --git a/docs/index.md b/docs/index.md
index efcdb81..4181ac2 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -1674,3 +1674,81 @@ This command deploys a Foundation model without code.
`--external-id EXTERNAL_ID` or `-x EXTERNAL_ID`: Optional external id used when using an IAM role
`--endpoint-name ENDPOINT_NAME`: Optional name for the SageMaker endpoint
+
+
+### LLM Start Infrastructure
+
+#### Name
+
+Command to start LLM infrastructure
+
+#### Synopsis
+```sh
+sagify llm start --all --chat-completions --image-creations --embeddings [--config EC2_CONFIG_FILE] --aws-profile AWS_PROFILE --aws-region AWS_REGION [--aws-tags TAGS] [--iam-role-arn IAM_ROLE] [--external-id EXTERNAL_ID]
+```
+
+#### Description
+
+It spins up the endpoints for chat completions, image creation and embeddings.
+
+#### Required Flags
+
+`--all`: Start infrastructure for all services. If this flag is used the flags `--chat-completions`, `--image-creations`, `--embeddings` are ignored.
+
+`--chat-completions`: Start infrastructure for chat completions.
+
+`--image-creations`: Start infrastructure for image creations.
+
+`--embeddings`: Start infrastructure for embeddings.
+
+`--config EC2_CONFIG_FILE`: Path to config file to override foundation models, ec2 instance types and/or number of instances.
+
+`--aws-profile AWS_PROFILE`: The AWS profile to use for the lightning deploy command
+
+`--aws-region AWS_REGION`: The AWS region to use for the lightning deploy command
+
+#### Optional Flags
+
+`--aws-tags TAGS` or `-a TAGS`: Tags for labeling a training job of the form `tag1=value1;tag2=value2`. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
+
+`--iam-role-arn IAM_ROLE` or `-r IAM_ROLE`: AWS IAM role to use for deploying with *SageMaker*
+
+`--external-id EXTERNAL_ID` or `-x EXTERNAL_ID`: Optional external id used when using an IAM role
+
+
+### LLM Stop Infrastructure
+
+#### Name
+
+Command to stop LLM infrastructure
+
+#### Synopsis
+```sh
+sagify llm stop --all --chat-completions --image-creations --embeddings --aws-profile AWS_PROFILE --aws-region AWS_REGION [--aws-tags TAGS] [--iam-role-arn IAM_ROLE] [--external-id EXTERNAL_ID]
+```
+
+#### Description
+
+It stop all or some of the services that are running.
+
+#### Required Flags
+
+`--all`: Start infrastructure for all services. If this flag is used the flags `--chat-completions`, `--image-creations`, `--embeddings` are ignored.
+
+`--chat-completions`: Start infrastructure for chat completions.
+
+`--image-creations`: Start infrastructure for image creations.
+
+`--embeddings`: Start infrastructure for embeddings.
+
+`--aws-profile AWS_PROFILE`: The AWS profile to use for the lightning deploy command
+
+`--aws-region AWS_REGION`: The AWS region to use for the lightning deploy command
+
+#### Optional Flags
+
+`--aws-tags TAGS` or `-a TAGS`: Tags for labeling a training job of the form `tag1=value1;tag2=value2`. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
+
+`--iam-role-arn IAM_ROLE` or `-r IAM_ROLE`: AWS IAM role to use for deploying with *SageMaker*
+
+`--external-id EXTERNAL_ID` or `-x EXTERNAL_ID`: Optional external id used when using an IAM role
diff --git a/sagify/__main__.py b/sagify/__main__.py
index 16e2748..67ae5c8 100644
--- a/sagify/__main__.py
+++ b/sagify/__main__.py
@@ -6,6 +6,7 @@
from sagify.commands.build import build
from sagify.commands.cloud import cloud
from sagify.commands.initialize import init
+from sagify.commands.llm import llm
from sagify.commands.local import local
from sagify.commands.push import push
from sagify.commands.configure import configure
@@ -31,6 +32,7 @@ def add_commands(cli):
cli.add_command(push)
cli.add_command(cloud)
cli.add_command(configure)
+ cli.add_command(llm)
add_commands(cli)
diff --git a/sagify/commands/llm.py b/sagify/commands/llm.py
new file mode 100644
index 0000000..34cb237
--- /dev/null
+++ b/sagify/commands/llm.py
@@ -0,0 +1,294 @@
+# -*- coding: utf-8 -*-
+from __future__ import print_function, unicode_literals
+
+import json
+import sys
+
+import click
+
+from sagify.api import cloud as api_cloud
+from sagify.commands import ASCII_LOGO
+from sagify.commands.custom_validators.validators import validate_tags
+from sagify.log import logger
+from sagify.sagemaker import sagemaker
+
+click.disable_unicode_literals_warning = True
+
+
+@click.group()
+def llm():
+ """
+ Commands for LLM (Large Language Model) operations
+ """
+ pass
+
+
+@llm.command()
+@click.option(
+ '--all',
+ is_flag=True,
+ show_default=True,
+ default=False,
+ help='Start infrastructure for all services.'
+)
+@click.option(
+ '--chat-completions',
+ is_flag=True,
+ show_default=True,
+ default=False,
+ help='Start infrastructure for chat completions.'
+)
+@click.option(
+ '--image-creations',
+ is_flag=True,
+ show_default=True,
+ default=False,
+ help='Start infrastructure for image creations.'
+)
+@click.option(
+ '--embeddings',
+ is_flag=True,
+ show_default=True,
+ default=False,
+ help='Start infrastructure for embeddings.'
+)
+@click.option('--config', required=False, type=click.File('r'), help='Path to config file.')
+@click.option(
+ u"-a", u"--aws-tags",
+ callback=validate_tags,
+ required=False,
+ default=None,
+ help='Tags for labeling a training job of the form "tag1=value1;tag2=value2". For more, see '
+ 'https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.'
+)
+@click.option(
+ u"--aws-profile",
+ required=True,
+ help="The AWS profile to use for the foundation model deploy command"
+)
+@click.option(
+ u"--aws-region",
+ required=True,
+ help="The AWS region to use for the foundation model deploy command"
+)
+@click.option(
+ u"-r",
+ u"--iam-role-arn",
+ required=False,
+ help="The AWS role to use for the foundation model deploy command"
+)
+@click.option(
+ u"-x",
+ u"--external-id",
+ required=False,
+ help="Optional external id used when using an IAM role"
+)
+def start(
+ all,
+ chat_completions,
+ image_creations,
+ embeddings,
+ config,
+ aws_tags,
+ aws_profile,
+ aws_region,
+ iam_role_arn,
+ external_id
+):
+ """
+ Command to start LLM infrastructure
+ """
+ logger.info(ASCII_LOGO)
+ logger.info("Starting LLM infrastructure. It will take ~15-30 mins...\n")
+
+ # Default configuration
+ default_config = {
+ 'chat_completions': {
+ 'model': 'meta-textgeneration-llama-2-7b-f',
+ 'model_version': '1.*',
+ 'instance_type': 'ml.g5.2xlarge',
+ 'num_instances': 1,
+ },
+ 'image_creations': {
+ 'model': 'model-txt2img-stabilityai-stable-diffusion-v2-1-base',
+ 'model_version': '1.*',
+ 'instance_type': 'ml.p3.2xlarge',
+ 'num_instances': 1,
+ },
+ 'embeddings': {
+ 'model': 'huggingface-sentencesimilarity-gte-small',
+ 'model_version': '1.*',
+ 'instance_type': 'ml.g5.2xlarge',
+ 'num_instances': 1,
+ },
+ }
+
+ # Load the config file if provided
+ if config:
+ custom_config = json.load(config)
+ default_config.update(custom_config)
+
+ try:
+ if all:
+ chat_completions, image_creations, embeddings = True, True, True
+
+ llm_infra_config = {
+ 'chat_completions_endpoint': None,
+ 'image_creations_endpoint': None,
+ 'embeddings_endpoint': None,
+ }
+
+ if chat_completions:
+ chat_endpoint_name, _ = api_cloud.foundation_model_deploy(
+ model_id=default_config['chat_completions']['model'],
+ model_version=default_config['chat_completions']['model_version'],
+ num_instances=default_config['chat_completions']['num_instances'],
+ ec2_type=default_config['chat_completions']['instance_type'],
+ aws_region=aws_region,
+ aws_profile=aws_profile,
+ aws_role=iam_role_arn,
+ external_id=external_id,
+ tags=aws_tags
+ )
+ llm_infra_config['chat_completions_endpoint'] = chat_endpoint_name
+
+ logger.info("Chat Completions Endpoint Name: {}".format(chat_endpoint_name))
+
+ if image_creations:
+ image_endpoint_name, _ = api_cloud.foundation_model_deploy(
+ model_id=default_config['image_creations']['model'],
+ model_version=default_config['image_creations']['model_version'],
+ num_instances=default_config['image_creations']['num_instances'],
+ ec2_type=default_config['image_creations']['instance_type'],
+ aws_region=aws_region,
+ aws_profile=aws_profile,
+ aws_role=iam_role_arn,
+ external_id=external_id,
+ tags=aws_tags
+ )
+ llm_infra_config['image_creations_endpoint'] = image_endpoint_name
+
+ logger.info("Image Creations Endpoint Name: {}".format(image_endpoint_name))
+
+ if embeddings:
+ embeddings_endpoint_name, _ = api_cloud.foundation_model_deploy(
+ model_id=default_config['embeddings']['model'],
+ model_version=default_config['embeddings']['model_version'],
+ num_instances=default_config['embeddings']['num_instances'],
+ ec2_type=default_config['embeddings']['instance_type'],
+ aws_region=aws_region,
+ aws_profile=aws_profile,
+ aws_role=iam_role_arn,
+ external_id=external_id,
+ tags=aws_tags
+ )
+ llm_infra_config['embeddings_endpoint'] = embeddings_endpoint_name
+
+ logger.info("Embeddings Endpoint Name: {}".format(embeddings_endpoint_name))
+
+ with open('.sagify_llm_infra.json', 'w') as f:
+ json.dump(llm_infra_config, f)
+ except ValueError as e:
+ logger.info("{}".format(e))
+ sys.exit(-1)
+
+
+@click.command()
+@click.option(
+ '--all',
+ is_flag=True,
+ show_default=True,
+ default=False,
+ help='Start infrastructure for all services.'
+)
+@click.option(
+ '--chat-completions',
+ is_flag=True,
+ show_default=True,
+ default=False,
+ help='Start infrastructure for chat completions.'
+)
+@click.option(
+ '--image-creations',
+ is_flag=True,
+ show_default=True,
+ default=False,
+ help='Start infrastructure for image creations.'
+)
+@click.option(
+ '--embeddings',
+ is_flag=True,
+ show_default=True,
+ default=False,
+ help='Start infrastructure for embeddings.'
+)
+@click.option(
+ u"--aws-profile",
+ required=True,
+ help="The AWS profile to use for the foundation model deploy command"
+)
+@click.option(
+ u"--aws-region",
+ required=True,
+ help="The AWS region to use for the foundation model deploy command"
+)
+@click.option(
+ u"-r",
+ u"--iam-role-arn",
+ required=False,
+ help="The AWS role to use for the train command"
+)
+@click.option(
+ u"-x",
+ u"--external-id",
+ required=False,
+ help="Optional external id used when using an IAM role"
+)
+def stop(
+ all,
+ chat_completions,
+ image_creations,
+ embeddings,
+ aws_profile,
+ aws_region,
+ iam_role_arn,
+ external_id
+):
+ """
+ Command to stop LLM infrastructure
+ """
+ logger.info(ASCII_LOGO)
+ logger.info("Stopping LLM infrastructure...\n")
+
+ sagemaker_client = sagemaker.SageMakerClient(aws_profile, aws_region, iam_role_arn, external_id)
+ try:
+ with open('.sagify_llm_infra.json', 'r') as f:
+ llm_infra_config = json.load(f)
+
+ endpoints_to_stop = []
+ if all:
+ endpoints_to_stop = ['chat_completions_endpoint', 'image_creations_endpoint', 'embeddings_endpoint']
+ else:
+ if chat_completions:
+ endpoints_to_stop.append('chat_completions_endpoint')
+ if image_creations:
+ endpoints_to_stop.append('image_creations_endpoint')
+ if embeddings:
+ endpoints_to_stop.append('embeddings_endpoint')
+
+ for _endpoint in endpoints_to_stop:
+ if llm_infra_config[_endpoint]:
+ try:
+ sagemaker_client.shutdown_endpoint(llm_infra_config[_endpoint])
+ except Exception as e:
+ logger.info("{}".format(e))
+ sys.exit(-1)
+
+ logger.info("LLM infrastructure stopped successfully.")
+ except FileNotFoundError as e:
+ logger.info("{}".format(e))
+ sys.exit(-1)
+
+
+llm.add_command(start)
+llm.add_command(stop)
diff --git a/sagify/sagemaker/sagemaker.py b/sagify/sagemaker/sagemaker.py
index 3f051ca..8e470c9 100644
--- a/sagify/sagemaker/sagemaker.py
+++ b/sagify/sagemaker/sagemaker.py
@@ -728,6 +728,13 @@ def query_endpoint(model_predictor, payload, content_type, accept):
return example_query_code_snippet
+ def shutdown_endpoint(self, endpoint_name):
+ """
+ Shuts down a SageMaker endpoint.
+ :param endpoint_name: [str], name of the endpoint to be shut down
+ """
+ self.sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
+
@staticmethod
def _get_s3_bucket(s3_dir):
"""
diff --git a/tests/commands/test_llm.py b/tests/commands/test_llm.py
new file mode 100644
index 0000000..0d078f4
--- /dev/null
+++ b/tests/commands/test_llm.py
@@ -0,0 +1,409 @@
+# -*- coding: utf-8 -*-
+import json
+import os
+try:
+ from unittest.mock import patch, call
+except ImportError:
+ from mock import patch, call
+
+from click.testing import CliRunner
+from sagify.__main__ import cli
+
+
+class TestLlmStart(object):
+ def test_start_all_happy_case(self):
+ runner = CliRunner()
+ with patch(
+ 'sagify.api.cloud.foundation_model_deploy'
+ ) as mocked_foundation_model_deploy:
+ mocked_foundation_model_deploy.side_effect = [
+ ('chat_completions_endpoint', 'some code snippet 1'),
+ ('image_creations_endpoint', 'some code snippet 2'),
+ ('embeddings_endpoint', 'some code snippet 3'),
+ ]
+ with runner.isolated_filesystem():
+ result = runner.invoke(
+ cli=cli,
+ args=[
+ 'llm', 'start',
+ '--all',
+ '--aws-region', 'us-east-1',
+ '--aws-profile', 'sagemaker-production'
+ ]
+ )
+
+ assert mocked_foundation_model_deploy.call_count == 3
+ mocked_foundation_model_deploy.assert_has_calls(
+ [
+ call(
+ model_id='meta-textgeneration-llama-2-7b-f',
+ model_version='1.*',
+ num_instances=1,
+ ec2_type='ml.g5.2xlarge',
+ aws_region='us-east-1',
+ aws_profile='sagemaker-production',
+ aws_role=None,
+ external_id=None,
+ tags=None
+ ),
+ call(
+ model_id='model-txt2img-stabilityai-stable-diffusion-v2-1-base',
+ model_version='1.*',
+ num_instances=1,
+ ec2_type='ml.p3.2xlarge',
+ aws_region='us-east-1',
+ aws_profile='sagemaker-production',
+ aws_role=None,
+ external_id=None,
+ tags=None
+ ),
+ call(
+ model_id='huggingface-sentencesimilarity-gte-small',
+ model_version='1.*',
+ num_instances=1,
+ ec2_type='ml.g5.2xlarge',
+ aws_region='us-east-1',
+ aws_profile='sagemaker-production',
+ aws_role=None,
+ external_id=None,
+ tags=None
+ )
+ ]
+ )
+
+ assert os.path.isfile('.sagify_llm_infra.json')
+
+ with open('.sagify_llm_infra.json', 'r') as f:
+ llm_infra_config = json.load(f)
+
+ assert llm_infra_config['chat_completions_endpoint'] is not None
+ assert llm_infra_config['image_creations_endpoint'] is not None
+ assert llm_infra_config['embeddings_endpoint'] is not None
+
+ assert result.exit_code == 0
+
+ def test_start_chat_completions_only(self):
+ runner = CliRunner()
+ with patch(
+ 'sagify.api.cloud.foundation_model_deploy'
+ ) as mocked_foundation_model_deploy:
+ mocked_foundation_model_deploy.side_effect = [
+ ('chat_completions_endpoint', 'some code snippet 1')
+ ]
+ with runner.isolated_filesystem():
+ result = runner.invoke(
+ cli=cli,
+ args=[
+ 'llm', 'start',
+ '--chat-completions',
+ '--aws-region', 'us-east-1',
+ '--aws-profile', 'sagemaker-production'
+ ]
+ )
+
+ assert mocked_foundation_model_deploy.call_count == 1
+ mocked_foundation_model_deploy.assert_called_with(
+ model_id='meta-textgeneration-llama-2-7b-f',
+ model_version='1.*',
+ num_instances=1,
+ ec2_type='ml.g5.2xlarge',
+ aws_region='us-east-1',
+ aws_profile='sagemaker-production',
+ aws_role=None,
+ external_id=None,
+ tags=None
+ )
+
+ assert os.path.isfile('.sagify_llm_infra.json')
+
+ with open('.sagify_llm_infra.json', 'r') as f:
+ llm_infra_config = json.load(f)
+
+ assert llm_infra_config['chat_completions_endpoint'] is not None
+ assert llm_infra_config['image_creations_endpoint'] is None
+ assert llm_infra_config['embeddings_endpoint'] is None
+
+ assert result.exit_code == 0
+
+ def test_start_image_creations_only(self):
+ runner = CliRunner()
+ with patch(
+ 'sagify.api.cloud.foundation_model_deploy'
+ ) as mocked_foundation_model_deploy:
+ mocked_foundation_model_deploy.side_effect = [
+ ('image_creations_endpoint', 'some code snippet 2')
+ ]
+ with runner.isolated_filesystem():
+ result = runner.invoke(
+ cli=cli,
+ args=[
+ 'llm', 'start',
+ '--image-creations',
+ '--aws-region', 'us-east-1',
+ '--aws-profile', 'sagemaker-production'
+ ]
+ )
+
+ assert mocked_foundation_model_deploy.call_count == 1
+ mocked_foundation_model_deploy.assert_called_with(
+ model_id='model-txt2img-stabilityai-stable-diffusion-v2-1-base',
+ model_version='1.*',
+ num_instances=1,
+ ec2_type='ml.p3.2xlarge',
+ aws_region='us-east-1',
+ aws_profile='sagemaker-production',
+ aws_role=None,
+ external_id=None,
+ tags=None
+ )
+
+ assert os.path.isfile('.sagify_llm_infra.json')
+
+ with open('.sagify_llm_infra.json', 'r') as f:
+ llm_infra_config = json.load(f)
+
+ assert llm_infra_config['chat_completions_endpoint'] is None
+ assert llm_infra_config['image_creations_endpoint'] is not None
+ assert llm_infra_config['embeddings_endpoint'] is None
+
+ assert result.exit_code == 0
+
+ def test_start_embeddings_only(self):
+ runner = CliRunner()
+ with patch(
+ 'sagify.api.cloud.foundation_model_deploy'
+ ) as mocked_foundation_model_deploy:
+ mocked_foundation_model_deploy.side_effect = [
+ ('embeddings_endpoint', 'some code snippet 3')
+ ]
+ with runner.isolated_filesystem():
+ result = runner.invoke(
+ cli=cli,
+ args=[
+ 'llm', 'start',
+ '--embeddings',
+ '--aws-region', 'us-east-1',
+ '--aws-profile', 'sagemaker-production'
+ ]
+ )
+
+ assert mocked_foundation_model_deploy.call_count == 1
+ mocked_foundation_model_deploy.assert_called_with(
+ model_id='huggingface-sentencesimilarity-gte-small',
+ model_version='1.*',
+ num_instances=1,
+ ec2_type='ml.g5.2xlarge',
+ aws_region='us-east-1',
+ aws_profile='sagemaker-production',
+ aws_role=None,
+ external_id=None,
+ tags=None
+ )
+
+ assert os.path.isfile('.sagify_llm_infra.json')
+
+ with open('.sagify_llm_infra.json', 'r') as f:
+ llm_infra_config = json.load(f)
+
+ assert llm_infra_config['chat_completions_endpoint'] is None
+ assert llm_infra_config['image_creations_endpoint'] is None
+ assert llm_infra_config['embeddings_endpoint'] is not None
+
+ assert result.exit_code == 0
+
+
+class TestLlmStop(object):
+ def test_stop_all_happy_case(self):
+ runner = CliRunner()
+ with patch(
+ 'sagify.commands.llm.sagemaker.SageMakerClient'
+ ) as mocked_sagemaker_client:
+ with runner.isolated_filesystem():
+ with open('.sagify_llm_infra.json', 'w') as f:
+ json.dump({
+ 'chat_completions_endpoint': 'endpoint1',
+ 'image_creations_endpoint': 'endpoint2',
+ 'embeddings_endpoint': 'endpoint3'
+ }, f)
+
+ result = runner.invoke(
+ cli=cli,
+ args=[
+ 'llm', 'stop',
+ '--all',
+ '--aws-region', 'us-east-1',
+ '--aws-profile', 'sagemaker-production',
+ '--iam-role-arn', 'arn:aws:iam::123456789012:role/MyRole',
+ '--external-id', '123456'
+ ]
+ )
+
+ mocked_sagemaker_client.assert_called_with(
+ 'sagemaker-production', 'us-east-1', 'arn:aws:iam::123456789012:role/MyRole', '123456'
+ )
+ assert mocked_sagemaker_client.return_value.shutdown_endpoint.call_count == 3
+ mocked_sagemaker_client.return_value.shutdown_endpoint.assert_has_calls(
+ [
+ call('endpoint1'),
+ call('endpoint2'),
+ call('endpoint3')
+ ]
+ )
+
+ assert result.exit_code == 0
+
+ def test_stop_chat_completions_only(self):
+ runner = CliRunner()
+ with patch(
+ 'sagify.commands.llm.sagemaker.SageMakerClient'
+ ) as mocked_sagemaker_client:
+ with runner.isolated_filesystem():
+ with open('.sagify_llm_infra.json', 'w') as f:
+ json.dump({
+ 'chat_completions_endpoint': 'endpoint1',
+ 'image_creations_endpoint': 'endpoint2',
+ 'embeddings_endpoint': 'endpoint3'
+ }, f)
+
+ result = runner.invoke(
+ cli=cli,
+ args=[
+ 'llm', 'stop',
+ '--chat-completions',
+ '--aws-region', 'us-east-1',
+ '--aws-profile', 'sagemaker-production',
+ '--iam-role-arn', 'arn:aws:iam::123456789012:role/MyRole',
+ '--external-id', '123456'
+ ]
+ )
+
+ mocked_sagemaker_client.assert_called_with(
+ 'sagemaker-production', 'us-east-1', 'arn:aws:iam::123456789012:role/MyRole', '123456'
+ )
+ assert mocked_sagemaker_client.return_value.shutdown_endpoint.call_count == 1
+ mocked_sagemaker_client.return_value.shutdown_endpoint.assert_called_with(
+ 'endpoint1'
+ )
+
+ assert result.exit_code == 0
+
+ def test_stop_image_creations_only(self):
+ runner = CliRunner()
+ with patch(
+ 'sagify.commands.llm.sagemaker.SageMakerClient'
+ ) as mocked_sagemaker_client:
+ with runner.isolated_filesystem():
+ with open('.sagify_llm_infra.json', 'w') as f:
+ json.dump({
+ 'chat_completions_endpoint': 'endpoint1',
+ 'image_creations_endpoint': 'endpoint2',
+ 'embeddings_endpoint': 'endpoint3'
+ }, f)
+
+ result = runner.invoke(
+ cli=cli,
+ args=[
+ 'llm', 'stop',
+ '--image-creations',
+ '--aws-region', 'us-east-1',
+ '--aws-profile', 'sagemaker-production',
+ '--iam-role-arn', 'arn:aws:iam::123456789012:role/MyRole',
+ '--external-id', '123456'
+ ]
+ )
+
+ mocked_sagemaker_client.assert_called_with(
+ 'sagemaker-production', 'us-east-1', 'arn:aws:iam::123456789012:role/MyRole', '123456'
+ )
+ assert mocked_sagemaker_client.return_value.shutdown_endpoint.call_count == 1
+ mocked_sagemaker_client.return_value.shutdown_endpoint.assert_called_with(
+ 'endpoint2'
+ )
+
+ assert result.exit_code == 0
+
+ def test_stop_embeddings_only(self):
+ runner = CliRunner()
+ with patch(
+ 'sagify.commands.llm.sagemaker.SageMakerClient'
+ ) as mocked_sagemaker_client:
+ with runner.isolated_filesystem():
+ with open('.sagify_llm_infra.json', 'w') as f:
+ json.dump({
+ 'chat_completions_endpoint': 'endpoint1',
+ 'image_creations_endpoint': 'endpoint2',
+ 'embeddings_endpoint': 'endpoint3'
+ }, f)
+
+ result = runner.invoke(
+ cli=cli,
+ args=[
+ 'llm', 'stop',
+ '--embeddings',
+ '--aws-region', 'us-east-1',
+ '--aws-profile', 'sagemaker-production',
+ '--iam-role-arn', 'arn:aws:iam::123456789012:role/MyRole',
+ '--external-id', '123456'
+ ]
+ )
+
+ mocked_sagemaker_client.assert_called_with(
+ 'sagemaker-production', 'us-east-1', 'arn:aws:iam::123456789012:role/MyRole', '123456'
+ )
+ assert mocked_sagemaker_client.return_value.shutdown_endpoint.call_count == 1
+ mocked_sagemaker_client.return_value.shutdown_endpoint.assert_called_with(
+ 'endpoint3'
+ )
+
+ assert result.exit_code == 0
+
+ def test_stop_missing_config_file(self):
+ runner = CliRunner()
+ with patch(
+ 'sagify.commands.llm.sagemaker.SageMakerClient'
+ ) as mocked_sagemaker_client:
+ with runner.isolated_filesystem():
+ result = runner.invoke(
+ cli=cli,
+ args=[
+ 'llm', 'stop',
+ '--aws-region', 'us-east-1',
+ '--aws-profile', 'sagemaker-production',
+ '--iam-role-arn', 'arn:aws:iam::123456789012:role/MyRole',
+ '--external-id', '123456'
+ ]
+ )
+
+ assert mocked_sagemaker_client.return_value.shutdown_endpoint.call_count == 0
+ assert result.exit_code == -1
+
+ def test_stop_endpoint_shutdown_error(self):
+ runner = CliRunner()
+ with patch(
+ 'sagify.commands.llm.sagemaker.SageMakerClient'
+ ) as mocked_sagemaker_client:
+ mocked_sagemaker_client.return_value.shutdown_endpoint.side_effect = Exception('Endpoint shutdown error')
+ with runner.isolated_filesystem():
+ with open('.sagify_llm_infra.json', 'w') as f:
+ json.dump({
+ 'chat_completions_endpoint': 'endpoint1',
+ 'image_creations_endpoint': 'endpoint2',
+ 'embeddings_endpoint': 'endpoint3'
+ }, f)
+
+ result = runner.invoke(
+ cli=cli,
+ args=[
+ 'llm', 'stop',
+ '--all',
+ '--aws-region', 'us-east-1',
+ '--aws-profile', 'sagemaker-production',
+ '--iam-role-arn', 'arn:aws:iam::123456789012:role/MyRole',
+ '--external-id', '123456'
+ ]
+ )
+
+ assert mocked_sagemaker_client.return_value.shutdown_endpoint.call_count == 1
+ mocked_sagemaker_client.return_value.shutdown_endpoint.assert_called_with('endpoint1')
+
+ assert result.exit_code == -1