From 9834996cf38e2589bdbcd44d20cbcd7a4afee205 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 18:34:09 +0200 Subject: [PATCH] Fix metrics widgets for convert_dtypes --- src/spikeinterface/widgets/metrics.py | 3 +++ src/spikeinterface/widgets/tests/test_widgets.py | 2 +- src/spikeinterface/widgets/utils_sortingview.py | 8 ++++---- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 2fbd0e31eb..813e7d7b63 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -235,6 +235,9 @@ def plot_sortingview(self, data_plot, **backend_kwargs): values = check_json(metrics.loc[unit_id].to_dict()) values_skip_nans = {} for k, v in values.items(): + # convert_dypes returns NaN as None or np.nan (for float) + if v is None: + continue if np.isnan(v): continue values_skip_nans[k] = v diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index debcd52085..80f58f5ad9 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -73,7 +73,7 @@ def setUpClass(cls): spike_amplitudes=dict(), unit_locations=dict(), spike_locations=dict(), - quality_metrics=dict(metric_names=["snr", "isi_violation", "num_spikes"]), + quality_metrics=dict(metric_names=["snr", "isi_violation", "num_spikes", "amplitude_cutoff"]), template_metrics=dict(), correlograms=dict(), template_similarity=dict(), diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index d18c581b6b..a6cc562ba2 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -106,9 +106,9 @@ def generate_unit_table_view( if prop_name in sorting_props: property_values = sorting.get_property(prop_name) elif prop_name in qm_props: - property_values = qm_data[prop_name].values + property_values = qm_data[prop_name].to_numpy() elif prop_name in tm_props: - property_values = tm_data[prop_name].values + property_values = tm_data[prop_name].to_numpy() else: warn(f"Property '{prop_name}' not found in sorting, quality_metrics, or template_metrics") continue @@ -137,9 +137,9 @@ def generate_unit_table_view( if prop_name in sorting_props: property_values = sorting.get_property(prop_name) elif prop_name in qm_props: - property_values = qm_data[prop_name].values + property_values = qm_data[prop_name].to_numpy() elif prop_name in tm_props: - property_values = tm_data[prop_name].values + property_values = tm_data[prop_name].to_numpy() # Check for NaN values and round floats val0 = np.array(property_values[0])