diff --git a/template/steps/model_evaluator.py b/template/steps/model_evaluator.py index 8ae51c0..0a1511c 100644 --- a/template/steps/model_evaluator.py +++ b/template/steps/model_evaluator.py @@ -4,7 +4,7 @@ import pandas as pd from sklearn.base import ClassifierMixin -from zenml import log_, step, Client +from zenml import log_metadata, step, Client from zenml.logger import get_logger logger = get_logger(__name__) @@ -81,6 +81,13 @@ def model_evaluator( client = Client() latest_classifier = client.get_artifact_version("sklearn_classifier") - log_metadata(metadata=metadata, artifact_version_id=latest_classifier.id) + + log_metadata( + metadata={ + "train_accuracy": float(trn_acc), + "test_accuracy": float(tst_acc) + }, + artifact_version_id=latest_classifier.id + ) return float(tst_acc)