Skip to content

Commit

Permalink
Fixes to the huggingface sagemaker example (#143)
Browse files Browse the repository at this point in the history
* log invocation url and start with small instance

* fix types

* fix artifact logging function
  • Loading branch information
wjayesh authored Oct 31, 2024
1 parent 3e39a93 commit f5702b6
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
19 changes: 17 additions & 2 deletions huggingface-sagemaker/steps/deploying/sagemaker_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
# limitations under the License.
#

import os
from typing import Optional

from gradio.aws_helper import get_sagemaker_role, get_sagemaker_session
from sagemaker.huggingface import HuggingFaceModel
from typing_extensions import Annotated
from zenml import get_step_context, step
from zenml import get_step_context, log_artifact_metadata, step
from zenml.logger import get_logger

# Initialize logger
Expand All @@ -35,7 +36,7 @@ def deploy_hf_to_sagemaker(
pytorch_version: str = "1.13.1",
py_version: str = "py39",
hf_task: str = "text-classification",
instance_type: str = "ml.g5.2xlarge",
instance_type: str = "ml.t2.medium",
container_startup_health_check_timeout: int = 300,
) -> Annotated[str, "sagemaker_endpoint_name"]:
"""
Expand Down Expand Up @@ -83,4 +84,18 @@ def deploy_hf_to_sagemaker(
)
endpoint_name = predictor.endpoint_name
logger.info(f"Model deployed to SageMaker: {endpoint_name}")

# get region from env variable
region = os.environ["AWS_REGION"] or "eu-central-1"
invocation_url = f"https://runtime.sagemaker.{region}.amazonaws.com/endpoints/{endpoint_name}/invocations"

log_artifact_metadata(
artifact_name="sagemaker_endpoint_name",
metadata={
"invocation_url": invocation_url,
"endpoint_name": endpoint_name,
},
)


return endpoint_name
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

@step
def promote_metric_compare_promoter(
latest_metrics: Dict[str, str],
current_metrics: Dict[str, str],
latest_metrics: Dict[str, float],
current_metrics: Dict[str, float],
metric_to_compare: str = "accuracy",
):
"""Try to promote trained model.
Expand Down
4 changes: 3 additions & 1 deletion huggingface-sagemaker/steps/training/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def model_trainer(
eval_results = trainer.evaluate(metric_key_prefix="")

# Log the evaluation results in model control plane
log_artifact_metadata(output_name="model", metrics=eval_results)
log_artifact_metadata(
artifact_name="model", metadata={"metrics": eval_results}
)

return model, tokenizer

0 comments on commit f5702b6

Please sign in to comment.