Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename model version to a model #10

Merged
merged 2 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ jobs:
with:
stack-name: ${{ matrix.stack-name }}
python-version: ${{ matrix.python-version }}
ref-zenml: ${{ inputs.ref-zenml || 'main' }}
ref-zenml: ${{ inputs.ref-zenml || 'feature/OSSK-342-rename-model-version-to-a-model' }}
avishniakov marked this conversation as resolved.
Show resolved Hide resolved
ref-template: ${{ inputs.ref-template || github.ref }}
14 changes: 7 additions & 7 deletions template/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
{{product_name}}_deploy_pipeline,
)
from zenml.logger import get_logger
from zenml.model.model_version import ModelVersion
from zenml.model.model import Model
from zenml.enums import ModelStages

logger = get_logger(__name__)
Expand Down Expand Up @@ -182,15 +182,15 @@ def main(
"weight_decay": weight_decay,
}

model_version = ModelVersion(
model = Model(
name=zenml_model_name,
license="{{open_source_license}}",
description="Show case Model Control Plane.",
delete_new_version_on_failure=True,
tags=["sentiment_analysis", "huggingface"],
)

pipeline_args["model_version"] = model_version
pipeline_args["model"] = model

pipeline_args[
"run_name"
Expand All @@ -201,8 +201,8 @@ def main(
# Execute Promoting Pipeline
if promoting_pipeline:
run_args_promoting = {}
model_version = ModelVersion(name=zenml_model_name, version=ModelStages.LATEST)
pipeline_args["model_version"] = model_version
model = Model(name=zenml_model_name, version=ModelStages.LATEST)
pipeline_args["model"] = model
pipeline_args[
"run_name"
] = f"{{product_name}}_promoting_pipeline_run_{dt.now().strftime('%Y_%m_%d_%H_%M_%S')}"
Expand All @@ -212,11 +212,11 @@ def main(
if deploying_pipeline:
pipeline_args["enable_cache"] = False
# Deploying pipeline has new ZenML model config
model_version = ModelVersion(
model = Model(
name=zenml_model_name,
version=ModelStages("{{target_environment}}"),
)
pipeline_args["model_version"] = model_version
pipeline_args["model"] = model
run_args_deploying = {
"title": deployment_app_title,
"description": deployment_app_description,
Expand Down
2 changes: 1 addition & 1 deletion template/steps/deploying/save_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def save_model_to_deploy():
f" Loading latest version of the model for stage {pipeline_extra['target_env']}..."
)
# Get latest saved model version in target environment
latest_version = get_step_context().model_version
latest_version = get_step_context().model

# Load model and tokenizer from Model Control Plane
model = latest_version.load_artifact(name="model")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,21 @@ def promote_get_metrics(
zenml_client = Client()

# Get current model version metric in current run
model_version = get_step_context().model_version
current_version = model_version._get_model_version()
current_metrics = current_version.get_model_artifact("model").run_metadata["metrics"].value
model = get_step_context().model
current_metrics = model.get_model_artifact("model").run_metadata["metrics"].value
logger.info(f"Current model version metrics are {current_metrics}")

# Get latest saved model version metric in target environment
try:
latest_version = zenml_client.get_model_version(
model_name_or_id=model_version.name,
model_name_or_id=model.name,
model_version_name_or_number_or_id=ModelStages(pipeline_extra["target_env"]),
)
except KeyError:
latest_version = None
if latest_version:
latest_metrics = current_version.get_model_artifact("model").run_metadata["metrics"].value
logger.info(f"Current model version metrics are {latest_metrics}")
latest_metrics = latest_version.get_model_artifact("model").run_metadata["metrics"].value
logger.info(f"Latest model version metrics are {latest_metrics}")
else:
logger.info("No currently promoted model version found.")
latest_metrics = current_metrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ def promote_metric_compare_promoter(
should_promote = False

if should_promote:
model_version = get_step_context().model_version
model_version = model_version._get_model_version()
model_version.set_stage(pipeline_extra["target_env"], force=True)
model = get_step_context().model
model.set_stage(pipeline_extra["target_env"], force=True)

logger.info(
f"Promoted current model version to {pipeline_extra['target_env']} environment"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ def promote_current():
### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ###
pipeline_extra = get_step_context().pipeline_run.config.extra
logger.info(f"Promoting current model version")
model_version = get_step_context().model_version
model_version = model_version._get_model_version()
model_version.set_stage(pipeline_extra["target_env"], force=True)
model = get_step_context().model
model.set_stage(pipeline_extra["target_env"], force=True)
logger.info(
f"Current model version promoted to {pipeline_extra['target_env']}"
)
Expand Down