Skip to content

Commit

Permalink
Add train using sorting_analyzer tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Nov 8, 2024
1 parent dbbb1b8 commit e99c3b9
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from pathlib import Path
from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation
from spikeinterface.curation.tests.common import make_sorting_analyzer
from spikeinterface.curation.model_based_curation import ModelBasedClassification
import numpy as np

Expand Down
63 changes: 63 additions & 0 deletions src/spikeinterface/curation/tests/test_train_manual_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import tempfile, csv
from pathlib import Path

from spikeinterface.curation.tests.common import make_sorting_analyzer


from spikeinterface.curation.train_manual_curation import CurationModelTrainer, train_model


Expand Down Expand Up @@ -141,3 +144,63 @@ def test_train_model():
overwrite=True,
)
assert isinstance(trainer, CurationModelTrainer)


def test_train_using_two_sorting_analyzers():

sorting_analyzer_1 = make_sorting_analyzer()
sorting_analyzer_1.compute({"quality_metrics": {"metric_names": ["num_spikes", "snr"]}})

sorting_analyzer_2 = make_sorting_analyzer()
sorting_analyzer_2.compute({"quality_metrics": {"metric_names": ["num_spikes", "snr"]}})

labels_1 = [0, 1, 1, 1, 1]
labels_2 = [1, 1, 0, 1, 1]

folder = tempfile.mkdtemp()
trainer = train_model(
analyzers=[sorting_analyzer_1, sorting_analyzer_2],
folder=folder,
labels=[labels_1, labels_2],
imputation_strategies=["median"],
scaling_techniques=["standard_scaler"],
classifiers=["LogisticRegression"],
overwrite=True,
)

assert isinstance(trainer, CurationModelTrainer)

# Xheck that there is an error raised if the metric names are different

sorting_analyzer_2 = make_sorting_analyzer()
sorting_analyzer_2.compute({"quality_metrics": {"metric_names": ["num_spikes"], "delete_existing_metrics": True}})

with pytest.raises(Exception):
trainer = train_model(
analyzers=[sorting_analyzer_1, sorting_analyzer_2],
folder=folder,
labels=[labels_1, labels_2],
imputation_strategies=["median"],
scaling_techniques=["standard_scaler"],
classifiers=["LogisticRegression"],
overwrite=True,
)

# Now check that there is an error raised if we demand the same metric params, but don't have them

sorting_analyzer_2.compute(
{"quality_metrics": {"metric_names": ["num_spikes", "snr"], "qm_params": {"snr": {"peak_mode": "at_index"}}}}
)

with pytest.raises(Exception):
train_model(
analyzers=[sorting_analyzer_1, sorting_analyzer_2],
folder=folder,
labels=[labels_1, labels_2],
imputation_strategies=["median"],
scaling_techniques=["standard_scaler"],
classifiers=["LogisticRegression"],
search_kwargs={"cv": 3, "scoring": "balanced_accuracy", "n_iter": 1},
overwrite=True,
enforce_metric_params=True,
)
38 changes: 34 additions & 4 deletions src/spikeinterface/curation/train_manual_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,10 @@ def load_and_preprocess_analyzers(self, analyzers, enforce_metric_params):
"""
import pandas as pd

self.testing_metrics = pd.concat(
[self._get_metrics_for_classification(an, an_index) for an_index, an in enumerate(analyzers)], axis=0
)
metrics_for_each_analyzer = [self._get_metrics_for_classification(an) for an in analyzers]
check_metric_names_are_the_same(metrics_for_each_analyzer)

self.testing_metrics = pd.concat(metrics_for_each_analyzer, axis=0)

# Set metric names to those calculated if not provided
if self.metric_names is None:
Expand Down Expand Up @@ -262,6 +263,12 @@ def _check_metrics_parameters(self, analyzers, enforce_metric_params):
if analyzer_index_1 <= analyzer_index_2:
continue
else:

qm_params_1 = {}
qm_params_2 = {}
tm_params_1 = {}
tm_params_2 = {}

if analyzer_1.has_extension("quality_metrics") is True:
qm_params_1 = analyzer_1.extensions["quality_metrics"].params["qm_params"]
if analyzer_2.has_extension("quality_metrics") is True:
Expand Down Expand Up @@ -484,7 +491,7 @@ def evaluate_model_config(self):
self.search_kwargs,
)

def _get_metrics_for_classification(self, analyzer, analyzer_index):
def _get_metrics_for_classification(self, analyzer):
"""Check if required metrics are present and return a DataFrame of metrics for classification."""

import pandas as pd
Expand Down Expand Up @@ -765,3 +772,26 @@ def set_default_search_kwargs(search_kwargs):
search_kwargs["n_iter"] = 25

return search_kwargs


def check_metric_names_are_the_same(metrics_for_each_analyzer):
"""
Given a list of dataframes, checks that the keys are all equal.
"""

for i, metrics_for_analyzer_1 in enumerate(metrics_for_each_analyzer):
for j, metrics_for_analyzer_2 in enumerate(metrics_for_each_analyzer):
if i > j:
metric_names_1 = set(metrics_for_analyzer_1.keys())
metric_names_2 = set(metrics_for_analyzer_2.keys())
if metric_names_1 != metric_names_2:
metrics_in_1_but_not_2 = metric_names_1.difference(metric_names_2)
metrics_in_2_but_not_1 = metric_names_2.difference(metric_names_1)
print(metrics_in_1_but_not_2)
print(metrics_in_2_but_not_1)
error_message = f"Computed metrics are not equal for sorting_analyzers #{j} and #{i}\n"
if metrics_in_1_but_not_2:
error_message += f"#{j} does not contain {metrics_in_1_but_not_2}, which #{i} does."
if metrics_in_2_but_not_1:
error_message += f"#{i} does not contain {metrics_in_2_but_not_1}, which #{j} does."
raise Exception(error_message)

0 comments on commit e99c3b9

Please sign in to comment.