Skip to content

Commit

Permalink
Respond to model_based_curation review
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Nov 28, 2024
1 parent 26ae2b5 commit 65847c1
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 13 deletions.
21 changes: 12 additions & 9 deletions src/spikeinterface/curation/model_based_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,13 @@ def predict_labels(

# Get metrics DataFrame for classification
if input_data is None:
input_data = self._get_metrics_for_classification()
input_data = self._get_computed_metrics()
else:
if not isinstance(input_data, pd.DataFrame):
raise ValueError("Input data must be a pandas DataFrame")

input_data = self._check_required_metrics_are_present(input_data)

if model_info is not None:
self._check_params_for_classification(enforce_metric_params, model_info=model_info)

Expand All @@ -103,9 +105,7 @@ def predict_labels(
probabilities = np.max(probabilities, axis=1)

if isinstance(label_conversion, dict):
try:
assert set(predictions).issubset(label_conversion.keys())
except AssertionError:
if set(predictions) != set(label_conversion.keys()):
raise ValueError("Labels in predictions do not match those in label_conversion")
predictions = [label_conversion[label] for label in predictions]

Expand All @@ -122,14 +122,12 @@ def predict_labels(

return classified_units

def _get_metrics_for_classification(self):
def _get_computed_metrics(self):
"""Check if all required metrics are present and return a DataFrame of metrics for classification"""

import pandas as pd

quality_metrics, template_metrics = try_to_get_metrics_from_analyzer(self.sorting_analyzer)

# Create DataFrame of all metrics and reorder columns to match the model
calculated_metrics = pd.concat([quality_metrics, template_metrics], axis=1)

# Remove any metrics for non-existent units, raise error if no units are present
Expand All @@ -139,6 +137,10 @@ def _get_metrics_for_classification(self):
if calculated_metrics.shape[0] == 0:
raise ValueError("No units present in sorting data")

return calculated_metrics

def _check_required_metrics_are_present(self, calculated_metrics):

# Check all the required metrics have been calculated
required_metrics = set(self.required_metrics)
if required_metrics.issubset(set(calculated_metrics)):
Expand Down Expand Up @@ -416,8 +418,9 @@ def _load_model_from_folder(model_folder=None, model_name=None, trust_model=Fals
model = skio.load(skops_file)
except UntrustedTypesFoundException as e:
exception_msg = str(e)
# the exception message contains the list of untrusted objects after a colon and enswith a period
trusted = eval(exception_msg.split(":")[1][:-1])
# the exception message contains the list of untrusted objects. The following
# search assumes it is the only list in the message.
trusted = re.search(r"\[(.*?)\]", exception_msg).group()
model = skio.load(skops_file, trusted=trusted)

model_info_path = folder / "model_info.json"
Expand Down
45 changes: 41 additions & 4 deletions src/spikeinterface/curation/tests/test_model_based_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,56 @@ def test_model_based_classification_init(sorting_analyzer_for_curation, model):
assert model_based_classification.pipeline == model[0]


def test_metric_ordering_independence(sorting_analyzer_for_curation, model):

sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"])
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"])

model_folder = Path(__file__).parent / Path("trained_pipeline")

prediction_prob_dataframe_1 = auto_label_units(
sorting_analyzer=sorting_analyzer_for_curation,
model_folder=model_folder,
trusted=["numpy.dtype"],
)

sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["snr", "num_spikes"])

prediction_prob_dataframe_2 = auto_label_units(
sorting_analyzer=sorting_analyzer_for_curation,
model_folder=model_folder,
trusted=["numpy.dtype"],
)

assert prediction_prob_dataframe_1.equals(prediction_prob_dataframe_2)


def test_model_based_classification_get_metrics_for_classification(
sorting_analyzer_for_curation, model, required_metrics
):
# Test the _get_metrics_for_classification() method of ModelBasedClassification

sorting_analyzer_for_curation.delete_extension("quality_metrics")
sorting_analyzer_for_curation.delete_extension("template_metrics")

# Test the _check_required_metrics_are_present() method of ModelBasedClassification
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0])

# Check that ValueError is returned when quality_metrics are not present in sorting_analyzer
with pytest.raises(ValueError):
model_based_classification._get_metrics_for_classification()
computed_metrics = model_based_classification._get_computed_metrics()

# Compute some (but not all) of the required metrics in sorting_analyzer
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=[required_metrics[0]])
computed_metrics = model_based_classification._get_computed_metrics()
with pytest.raises(ValueError):
model_based_classification._get_metrics_for_classification()
model_based_classification._check_required_metrics_are_present(computed_metrics)

# Compute all of the required metrics in sorting_analyzer
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=required_metrics[0:2])
sorting_analyzer_for_curation.compute("template_metrics", metric_names=[required_metrics[2]])

# Check that the metrics data is returned as a pandas DataFrame
metrics_data = model_based_classification._get_metrics_for_classification()
metrics_data = model_based_classification._get_computed_metrics()
assert metrics_data.shape[0] == len(sorting_analyzer_for_curation.sorting.get_unit_ids())
assert set(metrics_data.columns.to_list()) == set(required_metrics)

Expand Down Expand Up @@ -138,3 +167,11 @@ def test_exception_raised_when_metricparams_not_equal(sorting_analyzer_for_curat
enforce_metric_params=False,
trusted=["numpy.dtype"],
)

classifer_labels = sorting_analyzer_for_curation.get_sorting_property("classifier_label")
assert isinstance(classifer_labels, np.ndarray)
assert len(classifer_labels) == sorting_analyzer_for_curation.get_num_units()

classifier_probabilities = sorting_analyzer_for_curation.get_sorting_property("classifier_probability")
assert isinstance(classifier_probabilities, np.ndarray)
assert len(classifier_probabilities) == sorting_analyzer_for_curation.get_num_units()

0 comments on commit 65847c1

Please sign in to comment.