From 741834ab3fe72c1617d72142d2b860914db90e89 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 23 Nov 2023 15:49:37 +0100 Subject: [PATCH 01/19] Add outside_channels_location (top, bottom, both) in detect_bad_channels --- .../preprocessing/detect_bad_channels.py | 76 +++++++++++++------ .../tests/test_detect_bad_channels.py | 56 ++++++++++++-- 2 files changed, 103 insertions(+), 29 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index a162cfe636..f013537d6d 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -1,29 +1,32 @@ +from __future__ import annotations import warnings import numpy as np +from typing import Literal from .filter import highpass_filter -from ..core import get_random_data_chunks, order_channels_by_depth +from ..core import get_random_data_chunks, order_channels_by_depth, BaseRecording def detect_bad_channels( - recording, - method="coherence+psd", - std_mad_threshold=5, - psd_hf_threshold=0.02, - dead_channel_threshold=-0.5, - noisy_channel_threshold=1.0, - outside_channel_threshold=-0.75, - n_neighbors=11, - nyquist_threshold=0.8, - direction="y", - chunk_duration_s=0.3, - num_random_chunks=100, - welch_window_ms=10.0, - highpass_filter_cutoff=300, - neighborhood_r2_threshold=0.9, - neighborhood_r2_radius_um=30.0, - seed=None, + recording: BaseRecording, + method: str = "coherence+psd", + std_mad_threshold: float = 5, + psd_hf_threshold: float = 0.02, + dead_channel_threshold: float = -0.5, + noisy_channel_threshold: float = 1.0, + outside_channel_threshold: float = -0.75, + outside_channels_location: Literal["top", "bottom", "both"] = "top", + n_neighbors: int = 11, + nyquist_threshold: float = 0.8, + direction: Literal["x", "y", "z"] = "y", + chunk_duration_s: float = 0.3, + num_random_chunks: int = 100, + welch_window_ms: float = 10.0, + highpass_filter_cutoff: float = 300, + neighborhood_r2_threshold: float = 0.9, + neighborhood_r2_radius_um: float = 30.0, + seed: int | None = None, ): """ Perform bad channel detection. @@ -65,6 +68,11 @@ def detect_bad_channels( outside_channel_threshold (coeherence+psd) : float, default: -0.75 Threshold for channel coherence above which channels at the edge of the recording are marked as outside of the brain + outside_channels_location (coeherence+psd) : "top" | "bottom" | "both", default: "top" + Location of the outside channels. If "top", only the channels at the top of the probe can be + marked as outside channels. If "bottom", only the channels at the bottom of the probe can be + marked as outside channels. If "both", both the channels at the top and bottom of the probe can be + marked as outside channels n_neighbors (coeherence+psd) : int, default: 11 Number of channel neighbors to compute median filter (needs to be odd) nyquist_threshold (coeherence+psd) : float, default: 0.8 @@ -190,6 +198,7 @@ def detect_bad_channels( n_neighbors=n_neighbors, nyquist_threshold=nyquist_threshold, welch_window_ms=welch_window_ms, + outside_channels_location=outside_channels_location, ) chunk_channel_labels[:, i] = chunk_labels[order_r] if order_r is not None else chunk_labels @@ -275,6 +284,7 @@ def detect_bad_channels_ibl( n_neighbors=11, nyquist_threshold=0.8, welch_window_ms=0.3, + outside_channels_location="top", ): """ Bad channels detection for Neuropixel probes developed by IBL @@ -300,6 +310,11 @@ def detect_bad_channels_ibl( Threshold on Nyquist frequency to calculate HF noise band welch_window_ms: float, default: 0.3 Window size for the scipy.signal.welch that will be converted to nperseg + outside_channels_location : "top" | "bottom" | "both", default: "top" + Location of the outside channels. If "top", only the channels at the top of the probe can be + marked as outside channels. If "bottom", only the channels at the bottom of the probe can be + marked as outside channels. If "both", both the channels at the top and bottom of the probe can be + marked as outside channels Returns ------- @@ -332,12 +347,25 @@ def detect_bad_channels_ibl( ichannels[inoisy] = 2 # the channels outside of the brains are the contiguous channels below the threshold on the trend coherency - # the chanels outide need to be at either extremes of the probe - ioutside = np.where(xcorr_distant < outside_channel_thr)[0] - if ioutside.size > 0 and (ioutside[-1] == (nc - 1) or ioutside[0] == 0): - a = np.cumsum(np.r_[0, np.diff(ioutside) - 1]) - ioutside = ioutside[a == np.max(a)] - ichannels[ioutside] = 3 + # the chanels outide need to be at the extreme of the probe + (ioutside,) = np.where(xcorr_distant < outside_channel_thr) + ichannels = np.zeros_like(xcorr_distant, dtype=int) + a = np.cumsum(np.r_[0, np.diff(ioutside) - 1]) + if ioutside.size > 0: + if outside_channels_location == "top": + # channels are sorted bottom to top, so the last channel needs to be (nc - 1) + if ioutside[-1] == (nc - 1): + ioutside = ioutside[(a == np.max(a)) & (a > 0)] + ichannels[ioutside] = 3 + elif outside_channels_location == "bottom": + # outside channels are at the bottom of the probe, so the first channel needs to be 0 + if ioutside[0] == 0: + ioutside = ioutside[(a == np.min(a)) & (a < np.max(a))] + ichannels[ioutside] = 3 + else: # both extremes are considered + if ioutside[-1] == (nc - 1) or ioutside[0] == 0: + ioutside = ioutside[(a == np.max(a)) | (a == np.min(a))] + ichannels[ioutside] = 3 return ichannels diff --git a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py index c2de263063..4071bfe0ea 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py @@ -19,7 +19,7 @@ HAVE_NPIX = False -def test_remove_bad_channels_std_mad(): +def test_detect_bad_channels_std_mad(): num_channels = 4 sampling_frequency = 30000.0 durations = [10.325, 3.5] @@ -60,9 +60,48 @@ def test_remove_bad_channels_std_mad(): ), "wrong channels locations." +@pytest.mark.parametrize("outside_channels_location", ["bottom", "top", "both"]) +def test_detect_bad_channels_extremes(outside_channels_location): + num_channels = 64 + sampling_frequency = 30000.0 + durations = [20] + num_out_channels = 10 + + num_segments = len(durations) + num_timepoints = [int(sampling_frequency * d) for d in durations] + + traces_list = [] + for i in range(num_segments): + traces = np.random.randn(num_timepoints[i], num_channels).astype("float32") + # extreme channels are "out" + traces[:, :num_out_channels] *= 0.05 + traces[:, -num_out_channels:] *= 0.05 + traces_list.append(traces) + + rec = NumpyRecording(traces_list, sampling_frequency) + rec.set_channel_gains(1) + rec.set_channel_offsets(0) + + probe = generate_linear_probe(num_elec=num_channels) + probe.set_device_channel_indices(np.arange(num_channels)) + rec.set_probe(probe, in_place=True) + + bad_channel_ids, bad_labels = detect_bad_channels( + rec, method="coherence+psd", outside_channels_location=outside_channels_location + ) + if outside_channels_location == "top": + assert np.array_equal(bad_channel_ids, rec.channel_ids[-num_out_channels:]) + elif outside_channels_location == "bottom": + assert np.array_equal(bad_channel_ids, rec.channel_ids[:num_out_channels]) + elif outside_channels_location == "both": + assert np.array_equal( + bad_channel_ids, np.concatenate((rec.channel_ids[:num_out_channels], rec.channel_ids[-num_out_channels:])) + ) + + @pytest.mark.skipif(not HAVE_NPIX, reason="ibl-neuropixel is not installed") @pytest.mark.parametrize("num_channels", [32, 64, 384]) -def test_remove_bad_channels_ibl(num_channels): +def test_detect_bad_channels_ibl(num_channels): """ Cannot test against DL datasets because they are too short and need to control the PSD scaling. Here generate a dataset @@ -121,7 +160,9 @@ def test_remove_bad_channels_ibl(num_channels): traces_uV = random_chunk.T traces_V = traces_uV * 1e-6 channel_flags, _ = neurodsp.voltage.detect_bad_channels( - traces_V, recording.get_sampling_frequency(), psd_hf_threshold=psd_cutoff + traces_V, + recording.get_sampling_frequency(), + psd_hf_threshold=psd_cutoff, ) channel_flags_ibl[:, i] = channel_flags @@ -209,5 +250,10 @@ def add_dead_channels(recording, is_dead): if __name__ == "__main__": - test_remove_bad_channels_std_mad() - test_remove_bad_channels_ibl(num_channels=384) + # test_detect_bad_channels_std_mad() + test_detect_bad_channels_ibl(num_channels=32) + test_detect_bad_channels_ibl(num_channels=64) + test_detect_bad_channels_ibl(num_channels=384) + # test_detect_bad_channels_extremes("top") + # test_detect_bad_channels_extremes("bottom") + # test_detect_bad_channels_extremes("both") From 58293022cba9f4ec646a70c8c7d2f1ff4bf3d66d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 23 Nov 2023 16:02:10 +0100 Subject: [PATCH 02/19] Oups --- src/spikeinterface/preprocessing/detect_bad_channels.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index f013537d6d..8e323e4566 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -349,7 +349,6 @@ def detect_bad_channels_ibl( # the channels outside of the brains are the contiguous channels below the threshold on the trend coherency # the chanels outide need to be at the extreme of the probe (ioutside,) = np.where(xcorr_distant < outside_channel_thr) - ichannels = np.zeros_like(xcorr_distant, dtype=int) a = np.cumsum(np.r_[0, np.diff(ioutside) - 1]) if ioutside.size > 0: if outside_channels_location == "top": From d0263c043039900167d4d19d0789b20adaeb52f0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 27 Nov 2023 10:36:47 +0100 Subject: [PATCH 03/19] Fix memory leak in lsmr solver and optimize correct_motion --- src/spikeinterface/preprocessing/motion.py | 66 ++++++++++--------- .../sortingcomponents/motion_estimation.py | 3 + 2 files changed, 37 insertions(+), 32 deletions(-) diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 3992a4c8c6..e26ae6dbc3 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -8,6 +8,7 @@ from spikeinterface.core import get_noise_levels, fix_job_kwargs from spikeinterface.core.job_tools import _shared_job_kwargs_doc from spikeinterface.core.core_tools import SIJsonEncoder +from torch import gather motion_options_preset = { # This preset should be the most acccurate @@ -20,7 +21,7 @@ exclude_sweep_ms=0.1, radius_um=50, ), - "select_kwargs": None, + "select_kwargs": dict(), "localize_peaks_kwargs": dict( method="monopolar_triangulation", radius_um=75.0, @@ -83,7 +84,7 @@ exclude_sweep_ms=0.1, radius_um=50, ), - "select_kwargs": None, + "select_kwargs": dict(), "localize_peaks_kwargs": dict( method="center_of_mass", radius_um=75.0, @@ -111,7 +112,7 @@ exclude_sweep_ms=0.1, radius_um=50, ), - "select_kwargs": None, + "select_kwargs": dict(), "localize_peaks_kwargs": dict( method="grid_convolution", radius_um=40.0, @@ -157,7 +158,7 @@ def correct_motion( folder=None, output_motion_info=False, detect_kwargs={}, - select_kwargs=None, + select_kwargs={}, localize_peaks_kwargs={}, estimate_motion_kwargs={}, interpolate_motion_kwargs={}, @@ -241,13 +242,22 @@ def correct_motion( # get preset params and update if necessary params = motion_options_preset[preset] detect_kwargs = dict(params["detect_kwargs"], **detect_kwargs) - if params["select_kwargs"] is None: - select_kwargs = None - else: - select_kwargs = dict(params["select_kwargs"], **select_kwargs) + select_kwargs = dict(params["select_kwargs"], **select_kwargs) localize_peaks_kwargs = dict(params["localize_peaks_kwargs"], **localize_peaks_kwargs) estimate_motion_kwargs = dict(params["estimate_motion_kwargs"], **estimate_motion_kwargs) interpolate_motion_kwargs = dict(params["interpolate_motion_kwargs"], **interpolate_motion_kwargs) + do_selection = len(select_kwargs) > 0 + + # params + parameters = dict( + detect_kwargs=detect_kwargs, + select_kwargs=select_kwargs, + localize_peaks_kwargs=localize_peaks_kwargs, + estimate_motion_kwargs=estimate_motion_kwargs, + interpolate_motion_kwargs=interpolate_motion_kwargs, + job_kwargs=job_kwargs, + sampling_frequency=recording.sampling_frequency, + ) if output_motion_info: motion_info = {} @@ -255,13 +265,20 @@ def correct_motion( motion_info = None job_kwargs = fix_job_kwargs(job_kwargs) - noise_levels = get_noise_levels(recording, return_scaled=False) - if select_kwargs is None: - # maybe do this directly in the folder when not None + if folder is not None: + folder = Path(folder) + folder.mkdir(exist_ok=True, parents=True) + + (folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8") + if recording.check_serializability("json"): + recording.dump_to_json(folder / "recording.json") + gather_mode = "npy" + else: gather_mode = "memory" + if not do_selection: # node detect method = detect_kwargs.pop("method", "locally_exclusive") method_class = detect_peak_methods[method] @@ -281,9 +298,10 @@ def correct_motion( job_kwargs, job_name="detect and localize", gather_mode=gather_mode, + gather_kwargs={"exist_ok": True}, squeeze_output=False, - folder=None, - names=None, + folder=folder, + names=["peaks", "peak_locations"], ) t1 = time.perf_counter() run_times = dict( @@ -307,6 +325,9 @@ def correct_motion( select_peaks=t2 - t1, localize_peaks=t3 - t2, ) + if folder is not None: + np.save(folder / "peaks.npy", peaks) + np.save(folder / "peak_locations.npy", peak_locations) t0 = time.perf_counter() motion, temporal_bins, spatial_bins = estimate_motion(recording, peaks, peak_locations, **estimate_motion_kwargs) @@ -318,29 +339,10 @@ def correct_motion( ) if folder is not None: - folder = Path(folder) - folder.mkdir(exist_ok=True, parents=True) - - # params and run times - parameters = dict( - detect_kwargs=detect_kwargs, - select_kwargs=select_kwargs, - localize_peaks_kwargs=localize_peaks_kwargs, - estimate_motion_kwargs=estimate_motion_kwargs, - interpolate_motion_kwargs=interpolate_motion_kwargs, - job_kwargs=job_kwargs, - sampling_frequency=recording.sampling_frequency, - ) - (folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8") (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") - if recording.check_serializability("json"): - recording.dump_to_json(folder / "recording.json") - np.save(folder / "peaks.npy", peaks) - np.save(folder / "peak_locations.npy", peak_locations) np.save(folder / "temporal_bins.npy", temporal_bins) np.save(folder / "motion.npy", motion) - np.save(folder / "peak_locations.npy", peak_locations) if spatial_bins is not None: np.save(folder / "spatial_bins.npy", spatial_bins) diff --git a/src/spikeinterface/sortingcomponents/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion_estimation.py index df73575a01..141fc531f4 100644 --- a/src/spikeinterface/sortingcomponents/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion_estimation.py @@ -1039,6 +1039,7 @@ def jac(p): displacement = p elif convergence_method == "lsmr": + import gc from scipy import sparse from scipy.stats import zscore @@ -1170,6 +1171,8 @@ def jac(p): # warm start next iteration p0 = displacement + # Cleanup lsmr memory (see https://stackoverflow.com/questions/56147713/memory-leak-in-scipy) + gc.collect() displacement = displacement.reshape(B, T).T else: From 5f127c40bb6284c67179ae349294dae929ef574c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 27 Nov 2023 11:32:28 +0100 Subject: [PATCH 04/19] Remove unused import --- src/spikeinterface/preprocessing/motion.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index e26ae6dbc3..f451ef8618 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -8,7 +8,6 @@ from spikeinterface.core import get_noise_levels, fix_job_kwargs from spikeinterface.core.job_tools import _shared_job_kwargs_doc from spikeinterface.core.core_tools import SIJsonEncoder -from torch import gather motion_options_preset = { # This preset should be the most acccurate From 54bb7fdcab895fddd25a5f59682e5974717f33e9 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 27 Nov 2023 13:21:07 +0100 Subject: [PATCH 05/19] Add TODO --- src/spikeinterface/sortingcomponents/motion_estimation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/sortingcomponents/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion_estimation.py index 141fc531f4..1345bd312c 100644 --- a/src/spikeinterface/sortingcomponents/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion_estimation.py @@ -1172,6 +1172,7 @@ def jac(p): # warm start next iteration p0 = displacement # Cleanup lsmr memory (see https://stackoverflow.com/questions/56147713/memory-leak-in-scipy) + # TODO: check if this gets fixed in scipy gc.collect() displacement = displacement.reshape(B, T).T From fdbd852b38bf932eff016a8c5e0bddbecd7c5171 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 29 Nov 2023 12:04:08 +0100 Subject: [PATCH 06/19] run_sorter in container check json or pickle --- src/spikeinterface/sorters/runsorter.py | 26 ++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index da8b48085c..0a13f8d754 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -2,6 +2,7 @@ import os from pathlib import Path import json +import pickle import platform from warnings import warn from typing import Optional, Union @@ -414,9 +415,15 @@ def run_sorter_container( # create 3 files for communication with container # recording dict inside - (parent_folder / "in_container_recording.json").write_text( - json.dumps(check_json(rec_dict), indent=4), encoding="utf8" - ) + if recording.check_serializability("json"): + (parent_folder / "in_container_recording.json").write_text( + json.dumps(check_json(rec_dict), indent=4), encoding="utf8" + ) + elif recording.check_serializability("pickle"): + (parent_folder / "in_container_recording.pickle").write_bytes(pickle.dumps(rec_dict)) + else: + raise RuntimeError("To use run_sorter with container the recording must serializable") + # need to share specific parameters (parent_folder / "in_container_params.json").write_text( json.dumps(check_json(sorter_params), indent=4), encoding="utf8" @@ -433,13 +440,19 @@ def run_sorter_container( # the py script py_script = f""" import json +from pathlib import Path from spikeinterface import load_extractor from spikeinterface.sorters import run_sorter_local if __name__ == '__main__': # this __name__ protection help in some case with multiprocessing (for instance HS2) # load recording in container - recording = load_extractor('{parent_folder_unix}/in_container_recording.json') + json_rec = Path('{parent_folder_unix}/in_container_recording.json') + pickle_rec = Path('{parent_folder_unix}/in_container_recording.pickle') + if json_rec.exists(): + recording = load_extractor(json_rec) + else: + recording = load_extractor(pickle_rec) # load params in container with open('{parent_folder_unix}/in_container_params.json', encoding='utf8', mode='r') as f: @@ -593,7 +606,10 @@ def run_sorter_container( # clean useless files if delete_container_files: - os.remove(parent_folder / "in_container_recording.json") + if (parent_folder / "in_container_recording.json").exists(): + os.remove(parent_folder / "in_container_recording.json") + if (parent_folder / "in_container_recording.pickle").exists(): + os.remove(parent_folder / "in_container_recording.pickle") os.remove(parent_folder / "in_container_params.json") os.remove(parent_folder / "in_container_sorter_script.py") if mode == "singularity": From 20cb4bd4c8e7807c51765727694dad1ae20c2ca5 Mon Sep 17 00:00:00 2001 From: h-mayorquin Date: Thu, 30 Nov 2023 07:42:42 +0100 Subject: [PATCH 07/19] add nwb sorting rem file support --- .../extractors/nwbextractors.py | 28 +++++++---------- .../extractors/tests/test_nwb_s3_extractor.py | 31 +++++++++++++++++++ 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index d8284c7abe..c87bf02586 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -192,8 +192,8 @@ class NwbRecordingExtractor(BaseRecording): samples_for_rate_estimation: int, default: 100000 The number of timestamp samples to use to estimate the rate. Used if "rate" is not specified in the ElectricalSeries. - stream_mode: str or None, default: None - Specify the stream mode: "fsspec" or "ros3". + stream_mode : "fsspec" | "ros3" | "remfile" | None, default: None + The streaming mode to use. If None it assumes the file is on the local disk. cache: bool, default: False If True, the file is cached in the file passed to stream_cache_path if False, the file is not cached. @@ -411,13 +411,12 @@ def __init__( self.set_channel_groups(groups) else: self.set_property(property_name, values) + + if stream_mode is None and file_path is not None: + file_path = str(Path(file_path).resolve()) - if stream_mode not in ["fsspec", "ros3", "remfile"]: - if file_path is not None: - file_path = str(Path(file_path).absolute()) - if stream_mode == "fsspec": - if stream_cache_path is not None: - stream_cache_path = str(Path(self.stream_cache_path).absolute()) + if stream_mode == "fsspec" and stream_cache_path is not None: + stream_cache_path = str(Path(self.stream_cache_path).absolute()) self.extra_requirements.extend(["pandas", "pynwb", "hdmf"]) self._electrical_series = electrical_series @@ -493,8 +492,8 @@ class NwbSortingExtractor(BaseSorting): samples_for_rate_estimation: int, default: 100000 The number of timestamp samples to use to estimate the rate. Used if "rate" is not specified in the ElectricalSeries. - stream_mode: str or None, default: None - Specify the stream mode: "fsspec" or "ros3". + stream_mode : "fsspec" | "ros3" | "remfile" | None, default: None + The streaming mode to use. If None it assumes the file is on the local disk. cache: bool, default: False If True, the file is cached in the file passed to stream_cache_path if False, the file is not cached. @@ -591,12 +590,9 @@ def __init__( for prop_name, values in properties.items(): self.set_property(prop_name, np.array(values)) - if stream_mode not in ["fsspec", "ros3"]: - file_path = str(Path(file_path).absolute()) - if stream_mode == "fsspec": - # only add stream_cache_path to kwargs if it was passed as an argument - if stream_cache_path is not None: - stream_cache_path = str(Path(self.stream_cache_path).absolute()) + if stream_mode is None and file_path is not None: + file_path = str(Path(file_path).resolve()) + self._kwargs = { "file_path": file_path, "electrical_series_name": self._electrical_series_name, diff --git a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py index 45d969dde9..34c4d17fd0 100644 --- a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py +++ b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py @@ -219,6 +219,37 @@ def test_sorting_s3_nwb_fsspec(tmp_path, cache): check_sortings_equal(reloaded_sorting, sorting) +@pytest.mark.streaming_extractors +def test_sorting_s3_nwb_remfile(tmp_path): + file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" + # We provide the 'sampling_frequency' because the NWB file does not have the electrical series + sorting = NwbSortingExtractor( + file_path, + sampling_frequency=30000.0, + stream_mode="remfile", + ) + + num_seg = sorting.get_num_segments() + assert num_seg == 1 + num_units = len(sorting.unit_ids) + assert num_units == 64 + + for segment_index in range(num_seg): + for unit in sorting.unit_ids: + spike_train = sorting.get_unit_spike_train(unit_id=unit, segment_index=segment_index) + assert len(spike_train) > 0 + assert spike_train.dtype == "int64" + assert np.all(spike_train >= 0) + + tmp_file = tmp_path / "test_remfile_sorting.pkl" + with open(tmp_file, "wb") as f: + pickle.dump(sorting, f) + + with open(tmp_file, "rb") as f: + reloaded_sorting = pickle.load(f) + + check_sortings_equal(reloaded_sorting, sorting) + if __name__ == "__main__": test_recording_s3_nwb_ros3() From 18de6a7d835039a246a11de51573bf13b83a5d16 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 Nov 2023 06:44:16 +0000 Subject: [PATCH 08/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/extractors/nwbextractors.py | 4 ++-- src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index c87bf02586..e4c6e264fc 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -411,7 +411,7 @@ def __init__( self.set_channel_groups(groups) else: self.set_property(property_name, values) - + if stream_mode is None and file_path is not None: file_path = str(Path(file_path).resolve()) @@ -592,7 +592,7 @@ def __init__( if stream_mode is None and file_path is not None: file_path = str(Path(file_path).resolve()) - + self._kwargs = { "file_path": file_path, "electrical_series_name": self._electrical_series_name, diff --git a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py index 34c4d17fd0..9183c5b728 100644 --- a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py +++ b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py @@ -219,6 +219,7 @@ def test_sorting_s3_nwb_fsspec(tmp_path, cache): check_sortings_equal(reloaded_sorting, sorting) + @pytest.mark.streaming_extractors def test_sorting_s3_nwb_remfile(tmp_path): file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" From 2f7ca19b7cc15ec72ba058b30388f6fc1a0378b3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 30 Nov 2023 10:26:44 +0100 Subject: [PATCH 09/19] Update src/spikeinterface/sorters/runsorter.py Co-authored-by: Heberto Mayorquin --- src/spikeinterface/sorters/runsorter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 0a13f8d754..7591c9eb2c 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -422,7 +422,7 @@ def run_sorter_container( elif recording.check_serializability("pickle"): (parent_folder / "in_container_recording.pickle").write_bytes(pickle.dumps(rec_dict)) else: - raise RuntimeError("To use run_sorter with container the recording must serializable") + raise RuntimeError("To use run_sorter with container the recording must be serializable") # need to share specific parameters (parent_folder / "in_container_params.json").write_text( From 688afa7c07396a6ad57203bb89649dd294f4c511 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Thu, 30 Nov 2023 10:54:05 +0100 Subject: [PATCH 10/19] Strict inegality for radius_um --- src/spikeinterface/core/sparsity.py | 2 +- .../sortingcomponents/clustering/clustering_tools.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 3b8b6025ca..893da59d74 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -292,7 +292,7 @@ def from_radius(cls, we, radius_um, peak_sign="neg"): best_chan = get_template_extremum_channel(we, peak_sign=peak_sign, outputs="index") for unit_ind, unit_id in enumerate(we.unit_ids): chan_ind = best_chan[unit_id] - (chan_inds,) = np.nonzero(distances[chan_ind, :] <= radius_um) + (chan_inds,) = np.nonzero(distances[chan_ind, :] < radius_um) mask[unit_ind, chan_inds] = True return cls(mask, we.unit_ids, we.channel_ids) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 629b0b13ac..050ba10efb 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -368,7 +368,7 @@ def auto_clean_clustering( # we use (radius_chans,) = np.nonzero( - (channel_distances[main_chan0, :] <= radius_um) | (channel_distances[main_chan1, :] <= radius_um) + (channel_distances[main_chan0, :] < radius_um) | (channel_distances[main_chan1, :] < radius_um) ) if radius_chans.size < (intersect_chans.size * ratio_num_channel_intersect): # ~ print('WARNING INTERSECT') From 7f3be4040546900ff0f1a4779990e557140018de Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 30 Nov 2023 11:31:24 +0100 Subject: [PATCH 11/19] Update src/spikeinterface/preprocessing/detect_bad_channels.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/preprocessing/detect_bad_channels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index 8e323e4566..1fdd7737d0 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -347,7 +347,7 @@ def detect_bad_channels_ibl( ichannels[inoisy] = 2 # the channels outside of the brains are the contiguous channels below the threshold on the trend coherency - # the chanels outide need to be at the extreme of the probe + # the chanels outside need to be at the extreme of the probe (ioutside,) = np.where(xcorr_distant < outside_channel_thr) a = np.cumsum(np.r_[0, np.diff(ioutside) - 1]) if ioutside.size > 0: From 0fd1e67f722918026e35f62bed292e3253114641 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 30 Nov 2023 16:27:30 +0100 Subject: [PATCH 12/19] Always use gather_mode='memory' and then save --- src/spikeinterface/preprocessing/motion.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index f451ef8618..f2c7983f2b 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -273,9 +273,6 @@ def correct_motion( (folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8") if recording.check_serializability("json"): recording.dump_to_json(folder / "recording.json") - gather_mode = "npy" - else: - gather_mode = "memory" if not do_selection: # node detect @@ -296,11 +293,11 @@ def correct_motion( pipeline_nodes, job_kwargs, job_name="detect and localize", - gather_mode=gather_mode, + gather_mode="memory", gather_kwargs={"exist_ok": True}, squeeze_output=False, - folder=folder, - names=["peaks", "peak_locations"], + folder=None, + names=None, ) t1 = time.perf_counter() run_times = dict( @@ -324,9 +321,9 @@ def correct_motion( select_peaks=t2 - t1, localize_peaks=t3 - t2, ) - if folder is not None: - np.save(folder / "peaks.npy", peaks) - np.save(folder / "peak_locations.npy", peak_locations) + if folder is not None: + np.save(folder / "peaks.npy", peaks) + np.save(folder / "peak_locations.npy", peak_locations) t0 = time.perf_counter() motion, temporal_bins, spatial_bins = estimate_motion(recording, peaks, peak_locations, **estimate_motion_kwargs) From c4994617b2b2e88a0897f38124905a0b32a88b8a Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 1 Dec 2023 09:52:43 +0100 Subject: [PATCH 13/19] Radius_um now <= everywhere --- src/spikeinterface/core/node_pipeline.py | 4 ++-- src/spikeinterface/core/sparsity.py | 2 +- .../postprocessing/unit_localization.py | 2 +- src/spikeinterface/preprocessing/whiten.py | 2 +- .../clustering/clustering_tools.py | 2 +- .../sortingcomponents/clustering/merge.py | 4 ++-- .../sortingcomponents/features_from_peaks.py | 16 ++++++++-------- .../sortingcomponents/peak_detection.py | 6 +++--- 8 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index a00df98e05..fd8dbd35b6 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -175,7 +175,7 @@ def __init__( if not channel_from_template: channel_distance = get_channel_distances(recording) - self.neighbours_mask = channel_distance < radius_um + self.neighbours_mask = channel_distance <= radius_um self.peak_sign = peak_sign # precompute segment slice @@ -367,7 +367,7 @@ def __init__( self.radius_um = radius_um self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self.max_num_chans = np.max(np.sum(self.neighbours_mask, axis=1)) def get_trace_margin(self): diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 893da59d74..3b8b6025ca 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -292,7 +292,7 @@ def from_radius(cls, we, radius_um, peak_sign="neg"): best_chan = get_template_extremum_channel(we, peak_sign=peak_sign, outputs="index") for unit_ind, unit_id in enumerate(we.unit_ids): chan_ind = best_chan[unit_id] - (chan_inds,) = np.nonzero(distances[chan_ind, :] < radius_um) + (chan_inds,) = np.nonzero(distances[chan_ind, :] <= radius_um) mask[unit_ind, chan_inds] = True return cls(mask, we.unit_ids, we.channel_ids) diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index f665bac8d6..2ac841c148 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -597,7 +597,7 @@ def get_grid_convolution_templates_and_weights( # mask to get nearest template given a channel dist = sklearn.metrics.pairwise_distances(contact_locations, template_positions) - nearest_template_mask = dist < radius_um + nearest_template_mask = dist <= radius_um weights = np.zeros((len(sigma_um), len(contact_locations), nb_templates), dtype=np.float32) for count, sigma in enumerate(sigma_um): diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 3bea9b91bb..766229b62a 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -197,7 +197,7 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r distances = get_channel_distances(recording) W = np.zeros((n, n), dtype="float64") for c in range(n): - (inds,) = np.nonzero(distances[c, :] < radius_um) + (inds,) = np.nonzero(distances[c, :] <= radius_um) cov_local = cov[inds, :][:, inds] U, S, Ut = np.linalg.svd(cov_local, full_matrices=True) W_local = (U @ np.diag(1 / np.sqrt(S + eps))) @ Ut diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 050ba10efb..629b0b13ac 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -368,7 +368,7 @@ def auto_clean_clustering( # we use (radius_chans,) = np.nonzero( - (channel_distances[main_chan0, :] < radius_um) | (channel_distances[main_chan1, :] < radius_um) + (channel_distances[main_chan0, :] <= radius_um) | (channel_distances[main_chan1, :] <= radius_um) ) if radius_chans.size < (intersect_chans.size * ratio_num_channel_intersect): # ~ print('WARNING INTERSECT') diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 24ec923f06..285a9ff2f2 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -291,7 +291,7 @@ def find_merge_pairs( template_locs = channel_locs[max_chans, :] template_dist = scipy.spatial.distance.cdist(template_locs, template_locs, metric="euclidean") - pair_mask = pair_mask & (template_dist < radius_um) + pair_mask = pair_mask & (template_dist <= radius_um) indices0, indices1 = np.nonzero(pair_mask) n_jobs = job_kwargs["n_jobs"] @@ -337,7 +337,7 @@ def find_merge_pairs( pair_shift[ind0, ind1] = shift pair_values[ind0, ind1] = merge_value - pair_mask = pair_mask & (template_dist < radius_um) + pair_mask = pair_mask & (template_dist <= radius_um) indices0, indices1 = np.nonzero(pair_mask) return labels_set, pair_mask, pair_shift, pair_values diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index f7f020d153..4006939b22 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -119,7 +119,7 @@ def __init__( self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self.all_channels = all_channels self._kwargs.update(dict(radius_um=radius_um, all_channels=all_channels)) self._dtype = recording.get_dtype() @@ -157,7 +157,7 @@ def __init__( self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self._kwargs.update(dict(radius_um=radius_um, all_channels=all_channels)) self._dtype = recording.get_dtype() @@ -202,7 +202,7 @@ def __init__( self.sigmoid = sigmoid self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self.radius_um = radius_um self.sparse = sparse self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um, sparse=sparse)) @@ -253,7 +253,7 @@ def __init__( self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self.projections = projections self.min_values = min_values @@ -288,7 +288,7 @@ def __init__(self, recording, name="std_ptp_feature", return_output=True, parent self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self._kwargs.update(dict(radius_um=radius_um)) @@ -313,7 +313,7 @@ def __init__(self, recording, name="global_ptp_feature", return_output=True, par self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self._kwargs.update(dict(radius_um=radius_um)) @@ -338,7 +338,7 @@ def __init__(self, recording, name="kurtosis_ptp_feature", return_output=True, p self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self._kwargs.update(dict(radius_um=radius_um)) @@ -365,7 +365,7 @@ def __init__(self, recording, name="energy_feature", return_output=True, parents self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self._kwargs.update(dict(radius_um=radius_um)) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index e66c8be874..22438c0934 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -542,7 +542,7 @@ def check_params( ) channel_distance = get_channel_distances(recording) - neighbours_mask = channel_distance < radius_um + neighbours_mask = channel_distance <= radius_um return args + (neighbours_mask,) @classmethod @@ -624,7 +624,7 @@ def check_params( neighbour_indices_by_chan = [] num_channels = recording.get_num_channels() for chan in range(num_channels): - neighbour_indices_by_chan.append(np.nonzero(channel_distance[chan] < radius_um)[0]) + neighbour_indices_by_chan.append(np.nonzero(channel_distance[chan] <= radius_um)[0]) max_neighbs = np.max([len(neigh) for neigh in neighbour_indices_by_chan]) neighbours_idxs = num_channels * np.ones((num_channels, max_neighbs), dtype=int) for i, neigh in enumerate(neighbour_indices_by_chan): @@ -856,7 +856,7 @@ def check_params( abs_threholds = noise_levels * detect_threshold exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) channel_distance = get_channel_distances(recording) - neighbours_mask = channel_distance < radius_um + neighbours_mask = channel_distance <= radius_um executor = OpenCLDetectPeakExecutor(abs_threholds, exclude_sweep_size, neighbours_mask, peak_sign) From cdc8d58492725e6dcd5c3d624ac5c59902e45b78 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 1 Dec 2023 16:43:21 +0100 Subject: [PATCH 14/19] Sam's suggestions --- src/spikeinterface/preprocessing/motion.py | 6 ++++-- src/spikeinterface/sortingcomponents/motion_estimation.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index f2c7983f2b..c81630fc1b 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -275,6 +275,8 @@ def correct_motion( recording.dump_to_json(folder / "recording.json") if not do_selection: + # maybe do this directly in the folder when not None, but might be slow on external storage + gather_mode = "memory" # node detect method = detect_kwargs.pop("method", "locally_exclusive") method_class = detect_peak_methods[method] @@ -293,8 +295,8 @@ def correct_motion( pipeline_nodes, job_kwargs, job_name="detect and localize", - gather_mode="memory", - gather_kwargs={"exist_ok": True}, + gather_mode=gather_mode, + gather_kwargs=None, squeeze_output=False, folder=None, names=None, diff --git a/src/spikeinterface/sortingcomponents/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion_estimation.py index 1345bd312c..8eb9cafe9d 100644 --- a/src/spikeinterface/sortingcomponents/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion_estimation.py @@ -220,7 +220,7 @@ class DecentralizedRegistration: pairwise_displacement_method: "conv" or "phase_cross_correlation" How to estimate the displacement in the pairwise matrix. max_displacement_um: float - Maximum possible discplacement in micrometers. + Maximum possible displacement in micrometers. weight_scale: "linear" or "exp" For parwaise displacement, how to to rescale the associated weight matrix. error_sigma: float, default: 0.2 From 6da585a1589216fb269966fead7b439d0fed8ef3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 1 Dec 2023 16:45:08 +0100 Subject: [PATCH 15/19] Make sure sampling frequency is always float --- src/spikeinterface/core/baserecordingsnippets.py | 2 +- src/spikeinterface/core/basesorting.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 5d0d1b130a..c2386d0af0 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -21,7 +21,7 @@ class BaseRecordingSnippets(BaseExtractor): def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype: np.dtype): BaseExtractor.__init__(self, channel_ids) - self._sampling_frequency = sampling_frequency + self._sampling_frequency = float(sampling_frequency) self._dtype = np.dtype(dtype) @property diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 2535009642..50fa2d01b7 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -20,7 +20,7 @@ class BaseSorting(BaseExtractor): def __init__(self, sampling_frequency: float, unit_ids: List): BaseExtractor.__init__(self, unit_ids) - self._sampling_frequency = sampling_frequency + self._sampling_frequency = float(sampling_frequency) self._sorting_segments: List[BaseSortingSegment] = [] # this weak link is to handle times from a recording object self._recording = None From 296985f1ec8fd6bcf19a7ca0660216b995b11888 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 1 Dec 2023 16:48:33 +0100 Subject: [PATCH 16/19] Use sampling_frequency instead of get_sampling_frequency in _make_bins --- src/spikeinterface/postprocessing/correlograms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 369354fe04..1b82548c15 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -68,7 +68,7 @@ def get_extension_function(): def _make_bins(sorting, window_ms, bin_ms): - fs = sorting.get_sampling_frequency() + fs = sorting.sampling_frequency window_size = int(round(fs * window_ms / 2 * 1e-3)) bin_size = int(round(fs * bin_ms * 1e-3)) From e58967e988067b43ea0f4f9eb732793b1e70f74f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 1 Dec 2023 16:52:41 +0100 Subject: [PATCH 17/19] Avoid loading channel_name property in nwb recording --- src/spikeinterface/extractors/nwbextractors.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index e4c6e264fc..4b604e9aea 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -376,6 +376,9 @@ def __init__( for column in electrodes_table.colnames: if isinstance(electrodes_table[column][electrode_table_index], ElectrodeGroup): continue + if column == "channel_name": + # channel_names are already set as channel ids! + continue elif column == "group_name": group = unique_electrode_group_names.index(electrodes_table[column][electrode_table_index]) if "group" not in properties: From 2ee688ca4bf2f8fddcfa247ad51be6b4acbef422 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 1 Dec 2023 16:54:42 +0100 Subject: [PATCH 18/19] if -> elif --- src/spikeinterface/extractors/nwbextractors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 4b604e9aea..24e400bdaf 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -376,7 +376,7 @@ def __init__( for column in electrodes_table.colnames: if isinstance(electrodes_table[column][electrode_table_index], ElectrodeGroup): continue - if column == "channel_name": + elif column == "channel_name": # channel_names are already set as channel ids! continue elif column == "group_name": From a87f34b2b317eddd56da9e24706b750af0702e45 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 1 Dec 2023 13:39:54 -0500 Subject: [PATCH 19/19] Add warning and clip for plot traces --- src/spikeinterface/widgets/traces.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 45ae9c91aa..ae8a02f4a7 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -120,6 +120,12 @@ def __init__( if time_range is None: time_range = (0, 1.0) time_range = np.array(time_range) + if time_range[1] > rec0.get_duration(segment_index=segment_index): + warnings.warn( + "You have selected a time after the end of the segment. The range will be clipped to " + f"{rec0.get_duration(segment_index=segment_index)}" + ) + time_range[1] = rec0.get_duration(segment_index=segment_index) assert mode in ("auto", "line", "map"), 'Mode must be one of "auto","line", "map"' if mode == "auto":