diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 471302d57e..fd613e1fcf 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +import warnings from .basesorting import BaseSorting @@ -18,14 +19,16 @@ * "snr" : threshold based on template signal-to-noise ratio. Use the "threshold" argument to specify the SNR threshold (in units of noise levels) and the "amplitude_mode" argument to specify the mode to compute the amplitude of the templates. - * "ptp" : threshold based on the peak-to-peak values on every channels. Use the "threshold" argument - to specify the ptp threshold (in units of amplitude). + * "amplitude" : threshold based on the amplitude values on every channels. Use the "threshold" argument + to specify the ptp threshold (in units of amplitude) and the "amplitude_mode" argument + to specify the mode to compute the amplitude of the templates. * "energy" : threshold based on the expected energy that should be present on the channels, given their noise levels. Use the "threshold" argument to specify the energy threshold (in units of noise levels) * "by_property" : sparsity is given by a property of the recording and sorting (e.g. "group"). In this case the sparsity for each unit is given by the channels that have the same property value as the unit. Use the "by_property" argument to specify the property name. + * "ptp: : deprecated, use the 'snr' method with the 'peak_to_peak' amplitude mode instead. peak_sign : "neg" | "pos" | "both" Sign of the template to compute best channels. @@ -37,7 +40,7 @@ Threshold for "snr", "energy" (in units of noise levels) and "ptp" methods (in units of amplitude). For the "snr" method, the template amplitude mode is controlled by the "amplitude_mode" argument. amplitude_mode : "extremum" | "at_index" | "peak_to_peak" - Mode to compute the amplitude of the templates for the "snr" and "best_channels" methods. + Mode to compute the amplitude of the templates for the "snr", "amplitude", and "best_channels" methods. by_property : object Property name for "by_property" method. """ @@ -417,7 +420,7 @@ def from_snr( return cls(mask, unit_ids, channel_ids) @classmethod - def from_ptp(cls, templates_or_sorting_analyzer, threshold): + def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): """ Construct sparsity from a thresholds based on template peak-to-peak values. Use the "threshold" argument to specify the peak-to-peak threshold. @@ -434,30 +437,67 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold): sparsity : ChannelSparsity The estimated sparsity. """ + warnings.warn( + "The 'ptp' method is deprecated and will be removed in version 0.103.0. " + "Please use the 'snr' method with the 'peak_to_peak' amplitude mode instead.", + DeprecationWarning, + ) + return cls.from_snr( + templates_or_sorting_analyzer, threshold, amplitude_mode="peak_to_peak", noise_levels=noise_levels + ) - assert ( - templates_or_sorting_analyzer.sparsity is None - ), "To compute sparsity you need a dense SortingAnalyzer or Templates" + @classmethod + def from_amplitude(cls, templates_or_sorting_analyzer, threshold, amplitude_mode="extremum", peak_sign="neg"): + """ + Construct sparsity from a threshold based on template amplitude. + The amplitude is computed with the specified amplitude mode and it is assumed + that the amplitude is in uV. The input `Templates` or `SortingAnalyzer` object must + have scaled templates. - from .template_tools import get_dense_templates_array + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + threshold : float + Threshold for "amplitude" method (in uV). + amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + Mode to compute the amplitude of the templates. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. + """ + from .template_tools import get_template_amplitudes from .sortinganalyzer import SortingAnalyzer from .template import Templates + assert ( + templates_or_sorting_analyzer.sparsity is None + ), "To compute sparsity you need a dense SortingAnalyzer or Templates" + unit_ids = templates_or_sorting_analyzer.unit_ids channel_ids = templates_or_sorting_analyzer.channel_ids if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - return_scaled = templates_or_sorting_analyzer.return_scaled + assert templates_or_sorting_analyzer.return_scaled, ( + "To compute sparsity from amplitude you need to have scaled templates. " + "You can set `return_scaled=True` when computing the templates." + ) elif isinstance(templates_or_sorting_analyzer, Templates): - return_scaled = templates_or_sorting_analyzer.is_scaled + assert templates_or_sorting_analyzer.is_scaled, ( + "To compute sparsity from amplitude you need to have scaled templates. " + "You can set `is_scaled=True` when creating the Templates object." + ) mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") - templates_array = get_dense_templates_array(templates_or_sorting_analyzer, return_scaled=return_scaled) - templates_ptps = np.ptp(templates_array, axis=1) + peak_values = get_template_amplitudes( + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode, return_scaled=True + ) for unit_ind, unit_id in enumerate(unit_ids): - chan_inds = np.nonzero(templates_ptps[unit_ind] >= threshold) + chan_inds = np.nonzero((np.abs(peak_values[unit_id])) >= threshold) mask[unit_ind, chan_inds] = True return cls(mask, unit_ids, channel_ids) @@ -560,7 +600,7 @@ def create_dense(cls, sorting_analyzer): def compute_sparsity( templates_or_sorting_analyzer: "Templates | SortingAnalyzer", noise_levels: np.ndarray | None = None, - method: "radius" | "best_channels" | "snr" | "ptp" | "energy" | "by_property" = "radius", + method: "radius" | "best_channels" | "snr" | "amplitude" | "energy" | "by_property" | "ptp" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", num_channels: int | None = 5, radius_um: float | None = 100.0, @@ -595,7 +635,7 @@ def compute_sparsity( # to keep backward compatibility templates_or_sorting_analyzer = templates_or_sorting_analyzer.sorting_analyzer - if method in ("best_channels", "radius", "snr", "ptp"): + if method in ("best_channels", "radius", "snr", "amplitude", "ptp"): assert isinstance( templates_or_sorting_analyzer, (Templates, SortingAnalyzer) ), f"compute_sparsity(method='{method}') need Templates or SortingAnalyzer" @@ -619,11 +659,13 @@ def compute_sparsity( peak_sign=peak_sign, amplitude_mode=amplitude_mode, ) - elif method == "ptp": - assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_ptp( + elif method == "amplitude": + assert threshold is not None, "For the 'amplitude' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_amplitude( templates_or_sorting_analyzer, threshold, + amplitude_mode=amplitude_mode, + peak_sign=peak_sign, ) elif method == "energy": assert threshold is not None, "For the 'energy' method, 'threshold' needs to be given" @@ -633,6 +675,14 @@ def compute_sparsity( sparsity = ChannelSparsity.from_property( templates_or_sorting_analyzer.sorting, templates_or_sorting_analyzer.recording, by_property ) + elif method == "ptp": + # TODO: remove after deprecation + assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_ptp( + templates_or_sorting_analyzer, + threshold, + noise_levels=noise_levels, + ) else: raise ValueError(f"compute_sparsity() method={method} does not exists") @@ -648,7 +698,7 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" | "ptp" | "snr" | "by_property" = "radius", + method: "radius" | "best_channels" | "amplitude" | "snr" | "by_property" | "ptp" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", radius_um: float = 100.0, num_channels: int = 5, @@ -697,9 +747,9 @@ def estimate_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - assert method in ("radius", "best_channels", "ptp", "snr", "by_property"), ( + assert method in ("radius", "best_channels", "snr", "amplitude", "by_property", "ptp"), ( f"method={method} is not available for `estimate_sparsity()`. " - "Available methods are 'radius', 'best_channels', 'ptp', 'snr', 'energy', 'by_property'" + "Available methods are 'radius', 'best_channels', 'snr', 'amplitude', 'by_property', 'ptp' (deprecated)" ) if recording.get_probes() == 1: @@ -768,12 +818,19 @@ def estimate_sparsity( peak_sign=peak_sign, amplitude_mode=amplitude_mode, ) + elif method == "amplitude": + assert threshold is not None, "For the 'amplitude' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_amplitude( + templates, threshold, amplitude_mode=amplitude_mode, peak_sign=peak_sign + ) elif method == "ptp": + # TODO: remove after deprecation assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_ptp( - templates, - threshold, + assert noise_levels is not None, ( + "For the 'snr' method, 'noise_levels' needs to be given. You can use the " + "`get_noise_levels()` function to compute them." ) + sparsity = ChannelSparsity.from_ptp(templates, threshold, noise_levels=noise_levels) else: raise ValueError(f"compute_sparsity() method={method} does not exists") else: diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 64517c106d..ace869df8c 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -209,15 +209,16 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 5) - # ptp: just run it + # amplitude sparsity = estimate_sparsity( sorting, recording, num_spikes_for_sparsity=50, ms_before=1.0, ms_after=2.0, - method="ptp", + method="amplitude", threshold=5, + amplitude_mode="peak_to_peak", chunk_duration="1s", progress_bar=True, n_jobs=1, @@ -252,6 +253,23 @@ def test_estimate_sparsity(): progress_bar=True, n_jobs=1, ) + # ptp: just run it + print(noise_levels) + + with pytest.warns(DeprecationWarning): + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="ptp", + threshold=5, + noise_levels=noise_levels, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) def test_compute_sparsity(): @@ -273,9 +291,11 @@ def test_compute_sparsity(): sparsity = compute_sparsity( sorting_analyzer, method="snr", threshold=5, peak_sign="neg", amplitude_mode="peak_to_peak" ) - sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5) + sparsity = compute_sparsity(sorting_analyzer, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") sparsity = compute_sparsity(sorting_analyzer, method="energy", threshold=5) sparsity = compute_sparsity(sorting_analyzer, method="by_property", by_property="group") + with pytest.warns(DeprecationWarning): + sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5) # using object Templates templates = sorting_analyzer.get_extension("templates").get_data(outputs="Templates") @@ -283,7 +303,10 @@ def test_compute_sparsity(): sparsity = compute_sparsity(templates, method="best_channels", num_channels=2, peak_sign="neg") sparsity = compute_sparsity(templates, method="radius", radius_um=50.0, peak_sign="neg") sparsity = compute_sparsity(templates, method="snr", noise_levels=noise_levels, threshold=5, peak_sign="neg") - sparsity = compute_sparsity(templates, method="ptp", noise_levels=noise_levels, threshold=5) + sparsity = compute_sparsity(templates, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") + + with pytest.warns(DeprecationWarning): + sparsity = compute_sparsity(templates, method="ptp", noise_levels=noise_levels, threshold=5) if __name__ == "__main__":