From 65847c17b364cc37338a3b661e43e134059ca849 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 28 Nov 2024 10:30:54 +0000 Subject: [PATCH] Respond to `model_based_curation` review --- .../curation/model_based_curation.py | 21 +++++---- .../tests/test_model_based_curation.py | 45 +++++++++++++++++-- 2 files changed, 53 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py index 0c12e2f5e5..66fdd6118e 100644 --- a/src/spikeinterface/curation/model_based_curation.py +++ b/src/spikeinterface/curation/model_based_curation.py @@ -75,11 +75,13 @@ def predict_labels( # Get metrics DataFrame for classification if input_data is None: - input_data = self._get_metrics_for_classification() + input_data = self._get_computed_metrics() else: if not isinstance(input_data, pd.DataFrame): raise ValueError("Input data must be a pandas DataFrame") + input_data = self._check_required_metrics_are_present(input_data) + if model_info is not None: self._check_params_for_classification(enforce_metric_params, model_info=model_info) @@ -103,9 +105,7 @@ def predict_labels( probabilities = np.max(probabilities, axis=1) if isinstance(label_conversion, dict): - try: - assert set(predictions).issubset(label_conversion.keys()) - except AssertionError: + if set(predictions) != set(label_conversion.keys()): raise ValueError("Labels in predictions do not match those in label_conversion") predictions = [label_conversion[label] for label in predictions] @@ -122,14 +122,12 @@ def predict_labels( return classified_units - def _get_metrics_for_classification(self): + def _get_computed_metrics(self): """Check if all required metrics are present and return a DataFrame of metrics for classification""" import pandas as pd quality_metrics, template_metrics = try_to_get_metrics_from_analyzer(self.sorting_analyzer) - - # Create DataFrame of all metrics and reorder columns to match the model calculated_metrics = pd.concat([quality_metrics, template_metrics], axis=1) # Remove any metrics for non-existent units, raise error if no units are present @@ -139,6 +137,10 @@ def _get_metrics_for_classification(self): if calculated_metrics.shape[0] == 0: raise ValueError("No units present in sorting data") + return calculated_metrics + + def _check_required_metrics_are_present(self, calculated_metrics): + # Check all the required metrics have been calculated required_metrics = set(self.required_metrics) if required_metrics.issubset(set(calculated_metrics)): @@ -416,8 +418,9 @@ def _load_model_from_folder(model_folder=None, model_name=None, trust_model=Fals model = skio.load(skops_file) except UntrustedTypesFoundException as e: exception_msg = str(e) - # the exception message contains the list of untrusted objects after a colon and enswith a period - trusted = eval(exception_msg.split(":")[1][:-1]) + # the exception message contains the list of untrusted objects. The following + # search assumes it is the only list in the message. + trusted = re.search(r"\[(.*?)\]", exception_msg).group() model = skio.load(skops_file, trusted=trusted) model_info_path = folder / "model_info.json" diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py index 6ffe17e85a..6521e65dd7 100644 --- a/src/spikeinterface/curation/tests/test_model_based_curation.py +++ b/src/spikeinterface/curation/tests/test_model_based_curation.py @@ -32,27 +32,56 @@ def test_model_based_classification_init(sorting_analyzer_for_curation, model): assert model_based_classification.pipeline == model[0] +def test_metric_ordering_independence(sorting_analyzer_for_curation, model): + + sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) + + model_folder = Path(__file__).parent / Path("trained_pipeline") + + prediction_prob_dataframe_1 = auto_label_units( + sorting_analyzer=sorting_analyzer_for_curation, + model_folder=model_folder, + trusted=["numpy.dtype"], + ) + + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["snr", "num_spikes"]) + + prediction_prob_dataframe_2 = auto_label_units( + sorting_analyzer=sorting_analyzer_for_curation, + model_folder=model_folder, + trusted=["numpy.dtype"], + ) + + assert prediction_prob_dataframe_1.equals(prediction_prob_dataframe_2) + + def test_model_based_classification_get_metrics_for_classification( sorting_analyzer_for_curation, model, required_metrics ): - # Test the _get_metrics_for_classification() method of ModelBasedClassification + + sorting_analyzer_for_curation.delete_extension("quality_metrics") + sorting_analyzer_for_curation.delete_extension("template_metrics") + + # Test the _check_required_metrics_are_present() method of ModelBasedClassification model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) # Check that ValueError is returned when quality_metrics are not present in sorting_analyzer with pytest.raises(ValueError): - model_based_classification._get_metrics_for_classification() + computed_metrics = model_based_classification._get_computed_metrics() # Compute some (but not all) of the required metrics in sorting_analyzer sorting_analyzer_for_curation.compute("quality_metrics", metric_names=[required_metrics[0]]) + computed_metrics = model_based_classification._get_computed_metrics() with pytest.raises(ValueError): - model_based_classification._get_metrics_for_classification() + model_based_classification._check_required_metrics_are_present(computed_metrics) # Compute all of the required metrics in sorting_analyzer sorting_analyzer_for_curation.compute("quality_metrics", metric_names=required_metrics[0:2]) sorting_analyzer_for_curation.compute("template_metrics", metric_names=[required_metrics[2]]) # Check that the metrics data is returned as a pandas DataFrame - metrics_data = model_based_classification._get_metrics_for_classification() + metrics_data = model_based_classification._get_computed_metrics() assert metrics_data.shape[0] == len(sorting_analyzer_for_curation.sorting.get_unit_ids()) assert set(metrics_data.columns.to_list()) == set(required_metrics) @@ -138,3 +167,11 @@ def test_exception_raised_when_metricparams_not_equal(sorting_analyzer_for_curat enforce_metric_params=False, trusted=["numpy.dtype"], ) + + classifer_labels = sorting_analyzer_for_curation.get_sorting_property("classifier_label") + assert isinstance(classifer_labels, np.ndarray) + assert len(classifer_labels) == sorting_analyzer_for_curation.get_num_units() + + classifier_probabilities = sorting_analyzer_for_curation.get_sorting_property("classifier_probability") + assert isinstance(classifier_probabilities, np.ndarray) + assert len(classifier_probabilities) == sorting_analyzer_for_curation.get_num_units()