Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Dec 17, 2024
1 parent 493ba4e commit 832b4cd
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ def trainer():
scaling_techniques = ["standard_scaler"]
classifiers = ["LogisticRegression"]
metric_names = ["metric1", "metric2", "metric3"]
search_kwargs = {"cv": 3}
return CurationModelTrainer(
labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]],
folder=folder,
metric_names=metric_names,
imputation_strategies=imputation_strategies,
scaling_techniques=scaling_techniques,
classifiers=classifiers,
search_kwargs=search_kwargs,
)


Expand Down Expand Up @@ -187,6 +189,7 @@ def test_train_model():
scaling_techniques=["standard_scaler"],
classifiers=["LogisticRegression"],
overwrite=True,
search_kwargs={"cv": 3, "scoring": "balanced_accuracy", "n_iter": 1},
)
assert isinstance(trainer, CurationModelTrainer)

Expand Down
11 changes: 8 additions & 3 deletions src/spikeinterface/curation/train_manual_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,12 @@ def _evaluate(
)

test_accuracies, models = zip(*results)
scoring_method = self.search_kwargs.get("scoring")

if self.search_kwargs is None or self.search_kwargs.get("scoring"):
scoring_method = "balanced_accuracy"
else:
scoring_method = self.search_kwargs.get("scoring")

self.test_accuracies_df = pd.DataFrame(test_accuracies).sort_values(scoring_method, ascending=False)

best_model_id = int(self.test_accuracies_df.iloc[0]["model_id"])
Expand Down Expand Up @@ -597,7 +602,7 @@ def _train_and_evaluate(
if self.verbose is True:
print(f"Running {classifier.__class__.__name__} with imputation {imputation_strategy} and scaling {scaler}")
model, param_space = self.get_classifier_search_space(classifier.__class__.__name__)
print("search kwargs:", search_kwargs, flush=True)

try:
from skopt import BayesSearchCV

Expand All @@ -610,7 +615,7 @@ def _train_and_evaluate(
)
except:
if self.verbose is True:
print("BayesSearchCV from scikit-optimize not available, using GridSearchCV")
print("BayesSearchCV from scikit-optimize not available, using RandomizedSearchCV")
from sklearn.model_selection import RandomizedSearchCV

model = RandomizedSearchCV(model, param_space, n_jobs=self.n_jobs, **search_kwargs)
Expand Down

0 comments on commit 832b4cd

Please sign in to comment.