Skip to content

Commit

Permalink
deployment artifact type
Browse files Browse the repository at this point in the history
  • Loading branch information
avishniakov committed Feb 15, 2024
1 parent a274c72 commit ef190c9
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 20 deletions.
6 changes: 3 additions & 3 deletions classifier-e2e/pipelines/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def deploy(shutdown_endpoint_after_predicting: bool = True):
preprocess_pipeline=preprocess_pipeline,
target="target",
)
endpoint_name = deploy_endpoint()
predict_on_endpoint(endpoint_name, df_inference)
predictor = deploy_endpoint()
predict_on_endpoint(predictor, df_inference)
if shutdown_endpoint_after_predicting:
shutdown_endpoint(endpoint_name, after=["predict_on_endpoint"])
shutdown_endpoint(predictor, after=["predict_on_endpoint"])
44 changes: 36 additions & 8 deletions classifier-e2e/run_skip_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,7 @@
},
{
"cell_type": "code",
"execution_count": 121,
"execution_count": 123,
"id": "496adffa",
"metadata": {},
"outputs": [
Expand All @@ -1030,15 +1030,21 @@
"\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;32mimport\u001b[0m \u001b[0msagemaker\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;32mfrom\u001b[0m \u001b[0msagemaker\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimage_uris\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mretrieve\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;32mfrom\u001b[0m \u001b[0msagemaker\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mPredictor\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;32mfrom\u001b[0m \u001b[0mzenml\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_step_context\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;32mfrom\u001b[0m \u001b[0mzenml\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_step_context\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mArtifactConfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_artifact_metadata\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;32mfrom\u001b[0m \u001b[0mdatetime\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdatetime\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;32mfrom\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maws\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_aws_config\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m@\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menable_cache\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;32mdef\u001b[0m \u001b[0mdeploy_endpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mAnnotated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"sagemaker_endpoint_name\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;32mdef\u001b[0m \u001b[0mdeploy_endpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mAnnotated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mPredictor\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mArtifactConfig\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"sagemaker_endpoint_name\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_deployment_artifact\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mrole\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msession\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mregion\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_aws_config\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_step_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_model_version\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
Expand Down Expand Up @@ -1069,7 +1075,16 @@
"\u001b[0;34m\u001b[0m \u001b[0minstance_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"ml.m5.large\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mendpoint_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mendpoint_name\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mendpoint_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n"
"\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mlog_artifact_metadata\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m\"endpoint_name\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mendpoint_name\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m\"image_uri\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mimage_uri\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m\"role_arn\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mrole\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mPredictor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mendpoint_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mendpoint_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n"
]
}
],
Expand All @@ -1089,7 +1104,7 @@
},
{
"cell_type": "code",
"execution_count": 122,
"execution_count": 124,
"id": "9dfb5642",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -1118,10 +1133,10 @@
"\u001b[0;34m\u001b[0m \u001b[0mpreprocess_pipeline\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpreprocess_pipeline\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"target\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mendpoint_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdeploy_endpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mpredict_on_endpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mendpoint_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdf_inference\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mpredictor\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdeploy_endpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mpredict_on_endpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpredictor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdf_inference\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mshutdown_endpoint_after_predicting\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mshutdown_endpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mendpoint_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mafter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"predict_on_endpoint\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n"
"\u001b[0;34m\u001b[0m \u001b[0mshutdown_endpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpredictor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mafter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"predict_on_endpoint\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n"
]
}
],
Expand Down Expand Up @@ -1153,6 +1168,19 @@
")(shutdown_endpoint_after_predicting=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "41cec4dc",
"metadata": {},
"outputs": [],
"source": [
"# explore created endpoint\n",
"run_metadata = client.get_model_version(\"breast_cancer_classifier\", \"production\").get_artifact(\"sagemaker_endpoint\").run_metadata\n",
"for k,v in run_metadata.items():\n",
" print(k, v.value)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
25 changes: 21 additions & 4 deletions classifier-e2e/steps/deploy_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,23 @@

import sagemaker
from sagemaker.image_uris import retrieve
from sagemaker import Predictor

from zenml import step, get_step_context
from zenml import step, get_step_context, ArtifactConfig, log_artifact_metadata
from datetime import datetime

from utils.aws import get_aws_config
from utils.sagemaker_materializer import SagemakerPredictorMaterializer


@step(enable_cache=False)
def deploy_endpoint() -> Annotated[str, "sagemaker_endpoint_name"]:
@step(
enable_cache=False,
output_materializers=[SagemakerPredictorMaterializer],
)
def deploy_endpoint() -> Annotated[
Predictor,
ArtifactConfig(name="sagemaker_endpoint", is_deployment_artifact=True),
]:
role, session, region = get_aws_config()

model = get_step_context().model._get_model_version()
Expand Down Expand Up @@ -41,4 +49,13 @@ def deploy_endpoint() -> Annotated[str, "sagemaker_endpoint_name"]:
instance_type="ml.m5.large",
endpoint_name=endpoint_name,
)
return endpoint_name

log_artifact_metadata(
{
"endpoint_name": endpoint_name,
"image_uri": image_uri,
"role_arn": role,
}
)

return Predictor(endpoint_name=endpoint_name)
6 changes: 2 additions & 4 deletions classifier-e2e/steps/misc_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@

@step
def predict_on_endpoint(
endpoint_name: str, dataset: pd.DataFrame
predictor: Predictor, dataset: pd.DataFrame
) -> Annotated[pd.Series, "real_time_predictions"]:
predictor = Predictor(endpoint_name=endpoint_name)
predictions = predictor.predict(
data=dataset.to_csv(header=False, index=False),
initial_args={"ContentType": "text/csv"},
Expand All @@ -22,6 +21,5 @@ def predict_on_endpoint(


@step
def shutdown_endpoint(endpoint_name: str):
predictor = Predictor(endpoint_name=endpoint_name)
def shutdown_endpoint(predictor: Predictor):
predictor.delete_endpoint()
18 changes: 17 additions & 1 deletion classifier-e2e/utils/sagemaker_materializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from zenml.enums import ArtifactType
from zenml.io import fileio
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.materializers.built_in_materializer import BuiltInMaterializer
from sklearn.linear_model import SGDClassifier
from xgboost import XGBClassifier
import tarfile
import tempfile
import joblib
from sklearn.base import ClassifierMixin

from sagemaker import Predictor

class SagemakerMaterializer(BaseMaterializer):
ASSOCIATED_TYPES = (ClassifierMixin,)
Expand Down Expand Up @@ -70,3 +71,18 @@ def save(self, my_obj: ClassifierMixin) -> None:
overwrite=True,
)
fileio.remove(os.path.join(tempfile.gettempdir(), "model.tar.gz"))


class SagemakerPredictorMaterializer(BaseMaterializer):
ASSOCIATED_TYPES = (Predictor,)
ASSOCIATED_ARTIFACT_TYPE = ArtifactType.SERVICE

def load(
self, data_type: Type[Predictor]
) -> Predictor:
"""Read from artifact store."""
return Predictor(endpoint_name=BuiltInMaterializer(self.uri).load(str))

def save(self, my_obj: Predictor) -> None:
"""Write to artifact store."""
BuiltInMaterializer(self.uri).save(my_obj.endpoint_name)

0 comments on commit ef190c9

Please sign in to comment.