diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 5d0d1b130a..c2386d0af0 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -21,7 +21,7 @@ class BaseRecordingSnippets(BaseExtractor): def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype: np.dtype): BaseExtractor.__init__(self, channel_ids) - self._sampling_frequency = sampling_frequency + self._sampling_frequency = float(sampling_frequency) self._dtype = np.dtype(dtype) @property diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 2535009642..50fa2d01b7 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -20,7 +20,7 @@ class BaseSorting(BaseExtractor): def __init__(self, sampling_frequency: float, unit_ids: List): BaseExtractor.__init__(self, unit_ids) - self._sampling_frequency = sampling_frequency + self._sampling_frequency = float(sampling_frequency) self._sorting_segments: List[BaseSortingSegment] = [] # this weak link is to handle times from a recording object self._recording = None diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index a00df98e05..fd8dbd35b6 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -175,7 +175,7 @@ def __init__( if not channel_from_template: channel_distance = get_channel_distances(recording) - self.neighbours_mask = channel_distance < radius_um + self.neighbours_mask = channel_distance <= radius_um self.peak_sign = peak_sign # precompute segment slice @@ -367,7 +367,7 @@ def __init__( self.radius_um = radius_um self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self.max_num_chans = np.max(np.sum(self.neighbours_mask, axis=1)) def get_trace_margin(self): diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index d8284c7abe..24e400bdaf 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -192,8 +192,8 @@ class NwbRecordingExtractor(BaseRecording): samples_for_rate_estimation: int, default: 100000 The number of timestamp samples to use to estimate the rate. Used if "rate" is not specified in the ElectricalSeries. - stream_mode: str or None, default: None - Specify the stream mode: "fsspec" or "ros3". + stream_mode : "fsspec" | "ros3" | "remfile" | None, default: None + The streaming mode to use. If None it assumes the file is on the local disk. cache: bool, default: False If True, the file is cached in the file passed to stream_cache_path if False, the file is not cached. @@ -376,6 +376,9 @@ def __init__( for column in electrodes_table.colnames: if isinstance(electrodes_table[column][electrode_table_index], ElectrodeGroup): continue + elif column == "channel_name": + # channel_names are already set as channel ids! + continue elif column == "group_name": group = unique_electrode_group_names.index(electrodes_table[column][electrode_table_index]) if "group" not in properties: @@ -412,12 +415,11 @@ def __init__( else: self.set_property(property_name, values) - if stream_mode not in ["fsspec", "ros3", "remfile"]: - if file_path is not None: - file_path = str(Path(file_path).absolute()) - if stream_mode == "fsspec": - if stream_cache_path is not None: - stream_cache_path = str(Path(self.stream_cache_path).absolute()) + if stream_mode is None and file_path is not None: + file_path = str(Path(file_path).resolve()) + + if stream_mode == "fsspec" and stream_cache_path is not None: + stream_cache_path = str(Path(self.stream_cache_path).absolute()) self.extra_requirements.extend(["pandas", "pynwb", "hdmf"]) self._electrical_series = electrical_series @@ -493,8 +495,8 @@ class NwbSortingExtractor(BaseSorting): samples_for_rate_estimation: int, default: 100000 The number of timestamp samples to use to estimate the rate. Used if "rate" is not specified in the ElectricalSeries. - stream_mode: str or None, default: None - Specify the stream mode: "fsspec" or "ros3". + stream_mode : "fsspec" | "ros3" | "remfile" | None, default: None + The streaming mode to use. If None it assumes the file is on the local disk. cache: bool, default: False If True, the file is cached in the file passed to stream_cache_path if False, the file is not cached. @@ -591,12 +593,9 @@ def __init__( for prop_name, values in properties.items(): self.set_property(prop_name, np.array(values)) - if stream_mode not in ["fsspec", "ros3"]: - file_path = str(Path(file_path).absolute()) - if stream_mode == "fsspec": - # only add stream_cache_path to kwargs if it was passed as an argument - if stream_cache_path is not None: - stream_cache_path = str(Path(self.stream_cache_path).absolute()) + if stream_mode is None and file_path is not None: + file_path = str(Path(file_path).resolve()) + self._kwargs = { "file_path": file_path, "electrical_series_name": self._electrical_series_name, diff --git a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py index 45d969dde9..9183c5b728 100644 --- a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py +++ b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py @@ -220,6 +220,38 @@ def test_sorting_s3_nwb_fsspec(tmp_path, cache): check_sortings_equal(reloaded_sorting, sorting) +@pytest.mark.streaming_extractors +def test_sorting_s3_nwb_remfile(tmp_path): + file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" + # We provide the 'sampling_frequency' because the NWB file does not have the electrical series + sorting = NwbSortingExtractor( + file_path, + sampling_frequency=30000.0, + stream_mode="remfile", + ) + + num_seg = sorting.get_num_segments() + assert num_seg == 1 + num_units = len(sorting.unit_ids) + assert num_units == 64 + + for segment_index in range(num_seg): + for unit in sorting.unit_ids: + spike_train = sorting.get_unit_spike_train(unit_id=unit, segment_index=segment_index) + assert len(spike_train) > 0 + assert spike_train.dtype == "int64" + assert np.all(spike_train >= 0) + + tmp_file = tmp_path / "test_remfile_sorting.pkl" + with open(tmp_file, "wb") as f: + pickle.dump(sorting, f) + + with open(tmp_file, "rb") as f: + reloaded_sorting = pickle.load(f) + + check_sortings_equal(reloaded_sorting, sorting) + + if __name__ == "__main__": test_recording_s3_nwb_ros3() test_recording_s3_nwb_fsspec() diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 369354fe04..1b82548c15 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -68,7 +68,7 @@ def get_extension_function(): def _make_bins(sorting, window_ms, bin_ms): - fs = sorting.get_sampling_frequency() + fs = sorting.sampling_frequency window_size = int(round(fs * window_ms / 2 * 1e-3)) bin_size = int(round(fs * bin_ms * 1e-3)) diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index f665bac8d6..2ac841c148 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -597,7 +597,7 @@ def get_grid_convolution_templates_and_weights( # mask to get nearest template given a channel dist = sklearn.metrics.pairwise_distances(contact_locations, template_positions) - nearest_template_mask = dist < radius_um + nearest_template_mask = dist <= radius_um weights = np.zeros((len(sigma_um), len(contact_locations), nb_templates), dtype=np.float32) for count, sigma in enumerate(sigma_um): diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index a162cfe636..1fdd7737d0 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -1,29 +1,32 @@ +from __future__ import annotations import warnings import numpy as np +from typing import Literal from .filter import highpass_filter -from ..core import get_random_data_chunks, order_channels_by_depth +from ..core import get_random_data_chunks, order_channels_by_depth, BaseRecording def detect_bad_channels( - recording, - method="coherence+psd", - std_mad_threshold=5, - psd_hf_threshold=0.02, - dead_channel_threshold=-0.5, - noisy_channel_threshold=1.0, - outside_channel_threshold=-0.75, - n_neighbors=11, - nyquist_threshold=0.8, - direction="y", - chunk_duration_s=0.3, - num_random_chunks=100, - welch_window_ms=10.0, - highpass_filter_cutoff=300, - neighborhood_r2_threshold=0.9, - neighborhood_r2_radius_um=30.0, - seed=None, + recording: BaseRecording, + method: str = "coherence+psd", + std_mad_threshold: float = 5, + psd_hf_threshold: float = 0.02, + dead_channel_threshold: float = -0.5, + noisy_channel_threshold: float = 1.0, + outside_channel_threshold: float = -0.75, + outside_channels_location: Literal["top", "bottom", "both"] = "top", + n_neighbors: int = 11, + nyquist_threshold: float = 0.8, + direction: Literal["x", "y", "z"] = "y", + chunk_duration_s: float = 0.3, + num_random_chunks: int = 100, + welch_window_ms: float = 10.0, + highpass_filter_cutoff: float = 300, + neighborhood_r2_threshold: float = 0.9, + neighborhood_r2_radius_um: float = 30.0, + seed: int | None = None, ): """ Perform bad channel detection. @@ -65,6 +68,11 @@ def detect_bad_channels( outside_channel_threshold (coeherence+psd) : float, default: -0.75 Threshold for channel coherence above which channels at the edge of the recording are marked as outside of the brain + outside_channels_location (coeherence+psd) : "top" | "bottom" | "both", default: "top" + Location of the outside channels. If "top", only the channels at the top of the probe can be + marked as outside channels. If "bottom", only the channels at the bottom of the probe can be + marked as outside channels. If "both", both the channels at the top and bottom of the probe can be + marked as outside channels n_neighbors (coeherence+psd) : int, default: 11 Number of channel neighbors to compute median filter (needs to be odd) nyquist_threshold (coeherence+psd) : float, default: 0.8 @@ -190,6 +198,7 @@ def detect_bad_channels( n_neighbors=n_neighbors, nyquist_threshold=nyquist_threshold, welch_window_ms=welch_window_ms, + outside_channels_location=outside_channels_location, ) chunk_channel_labels[:, i] = chunk_labels[order_r] if order_r is not None else chunk_labels @@ -275,6 +284,7 @@ def detect_bad_channels_ibl( n_neighbors=11, nyquist_threshold=0.8, welch_window_ms=0.3, + outside_channels_location="top", ): """ Bad channels detection for Neuropixel probes developed by IBL @@ -300,6 +310,11 @@ def detect_bad_channels_ibl( Threshold on Nyquist frequency to calculate HF noise band welch_window_ms: float, default: 0.3 Window size for the scipy.signal.welch that will be converted to nperseg + outside_channels_location : "top" | "bottom" | "both", default: "top" + Location of the outside channels. If "top", only the channels at the top of the probe can be + marked as outside channels. If "bottom", only the channels at the bottom of the probe can be + marked as outside channels. If "both", both the channels at the top and bottom of the probe can be + marked as outside channels Returns ------- @@ -332,12 +347,24 @@ def detect_bad_channels_ibl( ichannels[inoisy] = 2 # the channels outside of the brains are the contiguous channels below the threshold on the trend coherency - # the chanels outide need to be at either extremes of the probe - ioutside = np.where(xcorr_distant < outside_channel_thr)[0] - if ioutside.size > 0 and (ioutside[-1] == (nc - 1) or ioutside[0] == 0): - a = np.cumsum(np.r_[0, np.diff(ioutside) - 1]) - ioutside = ioutside[a == np.max(a)] - ichannels[ioutside] = 3 + # the chanels outside need to be at the extreme of the probe + (ioutside,) = np.where(xcorr_distant < outside_channel_thr) + a = np.cumsum(np.r_[0, np.diff(ioutside) - 1]) + if ioutside.size > 0: + if outside_channels_location == "top": + # channels are sorted bottom to top, so the last channel needs to be (nc - 1) + if ioutside[-1] == (nc - 1): + ioutside = ioutside[(a == np.max(a)) & (a > 0)] + ichannels[ioutside] = 3 + elif outside_channels_location == "bottom": + # outside channels are at the bottom of the probe, so the first channel needs to be 0 + if ioutside[0] == 0: + ioutside = ioutside[(a == np.min(a)) & (a < np.max(a))] + ichannels[ioutside] = 3 + else: # both extremes are considered + if ioutside[-1] == (nc - 1) or ioutside[0] == 0: + ioutside = ioutside[(a == np.max(a)) | (a == np.min(a))] + ichannels[ioutside] = 3 return ichannels diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 3992a4c8c6..c81630fc1b 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -20,7 +20,7 @@ exclude_sweep_ms=0.1, radius_um=50, ), - "select_kwargs": None, + "select_kwargs": dict(), "localize_peaks_kwargs": dict( method="monopolar_triangulation", radius_um=75.0, @@ -83,7 +83,7 @@ exclude_sweep_ms=0.1, radius_um=50, ), - "select_kwargs": None, + "select_kwargs": dict(), "localize_peaks_kwargs": dict( method="center_of_mass", radius_um=75.0, @@ -111,7 +111,7 @@ exclude_sweep_ms=0.1, radius_um=50, ), - "select_kwargs": None, + "select_kwargs": dict(), "localize_peaks_kwargs": dict( method="grid_convolution", radius_um=40.0, @@ -157,7 +157,7 @@ def correct_motion( folder=None, output_motion_info=False, detect_kwargs={}, - select_kwargs=None, + select_kwargs={}, localize_peaks_kwargs={}, estimate_motion_kwargs={}, interpolate_motion_kwargs={}, @@ -241,13 +241,22 @@ def correct_motion( # get preset params and update if necessary params = motion_options_preset[preset] detect_kwargs = dict(params["detect_kwargs"], **detect_kwargs) - if params["select_kwargs"] is None: - select_kwargs = None - else: - select_kwargs = dict(params["select_kwargs"], **select_kwargs) + select_kwargs = dict(params["select_kwargs"], **select_kwargs) localize_peaks_kwargs = dict(params["localize_peaks_kwargs"], **localize_peaks_kwargs) estimate_motion_kwargs = dict(params["estimate_motion_kwargs"], **estimate_motion_kwargs) interpolate_motion_kwargs = dict(params["interpolate_motion_kwargs"], **interpolate_motion_kwargs) + do_selection = len(select_kwargs) > 0 + + # params + parameters = dict( + detect_kwargs=detect_kwargs, + select_kwargs=select_kwargs, + localize_peaks_kwargs=localize_peaks_kwargs, + estimate_motion_kwargs=estimate_motion_kwargs, + interpolate_motion_kwargs=interpolate_motion_kwargs, + job_kwargs=job_kwargs, + sampling_frequency=recording.sampling_frequency, + ) if output_motion_info: motion_info = {} @@ -255,13 +264,19 @@ def correct_motion( motion_info = None job_kwargs = fix_job_kwargs(job_kwargs) - noise_levels = get_noise_levels(recording, return_scaled=False) - if select_kwargs is None: - # maybe do this directly in the folder when not None - gather_mode = "memory" + if folder is not None: + folder = Path(folder) + folder.mkdir(exist_ok=True, parents=True) + + (folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8") + if recording.check_serializability("json"): + recording.dump_to_json(folder / "recording.json") + if not do_selection: + # maybe do this directly in the folder when not None, but might be slow on external storage + gather_mode = "memory" # node detect method = detect_kwargs.pop("method", "locally_exclusive") method_class = detect_peak_methods[method] @@ -281,6 +296,7 @@ def correct_motion( job_kwargs, job_name="detect and localize", gather_mode=gather_mode, + gather_kwargs=None, squeeze_output=False, folder=None, names=None, @@ -307,6 +323,9 @@ def correct_motion( select_peaks=t2 - t1, localize_peaks=t3 - t2, ) + if folder is not None: + np.save(folder / "peaks.npy", peaks) + np.save(folder / "peak_locations.npy", peak_locations) t0 = time.perf_counter() motion, temporal_bins, spatial_bins = estimate_motion(recording, peaks, peak_locations, **estimate_motion_kwargs) @@ -318,29 +337,10 @@ def correct_motion( ) if folder is not None: - folder = Path(folder) - folder.mkdir(exist_ok=True, parents=True) - - # params and run times - parameters = dict( - detect_kwargs=detect_kwargs, - select_kwargs=select_kwargs, - localize_peaks_kwargs=localize_peaks_kwargs, - estimate_motion_kwargs=estimate_motion_kwargs, - interpolate_motion_kwargs=interpolate_motion_kwargs, - job_kwargs=job_kwargs, - sampling_frequency=recording.sampling_frequency, - ) - (folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8") (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") - if recording.check_serializability("json"): - recording.dump_to_json(folder / "recording.json") - np.save(folder / "peaks.npy", peaks) - np.save(folder / "peak_locations.npy", peak_locations) np.save(folder / "temporal_bins.npy", temporal_bins) np.save(folder / "motion.npy", motion) - np.save(folder / "peak_locations.npy", peak_locations) if spatial_bins is not None: np.save(folder / "spatial_bins.npy", spatial_bins) diff --git a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py index c2de263063..4071bfe0ea 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py @@ -19,7 +19,7 @@ HAVE_NPIX = False -def test_remove_bad_channels_std_mad(): +def test_detect_bad_channels_std_mad(): num_channels = 4 sampling_frequency = 30000.0 durations = [10.325, 3.5] @@ -60,9 +60,48 @@ def test_remove_bad_channels_std_mad(): ), "wrong channels locations." +@pytest.mark.parametrize("outside_channels_location", ["bottom", "top", "both"]) +def test_detect_bad_channels_extremes(outside_channels_location): + num_channels = 64 + sampling_frequency = 30000.0 + durations = [20] + num_out_channels = 10 + + num_segments = len(durations) + num_timepoints = [int(sampling_frequency * d) for d in durations] + + traces_list = [] + for i in range(num_segments): + traces = np.random.randn(num_timepoints[i], num_channels).astype("float32") + # extreme channels are "out" + traces[:, :num_out_channels] *= 0.05 + traces[:, -num_out_channels:] *= 0.05 + traces_list.append(traces) + + rec = NumpyRecording(traces_list, sampling_frequency) + rec.set_channel_gains(1) + rec.set_channel_offsets(0) + + probe = generate_linear_probe(num_elec=num_channels) + probe.set_device_channel_indices(np.arange(num_channels)) + rec.set_probe(probe, in_place=True) + + bad_channel_ids, bad_labels = detect_bad_channels( + rec, method="coherence+psd", outside_channels_location=outside_channels_location + ) + if outside_channels_location == "top": + assert np.array_equal(bad_channel_ids, rec.channel_ids[-num_out_channels:]) + elif outside_channels_location == "bottom": + assert np.array_equal(bad_channel_ids, rec.channel_ids[:num_out_channels]) + elif outside_channels_location == "both": + assert np.array_equal( + bad_channel_ids, np.concatenate((rec.channel_ids[:num_out_channels], rec.channel_ids[-num_out_channels:])) + ) + + @pytest.mark.skipif(not HAVE_NPIX, reason="ibl-neuropixel is not installed") @pytest.mark.parametrize("num_channels", [32, 64, 384]) -def test_remove_bad_channels_ibl(num_channels): +def test_detect_bad_channels_ibl(num_channels): """ Cannot test against DL datasets because they are too short and need to control the PSD scaling. Here generate a dataset @@ -121,7 +160,9 @@ def test_remove_bad_channels_ibl(num_channels): traces_uV = random_chunk.T traces_V = traces_uV * 1e-6 channel_flags, _ = neurodsp.voltage.detect_bad_channels( - traces_V, recording.get_sampling_frequency(), psd_hf_threshold=psd_cutoff + traces_V, + recording.get_sampling_frequency(), + psd_hf_threshold=psd_cutoff, ) channel_flags_ibl[:, i] = channel_flags @@ -209,5 +250,10 @@ def add_dead_channels(recording, is_dead): if __name__ == "__main__": - test_remove_bad_channels_std_mad() - test_remove_bad_channels_ibl(num_channels=384) + # test_detect_bad_channels_std_mad() + test_detect_bad_channels_ibl(num_channels=32) + test_detect_bad_channels_ibl(num_channels=64) + test_detect_bad_channels_ibl(num_channels=384) + # test_detect_bad_channels_extremes("top") + # test_detect_bad_channels_extremes("bottom") + # test_detect_bad_channels_extremes("both") diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 3bea9b91bb..766229b62a 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -197,7 +197,7 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r distances = get_channel_distances(recording) W = np.zeros((n, n), dtype="float64") for c in range(n): - (inds,) = np.nonzero(distances[c, :] < radius_um) + (inds,) = np.nonzero(distances[c, :] <= radius_um) cov_local = cov[inds, :][:, inds] U, S, Ut = np.linalg.svd(cov_local, full_matrices=True) W_local = (U @ np.diag(1 / np.sqrt(S + eps))) @ Ut diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index da8b48085c..7591c9eb2c 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -2,6 +2,7 @@ import os from pathlib import Path import json +import pickle import platform from warnings import warn from typing import Optional, Union @@ -414,9 +415,15 @@ def run_sorter_container( # create 3 files for communication with container # recording dict inside - (parent_folder / "in_container_recording.json").write_text( - json.dumps(check_json(rec_dict), indent=4), encoding="utf8" - ) + if recording.check_serializability("json"): + (parent_folder / "in_container_recording.json").write_text( + json.dumps(check_json(rec_dict), indent=4), encoding="utf8" + ) + elif recording.check_serializability("pickle"): + (parent_folder / "in_container_recording.pickle").write_bytes(pickle.dumps(rec_dict)) + else: + raise RuntimeError("To use run_sorter with container the recording must be serializable") + # need to share specific parameters (parent_folder / "in_container_params.json").write_text( json.dumps(check_json(sorter_params), indent=4), encoding="utf8" @@ -433,13 +440,19 @@ def run_sorter_container( # the py script py_script = f""" import json +from pathlib import Path from spikeinterface import load_extractor from spikeinterface.sorters import run_sorter_local if __name__ == '__main__': # this __name__ protection help in some case with multiprocessing (for instance HS2) # load recording in container - recording = load_extractor('{parent_folder_unix}/in_container_recording.json') + json_rec = Path('{parent_folder_unix}/in_container_recording.json') + pickle_rec = Path('{parent_folder_unix}/in_container_recording.pickle') + if json_rec.exists(): + recording = load_extractor(json_rec) + else: + recording = load_extractor(pickle_rec) # load params in container with open('{parent_folder_unix}/in_container_params.json', encoding='utf8', mode='r') as f: @@ -593,7 +606,10 @@ def run_sorter_container( # clean useless files if delete_container_files: - os.remove(parent_folder / "in_container_recording.json") + if (parent_folder / "in_container_recording.json").exists(): + os.remove(parent_folder / "in_container_recording.json") + if (parent_folder / "in_container_recording.pickle").exists(): + os.remove(parent_folder / "in_container_recording.pickle") os.remove(parent_folder / "in_container_params.json") os.remove(parent_folder / "in_container_sorter_script.py") if mode == "singularity": diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 24ec923f06..285a9ff2f2 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -291,7 +291,7 @@ def find_merge_pairs( template_locs = channel_locs[max_chans, :] template_dist = scipy.spatial.distance.cdist(template_locs, template_locs, metric="euclidean") - pair_mask = pair_mask & (template_dist < radius_um) + pair_mask = pair_mask & (template_dist <= radius_um) indices0, indices1 = np.nonzero(pair_mask) n_jobs = job_kwargs["n_jobs"] @@ -337,7 +337,7 @@ def find_merge_pairs( pair_shift[ind0, ind1] = shift pair_values[ind0, ind1] = merge_value - pair_mask = pair_mask & (template_dist < radius_um) + pair_mask = pair_mask & (template_dist <= radius_um) indices0, indices1 = np.nonzero(pair_mask) return labels_set, pair_mask, pair_shift, pair_values diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index f7f020d153..4006939b22 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -119,7 +119,7 @@ def __init__( self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self.all_channels = all_channels self._kwargs.update(dict(radius_um=radius_um, all_channels=all_channels)) self._dtype = recording.get_dtype() @@ -157,7 +157,7 @@ def __init__( self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self._kwargs.update(dict(radius_um=radius_um, all_channels=all_channels)) self._dtype = recording.get_dtype() @@ -202,7 +202,7 @@ def __init__( self.sigmoid = sigmoid self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self.radius_um = radius_um self.sparse = sparse self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um, sparse=sparse)) @@ -253,7 +253,7 @@ def __init__( self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self.projections = projections self.min_values = min_values @@ -288,7 +288,7 @@ def __init__(self, recording, name="std_ptp_feature", return_output=True, parent self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self._kwargs.update(dict(radius_um=radius_um)) @@ -313,7 +313,7 @@ def __init__(self, recording, name="global_ptp_feature", return_output=True, par self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self._kwargs.update(dict(radius_um=radius_um)) @@ -338,7 +338,7 @@ def __init__(self, recording, name="kurtosis_ptp_feature", return_output=True, p self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self._kwargs.update(dict(radius_um=radius_um)) @@ -365,7 +365,7 @@ def __init__(self, recording, name="energy_feature", return_output=True, parents self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self._kwargs.update(dict(radius_um=radius_um)) diff --git a/src/spikeinterface/sortingcomponents/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion_estimation.py index df73575a01..8eb9cafe9d 100644 --- a/src/spikeinterface/sortingcomponents/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion_estimation.py @@ -220,7 +220,7 @@ class DecentralizedRegistration: pairwise_displacement_method: "conv" or "phase_cross_correlation" How to estimate the displacement in the pairwise matrix. max_displacement_um: float - Maximum possible discplacement in micrometers. + Maximum possible displacement in micrometers. weight_scale: "linear" or "exp" For parwaise displacement, how to to rescale the associated weight matrix. error_sigma: float, default: 0.2 @@ -1039,6 +1039,7 @@ def jac(p): displacement = p elif convergence_method == "lsmr": + import gc from scipy import sparse from scipy.stats import zscore @@ -1170,6 +1171,9 @@ def jac(p): # warm start next iteration p0 = displacement + # Cleanup lsmr memory (see https://stackoverflow.com/questions/56147713/memory-leak-in-scipy) + # TODO: check if this gets fixed in scipy + gc.collect() displacement = displacement.reshape(B, T).T else: diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index e66c8be874..22438c0934 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -542,7 +542,7 @@ def check_params( ) channel_distance = get_channel_distances(recording) - neighbours_mask = channel_distance < radius_um + neighbours_mask = channel_distance <= radius_um return args + (neighbours_mask,) @classmethod @@ -624,7 +624,7 @@ def check_params( neighbour_indices_by_chan = [] num_channels = recording.get_num_channels() for chan in range(num_channels): - neighbour_indices_by_chan.append(np.nonzero(channel_distance[chan] < radius_um)[0]) + neighbour_indices_by_chan.append(np.nonzero(channel_distance[chan] <= radius_um)[0]) max_neighbs = np.max([len(neigh) for neigh in neighbour_indices_by_chan]) neighbours_idxs = num_channels * np.ones((num_channels, max_neighbs), dtype=int) for i, neigh in enumerate(neighbour_indices_by_chan): @@ -856,7 +856,7 @@ def check_params( abs_threholds = noise_levels * detect_threshold exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) channel_distance = get_channel_distances(recording) - neighbours_mask = channel_distance < radius_um + neighbours_mask = channel_distance <= radius_um executor = OpenCLDetectPeakExecutor(abs_threholds, exclude_sweep_size, neighbours_mask, peak_sign) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 45ae9c91aa..ae8a02f4a7 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -120,6 +120,12 @@ def __init__( if time_range is None: time_range = (0, 1.0) time_range = np.array(time_range) + if time_range[1] > rec0.get_duration(segment_index=segment_index): + warnings.warn( + "You have selected a time after the end of the segment. The range will be clipped to " + f"{rec0.get_duration(segment_index=segment_index)}" + ) + time_range[1] = rec0.get_duration(segment_index=segment_index) assert mode in ("auto", "line", "map"), 'Mode must be one of "auto","line", "map"' if mode == "auto":