diff --git a/pyproject.toml b/pyproject.toml index 51efe1f585..dede4f9052 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,6 +103,13 @@ widgets = [ "sortingview>=0.11.15", ] +qualitymetrics = [ + "scikit-learn", + "scipy", + "pandas", + "numba", +] + test_core = [ "pytest", "zarr", diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 501fd8cc79..6d5a753cad 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -297,6 +297,48 @@ def test_extract_waveforms(): ) assert we4.sparsity is not None + # test with sparsity estimation + folder5 = cache_folder / "test_extract_waveforms_compute_sparsity_tmp_folder" + sparsity_temp_folder = cache_folder / "tmp_sparsity" + if folder5.is_dir(): + shutil.rmtree(folder5) + + we5 = extract_waveforms( + recording, + sorting, + folder5, + max_spikes_per_unit=100, + return_scaled=True, + sparse=True, + sparsity_temp_folder=sparsity_temp_folder, + method="radius", + radius_um=50.0, + n_jobs=2, + chunk_duration="500ms", + ) + assert we5.sparsity is not None + # tmp folder is cleaned up + assert not sparsity_temp_folder.is_dir() + + # should raise an error if sparsity_temp_folder is not empty + with pytest.raises(AssertionError): + if folder5.is_dir(): + shutil.rmtree(folder5) + sparsity_temp_folder.mkdir() + we5 = extract_waveforms( + recording, + sorting, + folder5, + max_spikes_per_unit=100, + return_scaled=True, + sparse=True, + sparsity_temp_folder=sparsity_temp_folder, + method="radius", + radius_um=50.0, + n_jobs=2, + chunk_duration="500ms", + ) + def test_recordingless(): durations = [30, 40] diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index d4ae140b90..9eef8b791a 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -1481,7 +1481,9 @@ def extract_waveforms( dtype=None, sparse=True, sparsity=None, + sparsity_temp_folder=None, num_spikes_for_sparsity=100, + unit_batch_size=200, allow_unfiltered=False, use_relative_path=False, seed=None, @@ -1531,8 +1533,15 @@ def extract_waveforms( sparsity you want to apply (by radius, by best channels, ...). sparsity: ChannelSparsity or None The sparsity used to compute waveforms. If this is given, `sparse` is ignored. Default None. + sparsity_temp_folder: str or Path or None, default: None + If sparse is True, this is the temporary folder where the dense waveforms are temporarily saved. + If None, dense waveforms are extracted in memory in batches (which can be controlled by the `unit_batch_size` + parameter. With a large number of units (e.g., > 400), it is advisable to use a temporary folder. num_spikes_for_sparsity: int (default 100) The number of spikes to use to estimate sparsity (if sparse=True). + unit_batch_size: int, default 200 + The number of units to process at once when extracting dense waveforms (if sparse=True and sparsity_temp_folder + is None). allow_unfiltered: bool If true, will accept an allow_unfiltered recording. False by default. @@ -1612,6 +1621,8 @@ def extract_waveforms( ms_before=ms_before, ms_after=ms_after, num_spikes_for_sparsity=num_spikes_for_sparsity, + unit_batch_size=unit_batch_size, + temp_folder=sparsity_temp_folder, allow_unfiltered=allow_unfiltered, **estimate_kwargs, **job_kwargs, @@ -1675,6 +1686,7 @@ def precompute_sparsity( unit_batch_size=200, ms_before=2.0, ms_after=3.0, + temp_folder=None, allow_unfiltered=False, **kwargs, ): @@ -1689,25 +1701,25 @@ def precompute_sparsity( The recording object sorting: Sorting The sorting object - num_spikes_for_sparsity: int - How many spikes per unit. - unit_batch_size: int or None + num_spikes_for_sparsity: int, default 100 + How many spikes per unit + unit_batch_size: int or None, default 200 How many units are extracted at once to estimate sparsity. - If None then they are extracted all at one (consum many memory) - ms_before: float + If None then they are extracted all at one (but uses a lot of memory) + ms_before: float, default 2.0 Time in ms to cut before spike peak - ms_after: float + ms_after: float, default 3.0 Time in ms to cut after spike peak - allow_unfiltered: bool + temp_folder: str or Path or None, default: None + If provided, dense waveforms are saved to this temporary folder + allow_unfiltered: bool, default: False If true, will accept an allow_unfiltered recording. - False by default. - kwargs for sparsity strategy: {} - Job kwargs: + job kwargs: {} Returns @@ -1724,18 +1736,38 @@ def precompute_sparsity( if unit_batch_size is None: unit_batch_size = len(unit_ids) - mask = np.zeros((len(unit_ids), len(channel_ids)), dtype="bool") - - nloop = int(np.ceil((unit_ids.size / unit_batch_size))) - for i in range(nloop): - sl = slice(i * unit_batch_size, (i + 1) * unit_batch_size) - local_ids = unit_ids[sl] - local_sorting = sorting.select_units(local_ids) - local_we = extract_waveforms( + if temp_folder is None: + mask = np.zeros((len(unit_ids), len(channel_ids)), dtype="bool") + nloop = int(np.ceil((unit_ids.size / unit_batch_size))) + for i in range(nloop): + sl = slice(i * unit_batch_size, (i + 1) * unit_batch_size) + local_ids = unit_ids[sl] + local_sorting = sorting.select_units(local_ids) + local_we = extract_waveforms( + recording, + local_sorting, + folder=None, + mode="memory", + precompute_template=("average",), + ms_before=ms_before, + ms_after=ms_after, + max_spikes_per_unit=num_spikes_for_sparsity, + return_scaled=False, + allow_unfiltered=allow_unfiltered, + sparse=False, + **job_kwargs, + ) + local_sparsity = compute_sparsity(local_we, **sparse_kwargs) + mask[sl, :] = local_sparsity.mask + else: + temp_folder = Path(temp_folder) + assert ( + not temp_folder.is_dir() + ), "Temporary folder for pre-computing sparsity already exists. Provide a non-existing folder" + dense_we = extract_waveforms( recording, - local_sorting, - folder=None, - mode="memory", + sorting, + folder=temp_folder, precompute_template=("average",), ms_before=ms_before, ms_after=ms_after, @@ -1745,8 +1777,9 @@ def precompute_sparsity( sparse=False, **job_kwargs, ) - local_sparsity = compute_sparsity(local_we, **sparse_kwargs) - mask[sl, :] = local_sparsity.mask + sparsity = compute_sparsity(dense_we, **sparse_kwargs) + mask = sparsity.mask + shutil.rmtree(temp_folder) sparsity = ChannelSparsity(mask, unit_ids, channel_ids) return sparsity diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index ed06f7d738..ff1995a7d9 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -1,6 +1,5 @@ """Cluster quality metrics computed from principal components.""" -from cmath import nan from copy import deepcopy import numpy as np @@ -16,7 +15,6 @@ except: pass -import spikeinterface as si from ..core import get_random_data_chunks, compute_sparsity, WaveformExtractor from ..core.job_tools import tqdm_joblib from ..core.template_tools import get_template_extremum_channel diff --git a/src/spikeinterface/qualitymetrics/utils.py b/src/spikeinterface/qualitymetrics/utils.py index 741308270b..4f2195b1a9 100644 --- a/src/spikeinterface/qualitymetrics/utils.py +++ b/src/spikeinterface/qualitymetrics/utils.py @@ -1,5 +1,5 @@ import numpy as np -from scipy.stats import norm, multivariate_normal +from scipy.stats import multivariate_normal def create_ground_truth_pc_distributions(center_locations, total_points):