From b9ae95bb4ee682988381ba645f222961b98a692b Mon Sep 17 00:00:00 2001
From: pm3310
Date: Fri, 19 Jan 2024 21:36:47 +0000
Subject: [PATCH] Simplify deployment of foundation model
---
sagify/sagemaker/sagemaker.py | 39 ++++++-----------------------------
1 file changed, 6 insertions(+), 33 deletions(-)
diff --git a/sagify/sagemaker/sagemaker.py b/sagify/sagemaker/sagemaker.py
index cf961f3..3f051ca 100644
--- a/sagify/sagemaker/sagemaker.py
+++ b/sagify/sagemaker/sagemaker.py
@@ -8,8 +8,8 @@
import sagemaker.huggingface
import sagemaker.xgboost
import sagemaker.sklearn.model
-from sagemaker import image_uris, model_uris, payloads
-from sagemaker.predictor import Predictor
+from sagemaker import payloads
+from sagemaker.jumpstart.model import JumpStartModel
from six.moves.urllib.parse import urlparse
import boto3
@@ -604,44 +604,17 @@ def deploy_foundation_model(
:return: [str], endpoint name
"""
- deploy_image_uri = image_uris.retrieve(
- region=self.aws_region,
- framework=None, # automatically inferred from model_id
- image_scope="inference",
- model_id=model_id,
- model_version=model_version,
- instance_type=instance_type,
- sagemaker_session=self.sagemaker_session
- )
-
- model_uri = model_uris.retrieve(
+ model = JumpStartModel(
model_id=model_id,
model_version=model_version,
- model_scope="inference",
region=self.aws_region,
- sagemaker_session=self.sagemaker_session
- )
-
- # Increase the maximum response size from the endpoint
- env = {
- "MMS_MAX_RESPONSE_SIZE": "20000000",
- }
-
- model = sage.Model(
- image_uri=deploy_image_uri,
- model_data=model_uri,
- role=self.role,
- predictor_cls=Predictor,
- name=endpoint_name,
- env=env,
- sagemaker_session=self.sagemaker_session
+ sagemaker_session=self.sagemaker_session,
+ tolerate_deprecated_model=True,
+ tolerate_vulnerable_model=True
)
-
model_predictor = model.deploy(
initial_instance_count=instance_count,
instance_type=instance_type,
- predictor_cls=Predictor,
- endpoint_name=endpoint_name,
tags=tags,
accept_eula=True
)