From c1e5e4ec1f889e3070a19734e4ea9169a872e280 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 10:59:16 +0200 Subject: [PATCH 1/8] extend estimate_sparsity methods and fix from_ptp --- src/spikeinterface/core/sparsity.py | 137 +++++++++--------- .../core/tests/test_sparsity.py | 29 +++- 2 files changed, 98 insertions(+), 68 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index a38562ea2c..57e1fa4769 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -338,7 +338,7 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p noise_levels = ext.data["noise_levels"] return_scaled = templates_or_sorting_analyzer.return_scaled elif isinstance(templates_or_sorting_analyzer, Templates): - assert noise_levels is not None + assert noise_levels is not None, "To compute sparsity from snr you need to provide noise_levels" return_scaled = templates_or_sorting_analyzer.is_scaled mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") @@ -353,17 +353,17 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p return cls(mask, unit_ids, channel_ids) @classmethod - def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): + def from_ptp(cls, templates_or_sorting_analyzer, threshold): """ Construct sparsity from a thresholds based on template peak-to-peak values. - Use the "threshold" argument to specify the SNR threshold. + Use the "threshold" argument to specify the peak-to-peak threshold. """ assert ( templates_or_sorting_analyzer.sparsity is None ), "To compute sparsity you need a dense SortingAnalyzer or Templates" - from .template_tools import get_template_amplitudes + from .template_tools import get_dense_templates_array from .sortinganalyzer import SortingAnalyzer from .template import Templates @@ -371,23 +371,17 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): channel_ids = templates_or_sorting_analyzer.channel_ids if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - ext = templates_or_sorting_analyzer.get_extension("noise_levels") - assert ext is not None, "To compute sparsity from snr you need to compute 'noise_levels' first" - noise_levels = ext.data["noise_levels"] return_scaled = templates_or_sorting_analyzer.return_scaled elif isinstance(templates_or_sorting_analyzer, Templates): - assert noise_levels is not None return_scaled = templates_or_sorting_analyzer.is_scaled - from .template_tools import get_dense_templates_array - mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") templates_array = get_dense_templates_array(templates_or_sorting_analyzer, return_scaled=return_scaled) templates_ptps = np.ptp(templates_array, axis=1) for unit_ind, unit_id in enumerate(unit_ids): - chan_inds = np.nonzero(templates_ptps[unit_ind] / noise_levels >= threshold) + chan_inds = np.nonzero(templates_ptps[unit_ind] >= threshold) mask[unit_ind, chan_inds] = True return cls(mask, unit_ids, channel_ids) @@ -455,15 +449,15 @@ def create_dense(cls, sorting_analyzer): def compute_sparsity( - templates_or_sorting_analyzer, - noise_levels=None, - method="radius", - peak_sign="neg", - num_channels=5, - radius_um=100.0, - threshold=5, - by_property=None, -): + templates_or_sorting_analyzer: Union[Templates, SortingAnalyzer], + noise_levels: np.ndarray | None = None, + method: "radius" | "best_channels" | "snr" | "ptp" | "energy" | "by_property" = "radius", + peak_sign: "neg" | "pos" | "both" = "neg", + num_channels: int | None = 5, + radius_um: float | None = 100.0, + threshold: float | None = 5, + by_property: str | None = None, +) -> ChannelSparsity: """ Get channel sparsity (subset of channels) for each template with several methods. @@ -500,11 +494,6 @@ def compute_sparsity( templates_or_sorting_analyzer, SortingAnalyzer ), f"compute_sparsity(method='{method}') need SortingAnalyzer" - if method in ("snr", "ptp") and isinstance(templates_or_sorting_analyzer, Templates): - assert ( - noise_levels is not None - ), f"compute_sparsity(..., method='{method}') with Templates need noise_levels as input" - if method == "best_channels": assert num_channels is not None, "For the 'best_channels' method, 'num_channels' needs to be given" sparsity = ChannelSparsity.from_best_channels(templates_or_sorting_analyzer, num_channels, peak_sign=peak_sign) @@ -521,7 +510,6 @@ def compute_sparsity( sparsity = ChannelSparsity.from_ptp( templates_or_sorting_analyzer, threshold, - noise_levels=noise_levels, ) elif method == "energy": assert threshold is not None, "For the 'energy' method, 'threshold' needs to be given" @@ -544,10 +532,12 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" = "radius", - peak_sign: str = "neg", + method: "radius" | "best_channels" | "ptp" | "by_property" = "radius", + peak_sign: "neg" | "pos" | "both" = "neg", radius_um: float = 100.0, num_channels: int = 5, + threshold: float | None = 5, + by_property: str | None = None, **job_kwargs, ): """ @@ -567,16 +557,15 @@ def estimate_sparsity( The sorting recording : BaseRecording The recording - num_spikes_for_sparsity : int, default: 100 How many spikes per units to compute the sparsity ms_before : float, default: 1.0 Cut out in ms before spike time ms_after : float, default: 2.5 Cut out in ms after spike time - method : "radius" | "best_channels", default: "radius" + method : "radius" | "best_channels" | "ptp" | "by_property", default: "radius" Sparsity method propagated to the `compute_sparsity()` function. - Only "radius" or "best_channels" are implemented + "snr" and "energy" are not available here, because they require noise levels. peak_sign : "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels radius_um : float, default: 100.0 @@ -594,7 +583,10 @@ def estimate_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - assert method in ("radius", "best_channels"), "estimate_sparsity() handle only method='radius' or 'best_channel'" + assert method in ("radius", "best_channels", "ptp", "by_property"), ( + f"method={method} is not available for `estimate_sparsity()`. " + "Available methods are 'radius', 'best_channels', 'ptp', 'by_property'" + ) if recording.get_probes() == 1: # standard case @@ -605,43 +597,54 @@ def estimate_sparsity( chan_locs = recording.get_channel_locations() probe = recording.create_dummy_probe_from_locations(chan_locs) - nbefore = int(ms_before * recording.sampling_frequency / 1000.0) - nafter = int(ms_after * recording.sampling_frequency / 1000.0) - - num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] - random_spikes_indices = random_spikes_selection( - sorting, - num_samples, - method="uniform", - max_spikes_per_unit=num_spikes_for_sparsity, - margin_size=max(nbefore, nafter), - seed=2205, - ) - spikes = sorting.to_spike_vector() - spikes = spikes[random_spikes_indices] - - templates_array = estimate_templates_with_accumulator( - recording, - spikes, - sorting.unit_ids, - nbefore, - nafter, - return_scaled=False, - job_name="estimate_sparsity", - **job_kwargs, - ) - templates = Templates( - templates_array=templates_array, - sampling_frequency=recording.sampling_frequency, - nbefore=nbefore, - sparsity_mask=None, - channel_ids=recording.channel_ids, - unit_ids=sorting.unit_ids, - probe=probe, - ) + if method != "by_property": + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + nafter = int(ms_after * recording.sampling_frequency / 1000.0) + + num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] + random_spikes_indices = random_spikes_selection( + sorting, + num_samples, + method="uniform", + max_spikes_per_unit=num_spikes_for_sparsity, + margin_size=max(nbefore, nafter), + seed=2205, + ) + spikes = sorting.to_spike_vector() + spikes = spikes[random_spikes_indices] + + templates_array = estimate_templates_with_accumulator( + recording, + spikes, + sorting.unit_ids, + nbefore, + nafter, + return_scaled=False, + job_name="estimate_sparsity", + **job_kwargs, + ) + templates = Templates( + templates_array=templates_array, + sampling_frequency=recording.sampling_frequency, + nbefore=nbefore, + sparsity_mask=None, + channel_ids=recording.channel_ids, + unit_ids=sorting.unit_ids, + probe=probe, + ) + templates_or_analyzer = templates + else: + from .sortinganalyzer import create_sorting_analyzer + templates_or_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, sparse=False) sparsity = compute_sparsity( - templates, method=method, peak_sign=peak_sign, num_channels=num_channels, radius_um=radius_um + templates_or_analyzer, + method=method, + peak_sign=peak_sign, + num_channels=num_channels, + radius_um=radius_um, + threshold=threshold, + by_property=by_property, ) return sparsity diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index a192d90502..6ee023fc12 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -86,7 +86,7 @@ def test_sparsify_waveforms(): num_active_channels = len(non_zero_indices) assert waveforms_sparse.shape == (num_units, num_samples, num_active_channels) - # Test round-trip (note that this is loosy) + # Test round-trip (note that this is lossy) unit_id = unit_ids[unit_id] non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id] waveforms_dense2 = sparsity.densify_waveforms(waveforms_sparse, unit_id=unit_id) @@ -195,6 +195,33 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 3) + # ptp : just run it + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="ptp", + threshold=3, + progress_bar=True, + n_jobs=1, + ) + + # by_property : just run it + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="by_property", + by_property="group", + progress_bar=True, + n_jobs=1, + ) + assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 5) + def test_compute_sparsity(): recording, sorting = get_dataset() From 30d7dbbc3998fb820b62a3029a4c36fffd48d71b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 12:38:03 +0200 Subject: [PATCH 2/8] Revert ptp changes --- src/spikeinterface/core/sparsity.py | 24 +++++++++++-------- .../core/tests/test_sparsity.py | 13 ---------- 2 files changed, 14 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 57e1fa4769..4de786cb37 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -353,10 +353,10 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p return cls(mask, unit_ids, channel_ids) @classmethod - def from_ptp(cls, templates_or_sorting_analyzer, threshold): + def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): """ - Construct sparsity from a thresholds based on template peak-to-peak values. - Use the "threshold" argument to specify the peak-to-peak threshold. + Construct sparsity from a thresholds based on template peak-to-peak relative values. + Use the "threshold" argument to specify the peak-to-peak threshold (with respect to noise_levels). """ assert ( @@ -371,8 +371,12 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold): channel_ids = templates_or_sorting_analyzer.channel_ids if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): + ext = templates_or_sorting_analyzer.get_extension("noise_levels") + assert ext is not None, "To compute sparsity from ptp you need to compute 'noise_levels' first" + noise_levels = ext.data["noise_levels"] return_scaled = templates_or_sorting_analyzer.return_scaled elif isinstance(templates_or_sorting_analyzer, Templates): + assert noise_levels is not None, "To compute sparsity from ptp you need to provide noise_levels" return_scaled = templates_or_sorting_analyzer.is_scaled mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") @@ -381,7 +385,7 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold): templates_ptps = np.ptp(templates_array, axis=1) for unit_ind, unit_id in enumerate(unit_ids): - chan_inds = np.nonzero(templates_ptps[unit_ind] >= threshold) + chan_inds = np.nonzero(templates_ptps[unit_ind] / noise_levels >= threshold) mask[unit_ind, chan_inds] = True return cls(mask, unit_ids, channel_ids) @@ -397,7 +401,7 @@ def from_energy(cls, sorting_analyzer, threshold): # noise_levels ext = sorting_analyzer.get_extension("noise_levels") - assert ext is not None, "To compute sparsity from ptp you need to compute 'noise_levels' first" + assert ext is not None, "To compute sparsity from energy you need to compute 'noise_levels' first" noise_levels = ext.data["noise_levels"] # waveforms @@ -532,7 +536,7 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" | "ptp" | "by_property" = "radius", + method: "radius" | "best_channels" | "by_property" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", radius_um: float = 100.0, num_channels: int = 5, @@ -563,9 +567,9 @@ def estimate_sparsity( Cut out in ms before spike time ms_after : float, default: 2.5 Cut out in ms after spike time - method : "radius" | "best_channels" | "ptp" | "by_property", default: "radius" + method : "radius" | "best_channels" | "by_property", default: "radius" Sparsity method propagated to the `compute_sparsity()` function. - "snr" and "energy" are not available here, because they require noise levels. + "snr", "ptp", and "energy" are not available here because they require noise levels. peak_sign : "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels radius_um : float, default: 100.0 @@ -583,9 +587,9 @@ def estimate_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - assert method in ("radius", "best_channels", "ptp", "by_property"), ( + assert method in ("radius", "best_channels", "by_property"), ( f"method={method} is not available for `estimate_sparsity()`. " - "Available methods are 'radius', 'best_channels', 'ptp', 'by_property'" + "Available methods are 'radius', 'best_channels', 'by_property'" ) if recording.get_probes() == 1: diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 6ee023fc12..b60c1c2eca 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -195,19 +195,6 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 3) - # ptp : just run it - sparsity = estimate_sparsity( - sorting, - recording, - num_spikes_for_sparsity=50, - ms_before=1.0, - ms_after=2.0, - method="ptp", - threshold=3, - progress_bar=True, - n_jobs=1, - ) - # by_property : just run it sparsity = estimate_sparsity( sorting, From fa8eb1f5b349846f13247537f8b3460b45ae9deb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 13:10:14 +0200 Subject: [PATCH 3/8] Expose snr_amplitude_mode for snr sparsity --- src/spikeinterface/core/sparsity.py | 43 +++++++++++-------- .../core/tests/test_sparsity.py | 19 +++++++- 2 files changed, 44 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 4de786cb37..6904886dcf 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -25,14 +25,16 @@ * "by_property" : sparsity is given by a property of the recording and sorting(e.g. "group"). Use the "by_property" argument to specify the property name. - peak_sign : str - Sign of the template to compute best channels ("neg", "pos", "both") + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute best channels num_channels : int Number of channels for "best_channels" method radius_um : float Radius in um for "radius" method threshold : float Threshold in SNR "threshold" method + snr_amplitude_mode : "extremum" | "at_index" | "peak_to_peak" + Mode to compute the amplitude of the templates for the "snr" method by_property : object Property name for "by_property" method """ @@ -316,7 +318,9 @@ def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) @classmethod - def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, peak_sign="neg"): + def from_snr( + cls, templates_or_sorting_analyzer, threshold, snr_amplitude_mode="extremum", noise_levels=None, peak_sign="neg" + ): """ Construct sparsity from a thresholds based on template signal-to-noise ratio. Use the "threshold" argument to specify the SNR threshold. @@ -344,7 +348,7 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") peak_values = get_template_amplitudes( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode="extremum", return_scaled=return_scaled + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=snr_amplitude_mode, return_scaled=return_scaled ) for unit_ind, unit_id in enumerate(unit_ids): @@ -353,10 +357,10 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p return cls(mask, unit_ids, channel_ids) @classmethod - def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): + def from_ptp(cls, templates_or_sorting_analyzer, threshold): """ - Construct sparsity from a thresholds based on template peak-to-peak relative values. - Use the "threshold" argument to specify the peak-to-peak threshold (with respect to noise_levels). + Construct sparsity from a thresholds based on template peak-to-peak values. + Use the "threshold" argument to specify the peak-to-peak threshold. """ assert ( @@ -371,12 +375,8 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): channel_ids = templates_or_sorting_analyzer.channel_ids if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - ext = templates_or_sorting_analyzer.get_extension("noise_levels") - assert ext is not None, "To compute sparsity from ptp you need to compute 'noise_levels' first" - noise_levels = ext.data["noise_levels"] return_scaled = templates_or_sorting_analyzer.return_scaled elif isinstance(templates_or_sorting_analyzer, Templates): - assert noise_levels is not None, "To compute sparsity from ptp you need to provide noise_levels" return_scaled = templates_or_sorting_analyzer.is_scaled mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") @@ -385,7 +385,7 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): templates_ptps = np.ptp(templates_array, axis=1) for unit_ind, unit_id in enumerate(unit_ids): - chan_inds = np.nonzero(templates_ptps[unit_ind] / noise_levels >= threshold) + chan_inds = np.nonzero(templates_ptps[unit_ind] >= threshold) mask[unit_ind, chan_inds] = True return cls(mask, unit_ids, channel_ids) @@ -453,7 +453,7 @@ def create_dense(cls, sorting_analyzer): def compute_sparsity( - templates_or_sorting_analyzer: Union[Templates, SortingAnalyzer], + templates_or_sorting_analyzer: Templates | SortingAnalyzer, noise_levels: np.ndarray | None = None, method: "radius" | "best_channels" | "snr" | "ptp" | "energy" | "by_property" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", @@ -461,6 +461,7 @@ def compute_sparsity( radius_um: float | None = 100.0, threshold: float | None = 5, by_property: str | None = None, + snr_amplitude_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", ) -> ChannelSparsity: """ Get channel sparsity (subset of channels) for each template with several methods. @@ -507,7 +508,11 @@ def compute_sparsity( elif method == "snr": assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given" sparsity = ChannelSparsity.from_snr( - templates_or_sorting_analyzer, threshold, noise_levels=noise_levels, peak_sign=peak_sign + templates_or_sorting_analyzer, + threshold, + noise_levels=noise_levels, + peak_sign=peak_sign, + snr_amplitude_mode=snr_amplitude_mode, ) elif method == "ptp": assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" @@ -536,11 +541,12 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" | "by_property" = "radius", + method: "radius" | "best_channels" | "ptp" | "by_property" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", radius_um: float = 100.0, num_channels: int = 5, threshold: float | None = 5, + snr_amplitude_mode: "extremum" | "peak_to_peak" = "extremum", by_property: str | None = None, **job_kwargs, ): @@ -576,6 +582,8 @@ def estimate_sparsity( Used for "radius" method num_channels : int, default: 5 Used for "best_channels" method + snr_amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + Used for "snr" method to compute the amplitude of the templates. {} @@ -587,9 +595,9 @@ def estimate_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - assert method in ("radius", "best_channels", "by_property"), ( + assert method in ("radius", "best_channels", "ptp", "by_property"), ( f"method={method} is not available for `estimate_sparsity()`. " - "Available methods are 'radius', 'best_channels', 'by_property'" + "Available methods are 'radius', 'best_channels', 'ptp', 'by_property'" ) if recording.get_probes() == 1: @@ -649,6 +657,7 @@ def estimate_sparsity( radius_um=radius_um, threshold=threshold, by_property=by_property, + snr_amplitude_mode=snr_amplitude_mode, ) return sparsity diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index b60c1c2eca..16b3bbc996 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -195,7 +195,7 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 3) - # by_property : just run it + # by_property sparsity = estimate_sparsity( sorting, recording, @@ -209,6 +209,20 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 5) + # ptp: just run it + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="ptp", + threshold=5, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) + def test_compute_sparsity(): recording, sorting = get_dataset() @@ -226,6 +240,9 @@ def test_compute_sparsity(): sparsity = compute_sparsity(sorting_analyzer, method="best_channels", num_channels=2, peak_sign="neg") sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=50.0, peak_sign="neg") sparsity = compute_sparsity(sorting_analyzer, method="snr", threshold=5, peak_sign="neg") + sparsity = compute_sparsity( + sorting_analyzer, method="snr", threshold=5, peak_sign="neg", snr_amplitude_mode="peak_to_peak" + ) sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5) sparsity = compute_sparsity(sorting_analyzer, method="energy", threshold=5) sparsity = compute_sparsity(sorting_analyzer, method="by_property", by_property="group") From b5c56d64373b217dcf395e95e1a2b7b70a5f8dbe Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 13:35:45 +0200 Subject: [PATCH 4/8] Propagate sparsity change --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- .../benchmark/tests/test_benchmark_matching.py | 4 +++- .../sortingcomponents/clustering/random_projections.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 4701d76012..a6d212425d 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -25,7 +25,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "sparsity": {"method": "ptp", "threshold": 0.25}, + "sparsity": {"method": "snr", "snr_amplitude_mode": "peak_to_peak", "threshold": 0.25}, "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py index aa9b16bb97..d6d0440a02 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py @@ -27,7 +27,9 @@ def test_benchmark_matching(create_cache_folder): recording, gt_sorting, ms_before=2.0, ms_after=3.0, return_scaled=False, **job_kwargs ) noise_levels = get_noise_levels(recording) - sparsity = compute_sparsity(gt_templates, noise_levels, method="ptp", threshold=0.25) + sparsity = compute_sparsity( + gt_templates, noise_levels, method="snr", snr_amplitude_mode="peak_to_peak", threshold=0.25 + ) gt_templates = gt_templates.to_sparse(sparsity) # create study diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 77d47aec16..b56fd3e02b 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -45,7 +45,7 @@ class RandomProjectionClustering: }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, - "sparsity": {"method": "ptp", "threshold": 0.25}, + "sparsity": {"method": "snr", "snr_amplitude_mode": "peak_to_peak", "threshold": 0.25}, "radius_um": 30, "nb_projections": 10, "feature": "energy", From cc14ab2697331ca4b9ca35e972d959e2cbe07a03 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 13:39:10 +0200 Subject: [PATCH 5/8] last one --- src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 2bacf36ac9..11a628bb53 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -50,7 +50,7 @@ class CircusClustering: }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, - "sparsity": {"method": "ptp", "threshold": 0.25}, + "sparsity": {"method": "snr", "snr_amplitude_mode": "peak_to_peak", "threshold": 0.25}, "recursive_kwargs": { "recursive": True, "recursive_depth": 3, From fa84e8c6869a475046e4b11c959a8bbac7c5d106 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 10 Sep 2024 16:52:54 +0200 Subject: [PATCH 6/8] Potentiate 'estimate_sparsity' and refactor from_property() constructor --- src/spikeinterface/core/sparsity.py | 261 +++++++++++++----- .../core/tests/test_sparsity.py | 34 ++- .../sorters/internal/spyking_circus2.py | 2 +- .../tests/test_benchmark_matching.py | 4 +- .../sortingcomponents/clustering/circus.py | 2 +- .../clustering/random_projections.py | 2 +- 6 files changed, 226 insertions(+), 79 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 6904886dcf..c72f89520e 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -13,30 +13,33 @@ _sparsity_doc = """ method : str * "best_channels" : N best channels with the largest amplitude. Use the "num_channels" argument to specify the - number of channels. - * "radius" : radius around the best channel. Use the "radius_um" argument to specify the radius in um + number of channels. + * "radius" : radius around the best channel. Use the "radius_um" argument to specify the radius in um. * "snr" : threshold based on template signal-to-noise ratio. Use the "threshold" argument - to specify the SNR threshold (in units of noise levels) + to specify the SNR threshold (in units of noise levels) and the "amplitude_mode" argument + to specify the mode to compute the amplitude of the templates. * "ptp" : threshold based on the peak-to-peak values on every channels. Use the "threshold" argument - to specify the ptp threshold (in units of noise levels) + to specify the ptp threshold (in units of amplitude). * "energy" : threshold based on the expected energy that should be present on the channels, - given their noise levels. Use the "threshold" argument to specify the SNR threshold + given their noise levels. Use the "threshold" argument to specify the energy threshold (in units of noise levels) - * "by_property" : sparsity is given by a property of the recording and sorting(e.g. "group"). - Use the "by_property" argument to specify the property name. + * "by_property" : sparsity is given by a property of the recording and sorting (e.g. "group"). + In this case the sparsity for each unit is given by the channels that have the same property + value as the unit. Use the "by_property" argument to specify the property name. peak_sign : "neg" | "pos" | "both" - Sign of the template to compute best channels + Sign of the template to compute best channels. num_channels : int - Number of channels for "best_channels" method + Number of channels for "best_channels" method. radius_um : float - Radius in um for "radius" method + Radius in um for "radius" method. threshold : float - Threshold in SNR "threshold" method - snr_amplitude_mode : "extremum" | "at_index" | "peak_to_peak" - Mode to compute the amplitude of the templates for the "snr" method + Threshold for "snr", "energy" (in units of noise levels) and "ptp" methods (in units of amplitude). + For the "snr" method, the template amplitude mode is controlled by the "amplitude_mode" argument. + amplitude_mode : "extremum" | "at_index" | "peak_to_peak" + Mode to compute the amplitude of the templates for the "snr" and "best_channels" methods. by_property : object - Property name for "by_property" method + Property name for "by_property" method. """ @@ -279,18 +282,35 @@ def from_dict(cls, dictionary: dict): ## Some convinient function to compute sparsity from several strategy @classmethod - def from_best_channels(cls, templates_or_sorting_analyzer, num_channels, peak_sign="neg"): + def from_best_channels( + cls, templates_or_sorting_analyzer, num_channels, peak_sign="neg", amplitude_mode="extremum" + ): """ Construct sparsity from N best channels with the largest amplitude. Use the "num_channels" argument to specify the number of channels. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + num_channels : int + Number of channels for "best_channels" method. + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute best channels. + amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + Mode to compute the amplitude of the templates. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity """ from .template_tools import get_template_amplitudes - print(templates_or_sorting_analyzer) mask = np.zeros( (templates_or_sorting_analyzer.unit_ids.size, templates_or_sorting_analyzer.channel_ids.size), dtype="bool" ) - peak_values = get_template_amplitudes(templates_or_sorting_analyzer, peak_sign=peak_sign) + peak_values = get_template_amplitudes(templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode) for unit_ind, unit_id in enumerate(templates_or_sorting_analyzer.unit_ids): chan_inds = np.argsort(np.abs(peak_values[unit_id]))[::-1] chan_inds = chan_inds[:num_channels] @@ -301,7 +321,21 @@ def from_best_channels(cls, templates_or_sorting_analyzer, num_channels, peak_si def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): """ Construct sparsity from a radius around the best channel. - Use the "radius_um" argument to specify the radius in um + Use the "radius_um" argument to specify the radius in um. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + radius_um : float + Radius in um for "radius" method. + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute best channels. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ from .template_tools import get_template_extremum_channel @@ -319,11 +353,37 @@ def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): @classmethod def from_snr( - cls, templates_or_sorting_analyzer, threshold, snr_amplitude_mode="extremum", noise_levels=None, peak_sign="neg" + cls, + templates_or_sorting_analyzer, + threshold, + amplitude_mode="extremum", + peak_sign="neg", + noise_levels=None, ): """ Construct sparsity from a thresholds based on template signal-to-noise ratio. Use the "threshold" argument to specify the SNR threshold. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + threshold : float + Threshold for "snr" method (in units of noise levels). + noise_levels : np.array | None, default: None + Noise levels required for the "snr" method. You can use the + `get_noise_levels()` function to compute them. + If the input is a `SortingAnalyzer`, the noise levels are automatically retrieved + if the `noise_levels` extension is present. + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute amplitudes. + amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + Mode to compute the amplitude of the templates for the "snr" method. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ from .template_tools import get_template_amplitudes from .sortinganalyzer import SortingAnalyzer @@ -348,7 +408,7 @@ def from_snr( mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") peak_values = get_template_amplitudes( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode=snr_amplitude_mode, return_scaled=return_scaled + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode, return_scaled=return_scaled ) for unit_ind, unit_id in enumerate(unit_ids): @@ -361,6 +421,18 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold): """ Construct sparsity from a thresholds based on template peak-to-peak values. Use the "threshold" argument to specify the peak-to-peak threshold. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + threshold : float + Threshold for "ptp" method (in units of amplitude). + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ assert ( @@ -394,6 +466,19 @@ def from_energy(cls, sorting_analyzer, threshold): """ Construct sparsity from a threshold based on per channel energy ratio. Use the "threshold" argument to specify the SNR threshold. + This method requires the "waveforms" and "noise_levels" extensions to be computed. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + threshold : float + Threshold for "energy" method (in units of noise levels). + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ assert sorting_analyzer.sparsity is None, "To compute sparsity with energy you need a dense SortingAnalyzer" @@ -419,41 +504,61 @@ def from_energy(cls, sorting_analyzer, threshold): return cls(mask, sorting_analyzer.unit_ids, sorting_analyzer.channel_ids) @classmethod - def from_property(cls, sorting_analyzer, by_property): + def from_property(cls, sorting, recording, by_property): """ Construct sparsity witha property of the recording and sorting(e.g. "group"). Use the "by_property" argument to specify the property name. + + Parameters + ---------- + sorting : Sorting + A Sorting object. + recording : Recording + A Recording object. + by_property : object + Property name for "by_property" method. Both the recording and sorting must have this property set. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ # check consistency - assert ( - by_property in sorting_analyzer.recording.get_property_keys() - ), f"Property {by_property} is not a recording property" - assert ( - by_property in sorting_analyzer.sorting.get_property_keys() - ), f"Property {by_property} is not a sorting property" + assert by_property in recording.get_property_keys(), f"Property {by_property} is not a recording property" + assert by_property in sorting.get_property_keys(), f"Property {by_property} is not a sorting property" - mask = np.zeros((sorting_analyzer.unit_ids.size, sorting_analyzer.channel_ids.size), dtype="bool") - rec_by = sorting_analyzer.recording.split_by(by_property) - for unit_ind, unit_id in enumerate(sorting_analyzer.unit_ids): - unit_property = sorting_analyzer.sorting.get_property(by_property)[unit_ind] + mask = np.zeros((sorting.unit_ids.size, recording.channel_ids.size), dtype="bool") + rec_by = recording.split_by(by_property) + for unit_ind, unit_id in enumerate(sorting.unit_ids): + unit_property = sorting.get_property(by_property)[unit_ind] assert ( unit_property in rec_by.keys() ), f"Unit property {unit_property} cannot be found in the recording properties" - chan_inds = sorting_analyzer.recording.ids_to_indices(rec_by[unit_property].get_channel_ids()) + chan_inds = recording.ids_to_indices(rec_by[unit_property].get_channel_ids()) mask[unit_ind, chan_inds] = True - return cls(mask, sorting_analyzer.unit_ids, sorting_analyzer.channel_ids) + return cls(mask, sorting.unit_ids, recording.channel_ids) @classmethod def create_dense(cls, sorting_analyzer): """ Create a sparsity object with all selected channel for all units. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + + Returns + ------- + sparsity : ChannelSparsity + The full sparsity. """ mask = np.ones((sorting_analyzer.unit_ids.size, sorting_analyzer.channel_ids.size), dtype="bool") return cls(mask, sorting_analyzer.unit_ids, sorting_analyzer.channel_ids) def compute_sparsity( - templates_or_sorting_analyzer: Templates | SortingAnalyzer, + templates_or_sorting_analyzer: "Templates | SortingAnalyzer", noise_levels: np.ndarray | None = None, method: "radius" | "best_channels" | "snr" | "ptp" | "energy" | "by_property" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", @@ -461,10 +566,10 @@ def compute_sparsity( radius_um: float | None = 100.0, threshold: float | None = 5, by_property: str | None = None, - snr_amplitude_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + amplitude_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", ) -> ChannelSparsity: """ - Get channel sparsity (subset of channels) for each template with several methods. + Compute channel sparsity from a `SortingAnalyzer` for each template with several methods. Parameters ---------- @@ -512,7 +617,7 @@ def compute_sparsity( threshold, noise_levels=noise_levels, peak_sign=peak_sign, - snr_amplitude_mode=snr_amplitude_mode, + amplitude_mode=amplitude_mode, ) elif method == "ptp": assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" @@ -525,7 +630,9 @@ def compute_sparsity( sparsity = ChannelSparsity.from_energy(templates_or_sorting_analyzer, threshold) elif method == "by_property": assert by_property is not None, "For the 'by_property' method, 'by_property' needs to be given" - sparsity = ChannelSparsity.from_property(templates_or_sorting_analyzer, by_property) + sparsity = ChannelSparsity.from_property( + templates_or_sorting_analyzer.sorting, templates_or_sorting_analyzer.recording, by_property + ) else: raise ValueError(f"compute_sparsity() method={method} does not exists") @@ -541,19 +648,20 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" | "ptp" | "by_property" = "radius", + method: "radius" | "best_channels" | "ptp" | "snr" | "by_property" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", radius_um: float = 100.0, num_channels: int = 5, threshold: float | None = 5, - snr_amplitude_mode: "extremum" | "peak_to_peak" = "extremum", + amplitude_mode: "extremum" | "peak_to_peak" = "extremum", by_property: str | None = None, + noise_levels: np.ndarray | list | None = None, **job_kwargs, ): """ - Estimate the sparsity without needing a SortingAnalyzer or Templates object - This is faster than `spikeinterface.waveforms_extractor.precompute_sparsity()` and it - traverses the recording to compute the average templates for each unit. + Estimate the sparsity without needing a SortingAnalyzer or Templates object. + In case the sparsity method needs templates, they are computed on-the-fly. + The same is done for noise levels, if needed by the method ("snr"). Contrary to the previous implementation: * all units are computed in one read of recording @@ -561,6 +669,8 @@ def estimate_sparsity( * it doesn't consume too much memory * it uses internally the `estimate_templates_with_accumulator()` which is fast and parallel + Note that the "energy" method is not supported because it requires a `SortingAnalyzer` object. + Parameters ---------- sorting : BaseSorting @@ -573,18 +683,9 @@ def estimate_sparsity( Cut out in ms before spike time ms_after : float, default: 2.5 Cut out in ms after spike time - method : "radius" | "best_channels" | "by_property", default: "radius" - Sparsity method propagated to the `compute_sparsity()` function. - "snr", "ptp", and "energy" are not available here because they require noise levels. - peak_sign : "neg" | "pos" | "both", default: "neg" - Sign of the template to compute best channels - radius_um : float, default: 100.0 - Used for "radius" method - num_channels : int, default: 5 - Used for "best_channels" method - snr_amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" - Used for "snr" method to compute the amplitude of the templates. - + noise_levels : np.array | None, default: None + Noise levels required for the "snr" and "energy" methods. You can use the + `get_noise_levels()` function to compute them. {} Returns @@ -595,9 +696,9 @@ def estimate_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - assert method in ("radius", "best_channels", "ptp", "by_property"), ( + assert method in ("radius", "best_channels", "ptp", "snr", "by_property"), ( f"method={method} is not available for `estimate_sparsity()`. " - "Available methods are 'radius', 'best_channels', 'ptp', 'by_property'" + "Available methods are 'radius', 'best_channels', 'ptp', 'snr', 'energy', 'by_property'" ) if recording.get_probes() == 1: @@ -644,21 +745,39 @@ def estimate_sparsity( unit_ids=sorting.unit_ids, probe=probe, ) - templates_or_analyzer = templates + + if method == "best_channels": + assert num_channels is not None, "For the 'best_channels' method, 'num_channels' needs to be given" + sparsity = ChannelSparsity.from_best_channels( + templates, num_channels, peak_sign=peak_sign, amplitude_mode=amplitude_mode + ) + elif method == "radius": + assert radius_um is not None, "For the 'radius' method, 'radius_um' needs to be given" + sparsity = ChannelSparsity.from_radius(templates, radius_um, peak_sign=peak_sign) + elif method == "snr": + assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given" + assert noise_levels is not None, ( + "For the 'snr' method, 'noise_levels' needs to be given. You can use the " + "`get_noise_levels()` function to compute them." + ) + sparsity = ChannelSparsity.from_snr( + templates, + threshold, + noise_levels=noise_levels, + peak_sign=peak_sign, + amplitude_mode=amplitude_mode, + ) + elif method == "ptp": + assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_ptp( + templates, + threshold, + ) + else: + raise ValueError(f"compute_sparsity() method={method} does not exists") else: - from .sortinganalyzer import create_sorting_analyzer - - templates_or_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, sparse=False) - sparsity = compute_sparsity( - templates_or_analyzer, - method=method, - peak_sign=peak_sign, - num_channels=num_channels, - radius_um=radius_um, - threshold=threshold, - by_property=by_property, - snr_amplitude_mode=snr_amplitude_mode, - ) + assert by_property is not None, "For the 'by_property' method, 'by_property' needs to be given" + sparsity = ChannelSparsity.from_property(sorting, recording, by_property) return sparsity diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 16b3bbc996..64517c106d 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -3,7 +3,7 @@ import numpy as np import json -from spikeinterface.core import ChannelSparsity, estimate_sparsity, compute_sparsity, Templates +from spikeinterface.core import ChannelSparsity, estimate_sparsity, compute_sparsity, get_noise_levels from spikeinterface.core.core_tools import check_json from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core import create_sorting_analyzer @@ -223,6 +223,36 @@ def test_estimate_sparsity(): n_jobs=1, ) + # snr: fails without noise levels + with pytest.raises(AssertionError): + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="snr", + threshold=5, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) + # snr: works with noise levels + noise_levels = get_noise_levels(recording) + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="snr", + threshold=5, + noise_levels=noise_levels, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) + def test_compute_sparsity(): recording, sorting = get_dataset() @@ -241,7 +271,7 @@ def test_compute_sparsity(): sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=50.0, peak_sign="neg") sparsity = compute_sparsity(sorting_analyzer, method="snr", threshold=5, peak_sign="neg") sparsity = compute_sparsity( - sorting_analyzer, method="snr", threshold=5, peak_sign="neg", snr_amplitude_mode="peak_to_peak" + sorting_analyzer, method="snr", threshold=5, peak_sign="neg", amplitude_mode="peak_to_peak" ) sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5) sparsity = compute_sparsity(sorting_analyzer, method="energy", threshold=5) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index a6d212425d..c3b3099535 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -25,7 +25,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "sparsity": {"method": "snr", "snr_amplitude_mode": "peak_to_peak", "threshold": 0.25}, + "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py index d6d0440a02..71a5f282a8 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py @@ -27,9 +27,7 @@ def test_benchmark_matching(create_cache_folder): recording, gt_sorting, ms_before=2.0, ms_after=3.0, return_scaled=False, **job_kwargs ) noise_levels = get_noise_levels(recording) - sparsity = compute_sparsity( - gt_templates, noise_levels, method="snr", snr_amplitude_mode="peak_to_peak", threshold=0.25 - ) + sparsity = compute_sparsity(gt_templates, noise_levels, method="snr", amplitude_mode="peak_to_peak", threshold=0.25) gt_templates = gt_templates.to_sparse(sparsity) # create study diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 11a628bb53..b08ee4d9cb 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -50,7 +50,7 @@ class CircusClustering: }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, - "sparsity": {"method": "snr", "snr_amplitude_mode": "peak_to_peak", "threshold": 0.25}, + "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "recursive_kwargs": { "recursive": True, "recursive_depth": 3, diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index b56fd3e02b..f7ca999d53 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -45,7 +45,7 @@ class RandomProjectionClustering: }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, - "sparsity": {"method": "snr", "snr_amplitude_mode": "peak_to_peak", "threshold": 0.25}, + "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "radius_um": 30, "nb_projections": 10, "feature": "energy", From 15a4a11bad4e08bdf4d8ce5de67e05dd2a2a8fab Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 10 Sep 2024 18:14:34 +0200 Subject: [PATCH 7/8] minor docstring fix --- src/spikeinterface/core/sparsity.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index c72f89520e..471302d57e 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -661,7 +661,8 @@ def estimate_sparsity( """ Estimate the sparsity without needing a SortingAnalyzer or Templates object. In case the sparsity method needs templates, they are computed on-the-fly. - The same is done for noise levels, if needed by the method ("snr"). + For the "snr" method, `noise_levels` must passed with the `noise_levels` argument. + These can be computed with the `get_noise_levels()` function. Contrary to the previous implementation: * all units are computed in one read of recording From 9ffda3543aa794418af02830be70144de64c54f2 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 13 Sep 2024 12:33:48 +0200 Subject: [PATCH 8/8] Add from_amplitude() option to sparsity and deprecate ptp --- src/spikeinterface/core/sparsity.py | 105 ++++++++++++++---- .../core/tests/test_sparsity.py | 31 +++++- 2 files changed, 108 insertions(+), 28 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 471302d57e..fd613e1fcf 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +import warnings from .basesorting import BaseSorting @@ -18,14 +19,16 @@ * "snr" : threshold based on template signal-to-noise ratio. Use the "threshold" argument to specify the SNR threshold (in units of noise levels) and the "amplitude_mode" argument to specify the mode to compute the amplitude of the templates. - * "ptp" : threshold based on the peak-to-peak values on every channels. Use the "threshold" argument - to specify the ptp threshold (in units of amplitude). + * "amplitude" : threshold based on the amplitude values on every channels. Use the "threshold" argument + to specify the ptp threshold (in units of amplitude) and the "amplitude_mode" argument + to specify the mode to compute the amplitude of the templates. * "energy" : threshold based on the expected energy that should be present on the channels, given their noise levels. Use the "threshold" argument to specify the energy threshold (in units of noise levels) * "by_property" : sparsity is given by a property of the recording and sorting (e.g. "group"). In this case the sparsity for each unit is given by the channels that have the same property value as the unit. Use the "by_property" argument to specify the property name. + * "ptp: : deprecated, use the 'snr' method with the 'peak_to_peak' amplitude mode instead. peak_sign : "neg" | "pos" | "both" Sign of the template to compute best channels. @@ -37,7 +40,7 @@ Threshold for "snr", "energy" (in units of noise levels) and "ptp" methods (in units of amplitude). For the "snr" method, the template amplitude mode is controlled by the "amplitude_mode" argument. amplitude_mode : "extremum" | "at_index" | "peak_to_peak" - Mode to compute the amplitude of the templates for the "snr" and "best_channels" methods. + Mode to compute the amplitude of the templates for the "snr", "amplitude", and "best_channels" methods. by_property : object Property name for "by_property" method. """ @@ -417,7 +420,7 @@ def from_snr( return cls(mask, unit_ids, channel_ids) @classmethod - def from_ptp(cls, templates_or_sorting_analyzer, threshold): + def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): """ Construct sparsity from a thresholds based on template peak-to-peak values. Use the "threshold" argument to specify the peak-to-peak threshold. @@ -434,30 +437,67 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold): sparsity : ChannelSparsity The estimated sparsity. """ + warnings.warn( + "The 'ptp' method is deprecated and will be removed in version 0.103.0. " + "Please use the 'snr' method with the 'peak_to_peak' amplitude mode instead.", + DeprecationWarning, + ) + return cls.from_snr( + templates_or_sorting_analyzer, threshold, amplitude_mode="peak_to_peak", noise_levels=noise_levels + ) - assert ( - templates_or_sorting_analyzer.sparsity is None - ), "To compute sparsity you need a dense SortingAnalyzer or Templates" + @classmethod + def from_amplitude(cls, templates_or_sorting_analyzer, threshold, amplitude_mode="extremum", peak_sign="neg"): + """ + Construct sparsity from a threshold based on template amplitude. + The amplitude is computed with the specified amplitude mode and it is assumed + that the amplitude is in uV. The input `Templates` or `SortingAnalyzer` object must + have scaled templates. - from .template_tools import get_dense_templates_array + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + threshold : float + Threshold for "amplitude" method (in uV). + amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + Mode to compute the amplitude of the templates. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. + """ + from .template_tools import get_template_amplitudes from .sortinganalyzer import SortingAnalyzer from .template import Templates + assert ( + templates_or_sorting_analyzer.sparsity is None + ), "To compute sparsity you need a dense SortingAnalyzer or Templates" + unit_ids = templates_or_sorting_analyzer.unit_ids channel_ids = templates_or_sorting_analyzer.channel_ids if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - return_scaled = templates_or_sorting_analyzer.return_scaled + assert templates_or_sorting_analyzer.return_scaled, ( + "To compute sparsity from amplitude you need to have scaled templates. " + "You can set `return_scaled=True` when computing the templates." + ) elif isinstance(templates_or_sorting_analyzer, Templates): - return_scaled = templates_or_sorting_analyzer.is_scaled + assert templates_or_sorting_analyzer.is_scaled, ( + "To compute sparsity from amplitude you need to have scaled templates. " + "You can set `is_scaled=True` when creating the Templates object." + ) mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") - templates_array = get_dense_templates_array(templates_or_sorting_analyzer, return_scaled=return_scaled) - templates_ptps = np.ptp(templates_array, axis=1) + peak_values = get_template_amplitudes( + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode, return_scaled=True + ) for unit_ind, unit_id in enumerate(unit_ids): - chan_inds = np.nonzero(templates_ptps[unit_ind] >= threshold) + chan_inds = np.nonzero((np.abs(peak_values[unit_id])) >= threshold) mask[unit_ind, chan_inds] = True return cls(mask, unit_ids, channel_ids) @@ -560,7 +600,7 @@ def create_dense(cls, sorting_analyzer): def compute_sparsity( templates_or_sorting_analyzer: "Templates | SortingAnalyzer", noise_levels: np.ndarray | None = None, - method: "radius" | "best_channels" | "snr" | "ptp" | "energy" | "by_property" = "radius", + method: "radius" | "best_channels" | "snr" | "amplitude" | "energy" | "by_property" | "ptp" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", num_channels: int | None = 5, radius_um: float | None = 100.0, @@ -595,7 +635,7 @@ def compute_sparsity( # to keep backward compatibility templates_or_sorting_analyzer = templates_or_sorting_analyzer.sorting_analyzer - if method in ("best_channels", "radius", "snr", "ptp"): + if method in ("best_channels", "radius", "snr", "amplitude", "ptp"): assert isinstance( templates_or_sorting_analyzer, (Templates, SortingAnalyzer) ), f"compute_sparsity(method='{method}') need Templates or SortingAnalyzer" @@ -619,11 +659,13 @@ def compute_sparsity( peak_sign=peak_sign, amplitude_mode=amplitude_mode, ) - elif method == "ptp": - assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_ptp( + elif method == "amplitude": + assert threshold is not None, "For the 'amplitude' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_amplitude( templates_or_sorting_analyzer, threshold, + amplitude_mode=amplitude_mode, + peak_sign=peak_sign, ) elif method == "energy": assert threshold is not None, "For the 'energy' method, 'threshold' needs to be given" @@ -633,6 +675,14 @@ def compute_sparsity( sparsity = ChannelSparsity.from_property( templates_or_sorting_analyzer.sorting, templates_or_sorting_analyzer.recording, by_property ) + elif method == "ptp": + # TODO: remove after deprecation + assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_ptp( + templates_or_sorting_analyzer, + threshold, + noise_levels=noise_levels, + ) else: raise ValueError(f"compute_sparsity() method={method} does not exists") @@ -648,7 +698,7 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" | "ptp" | "snr" | "by_property" = "radius", + method: "radius" | "best_channels" | "amplitude" | "snr" | "by_property" | "ptp" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", radius_um: float = 100.0, num_channels: int = 5, @@ -697,9 +747,9 @@ def estimate_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - assert method in ("radius", "best_channels", "ptp", "snr", "by_property"), ( + assert method in ("radius", "best_channels", "snr", "amplitude", "by_property", "ptp"), ( f"method={method} is not available for `estimate_sparsity()`. " - "Available methods are 'radius', 'best_channels', 'ptp', 'snr', 'energy', 'by_property'" + "Available methods are 'radius', 'best_channels', 'snr', 'amplitude', 'by_property', 'ptp' (deprecated)" ) if recording.get_probes() == 1: @@ -768,12 +818,19 @@ def estimate_sparsity( peak_sign=peak_sign, amplitude_mode=amplitude_mode, ) + elif method == "amplitude": + assert threshold is not None, "For the 'amplitude' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_amplitude( + templates, threshold, amplitude_mode=amplitude_mode, peak_sign=peak_sign + ) elif method == "ptp": + # TODO: remove after deprecation assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_ptp( - templates, - threshold, + assert noise_levels is not None, ( + "For the 'snr' method, 'noise_levels' needs to be given. You can use the " + "`get_noise_levels()` function to compute them." ) + sparsity = ChannelSparsity.from_ptp(templates, threshold, noise_levels=noise_levels) else: raise ValueError(f"compute_sparsity() method={method} does not exists") else: diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 64517c106d..ace869df8c 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -209,15 +209,16 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 5) - # ptp: just run it + # amplitude sparsity = estimate_sparsity( sorting, recording, num_spikes_for_sparsity=50, ms_before=1.0, ms_after=2.0, - method="ptp", + method="amplitude", threshold=5, + amplitude_mode="peak_to_peak", chunk_duration="1s", progress_bar=True, n_jobs=1, @@ -252,6 +253,23 @@ def test_estimate_sparsity(): progress_bar=True, n_jobs=1, ) + # ptp: just run it + print(noise_levels) + + with pytest.warns(DeprecationWarning): + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="ptp", + threshold=5, + noise_levels=noise_levels, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) def test_compute_sparsity(): @@ -273,9 +291,11 @@ def test_compute_sparsity(): sparsity = compute_sparsity( sorting_analyzer, method="snr", threshold=5, peak_sign="neg", amplitude_mode="peak_to_peak" ) - sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5) + sparsity = compute_sparsity(sorting_analyzer, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") sparsity = compute_sparsity(sorting_analyzer, method="energy", threshold=5) sparsity = compute_sparsity(sorting_analyzer, method="by_property", by_property="group") + with pytest.warns(DeprecationWarning): + sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5) # using object Templates templates = sorting_analyzer.get_extension("templates").get_data(outputs="Templates") @@ -283,7 +303,10 @@ def test_compute_sparsity(): sparsity = compute_sparsity(templates, method="best_channels", num_channels=2, peak_sign="neg") sparsity = compute_sparsity(templates, method="radius", radius_um=50.0, peak_sign="neg") sparsity = compute_sparsity(templates, method="snr", noise_levels=noise_levels, threshold=5, peak_sign="neg") - sparsity = compute_sparsity(templates, method="ptp", noise_levels=noise_levels, threshold=5) + sparsity = compute_sparsity(templates, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") + + with pytest.warns(DeprecationWarning): + sparsity = compute_sparsity(templates, method="ptp", noise_levels=noise_levels, threshold=5) if __name__ == "__main__":