Skip to content

Commit

Permalink
Merge branch 'main' into feature/OSS-2487-rework-the-template-for-mod…
Browse files Browse the repository at this point in the history
…el-control-plane
  • Loading branch information
avishniakov committed Oct 19, 2023
2 parents 7cb71f3 + 80aaa6c commit 11b9a80
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 1 deletion.
78 changes: 78 additions & 0 deletions template/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# {% include 'template/license_header' %}

settings:
docker:
requirements:
- aws
{%- if data_quality_checks %}
- evidently
{%- endif %}
- kubeflow
- kubernetes
- mlflow
- sklearn
- slack
extra:
mlflow_model_name: e2e_use_case_model
{%- if target_environment == 'production' %}
target_env: Production
{%- else %}
target_env: Staging
{%- endif %}
notify_on_success: False
notify_on_failure: True
{%- if hyperparameters_tuning %}
# This set contains all the models that you want to evaluate
# during hyperparameter tuning stage.
model_search_space:
random_forest:
model_package: sklearn.ensemble
model_class: RandomForestClassifier
search_grid:
criterion:
- gini
- entropy
max_depth:
- 2
- 4
- 6
- 8
- 10
- 12
min_samples_leaf:
range:
start: 1
end: 10
n_estimators:
range:
start: 50
end: 500
step: 25
decision_tree:
model_package: sklearn.tree
model_class: DecisionTreeClassifier
search_grid:
criterion:
- gini
- entropy
max_depth:
- 2
- 4
- 6
- 8
- 10
- 12
min_samples_leaf:
range:
start: 1
end: 10
{%- else %}
# This model configuration will be used for the training stage.
model_configuration:
model_package: sklearn.tree
model_class: DecisionTreeClassifier
params:
criterion: gini
max_depth: 5
min_samples_leaf: 3
{%- endif %}
95 changes: 95 additions & 0 deletions template/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# {% include 'template/license_header' %}

from datetime import datetime as dt
import os
from typing import Optional

from zenml.artifacts.external_artifact import ExternalArtifact
from zenml.logger import get_logger

from pipelines import {{product_name}}_batch_inference, {{product_name}}_training

logger = get_logger(__name__)


def main(
no_cache: bool = False,
no_drop_na: bool = False,
no_normalize: bool = False,
drop_columns: Optional[str] = None,
test_size: float = 0.2,
min_train_accuracy: float = 0.8,
min_test_accuracy: float = 0.8,
fail_on_accuracy_quality_gates: bool = False,
only_inference: bool = False,
):
"""Main entry point for the pipeline execution.
This entrypoint is where everything comes together:
* configuring pipeline with the required parameters
(some of which may come from command line arguments)
* launching the pipeline
Args:
no_cache: If `True` cache will be disabled.
no_drop_na: If `True` NA values will not be dropped from the dataset.
no_normalize: If `True` normalization will not be done for the dataset.
drop_columns: List of comma-separated names of columns to drop from the dataset.
test_size: Percentage of records from the training dataset to go into the test dataset.
min_train_accuracy: Minimum acceptable accuracy on the train set.
min_test_accuracy: Minimum acceptable accuracy on the test set.
fail_on_accuracy_quality_gates: If `True` and any of minimal accuracy
thresholds are violated - the pipeline will fail. If `False` thresholds will
not affect the pipeline.
only_inference: If `True` only inference pipeline will be triggered.
"""

# Run a pipeline with the required parameters. This executes
# all steps in the pipeline in the correct order using the orchestrator
# stack component that is configured in your active ZenML stack.
pipeline_args = {
"config_path":os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"config.yaml",
)
}
if no_cache:
pipeline_args["enable_cache"] = False

if not only_inference:
# Execute Training Pipeline
run_args_train = {
"drop_na": not no_drop_na,
"normalize": not no_normalize,
"random_seed": 42,
"test_size": test_size,
"min_train_accuracy": min_train_accuracy,
"min_test_accuracy": min_test_accuracy,
"fail_on_accuracy_quality_gates": fail_on_accuracy_quality_gates,
}
if drop_columns:
run_args_train["drop_columns"] = drop_columns.split(",")

pipeline_args[
"run_name"
] = f"{{product_name}}_training_run_{dt.now().strftime('%Y_%m_%d_%H_%M_%S')}"
{{product_name}}_training.with_options(**pipeline_args)(**run_args_train)
logger.info("Training pipeline finished successfully!")

# Execute Batch Inference Pipeline
run_args_inference = {}
pipeline_args[
"run_name"
] = f"{{product_name}}_batch_inference_run_{dt.now().strftime('%Y_%m_%d_%H_%M_%S')}"
{{product_name}}_batch_inference.with_options(**pipeline_args)(**run_args_inference)

artifact = ExternalArtifact(
pipeline_name="{{product_name}}_batch_inference",
artifact_name="predictions",
)
logger.info(
"Batch inference pipeline finished successfully! "
"You can find predictions in Artifact Store using ID: "
f"`{str(artifact.get_artifact_id())}`."
)
7 changes: 6 additions & 1 deletion template/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,12 @@ def main(
# Run a pipeline with the required parameters. This executes
# all steps in the pipeline in the correct order using the orchestrator
# stack component that is configured in your active ZenML stack.
pipeline_args = {}
pipeline_args = {
"config_path":os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"config.yaml",
)
}
if no_cache:
pipeline_args["enable_cache"] = False

Expand Down

0 comments on commit 11b9a80

Please sign in to comment.