From 110466e0c898dde9535be8e139bca536a85f88bd Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Fri, 17 Nov 2023 10:59:21 +0100 Subject: [PATCH] move extra to parameters --- template/configs/train_config.yaml | 7 +++---- template/pipelines/training.py | 25 ++++++++++++++++++------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/template/configs/train_config.yaml b/template/configs/train_config.yaml index 5a946db..84ad0df 100644 --- a/template/configs/train_config.yaml +++ b/template/configs/train_config.yaml @@ -20,16 +20,12 @@ steps: parameters: name: {{ product_name }} {%- if metric_compare_promotion %} - compute_performance_metrics_on_current_data: - parameters: - target_env: {{ target_environment }} promote_with_metric_compare: {%- else %} promote_latest_version: {%- endif %} parameters: mlflow_model_name: {{ product_name }} - target_env: {{ target_environment }} notify_on_success: parameters: notify_on_success: False @@ -56,6 +52,9 @@ model_version: # pipeline level extra configurations extra: notify_on_failure: True +# pipeline level parameters +parameters: + target_env: {{ target_environment }} {%- if hyperparameters_tuning %} # This set contains all the model configurations that you want # to evaluate during hyperparameter tuning stage. diff --git a/template/pipelines/training.py b/template/pipelines/training.py index bd1a647..e6a1b71 100644 --- a/template/pipelines/training.py +++ b/template/pipelines/training.py @@ -1,7 +1,7 @@ # {% include 'template/license_header' %} -from typing import List, Optional +from typing import List, Optional, Any, Dict import random from steps import ( @@ -23,7 +23,7 @@ promote_latest_version, {%- endif %} ) -from zenml import pipeline, get_pipeline_context +from zenml import pipeline from zenml.logger import get_logger {%- if hyperparameters_tuning %} @@ -38,6 +38,12 @@ @pipeline(on_failure=notify_on_failure) def {{product_name}}_training( +{%- if hyperparameters_tuning %} + model_search_space: Dict[str,Any], +{%- else %} + model_configuration: Dict[str,Any], +{%- endif %} + target_env: str, test_size: float = 0.2, drop_na: Optional[bool] = None, normalize: Optional[bool] = None, @@ -54,6 +60,12 @@ def {{product_name}}_training( trains and evaluates a model. Args: +{%- if hyperparameters_tuning %} + model_search_space: Search space for hyperparameter tuning +{%- else %} + model_configuration: Configuration of the model to train +{%- endif %} + target_env: The environment to promote the model to test_size: Size of holdout set for training 0.0..1.0 drop_na: If `True` NA values will be removed from dataset normalize: If `True` dataset will be normalized with MinMaxScaler @@ -62,12 +74,10 @@ def {{product_name}}_training( min_test_accuracy: Threshold to stop execution if test set accuracy is lower fail_on_accuracy_quality_gates: If `True` and `min_train_accuracy` or `min_test_accuracy` are not met - execution will be interrupted early - """ ### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ### # Link all the steps together by calling them and passing the output # of one step as the input of the next step. - pipeline_extra = get_pipeline_context().extra ########## ETL stage ########## raw_data, target, _ = data_loader(random_state=random.randint(0,100)) dataset_trn, dataset_tst = train_data_splitter( @@ -86,7 +96,7 @@ def {{product_name}}_training( ########## Hyperparameter tuning stage ########## after = [] search_steps_prefix = "hp_tuning_search_" - for config_name,model_search_configuration in pipeline_extra["model_search_space"].items(): + for config_name,model_search_configuration in model_search_space.items(): step_name = f"{search_steps_prefix}{config_name}" hp_tuning_single_search( id=step_name, @@ -100,7 +110,6 @@ def {{product_name}}_training( after.append(step_name) best_model = hp_tuning_select_best_model(step_names=after, after=after) {%- else %} - model_configuration = pipeline_extra["model_configuration"] best_model = get_model_from_config( model_package=model_configuration["model_package"], model_class=model_configuration["model_class"], @@ -130,16 +139,18 @@ def {{product_name}}_training( {%- if metric_compare_promotion %} latest_metric,current_metric = compute_performance_metrics_on_current_data( dataset_tst=dataset_tst, + target_env=target_env, after=["model_evaluator"] ) promote_with_metric_compare( latest_metric=latest_metric, current_metric=current_metric, + target_env=target_env, ) last_step = "promote_with_metric_compare" {%- else %} - promote_latest_version(after=["model_evaluator"]) + promote_latest_version(target_env=target_env,after=["model_evaluator"]) last_step = "promote_latest_version" {%- endif %}