From d652e6393ce0700cab3e051aed5f572d79408a9e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jun 2024 17:05:37 +0200 Subject: [PATCH 01/15] Port #3032 and #3056 --- pyproject.toml | 4 ++-- src/spikeinterface/core/core_tools.py | 7 +++++- .../core/tests/test_jsonification.py | 1 - .../tests/test_highpass_spatial_filter.py | 19 +++++++-------- .../tests/test_interpolate_bad_channels.py | 23 +++++++++++-------- .../sortingcomponents/peak_localization.py | 4 ++-- .../test_waveform_thresholder.py | 2 +- .../waveforms/waveform_thresholder.py | 2 +- 8 files changed, 36 insertions(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a51d22ec18..382d8bc52c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spikeinterface" -version = "0.100.7" +version = "0.100.8" authors = [ { name="Alessio Buccino", email="alessiop.buccino@gmail.com" }, { name="Samuel Garcia", email="sam.garcia.die@gmail.com" }, @@ -20,7 +20,7 @@ classifiers = [ dependencies = [ - "numpy", + "numpy>=1.20, <2.0", # 1.20 np.ptp, 1.26 might be necessary for avoiding pickling errors when numpy >2.0 "threadpoolctl>=3.0.0", "tqdm", "zarr>=2.16,<2.18", diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 3725fcfba8..e42244c648 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -93,7 +93,12 @@ def default(self, obj): if isinstance(obj, np.generic): return obj.item() - if np.issctype(obj): # Cast numpy datatypes to their names + # Standard numpy dtypes like np.dtype('int32") are transformed this way + if isinstance(obj, np.dtype): + return np.dtype(obj).name + + # This will transform to a string canonical representation of the dtype (e.g. np.int32 -> 'int32') + if isinstance(obj, type) and issubclass(obj, np.generic): return np.dtype(obj).name if isinstance(obj, np.ndarray): diff --git a/src/spikeinterface/core/tests/test_jsonification.py b/src/spikeinterface/core/tests/test_jsonification.py index f63cfb16d8..4417ea342f 100644 --- a/src/spikeinterface/core/tests/test_jsonification.py +++ b/src/spikeinterface/core/tests/test_jsonification.py @@ -122,7 +122,6 @@ def test_numpy_dtype_alises_encoding(): # People tend to use this a dtype instead of the proper classes json.dumps(np.int32, cls=SIJsonEncoder) json.dumps(np.float32, cls=SIJsonEncoder) - json.dumps(np.bool_, cls=SIJsonEncoder) # Note that np.bool was deperecated in numpy 1.20.0 def test_recording_encoding(numpy_generated_recording): diff --git a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py index 5b2bec2eda..f4aff7bf1f 100644 --- a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py @@ -7,15 +7,8 @@ import spikeinterface.preprocessing as spre import spikeinterface.extractors as se from spikeinterface.core import generate_recording -import spikeinterface.widgets as sw +import importlib.util -try: - import spikeglx - import neurodsp.voltage as voltage - - HAVE_IBL_NPIX = True -except ImportError: - HAVE_IBL_NPIX = False ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) @@ -31,7 +24,10 @@ # ---------------------------------------------------------------------------------------------------------------------- -@pytest.mark.skipif(not HAVE_IBL_NPIX or ON_GITHUB, reason="Only local. Requires ibl-neuropixel install") +@pytest.mark.skipif( + importlib.util.find_spec("neurodsp") is not None or importlib.util.find_spec("spikeglx") or ON_GITHUB, + reason="Only local. Requires ibl-neuropixel install", +) @pytest.mark.parametrize("lagc", [False, 1, 300]) def test_highpass_spatial_filter_real_data(lagc): """ @@ -56,6 +52,9 @@ def test_highpass_spatial_filter_real_data(lagc): use DEBUG = true to visualise. """ + import spikeglx + import neurodsp.voltage as voltage + options = dict(lagc=lagc, ntr_pad=25, ntr_tap=50, butter_kwargs=None) print(options) @@ -118,6 +117,8 @@ def get_ibl_si_data(): """ Set fixture to session to ensure origional data is not changed. """ + import spikeglx + local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0") ibl_recording = spikeglx.Reader( local_path / "Noise4Sam_g0_imec0" / "Noise4Sam_g0_t0.imec0.ap.bin", ignore_warnings=True diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index ad073e40aa..fd66cf9bc3 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -6,14 +6,8 @@ import spikeinterface.preprocessing as spre import spikeinterface.extractors as se from spikeinterface.core.generate import generate_recording +import importlib.util -try: - import spikeglx - import neurodsp.voltage as voltage - - HAVE_IBL_NPIX = True -except ImportError: - HAVE_IBL_NPIX = False ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) @@ -30,7 +24,10 @@ # ------------------------------------------------------------------------------- -@pytest.mark.skipif(not HAVE_IBL_NPIX or ON_GITHUB, reason="Only local. Requires ibl-neuropixel install") +@pytest.mark.skipif( + importlib.util.find_spec("neurodsp") is not None or importlib.util.find_spec("spikeglx") or ON_GITHUB, + reason="Only local. Requires ibl-neuropixel install", +) def test_compare_real_data_with_ibl(): """ Test SI implementation of bad channel interpolation against native IBL. @@ -42,6 +39,9 @@ def test_compare_real_data_with_ibl(): voltage.interpolate_bad_channel() with ibl_channel geometry to si_scaled_recordin.get_traces(0) is also close to 1e-2. """ + import spikeglx + import neurodsp.voltage as voltage + # Download and load data local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0") si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap") @@ -80,7 +80,10 @@ def test_compare_real_data_with_ibl(): assert np.mean(is_close) > 0.999 -@pytest.mark.skipif(not HAVE_IBL_NPIX, reason="Requires ibl-neuropixel install") +@pytest.mark.skipif( + importlib.util.find_spec("neurodsp") is not None or importlib.util.find_spec("spikeglx") is not None, + reason="Requires ibl-neuropixel install", +) @pytest.mark.parametrize("num_channels", [32, 64]) @pytest.mark.parametrize("sigma_um", [1.25, 40]) @pytest.mark.parametrize("p", [0, -0.5, 1, 5]) @@ -90,6 +93,8 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan Perform an extended test across a range of function inputs to check IBL and SI interpolation results match. """ + import neurodsp.voltage as voltage + recording = generate_recording(num_channels=num_channels, durations=[1]) # distribute default probe locations across 4 shanks if set diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index 20328dda6d..e8b4127429 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -203,7 +203,7 @@ def compute(self, traces, peaks, waveforms): wf = waveforms[idx][:, :, chan_inds] if self.feature == "ptp": - wf_data = wf.ptp(axis=1) + wf_data = np.ptp(wf, axis=1) elif self.feature == "mean": wf_data = wf.mean(axis=1) elif self.feature == "energy": @@ -292,7 +292,7 @@ def compute(self, traces, peaks, waveforms): wf = waveforms[i, :][:, chan_inds] if self.feature == "ptp": - wf_data = wf.ptp(axis=0) + wf_data = np.ptp(wf, axis=0) elif self.feature == "energy": wf_data = np.linalg.norm(wf, axis=0) elif self.feature == "peak_voltage": diff --git a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py index 572e6c36c1..17f6e5b86c 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py +++ b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py @@ -35,7 +35,7 @@ def test_waveform_thresholder_ptp(extract_waveforms, generated_recording, detect recording, peaks, nodes=pipeline_nodes, job_kwargs=chunk_executor_kwargs ) - data = tresholded_waveforms.ptp(axis=1) / noise_levels + data = np.ptp(tresholded_waveforms, axis=1) / noise_levels assert np.all(data[data != 0] > 3) diff --git a/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py b/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py index 8dd925ff14..567f26e098 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py +++ b/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py @@ -81,7 +81,7 @@ def __init__( def compute(self, traces, peaks, waveforms): if self.feature == "ptp": - wf_data = waveforms.ptp(axis=1) / self.noise_levels + wf_data = np.ptp(waveforms, axis=1) / self.noise_levels elif self.feature == "mean": wf_data = waveforms.mean(axis=1) / self.noise_levels elif self.feature == "energy": From f7a81a98c2fbb184480fe0733f6ea038cf818cef Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jun 2024 17:14:53 +0200 Subject: [PATCH 02/15] Port #3055 --- src/spikeinterface/core/generate.py | 6 +++++ .../sorters/external/kilosort4.py | 26 ++++++++++++++++--- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 0c92eab0e4..2853a4fc55 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1094,6 +1094,8 @@ def get_traces( ) -> np.ndarray: start_frame = 0 if start_frame is None else max(start_frame, 0) end_frame = self.num_samples if end_frame is None else min(end_frame, self.num_samples) + start_frame = int(start_frame) + end_frame = int(end_frame) start_frame_within_block = start_frame % self.noise_block_size end_frame_within_block = end_frame % self.noise_block_size @@ -1652,6 +1654,8 @@ def get_traces( ) -> np.ndarray: start_frame = 0 if start_frame is None else start_frame end_frame = self.num_samples if end_frame is None else end_frame + start_frame = int(start_frame) + end_frame = int(end_frame) if channel_indices is None: n_channels = self.templates.shape[2] @@ -1688,6 +1692,8 @@ def get_traces( end_traces = start_traces + template.shape[0] if start_traces >= end_frame - start_frame or end_traces <= 0: continue + start_traces = int(start_traces) + end_traces = int(end_traces) start_template = 0 end_template = template.shape[0] diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 47846f10ce..a7f40a9558 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Union +from packaging import version from ..basesorter import BaseSorter from .kilosortbase import KilosortBase @@ -24,11 +25,14 @@ class Kilosort4Sorter(BaseSorter): "do_CAR": True, "invert_sign": False, "nt": 61, + "shift": None, + "scale": None, "artifact_threshold": None, "nskip": 25, "whitening_range": 32, "binning_depth": 5, "sig_interp": 20, + "drift_smoothing": [0.5, 0.5, 0.5], "nt0min": None, "dmin": None, "dminx": 32, @@ -63,11 +67,14 @@ class Kilosort4Sorter(BaseSorter): "do_CAR": "Whether to perform common average reference. Default value: True.", "invert_sign": "Invert the sign of the data. Default value: False.", "nt": "Number of samples per waveform. Also size of symmetric padding for filtering. Default value: 61.", + "shift": "Scalar shift to apply to data before all other operations. Default None.", + "scale": "Scaling factor to apply to data before all other operations. Default None.", "artifact_threshold": "If a batch contains absolute values above this number, it will be zeroed out under the assumption that a recording artifact is present. By default, the threshold is infinite (so that no zeroing occurs). Default value: None.", "nskip": "Batch stride for computing whitening matrix. Default value: 25.", "whitening_range": "Number of nearby channels used to estimate the whitening matrix. Default value: 32.", "binning_depth": "For drift correction, vertical bin size in microns used for 2D histogram. Default value: 5.", "sig_interp": "For drift correction, sigma for interpolation (spatial standard deviation). Approximate smoothness scale in units of microns. Default value: 20.", + "drift_smoothing": "Amount of gaussian smoothing to apply to the spatiotemporal drift estimation, for x,y,time axes in units of registration blocks (for x,y axes) and batch size (for time axis). The x,y smoothing has no effect for `nblocks = 1`.", "nt0min": "Sample index for aligning waveforms, so that their minimum or maximum value happens here. Default of 20. Default value: None.", "dmin": "Vertical spacing of template centers used for spike detection, in microns. Determined automatically by default. Default value: None.", "dminx": "Horizontal spacing of template centers used for spike detection, in microns. Default value: 32.", @@ -153,6 +160,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): import torch import numpy as np + if verbose: + import logging + + logging.basicConfig(level=logging.INFO) + sorter_output_folder = sorter_output_folder.absolute() probe_filename = sorter_output_folder / "probe.prb" @@ -194,11 +206,17 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): data_dir = "" results_dir = sorter_output_folder filename, data_dir, results_dir, probe = set_files(settings, filename, probe, probe_name, data_dir, results_dir) - ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device) + if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): + ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device, False) + n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( + get_run_parameters(ops) + ) + else: + ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device) + n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = ( + get_run_parameters(ops) + ) - n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = ( - get_run_parameters(ops) - ) # Set preprocessing and drift correction parameters if not params["skip_kilosort_preprocessing"]: ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object) From 56d5911f6a2fcce6a8a485a0f1e12fc5104d864e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jun 2024 17:16:40 +0200 Subject: [PATCH 03/15] Port #3037 --- src/spikeinterface/sorters/basesorter.py | 2 +- src/spikeinterface/sorters/external/mountainsort5.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 1465d205b4..8eefb9acd6 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -184,7 +184,7 @@ def set_params_to_folder(cls, recording, output_folder, new_params, verbose): # custom check params params = cls._check_params(recording, output_folder, params) # common check : filter warning - if recording.is_filtered and cls._check_apply_filter_in_params(params) and verbose: + if recording.is_filtered() and cls._check_apply_filter_in_params(params) and verbose: print(f"Warning! The recording is already filtered, but {cls.sorter_name} filter is enabled") # dump parameters inside the folder with json diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index d516089d34..8916c0181a 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -120,7 +120,6 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): import mountainsort5 as ms5 - from mountainsort5.util import create_cached_recording recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) if recording is None: From 18bda8e515c5faf357c24c1f72b312b44cd5f117 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jun 2024 17:18:06 +0200 Subject: [PATCH 04/15] Port #2997 --- src/spikeinterface/sorters/external/kilosort2.py | 6 ++++-- src/spikeinterface/sorters/external/kilosort2_5.py | 4 +++- src/spikeinterface/sorters/external/kilosort3.py | 6 ++++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort2.py b/src/spikeinterface/sorters/external/kilosort2.py index fa5ff9889f..0425ad5e53 100644 --- a/src/spikeinterface/sorters/external/kilosort2.py +++ b/src/spikeinterface/sorters/external/kilosort2.py @@ -37,6 +37,7 @@ class Kilosort2Sorter(KilosortBase, BaseSorter): "detect_threshold": 6, "projection_threshold": [10, 4], "preclust_threshold": 8, + "whiteningRange": 32, # samples of the template to use for whitening "spatial" dimension "momentum": [20.0, 400.0], "car": True, "minFR": 0.1, @@ -62,6 +63,7 @@ class Kilosort2Sorter(KilosortBase, BaseSorter): "detect_threshold": "Threshold for spike detection", "projection_threshold": "Threshold on projections", "preclust_threshold": "Threshold crossings for pre-clustering (in PCA projection space)", + "whiteningRange": "Number of channels to use for whitening each channel", "momentum": "Number of samples to average over (annealed from first to second value)", "car": "Enable or disable common reference", "minFR": "Minimum spike rate (Hz), if a cluster falls below this for too long it gets removed", @@ -72,7 +74,7 @@ class Kilosort2Sorter(KilosortBase, BaseSorter): "nPCs": "Number of PCA dimensions", "ntbuff": "Samples of symmetrical buffer for whitening and spike detection", "nfilt_factor": "Max number of clusters per good channel (even temporary ones) 4", - "NT": "Batch size (if None it is automatically computed)", + "NT": "Batch size (if None it is automatically computed--recommended Kilosort behavior if ntbuff also not changed)", "AUCsplit": "Threshold on the area under the curve (AUC) criterion for performing a split in the final step", "wave_length": "size of the waveform extracted around each detected peak, (Default 61, maximum 81)", "keep_good_only": "If True only 'good' units are returned", @@ -199,7 +201,7 @@ def _get_specific_options(cls, ops, params): ops["NT"] = params[ "NT" ] # must be multiple of 32 + ntbuff. This is the batch size (try decreasing if out of memory). - ops["whiteningRange"] = 32.0 # number of channels to use for whitening each channel + ops["whiteningRange"] = params["whiteningRange"] # number of channels to use for whitening each channel ops["nSkipCov"] = 25.0 # compute whitening matrix from every N-th batch ops["nPCs"] = params["nPCs"] # how many PCs to project the spikes into ops["useRAM"] = 0.0 # not yet available diff --git a/src/spikeinterface/sorters/external/kilosort2_5.py b/src/spikeinterface/sorters/external/kilosort2_5.py index abde2ab324..b3d1718d59 100644 --- a/src/spikeinterface/sorters/external/kilosort2_5.py +++ b/src/spikeinterface/sorters/external/kilosort2_5.py @@ -41,6 +41,7 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "detect_threshold": 6, "projection_threshold": [10, 4], "preclust_threshold": 8, + "whiteningRange": 32.0, "momentum": [20.0, 400.0], "car": True, "minFR": 0.1, @@ -69,6 +70,7 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "detect_threshold": "Threshold for spike detection", "projection_threshold": "Threshold on projections", "preclust_threshold": "Threshold crossings for pre-clustering (in PCA projection space)", + "whiteningRange": "Number of channels to use for whitening each channel", "momentum": "Number of samples to average over (annealed from first to second value)", "car": "Enable or disable common reference", "minFR": "Minimum spike rate (Hz), if a cluster falls below this for too long it gets removed", @@ -220,7 +222,7 @@ def _get_specific_options(cls, ops, params): ops["NT"] = params[ "NT" ] # must be multiple of 32 + ntbuff. This is the batch size (try decreasing if out of memory). - ops["whiteningRange"] = 32.0 # number of channels to use for whitening each channel + ops["whiteningRange"] = params["whiteningRange"] # number of channels to use for whitening each channel ops["nSkipCov"] = 25.0 # compute whitening matrix from every N-th batch ops["nPCs"] = params["nPCs"] # how many PCs to project the spikes into ops["useRAM"] = 0.0 # not yet available diff --git a/src/spikeinterface/sorters/external/kilosort3.py b/src/spikeinterface/sorters/external/kilosort3.py index ef42b5b6a3..f560fd7e1e 100644 --- a/src/spikeinterface/sorters/external/kilosort3.py +++ b/src/spikeinterface/sorters/external/kilosort3.py @@ -38,6 +38,7 @@ class Kilosort3Sorter(KilosortBase, BaseSorter): "detect_threshold": 6, "projection_threshold": [9, 9], "preclust_threshold": 8, + "whiteningRange": 32, "car": True, "minFR": 0.2, "minfr_goodchannels": 0.2, @@ -65,6 +66,7 @@ class Kilosort3Sorter(KilosortBase, BaseSorter): "detect_threshold": "Threshold for spike detection", "projection_threshold": "Threshold on projections", "preclust_threshold": "Threshold crossings for pre-clustering (in PCA projection space)", + "whiteningRange": "number of channels to use for whitening each channel", "car": "Enable or disable common reference", "minFR": "Minimum spike rate (Hz), if a cluster falls below this for too long it gets removed", "minfr_goodchannels": "Minimum firing rate on a 'good' channel", @@ -77,7 +79,7 @@ class Kilosort3Sorter(KilosortBase, BaseSorter): "ntbuff": "Samples of symmetrical buffer for whitening and spike detection", "nfilt_factor": "Max number of clusters per good channel (even temporary ones) 4", "do_correction": "If True drift registration is applied", - "NT": "Batch size (if None it is automatically computed)", + "NT": "Batch size (if None it is automatically computed--recommended Kilosort behavior if ntbuff also not changed)", "AUCsplit": "Threshold on the area under the curve (AUC) criterion for performing a split in the final step", "wave_length": "size of the waveform extracted around each detected peak, (Default 61, maximum 81)", "keep_good_only": "If True only 'good' units are returned", @@ -212,7 +214,7 @@ def _get_specific_options(cls, ops, params): ops["NT"] = params[ "NT" ] # must be multiple of 32 + ntbuff. This is the batch size (try decreasing if out of memory). - ops["whiteningRange"] = 32.0 # number of channels to use for whitening each channel + ops["whiteningRange"] = params["whiteningRange"] # number of channels to use for whitening each channel ops["nSkipCov"] = 25.0 # compute whitening matrix from every N-th batch ops["scaleproc"] = 200.0 # int16 scaling of whitened data ops["nPCs"] = params["nPCs"] # how many PCs to project the spikes into From 39bf722b6ab0569e7859c8bd87e5d7248137ff06 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jun 2024 17:23:30 +0200 Subject: [PATCH 05/15] Port #2712 --- src/spikeinterface/core/globals.py | 13 +++++--- src/spikeinterface/core/job_tools.py | 32 +++++++++++++------ src/spikeinterface/core/tests/test_globals.py | 12 +++++++ src/spikeinterface/sorters/basesorter.py | 9 +++--- 4 files changed, 46 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/core/globals.py b/src/spikeinterface/core/globals.py index bbce3998c6..23d60a5ac5 100644 --- a/src/spikeinterface/core/globals.py +++ b/src/spikeinterface/core/globals.py @@ -42,9 +42,9 @@ def set_global_tmp_folder(folder): temp_folder_set = True -def is_set_global_tmp_folder(): +def is_set_global_tmp_folder() -> bool: """ - Check is the global path temporary folder have been manually set. + Check if the global path temporary folder have been manually set. """ global temp_folder_set return temp_folder_set @@ -88,9 +88,9 @@ def set_global_dataset_folder(folder): dataset_folder_set = True -def is_set_global_dataset_folder(): +def is_set_global_dataset_folder() -> bool: """ - Check is the global path dataset folder have been manually set. + Check if the global path dataset folder has been manually set. """ global dataset_folder_set return dataset_folder_set @@ -138,7 +138,10 @@ def reset_global_job_kwargs(): global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True) -def is_set_global_job_kwargs_set(): +def is_set_global_job_kwargs_set() -> bool: + """ + Check if the global job kwargs have been manually set. + """ global global_job_kwargs_set return global_job_kwargs_set diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 981ad6320b..5812fbdc5a 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -10,7 +10,6 @@ import warnings import sys -import contextlib from tqdm.auto import tqdm from concurrent.futures import ProcessPoolExecutor @@ -28,8 +27,9 @@ Total memory usage (e.g. "500M", "2G") - chunk_duration : str or float or None Chunk duration in s if float or with units if str (e.g. "1s", "500ms") - * n_jobs: int - Number of jobs to use. With -1 the number of jobs is the same as number of cores + * n_jobs: int | float + Number of jobs to use. With -1 the number of jobs is the same as number of cores. + Using a float between 0 and 1 will use that fraction of the total cores. * progress_bar: bool If True, a progress bar is printed * mp_context: "fork" | "spawn" | None, default: None @@ -60,7 +60,7 @@ def fix_job_kwargs(runtime_job_kwargs): - from .globals import get_global_job_kwargs + from .globals import get_global_job_kwargs, is_set_global_job_kwargs_set job_kwargs = get_global_job_kwargs() @@ -68,30 +68,42 @@ def fix_job_kwargs(runtime_job_kwargs): assert k in job_keys, ( f"{k} is not a valid job keyword argument. " f"Available keyword arguments are: {list(job_keys)}" ) - # remove mutually exclusive from global job kwargs for k, v in runtime_job_kwargs.items(): if k in _mutually_exclusive and v is not None: for key_to_remove in _mutually_exclusive: if key_to_remove in job_kwargs: job_kwargs.pop(key_to_remove) - # remove None runtime_job_kwargs_exclude_none = runtime_job_kwargs.copy() for job_key, job_value in runtime_job_kwargs.items(): if job_value is None: del runtime_job_kwargs_exclude_none[job_key] job_kwargs.update(runtime_job_kwargs_exclude_none) - # if n_jobs is -1, set to os.cpu_count() (n_jobs is always in global job_kwargs) n_jobs = job_kwargs["n_jobs"] - assert isinstance(n_jobs, (float, np.integer, int)) - if isinstance(n_jobs, float): + assert isinstance(n_jobs, (float, np.integer, int)) and n_jobs != 0, "n_jobs must be a non-zero int or float" + # for a fraction we do fraction of total cores + if isinstance(n_jobs, float) and 0 < n_jobs <= 1: n_jobs = int(n_jobs * os.cpu_count()) + # for negative numbers we count down from total cores (with -1 being all) elif n_jobs < 0: - n_jobs = os.cpu_count() + 1 + n_jobs + n_jobs = int(os.cpu_count() + 1 + n_jobs) + # otherwise we just take the value given + else: + n_jobs = int(n_jobs) + job_kwargs["n_jobs"] = max(n_jobs, 1) + if "n_jobs" not in runtime_job_kwargs and job_kwargs["n_jobs"] == 1 and not is_set_global_job_kwargs_set(): + warnings.warn( + "`n_jobs` is not set so parallel processing is disabled! " + "To speed up computations, it is recommended to set n_jobs either " + "globally (with the `spikeinterface.set_global_job_kwargs()` function) or " + "locally (with the `n_jobs` argument). Use `spikeinterface.set_global_job_kwargs?` " + "for more information about job_kwargs." + ) + return job_kwargs diff --git a/src/spikeinterface/core/tests/test_globals.py b/src/spikeinterface/core/tests/test_globals.py index d0672405d6..a45bb6f49c 100644 --- a/src/spikeinterface/core/tests/test_globals.py +++ b/src/spikeinterface/core/tests/test_globals.py @@ -1,4 +1,5 @@ import pytest +import warnings from pathlib import Path from spikeinterface import ( @@ -39,11 +40,22 @@ def test_global_tmp_folder(): def test_global_job_kwargs(): job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) global_job_kwargs = get_global_job_kwargs() + + # test warning when not setting n_jobs and calling fix_job_kwargs + with pytest.warns(UserWarning): + job_kwargs_split = fix_job_kwargs({}) + assert global_job_kwargs == dict( n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 ) set_global_job_kwargs(**job_kwargs) assert get_global_job_kwargs() == job_kwargs + + # after setting global job kwargs, fix_job_kwargs should not raise a warning + with warnings.catch_warnings(): + warnings.simplefilter("error") + job_kwargs_split = fix_job_kwargs({}) + # test updating only one field partial_job_kwargs = dict(n_jobs=2) set_global_job_kwargs(**partial_job_kwargs) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 8eefb9acd6..1adb4103dd 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -17,15 +17,14 @@ from spikeinterface.core import load_extractor, BaseRecordingSnippets from spikeinterface.core.core_tools import check_json +from spikeinterface.core.globals import get_global_job_kwargs from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from .utils import SpikeSortingError, ShellScript -default_job_kwargs = {"n_jobs": -1} - default_job_kwargs_description = { - "n_jobs": "Number of jobs (when saving ti binary) - default -1 (all cores)", - "chunk_size": "Number of samples per chunk (when saving ti binary) - default global", + "n_jobs": "Number of jobs (when saving to binary) - default global", + "chunk_size": "Number of samples per chunk (when saving to binary) - default global", "chunk_memory": "Memory usage for each job (e.g. '100M', '1G') (when saving to binary) - default global", "total_memory": "Total memory usage (e.g. '500M', '2G') (when saving to binary) - default global", "chunk_duration": "Chunk duration in s if float or with units if str (e.g. '1s', '500ms') (when saving to binary)" @@ -156,7 +155,7 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo def default_params(cls): p = copy.deepcopy(cls._default_params) if cls.requires_binary_data: - job_kwargs = fix_job_kwargs(default_job_kwargs) + job_kwargs = get_global_job_kwargs() p.update(job_kwargs) return p From 7fb826df81e5ef20ce5d84f305da734da5bbaed9 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jun 2024 17:29:29 +0200 Subject: [PATCH 06/15] Add release ntoes and fix fix_job_kwargs test --- doc/releases/0.100.7.rst | 2 +- doc/releases/0.100.8.rst | 15 +++++++++++++++ doc/whatisnew.rst | 7 +++++++ src/spikeinterface/core/tests/test_job_tools.py | 6 +++--- 4 files changed, 26 insertions(+), 4 deletions(-) create mode 100644 doc/releases/0.100.8.rst diff --git a/doc/releases/0.100.7.rst b/doc/releases/0.100.7.rst index a224494da5..c6cc1ecc0d 100644 --- a/doc/releases/0.100.7.rst +++ b/doc/releases/0.100.7.rst @@ -3,7 +3,7 @@ SpikeInterface 0.100.7 release notes ------------------------------------ -7th June 2024 +7th June 2024 Minor release with bug fixes diff --git a/doc/releases/0.100.8.rst b/doc/releases/0.100.8.rst new file mode 100644 index 0000000000..88db26da7e --- /dev/null +++ b/doc/releases/0.100.8.rst @@ -0,0 +1,15 @@ +.. _release0.100.8: + +SpikeInterface 0.100.8 release notes +------------------------------------ + +24th June 2024 + +Minor release with bug fixes + +* Remove separate default job_kwarg n_jobs for sorters (#2712) +* Add `whiteningRange` added as Kilosort 2/2.5/3 parameter (#2997) +* Make sure we check `is_filtered()`` rather than bound method during run basesorter (#3037) +* Numpy 2.0 cap Fix most egregorious deprecated behavior and cap version (#3032, #3056) +* Add support for kilosort>=4.0.12 (#3055) +* Check start_frame/end_frame in BaseRecording.get_traces() (#3059) diff --git a/doc/whatisnew.rst b/doc/whatisnew.rst index 5f35b3efd2..1a875f3ab3 100644 --- a/doc/whatisnew.rst +++ b/doc/whatisnew.rst @@ -8,6 +8,7 @@ Release notes .. toctree:: :maxdepth: 1 + releases/0.100.8.rst releases/0.100.7.rst releases/0.100.6.rst releases/0.100.5.rst @@ -41,6 +42,12 @@ Release notes releases/0.9.1.rst +Version 0.100.8 +=============== + +* Minor release with bug fixes + + Version 0.100.7 =============== diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 1bfe3a5e79..e8a2f902e0 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -180,10 +180,10 @@ def test_fix_job_kwargs(): else: assert fixed_job_kwargs["n_jobs"] == 1 - # test minimum n_jobs - job_kwargs = dict(n_jobs=0, progress_bar=False, chunk_duration="1s") + # test float value > 1 is cast to correct int + job_kwargs = dict(n_jobs=float(os.cpu_count()), progress_bar=False, chunk_duration="1s") fixed_job_kwargs = fix_job_kwargs(job_kwargs) - assert fixed_job_kwargs["n_jobs"] == 1 + assert fixed_job_kwargs["n_jobs"] == os.cpu_count() # test wrong keys with pytest.raises(AssertionError): From 5b538f5b1a92226f35a7119a222557c721706658 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jun 2024 17:37:53 +0200 Subject: [PATCH 07/15] Port #3059 --- src/spikeinterface/core/baserecording.py | 2 ++ src/spikeinterface/core/frameslicerecording.py | 4 ---- src/spikeinterface/core/generate.py | 10 ---------- src/spikeinterface/core/segmentutils.py | 10 ---------- src/spikeinterface/extractors/cbin_ibl.py | 4 ---- src/spikeinterface/extractors/nwbextractors.py | 5 ----- .../preprocessing/average_across_direction.py | 5 ----- src/spikeinterface/preprocessing/decimate.py | 7 ------- .../deepinterpolation/deepinterpolation.py | 6 ------ .../preprocessing/directional_derivative.py | 5 ----- src/spikeinterface/preprocessing/phase_shift.py | 4 ---- src/spikeinterface/preprocessing/remove_artifacts.py | 5 ----- src/spikeinterface/preprocessing/resample.py | 5 ----- src/spikeinterface/preprocessing/silence_periods.py | 5 ----- src/spikeinterface/preprocessing/zero_channel_pad.py | 9 --------- 15 files changed, 2 insertions(+), 84 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index b65409e033..07ebc9efb7 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -285,6 +285,8 @@ def get_traces( segment_index = self._check_segment_index(segment_index) channel_indices = self.ids_to_indices(channel_ids, prefer_slice=True) rs = self._recording_segments[segment_index] + start_frame = int(max(0, start_frame)) if start_frame is not None else 0 + end_frame = int(min(end_frame, rs.get_num_samples())) if end_frame is not None else rs.get_num_samples() traces = rs.get_traces(start_frame=start_frame, end_frame=end_frame, channel_indices=channel_indices) if order is not None: assert order in ["C", "F"] diff --git a/src/spikeinterface/core/frameslicerecording.py b/src/spikeinterface/core/frameslicerecording.py index cc899650d7..adb0c0e3ce 100644 --- a/src/spikeinterface/core/frameslicerecording.py +++ b/src/spikeinterface/core/frameslicerecording.py @@ -85,10 +85,6 @@ def get_num_samples(self): return self.end_frame - self.start_frame def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() parent_start = self.start_frame + start_frame parent_end = self.start_frame + end_frame traces = self._parent_recording_segment.get_traces( diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 2853a4fc55..e4a0dcb0be 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1092,11 +1092,6 @@ def get_traces( end_frame: Union[int, None] = None, channel_indices: Union[List, None] = None, ) -> np.ndarray: - start_frame = 0 if start_frame is None else max(start_frame, 0) - end_frame = self.num_samples if end_frame is None else min(end_frame, self.num_samples) - start_frame = int(start_frame) - end_frame = int(end_frame) - start_frame_within_block = start_frame % self.noise_block_size end_frame_within_block = end_frame % self.noise_block_size num_samples = end_frame - start_frame @@ -1652,11 +1647,6 @@ def get_traces( end_frame: Union[int, None] = None, channel_indices: Union[List, None] = None, ) -> np.ndarray: - start_frame = 0 if start_frame is None else start_frame - end_frame = self.num_samples if end_frame is None else end_frame - start_frame = int(start_frame) - end_frame = int(end_frame) - if channel_indices is None: n_channels = self.templates.shape[2] elif isinstance(channel_indices, slice): diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index c3881cc1f8..426fc5c523 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -163,11 +163,6 @@ def get_num_samples(self): return self.total_length def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() - # # Ensures that we won't request invalid segment indices if (start_frame >= self.get_num_samples()) or (end_frame <= start_frame): # Return (0 * num_channels) array of correct dtype @@ -462,11 +457,6 @@ def get_unit_spike_train( start_frame, end_frame, ): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() - i0, i1 = np.searchsorted(self.cumsum_length, [start_frame, end_frame], side="right") - 1 # several case: diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index 423b91ff63..1ac7da56cf 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -130,10 +130,6 @@ def get_num_samples(self): return self._cbuffer.shape[0] def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() if channel_indices is None: channel_indices = slice(None) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 39fdceceb0..27071cfe7b 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -873,11 +873,6 @@ def get_num_samples(self): return self._num_samples def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() - electrical_series_data = self.electrical_series_data if electrical_series_data.ndim == 1: traces = electrical_series_data[start_frame:end_frame][:, np.newaxis] diff --git a/src/spikeinterface/preprocessing/average_across_direction.py b/src/spikeinterface/preprocessing/average_across_direction.py index 71051f07ab..53f0d54147 100644 --- a/src/spikeinterface/preprocessing/average_across_direction.py +++ b/src/spikeinterface/preprocessing/average_across_direction.py @@ -116,11 +116,6 @@ def get_num_samples(self): return self.parent_recording_segment.get_num_samples() def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() - parent_traces = self.parent_recording_segment.get_traces( start_frame=start_frame, end_frame=end_frame, diff --git a/src/spikeinterface/preprocessing/decimate.py b/src/spikeinterface/preprocessing/decimate.py index 8c4970c4e4..aa5c600182 100644 --- a/src/spikeinterface/preprocessing/decimate.py +++ b/src/spikeinterface/preprocessing/decimate.py @@ -123,13 +123,6 @@ def get_num_samples(self): return int(np.ceil((parent_n_samp - self._decimation_offset) / self._decimation_factor)) def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() - end_frame = min(end_frame, self.get_num_samples()) - start_frame = min(start_frame, self.get_num_samples()) - # Account for offset and end when querying parent traces parent_start_frame = self._decimation_offset + start_frame * self._decimation_factor parent_end_frame = parent_start_frame + (end_frame - start_frame) * self._decimation_factor diff --git a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py index 8543c01218..5a376db061 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py @@ -147,12 +147,6 @@ def get_traces(self, start_frame, end_frame, channel_indices): n_frames = self.parent_recording_segment.get_num_samples() - if start_frame == None: - start_frame = 0 - - if end_frame == None: - end_frame = n_frames - # for frames that lack full training data (i.e. pre and post frames including omissinos), # just return uninterpolated if start_frame < self.pre_frame + self.pre_post_omission: diff --git a/src/spikeinterface/preprocessing/directional_derivative.py b/src/spikeinterface/preprocessing/directional_derivative.py index 5e77cc8ae6..f8aeac05fc 100644 --- a/src/spikeinterface/preprocessing/directional_derivative.py +++ b/src/spikeinterface/preprocessing/directional_derivative.py @@ -103,11 +103,6 @@ def __init__( self.unique_pos_other_dims, self.column_inds = np.unique(geom_other_dims, axis=0, return_inverse=True) def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() - parent_traces = self.parent_recording_segment.get_traces( start_frame=start_frame, end_frame=end_frame, diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index 23f4320053..8bb89623c9 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -81,10 +81,6 @@ def __init__(self, parent_recording_segment, sample_shifts, margin, dtype, tmp_d self.tmp_dtype = tmp_dtype def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() if channel_indices is None: channel_indices = slice(None) diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 793b44f099..a62c8ec1e5 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -258,11 +258,6 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) traces = traces.copy() - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() - mask = (self.triggers >= start_frame) & (self.triggers < end_frame) triggers = self.triggers[mask] - start_frame labels = self.labels[mask] diff --git a/src/spikeinterface/preprocessing/resample.py b/src/spikeinterface/preprocessing/resample.py index 9ec9c5779b..9d7e6611da 100644 --- a/src/spikeinterface/preprocessing/resample.py +++ b/src/spikeinterface/preprocessing/resample.py @@ -115,11 +115,6 @@ def get_num_samples(self): return int(self._parent_segment.get_num_samples() / self._parent_rate * self.sampling_frequency) def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() - # get parent traces with margin parent_start_frame, parent_end_frame = [ int((frame / self.sampling_frequency) * self._parent_rate) for frame in [start_frame, end_frame] diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 6413ec06b4..32ef545fa3 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -99,11 +99,6 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces = traces.copy() num_channels = traces.shape[1] - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() - if len(self.periods) > 0: new_interval = np.array([start_frame, end_frame]) lower_index = np.searchsorted(self.periods[:, 1], new_interval[0]) diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index ca56542475..308c0dc085 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -77,11 +77,6 @@ def __init__( super().__init__(parent_recording_segment=parent_recording_segment) def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() - # This contains the padded elements by default and we add the original traces if necessary trace_size = end_frame - start_frame if isinstance(channel_indices, (np.ndarray, list)): @@ -208,10 +203,6 @@ def __init__(self, parent_recording_segment: BaseRecordingSegment, num_channels: self.channel_mapping = channel_mapping def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() traces = np.zeros((end_frame - start_frame, self.num_channels)) traces[:, self.channel_mapping] = self.parent_recording_segment.get_traces( start_frame=start_frame, end_frame=end_frame, channel_indices=self.channel_mapping From 131913e1ac78818657337ad7ced5fe014451e7bd Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jun 2024 17:39:51 +0200 Subject: [PATCH 08/15] Port #3059 fix --- src/spikeinterface/core/segmentutils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index 426fc5c523..74e44d0255 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -457,6 +457,11 @@ def get_unit_spike_train( start_frame, end_frame, ): + if start_frame is None: + start_frame = 0 + if end_frame is None: + end_frame = self.get_num_samples() + i0, i1 = np.searchsorted(self.cumsum_length, [start_frame, end_frame], side="right") - 1 # several case: From 8fde57afa113f64af437e3488672caa17b2d6c20 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jun 2024 18:00:15 +0200 Subject: [PATCH 09/15] Port #2964 and fix qm tests --- doc/releases/0.100.8.rst | 1 + src/spikeinterface/qualitymetrics/misc_metrics.py | 2 +- .../qualitymetrics/tests/test_quality_metric_calculator.py | 6 +++++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/doc/releases/0.100.8.rst b/doc/releases/0.100.8.rst index 88db26da7e..9171884eca 100644 --- a/doc/releases/0.100.8.rst +++ b/doc/releases/0.100.8.rst @@ -8,6 +8,7 @@ SpikeInterface 0.100.8 release notes Minor release with bug fixes * Remove separate default job_kwarg n_jobs for sorters (#2712) +* Fix math error in sd_ratio (#2964) * Add `whiteningRange` added as Kilosort 2/2.5/3 parameter (#2997) * Make sure we check `is_filtered()`` rather than bound method during run basesorter (#3037) * Numpy 2.0 cap Fix most egregorious deprecated behavior and cap version (#3032, #3056) diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 079f7dc027..3276e6d6b8 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -1473,7 +1473,7 @@ def compute_sd_ratio( # Computing the variance of a trace that is all 0 and n_spikes non-overlapping template. # TODO: Take into account that templates for different segments might differ. p = wvf_extractor.nsamples * n_spikes[unit_id] / wvf_extractor.get_total_samples() - total_variance = p * np.mean(template**2) - p**2 * np.mean(template) + total_variance = p * np.mean(template**2) - p**2 * np.mean(template) ** 2 std_noise = np.sqrt(std_noise**2 - total_variance) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index b1055a716d..328ed06e00 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -15,6 +15,7 @@ select_segment_sorting, load_waveforms, aggregate_units, + set_global_job_kwargs, ) from spikeinterface.extractors import toy_example @@ -29,12 +30,14 @@ from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite - if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "qualitymetrics" else: cache_folder = Path("cache_folder") / "qualitymetrics" +# needed to suppress warnings +set_global_job_kwargs(n_jobs=1) + class QualityMetricsExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): extension_class = QualityMetricCalculator @@ -144,6 +147,7 @@ def test_metrics(self): # print(metrics_sparse) def test_amplitude_cutoff(self): + we = self.we_short _ = compute_spike_amplitudes(we, peak_sign="neg") From ca33ff582b2a2c629bb5493b4cbd08af0c6e963c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jun 2024 18:12:39 +0200 Subject: [PATCH 10/15] Move global job kwargs to test setup class --- .../qualitymetrics/tests/test_quality_metric_calculator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 328ed06e00..a4fd283530 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -35,9 +35,6 @@ else: cache_folder = Path("cache_folder") / "qualitymetrics" -# needed to suppress warnings -set_global_job_kwargs(n_jobs=1) - class QualityMetricsExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): extension_class = QualityMetricCalculator @@ -93,6 +90,9 @@ def setUp(self): self.we_long = we_long self.we_short = we_short + # needed to suppress warnings + set_global_job_kwargs(n_jobs=1) + def tearDown(self): super().tearDown() # delete object to release memmap From 5957c8c2662460a3996c8f6d00488fd20c9920b6 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 22 Jun 2024 15:56:28 +0200 Subject: [PATCH 11/15] Update #3059 --- src/spikeinterface/core/baserecording.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 07ebc9efb7..10604c6f60 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -285,8 +285,12 @@ def get_traces( segment_index = self._check_segment_index(segment_index) channel_indices = self.ids_to_indices(channel_ids, prefer_slice=True) rs = self._recording_segments[segment_index] - start_frame = int(max(0, start_frame)) if start_frame is not None else 0 + start_frame = int(start_frame) if start_frame is not None else 0 end_frame = int(min(end_frame, rs.get_num_samples())) if end_frame is not None else rs.get_num_samples() + if start_frame < 0: + raise ValueError("start_frame cannot be negative") + if start_frame > end_frame: + raise ValueError("start_frame cannot be greater than end_frame") traces = rs.get_traces(start_frame=start_frame, end_frame=end_frame, channel_indices=channel_indices) if order is not None: assert order in ["C", "F"] From 5353469f8e7065c985933ec79410840c649d16d3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 22 Jun 2024 16:01:07 +0200 Subject: [PATCH 12/15] Update #3059 --- src/spikeinterface/core/tests/test_binaryrecordingextractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py index fb4c3ee3c4..db2e87ecd3 100644 --- a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py +++ b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py @@ -38,7 +38,7 @@ def test_BinaryRecordingExtractor(): def test_round_trip(tmp_path): num_channels = 10 - num_samples = 50 + num_samples = 500 traces_list = [np.ones(shape=(num_samples, num_channels), dtype="int32")] sampling_frequency = 30_000.0 recording = NumpyRecording(traces_list=traces_list, sampling_frequency=sampling_frequency) From f874cd230885d57a2a88475adf804833a23747d1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 22 Jun 2024 17:03:55 +0200 Subject: [PATCH 13/15] Update #3059 --- src/spikeinterface/core/baserecording.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 10604c6f60..256607c331 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -289,8 +289,6 @@ def get_traces( end_frame = int(min(end_frame, rs.get_num_samples())) if end_frame is not None else rs.get_num_samples() if start_frame < 0: raise ValueError("start_frame cannot be negative") - if start_frame > end_frame: - raise ValueError("start_frame cannot be greater than end_frame") traces = rs.get_traces(start_frame=start_frame, end_frame=end_frame, channel_indices=channel_indices) if order is not None: assert order in ["C", "F"] From 68886b070ea6c275e79b9a88f2e7b81ea817c349 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jun 2024 12:51:29 +0200 Subject: [PATCH 14/15] Upgrade #3059 --- src/spikeinterface/core/baserecording.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 256607c331..748888e03a 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -286,7 +286,8 @@ def get_traces( channel_indices = self.ids_to_indices(channel_ids, prefer_slice=True) rs = self._recording_segments[segment_index] start_frame = int(start_frame) if start_frame is not None else 0 - end_frame = int(min(end_frame, rs.get_num_samples())) if end_frame is not None else rs.get_num_samples() + num_samples = rs.get_num_samples() + end_frame = int(min(end_frame, num_samples)) if end_frame is not None else num_samples if start_frame < 0: raise ValueError("start_frame cannot be negative") traces = rs.get_traces(start_frame=start_frame, end_frame=end_frame, channel_indices=channel_indices) From 9e8a709dc72be0e597a15c33c3d535c1995e1328 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jun 2024 17:50:53 +0200 Subject: [PATCH 15/15] Upgrade #3059 --- src/spikeinterface/core/baserecording.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 748888e03a..afe1ac46e6 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -288,8 +288,6 @@ def get_traces( start_frame = int(start_frame) if start_frame is not None else 0 num_samples = rs.get_num_samples() end_frame = int(min(end_frame, num_samples)) if end_frame is not None else num_samples - if start_frame < 0: - raise ValueError("start_frame cannot be negative") traces = rs.get_traces(start_frame=start_frame, end_frame=end_frame, channel_indices=channel_indices) if order is not None: assert order in ["C", "F"]