Skip to content

Commit

Permalink
Add scoring_method for sorting best models
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Dec 17, 2024
1 parent 59fad3f commit 493ba4e
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/spikeinterface/curation/train_manual_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class CurationModelTrainer:
If True, useful information is printed during training.
search_kwargs : dict or None, default: None
Keyword arguments passed to `BayesSearchCV` or `RandomizedSearchCV` from `sklearn`. If None, use
`search_kwargs = {'cv': 3, 'scoring': 'balanced_accuracy', 'n_iter': 25}`.
`search_kwargs = {'cv': 5, 'scoring': 'balanced_accuracy', 'n_iter': 25}`.
Attributes
----------
Expand Down Expand Up @@ -544,7 +544,8 @@ def _evaluate(
)

test_accuracies, models = zip(*results)
self.test_accuracies_df = pd.DataFrame(test_accuracies).sort_values("accuracy", ascending=False)
scoring_method = self.search_kwargs.get("scoring")
self.test_accuracies_df = pd.DataFrame(test_accuracies).sort_values(scoring_method, ascending=False)

best_model_id = int(self.test_accuracies_df.iloc[0]["model_id"])
best_model, best_imputer, best_scaler = models[best_model_id]
Expand Down Expand Up @@ -598,8 +599,6 @@ def _train_and_evaluate(
model, param_space = self.get_classifier_search_space(classifier.__class__.__name__)
print("search kwargs:", search_kwargs, flush=True)
try:
print("now trying the classifier search...")

from skopt import BayesSearchCV

model = BayesSearchCV(
Expand All @@ -614,7 +613,7 @@ def _train_and_evaluate(
print("BayesSearchCV from scikit-optimize not available, using GridSearchCV")
from sklearn.model_selection import RandomizedSearchCV

model = RandomizedSearchCV(model, param_space, n_jobs=self.n_jobs, **search_kwargs, verbose=5)
model = RandomizedSearchCV(model, param_space, n_jobs=self.n_jobs, **search_kwargs)

model.fit(X_train_scaled, y_train)
y_pred = model.predict(X_test_scaled)
Expand All @@ -625,7 +624,7 @@ def _train_and_evaluate(
"classifier name": classifier.__class__.__name__,
"imputation_strategy": imputation_strategy,
"scaling_strategy": scaler,
"accuracy": balanced_acc,
"balanced_accuracy": balanced_acc,
"precision": precision,
"recall": recall,
"model_id": model_id,
Expand Down Expand Up @@ -790,7 +789,7 @@ def set_default_search_kwargs(search_kwargs):
search_kwargs = {}

if search_kwargs.get("cv") is None:
search_kwargs["cv"] = 3
search_kwargs["cv"] = 5
if search_kwargs.get("scoring") is None:
search_kwargs["scoring"] = "balanced_accuracy"
if search_kwargs.get("n_iter") is None:
Expand Down

0 comments on commit 493ba4e

Please sign in to comment.