Skip to content

Commit

Permalink
fixed imports
Browse files Browse the repository at this point in the history
  • Loading branch information
bcdurak committed Nov 28, 2024
1 parent 1124356 commit b60e441
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions template/steps/model_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@

import pandas as pd
from sklearn.base import ClassifierMixin
from zenml import log_metadata, step, Client

from zenml import log_metadata, step
from zenml.client import Client
from zenml.logger import get_logger

logger = get_logger(__name__)


@step
def model_evaluator(
model: ClassifierMixin,
dataset_trn: pd.DataFrame,
dataset_tst: pd.DataFrame,
min_train_accuracy: float = 0.0,
min_test_accuracy: float = 0.0,
target: Optional[str] = "target",
model: ClassifierMixin,
dataset_trn: pd.DataFrame,
dataset_tst: pd.DataFrame,
min_train_accuracy: float = 0.0,
min_test_accuracy: float = 0.0,
target: Optional[str] = "target",
) -> float:
"""Evaluate a trained model.
Expand Down Expand Up @@ -63,17 +65,17 @@ def model_evaluator(
dataset_tst.drop(columns=[target]),
dataset_tst[target],
)
logger.info(f"Train accuracy={trn_acc*100:.2f}%")
logger.info(f"Test accuracy={tst_acc*100:.2f}%")
logger.info(f"Train accuracy={trn_acc * 100:.2f}%")
logger.info(f"Test accuracy={tst_acc * 100:.2f}%")

messages = []
if trn_acc < min_train_accuracy:
messages.append(
f"Train accuracy {trn_acc*100:.2f}% is below {min_train_accuracy*100:.2f}% !"
f"Train accuracy {trn_acc * 100:.2f}% is below {min_train_accuracy * 100:.2f}% !"
)
if tst_acc < min_test_accuracy:
messages.append(
f"Test accuracy {tst_acc*100:.2f}% is below {min_test_accuracy*100:.2f}% !"
f"Test accuracy {tst_acc * 100:.2f}% is below {min_test_accuracy * 100:.2f}% !"
)
else:
for message in messages:
Expand Down

0 comments on commit b60e441

Please sign in to comment.