diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py index bdd340459c..6d5fe9f3f3 100644 --- a/src/spikeinterface/curation/model_based_curation.py +++ b/src/spikeinterface/curation/model_based_curation.py @@ -4,7 +4,7 @@ import warnings from spikeinterface.core import SortingAnalyzer -from spikeinterface.curation.train_manual_curation import try_to_get_metrics_from_analyzer +from spikeinterface.curation.train_manual_curation import try_to_get_metrics_from_analyzer, _get_computed_metrics class ModelBasedClassification: @@ -75,7 +75,7 @@ def predict_labels( # Get metrics DataFrame for classification if input_data is None: - input_data = self._get_computed_metrics() + input_data = _get_computed_metrics(self.sorting_analyzer) else: if not isinstance(input_data, pd.DataFrame): raise ValueError("Input data must be a pandas DataFrame") @@ -122,23 +122,6 @@ def predict_labels( return classified_units - def _get_computed_metrics(self): - """Check if all required metrics are present and return a DataFrame of metrics for classification""" - - import pandas as pd - - quality_metrics, template_metrics = try_to_get_metrics_from_analyzer(self.sorting_analyzer) - calculated_metrics = pd.concat([quality_metrics, template_metrics], axis=1) - - # Remove any metrics for non-existent units, raise error if no units are present - calculated_metrics = calculated_metrics.loc[ - calculated_metrics.index.isin(self.sorting_analyzer.sorting.get_unit_ids()) - ] - if calculated_metrics.shape[0] == 0: - raise ValueError("No units present in sorting data") - - return calculated_metrics - def _check_required_metrics_are_present(self, calculated_metrics): # Check all the required metrics have been calculated diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py index 6521e65dd7..2c3515e06b 100644 --- a/src/spikeinterface/curation/tests/test_model_based_curation.py +++ b/src/spikeinterface/curation/tests/test_model_based_curation.py @@ -3,6 +3,8 @@ from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation from spikeinterface.curation.model_based_curation import ModelBasedClassification from spikeinterface.curation import auto_label_units, load_model +from spikeinterface.curation.train_manual_curation import _get_computed_metrics + import numpy as np if hasattr(pytest, "global_test_folder"): @@ -68,11 +70,11 @@ def test_model_based_classification_get_metrics_for_classification( # Check that ValueError is returned when quality_metrics are not present in sorting_analyzer with pytest.raises(ValueError): - computed_metrics = model_based_classification._get_computed_metrics() + computed_metrics = _get_computed_metrics(sorting_analyzer_for_curation) # Compute some (but not all) of the required metrics in sorting_analyzer sorting_analyzer_for_curation.compute("quality_metrics", metric_names=[required_metrics[0]]) - computed_metrics = model_based_classification._get_computed_metrics() + computed_metrics = _get_computed_metrics(sorting_analyzer_for_curation) with pytest.raises(ValueError): model_based_classification._check_required_metrics_are_present(computed_metrics) @@ -81,7 +83,7 @@ def test_model_based_classification_get_metrics_for_classification( sorting_analyzer_for_curation.compute("template_metrics", metric_names=[required_metrics[2]]) # Check that the metrics data is returned as a pandas DataFrame - metrics_data = model_based_classification._get_computed_metrics() + metrics_data = _get_computed_metrics(sorting_analyzer_for_curation) assert metrics_data.shape[0] == len(sorting_analyzer_for_curation.sorting.get_unit_ids()) assert set(metrics_data.columns.to_list()) == set(required_metrics) diff --git a/src/spikeinterface/curation/train_manual_curation.py b/src/spikeinterface/curation/train_manual_curation.py index 4b8399bc36..ccffbef0a2 100644 --- a/src/spikeinterface/curation/train_manual_curation.py +++ b/src/spikeinterface/curation/train_manual_curation.py @@ -224,7 +224,7 @@ def load_and_preprocess_analyzers(self, analyzers, enforce_metric_params): """ import pandas as pd - metrics_for_each_analyzer = [self._get_metrics_for_classification(an) for an in analyzers] + metrics_for_each_analyzer = [_get_computed_metrics(an) for an in analyzers] check_metric_names_are_the_same(metrics_for_each_analyzer) self.testing_metrics = pd.concat(metrics_for_each_analyzer, axis=0) @@ -494,23 +494,6 @@ def evaluate_model_config(self): self.search_kwargs, ) - def _get_metrics_for_classification(self, analyzer): - """Check if required metrics are present and return a DataFrame of metrics for classification.""" - - import pandas as pd - - quality_metrics, template_metrics = try_to_get_metrics_from_analyzer(analyzer) - - # Concatenate the available metrics - calculated_metrics = pd.concat([m for m in [quality_metrics, template_metrics] if m is not None], axis=1) - - # Remove any metrics for non-existent units, raise error if no units are present - calculated_metrics = calculated_metrics.loc[calculated_metrics.index.isin(analyzer.sorting.get_unit_ids())] - if calculated_metrics.shape[0] == 0: - raise ValueError("No units present in sorting data") - - return calculated_metrics - def _load_data_files(self, paths): import pandas as pd @@ -736,6 +719,22 @@ def train_model( return trainer +def _get_computed_metrics(sorting_analyzer): + """Loads and organises the computed metrics from a sorting_analyzer into a single dataframe""" + + import pandas as pd + + quality_metrics, template_metrics = try_to_get_metrics_from_analyzer(sorting_analyzer) + calculated_metrics = pd.concat([quality_metrics, template_metrics], axis=1) + + # Remove any metrics for non-existent units, raise error if no units are present + calculated_metrics = calculated_metrics.loc[calculated_metrics.index.isin(sorting_analyzer.sorting.get_unit_ids())] + if calculated_metrics.shape[0] == 0: + raise ValueError("No units present in sorting data") + + return calculated_metrics + + def try_to_get_metrics_from_analyzer(sorting_analyzer): quality_metrics = None