From bb0de4f2b1a0782f18ea31bbea5b6ec560fdc445 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 10 May 2024 11:03:36 +0200 Subject: [PATCH 1/5] Propagate #2649 and #2810 --- .../preprocessing/common_reference.py | 6 +- .../tests/test_common_reference.py | 55 ++++++++++++++++--- .../sorters/external/kilosort4.py | 9 ++- 3 files changed, 57 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 07920b61ca..a4369309f2 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -189,12 +189,12 @@ def get_traces(self, start_frame, end_frame, channel_indices): shift = traces[:, self.ref_channel_indices] re_referenced_traces = traces[:, channel_indices] - shift else: # then it must be local - re_referenced_traces = np.zeros_like(traces[:, channel_indices]) channel_indices_array = np.arange(traces.shape[1])[channel_indices] - for channel_index in channel_indices_array: + re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32") + for i, channel_index in enumerate(channel_indices_array): channel_neighborhood = self.neighbors[channel_index] channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1) - re_referenced_traces[:, channel_index] = traces[:, channel_index] - channel_shift + re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift return re_referenced_traces.astype(self.dtype, copy=False) diff --git a/src/spikeinterface/preprocessing/tests/test_common_reference.py b/src/spikeinterface/preprocessing/tests/test_common_reference.py index e64775df25..1df9b21c81 100644 --- a/src/spikeinterface/preprocessing/tests/test_common_reference.py +++ b/src/spikeinterface/preprocessing/tests/test_common_reference.py @@ -10,7 +10,7 @@ def _generate_test_recording(): - recording = generate_recording(durations=[5.0], num_channels=4) + recording = generate_recording(durations=[1.0], num_channels=4) recording = recording.channel_slice(recording.channel_ids, np.array(["a", "b", "c", "d"])) return recording @@ -46,11 +46,15 @@ def test_common_reference(recording): def test_common_reference_channel_slicing(recording): recording_cmr = common_reference(recording, reference="global", operator="median") recording_car = common_reference(recording, reference="global", operator="average") - recording_single_reference = common_reference(recording, reference="single", ref_channel_ids=["a"]) + recording_single_reference = common_reference(recording, reference="single", ref_channel_ids=["b"]) recording_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="median") - channel_ids = ["a", "b"] - indices = recording.ids_to_indices(["a", "b"]) + channel_ids = ["b", "d"] + indices = recording.ids_to_indices(channel_ids) + + all_channel_ids = recording.channel_ids + all_indices = recording.ids_to_indices(all_channel_ids) + original_traces = recording.get_traces() cmr_trace = recording_cmr.get_traces(channel_ids=channel_ids) @@ -62,13 +66,50 @@ def test_common_reference_channel_slicing(recording): assert np.allclose(car_trace, expected_trace, atol=0.01) single_reference_trace = recording_single_reference.get_traces(channel_ids=channel_ids) - single_reference_index = recording.ids_to_indices(["a"]) + single_reference_index = recording.ids_to_indices(["b"]) expected_trace = original_traces[:, indices] - original_traces[:, single_reference_index] assert np.allclose(single_reference_trace, expected_trace, atol=0.01) - # local car - local_trace = recording_local_car.get_traces(channel_ids=channel_ids) + local_trace = recording_local_car.get_traces(channel_ids=all_channel_ids) + local_trace_sub = recording_local_car.get_traces(channel_ids=channel_ids) + + assert np.all(local_trace[:, indices] == local_trace_sub) + + # test segment slicing + + start_frame = 0 + end_frame = 10 + + recording_segment_cmr = recording_cmr._recording_segments[0] + traces_cmr_all = recording_segment_cmr.get_traces( + start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices + ) + traces_cmr_sub = recording_segment_cmr.get_traces( + start_frame=start_frame, end_frame=end_frame, channel_indices=indices + ) + + assert np.all(traces_cmr_all[:, indices] == traces_cmr_sub) + + recording_segment_car = recording_car._recording_segments[0] + traces_car_all = recording_segment_car.get_traces( + start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices + ) + traces_car_sub = recording_segment_car.get_traces( + start_frame=start_frame, end_frame=end_frame, channel_indices=indices + ) + + assert np.all(traces_car_all[:, indices] == traces_car_sub) + + recording_segment_local = recording_local_car._recording_segments[0] + traces_local_all = recording_segment_local.get_traces( + start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices + ) + traces_local_sub = recording_segment_local.get_traces( + start_frame=start_frame, end_frame=end_frame, channel_indices=indices + ) + + assert np.all(traces_local_all[:, indices] == traces_local_sub) def test_common_reference_groups(recording): diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 90bdc1056d..47846f10ce 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -2,7 +2,6 @@ from pathlib import Path from typing import Union -from packaging.version import parse from ..basesorter import BaseSorter from .kilosortbase import KilosortBase @@ -37,6 +36,7 @@ class Kilosort4Sorter(BaseSorter): "template_sizes": 5, "nearest_chans": 10, "nearest_templates": 100, + "max_channel_distance": None, "templates_from_data": True, "n_templates": 6, "n_pcs": 6, @@ -45,7 +45,8 @@ class Kilosort4Sorter(BaseSorter): "ccg_threshold": 0.25, "cluster_downsampling": 20, "cluster_pcs": 64, - "duplicate_spike_bins": 15, + "x_centers": None, + "duplicate_spike_bins": 7, "do_correction": True, "keep_good_only": False, "save_extra_kwargs": False, @@ -74,6 +75,7 @@ class Kilosort4Sorter(BaseSorter): "template_sizes": "Number of sizes for universal spike templates (multiples of the min_template_size). Default value: 5.", "nearest_chans": "Number of nearest channels to consider when finding local maxima during spike detection. Default value: 10.", "nearest_templates": "Number of nearest spike template locations to consider when finding local maxima during spike detection. Default value: 100.", + "max_channel_distance": "Templates farther away than this from their nearest channel will not be used. Also limits distance between compared channels during clustering. Default value: None.", "templates_from_data": "Indicates whether spike shapes used in universal templates should be estimated from the data or loaded from the predefined templates. Default value: True.", "n_templates": "Number of single-channel templates to use for the universal templates (only used if templates_from_data is True). Default value: 6.", "n_pcs": "Number of single-channel PCs to use for extracting spike features (only used if templates_from_data is True). Default value: 6.", @@ -82,7 +84,8 @@ class Kilosort4Sorter(BaseSorter): "ccg_threshold": "Fraction of refractory period violations that are allowed in the CCG compared to baseline; used to perform splits and merges. Default value: 0.25.", "cluster_downsampling": "Inverse fraction of nodes used as landmarks during clustering (can be 1, but that slows down the optimization). Default value: 20.", "cluster_pcs": "Maximum number of spatiotemporal PC features used for clustering. Default value: 64.", - "duplicate_spike_bins": "Number of bins for which subsequent spikes from the same cluster are assumed to be artifacts. A value of 0 disables this step. Default value: 15.", + "x_centers": "Number of x-positions to use when determining center points for template groupings. If None, this will be determined automatically by finding peaks in channel density. For 2D array type probes, we recommend specifying this so that centers are placed every few hundred microns.", + "duplicate_spike_bins": "Number of bins for which subsequent spikes from the same cluster are assumed to be artifacts. A value of 0 disables this step. Default value: 7.", "keep_good_only": "If True only 'good' units are returned", "do_correction": "If True, drift correction is performed", "save_extra_kwargs": "If True, additional kwargs are saved to the output", From 696e6e19497f3d60b59993ce56493ce9defc9d5b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 10 May 2024 11:11:05 +0200 Subject: [PATCH 2/5] Propagate #2828 --- src/spikeinterface/core/zarrextractors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 4a0c5f8eef..1ba96defac 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -72,7 +72,7 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None) time_kwargs = {} time_vector = self._root.get(f"times_seg{segment_index}", None) if time_vector is not None: - time_kwargs["time_vector"] = time_vector + time_kwargs["time_vector"] = time_vector[:] else: if t_starts is None: t_start = None From 2b58945fff8e9df66dd17ac9f3001d1025570a05 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 10 May 2024 13:22:27 +0200 Subject: [PATCH 3/5] zarr<2.18 to avoid pynwb failure --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c9fc3ecf78..38940876c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ "numpy", "threadpoolctl>=3.0.0", "tqdm", - "zarr>=0.2.16", + "zarr>=0.2.16, <2.18", "neo>=0.13.0", "probeinterface>=0.2.21", ] From 72b52a2371125b5b6d70d6d549b6b5418458795b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 14 May 2024 15:09:16 +0200 Subject: [PATCH 4/5] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 38940876c0..d7a69fbe10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ "numpy", "threadpoolctl>=3.0.0", "tqdm", - "zarr>=0.2.16, <2.18", + "zarr>=2.16, <2.18", "neo>=0.13.0", "probeinterface>=0.2.21", ] From a340c852be81d5235559a2cfe453eeebd8640936 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 7 Jun 2024 16:16:11 +0200 Subject: [PATCH 5/5] fix widget utils and add release notes --- doc/releases/0.100.7.rst | 14 ++++++++++++++ doc/whatisnew.rst | 7 +++++++ src/spikeinterface/widgets/utils.py | 4 ++-- 3 files changed, 23 insertions(+), 2 deletions(-) create mode 100644 doc/releases/0.100.7.rst diff --git a/doc/releases/0.100.7.rst b/doc/releases/0.100.7.rst new file mode 100644 index 0000000000..a224494da5 --- /dev/null +++ b/doc/releases/0.100.7.rst @@ -0,0 +1,14 @@ +.. _release0.100.7: + +SpikeInterface 0.100.7 release notes +------------------------------------ + +7th June 2024 + +Minor release with bug fixes + +* Fix get_traces for a local common reference (#2649) +* Update KS4 parameters (#2810) +* Zarr: extract time vector once and for all! (#2828) +* Fix waveforms save in recordingless mode (#2889) +* Fix the new way of handling cmap in matpltolib. This fix the matplotib 3.9 problem related to this (#2891) diff --git a/doc/whatisnew.rst b/doc/whatisnew.rst index 2ba199eb94..5f35b3efd2 100644 --- a/doc/whatisnew.rst +++ b/doc/whatisnew.rst @@ -8,6 +8,7 @@ Release notes .. toctree:: :maxdepth: 1 + releases/0.100.7.rst releases/0.100.6.rst releases/0.100.5.rst releases/0.100.4.rst @@ -40,6 +41,12 @@ Release notes releases/0.9.1.rst +Version 0.100.7 +=============== + +* Minor release with bug fixes + + Version 0.100.6 =============== diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 29e6474ee9..21dc1a931a 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -76,7 +76,7 @@ def get_some_colors(keys, color_engine="auto", map_name="gist_ncar", format="RGB elif color_engine == "matplotlib": # some map have black or white at border so +10 margin = max(4, int(N * 0.08)) - cmap = plt.get_cmap(map_name, N + 2 * margin) + cmap = plt.colormaps[map_name].resampled(N + 2 * margin) colors = [cmap(i + margin) for i, key in enumerate(keys)] elif color_engine == "colorsys": @@ -153,7 +153,7 @@ def array_to_image( num_channels = data.shape[1] spacing = int(num_channels * spatial_zoom[1] * row_spacing) - cmap = plt.get_cmap(colormap) + cmap = plt.colormaps[colormap] zoomed_data = zoom(data, spatial_zoom) num_timepoints_after_scaling, num_channels_after_scaling = zoomed_data.shape num_timepoints_per_row_after_scaling = int(np.min([num_timepoints_per_row, num_timepoints]) * spatial_zoom[0])