Skip to content

Commit

Permalink
Merge pull request #3364 from alejoe91/protect-template-metrics
Browse files Browse the repository at this point in the history
Add extra protection for template metrics
  • Loading branch information
samuelgarcia authored Sep 10, 2024
2 parents 358e0d3 + 4e000ed commit 07c9bff
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 16 deletions.
36 changes: 21 additions & 15 deletions src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@

import numpy as np
import warnings
from typing import Optional
from copy import deepcopy

from ..core.sortinganalyzer import register_result_extension, AnalyzerExtension
from ..core import ChannelSparsity
from ..core.template_tools import get_template_extremum_channel
from ..core.template_tools import get_dense_templates_array

Expand Down Expand Up @@ -238,13 +236,17 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job

for metric_name in metrics_single_channel:
func = _metric_name_to_func[metric_name]
value = func(
template_upsampled,
sampling_frequency=sampling_frequency_up,
trough_idx=trough_idx,
peak_idx=peak_idx,
**self.params["metrics_kwargs"],
)
try:
value = func(
template_upsampled,
sampling_frequency=sampling_frequency_up,
trough_idx=trough_idx,
peak_idx=peak_idx,
**self.params["metrics_kwargs"],
)
except Exception as e:
warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}")
value = np.nan
template_metrics.at[index, metric_name] = value

# compute metrics multi_channel
Expand Down Expand Up @@ -274,12 +276,16 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job
sampling_frequency_up = sampling_frequency

func = _metric_name_to_func[metric_name]
value = func(
template_upsampled,
channel_locations=channel_locations_sparse,
sampling_frequency=sampling_frequency_up,
**self.params["metrics_kwargs"],
)
try:
value = func(
template_upsampled,
channel_locations=channel_locations_sparse,
sampling_frequency=sampling_frequency_up,
**self.params["metrics_kwargs"],
)
except Exception as e:
warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}")
value = np.nan
template_metrics.at[index, metric_name] = value
return template_metrics

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job
pc_metric_names = [k for k in metric_names if k in _possible_pc_metric_names]
if len(pc_metric_names) > 0 and not self.params["skip_pc_metrics"]:
if not sorting_analyzer.has_extension("principal_components"):
raise ValueError("waveform_principal_component must be provied")
raise ValueError(
"To compute principal components base metrics, the principal components "
"extension must be computed first."
)
pc_metrics = compute_pc_metrics(
sorting_analyzer,
unit_ids=non_empty_unit_ids,
Expand Down

0 comments on commit 07c9bff

Please sign in to comment.