Skip to content

Commit

Permalink
OSSK-342
Browse files Browse the repository at this point in the history
  • Loading branch information
avishniakov committed Jan 12, 2024
1 parent bd996ed commit 840a215
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ jobs:
with:
stack-name: ${{ matrix.stack-name }}
python-version: ${{ matrix.python-version }}
ref-zenml: ${{ inputs.ref-zenml || 'main' }}
ref-zenml: ${{ inputs.ref-zenml || 'feature/OSSK-342-rename-model-version-to-a-model' }}
ref-template: ${{ inputs.ref-template || github.ref }}
14 changes: 7 additions & 7 deletions template/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
{{product_name}}_deploy_pipeline,
)
from zenml.logger import get_logger
from zenml.model.model_version import ModelVersion
from zenml.model.model import Model
from zenml.enums import ModelStages

logger = get_logger(__name__)
Expand Down Expand Up @@ -182,15 +182,15 @@ def main(
"weight_decay": weight_decay,
}

model_version = ModelVersion(
model = Model(
name=zenml_model_name,
license="{{open_source_license}}",
description="Show case Model Control Plane.",
delete_new_version_on_failure=True,
tags=["sentiment_analysis", "huggingface"],
)

pipeline_args["model_version"] = model_version
pipeline_args["model"] = model

pipeline_args[
"run_name"
Expand All @@ -201,8 +201,8 @@ def main(
# Execute Promoting Pipeline
if promoting_pipeline:
run_args_promoting = {}
model_version = ModelVersion(name=zenml_model_name, version=ModelStages.LATEST)
pipeline_args["model_version"] = model_version
model = Model(name=zenml_model_name, version=ModelStages.LATEST)
pipeline_args["model"] = model
pipeline_args[
"run_name"
] = f"{{product_name}}_promoting_pipeline_run_{dt.now().strftime('%Y_%m_%d_%H_%M_%S')}"
Expand All @@ -212,11 +212,11 @@ def main(
if deploying_pipeline:
pipeline_args["enable_cache"] = False
# Deploying pipeline has new ZenML model config
model_version = ModelVersion(
model = Model(
name=zenml_model_name,
version=ModelStages("{{target_environment}}"),
)
pipeline_args["model_version"] = model_version
pipeline_args["model"] = model
run_args_deploying = {
"title": deployment_app_title,
"description": deployment_app_description,
Expand Down
2 changes: 1 addition & 1 deletion template/steps/deploying/save_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def save_model_to_deploy():
f" Loading latest version of the model for stage {pipeline_extra['target_env']}..."
)
# Get latest saved model version in target environment
latest_version = get_step_context().model_version
latest_version = get_step_context().model

# Load model and tokenizer from Model Control Plane
model = latest_version.load_artifact(name="model")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,14 @@ def promote_get_metrics(
zenml_client = Client()

# Get current model version metric in current run
model_version = get_step_context().model_version
current_version = model_version._get_model_version()
model = get_step_context().model
current_metrics = current_version.get_model_artifact("model").run_metadata["metrics"].value
logger.info(f"Current model version metrics are {current_metrics}")

# Get latest saved model version metric in target environment
try:
latest_version = zenml_client.get_model_version(
model_name_or_id=model_version.name,
model_name_or_id=model.name,
model_version_name_or_number_or_id=ModelStages(pipeline_extra["target_env"]),
)
except KeyError:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ def promote_metric_compare_promoter(
should_promote = False

if should_promote:
model_version = get_step_context().model_version
model_version = model_version._get_model_version()
model_version.set_stage(pipeline_extra["target_env"], force=True)
model = get_step_context().model
model.set_stage(pipeline_extra["target_env"], force=True)

logger.info(
f"Promoted current model version to {pipeline_extra['target_env']} environment"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ def promote_current():
### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ###
pipeline_extra = get_step_context().pipeline_run.config.extra
logger.info(f"Promoting current model version")
model_version = get_step_context().model_version
model_version = model_version._get_model_version()
model_version.set_stage(pipeline_extra["target_env"], force=True)
model = get_step_context().model
model.set_stage(pipeline_extra["target_env"], force=True)
logger.info(
f"Current model version promoted to {pipeline_extra['target_env']}"
)
Expand Down

0 comments on commit 840a215

Please sign in to comment.