From 832b4cd32a3b3f2cb807c037c012280849f894f3 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 17 Dec 2024 14:44:59 +0000 Subject: [PATCH] Fix tests --- .../curation/tests/test_train_manual_curation.py | 3 +++ src/spikeinterface/curation/train_manual_curation.py | 11 ++++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/curation/tests/test_train_manual_curation.py b/src/spikeinterface/curation/tests/test_train_manual_curation.py index 0a3ca2d45b..f455fbdb9c 100644 --- a/src/spikeinterface/curation/tests/test_train_manual_curation.py +++ b/src/spikeinterface/curation/tests/test_train_manual_curation.py @@ -17,6 +17,7 @@ def trainer(): scaling_techniques = ["standard_scaler"] classifiers = ["LogisticRegression"] metric_names = ["metric1", "metric2", "metric3"] + search_kwargs = {"cv": 3} return CurationModelTrainer( labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]], folder=folder, @@ -24,6 +25,7 @@ def trainer(): imputation_strategies=imputation_strategies, scaling_techniques=scaling_techniques, classifiers=classifiers, + search_kwargs=search_kwargs, ) @@ -187,6 +189,7 @@ def test_train_model(): scaling_techniques=["standard_scaler"], classifiers=["LogisticRegression"], overwrite=True, + search_kwargs={"cv": 3, "scoring": "balanced_accuracy", "n_iter": 1}, ) assert isinstance(trainer, CurationModelTrainer) diff --git a/src/spikeinterface/curation/train_manual_curation.py b/src/spikeinterface/curation/train_manual_curation.py index 1c3b6de34a..af4ad8c6ef 100644 --- a/src/spikeinterface/curation/train_manual_curation.py +++ b/src/spikeinterface/curation/train_manual_curation.py @@ -544,7 +544,12 @@ def _evaluate( ) test_accuracies, models = zip(*results) - scoring_method = self.search_kwargs.get("scoring") + + if self.search_kwargs is None or self.search_kwargs.get("scoring"): + scoring_method = "balanced_accuracy" + else: + scoring_method = self.search_kwargs.get("scoring") + self.test_accuracies_df = pd.DataFrame(test_accuracies).sort_values(scoring_method, ascending=False) best_model_id = int(self.test_accuracies_df.iloc[0]["model_id"]) @@ -597,7 +602,7 @@ def _train_and_evaluate( if self.verbose is True: print(f"Running {classifier.__class__.__name__} with imputation {imputation_strategy} and scaling {scaler}") model, param_space = self.get_classifier_search_space(classifier.__class__.__name__) - print("search kwargs:", search_kwargs, flush=True) + try: from skopt import BayesSearchCV @@ -610,7 +615,7 @@ def _train_and_evaluate( ) except: if self.verbose is True: - print("BayesSearchCV from scikit-optimize not available, using GridSearchCV") + print("BayesSearchCV from scikit-optimize not available, using RandomizedSearchCV") from sklearn.model_selection import RandomizedSearchCV model = RandomizedSearchCV(model, param_space, n_jobs=self.n_jobs, **search_kwargs)