Skip to content

Commit

Permalink
Merge pull request #2827 from alejoe91/prepare-0.100.7
Browse files Browse the repository at this point in the history
Prepare 0.100.7 release
  • Loading branch information
alejoe91 authored Jun 7, 2024
2 parents a75f1d7 + a340c85 commit afdfae0
Show file tree
Hide file tree
Showing 14 changed files with 95 additions and 30 deletions.
14 changes: 14 additions & 0 deletions doc/releases/0.100.7.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
.. _release0.100.7:

SpikeInterface 0.100.7 release notes
------------------------------------

7th June 2024

Minor release with bug fixes

* Fix get_traces for a local common reference (#2649)
* Update KS4 parameters (#2810)
* Zarr: extract time vector once and for all! (#2828)
* Fix waveforms save in recordingless mode (#2889)
* Fix the new way of handling cmap in matpltolib. This fix the matplotib 3.9 problem related to this (#2891)
7 changes: 7 additions & 0 deletions doc/whatisnew.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Release notes
.. toctree::
:maxdepth: 1

releases/0.100.7.rst
releases/0.100.6.rst
releases/0.100.5.rst
releases/0.100.4.rst
Expand Down Expand Up @@ -40,6 +41,12 @@ Release notes
releases/0.9.1.rst


Version 0.100.7
===============

* Minor release with bug fixes


Version 0.100.6
===============

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ full = [
"scikit-learn",
"networkx",
"distinctipy",
"matplotlib",
"matplotlib>=3.6",
"cuda-python; platform_system != 'Darwin'",
"numba",
]

widgets = [
"matplotlib",
"matplotlib>=3.6",
"ipympl",
"ipywidgets",
"sortingview>=0.12.0",
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/zarrextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None)
time_kwargs = {}
time_vector = self._root.get(f"times_seg{segment_index}", None)
if time_vector is not None:
time_kwargs["time_vector"] = time_vector
time_kwargs["time_vector"] = time_vector[:]
else:
if t_starts is None:
t_start = None
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/preprocessing/common_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,12 @@ def get_traces(self, start_frame, end_frame, channel_indices):
shift = traces[:, self.ref_channel_indices]
re_referenced_traces = traces[:, channel_indices] - shift
else: # then it must be local
re_referenced_traces = np.zeros_like(traces[:, channel_indices])
channel_indices_array = np.arange(traces.shape[1])[channel_indices]
for channel_index in channel_indices_array:
re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32")
for i, channel_index in enumerate(channel_indices_array):
channel_neighborhood = self.neighbors[channel_index]
channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1)
re_referenced_traces[:, channel_index] = traces[:, channel_index] - channel_shift
re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift

return re_referenced_traces.astype(self.dtype, copy=False)

Expand Down
55 changes: 48 additions & 7 deletions src/spikeinterface/preprocessing/tests/test_common_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def _generate_test_recording():
recording = generate_recording(durations=[5.0], num_channels=4)
recording = generate_recording(durations=[1.0], num_channels=4)
recording = recording.channel_slice(recording.channel_ids, np.array(["a", "b", "c", "d"]))
return recording

Expand Down Expand Up @@ -46,11 +46,15 @@ def test_common_reference(recording):
def test_common_reference_channel_slicing(recording):
recording_cmr = common_reference(recording, reference="global", operator="median")
recording_car = common_reference(recording, reference="global", operator="average")
recording_single_reference = common_reference(recording, reference="single", ref_channel_ids=["a"])
recording_single_reference = common_reference(recording, reference="single", ref_channel_ids=["b"])
recording_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="median")

channel_ids = ["a", "b"]
indices = recording.ids_to_indices(["a", "b"])
channel_ids = ["b", "d"]
indices = recording.ids_to_indices(channel_ids)

all_channel_ids = recording.channel_ids
all_indices = recording.ids_to_indices(all_channel_ids)

original_traces = recording.get_traces()

cmr_trace = recording_cmr.get_traces(channel_ids=channel_ids)
Expand All @@ -62,13 +66,50 @@ def test_common_reference_channel_slicing(recording):
assert np.allclose(car_trace, expected_trace, atol=0.01)

single_reference_trace = recording_single_reference.get_traces(channel_ids=channel_ids)
single_reference_index = recording.ids_to_indices(["a"])
single_reference_index = recording.ids_to_indices(["b"])
expected_trace = original_traces[:, indices] - original_traces[:, single_reference_index]

assert np.allclose(single_reference_trace, expected_trace, atol=0.01)

# local car
local_trace = recording_local_car.get_traces(channel_ids=channel_ids)
local_trace = recording_local_car.get_traces(channel_ids=all_channel_ids)
local_trace_sub = recording_local_car.get_traces(channel_ids=channel_ids)

assert np.all(local_trace[:, indices] == local_trace_sub)

# test segment slicing

start_frame = 0
end_frame = 10

recording_segment_cmr = recording_cmr._recording_segments[0]
traces_cmr_all = recording_segment_cmr.get_traces(
start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices
)
traces_cmr_sub = recording_segment_cmr.get_traces(
start_frame=start_frame, end_frame=end_frame, channel_indices=indices
)

assert np.all(traces_cmr_all[:, indices] == traces_cmr_sub)

recording_segment_car = recording_car._recording_segments[0]
traces_car_all = recording_segment_car.get_traces(
start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices
)
traces_car_sub = recording_segment_car.get_traces(
start_frame=start_frame, end_frame=end_frame, channel_indices=indices
)

assert np.all(traces_car_all[:, indices] == traces_car_sub)

recording_segment_local = recording_local_car._recording_segments[0]
traces_local_all = recording_segment_local.get_traces(
start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices
)
traces_local_sub = recording_segment_local.get_traces(
start_frame=start_frame, end_frame=end_frame, channel_indices=indices
)

assert np.all(traces_local_all[:, indices] == traces_local_sub)


def test_common_reference_groups(recording):
Expand Down
9 changes: 6 additions & 3 deletions src/spikeinterface/sorters/external/kilosort4.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from pathlib import Path
from typing import Union
from packaging.version import parse

from ..basesorter import BaseSorter
from .kilosortbase import KilosortBase
Expand Down Expand Up @@ -37,6 +36,7 @@ class Kilosort4Sorter(BaseSorter):
"template_sizes": 5,
"nearest_chans": 10,
"nearest_templates": 100,
"max_channel_distance": None,
"templates_from_data": True,
"n_templates": 6,
"n_pcs": 6,
Expand All @@ -45,7 +45,8 @@ class Kilosort4Sorter(BaseSorter):
"ccg_threshold": 0.25,
"cluster_downsampling": 20,
"cluster_pcs": 64,
"duplicate_spike_bins": 15,
"x_centers": None,
"duplicate_spike_bins": 7,
"do_correction": True,
"keep_good_only": False,
"save_extra_kwargs": False,
Expand Down Expand Up @@ -74,6 +75,7 @@ class Kilosort4Sorter(BaseSorter):
"template_sizes": "Number of sizes for universal spike templates (multiples of the min_template_size). Default value: 5.",
"nearest_chans": "Number of nearest channels to consider when finding local maxima during spike detection. Default value: 10.",
"nearest_templates": "Number of nearest spike template locations to consider when finding local maxima during spike detection. Default value: 100.",
"max_channel_distance": "Templates farther away than this from their nearest channel will not be used. Also limits distance between compared channels during clustering. Default value: None.",
"templates_from_data": "Indicates whether spike shapes used in universal templates should be estimated from the data or loaded from the predefined templates. Default value: True.",
"n_templates": "Number of single-channel templates to use for the universal templates (only used if templates_from_data is True). Default value: 6.",
"n_pcs": "Number of single-channel PCs to use for extracting spike features (only used if templates_from_data is True). Default value: 6.",
Expand All @@ -82,7 +84,8 @@ class Kilosort4Sorter(BaseSorter):
"ccg_threshold": "Fraction of refractory period violations that are allowed in the CCG compared to baseline; used to perform splits and merges. Default value: 0.25.",
"cluster_downsampling": "Inverse fraction of nodes used as landmarks during clustering (can be 1, but that slows down the optimization). Default value: 20.",
"cluster_pcs": "Maximum number of spatiotemporal PC features used for clustering. Default value: 64.",
"duplicate_spike_bins": "Number of bins for which subsequent spikes from the same cluster are assumed to be artifacts. A value of 0 disables this step. Default value: 15.",
"x_centers": "Number of x-positions to use when determining center points for template groupings. If None, this will be determined automatically by finding peaks in channel density. For 2D array type probes, we recommend specifying this so that centers are placed every few hundred microns.",
"duplicate_spike_bins": "Number of bins for which subsequent spikes from the same cluster are assumed to be artifacts. A value of 0 disables this step. Default value: 7.",
"keep_good_only": "If True only 'good' units are returned",
"do_correction": "If True, drift correction is performed",
"save_extra_kwargs": "If True, additional kwargs are saved to the output",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _split_waveforms(
local_feature_plot = local_feature

unique_lab = np.unique(local_labels_with_noise)
cmap = plt.get_cmap("jet", unique_lab.size)
cmap = plt.colormaps["jet"].resampled(unique_lab.size)
cmap = {k: cmap(l) for l, k in enumerate(unique_lab)}
cmap[-1] = "k"
active_ind = np.arange(local_feature.shape[0])
Expand Down Expand Up @@ -144,7 +144,7 @@ def _split_waveforms_nested(
local_feature_plot = reducer.fit_transform(local_feature)

unique_lab = np.unique(active_labels_with_noise)
cmap = plt.get_cmap("jet", unique_lab.size)
cmap = plt.colormaps["jet"].resampled(unique_lab.size)
cmap = {k: cmap(l) for l, k in enumerate(unique_lab)}
cmap[-1] = "k"
cmap[-2] = "b"
Expand Down Expand Up @@ -275,7 +275,7 @@ def auto_split_clustering(

fig, ax = plt.subplots()
plot_labels_set = np.unique(local_labels_with_noise)
cmap = plt.get_cmap("jet", plot_labels_set.size)
cmap = plt.colormaps["jet"].resampled(unique_lab.size)
cmap = {k: cmap(l) for l, k in enumerate(plot_labels_set)}
cmap[-1] = "k"
cmap[-2] = "b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d):
wfs_no_noise = wfs[: -noise.shape[0]]

fig, axs = plt.subplots(ncols=3)
cmap = plt.get_cmap("jet", np.unique(local_labels).size)
cmap = plt.colormaps["jet"].resampled(np.unique(local_labels).size)
cmap = {label: cmap(l) for l, label in enumerate(local_labels_set)}
cmap[-1] = "k"
for label in local_labels_set:
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sortingcomponents/clustering/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def split(
import matplotlib.pyplot as plt

labels_set = np.setdiff1d(possible_labels, [-1])
colors = plt.get_cmap("tab10", len(labels_set))
colors = plt.colormaps["tab10"].resampled(len(labels_set))
colors = {k: colors(i) for i, k in enumerate(labels_set)}
colors[-1] = "k"
fix, axs = plt.subplots(nrows=2)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/widgets/collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):

ax1.set_xlabel("lag (ms)")
elif dp.mode == "lines":
my_cmap = plt.get_cmap(dp.cmap)
my_cmap = plt.colormaps[dp.cmap]
cNorm = matplotlib.colors.Normalize(vmin=dp.similarity_bins.min(), vmax=dp.similarity_bins.max())
scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap)

Expand Down Expand Up @@ -245,7 +245,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):

study = dp.study

my_cmap = plt.get_cmap(dp.cmap)
my_cmap = plt.colormaps[dp.cmap]
cNorm = matplotlib.colors.Normalize(vmin=dp.similarity_bins.min(), vmax=dp.similarity_bins.max())
scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap)
study.precompute_scores_by_similarities(
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
if dp.scatter_decimate is not None:
amps = amps[:: dp.scatter_decimate]
amps_abs = amps_abs[:: dp.scatter_decimate]
cmap = plt.get_cmap(dp.amplitude_cmap)
cmap = plt.colormaps[dp.amplitude_cmap]
if dp.amplitude_clim is None:
amps = amps_abs
amps /= q_95
Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface/widgets/multicomparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
nodelist=sorted(g.nodes),
edge_color=edge_col,
alpha=dp.alpha_edges,
edge_cmap=plt.cm.get_cmap(dp.edge_cmap),
edge_cmap=plt.colormaps[dp.edge_cmap],
edge_vmin=mcmp.match_score,
edge_vmax=1,
ax=self.ax,
Expand All @@ -106,7 +106,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt

norm = mpl_colors.Normalize(vmin=mcmp.match_score, vmax=1)
cmap = plt.cm.get_cmap(dp.edge_cmap)
cmap = plt.colormaps[dp.edge_cmap]
m = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
self.figure.colorbar(m)

Expand Down Expand Up @@ -159,7 +159,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)

mcmp = dp.multi_comparison
cmap = plt.get_cmap(dp.cmap)
cmap = plt.colormaps[dp.cmap]
colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(mcmp.name_list))])
sg_names, sg_units = mcmp.compute_subgraphs()
# fraction of units with agreement > threshold
Expand Down Expand Up @@ -242,7 +242,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
backend_kwargs["ncols"] = len(name_list)
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)

cmap = plt.get_cmap(dp.cmap)
cmap = plt.colormaps[dp.cmap]
colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(mcmp.name_list))])
sg_names, sg_units = mcmp.compute_subgraphs()
# fraction of units with agreement > threshold
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/widgets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_some_colors(keys, color_engine="auto", map_name="gist_ncar", format="RGB
elif color_engine == "matplotlib":
# some map have black or white at border so +10
margin = max(4, int(N * 0.08))
cmap = plt.get_cmap(map_name, N + 2 * margin)
cmap = plt.colormaps[map_name].resampled(N + 2 * margin)
colors = [cmap(i + margin) for i, key in enumerate(keys)]

elif color_engine == "colorsys":
Expand Down Expand Up @@ -153,7 +153,7 @@ def array_to_image(
num_channels = data.shape[1]
spacing = int(num_channels * spatial_zoom[1] * row_spacing)

cmap = plt.get_cmap(colormap)
cmap = plt.colormaps[colormap]
zoomed_data = zoom(data, spatial_zoom)
num_timepoints_after_scaling, num_channels_after_scaling = zoomed_data.shape
num_timepoints_per_row_after_scaling = int(np.min([num_timepoints_per_row, num_timepoints]) * spatial_zoom[0])
Expand Down

0 comments on commit afdfae0

Please sign in to comment.