diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index a38562ea2c..fd613e1fcf 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +import warnings from .basesorting import BaseSorting @@ -13,28 +14,35 @@ _sparsity_doc = """ method : str * "best_channels" : N best channels with the largest amplitude. Use the "num_channels" argument to specify the - number of channels. - * "radius" : radius around the best channel. Use the "radius_um" argument to specify the radius in um + number of channels. + * "radius" : radius around the best channel. Use the "radius_um" argument to specify the radius in um. * "snr" : threshold based on template signal-to-noise ratio. Use the "threshold" argument - to specify the SNR threshold (in units of noise levels) - * "ptp" : threshold based on the peak-to-peak values on every channels. Use the "threshold" argument - to specify the ptp threshold (in units of noise levels) + to specify the SNR threshold (in units of noise levels) and the "amplitude_mode" argument + to specify the mode to compute the amplitude of the templates. + * "amplitude" : threshold based on the amplitude values on every channels. Use the "threshold" argument + to specify the ptp threshold (in units of amplitude) and the "amplitude_mode" argument + to specify the mode to compute the amplitude of the templates. * "energy" : threshold based on the expected energy that should be present on the channels, - given their noise levels. Use the "threshold" argument to specify the SNR threshold + given their noise levels. Use the "threshold" argument to specify the energy threshold (in units of noise levels) - * "by_property" : sparsity is given by a property of the recording and sorting(e.g. "group"). - Use the "by_property" argument to specify the property name. + * "by_property" : sparsity is given by a property of the recording and sorting (e.g. "group"). + In this case the sparsity for each unit is given by the channels that have the same property + value as the unit. Use the "by_property" argument to specify the property name. + * "ptp: : deprecated, use the 'snr' method with the 'peak_to_peak' amplitude mode instead. - peak_sign : str - Sign of the template to compute best channels ("neg", "pos", "both") + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute best channels. num_channels : int - Number of channels for "best_channels" method + Number of channels for "best_channels" method. radius_um : float - Radius in um for "radius" method + Radius in um for "radius" method. threshold : float - Threshold in SNR "threshold" method + Threshold for "snr", "energy" (in units of noise levels) and "ptp" methods (in units of amplitude). + For the "snr" method, the template amplitude mode is controlled by the "amplitude_mode" argument. + amplitude_mode : "extremum" | "at_index" | "peak_to_peak" + Mode to compute the amplitude of the templates for the "snr", "amplitude", and "best_channels" methods. by_property : object - Property name for "by_property" method + Property name for "by_property" method. """ @@ -277,18 +285,35 @@ def from_dict(cls, dictionary: dict): ## Some convinient function to compute sparsity from several strategy @classmethod - def from_best_channels(cls, templates_or_sorting_analyzer, num_channels, peak_sign="neg"): + def from_best_channels( + cls, templates_or_sorting_analyzer, num_channels, peak_sign="neg", amplitude_mode="extremum" + ): """ Construct sparsity from N best channels with the largest amplitude. Use the "num_channels" argument to specify the number of channels. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + num_channels : int + Number of channels for "best_channels" method. + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute best channels. + amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + Mode to compute the amplitude of the templates. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity """ from .template_tools import get_template_amplitudes - print(templates_or_sorting_analyzer) mask = np.zeros( (templates_or_sorting_analyzer.unit_ids.size, templates_or_sorting_analyzer.channel_ids.size), dtype="bool" ) - peak_values = get_template_amplitudes(templates_or_sorting_analyzer, peak_sign=peak_sign) + peak_values = get_template_amplitudes(templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode) for unit_ind, unit_id in enumerate(templates_or_sorting_analyzer.unit_ids): chan_inds = np.argsort(np.abs(peak_values[unit_id]))[::-1] chan_inds = chan_inds[:num_channels] @@ -299,7 +324,21 @@ def from_best_channels(cls, templates_or_sorting_analyzer, num_channels, peak_si def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): """ Construct sparsity from a radius around the best channel. - Use the "radius_um" argument to specify the radius in um + Use the "radius_um" argument to specify the radius in um. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + radius_um : float + Radius in um for "radius" method. + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute best channels. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ from .template_tools import get_template_extremum_channel @@ -316,10 +355,38 @@ def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) @classmethod - def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, peak_sign="neg"): + def from_snr( + cls, + templates_or_sorting_analyzer, + threshold, + amplitude_mode="extremum", + peak_sign="neg", + noise_levels=None, + ): """ Construct sparsity from a thresholds based on template signal-to-noise ratio. Use the "threshold" argument to specify the SNR threshold. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + threshold : float + Threshold for "snr" method (in units of noise levels). + noise_levels : np.array | None, default: None + Noise levels required for the "snr" method. You can use the + `get_noise_levels()` function to compute them. + If the input is a `SortingAnalyzer`, the noise levels are automatically retrieved + if the `noise_levels` extension is present. + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute amplitudes. + amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + Mode to compute the amplitude of the templates for the "snr" method. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ from .template_tools import get_template_amplitudes from .sortinganalyzer import SortingAnalyzer @@ -338,13 +405,13 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p noise_levels = ext.data["noise_levels"] return_scaled = templates_or_sorting_analyzer.return_scaled elif isinstance(templates_or_sorting_analyzer, Templates): - assert noise_levels is not None + assert noise_levels is not None, "To compute sparsity from snr you need to provide noise_levels" return_scaled = templates_or_sorting_analyzer.is_scaled mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") peak_values = get_template_amplitudes( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode="extremum", return_scaled=return_scaled + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode, return_scaled=return_scaled ) for unit_ind, unit_id in enumerate(unit_ids): @@ -356,38 +423,81 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): """ Construct sparsity from a thresholds based on template peak-to-peak values. - Use the "threshold" argument to specify the SNR threshold. + Use the "threshold" argument to specify the peak-to-peak threshold. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + threshold : float + Threshold for "ptp" method (in units of amplitude). + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ + warnings.warn( + "The 'ptp' method is deprecated and will be removed in version 0.103.0. " + "Please use the 'snr' method with the 'peak_to_peak' amplitude mode instead.", + DeprecationWarning, + ) + return cls.from_snr( + templates_or_sorting_analyzer, threshold, amplitude_mode="peak_to_peak", noise_levels=noise_levels + ) - assert ( - templates_or_sorting_analyzer.sparsity is None - ), "To compute sparsity you need a dense SortingAnalyzer or Templates" + @classmethod + def from_amplitude(cls, templates_or_sorting_analyzer, threshold, amplitude_mode="extremum", peak_sign="neg"): + """ + Construct sparsity from a threshold based on template amplitude. + The amplitude is computed with the specified amplitude mode and it is assumed + that the amplitude is in uV. The input `Templates` or `SortingAnalyzer` object must + have scaled templates. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + threshold : float + Threshold for "amplitude" method (in uV). + amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + Mode to compute the amplitude of the templates. + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. + """ from .template_tools import get_template_amplitudes from .sortinganalyzer import SortingAnalyzer from .template import Templates + assert ( + templates_or_sorting_analyzer.sparsity is None + ), "To compute sparsity you need a dense SortingAnalyzer or Templates" + unit_ids = templates_or_sorting_analyzer.unit_ids channel_ids = templates_or_sorting_analyzer.channel_ids if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - ext = templates_or_sorting_analyzer.get_extension("noise_levels") - assert ext is not None, "To compute sparsity from snr you need to compute 'noise_levels' first" - noise_levels = ext.data["noise_levels"] - return_scaled = templates_or_sorting_analyzer.return_scaled + assert templates_or_sorting_analyzer.return_scaled, ( + "To compute sparsity from amplitude you need to have scaled templates. " + "You can set `return_scaled=True` when computing the templates." + ) elif isinstance(templates_or_sorting_analyzer, Templates): - assert noise_levels is not None - return_scaled = templates_or_sorting_analyzer.is_scaled - - from .template_tools import get_dense_templates_array + assert templates_or_sorting_analyzer.is_scaled, ( + "To compute sparsity from amplitude you need to have scaled templates. " + "You can set `is_scaled=True` when creating the Templates object." + ) mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") - templates_array = get_dense_templates_array(templates_or_sorting_analyzer, return_scaled=return_scaled) - templates_ptps = np.ptp(templates_array, axis=1) + peak_values = get_template_amplitudes( + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode, return_scaled=True + ) for unit_ind, unit_id in enumerate(unit_ids): - chan_inds = np.nonzero(templates_ptps[unit_ind] / noise_levels >= threshold) + chan_inds = np.nonzero((np.abs(peak_values[unit_id])) >= threshold) mask[unit_ind, chan_inds] = True return cls(mask, unit_ids, channel_ids) @@ -396,6 +506,19 @@ def from_energy(cls, sorting_analyzer, threshold): """ Construct sparsity from a threshold based on per channel energy ratio. Use the "threshold" argument to specify the SNR threshold. + This method requires the "waveforms" and "noise_levels" extensions to be computed. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + threshold : float + Threshold for "energy" method (in units of noise levels). + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ assert sorting_analyzer.sparsity is None, "To compute sparsity with energy you need a dense SortingAnalyzer" @@ -403,7 +526,7 @@ def from_energy(cls, sorting_analyzer, threshold): # noise_levels ext = sorting_analyzer.get_extension("noise_levels") - assert ext is not None, "To compute sparsity from ptp you need to compute 'noise_levels' first" + assert ext is not None, "To compute sparsity from energy you need to compute 'noise_levels' first" noise_levels = ext.data["noise_levels"] # waveforms @@ -421,51 +544,72 @@ def from_energy(cls, sorting_analyzer, threshold): return cls(mask, sorting_analyzer.unit_ids, sorting_analyzer.channel_ids) @classmethod - def from_property(cls, sorting_analyzer, by_property): + def from_property(cls, sorting, recording, by_property): """ Construct sparsity witha property of the recording and sorting(e.g. "group"). Use the "by_property" argument to specify the property name. + + Parameters + ---------- + sorting : Sorting + A Sorting object. + recording : Recording + A Recording object. + by_property : object + Property name for "by_property" method. Both the recording and sorting must have this property set. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ # check consistency - assert ( - by_property in sorting_analyzer.recording.get_property_keys() - ), f"Property {by_property} is not a recording property" - assert ( - by_property in sorting_analyzer.sorting.get_property_keys() - ), f"Property {by_property} is not a sorting property" + assert by_property in recording.get_property_keys(), f"Property {by_property} is not a recording property" + assert by_property in sorting.get_property_keys(), f"Property {by_property} is not a sorting property" - mask = np.zeros((sorting_analyzer.unit_ids.size, sorting_analyzer.channel_ids.size), dtype="bool") - rec_by = sorting_analyzer.recording.split_by(by_property) - for unit_ind, unit_id in enumerate(sorting_analyzer.unit_ids): - unit_property = sorting_analyzer.sorting.get_property(by_property)[unit_ind] + mask = np.zeros((sorting.unit_ids.size, recording.channel_ids.size), dtype="bool") + rec_by = recording.split_by(by_property) + for unit_ind, unit_id in enumerate(sorting.unit_ids): + unit_property = sorting.get_property(by_property)[unit_ind] assert ( unit_property in rec_by.keys() ), f"Unit property {unit_property} cannot be found in the recording properties" - chan_inds = sorting_analyzer.recording.ids_to_indices(rec_by[unit_property].get_channel_ids()) + chan_inds = recording.ids_to_indices(rec_by[unit_property].get_channel_ids()) mask[unit_ind, chan_inds] = True - return cls(mask, sorting_analyzer.unit_ids, sorting_analyzer.channel_ids) + return cls(mask, sorting.unit_ids, recording.channel_ids) @classmethod def create_dense(cls, sorting_analyzer): """ Create a sparsity object with all selected channel for all units. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + + Returns + ------- + sparsity : ChannelSparsity + The full sparsity. """ mask = np.ones((sorting_analyzer.unit_ids.size, sorting_analyzer.channel_ids.size), dtype="bool") return cls(mask, sorting_analyzer.unit_ids, sorting_analyzer.channel_ids) def compute_sparsity( - templates_or_sorting_analyzer, - noise_levels=None, - method="radius", - peak_sign="neg", - num_channels=5, - radius_um=100.0, - threshold=5, - by_property=None, -): + templates_or_sorting_analyzer: "Templates | SortingAnalyzer", + noise_levels: np.ndarray | None = None, + method: "radius" | "best_channels" | "snr" | "amplitude" | "energy" | "by_property" | "ptp" = "radius", + peak_sign: "neg" | "pos" | "both" = "neg", + num_channels: int | None = 5, + radius_um: float | None = 100.0, + threshold: float | None = 5, + by_property: str | None = None, + amplitude_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", +) -> ChannelSparsity: """ - Get channel sparsity (subset of channels) for each template with several methods. + Compute channel sparsity from a `SortingAnalyzer` for each template with several methods. Parameters ---------- @@ -491,7 +635,7 @@ def compute_sparsity( # to keep backward compatibility templates_or_sorting_analyzer = templates_or_sorting_analyzer.sorting_analyzer - if method in ("best_channels", "radius", "snr", "ptp"): + if method in ("best_channels", "radius", "snr", "amplitude", "ptp"): assert isinstance( templates_or_sorting_analyzer, (Templates, SortingAnalyzer) ), f"compute_sparsity(method='{method}') need Templates or SortingAnalyzer" @@ -500,11 +644,6 @@ def compute_sparsity( templates_or_sorting_analyzer, SortingAnalyzer ), f"compute_sparsity(method='{method}') need SortingAnalyzer" - if method in ("snr", "ptp") and isinstance(templates_or_sorting_analyzer, Templates): - assert ( - noise_levels is not None - ), f"compute_sparsity(..., method='{method}') with Templates need noise_levels as input" - if method == "best_channels": assert num_channels is not None, "For the 'best_channels' method, 'num_channels' needs to be given" sparsity = ChannelSparsity.from_best_channels(templates_or_sorting_analyzer, num_channels, peak_sign=peak_sign) @@ -514,21 +653,36 @@ def compute_sparsity( elif method == "snr": assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given" sparsity = ChannelSparsity.from_snr( - templates_or_sorting_analyzer, threshold, noise_levels=noise_levels, peak_sign=peak_sign - ) - elif method == "ptp": - assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_ptp( templates_or_sorting_analyzer, threshold, noise_levels=noise_levels, + peak_sign=peak_sign, + amplitude_mode=amplitude_mode, + ) + elif method == "amplitude": + assert threshold is not None, "For the 'amplitude' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_amplitude( + templates_or_sorting_analyzer, + threshold, + amplitude_mode=amplitude_mode, + peak_sign=peak_sign, ) elif method == "energy": assert threshold is not None, "For the 'energy' method, 'threshold' needs to be given" sparsity = ChannelSparsity.from_energy(templates_or_sorting_analyzer, threshold) elif method == "by_property": assert by_property is not None, "For the 'by_property' method, 'by_property' needs to be given" - sparsity = ChannelSparsity.from_property(templates_or_sorting_analyzer, by_property) + sparsity = ChannelSparsity.from_property( + templates_or_sorting_analyzer.sorting, templates_or_sorting_analyzer.recording, by_property + ) + elif method == "ptp": + # TODO: remove after deprecation + assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_ptp( + templates_or_sorting_analyzer, + threshold, + noise_levels=noise_levels, + ) else: raise ValueError(f"compute_sparsity() method={method} does not exists") @@ -544,16 +698,21 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" = "radius", - peak_sign: str = "neg", + method: "radius" | "best_channels" | "amplitude" | "snr" | "by_property" | "ptp" = "radius", + peak_sign: "neg" | "pos" | "both" = "neg", radius_um: float = 100.0, num_channels: int = 5, + threshold: float | None = 5, + amplitude_mode: "extremum" | "peak_to_peak" = "extremum", + by_property: str | None = None, + noise_levels: np.ndarray | list | None = None, **job_kwargs, ): """ - Estimate the sparsity without needing a SortingAnalyzer or Templates object - This is faster than `spikeinterface.waveforms_extractor.precompute_sparsity()` and it - traverses the recording to compute the average templates for each unit. + Estimate the sparsity without needing a SortingAnalyzer or Templates object. + In case the sparsity method needs templates, they are computed on-the-fly. + For the "snr" method, `noise_levels` must passed with the `noise_levels` argument. + These can be computed with the `get_noise_levels()` function. Contrary to the previous implementation: * all units are computed in one read of recording @@ -561,29 +720,23 @@ def estimate_sparsity( * it doesn't consume too much memory * it uses internally the `estimate_templates_with_accumulator()` which is fast and parallel + Note that the "energy" method is not supported because it requires a `SortingAnalyzer` object. + Parameters ---------- sorting : BaseSorting The sorting recording : BaseRecording The recording - num_spikes_for_sparsity : int, default: 100 How many spikes per units to compute the sparsity ms_before : float, default: 1.0 Cut out in ms before spike time ms_after : float, default: 2.5 Cut out in ms after spike time - method : "radius" | "best_channels", default: "radius" - Sparsity method propagated to the `compute_sparsity()` function. - Only "radius" or "best_channels" are implemented - peak_sign : "neg" | "pos" | "both", default: "neg" - Sign of the template to compute best channels - radius_um : float, default: 100.0 - Used for "radius" method - num_channels : int, default: 5 - Used for "best_channels" method - + noise_levels : np.array | None, default: None + Noise levels required for the "snr" and "energy" methods. You can use the + `get_noise_levels()` function to compute them. {} Returns @@ -594,7 +747,10 @@ def estimate_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - assert method in ("radius", "best_channels"), "estimate_sparsity() handle only method='radius' or 'best_channel'" + assert method in ("radius", "best_channels", "snr", "amplitude", "by_property", "ptp"), ( + f"method={method} is not available for `estimate_sparsity()`. " + "Available methods are 'radius', 'best_channels', 'snr', 'amplitude', 'by_property', 'ptp' (deprecated)" + ) if recording.get_probes() == 1: # standard case @@ -605,44 +761,81 @@ def estimate_sparsity( chan_locs = recording.get_channel_locations() probe = recording.create_dummy_probe_from_locations(chan_locs) - nbefore = int(ms_before * recording.sampling_frequency / 1000.0) - nafter = int(ms_after * recording.sampling_frequency / 1000.0) - - num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] - random_spikes_indices = random_spikes_selection( - sorting, - num_samples, - method="uniform", - max_spikes_per_unit=num_spikes_for_sparsity, - margin_size=max(nbefore, nafter), - seed=2205, - ) - spikes = sorting.to_spike_vector() - spikes = spikes[random_spikes_indices] - - templates_array = estimate_templates_with_accumulator( - recording, - spikes, - sorting.unit_ids, - nbefore, - nafter, - return_scaled=False, - job_name="estimate_sparsity", - **job_kwargs, - ) - templates = Templates( - templates_array=templates_array, - sampling_frequency=recording.sampling_frequency, - nbefore=nbefore, - sparsity_mask=None, - channel_ids=recording.channel_ids, - unit_ids=sorting.unit_ids, - probe=probe, - ) + if method != "by_property": + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + nafter = int(ms_after * recording.sampling_frequency / 1000.0) + + num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] + random_spikes_indices = random_spikes_selection( + sorting, + num_samples, + method="uniform", + max_spikes_per_unit=num_spikes_for_sparsity, + margin_size=max(nbefore, nafter), + seed=2205, + ) + spikes = sorting.to_spike_vector() + spikes = spikes[random_spikes_indices] + + templates_array = estimate_templates_with_accumulator( + recording, + spikes, + sorting.unit_ids, + nbefore, + nafter, + return_scaled=False, + job_name="estimate_sparsity", + **job_kwargs, + ) + templates = Templates( + templates_array=templates_array, + sampling_frequency=recording.sampling_frequency, + nbefore=nbefore, + sparsity_mask=None, + channel_ids=recording.channel_ids, + unit_ids=sorting.unit_ids, + probe=probe, + ) - sparsity = compute_sparsity( - templates, method=method, peak_sign=peak_sign, num_channels=num_channels, radius_um=radius_um - ) + if method == "best_channels": + assert num_channels is not None, "For the 'best_channels' method, 'num_channels' needs to be given" + sparsity = ChannelSparsity.from_best_channels( + templates, num_channels, peak_sign=peak_sign, amplitude_mode=amplitude_mode + ) + elif method == "radius": + assert radius_um is not None, "For the 'radius' method, 'radius_um' needs to be given" + sparsity = ChannelSparsity.from_radius(templates, radius_um, peak_sign=peak_sign) + elif method == "snr": + assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given" + assert noise_levels is not None, ( + "For the 'snr' method, 'noise_levels' needs to be given. You can use the " + "`get_noise_levels()` function to compute them." + ) + sparsity = ChannelSparsity.from_snr( + templates, + threshold, + noise_levels=noise_levels, + peak_sign=peak_sign, + amplitude_mode=amplitude_mode, + ) + elif method == "amplitude": + assert threshold is not None, "For the 'amplitude' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_amplitude( + templates, threshold, amplitude_mode=amplitude_mode, peak_sign=peak_sign + ) + elif method == "ptp": + # TODO: remove after deprecation + assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" + assert noise_levels is not None, ( + "For the 'snr' method, 'noise_levels' needs to be given. You can use the " + "`get_noise_levels()` function to compute them." + ) + sparsity = ChannelSparsity.from_ptp(templates, threshold, noise_levels=noise_levels) + else: + raise ValueError(f"compute_sparsity() method={method} does not exists") + else: + assert by_property is not None, "For the 'by_property' method, 'by_property' needs to be given" + sparsity = ChannelSparsity.from_property(sorting, recording, by_property) return sparsity diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index a192d90502..ace869df8c 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -3,7 +3,7 @@ import numpy as np import json -from spikeinterface.core import ChannelSparsity, estimate_sparsity, compute_sparsity, Templates +from spikeinterface.core import ChannelSparsity, estimate_sparsity, compute_sparsity, get_noise_levels from spikeinterface.core.core_tools import check_json from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core import create_sorting_analyzer @@ -86,7 +86,7 @@ def test_sparsify_waveforms(): num_active_channels = len(non_zero_indices) assert waveforms_sparse.shape == (num_units, num_samples, num_active_channels) - # Test round-trip (note that this is loosy) + # Test round-trip (note that this is lossy) unit_id = unit_ids[unit_id] non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id] waveforms_dense2 = sparsity.densify_waveforms(waveforms_sparse, unit_id=unit_id) @@ -195,6 +195,82 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 3) + # by_property + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="by_property", + by_property="group", + progress_bar=True, + n_jobs=1, + ) + assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 5) + + # amplitude + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="amplitude", + threshold=5, + amplitude_mode="peak_to_peak", + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) + + # snr: fails without noise levels + with pytest.raises(AssertionError): + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="snr", + threshold=5, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) + # snr: works with noise levels + noise_levels = get_noise_levels(recording) + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="snr", + threshold=5, + noise_levels=noise_levels, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) + # ptp: just run it + print(noise_levels) + + with pytest.warns(DeprecationWarning): + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="ptp", + threshold=5, + noise_levels=noise_levels, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) + def test_compute_sparsity(): recording, sorting = get_dataset() @@ -212,9 +288,14 @@ def test_compute_sparsity(): sparsity = compute_sparsity(sorting_analyzer, method="best_channels", num_channels=2, peak_sign="neg") sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=50.0, peak_sign="neg") sparsity = compute_sparsity(sorting_analyzer, method="snr", threshold=5, peak_sign="neg") - sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5) + sparsity = compute_sparsity( + sorting_analyzer, method="snr", threshold=5, peak_sign="neg", amplitude_mode="peak_to_peak" + ) + sparsity = compute_sparsity(sorting_analyzer, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") sparsity = compute_sparsity(sorting_analyzer, method="energy", threshold=5) sparsity = compute_sparsity(sorting_analyzer, method="by_property", by_property="group") + with pytest.warns(DeprecationWarning): + sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5) # using object Templates templates = sorting_analyzer.get_extension("templates").get_data(outputs="Templates") @@ -222,7 +303,10 @@ def test_compute_sparsity(): sparsity = compute_sparsity(templates, method="best_channels", num_channels=2, peak_sign="neg") sparsity = compute_sparsity(templates, method="radius", radius_um=50.0, peak_sign="neg") sparsity = compute_sparsity(templates, method="snr", noise_levels=noise_levels, threshold=5, peak_sign="neg") - sparsity = compute_sparsity(templates, method="ptp", noise_levels=noise_levels, threshold=5) + sparsity = compute_sparsity(templates, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") + + with pytest.warns(DeprecationWarning): + sparsity = compute_sparsity(templates, method="ptp", noise_levels=noise_levels, threshold=5) if __name__ == "__main__": diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 4701d76012..c3b3099535 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -25,7 +25,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "sparsity": {"method": "ptp", "threshold": 0.25}, + "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py index aa9b16bb97..71a5f282a8 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py @@ -27,7 +27,7 @@ def test_benchmark_matching(create_cache_folder): recording, gt_sorting, ms_before=2.0, ms_after=3.0, return_scaled=False, **job_kwargs ) noise_levels = get_noise_levels(recording) - sparsity = compute_sparsity(gt_templates, noise_levels, method="ptp", threshold=0.25) + sparsity = compute_sparsity(gt_templates, noise_levels, method="snr", amplitude_mode="peak_to_peak", threshold=0.25) gt_templates = gt_templates.to_sparse(sparsity) # create study diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 2bacf36ac9..b08ee4d9cb 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -50,7 +50,7 @@ class CircusClustering: }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, - "sparsity": {"method": "ptp", "threshold": 0.25}, + "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "recursive_kwargs": { "recursive": True, "recursive_depth": 3, diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 77d47aec16..f7ca999d53 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -45,7 +45,7 @@ class RandomProjectionClustering: }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, - "sparsity": {"method": "ptp", "threshold": 0.25}, + "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "radius_um": 30, "nb_projections": 10, "feature": "energy",