diff --git a/doc/releases/0.100.7.rst b/doc/releases/0.100.7.rst new file mode 100644 index 0000000000..a224494da5 --- /dev/null +++ b/doc/releases/0.100.7.rst @@ -0,0 +1,14 @@ +.. _release0.100.7: + +SpikeInterface 0.100.7 release notes +------------------------------------ + +7th June 2024 + +Minor release with bug fixes + +* Fix get_traces for a local common reference (#2649) +* Update KS4 parameters (#2810) +* Zarr: extract time vector once and for all! (#2828) +* Fix waveforms save in recordingless mode (#2889) +* Fix the new way of handling cmap in matpltolib. This fix the matplotib 3.9 problem related to this (#2891) diff --git a/doc/whatisnew.rst b/doc/whatisnew.rst index 2ba199eb94..5f35b3efd2 100644 --- a/doc/whatisnew.rst +++ b/doc/whatisnew.rst @@ -8,6 +8,7 @@ Release notes .. toctree:: :maxdepth: 1 + releases/0.100.7.rst releases/0.100.6.rst releases/0.100.5.rst releases/0.100.4.rst @@ -40,6 +41,12 @@ Release notes releases/0.9.1.rst +Version 0.100.7 +=============== + +* Minor release with bug fixes + + Version 0.100.6 =============== diff --git a/pyproject.toml b/pyproject.toml index 290393ef51..203454310f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,13 +94,13 @@ full = [ "scikit-learn", "networkx", "distinctipy", - "matplotlib", + "matplotlib>=3.6", "cuda-python; platform_system != 'Darwin'", "numba", ] widgets = [ - "matplotlib", + "matplotlib>=3.6", "ipympl", "ipywidgets", "sortingview>=0.12.0", diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 4a0c5f8eef..1ba96defac 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -72,7 +72,7 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None) time_kwargs = {} time_vector = self._root.get(f"times_seg{segment_index}", None) if time_vector is not None: - time_kwargs["time_vector"] = time_vector + time_kwargs["time_vector"] = time_vector[:] else: if t_starts is None: t_start = None diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 07920b61ca..a4369309f2 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -189,12 +189,12 @@ def get_traces(self, start_frame, end_frame, channel_indices): shift = traces[:, self.ref_channel_indices] re_referenced_traces = traces[:, channel_indices] - shift else: # then it must be local - re_referenced_traces = np.zeros_like(traces[:, channel_indices]) channel_indices_array = np.arange(traces.shape[1])[channel_indices] - for channel_index in channel_indices_array: + re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32") + for i, channel_index in enumerate(channel_indices_array): channel_neighborhood = self.neighbors[channel_index] channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1) - re_referenced_traces[:, channel_index] = traces[:, channel_index] - channel_shift + re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift return re_referenced_traces.astype(self.dtype, copy=False) diff --git a/src/spikeinterface/preprocessing/tests/test_common_reference.py b/src/spikeinterface/preprocessing/tests/test_common_reference.py index e64775df25..1df9b21c81 100644 --- a/src/spikeinterface/preprocessing/tests/test_common_reference.py +++ b/src/spikeinterface/preprocessing/tests/test_common_reference.py @@ -10,7 +10,7 @@ def _generate_test_recording(): - recording = generate_recording(durations=[5.0], num_channels=4) + recording = generate_recording(durations=[1.0], num_channels=4) recording = recording.channel_slice(recording.channel_ids, np.array(["a", "b", "c", "d"])) return recording @@ -46,11 +46,15 @@ def test_common_reference(recording): def test_common_reference_channel_slicing(recording): recording_cmr = common_reference(recording, reference="global", operator="median") recording_car = common_reference(recording, reference="global", operator="average") - recording_single_reference = common_reference(recording, reference="single", ref_channel_ids=["a"]) + recording_single_reference = common_reference(recording, reference="single", ref_channel_ids=["b"]) recording_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="median") - channel_ids = ["a", "b"] - indices = recording.ids_to_indices(["a", "b"]) + channel_ids = ["b", "d"] + indices = recording.ids_to_indices(channel_ids) + + all_channel_ids = recording.channel_ids + all_indices = recording.ids_to_indices(all_channel_ids) + original_traces = recording.get_traces() cmr_trace = recording_cmr.get_traces(channel_ids=channel_ids) @@ -62,13 +66,50 @@ def test_common_reference_channel_slicing(recording): assert np.allclose(car_trace, expected_trace, atol=0.01) single_reference_trace = recording_single_reference.get_traces(channel_ids=channel_ids) - single_reference_index = recording.ids_to_indices(["a"]) + single_reference_index = recording.ids_to_indices(["b"]) expected_trace = original_traces[:, indices] - original_traces[:, single_reference_index] assert np.allclose(single_reference_trace, expected_trace, atol=0.01) - # local car - local_trace = recording_local_car.get_traces(channel_ids=channel_ids) + local_trace = recording_local_car.get_traces(channel_ids=all_channel_ids) + local_trace_sub = recording_local_car.get_traces(channel_ids=channel_ids) + + assert np.all(local_trace[:, indices] == local_trace_sub) + + # test segment slicing + + start_frame = 0 + end_frame = 10 + + recording_segment_cmr = recording_cmr._recording_segments[0] + traces_cmr_all = recording_segment_cmr.get_traces( + start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices + ) + traces_cmr_sub = recording_segment_cmr.get_traces( + start_frame=start_frame, end_frame=end_frame, channel_indices=indices + ) + + assert np.all(traces_cmr_all[:, indices] == traces_cmr_sub) + + recording_segment_car = recording_car._recording_segments[0] + traces_car_all = recording_segment_car.get_traces( + start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices + ) + traces_car_sub = recording_segment_car.get_traces( + start_frame=start_frame, end_frame=end_frame, channel_indices=indices + ) + + assert np.all(traces_car_all[:, indices] == traces_car_sub) + + recording_segment_local = recording_local_car._recording_segments[0] + traces_local_all = recording_segment_local.get_traces( + start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices + ) + traces_local_sub = recording_segment_local.get_traces( + start_frame=start_frame, end_frame=end_frame, channel_indices=indices + ) + + assert np.all(traces_local_all[:, indices] == traces_local_sub) def test_common_reference_groups(recording): diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 90bdc1056d..47846f10ce 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -2,7 +2,6 @@ from pathlib import Path from typing import Union -from packaging.version import parse from ..basesorter import BaseSorter from .kilosortbase import KilosortBase @@ -37,6 +36,7 @@ class Kilosort4Sorter(BaseSorter): "template_sizes": 5, "nearest_chans": 10, "nearest_templates": 100, + "max_channel_distance": None, "templates_from_data": True, "n_templates": 6, "n_pcs": 6, @@ -45,7 +45,8 @@ class Kilosort4Sorter(BaseSorter): "ccg_threshold": 0.25, "cluster_downsampling": 20, "cluster_pcs": 64, - "duplicate_spike_bins": 15, + "x_centers": None, + "duplicate_spike_bins": 7, "do_correction": True, "keep_good_only": False, "save_extra_kwargs": False, @@ -74,6 +75,7 @@ class Kilosort4Sorter(BaseSorter): "template_sizes": "Number of sizes for universal spike templates (multiples of the min_template_size). Default value: 5.", "nearest_chans": "Number of nearest channels to consider when finding local maxima during spike detection. Default value: 10.", "nearest_templates": "Number of nearest spike template locations to consider when finding local maxima during spike detection. Default value: 100.", + "max_channel_distance": "Templates farther away than this from their nearest channel will not be used. Also limits distance between compared channels during clustering. Default value: None.", "templates_from_data": "Indicates whether spike shapes used in universal templates should be estimated from the data or loaded from the predefined templates. Default value: True.", "n_templates": "Number of single-channel templates to use for the universal templates (only used if templates_from_data is True). Default value: 6.", "n_pcs": "Number of single-channel PCs to use for extracting spike features (only used if templates_from_data is True). Default value: 6.", @@ -82,7 +84,8 @@ class Kilosort4Sorter(BaseSorter): "ccg_threshold": "Fraction of refractory period violations that are allowed in the CCG compared to baseline; used to perform splits and merges. Default value: 0.25.", "cluster_downsampling": "Inverse fraction of nodes used as landmarks during clustering (can be 1, but that slows down the optimization). Default value: 20.", "cluster_pcs": "Maximum number of spatiotemporal PC features used for clustering. Default value: 64.", - "duplicate_spike_bins": "Number of bins for which subsequent spikes from the same cluster are assumed to be artifacts. A value of 0 disables this step. Default value: 15.", + "x_centers": "Number of x-positions to use when determining center points for template groupings. If None, this will be determined automatically by finding peaks in channel density. For 2D array type probes, we recommend specifying this so that centers are placed every few hundred microns.", + "duplicate_spike_bins": "Number of bins for which subsequent spikes from the same cluster are assumed to be artifacts. A value of 0 disables this step. Default value: 7.", "keep_good_only": "If True only 'good' units are returned", "do_correction": "If True, drift correction is performed", "save_extra_kwargs": "If True, additional kwargs are saved to the output", diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index b7f17d99e3..8a27334934 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -61,7 +61,7 @@ def _split_waveforms( local_feature_plot = local_feature unique_lab = np.unique(local_labels_with_noise) - cmap = plt.get_cmap("jet", unique_lab.size) + cmap = plt.colormaps["jet"].resampled(unique_lab.size) cmap = {k: cmap(l) for l, k in enumerate(unique_lab)} cmap[-1] = "k" active_ind = np.arange(local_feature.shape[0]) @@ -144,7 +144,7 @@ def _split_waveforms_nested( local_feature_plot = reducer.fit_transform(local_feature) unique_lab = np.unique(active_labels_with_noise) - cmap = plt.get_cmap("jet", unique_lab.size) + cmap = plt.colormaps["jet"].resampled(unique_lab.size) cmap = {k: cmap(l) for l, k in enumerate(unique_lab)} cmap[-1] = "k" cmap[-2] = "b" @@ -275,7 +275,7 @@ def auto_split_clustering( fig, ax = plt.subplots() plot_labels_set = np.unique(local_labels_with_noise) - cmap = plt.get_cmap("jet", plot_labels_set.size) + cmap = plt.colormaps["jet"].resampled(unique_lab.size) cmap = {k: cmap(l) for l, k in enumerate(plot_labels_set)} cmap[-1] = "k" cmap[-2] = "b" diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index 871d486b9c..ee1a2d36e1 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -347,7 +347,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): wfs_no_noise = wfs[: -noise.shape[0]] fig, axs = plt.subplots(ncols=3) - cmap = plt.get_cmap("jet", np.unique(local_labels).size) + cmap = plt.colormaps["jet"].resampled(np.unique(local_labels).size) cmap = {label: cmap(l) for l, label in enumerate(local_labels_set)} cmap[-1] = "k" for label in local_labels_set: diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 3861e7fe83..1e047fa906 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -249,7 +249,7 @@ def split( import matplotlib.pyplot as plt labels_set = np.setdiff1d(possible_labels, [-1]) - colors = plt.get_cmap("tab10", len(labels_set)) + colors = plt.colormaps["tab10"].resampled(len(labels_set)) colors = {k: colors(i) for i, k in enumerate(labels_set)} colors[-1] = "k" fix, axs = plt.subplots(nrows=2) diff --git a/src/spikeinterface/widgets/collision.py b/src/spikeinterface/widgets/collision.py index a5b5891110..34f65a2f89 100644 --- a/src/spikeinterface/widgets/collision.py +++ b/src/spikeinterface/widgets/collision.py @@ -136,7 +136,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax1.set_xlabel("lag (ms)") elif dp.mode == "lines": - my_cmap = plt.get_cmap(dp.cmap) + my_cmap = plt.colormaps[dp.cmap] cNorm = matplotlib.colors.Normalize(vmin=dp.similarity_bins.min(), vmax=dp.similarity_bins.max()) scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap) @@ -245,7 +245,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): study = dp.study - my_cmap = plt.get_cmap(dp.cmap) + my_cmap = plt.colormaps[dp.cmap] cNorm = matplotlib.colors.Normalize(vmin=dp.similarity_bins.min(), vmax=dp.similarity_bins.max()) scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap) study.precompute_scores_by_similarities( diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 2e4efc82b0..9d64c89e46 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -128,7 +128,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if dp.scatter_decimate is not None: amps = amps[:: dp.scatter_decimate] amps_abs = amps_abs[:: dp.scatter_decimate] - cmap = plt.get_cmap(dp.amplitude_cmap) + cmap = plt.colormaps[dp.amplitude_cmap] if dp.amplitude_clim is None: amps = amps_abs amps /= q_95 diff --git a/src/spikeinterface/widgets/multicomparison.py b/src/spikeinterface/widgets/multicomparison.py index 78693aacc2..2d4a22a2b3 100644 --- a/src/spikeinterface/widgets/multicomparison.py +++ b/src/spikeinterface/widgets/multicomparison.py @@ -87,7 +87,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): nodelist=sorted(g.nodes), edge_color=edge_col, alpha=dp.alpha_edges, - edge_cmap=plt.cm.get_cmap(dp.edge_cmap), + edge_cmap=plt.colormaps[dp.edge_cmap], edge_vmin=mcmp.match_score, edge_vmax=1, ax=self.ax, @@ -106,7 +106,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt norm = mpl_colors.Normalize(vmin=mcmp.match_score, vmax=1) - cmap = plt.cm.get_cmap(dp.edge_cmap) + cmap = plt.colormaps[dp.edge_cmap] m = plt.cm.ScalarMappable(norm=norm, cmap=cmap) self.figure.colorbar(m) @@ -159,7 +159,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) mcmp = dp.multi_comparison - cmap = plt.get_cmap(dp.cmap) + cmap = plt.colormaps[dp.cmap] colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(mcmp.name_list))]) sg_names, sg_units = mcmp.compute_subgraphs() # fraction of units with agreement > threshold @@ -242,7 +242,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): backend_kwargs["ncols"] = len(name_list) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - cmap = plt.get_cmap(dp.cmap) + cmap = plt.colormaps[dp.cmap] colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(mcmp.name_list))]) sg_names, sg_units = mcmp.compute_subgraphs() # fraction of units with agreement > threshold diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 29e6474ee9..21dc1a931a 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -76,7 +76,7 @@ def get_some_colors(keys, color_engine="auto", map_name="gist_ncar", format="RGB elif color_engine == "matplotlib": # some map have black or white at border so +10 margin = max(4, int(N * 0.08)) - cmap = plt.get_cmap(map_name, N + 2 * margin) + cmap = plt.colormaps[map_name].resampled(N + 2 * margin) colors = [cmap(i + margin) for i, key in enumerate(keys)] elif color_engine == "colorsys": @@ -153,7 +153,7 @@ def array_to_image( num_channels = data.shape[1] spacing = int(num_channels * spatial_zoom[1] * row_spacing) - cmap = plt.get_cmap(colormap) + cmap = plt.colormaps[colormap] zoomed_data = zoom(data, spatial_zoom) num_timepoints_after_scaling, num_channels_after_scaling = zoomed_data.shape num_timepoints_per_row_after_scaling = int(np.min([num_timepoints_per_row, num_timepoints]) * spatial_zoom[0])