From bf96fe114b9e479a3db784f4e4de2aa02f65489e Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 22 Nov 2024 16:13:20 -0500 Subject: [PATCH] fix synchrony --- .../qualitymetrics/quality_metric_calculator.py | 8 +++++++- src/spikeinterface/qualitymetrics/quality_metric_list.py | 3 ++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index aef3631438..5cefcaa75d 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -230,7 +230,13 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri # we do this because the convert_dtypes infers the wrong types sometimes. # the actual types for columns can be found in column_name_to_column_dtype dictionary. for column in metrics.columns: - metrics[column] = metrics[column].astype(column_name_to_column_dtype[column]) + # we have one issue where the name of the columns for synchrony are named based on + # what the user has input as arguments so we need a way to handle this separately + # everything else should be handled with the column name. + if "sync" in column: + metrics[column] = metrics[column].astype(column_name_to_column_dtype["sync"]) + else: + metrics[column] = metrics[column].astype(column_name_to_column_dtype[column]) return metrics diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 8ad3bee44c..685aaddc83 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -84,6 +84,7 @@ "silhouette_full": ["silhouette_full"], } +# this dict allows us to ensure the appropriate dtype of metrics rather than allow Pandas to infer them column_name_to_column_dtype = { "num_spikes": int, "firing_rate": float, @@ -98,7 +99,7 @@ "amplitude_median": float, "amplitude_cv_median": float, "amplitude_cv_range": float, - "synch": float, + "sync": float, "firing_range": float, "drift_ptp": float, "drift_std": float,