Skip to content

Commit

Permalink
Refactor to use _get_computed_metrics for both CurationModelTrainer…
Browse files Browse the repository at this point in the history
… and ModelBasedClassification
  • Loading branch information
chrishalcrow committed Nov 28, 2024
1 parent 4d0e259 commit a939d81
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 40 deletions.
21 changes: 2 additions & 19 deletions src/spikeinterface/curation/model_based_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings

from spikeinterface.core import SortingAnalyzer
from spikeinterface.curation.train_manual_curation import try_to_get_metrics_from_analyzer
from spikeinterface.curation.train_manual_curation import try_to_get_metrics_from_analyzer, _get_computed_metrics


class ModelBasedClassification:
Expand Down Expand Up @@ -75,7 +75,7 @@ def predict_labels(

# Get metrics DataFrame for classification
if input_data is None:
input_data = self._get_computed_metrics()
input_data = _get_computed_metrics(self.sorting_analyzer)
else:
if not isinstance(input_data, pd.DataFrame):
raise ValueError("Input data must be a pandas DataFrame")
Expand Down Expand Up @@ -122,23 +122,6 @@ def predict_labels(

return classified_units

def _get_computed_metrics(self):
"""Check if all required metrics are present and return a DataFrame of metrics for classification"""

import pandas as pd

quality_metrics, template_metrics = try_to_get_metrics_from_analyzer(self.sorting_analyzer)
calculated_metrics = pd.concat([quality_metrics, template_metrics], axis=1)

# Remove any metrics for non-existent units, raise error if no units are present
calculated_metrics = calculated_metrics.loc[
calculated_metrics.index.isin(self.sorting_analyzer.sorting.get_unit_ids())
]
if calculated_metrics.shape[0] == 0:
raise ValueError("No units present in sorting data")

return calculated_metrics

def _check_required_metrics_are_present(self, calculated_metrics):

# Check all the required metrics have been calculated
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation
from spikeinterface.curation.model_based_curation import ModelBasedClassification
from spikeinterface.curation import auto_label_units, load_model
from spikeinterface.curation.train_manual_curation import _get_computed_metrics

import numpy as np

if hasattr(pytest, "global_test_folder"):
Expand Down Expand Up @@ -68,11 +70,11 @@ def test_model_based_classification_get_metrics_for_classification(

# Check that ValueError is returned when quality_metrics are not present in sorting_analyzer
with pytest.raises(ValueError):
computed_metrics = model_based_classification._get_computed_metrics()
computed_metrics = _get_computed_metrics(sorting_analyzer_for_curation)

# Compute some (but not all) of the required metrics in sorting_analyzer
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=[required_metrics[0]])
computed_metrics = model_based_classification._get_computed_metrics()
computed_metrics = _get_computed_metrics(sorting_analyzer_for_curation)
with pytest.raises(ValueError):
model_based_classification._check_required_metrics_are_present(computed_metrics)

Expand All @@ -81,7 +83,7 @@ def test_model_based_classification_get_metrics_for_classification(
sorting_analyzer_for_curation.compute("template_metrics", metric_names=[required_metrics[2]])

# Check that the metrics data is returned as a pandas DataFrame
metrics_data = model_based_classification._get_computed_metrics()
metrics_data = _get_computed_metrics(sorting_analyzer_for_curation)
assert metrics_data.shape[0] == len(sorting_analyzer_for_curation.sorting.get_unit_ids())
assert set(metrics_data.columns.to_list()) == set(required_metrics)

Expand Down
35 changes: 17 additions & 18 deletions src/spikeinterface/curation/train_manual_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def load_and_preprocess_analyzers(self, analyzers, enforce_metric_params):
"""
import pandas as pd

metrics_for_each_analyzer = [self._get_metrics_for_classification(an) for an in analyzers]
metrics_for_each_analyzer = [_get_computed_metrics(an) for an in analyzers]
check_metric_names_are_the_same(metrics_for_each_analyzer)

self.testing_metrics = pd.concat(metrics_for_each_analyzer, axis=0)
Expand Down Expand Up @@ -494,23 +494,6 @@ def evaluate_model_config(self):
self.search_kwargs,
)

def _get_metrics_for_classification(self, analyzer):
"""Check if required metrics are present and return a DataFrame of metrics for classification."""

import pandas as pd

quality_metrics, template_metrics = try_to_get_metrics_from_analyzer(analyzer)

# Concatenate the available metrics
calculated_metrics = pd.concat([m for m in [quality_metrics, template_metrics] if m is not None], axis=1)

# Remove any metrics for non-existent units, raise error if no units are present
calculated_metrics = calculated_metrics.loc[calculated_metrics.index.isin(analyzer.sorting.get_unit_ids())]
if calculated_metrics.shape[0] == 0:
raise ValueError("No units present in sorting data")

return calculated_metrics

def _load_data_files(self, paths):
import pandas as pd

Expand Down Expand Up @@ -736,6 +719,22 @@ def train_model(
return trainer


def _get_computed_metrics(sorting_analyzer):
"""Loads and organises the computed metrics from a sorting_analyzer into a single dataframe"""

import pandas as pd

quality_metrics, template_metrics = try_to_get_metrics_from_analyzer(sorting_analyzer)
calculated_metrics = pd.concat([quality_metrics, template_metrics], axis=1)

# Remove any metrics for non-existent units, raise error if no units are present
calculated_metrics = calculated_metrics.loc[calculated_metrics.index.isin(sorting_analyzer.sorting.get_unit_ids())]
if calculated_metrics.shape[0] == 0:
raise ValueError("No units present in sorting data")

return calculated_metrics


def try_to_get_metrics_from_analyzer(sorting_analyzer):

quality_metrics = None
Expand Down

0 comments on commit a939d81

Please sign in to comment.