Skip to content

Commit

Permalink
Merge branch 'fix/update-mcp-syntax-0.55.0' of github.com:zenml-io/ze…
Browse files Browse the repository at this point in the history
…nml-projects into fix/update-mcp-syntax-0.55.0
  • Loading branch information
htahir1 committed Jan 24, 2024
2 parents ad112f2 + 346c062 commit 4b057b8
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 14 deletions.
6 changes: 3 additions & 3 deletions huggingface-sagemaker/steps/deploying/sagemaker_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ def deploy_hf_to_sagemaker(
Args:
repo_name: The name of the repo to create/use on huggingface.
"""
# If repo_id and revision are not provided, get them from the model version
# If repo_id and revision are not provided, get them from the model
# Otherwise, use the provided values.
if repo_id is None or revision is None:
context = get_step_context()
mv = context.model
deployment_metadata = mv.get_data_artifact(name="huggingface_url").run_metadata
zenml_model = context.model
deployment_metadata = zenml_model.get_data_artifact(name="huggingface_url").run_metadata
repo_id = deployment_metadata["repo_id"].value
revision = deployment_metadata["revision"].value

Expand Down
2 changes: 1 addition & 1 deletion huggingface-sagemaker/steps/deploying/save_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def save_model_to_deploy():
logger.info(
f" Loading latest version of the model for stage {pipeline_extra['target_env']}..."
)
# Get the current model version
# Get the current model
current_zenml_model = get_step_context().model

# Load model and tokenizer from Model Control Plane
Expand Down
10 changes: 5 additions & 5 deletions huggingface-sagemaker/steps/promotion/promote_get_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ def promote_get_metrics() -> (
version: Version of the model to be retrieved.
Returns:
Metric value for a given model version.
Metric value for a given model.
"""
### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ###
pipeline_extra = get_step_context().pipeline_run.config.extra
zenml_client = Client()

# Get current model version metric in current run
# Get current model metric in current run
current_zenml_model = get_step_context().model
current_metrics = current_zenml_model.get_model_artifact("model").run_metadata["metrics"].value
logger.info(f"Current model version metrics are {current_metrics}")
logger.info(f"Current model metrics are {current_metrics}")

# Get latest saved model version metric in target environment
try:
Expand All @@ -69,9 +69,9 @@ def promote_get_metrics() -> (
latest_metrics = (
latest_zenml_model.get_model_artifact("model").run_metadata["metrics"].value
)
logger.info(f"Current model version metrics are {latest_metrics}")
logger.info(f"Current model metrics are {latest_metrics}")
else:
logger.info("No currently promoted model version found.")
logger.info("No currently promoted model found.")
latest_metrics = current_metrics
### YOUR CODE ENDS HERE ###

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,20 @@ def promote_metric_compare_promoter(
should_promote = True

if latest_metrics == current_metrics:
logger.info("No current model version found - promoting latest")
logger.info("No current model found - promoting latest")
else:
logger.info(
f"Latest model metric={latest_metrics[metric_to_compare]:.6f}\n"
f"Current model metric={current_metrics[metric_to_compare]:.6f}"
)
if latest_metrics[metric_to_compare] < current_metrics[metric_to_compare]:
logger.info(
"Current model versions outperformed latest versions - promoting current"
"Current model outperformed latest model - promoting current"
)

else:
logger.info(
"Latest model versions outperformed current versions - keeping latest"
"Latest model outperformed current model - keeping latest"
)
should_promote = False

Expand All @@ -82,6 +82,6 @@ def promote_metric_compare_promoter(
zenml_model.set_stage(pipeline_extra["target_env"], force=True)

logger.info(
f"Promoted current model version to {pipeline_extra['target_env']} environment"
f"Promoted current model to {pipeline_extra['target_env']} environment"
)
### YOUR CODE ENDS HERE ###
2 changes: 1 addition & 1 deletion stack-showcase/steps/model_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def model_evaluator(
for message in messages:
logger.warning(message)

artifact = get_step_context().model_version.get_artifact("model")
artifact = get_step_context().model.get_artifact("model")

log_artifact_metadata(
metadata={"train_accuracy": float(trn_acc), "test_accuracy": float(tst_acc)},
Expand Down

0 comments on commit 4b057b8

Please sign in to comment.