Skip to content

Commit

Permalink
Merge pull request #3537 from chrishalcrow/unify_template_and_quality…
Browse files Browse the repository at this point in the history
…_metrics

Unify template and quality metrics
  • Loading branch information
alejoe91 authored Dec 3, 2024
2 parents 853d8a4 + de7210a commit d991382
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 84 deletions.
121 changes: 77 additions & 44 deletions src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,10 @@ class ComputeTemplateMetrics(AnalyzerExtension):
include_multi_channel_metrics : bool, default: False
Whether to compute multi-channel metrics
delete_existing_metrics : bool, default: False
If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metrics_kwargs` are unchanged.
metrics_kwargs : dict
Additional arguments to pass to the metric functions. Including:
* recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7
* peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2
* peak_width_ms: the width in samples to detect peaks, default: 0.2
* depth_direction: the direction to compute velocity above and below, default: "y" (see notes)
* min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5
* min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7
* exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp"
* min_r2_exp_decay: the minimum r2 to accept the exp decay fit, default: 0.5
* spread_threshold: the threshold to compute the spread, default: 0.2
* spread_smooth_um: the smoothing in um to compute the spread, default: 20
* column_range: the range in um in the horizontal direction to consider channels for velocity, default: None
- If None, all channels all channels are considered
- If 0 or 1, only the "column" that includes the max channel is considered
- If > 1, only channels within range (+/-) um from the max channel horizontal position are used
If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metric_params` are unchanged.
metric_params : dict of dicts or None, default: None
Dictionary with parameters for template metrics calculation.
Default parameters can be obtained with: `si.postprocessing.template_metrics.get_default_tm_params()`
Returns
-------
Expand All @@ -100,15 +87,29 @@ class ComputeTemplateMetrics(AnalyzerExtension):
need_recording = False
use_nodepipeline = False
need_job_kwargs = False
need_backward_compatibility_on_load = True

min_channels_for_multi_channel_warning = 10

def _handle_backward_compatibility_on_load(self):

# For backwards compatibility - this reformats metrics_kwargs as metric_params
if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None:

metric_params = {}
for metric_name in self.params["metric_names"]:
metric_params[metric_name] = deepcopy(metrics_kwargs)
self.params["metric_params"] = metric_params

del self.params["metrics_kwargs"]

def _set_params(
self,
metric_names=None,
peak_sign="neg",
upsampling_factor=10,
sparsity=None,
metric_params=None,
metrics_kwargs=None,
include_multi_channel_metrics=False,
delete_existing_metrics=False,
Expand All @@ -134,33 +135,24 @@ def _set_params(
if include_multi_channel_metrics:
metric_names += get_multi_channel_template_metric_names()

if metrics_kwargs is None:
metrics_kwargs_ = _default_function_kwargs.copy()
if len(other_kwargs) > 0:
for m in other_kwargs:
if m in metrics_kwargs_:
metrics_kwargs_[m] = other_kwargs[m]
else:
metrics_kwargs_ = _default_function_kwargs.copy()
metrics_kwargs_.update(metrics_kwargs)
if metrics_kwargs is not None and metric_params is None:
deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use `metric_params` instead"
deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use `metric_params` instead"

metric_params = {}
for metric_name in metric_names:
metric_params[metric_name] = deepcopy(metrics_kwargs)

metric_params_ = get_default_tm_params(metric_names)
for k in metric_params_:
if metric_params is not None and k in metric_params:
metric_params_[k].update(metric_params[k])

metrics_to_compute = metric_names
tm_extension = self.sorting_analyzer.get_extension("template_metrics")
if delete_existing_metrics is False and tm_extension is not None:

existing_params = tm_extension.params["metrics_kwargs"]
# checks that existing metrics were calculated using the same params
if existing_params != metrics_kwargs_:
warnings.warn(
f"The parameters used to calculate the previous template metrics are different"
f"than those used now.\nPrevious parameters: {existing_params}\nCurrent "
f"parameters: {metrics_kwargs_}\nDeleting previous template metrics..."
)
tm_extension.params["metric_names"] = []
existing_metric_names = []
else:
existing_metric_names = tm_extension.params["metric_names"]

existing_metric_names = tm_extension.params["metric_names"]
existing_metric_names_propogated = [
metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute
]
Expand All @@ -171,7 +163,7 @@ def _set_params(
sparsity=sparsity,
peak_sign=peak_sign,
upsampling_factor=int(upsampling_factor),
metrics_kwargs=metrics_kwargs_,
metric_params=metric_params_,
delete_existing_metrics=delete_existing_metrics,
metrics_to_compute=metrics_to_compute,
)
Expand Down Expand Up @@ -273,7 +265,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri
sampling_frequency=sampling_frequency_up,
trough_idx=trough_idx,
peak_idx=peak_idx,
**self.params["metrics_kwargs"],
**self.params["metric_params"][metric_name],
)
except Exception as e:
warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}")
Expand Down Expand Up @@ -312,7 +304,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri
template_upsampled,
channel_locations=channel_locations_sparse,
sampling_frequency=sampling_frequency_up,
**self.params["metrics_kwargs"],
**self.params["metric_params"][metric_name],
)
except Exception as e:
warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}")
Expand All @@ -326,8 +318,8 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri

def _run(self, verbose=False):

delete_existing_metrics = self.params["delete_existing_metrics"]
metrics_to_compute = self.params["metrics_to_compute"]
delete_existing_metrics = self.params["delete_existing_metrics"]

# compute the metrics which have been specified by the user
computed_metrics = self._compute_metrics(
Expand All @@ -343,9 +335,21 @@ def _run(self, verbose=False):
):
existing_metrics = tm_extension.params["metric_names"]

existing_metrics = []
# here we get in the loaded via the dict only (to avoid full loading from disk after params reset)
tm_extension = self.sorting_analyzer.extensions.get("template_metrics", None)
if (
delete_existing_metrics is False
and tm_extension is not None
and tm_extension.data.get("metrics") is not None
):
existing_metrics = tm_extension.params["metric_names"]

# append the metrics which were previously computed
for metric_name in set(existing_metrics).difference(metrics_to_compute):
computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name]
# some metrics names produce data columns with other names. This deals with that.
for column_name in tm_compute_name_to_column_names[metric_name]:
computed_metrics[column_name] = tm_extension.data["metrics"][column_name]

self.data["metrics"] = computed_metrics

Expand All @@ -372,6 +376,35 @@ def _get_data(self):
)


def get_default_tm_params(metric_names):
if metric_names is None:
metric_names = get_template_metric_names()

base_tm_params = _default_function_kwargs

metric_params = {}
for metric_name in metric_names:
metric_params[metric_name] = deepcopy(base_tm_params)

return metric_params


# a dict converting the name of the metric for computation to the output of that computation
tm_compute_name_to_column_names = {
"peak_to_valley": ["peak_to_valley"],
"peak_trough_ratio": ["peak_trough_ratio"],
"half_width": ["half_width"],
"repolarization_slope": ["repolarization_slope"],
"recovery_slope": ["recovery_slope"],
"num_positive_peaks": ["num_positive_peaks"],
"num_negative_peaks": ["num_negative_peaks"],
"velocity_above": ["velocity_above"],
"velocity_below": ["velocity_below"],
"exp_decay": ["exp_decay"],
"spread": ["spread"],
}


def get_trough_and_peak_idx(template):
"""
Return the indices into the input template of the detected trough
Expand Down
49 changes: 47 additions & 2 deletions src/spikeinterface/postprocessing/tests/test_template_metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite
from spikeinterface.postprocessing import ComputeTemplateMetrics
from spikeinterface.postprocessing import ComputeTemplateMetrics, compute_template_metrics
import pytest
import csv

Expand All @@ -8,6 +8,49 @@
template_metrics = list(_single_channel_metric_name_to_func.keys())


def test_different_params_template_metrics(small_sorting_analyzer):
"""
Computes template metrics using different params, and check that they are
actually calculated using the different params.
"""
compute_template_metrics(
sorting_analyzer=small_sorting_analyzer,
metric_names=["exp_decay", "spread", "half_width"],
metric_params={"exp_decay": {"recovery_window_ms": 0.8}, "spread": {"spread_smooth_um": 15}},
)

tm_extension = small_sorting_analyzer.get_extension("template_metrics")
tm_params = tm_extension.params["metric_params"]

assert tm_params["exp_decay"]["recovery_window_ms"] == 0.8
assert tm_params["spread"]["recovery_window_ms"] == 0.7
assert tm_params["half_width"]["recovery_window_ms"] == 0.7

assert tm_params["spread"]["spread_smooth_um"] == 15
assert tm_params["exp_decay"]["spread_smooth_um"] == 20
assert tm_params["half_width"]["spread_smooth_um"] == 20


def test_backwards_compat_params_template_metrics(small_sorting_analyzer):
"""
Computes template metrics using the metrics_kwargs keyword
"""
compute_template_metrics(
sorting_analyzer=small_sorting_analyzer,
metric_names=["exp_decay", "spread"],
metrics_kwargs={"recovery_window_ms": 0.8},
)

tm_extension = small_sorting_analyzer.get_extension("template_metrics")
tm_params = tm_extension.params["metric_params"]

assert tm_params["exp_decay"]["recovery_window_ms"] == 0.8
assert tm_params["spread"]["recovery_window_ms"] == 0.8

assert tm_params["spread"]["spread_smooth_um"] == 20
assert tm_params["exp_decay"]["spread_smooth_um"] == 20


def test_compute_new_template_metrics(small_sorting_analyzer):
"""
Computes template metrics then computes a subset of template metrics, and checks
Expand All @@ -17,6 +60,8 @@ def test_compute_new_template_metrics(small_sorting_analyzer):
are deleted.
"""

small_sorting_analyzer.delete_extension("template_metrics")

# calculate just exp_decay
small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}})
template_metric_extension = small_sorting_analyzer.get_extension("template_metrics")
Expand Down Expand Up @@ -47,7 +92,7 @@ def test_compute_new_template_metrics(small_sorting_analyzer):

# check that, when parameters are changed, the old metrics are deleted
small_sorting_analyzer.compute(
{"template_metrics": {"metric_names": ["exp_decay"], "metrics_kwargs": {"recovery_window_ms": 0.6}}}
{"template_metrics": {"metric_names": ["exp_decay"], "metric_params": {"recovery_window_ms": 0.6}}}
)


Expand Down
Loading

0 comments on commit d991382

Please sign in to comment.