From ef190c9917cf38a84b121092631356e8fdb844be Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Thu, 15 Feb 2024 14:55:59 +0100 Subject: [PATCH] deployment artifact type --- classifier-e2e/pipelines/deploy.py | 6 +-- classifier-e2e/run_skip_basics.ipynb | 44 +++++++++++++++---- classifier-e2e/steps/deploy_endpoint.py | 25 +++++++++-- classifier-e2e/steps/misc_endpoint.py | 6 +-- .../utils/sagemaker_materializer.py | 18 +++++++- 5 files changed, 79 insertions(+), 20 deletions(-) diff --git a/classifier-e2e/pipelines/deploy.py b/classifier-e2e/pipelines/deploy.py index 8b67dad0..f9b77fc1 100644 --- a/classifier-e2e/pipelines/deploy.py +++ b/classifier-e2e/pipelines/deploy.py @@ -19,7 +19,7 @@ def deploy(shutdown_endpoint_after_predicting: bool = True): preprocess_pipeline=preprocess_pipeline, target="target", ) - endpoint_name = deploy_endpoint() - predict_on_endpoint(endpoint_name, df_inference) + predictor = deploy_endpoint() + predict_on_endpoint(predictor, df_inference) if shutdown_endpoint_after_predicting: - shutdown_endpoint(endpoint_name, after=["predict_on_endpoint"]) + shutdown_endpoint(predictor, after=["predict_on_endpoint"]) diff --git a/classifier-e2e/run_skip_basics.ipynb b/classifier-e2e/run_skip_basics.ipynb index 7c4b6e9b..95d969f7 100644 --- a/classifier-e2e/run_skip_basics.ipynb +++ b/classifier-e2e/run_skip_basics.ipynb @@ -1018,7 +1018,7 @@ }, { "cell_type": "code", - "execution_count": 121, + "execution_count": 123, "id": "496adffa", "metadata": {}, "outputs": [ @@ -1030,15 +1030,21 @@ "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mimport\u001b[0m \u001b[0msagemaker\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mfrom\u001b[0m \u001b[0msagemaker\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimage_uris\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mretrieve\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m\u001b[0;32mfrom\u001b[0m \u001b[0msagemaker\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mPredictor\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m\u001b[0;32mfrom\u001b[0m \u001b[0mzenml\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_step_context\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m\u001b[0;32mfrom\u001b[0m \u001b[0mzenml\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_step_context\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mArtifactConfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_artifact_metadata\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mfrom\u001b[0m \u001b[0mdatetime\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdatetime\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mfrom\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maws\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_aws_config\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m@\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menable_cache\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m\u001b[0;32mdef\u001b[0m \u001b[0mdeploy_endpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mAnnotated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"sagemaker_endpoint_name\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m\u001b[0;32mdef\u001b[0m \u001b[0mdeploy_endpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mAnnotated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mPredictor\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mArtifactConfig\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"sagemaker_endpoint_name\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_deployment_artifact\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mrole\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msession\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mregion\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_aws_config\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_step_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_model_version\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", @@ -1069,7 +1075,16 @@ "\u001b[0;34m\u001b[0m \u001b[0minstance_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"ml.m5.large\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mendpoint_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mendpoint_name\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mendpoint_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n" + "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mlog_artifact_metadata\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;34m\"endpoint_name\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mendpoint_name\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;34m\"image_uri\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mimage_uri\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;34m\"role_arn\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mrole\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mPredictor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mendpoint_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mendpoint_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n" ] } ], @@ -1089,7 +1104,7 @@ }, { "cell_type": "code", - "execution_count": 122, + "execution_count": 124, "id": "9dfb5642", "metadata": {}, "outputs": [ @@ -1118,10 +1133,10 @@ "\u001b[0;34m\u001b[0m \u001b[0mpreprocess_pipeline\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpreprocess_pipeline\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"target\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mendpoint_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdeploy_endpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mpredict_on_endpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mendpoint_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdf_inference\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mpredictor\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdeploy_endpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mpredict_on_endpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpredictor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdf_inference\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mshutdown_endpoint_after_predicting\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mshutdown_endpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mendpoint_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mafter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"predict_on_endpoint\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n" + "\u001b[0;34m\u001b[0m \u001b[0mshutdown_endpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpredictor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mafter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"predict_on_endpoint\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n" ] } ], @@ -1153,6 +1168,19 @@ ")(shutdown_endpoint_after_predicting=True)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "41cec4dc", + "metadata": {}, + "outputs": [], + "source": [ + "# explore created endpoint\n", + "run_metadata = client.get_model_version(\"breast_cancer_classifier\", \"production\").get_artifact(\"sagemaker_endpoint\").run_metadata\n", + "for k,v in run_metadata.items():\n", + " print(k, v.value)" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/classifier-e2e/steps/deploy_endpoint.py b/classifier-e2e/steps/deploy_endpoint.py index 3e549926..7f5ad977 100644 --- a/classifier-e2e/steps/deploy_endpoint.py +++ b/classifier-e2e/steps/deploy_endpoint.py @@ -2,15 +2,23 @@ import sagemaker from sagemaker.image_uris import retrieve +from sagemaker import Predictor -from zenml import step, get_step_context +from zenml import step, get_step_context, ArtifactConfig, log_artifact_metadata from datetime import datetime from utils.aws import get_aws_config +from utils.sagemaker_materializer import SagemakerPredictorMaterializer -@step(enable_cache=False) -def deploy_endpoint() -> Annotated[str, "sagemaker_endpoint_name"]: +@step( + enable_cache=False, + output_materializers=[SagemakerPredictorMaterializer], +) +def deploy_endpoint() -> Annotated[ + Predictor, + ArtifactConfig(name="sagemaker_endpoint", is_deployment_artifact=True), +]: role, session, region = get_aws_config() model = get_step_context().model._get_model_version() @@ -41,4 +49,13 @@ def deploy_endpoint() -> Annotated[str, "sagemaker_endpoint_name"]: instance_type="ml.m5.large", endpoint_name=endpoint_name, ) - return endpoint_name + + log_artifact_metadata( + { + "endpoint_name": endpoint_name, + "image_uri": image_uri, + "role_arn": role, + } + ) + + return Predictor(endpoint_name=endpoint_name) diff --git a/classifier-e2e/steps/misc_endpoint.py b/classifier-e2e/steps/misc_endpoint.py index 6bbd0198..aa6c7c8d 100644 --- a/classifier-e2e/steps/misc_endpoint.py +++ b/classifier-e2e/steps/misc_endpoint.py @@ -8,9 +8,8 @@ @step def predict_on_endpoint( - endpoint_name: str, dataset: pd.DataFrame + predictor: Predictor, dataset: pd.DataFrame ) -> Annotated[pd.Series, "real_time_predictions"]: - predictor = Predictor(endpoint_name=endpoint_name) predictions = predictor.predict( data=dataset.to_csv(header=False, index=False), initial_args={"ContentType": "text/csv"}, @@ -22,6 +21,5 @@ def predict_on_endpoint( @step -def shutdown_endpoint(endpoint_name: str): - predictor = Predictor(endpoint_name=endpoint_name) +def shutdown_endpoint(predictor: Predictor): predictor.delete_endpoint() diff --git a/classifier-e2e/utils/sagemaker_materializer.py b/classifier-e2e/utils/sagemaker_materializer.py index 33fdf353..aa886af9 100644 --- a/classifier-e2e/utils/sagemaker_materializer.py +++ b/classifier-e2e/utils/sagemaker_materializer.py @@ -4,13 +4,14 @@ from zenml.enums import ArtifactType from zenml.io import fileio from zenml.materializers.base_materializer import BaseMaterializer +from zenml.materializers.built_in_materializer import BuiltInMaterializer from sklearn.linear_model import SGDClassifier from xgboost import XGBClassifier import tarfile import tempfile import joblib from sklearn.base import ClassifierMixin - +from sagemaker import Predictor class SagemakerMaterializer(BaseMaterializer): ASSOCIATED_TYPES = (ClassifierMixin,) @@ -70,3 +71,18 @@ def save(self, my_obj: ClassifierMixin) -> None: overwrite=True, ) fileio.remove(os.path.join(tempfile.gettempdir(), "model.tar.gz")) + + +class SagemakerPredictorMaterializer(BaseMaterializer): + ASSOCIATED_TYPES = (Predictor,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.SERVICE + + def load( + self, data_type: Type[Predictor] + ) -> Predictor: + """Read from artifact store.""" + return Predictor(endpoint_name=BuiltInMaterializer(self.uri).load(str)) + + def save(self, my_obj: Predictor) -> None: + """Write to artifact store.""" + BuiltInMaterializer(self.uri).save(my_obj.endpoint_name)