From e99c3b9d456bbffb410d00d721dd38c6f9d688fc Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Fri, 8 Nov 2024 09:07:29 +0000 Subject: [PATCH] Add train using sorting_analyzer tests --- .../tests/test_model_based_curation.py | 2 +- .../tests/test_train_manual_curation.py | 63 +++++++++++++++++++ .../curation/train_manual_curation.py | 38 +++++++++-- 3 files changed, 98 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py index 74db8e6996..0ef6dc3cc9 100644 --- a/src/spikeinterface/curation/tests/test_model_based_curation.py +++ b/src/spikeinterface/curation/tests/test_model_based_curation.py @@ -1,6 +1,6 @@ import pytest from pathlib import Path -from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation +from spikeinterface.curation.tests.common import make_sorting_analyzer from spikeinterface.curation.model_based_curation import ModelBasedClassification import numpy as np diff --git a/src/spikeinterface/curation/tests/test_train_manual_curation.py b/src/spikeinterface/curation/tests/test_train_manual_curation.py index 919640678f..759e560329 100644 --- a/src/spikeinterface/curation/tests/test_train_manual_curation.py +++ b/src/spikeinterface/curation/tests/test_train_manual_curation.py @@ -3,6 +3,9 @@ import tempfile, csv from pathlib import Path +from spikeinterface.curation.tests.common import make_sorting_analyzer + + from spikeinterface.curation.train_manual_curation import CurationModelTrainer, train_model @@ -141,3 +144,63 @@ def test_train_model(): overwrite=True, ) assert isinstance(trainer, CurationModelTrainer) + + +def test_train_using_two_sorting_analyzers(): + + sorting_analyzer_1 = make_sorting_analyzer() + sorting_analyzer_1.compute({"quality_metrics": {"metric_names": ["num_spikes", "snr"]}}) + + sorting_analyzer_2 = make_sorting_analyzer() + sorting_analyzer_2.compute({"quality_metrics": {"metric_names": ["num_spikes", "snr"]}}) + + labels_1 = [0, 1, 1, 1, 1] + labels_2 = [1, 1, 0, 1, 1] + + folder = tempfile.mkdtemp() + trainer = train_model( + analyzers=[sorting_analyzer_1, sorting_analyzer_2], + folder=folder, + labels=[labels_1, labels_2], + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + overwrite=True, + ) + + assert isinstance(trainer, CurationModelTrainer) + + # Xheck that there is an error raised if the metric names are different + + sorting_analyzer_2 = make_sorting_analyzer() + sorting_analyzer_2.compute({"quality_metrics": {"metric_names": ["num_spikes"], "delete_existing_metrics": True}}) + + with pytest.raises(Exception): + trainer = train_model( + analyzers=[sorting_analyzer_1, sorting_analyzer_2], + folder=folder, + labels=[labels_1, labels_2], + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + overwrite=True, + ) + + # Now check that there is an error raised if we demand the same metric params, but don't have them + + sorting_analyzer_2.compute( + {"quality_metrics": {"metric_names": ["num_spikes", "snr"], "qm_params": {"snr": {"peak_mode": "at_index"}}}} + ) + + with pytest.raises(Exception): + train_model( + analyzers=[sorting_analyzer_1, sorting_analyzer_2], + folder=folder, + labels=[labels_1, labels_2], + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + search_kwargs={"cv": 3, "scoring": "balanced_accuracy", "n_iter": 1}, + overwrite=True, + enforce_metric_params=True, + ) diff --git a/src/spikeinterface/curation/train_manual_curation.py b/src/spikeinterface/curation/train_manual_curation.py index d4a73ccec1..f70cbb3b57 100644 --- a/src/spikeinterface/curation/train_manual_curation.py +++ b/src/spikeinterface/curation/train_manual_curation.py @@ -220,9 +220,10 @@ def load_and_preprocess_analyzers(self, analyzers, enforce_metric_params): """ import pandas as pd - self.testing_metrics = pd.concat( - [self._get_metrics_for_classification(an, an_index) for an_index, an in enumerate(analyzers)], axis=0 - ) + metrics_for_each_analyzer = [self._get_metrics_for_classification(an) for an in analyzers] + check_metric_names_are_the_same(metrics_for_each_analyzer) + + self.testing_metrics = pd.concat(metrics_for_each_analyzer, axis=0) # Set metric names to those calculated if not provided if self.metric_names is None: @@ -262,6 +263,12 @@ def _check_metrics_parameters(self, analyzers, enforce_metric_params): if analyzer_index_1 <= analyzer_index_2: continue else: + + qm_params_1 = {} + qm_params_2 = {} + tm_params_1 = {} + tm_params_2 = {} + if analyzer_1.has_extension("quality_metrics") is True: qm_params_1 = analyzer_1.extensions["quality_metrics"].params["qm_params"] if analyzer_2.has_extension("quality_metrics") is True: @@ -484,7 +491,7 @@ def evaluate_model_config(self): self.search_kwargs, ) - def _get_metrics_for_classification(self, analyzer, analyzer_index): + def _get_metrics_for_classification(self, analyzer): """Check if required metrics are present and return a DataFrame of metrics for classification.""" import pandas as pd @@ -765,3 +772,26 @@ def set_default_search_kwargs(search_kwargs): search_kwargs["n_iter"] = 25 return search_kwargs + + +def check_metric_names_are_the_same(metrics_for_each_analyzer): + """ + Given a list of dataframes, checks that the keys are all equal. + """ + + for i, metrics_for_analyzer_1 in enumerate(metrics_for_each_analyzer): + for j, metrics_for_analyzer_2 in enumerate(metrics_for_each_analyzer): + if i > j: + metric_names_1 = set(metrics_for_analyzer_1.keys()) + metric_names_2 = set(metrics_for_analyzer_2.keys()) + if metric_names_1 != metric_names_2: + metrics_in_1_but_not_2 = metric_names_1.difference(metric_names_2) + metrics_in_2_but_not_1 = metric_names_2.difference(metric_names_1) + print(metrics_in_1_but_not_2) + print(metrics_in_2_but_not_1) + error_message = f"Computed metrics are not equal for sorting_analyzers #{j} and #{i}\n" + if metrics_in_1_but_not_2: + error_message += f"#{j} does not contain {metrics_in_1_but_not_2}, which #{i} does." + if metrics_in_2_but_not_1: + error_message += f"#{i} does not contain {metrics_in_2_but_not_1}, which #{j} does." + raise Exception(error_message)