Skip to content

Commit

Permalink
Merge pull request #6 from zenml-io/fix/update-template-with-new-names
Browse files Browse the repository at this point in the history
update model version names
  • Loading branch information
safoinme authored Nov 28, 2023
2 parents 0bd65ef + 4f04b29 commit 0cf04fa
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 28 deletions.
14 changes: 7 additions & 7 deletions template/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
{{product_name}}_deploy_pipeline,
)
from zenml.logger import get_logger
from zenml.model.model_config import ModelConfig
from zenml.model.model_version import ModelVersion
from zenml.enums import ModelStages

logger = get_logger(__name__)
Expand Down Expand Up @@ -182,15 +182,15 @@ def main(
"weight_decay": weight_decay,
}

model_config = ModelConfig(
model_version = ModelVersion(
name=zenml_model_name,
license="{{open_source_license}}",
description="Show case Model Control Plane.",
delete_new_version_on_failure=True,
tags=["sentiment_analysis", "huggingface"],
)

pipeline_args["model_config"] = model_config
pipeline_args["model_version"] = model_version

pipeline_args[
"run_name"
Expand All @@ -201,8 +201,8 @@ def main(
# Execute Promoting Pipeline
if promoting_pipeline:
run_args_promoting = {}
model_config = ModelConfig(name=zenml_model_name, version=ModelStages.LATEST)
pipeline_args["model_config"] = model_config
model_version = ModelVersion(name=zenml_model_name, version=ModelStages.LATEST)
pipeline_args["model_version"] = model_version
pipeline_args[
"run_name"
] = f"{{product_name}}_promoting_pipeline_run_{dt.now().strftime('%Y_%m_%d_%H_%M_%S')}"
Expand All @@ -212,11 +212,11 @@ def main(
if deploying_pipeline:
pipeline_args["enable_cache"] = False
# Deploying pipeline has new ZenML model config
model_config = ModelConfig(
model_version = ModelVersion(
name=zenml_model_name,
version=ModelStages("{{target_environment}}"),
)
pipeline_args["model_config"] = model_config
pipeline_args["model_version"] = model_version
run_args_deploying = {
"title": deployment_app_title,
"description": deployment_app_description,
Expand Down
9 changes: 3 additions & 6 deletions template/steps/deploying/save_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@


from zenml import get_step_context, step
from zenml.client import Client
from zenml.logger import get_logger

# Initialize logger
Expand All @@ -25,17 +24,15 @@ def save_model_to_deploy():
"""
### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ###
pipeline_extra = get_step_context().pipeline_run.config.extra
zenml_client = Client()

logger.info(
f" Loading latest version of the model for stage {pipeline_extra['target_env']}..."
)
# Get latest saved model version in target environment
latest_version = get_step_context().model_config._get_model_version()
latest_version = get_step_context().model_version._get_model_version()

# Load model and tokenizer from Model Control Plane
model = latest_version.get_model_object(name="model").load()
tokenizer = latest_version.get_model_object(name="tokenizer").load()
model = latest_version.load_artifact(name="model")
tokenizer = latest_version.load_artifact(name="tokenizer")
# Save the model and tokenizer locally
model_path = "./gradio/model" # replace with the actual path
tokenizer_path = "./gradio/tokenizer" # replace with the actual path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,21 @@ def promote_get_metrics(
zenml_client = Client()

# Get current model version metric in current run
model_config = get_step_context().model_config
current_version = model_config._get_model_version()
current_metrics = current_version.get_model_object(name="model").metadata["metrics"].value
model_version = get_step_context().model_version
current_version = model_version._get_model_version()
current_metrics = current_version.get_model_artifact("model").run_metadata["metrics"].value
logger.info(f"Current model version metrics are {current_metrics}")

# Get latest saved model version metric in target environment
try:
latest_version = zenml_client.get_model_version(
model_name_or_id=model_config.name,
model_name_or_id=model_version.name,
model_version_name_or_number_or_id=ModelStages(pipeline_extra["target_env"]),
)
except KeyError:
latest_version = None
if latest_version:
latest_metrics = current_version.get_model_object(name="model").metadata["metrics"].value
latest_metrics = current_version.get_model_artifact("model").run_metadata["metrics"].value
logger.info(f"Current model version metrics are {latest_metrics}")
else:
logger.info("No currently promoted model version found.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def promote_metric_compare_promoter(
should_promote = False

if should_promote:
model_config = get_step_context().model_config
model_version = model_config._get_model_version()
model_version = get_step_context().model_version
model_version = model_version._get_model_version()
model_version.set_stage(pipeline_extra["target_env"], force=True)

logger.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def promote_current():
### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ###
pipeline_extra = get_step_context().pipeline_run.config.extra
logger.info(f"Promoting current model version")
model_config = get_step_context().model_config
model_version = model_config._get_model_version()
model_version = get_step_context().model_version
model_version = model_version._get_model_version()
model_version.set_stage(pipeline_extra["target_env"], force=True)
logger.info(
f"Current model version promoted to {pipeline_extra['target_env']}"
Expand Down
10 changes: 6 additions & 4 deletions template/steps/training/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
TrainingArguments,
AutoModelForSequenceClassification,
)
from zenml import log_artifact_metadata, step
from zenml import ArtifactConfig, log_artifact_metadata, step
from zenml.client import Client
from zenml.integrations.mlflow.experiment_trackers import MLFlowExperimentTracker
from zenml.logger import get_logger
from zenml.model import ModelArtifactConfig
from utils.misc import compute_metrics

# Initialize logger
Expand Down Expand Up @@ -47,7 +46,7 @@ def model_trainer(
eval_batch_size: Optional[int] = 16,
weight_decay: Optional[float] = 0.01,
mlflow_model_name: Optional[str] = "sentiment_analysis",
) -> Tuple[Annotated[PreTrainedModel, "model", ModelArtifactConfig(overwrite=True)], Annotated[PreTrainedTokenizerBase, "tokenizer", ModelArtifactConfig(overwrite=True)]]:
) -> Tuple[Annotated[PreTrainedModel, ArtifactConfig(name="model", is_model_artifact=True)], Annotated[PreTrainedTokenizerBase, ArtifactConfig(name="tokenizer", is_model_artifact=True)]]:
"""
Configure and train a model on the training dataset.
Expand Down Expand Up @@ -136,7 +135,10 @@ def model_trainer(
eval_results = trainer.evaluate(metric_key_prefix="")

# Log the evaluation results in model control plane
log_artifact_metadata(output_name="model", metrics=eval_results)
log_artifact_metadata(
metadata={"metrics": eval_results},
artifact_name="model",
)
### YOUR CODE ENDS HERE ###

return model, tokenizer
2 changes: 1 addition & 1 deletion template/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# {% include 'template/license_header' %}

from typing import Dict, Tuple, List
from typing import Dict, List, Tuple

import numpy as np
from datasets import load_metric
Expand Down
2 changes: 1 addition & 1 deletion tests/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def generate_and_run_project(
cloud_of_choice: str = "gcp",
dataset: str = "airline_reviews",
zenml_model_name: str = "sentiment_analysis",

):
"""Generate and run the starter project with different options."""

Expand Down Expand Up @@ -162,6 +161,7 @@ def test_latest_promotion(
tmp_path_factory=tmp_path_factory, metric_compare_promotion=False
)


def test_production_environment(
clean_zenml_client,
tmp_path_factory: pytest.TempPathFactory,
Expand Down

0 comments on commit 0cf04fa

Please sign in to comment.