From 55122880e07305a5d13aad1726fc8b31d5840cbf Mon Sep 17 00:00:00 2001
From: pm3310
Date: Fri, 19 Jan 2024 21:37:46 +0000
Subject: [PATCH] LLM infra commands
---
README.md | 70 ++++++++
docs/index.md | 70 ++++++++
sagify/__main__.py | 2 +
sagify/commands/llm.py | 247 +++++++++++++++++++++++++++
sagify/sagemaker/sagemaker.py | 7 +
tests/commands/test_llm.py | 307 ++++++++++++++++++++++++++++++++++
6 files changed, 703 insertions(+)
create mode 100644 sagify/commands/llm.py
create mode 100644 tests/commands/test_llm.py
diff --git a/README.md b/README.md
index 1b7c272..8fd6aa9 100644
--- a/README.md
+++ b/README.md
@@ -967,3 +967,73 @@ This command deploys a Foundation model without code.
`--external-id EXTERNAL_ID` or `-x EXTERNAL_ID`: Optional external id used when using an IAM role
`--endpoint-name ENDPOINT_NAME`: Optional name for the SageMaker endpoint
+
+
+### LLM Start Infrastructure
+
+#### Name
+
+Command to start LLM infrastructure
+
+#### Synopsis
+```sh
+sagify llm start [--all] [--chat-completions] [--image-creations] [--embeddings] [--config EC2_CONFIG_FILE] --aws-profile AWS_PROFILE --aws-region AWS_REGION [--aws-tags TAGS] [--iam-role-arn IAM_ROLE] [--external-id EXTERNAL_ID]
+```
+
+#### Description
+
+It spins up the endpoints for chat completions, image creation and embeddings.
+
+#### Required Flags
+
+`--all`: Start infrastructure for all services.
+
+`--chat-completions`: Start infrastructure for chat completions.
+
+`--image-creations`: Start infrastructure for image creations.
+
+`--embeddings`: Start infrastructure for embeddings.
+
+`--config EC2_CONFIG_FILE`: Path to config file to override foundation models, ec2 instance types and/or number of instances.
+
+`--aws-profile AWS_PROFILE`: The AWS profile to use for the lightning deploy command
+
+`--aws-region AWS_REGION`: The AWS region to use for the lightning deploy command
+
+#### Optional Flags
+
+`--aws-tags TAGS` or `-a TAGS`: Tags for labeling a training job of the form `tag1=value1;tag2=value2`. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
+
+`--iam-role-arn IAM_ROLE` or `-r IAM_ROLE`: AWS IAM role to use for deploying with *SageMaker*
+
+`--external-id EXTERNAL_ID` or `-x EXTERNAL_ID`: Optional external id used when using an IAM role
+
+
+### LLM Stop Infrastructure
+
+#### Name
+
+Command to stop LLM infrastructure
+
+#### Synopsis
+```sh
+sagify llm stop --aws-profile AWS_PROFILE --aws-region AWS_REGION [--aws-tags TAGS] [--iam-role-arn IAM_ROLE] [--external-id EXTERNAL_ID]
+```
+
+#### Description
+
+It stop all the services that are running.
+
+#### Required Flags
+
+`--aws-profile AWS_PROFILE`: The AWS profile to use for the lightning deploy command
+
+`--aws-region AWS_REGION`: The AWS region to use for the lightning deploy command
+
+#### Optional Flags
+
+`--aws-tags TAGS` or `-a TAGS`: Tags for labeling a training job of the form `tag1=value1;tag2=value2`. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
+
+`--iam-role-arn IAM_ROLE` or `-r IAM_ROLE`: AWS IAM role to use for deploying with *SageMaker*
+
+`--external-id EXTERNAL_ID` or `-x EXTERNAL_ID`: Optional external id used when using an IAM role
\ No newline at end of file
diff --git a/docs/index.md b/docs/index.md
index efcdb81..6ed8a34 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -1674,3 +1674,73 @@ This command deploys a Foundation model without code.
`--external-id EXTERNAL_ID` or `-x EXTERNAL_ID`: Optional external id used when using an IAM role
`--endpoint-name ENDPOINT_NAME`: Optional name for the SageMaker endpoint
+
+
+### LLM Start Infrastructure
+
+#### Name
+
+Command to start LLM infrastructure
+
+#### Synopsis
+```sh
+sagify llm start [--all] [--chat-completions] [--image-creations] [--embeddings] [--config EC2_CONFIG_FILE] --aws-profile AWS_PROFILE --aws-region AWS_REGION [--aws-tags TAGS] [--iam-role-arn IAM_ROLE] [--external-id EXTERNAL_ID]
+```
+
+#### Description
+
+It spins up the endpoints for chat completions, image creation and embeddings.
+
+#### Required Flags
+
+`--all`: Start infrastructure for all services.
+
+`--chat-completions`: Start infrastructure for chat completions.
+
+`--image-creations`: Start infrastructure for image creations.
+
+`--embeddings`: Start infrastructure for embeddings.
+
+`--config EC2_CONFIG_FILE`: Path to config file to override foundation models, ec2 instance types and/or number of instances.
+
+`--aws-profile AWS_PROFILE`: The AWS profile to use for the lightning deploy command
+
+`--aws-region AWS_REGION`: The AWS region to use for the lightning deploy command
+
+#### Optional Flags
+
+`--aws-tags TAGS` or `-a TAGS`: Tags for labeling a training job of the form `tag1=value1;tag2=value2`. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
+
+`--iam-role-arn IAM_ROLE` or `-r IAM_ROLE`: AWS IAM role to use for deploying with *SageMaker*
+
+`--external-id EXTERNAL_ID` or `-x EXTERNAL_ID`: Optional external id used when using an IAM role
+
+
+### LLM Stop Infrastructure
+
+#### Name
+
+Command to stop LLM infrastructure
+
+#### Synopsis
+```sh
+sagify llm stop --aws-profile AWS_PROFILE --aws-region AWS_REGION [--aws-tags TAGS] [--iam-role-arn IAM_ROLE] [--external-id EXTERNAL_ID]
+```
+
+#### Description
+
+It stop all the services that are running.
+
+#### Required Flags
+
+`--aws-profile AWS_PROFILE`: The AWS profile to use for the lightning deploy command
+
+`--aws-region AWS_REGION`: The AWS region to use for the lightning deploy command
+
+#### Optional Flags
+
+`--aws-tags TAGS` or `-a TAGS`: Tags for labeling a training job of the form `tag1=value1;tag2=value2`. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
+
+`--iam-role-arn IAM_ROLE` or `-r IAM_ROLE`: AWS IAM role to use for deploying with *SageMaker*
+
+`--external-id EXTERNAL_ID` or `-x EXTERNAL_ID`: Optional external id used when using an IAM role
diff --git a/sagify/__main__.py b/sagify/__main__.py
index 16e2748..67ae5c8 100644
--- a/sagify/__main__.py
+++ b/sagify/__main__.py
@@ -6,6 +6,7 @@
from sagify.commands.build import build
from sagify.commands.cloud import cloud
from sagify.commands.initialize import init
+from sagify.commands.llm import llm
from sagify.commands.local import local
from sagify.commands.push import push
from sagify.commands.configure import configure
@@ -31,6 +32,7 @@ def add_commands(cli):
cli.add_command(push)
cli.add_command(cloud)
cli.add_command(configure)
+ cli.add_command(llm)
add_commands(cli)
diff --git a/sagify/commands/llm.py b/sagify/commands/llm.py
new file mode 100644
index 0000000..65b51bf
--- /dev/null
+++ b/sagify/commands/llm.py
@@ -0,0 +1,247 @@
+# -*- coding: utf-8 -*-
+from __future__ import print_function, unicode_literals
+
+import json
+import sys
+
+import click
+
+from sagify.api import cloud as api_cloud
+from sagify.commands import ASCII_LOGO
+from sagify.commands.custom_validators.validators import validate_tags
+from sagify.log import logger
+from sagify.config.config import ConfigManager
+from sagify.sagemaker import sagemaker
+
+click.disable_unicode_literals_warning = True
+
+
+@click.group()
+def llm():
+ """
+ Commands for LLM (Large Language Model) operations
+ """
+ pass
+
+
+@llm.command()
+@click.option(
+ '--all',
+ is_flag=True,
+ show_default=True,
+ default=False,
+ help='Start infrastructure for all services.'
+)
+@click.option(
+ '--chat-completions',
+ is_flag=True,
+ show_default=True,
+ default=False,
+ help='Start infrastructure for chat completions.'
+)
+@click.option(
+ '--image-creations',
+ is_flag=True,
+ show_default=True,
+ default=False,
+ help='Start infrastructure for image creations.'
+)
+@click.option(
+ '--embeddings',
+ is_flag=True,
+ show_default=True,
+ default=False,
+ help='Start infrastructure for embeddings.'
+)
+@click.option('--config', required=False, type=click.File('r'), help='Path to config file.')
+@click.option(
+ u"-a", u"--aws-tags",
+ callback=validate_tags,
+ required=False,
+ default=None,
+ help='Tags for labeling a training job of the form "tag1=value1;tag2=value2". For more, see '
+ 'https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.'
+)
+@click.option(
+ u"--aws-profile",
+ required=True,
+ help="The AWS profile to use for the foundation model deploy command"
+)
+@click.option(
+ u"--aws-region",
+ required=True,
+ help="The AWS region to use for the foundation model deploy command"
+)
+@click.option(
+ u"-r",
+ u"--iam-role-arn",
+ required=False,
+ help="The AWS role to use for the foundation model deploy command"
+)
+@click.option(
+ u"-x",
+ u"--external-id",
+ required=False,
+ help="Optional external id used when using an IAM role"
+)
+def start(
+ all,
+ chat_completions,
+ image_creations,
+ embeddings,
+ config,
+ aws_tags,
+ aws_profile,
+ aws_region,
+ iam_role_arn,
+ external_id
+):
+ """
+ Command to start LLM infrastructure
+ """
+ logger.info(ASCII_LOGO)
+ logger.info("Starting LLM infrastructure. It will take ~15-30 mins...\n")
+
+ # Default configuration
+ default_config = {
+ 'chat_completions': {
+ 'model': 'meta-textgeneration-llama-2-7b-f',
+ 'model_version': '1.*',
+ 'instance_type': 'ml.g5.2xlarge',
+ 'num_instances': 1,
+ },
+ 'image_creations': {
+ 'model': 'model-txt2img-stabilityai-stable-diffusion-v2-1-base',
+ 'model_version': '1.*',
+ 'instance_type': 'ml.p3.2xlarge',
+ 'num_instances': 1,
+ },
+ 'embeddings': {
+ 'model': 'huggingface-sentencesimilarity-gte-small',
+ 'model_version': '1.*',
+ 'instance_type': 'ml.g5.2xlarge',
+ 'num_instances': 1,
+ },
+ }
+
+ # Load the config file if provided
+ if config:
+ custom_config = json.load(config)
+ default_config.update(custom_config)
+
+ try:
+ if all:
+ chat_completions, image_creations, embeddings = True, True, True
+
+ llm_infra_config = {
+ 'chat_completions_endpoint': None,
+ 'image_creations_endpoint': None,
+ 'embeddings_endpoint': None,
+ }
+
+ if chat_completions:
+ chat_endpoint_name, _ = api_cloud.foundation_model_deploy(
+ model_id=default_config['chat_completions']['model'],
+ model_version=default_config['chat_completions']['model_version'],
+ num_instances=default_config['chat_completions']['num_instances'],
+ ec2_type=default_config['chat_completions']['instance_type'],
+ aws_region=aws_region,
+ aws_profile=aws_profile,
+ aws_role=iam_role_arn,
+ external_id=external_id,
+ tags=aws_tags
+ )
+ llm_infra_config['chat_completions_endpoint'] = chat_endpoint_name
+
+ logger.info("Chat Completions Endpoint Name: {}".format(chat_endpoint_name))
+
+ if image_creations:
+ image_endpoint_name, _ = api_cloud.foundation_model_deploy(
+ model_id=default_config['image_creations']['model'],
+ model_version=default_config['image_creations']['model_version'],
+ num_instances=default_config['image_creations']['num_instances'],
+ ec2_type=default_config['image_creations']['instance_type'],
+ aws_region=aws_region,
+ aws_profile=aws_profile,
+ aws_role=iam_role_arn,
+ external_id=external_id,
+ tags=aws_tags
+ )
+ llm_infra_config['image_creations_endpoint'] = image_endpoint_name
+
+ logger.info("Image Creations Endpoint Name: {}".format(image_endpoint_name))
+
+ if embeddings:
+ embeddings_endpoint_name, _ = api_cloud.foundation_model_deploy(
+ model_id=default_config['embeddings']['model'],
+ model_version=default_config['embeddings']['model_version'],
+ num_instances=default_config['embeddings']['num_instances'],
+ ec2_type=default_config['embeddings']['instance_type'],
+ aws_region=aws_region,
+ aws_profile=aws_profile,
+ aws_role=iam_role_arn,
+ external_id=external_id,
+ tags=aws_tags
+ )
+ llm_infra_config['embeddings_endpoint'] = embeddings_endpoint_name
+
+ logger.info("Embeddings Endpoint Name: {}".format(embeddings_endpoint_name))
+
+ with open('.sagify_llm_infra.json', 'w') as f:
+ json.dump(llm_infra_config, f)
+ except ValueError as e:
+ logger.info("{}".format(e))
+ sys.exit(-1)
+
+
+@click.command()
+@click.option(
+ u"--aws-profile",
+ required=True,
+ help="The AWS profile to use for the foundation model deploy command"
+)
+@click.option(
+ u"--aws-region",
+ required=True,
+ help="The AWS region to use for the foundation model deploy command"
+)
+@click.option(
+ u"-r",
+ u"--iam-role-arn",
+ required=False,
+ help="The AWS role to use for the train command"
+)
+@click.option(
+ u"-x",
+ u"--external-id",
+ required=False,
+ help="Optional external id used when using an IAM role"
+)
+def stop(aws_profile, aws_region, iam_role_arn, external_id):
+ """
+ Command to stop LLM infrastructure
+ """
+ logger.info(ASCII_LOGO)
+ logger.info("Stopping LLM infrastructure...\n")
+
+ sagemaker_client = sagemaker.SageMakerClient(aws_profile, aws_region, iam_role_arn, external_id)
+ try:
+ with open('.sagify_llm_infra.json', 'r') as f:
+ llm_infra_config = json.load(f)
+
+ for _endpoint in ['chat_completions_endpoint', 'image_creations_endpoint', 'embeddings_endpoint']:
+ if llm_infra_config[_endpoint]:
+ try:
+ sagemaker_client.shutdown_endpoint(llm_infra_config[_endpoint])
+ except Exception as e:
+ logger.info("{}".format(e))
+ sys.exit(-1)
+
+ logger.info("LLM infrastructure stopped successfully.")
+ except FileNotFoundError as e:
+ logger.info("{}".format(e))
+ sys.exit(-1)
+
+
+llm.add_command(start)
+llm.add_command(stop)
diff --git a/sagify/sagemaker/sagemaker.py b/sagify/sagemaker/sagemaker.py
index 3f051ca..8e470c9 100644
--- a/sagify/sagemaker/sagemaker.py
+++ b/sagify/sagemaker/sagemaker.py
@@ -728,6 +728,13 @@ def query_endpoint(model_predictor, payload, content_type, accept):
return example_query_code_snippet
+ def shutdown_endpoint(self, endpoint_name):
+ """
+ Shuts down a SageMaker endpoint.
+ :param endpoint_name: [str], name of the endpoint to be shut down
+ """
+ self.sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
+
@staticmethod
def _get_s3_bucket(s3_dir):
"""
diff --git a/tests/commands/test_llm.py b/tests/commands/test_llm.py
new file mode 100644
index 0000000..b1d09d6
--- /dev/null
+++ b/tests/commands/test_llm.py
@@ -0,0 +1,307 @@
+# -*- coding: utf-8 -*-
+import json
+import os
+try:
+ from unittest.mock import patch, call
+except ImportError:
+ from mock import patch, call
+
+from click.testing import CliRunner
+from sagify.__main__ import cli
+
+
+class TestLlmStart(object):
+ def test_start_all_happy_case(self):
+ runner = CliRunner()
+ with patch(
+ 'sagify.api.cloud.foundation_model_deploy'
+ ) as mocked_foundation_model_deploy:
+ mocked_foundation_model_deploy.side_effect = [
+ ('chat_completions_endpoint', 'some code snippet 1'),
+ ('image_creations_endpoint', 'some code snippet 2'),
+ ('embeddings_endpoint', 'some code snippet 3'),
+ ]
+ with runner.isolated_filesystem():
+ result = runner.invoke(
+ cli=cli,
+ args=[
+ 'llm', 'start',
+ '--all',
+ '--aws-region', 'us-east-1',
+ '--aws-profile', 'sagemaker-production'
+ ]
+ )
+
+ assert mocked_foundation_model_deploy.call_count == 3
+ mocked_foundation_model_deploy.assert_has_calls(
+ [
+ call(
+ model_id='meta-textgeneration-llama-2-7b-f',
+ model_version='1.*',
+ num_instances=1,
+ ec2_type='ml.g5.2xlarge',
+ aws_region='us-east-1',
+ aws_profile='sagemaker-production',
+ aws_role=None,
+ external_id=None,
+ tags=None
+ ),
+ call(
+ model_id='model-txt2img-stabilityai-stable-diffusion-v2-1-base',
+ model_version='1.*',
+ num_instances=1,
+ ec2_type='ml.p3.2xlarge',
+ aws_region='us-east-1',
+ aws_profile='sagemaker-production',
+ aws_role=None,
+ external_id=None,
+ tags=None
+ ),
+ call(
+ model_id='huggingface-sentencesimilarity-gte-small',
+ model_version='1.*',
+ num_instances=1,
+ ec2_type='ml.g5.2xlarge',
+ aws_region='us-east-1',
+ aws_profile='sagemaker-production',
+ aws_role=None,
+ external_id=None,
+ tags=None
+ )
+ ]
+ )
+
+ assert os.path.isfile('.sagify_llm_infra.json')
+
+ with open('.sagify_llm_infra.json', 'r') as f:
+ llm_infra_config = json.load(f)
+
+ assert llm_infra_config['chat_completions_endpoint'] is not None
+ assert llm_infra_config['image_creations_endpoint'] is not None
+ assert llm_infra_config['embeddings_endpoint'] is not None
+
+ assert result.exit_code == 0
+
+ def test_start_chat_completions_only(self):
+ runner = CliRunner()
+ with patch(
+ 'sagify.api.cloud.foundation_model_deploy'
+ ) as mocked_foundation_model_deploy:
+ mocked_foundation_model_deploy.side_effect = [
+ ('chat_completions_endpoint', 'some code snippet 1')
+ ]
+ with runner.isolated_filesystem():
+ result = runner.invoke(
+ cli=cli,
+ args=[
+ 'llm', 'start',
+ '--chat-completions',
+ '--aws-region', 'us-east-1',
+ '--aws-profile', 'sagemaker-production'
+ ]
+ )
+
+ assert mocked_foundation_model_deploy.call_count == 1
+ mocked_foundation_model_deploy.assert_called_with(
+ model_id='meta-textgeneration-llama-2-7b-f',
+ model_version='1.*',
+ num_instances=1,
+ ec2_type='ml.g5.2xlarge',
+ aws_region='us-east-1',
+ aws_profile='sagemaker-production',
+ aws_role=None,
+ external_id=None,
+ tags=None
+ )
+
+ assert os.path.isfile('.sagify_llm_infra.json')
+
+ with open('.sagify_llm_infra.json', 'r') as f:
+ llm_infra_config = json.load(f)
+
+ assert llm_infra_config['chat_completions_endpoint'] is not None
+ assert llm_infra_config['image_creations_endpoint'] is None
+ assert llm_infra_config['embeddings_endpoint'] is None
+
+ assert result.exit_code == 0
+
+ def test_start_image_creations_only(self):
+ runner = CliRunner()
+ with patch(
+ 'sagify.api.cloud.foundation_model_deploy'
+ ) as mocked_foundation_model_deploy:
+ mocked_foundation_model_deploy.side_effect = [
+ ('image_creations_endpoint', 'some code snippet 2')
+ ]
+ with runner.isolated_filesystem():
+ result = runner.invoke(
+ cli=cli,
+ args=[
+ 'llm', 'start',
+ '--image-creations',
+ '--aws-region', 'us-east-1',
+ '--aws-profile', 'sagemaker-production'
+ ]
+ )
+
+ assert mocked_foundation_model_deploy.call_count == 1
+ mocked_foundation_model_deploy.assert_called_with(
+ model_id='model-txt2img-stabilityai-stable-diffusion-v2-1-base',
+ model_version='1.*',
+ num_instances=1,
+ ec2_type='ml.p3.2xlarge',
+ aws_region='us-east-1',
+ aws_profile='sagemaker-production',
+ aws_role=None,
+ external_id=None,
+ tags=None
+ )
+
+ assert os.path.isfile('.sagify_llm_infra.json')
+
+ with open('.sagify_llm_infra.json', 'r') as f:
+ llm_infra_config = json.load(f)
+
+ assert llm_infra_config['chat_completions_endpoint'] is None
+ assert llm_infra_config['image_creations_endpoint'] is not None
+ assert llm_infra_config['embeddings_endpoint'] is None
+
+ assert result.exit_code == 0
+
+ def test_start_embeddings_only(self):
+ runner = CliRunner()
+ with patch(
+ 'sagify.api.cloud.foundation_model_deploy'
+ ) as mocked_foundation_model_deploy:
+ mocked_foundation_model_deploy.side_effect = [
+ ('embeddings_endpoint', 'some code snippet 3')
+ ]
+ with runner.isolated_filesystem():
+ result = runner.invoke(
+ cli=cli,
+ args=[
+ 'llm', 'start',
+ '--embeddings',
+ '--aws-region', 'us-east-1',
+ '--aws-profile', 'sagemaker-production'
+ ]
+ )
+
+ assert mocked_foundation_model_deploy.call_count == 1
+ mocked_foundation_model_deploy.assert_called_with(
+ model_id='huggingface-sentencesimilarity-gte-small',
+ model_version='1.*',
+ num_instances=1,
+ ec2_type='ml.g5.2xlarge',
+ aws_region='us-east-1',
+ aws_profile='sagemaker-production',
+ aws_role=None,
+ external_id=None,
+ tags=None
+ )
+
+ assert os.path.isfile('.sagify_llm_infra.json')
+
+ with open('.sagify_llm_infra.json', 'r') as f:
+ llm_infra_config = json.load(f)
+
+ assert llm_infra_config['chat_completions_endpoint'] is None
+ assert llm_infra_config['image_creations_endpoint'] is None
+ assert llm_infra_config['embeddings_endpoint'] is not None
+
+ assert result.exit_code == 0
+
+
+class TestLlmStop(object):
+ def test_stop_happy_case(self):
+ runner = CliRunner()
+ with patch(
+ 'sagify.commands.llm.sagemaker.SageMakerClient'
+ ) as mocked_sagemaker_client:
+ # from unittest.mock import MagicMock
+
+ # mocked_sagemaker_client.return_value = MagicMock()
+ with runner.isolated_filesystem():
+ with open('.sagify_llm_infra.json', 'w') as f:
+ json.dump({
+ 'chat_completions_endpoint': 'endpoint1',
+ 'image_creations_endpoint': 'endpoint2',
+ 'embeddings_endpoint': 'endpoint3'
+ }, f)
+
+ result = runner.invoke(
+ cli=cli,
+ args=[
+ 'llm', 'stop',
+ '--aws-region', 'us-east-1',
+ '--aws-profile', 'sagemaker-production',
+ '--iam-role-arn', 'arn:aws:iam::123456789012:role/MyRole',
+ '--external-id', '123456'
+ ]
+ )
+
+ mocked_sagemaker_client.assert_called_with(
+ 'sagemaker-production', 'us-east-1', 'arn:aws:iam::123456789012:role/MyRole', '123456'
+ )
+ assert mocked_sagemaker_client.return_value.shutdown_endpoint.call_count == 3
+ mocked_sagemaker_client.return_value.shutdown_endpoint.assert_has_calls(
+ [
+ call('endpoint1'),
+ call('endpoint2'),
+ call('endpoint3')
+ ]
+ )
+
+ assert result.exit_code == 0
+
+ def test_stop_missing_config_file(self):
+ runner = CliRunner()
+ with patch(
+ 'sagify.commands.llm.sagemaker.SageMakerClient'
+ ) as mocked_sagemaker_client:
+ # mocked_sagemaker_client.return_value = MagicMock()
+ with runner.isolated_filesystem():
+ result = runner.invoke(
+ cli=cli,
+ args=[
+ 'llm', 'stop',
+ '--aws-region', 'us-east-1',
+ '--aws-profile', 'sagemaker-production',
+ '--iam-role-arn', 'arn:aws:iam::123456789012:role/MyRole',
+ '--external-id', '123456'
+ ]
+ )
+
+ assert mocked_sagemaker_client.return_value.shutdown_endpoint.call_count == 0
+ assert result.exit_code == -1
+
+ def test_stop_endpoint_shutdown_error(self):
+ runner = CliRunner()
+ with patch(
+ 'sagify.commands.llm.sagemaker.SageMakerClient'
+ ) as mocked_sagemaker_client:
+ # mocked_sagemaker_client.return_value = MagicMock()
+ mocked_sagemaker_client.return_value.shutdown_endpoint.side_effect = Exception('Endpoint shutdown error')
+ with runner.isolated_filesystem():
+ with open('.sagify_llm_infra.json', 'w') as f:
+ json.dump({
+ 'chat_completions_endpoint': 'endpoint1',
+ 'image_creations_endpoint': 'endpoint2',
+ 'embeddings_endpoint': 'endpoint3'
+ }, f)
+
+ result = runner.invoke(
+ cli=cli,
+ args=[
+ 'llm', 'stop',
+ '--aws-region', 'us-east-1',
+ '--aws-profile', 'sagemaker-production',
+ '--iam-role-arn', 'arn:aws:iam::123456789012:role/MyRole',
+ '--external-id', '123456'
+ ]
+ )
+
+ assert mocked_sagemaker_client.return_value.shutdown_endpoint.call_count == 1
+ mocked_sagemaker_client.return_value.shutdown_endpoint.assert_called_with('endpoint1')
+
+ assert result.exit_code == -1