From b60e4416b9fcc5ac1d15051c20f8034432483041 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Thu, 28 Nov 2024 15:31:24 +0100 Subject: [PATCH] fixed imports --- template/steps/model_evaluator.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/template/steps/model_evaluator.py b/template/steps/model_evaluator.py index 0a1511c..fe6c6a0 100644 --- a/template/steps/model_evaluator.py +++ b/template/steps/model_evaluator.py @@ -4,7 +4,9 @@ import pandas as pd from sklearn.base import ClassifierMixin -from zenml import log_metadata, step, Client + +from zenml import log_metadata, step +from zenml.client import Client from zenml.logger import get_logger logger = get_logger(__name__) @@ -12,12 +14,12 @@ @step def model_evaluator( - model: ClassifierMixin, - dataset_trn: pd.DataFrame, - dataset_tst: pd.DataFrame, - min_train_accuracy: float = 0.0, - min_test_accuracy: float = 0.0, - target: Optional[str] = "target", + model: ClassifierMixin, + dataset_trn: pd.DataFrame, + dataset_tst: pd.DataFrame, + min_train_accuracy: float = 0.0, + min_test_accuracy: float = 0.0, + target: Optional[str] = "target", ) -> float: """Evaluate a trained model. @@ -63,17 +65,17 @@ def model_evaluator( dataset_tst.drop(columns=[target]), dataset_tst[target], ) - logger.info(f"Train accuracy={trn_acc*100:.2f}%") - logger.info(f"Test accuracy={tst_acc*100:.2f}%") + logger.info(f"Train accuracy={trn_acc * 100:.2f}%") + logger.info(f"Test accuracy={tst_acc * 100:.2f}%") messages = [] if trn_acc < min_train_accuracy: messages.append( - f"Train accuracy {trn_acc*100:.2f}% is below {min_train_accuracy*100:.2f}% !" + f"Train accuracy {trn_acc * 100:.2f}% is below {min_train_accuracy * 100:.2f}% !" ) if tst_acc < min_test_accuracy: messages.append( - f"Test accuracy {tst_acc*100:.2f}% is below {min_test_accuracy*100:.2f}% !" + f"Test accuracy {tst_acc * 100:.2f}% is below {min_test_accuracy * 100:.2f}% !" ) else: for message in messages: