Skip to content

Commit

Permalink
Add from_amplitude() option to sparsity and deprecate ptp
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Sep 13, 2024
1 parent 25e7e87 commit 9ffda35
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 28 deletions.
105 changes: 81 additions & 24 deletions src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import numpy as np
import warnings


from .basesorting import BaseSorting
Expand All @@ -18,14 +19,16 @@
* "snr" : threshold based on template signal-to-noise ratio. Use the "threshold" argument
to specify the SNR threshold (in units of noise levels) and the "amplitude_mode" argument
to specify the mode to compute the amplitude of the templates.
* "ptp" : threshold based on the peak-to-peak values on every channels. Use the "threshold" argument
to specify the ptp threshold (in units of amplitude).
* "amplitude" : threshold based on the amplitude values on every channels. Use the "threshold" argument
to specify the ptp threshold (in units of amplitude) and the "amplitude_mode" argument
to specify the mode to compute the amplitude of the templates.
* "energy" : threshold based on the expected energy that should be present on the channels,
given their noise levels. Use the "threshold" argument to specify the energy threshold
(in units of noise levels)
* "by_property" : sparsity is given by a property of the recording and sorting (e.g. "group").
In this case the sparsity for each unit is given by the channels that have the same property
value as the unit. Use the "by_property" argument to specify the property name.
* "ptp: : deprecated, use the 'snr' method with the 'peak_to_peak' amplitude mode instead.
peak_sign : "neg" | "pos" | "both"
Sign of the template to compute best channels.
Expand All @@ -37,7 +40,7 @@
Threshold for "snr", "energy" (in units of noise levels) and "ptp" methods (in units of amplitude).
For the "snr" method, the template amplitude mode is controlled by the "amplitude_mode" argument.
amplitude_mode : "extremum" | "at_index" | "peak_to_peak"
Mode to compute the amplitude of the templates for the "snr" and "best_channels" methods.
Mode to compute the amplitude of the templates for the "snr", "amplitude", and "best_channels" methods.
by_property : object
Property name for "by_property" method.
"""
Expand Down Expand Up @@ -417,7 +420,7 @@ def from_snr(
return cls(mask, unit_ids, channel_ids)

@classmethod
def from_ptp(cls, templates_or_sorting_analyzer, threshold):
def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None):
"""
Construct sparsity from a thresholds based on template peak-to-peak values.
Use the "threshold" argument to specify the peak-to-peak threshold.
Expand All @@ -434,30 +437,67 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold):
sparsity : ChannelSparsity
The estimated sparsity.
"""
warnings.warn(
"The 'ptp' method is deprecated and will be removed in version 0.103.0. "
"Please use the 'snr' method with the 'peak_to_peak' amplitude mode instead.",
DeprecationWarning,
)
return cls.from_snr(
templates_or_sorting_analyzer, threshold, amplitude_mode="peak_to_peak", noise_levels=noise_levels
)

assert (
templates_or_sorting_analyzer.sparsity is None
), "To compute sparsity you need a dense SortingAnalyzer or Templates"
@classmethod
def from_amplitude(cls, templates_or_sorting_analyzer, threshold, amplitude_mode="extremum", peak_sign="neg"):
"""
Construct sparsity from a threshold based on template amplitude.
The amplitude is computed with the specified amplitude mode and it is assumed
that the amplitude is in uV. The input `Templates` or `SortingAnalyzer` object must
have scaled templates.
from .template_tools import get_dense_templates_array
Parameters
----------
templates_or_sorting_analyzer : Templates | SortingAnalyzer
A Templates or a SortingAnalyzer object.
threshold : float
Threshold for "amplitude" method (in uV).
amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum"
Mode to compute the amplitude of the templates.
Returns
-------
sparsity : ChannelSparsity
The estimated sparsity.
"""
from .template_tools import get_template_amplitudes
from .sortinganalyzer import SortingAnalyzer
from .template import Templates

assert (
templates_or_sorting_analyzer.sparsity is None
), "To compute sparsity you need a dense SortingAnalyzer or Templates"

unit_ids = templates_or_sorting_analyzer.unit_ids
channel_ids = templates_or_sorting_analyzer.channel_ids

if isinstance(templates_or_sorting_analyzer, SortingAnalyzer):
return_scaled = templates_or_sorting_analyzer.return_scaled
assert templates_or_sorting_analyzer.return_scaled, (
"To compute sparsity from amplitude you need to have scaled templates. "
"You can set `return_scaled=True` when computing the templates."
)
elif isinstance(templates_or_sorting_analyzer, Templates):
return_scaled = templates_or_sorting_analyzer.is_scaled
assert templates_or_sorting_analyzer.is_scaled, (
"To compute sparsity from amplitude you need to have scaled templates. "
"You can set `is_scaled=True` when creating the Templates object."
)

mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool")

templates_array = get_dense_templates_array(templates_or_sorting_analyzer, return_scaled=return_scaled)
templates_ptps = np.ptp(templates_array, axis=1)
peak_values = get_template_amplitudes(
templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode, return_scaled=True
)

for unit_ind, unit_id in enumerate(unit_ids):
chan_inds = np.nonzero(templates_ptps[unit_ind] >= threshold)
chan_inds = np.nonzero((np.abs(peak_values[unit_id])) >= threshold)
mask[unit_ind, chan_inds] = True
return cls(mask, unit_ids, channel_ids)

Expand Down Expand Up @@ -560,7 +600,7 @@ def create_dense(cls, sorting_analyzer):
def compute_sparsity(
templates_or_sorting_analyzer: "Templates | SortingAnalyzer",
noise_levels: np.ndarray | None = None,
method: "radius" | "best_channels" | "snr" | "ptp" | "energy" | "by_property" = "radius",
method: "radius" | "best_channels" | "snr" | "amplitude" | "energy" | "by_property" | "ptp" = "radius",
peak_sign: "neg" | "pos" | "both" = "neg",
num_channels: int | None = 5,
radius_um: float | None = 100.0,
Expand Down Expand Up @@ -595,7 +635,7 @@ def compute_sparsity(
# to keep backward compatibility
templates_or_sorting_analyzer = templates_or_sorting_analyzer.sorting_analyzer

if method in ("best_channels", "radius", "snr", "ptp"):
if method in ("best_channels", "radius", "snr", "amplitude", "ptp"):
assert isinstance(
templates_or_sorting_analyzer, (Templates, SortingAnalyzer)
), f"compute_sparsity(method='{method}') need Templates or SortingAnalyzer"
Expand All @@ -619,11 +659,13 @@ def compute_sparsity(
peak_sign=peak_sign,
amplitude_mode=amplitude_mode,
)
elif method == "ptp":
assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given"
sparsity = ChannelSparsity.from_ptp(
elif method == "amplitude":
assert threshold is not None, "For the 'amplitude' method, 'threshold' needs to be given"
sparsity = ChannelSparsity.from_amplitude(
templates_or_sorting_analyzer,
threshold,
amplitude_mode=amplitude_mode,
peak_sign=peak_sign,
)
elif method == "energy":
assert threshold is not None, "For the 'energy' method, 'threshold' needs to be given"
Expand All @@ -633,6 +675,14 @@ def compute_sparsity(
sparsity = ChannelSparsity.from_property(
templates_or_sorting_analyzer.sorting, templates_or_sorting_analyzer.recording, by_property
)
elif method == "ptp":
# TODO: remove after deprecation
assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given"
sparsity = ChannelSparsity.from_ptp(
templates_or_sorting_analyzer,
threshold,
noise_levels=noise_levels,
)
else:
raise ValueError(f"compute_sparsity() method={method} does not exists")

Expand All @@ -648,7 +698,7 @@ def estimate_sparsity(
num_spikes_for_sparsity: int = 100,
ms_before: float = 1.0,
ms_after: float = 2.5,
method: "radius" | "best_channels" | "ptp" | "snr" | "by_property" = "radius",
method: "radius" | "best_channels" | "amplitude" | "snr" | "by_property" | "ptp" = "radius",
peak_sign: "neg" | "pos" | "both" = "neg",
radius_um: float = 100.0,
num_channels: int = 5,
Expand Down Expand Up @@ -697,9 +747,9 @@ def estimate_sparsity(
# Can't be done at module because this is a cyclic import, too bad
from .template import Templates

assert method in ("radius", "best_channels", "ptp", "snr", "by_property"), (
assert method in ("radius", "best_channels", "snr", "amplitude", "by_property", "ptp"), (
f"method={method} is not available for `estimate_sparsity()`. "
"Available methods are 'radius', 'best_channels', 'ptp', 'snr', 'energy', 'by_property'"
"Available methods are 'radius', 'best_channels', 'snr', 'amplitude', 'by_property', 'ptp' (deprecated)"
)

if recording.get_probes() == 1:
Expand Down Expand Up @@ -768,12 +818,19 @@ def estimate_sparsity(
peak_sign=peak_sign,
amplitude_mode=amplitude_mode,
)
elif method == "amplitude":
assert threshold is not None, "For the 'amplitude' method, 'threshold' needs to be given"
sparsity = ChannelSparsity.from_amplitude(
templates, threshold, amplitude_mode=amplitude_mode, peak_sign=peak_sign
)
elif method == "ptp":
# TODO: remove after deprecation
assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given"
sparsity = ChannelSparsity.from_ptp(
templates,
threshold,
assert noise_levels is not None, (
"For the 'snr' method, 'noise_levels' needs to be given. You can use the "
"`get_noise_levels()` function to compute them."
)
sparsity = ChannelSparsity.from_ptp(templates, threshold, noise_levels=noise_levels)
else:
raise ValueError(f"compute_sparsity() method={method} does not exists")
else:
Expand Down
31 changes: 27 additions & 4 deletions src/spikeinterface/core/tests/test_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,16 @@ def test_estimate_sparsity():
)
assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 5)

# ptp: just run it
# amplitude
sparsity = estimate_sparsity(
sorting,
recording,
num_spikes_for_sparsity=50,
ms_before=1.0,
ms_after=2.0,
method="ptp",
method="amplitude",
threshold=5,
amplitude_mode="peak_to_peak",
chunk_duration="1s",
progress_bar=True,
n_jobs=1,
Expand Down Expand Up @@ -252,6 +253,23 @@ def test_estimate_sparsity():
progress_bar=True,
n_jobs=1,
)
# ptp: just run it
print(noise_levels)

with pytest.warns(DeprecationWarning):
sparsity = estimate_sparsity(
sorting,
recording,
num_spikes_for_sparsity=50,
ms_before=1.0,
ms_after=2.0,
method="ptp",
threshold=5,
noise_levels=noise_levels,
chunk_duration="1s",
progress_bar=True,
n_jobs=1,
)


def test_compute_sparsity():
Expand All @@ -273,17 +291,22 @@ def test_compute_sparsity():
sparsity = compute_sparsity(
sorting_analyzer, method="snr", threshold=5, peak_sign="neg", amplitude_mode="peak_to_peak"
)
sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5)
sparsity = compute_sparsity(sorting_analyzer, method="amplitude", threshold=5, amplitude_mode="peak_to_peak")
sparsity = compute_sparsity(sorting_analyzer, method="energy", threshold=5)
sparsity = compute_sparsity(sorting_analyzer, method="by_property", by_property="group")
with pytest.warns(DeprecationWarning):
sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5)

# using object Templates
templates = sorting_analyzer.get_extension("templates").get_data(outputs="Templates")
noise_levels = sorting_analyzer.get_extension("noise_levels").get_data()
sparsity = compute_sparsity(templates, method="best_channels", num_channels=2, peak_sign="neg")
sparsity = compute_sparsity(templates, method="radius", radius_um=50.0, peak_sign="neg")
sparsity = compute_sparsity(templates, method="snr", noise_levels=noise_levels, threshold=5, peak_sign="neg")
sparsity = compute_sparsity(templates, method="ptp", noise_levels=noise_levels, threshold=5)
sparsity = compute_sparsity(templates, method="amplitude", threshold=5, amplitude_mode="peak_to_peak")

with pytest.warns(DeprecationWarning):
sparsity = compute_sparsity(templates, method="ptp", noise_levels=noise_levels, threshold=5)


if __name__ == "__main__":
Expand Down

0 comments on commit 9ffda35

Please sign in to comment.