From 5b77ba170788c18fcb9fb06413b8baf58caf73fb Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 22 Nov 2024 16:29:44 -0500 Subject: [PATCH] fix nan and empty units --- .../qualitymetrics/quality_metric_calculator.py | 3 +++ .../qualitymetrics/tests/test_quality_metric_calculator.py | 7 ++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 5cefcaa75d..6fdc21bac2 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -222,6 +222,9 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri # add NaN for empty units if len(empty_unit_ids) > 0: metrics.loc[empty_unit_ids] = np.nan + # num_spikes is an int and should be 0 + if "num_spikes" in metrics.columns: + metrics.loc[empty_unit_ids, ["num_spikes"]] = 0 # we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns # (in case of NaN values) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index c4c1778cf2..56e3975210 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -133,10 +133,15 @@ def test_empty_units(sorting_analyzer_simple): seed=2205, ) + # num_spikes are ints not nans so we confirm empty units are nans for everything except + # num_spikes which should be 0 + nan_containing_columns = [column for column in metrics_empty.columns if column != "num_spikes"] for empty_unit_id in sorting_empty.get_empty_unit_ids(): from pandas import isnull - assert np.all(isnull(metrics_empty.loc[empty_unit_id].values)) + assert np.all(isnull(metrics_empty.loc[empty_unit_id, nan_containing_columns].values)) + if "num_spikes" in metrics_empty.columns: + assert metrics_empty.loc[empty_unit_id, ["num_spikes"]] == 0 # TODO @alessio all theses old test should be moved in test_metric_functions.py or test_pca_metrics()