From d18f48f7ca9a3168c15fdd896377a64c56695f9c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Sep 2024 15:45:29 +0200 Subject: [PATCH 1/3] Add extra protection for template metrix --- .../postprocessing/template_metrics.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index e54ff87221..9d21e56611 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -8,11 +8,9 @@ import numpy as np import warnings -from typing import Optional from copy import deepcopy from ..core.sortinganalyzer import register_result_extension, AnalyzerExtension -from ..core import ChannelSparsity from ..core.template_tools import get_template_extremum_channel from ..core.template_tools import get_dense_templates_array @@ -238,13 +236,17 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job for metric_name in metrics_single_channel: func = _metric_name_to_func[metric_name] - value = func( - template_upsampled, - sampling_frequency=sampling_frequency_up, - trough_idx=trough_idx, - peak_idx=peak_idx, - **self.params["metrics_kwargs"], - ) + try: + value = func( + template_upsampled, + sampling_frequency=sampling_frequency_up, + trough_idx=trough_idx, + peak_idx=peak_idx, + **self.params["metrics_kwargs"], + ) + except Exception as e: + warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") + value = np.nan template_metrics.at[index, metric_name] = value # compute metrics multi_channel From ea13bcb9996e4894e7d9ea1be49fe6a2c5dee6c8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 6 Sep 2024 19:16:27 +0200 Subject: [PATCH 2/3] Add protection for multi-channel metrics (thanks Chris) --- .../qualitymetrics/quality_metric_calculator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 0c7cf25237..cdf6151e95 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -164,7 +164,10 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job pc_metric_names = [k for k in metric_names if k in _possible_pc_metric_names] if len(pc_metric_names) > 0 and not self.params["skip_pc_metrics"]: if not sorting_analyzer.has_extension("principal_components"): - raise ValueError("waveform_principal_component must be provied") + raise ValueError( + "To compute principal components base metrics, the principal components " + "extension must be computed first." + ) pc_metrics = compute_pc_metrics( sorting_analyzer, unit_ids=non_empty_unit_ids, From 4e000ed041b11a9f2195691caf0bcb39bca4a500 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 6 Sep 2024 19:24:07 +0200 Subject: [PATCH 3/3] same for multi-channel --- .../postprocessing/template_metrics.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 9d21e56611..726ec49558 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -276,12 +276,16 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job sampling_frequency_up = sampling_frequency func = _metric_name_to_func[metric_name] - value = func( - template_upsampled, - channel_locations=channel_locations_sparse, - sampling_frequency=sampling_frequency_up, - **self.params["metrics_kwargs"], - ) + try: + value = func( + template_upsampled, + channel_locations=channel_locations_sparse, + sampling_frequency=sampling_frequency_up, + **self.params["metrics_kwargs"], + ) + except Exception as e: + warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") + value = np.nan template_metrics.at[index, metric_name] = value return template_metrics