From b78257cf7217764de00be0eac72b56deb499e1bd Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 15 Sep 2023 11:51:55 +0200 Subject: [PATCH 01/10] Speed up searchsorted calls --- src/spikeinterface/core/basesorting.py | 3 +-- src/spikeinterface/core/generate.py | 7 +++---- src/spikeinterface/core/node_pipeline.py | 12 ++++-------- src/spikeinterface/core/numpyextractors.py | 3 +-- src/spikeinterface/core/segmentutils.py | 6 ++---- src/spikeinterface/core/waveform_tools.py | 15 +++++---------- .../curation/remove_duplicated_spikes.py | 3 +-- .../postprocessing/amplitude_scalings.py | 3 +-- .../postprocessing/principal_component.py | 3 +-- .../postprocessing/spike_amplitudes.py | 4 +--- .../postprocessing/spike_locations.py | 3 +-- src/spikeinterface/qualitymetrics/misc_metrics.py | 6 ++---- .../sortingcomponents/motion_interpolation.py | 3 +-- 13 files changed, 24 insertions(+), 47 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 52f71c2399..eb141abde4 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -473,8 +473,7 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac if not concatenated: spikes_ = [] for segment_index in range(self.get_num_segments()): - s0 = np.searchsorted(spikes["segment_index"], segment_index, side="left") - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1, side="left") + s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1], side="left") spikes_.append(spikes[s0:s1]) spikes = spikes_ diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 401c498f03..56a2bb4f48 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1109,8 +1109,7 @@ def __init__( num_samples = [num_samples] for segment_index in range(sorting.get_num_segments()): - start = np.searchsorted(self.spike_vector["segment_index"], segment_index, side="left") - end = np.searchsorted(self.spike_vector["segment_index"], segment_index, side="right") + start, end = np.searchsorted(self.spike_vector["segment_index"], [segment_index, segment_index+1], side="left") spikes = self.spike_vector[start:end] amplitude_vec = amplitude_vector[start:end] if amplitude_vector is not None else None upsample_vec = upsample_vector[start:end] if upsample_vector is not None else None @@ -1208,8 +1207,8 @@ def get_traces( else: traces = np.zeros([end_frame - start_frame, n_channels], dtype=self.dtype) - start = np.searchsorted(self.spike_vector["sample_index"], start_frame - self.templates.shape[1], side="left") - end = np.searchsorted(self.spike_vector["sample_index"], end_frame + self.templates.shape[1], side="right") + start, end = np.searchsorted(self.spike_vector["sample_index"], [start_frame - self.templates.shape[1], + end_frame + self.templates.shape[1] + 1], side="left") for i in range(start, end): spike = self.spike_vector[i] diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index b11f40a441..5627eba518 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -111,8 +111,7 @@ def __init__(self, recording, peaks): # precompute segment slice self.segment_slices = [] for segment_index in range(recording.get_num_segments()): - i0 = np.searchsorted(peaks["segment_index"], segment_index) - i1 = np.searchsorted(peaks["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1]) self.segment_slices.append(slice(i0, i1)) def get_trace_margin(self): @@ -125,8 +124,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] - i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(peaks_in_segment["segment_index"], [start_frame, end_frame]) local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces @@ -183,8 +181,7 @@ def __init__( # precompute segment slice self.segment_slices = [] for segment_index in range(recording.get_num_segments()): - i0 = np.searchsorted(self.peaks["segment_index"], segment_index) - i1 = np.searchsorted(self.peaks["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(self.peaks["segment_index"], [segment_index, segment_index + 1]) self.segment_slices.append(slice(i0, i1)) def get_trace_margin(self): @@ -197,8 +194,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] - i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(peaks_in_segment["segment_index"], [start_frame, end_frame]) local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 97f22615df..d5663156c7 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -338,8 +338,7 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): if self.spikes_in_seg is None: # the slicing of segment is done only once the first time # this fasten the constructor a lot - s0 = np.searchsorted(self.spikes["segment_index"], self.segment_index, side="left") - s1 = np.searchsorted(self.spikes["segment_index"], self.segment_index + 1, side="left") + s0, s1 = np.searchsorted(self.spikes["segment_index"], [self.segment_index, self.segment_index + 1]) self.spikes_in_seg = self.spikes[s0:s1] unit_index = self.unit_ids.index(unit_id) diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index f70c45bfe5..85e36cf7a5 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -174,8 +174,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): # Return (0 * num_channels) array of correct dtype return self.parent_segments[0].get_traces(0, 0, channel_indices) - i0 = np.searchsorted(self.cumsum_length, start_frame, side="right") - 1 - i1 = np.searchsorted(self.cumsum_length, end_frame, side="right") - 1 + i0, i1 = np.searchsorted(self.cumsum_length, [start_frame, end_frame], side="right") - 1 # several case: # * come from one segment (i0 == i1) @@ -469,8 +468,7 @@ def get_unit_spike_train( if end_frame is None: end_frame = self.get_num_samples() - i0 = np.searchsorted(self.cumsum_length, start_frame, side="right") - 1 - i1 = np.searchsorted(self.cumsum_length, end_frame, side="right") - 1 + i0, i1 = np.searchsorted(self.cumsum_length, [start_frame, end_frame], side="right") - 1 # several case: # * come from one segment (i0 == i1) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index da8e3d64b6..0ac20b9fec 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -344,15 +344,13 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx # take only spikes with the correct segment_index # this is a slice so no copy!! - s0 = np.searchsorted(spikes["segment_index"], segment_index) - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) in_seg_spikes = spikes[s0:s1] # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! # the border of segment are protected by nbefore on left an nafter on the right - i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) - i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + i0, i1 = np.searchsorted(in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)]) # slice in absolut in spikes vector l0 = i0 + s0 @@ -562,8 +560,7 @@ def _init_worker_distribute_single_buffer( # prepare segment slices segment_slices = [] for segment_index in range(recording.get_num_segments()): - s0 = np.searchsorted(spikes["segment_index"], segment_index) - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append((s0, s1)) worker_ctx["segment_slices"] = segment_slices @@ -590,8 +587,7 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! # the border of segment are protected by nbefore on left an nafter on the right - i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) - i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + i0, i1 = np.searchsorted(in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)]) # slice in absolut in spikes vector l0 = i0 + s0 @@ -685,8 +681,7 @@ def has_exceeding_spikes(recording, sorting): """ spike_vector = sorting.to_spike_vector() for segment_index in range(recording.get_num_segments()): - start_seg_ind = np.searchsorted(spike_vector["segment_index"], segment_index) - end_seg_ind = np.searchsorted(spike_vector["segment_index"], segment_index + 1) + start_seg_ind, end_seg_ind = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1]) spike_vector_seg = spike_vector[start_seg_ind:end_seg_ind] if len(spike_vector_seg) > 0: if spike_vector_seg["sample_index"][-1] > recording.get_num_samples(segment_index=segment_index) - 1: diff --git a/src/spikeinterface/curation/remove_duplicated_spikes.py b/src/spikeinterface/curation/remove_duplicated_spikes.py index 04af69b37a..3badaa9402 100644 --- a/src/spikeinterface/curation/remove_duplicated_spikes.py +++ b/src/spikeinterface/curation/remove_duplicated_spikes.py @@ -82,8 +82,7 @@ def get_unit_spike_train( if end_frame == None: end_frame = spike_train[-1] if len(spike_train) > 0 else 0 - start = np.searchsorted(spike_train, start_frame, side="left") - end = np.searchsorted(spike_train, end_frame, side="right") + start, end = np.searchsorted(spike_train, [start_frame, end + 1], side="left") return spike_train[start:end] diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 5a0148c5c4..bb97f246d9 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -99,8 +99,7 @@ def _run(self, **job_kwargs): # precompute segment slice segment_slices = [] for segment_index in range(we.get_num_segments()): - i0 = np.searchsorted(self.spikes["segment_index"], segment_index) - i1 = np.searchsorted(self.spikes["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(self.spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append(slice(i0, i1)) # and run diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 233625e09e..ce1c3bd5a0 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -600,8 +600,7 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): seg_size = recording.get_num_samples(segment_index=segment_index) - i0 = np.searchsorted(spike_times, start_frame) - i1 = np.searchsorted(spike_times, end_frame) + i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame]) if i0 != i1: # protect from spikes on border : spike_time<0 or spike_time>seg_size diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 62a4e2c320..fd6078b9b0 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -218,9 +218,7 @@ def _spike_amplitudes_chunk(segment_index, start_frame, end_frame, worker_ctx): d = np.diff(spike_times) assert np.all(d >= 0) - i0 = np.searchsorted(spike_times, start_frame) - i1 = np.searchsorted(spike_times, end_frame) - + i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame]) n_spikes = i1 - i0 amplitudes = np.zeros(n_spikes, dtype=recording.get_dtype()) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index c6f498f7e8..5f23e25b32 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -77,8 +77,7 @@ def get_data(self, outputs="concatenated"): elif outputs == "by_unit": locations_by_unit = [] for segment_index in range(self.waveform_extractor.get_num_segments()): - i0 = np.searchsorted(self.spikes["segment_index"], segment_index, side="left") - i1 = np.searchsorted(self.spikes["segment_index"], segment_index, side="right") + i0, i1 = np.searchsorted(self.spikes["segment_index"], [segment_index, segment_index + 1], side="left") spikes = self.spikes[i0:i1] locations = self._extension_data["spike_locations"][i0:i1] diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index ee28485983..01701e4f65 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -848,16 +848,14 @@ def compute_drift_metrics( spike_vector = sorting.to_spike_vector() # retrieve spikes in segment - i0 = np.searchsorted(spike_vector["segment_index"], segment_index) - i1 = np.searchsorted(spike_vector["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1]) spikes_in_segment = spike_vector[i0:i1] spike_locations_in_segment = spike_locations[i0:i1] # compute median positions (if less than min_spikes_per_interval, median position is 0) median_positions = np.nan * np.zeros((len(unit_ids), num_bin_edges - 1)) for bin_index, (start_frame, end_frame) in enumerate(zip(bins[:-1], bins[1:])): - i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame]) spikes_in_bin = spikes_in_segment[i0:i1] spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction] diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index b4a44105e4..1f6c348574 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -155,8 +155,7 @@ def interpolate_motion_on_traces( **spatial_interpolation_kwargs, ) - i0 = np.searchsorted(bin_inds, bin_ind, side="left") - i1 = np.searchsorted(bin_inds, bin_ind, side="right") + i0, i1 = np.searchsorted(bin_inds, [bin_ind, bin_ind + 1] side="left") # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing From 164430c83cf66221bed677198fa8d468a8781c1d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Sep 2023 09:54:35 +0000 Subject: [PATCH 02/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/generate.py | 11 ++++++++--- src/spikeinterface/core/waveform_tools.py | 8 ++++++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 56a2bb4f48..6f85e76f1f 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1109,7 +1109,9 @@ def __init__( num_samples = [num_samples] for segment_index in range(sorting.get_num_segments()): - start, end = np.searchsorted(self.spike_vector["segment_index"], [segment_index, segment_index+1], side="left") + start, end = np.searchsorted( + self.spike_vector["segment_index"], [segment_index, segment_index + 1], side="left" + ) spikes = self.spike_vector[start:end] amplitude_vec = amplitude_vector[start:end] if amplitude_vector is not None else None upsample_vec = upsample_vector[start:end] if upsample_vector is not None else None @@ -1207,8 +1209,11 @@ def get_traces( else: traces = np.zeros([end_frame - start_frame, n_channels], dtype=self.dtype) - start, end = np.searchsorted(self.spike_vector["sample_index"], [start_frame - self.templates.shape[1], - end_frame + self.templates.shape[1] + 1], side="left") + start, end = np.searchsorted( + self.spike_vector["sample_index"], + [start_frame - self.templates.shape[1], end_frame + self.templates.shape[1] + 1], + side="left", + ) for i in range(start, end): spike = self.spike_vector[i] diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 0ac20b9fec..a2f1296e31 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -350,7 +350,9 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! # the border of segment are protected by nbefore on left an nafter on the right - i0, i1 = np.searchsorted(in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)]) + i0, i1 = np.searchsorted( + in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)] + ) # slice in absolut in spikes vector l0 = i0 + s0 @@ -587,7 +589,9 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! # the border of segment are protected by nbefore on left an nafter on the right - i0, i1 = np.searchsorted(in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)]) + i0, i1 = np.searchsorted( + in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)] + ) # slice in absolut in spikes vector l0 = i0 + s0 From 9ad5f56907a848b757977e8dc2316445f867e269 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 15 Sep 2023 13:01:43 +0200 Subject: [PATCH 03/10] Update src/spikeinterface/sortingcomponents/motion_interpolation.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/sortingcomponents/motion_interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 1f6c348574..18bb4f5a99 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -155,7 +155,7 @@ def interpolate_motion_on_traces( **spatial_interpolation_kwargs, ) - i0, i1 = np.searchsorted(bin_inds, [bin_ind, bin_ind + 1] side="left") + i0, i1 = np.searchsorted(bin_inds, [bin_ind, bin_ind + 1], side="left") # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing From 9c6e6c1cef249d0382c6c441cdd7d2a7b0194cb1 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 15 Sep 2023 13:30:36 +0200 Subject: [PATCH 04/10] Typos while copy/paste --- src/spikeinterface/core/node_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 5627eba518..651804c995 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -124,7 +124,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] - i0, i1 = np.searchsorted(peaks_in_segment["segment_index"], [start_frame, end_frame]) + i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces @@ -194,7 +194,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] - i0, i1 = np.searchsorted(peaks_in_segment["segment_index"], [start_frame, end_frame]) + i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces From 646455a1054bf4cebed133c3197e8598ef75e59f Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 15 Sep 2023 13:38:26 +0200 Subject: [PATCH 05/10] Some more searchsorted --- .../postprocessing/amplitude_scalings.py | 15 +++++---------- .../widgets/_legacy_mpl_widgets/activity.py | 3 +-- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index bb97f246d9..73e75870f9 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -316,8 +316,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) spikes_in_segment = spikes[segment_slices[segment_index]] - i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame]) if i0 != i1: local_spikes = spikes_in_segment[i0:i1] @@ -334,8 +333,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) # set colliding spikes apart (if needed) if handle_collisions: # local spikes with margin! - i0_margin = np.searchsorted(spikes_in_segment["sample_index"], start_frame - left) - i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right) + i0_margin, i1_margin = np.searchsorted(spikes_in_segment["sample_index"], [start_frame - left, end_frame + right]) local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin] collisions_local = find_collisions( local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices @@ -461,14 +459,11 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_ spike_index_w_margin = np.where(spikes_w_margin == spike)[0][0] # find the possible spikes per and post within delta_collision_samples - consecutive_window_pre = np.searchsorted( + consecutive_window_pre, consecutive_window_post = np.searchsorted( spikes_w_margin["sample_index"], - spike["sample_index"] - delta_collision_samples, - ) - consecutive_window_post = np.searchsorted( - spikes_w_margin["sample_index"], - spike["sample_index"] + delta_collision_samples, + [spike["sample_index"] - delta_collision_samples, spike["sample_index"] + delta_collision_samples] ) + # exclude the spike itself (it is included in the collision_spikes by construction) pre_possible_consecutive_spike_indices = np.arange(consecutive_window_pre, spike_index_w_margin) post_possible_consecutive_spike_indices = np.arange(spike_index_w_margin + 1, consecutive_window_post) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py index 939475c17d..9715b7ea87 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py @@ -95,8 +95,7 @@ def plot(self): num_frames = int(duration / self.bin_duration_s) def animate_func(i): - i0 = np.searchsorted(peaks["sample_index"], bin_size * i) - i1 = np.searchsorted(peaks["sample_index"], bin_size * (i + 1)) + i0, i1 = np.searchsorted(peaks["sample_index"], [bin_size * i, bin_size * (i + 1)]) local_peaks = peaks[i0:i1] artists = self._plot_one_bin(rec, probe, local_peaks, self.bin_duration_s) return artists From 4410d6e8d06a8f3db8004846152be90bf04b8615 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Sep 2023 11:40:20 +0000 Subject: [PATCH 06/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/postprocessing/amplitude_scalings.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 73e75870f9..d4446e2289 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -333,7 +333,9 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) # set colliding spikes apart (if needed) if handle_collisions: # local spikes with margin! - i0_margin, i1_margin = np.searchsorted(spikes_in_segment["sample_index"], [start_frame - left, end_frame + right]) + i0_margin, i1_margin = np.searchsorted( + spikes_in_segment["sample_index"], [start_frame - left, end_frame + right] + ) local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin] collisions_local = find_collisions( local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices @@ -461,7 +463,7 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_ # find the possible spikes per and post within delta_collision_samples consecutive_window_pre, consecutive_window_post = np.searchsorted( spikes_w_margin["sample_index"], - [spike["sample_index"] - delta_collision_samples, spike["sample_index"] + delta_collision_samples] + [spike["sample_index"] - delta_collision_samples, spike["sample_index"] + delta_collision_samples], ) # exclude the spike itself (it is included in the collision_spikes by construction) From 334f178aaafc0cccbc81db9821749691b7d67da6 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 15 Sep 2023 13:52:20 +0200 Subject: [PATCH 07/10] Fix --- src/spikeinterface/curation/remove_duplicated_spikes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/remove_duplicated_spikes.py b/src/spikeinterface/curation/remove_duplicated_spikes.py index 3badaa9402..d01ca1f6a1 100644 --- a/src/spikeinterface/curation/remove_duplicated_spikes.py +++ b/src/spikeinterface/curation/remove_duplicated_spikes.py @@ -82,7 +82,7 @@ def get_unit_spike_train( if end_frame == None: end_frame = spike_train[-1] if len(spike_train) > 0 else 0 - start, end = np.searchsorted(spike_train, [start_frame, end + 1], side="left") + start, end = np.searchsorted(spike_train, [start_frame, end_frame + 1], side="left") return spike_train[start:end] From ef0d66e6cfeea0b1f3392c5a0a8758194a9c884d Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 19 Sep 2023 09:27:18 +0200 Subject: [PATCH 08/10] Bringing back right searches --- src/spikeinterface/core/generate.py | 8 +++----- src/spikeinterface/curation/remove_duplicated_spikes.py | 3 ++- src/spikeinterface/postprocessing/spike_locations.py | 3 ++- .../sortingcomponents/motion_interpolation.py | 5 +++-- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 6f85e76f1f..33f3dea923 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1209,11 +1209,9 @@ def get_traces( else: traces = np.zeros([end_frame - start_frame, n_channels], dtype=self.dtype) - start, end = np.searchsorted( - self.spike_vector["sample_index"], - [start_frame - self.templates.shape[1], end_frame + self.templates.shape[1] + 1], - side="left", - ) + start = np.searchsorted(self.spike_vector["sample_index"], start_frame - self.templates.shape[1], side="left") + end = np.searchsorted(self.spike_vector["sample_index"], end_frame + self.templates.shape[1], side="right") + for i in range(start, end): spike = self.spike_vector[i] diff --git a/src/spikeinterface/curation/remove_duplicated_spikes.py b/src/spikeinterface/curation/remove_duplicated_spikes.py index d01ca1f6a1..04af69b37a 100644 --- a/src/spikeinterface/curation/remove_duplicated_spikes.py +++ b/src/spikeinterface/curation/remove_duplicated_spikes.py @@ -82,7 +82,8 @@ def get_unit_spike_train( if end_frame == None: end_frame = spike_train[-1] if len(spike_train) > 0 else 0 - start, end = np.searchsorted(spike_train, [start_frame, end_frame + 1], side="left") + start = np.searchsorted(spike_train, start_frame, side="left") + end = np.searchsorted(spike_train, end_frame, side="right") return spike_train[start:end] diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 5f23e25b32..c6f498f7e8 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -77,7 +77,8 @@ def get_data(self, outputs="concatenated"): elif outputs == "by_unit": locations_by_unit = [] for segment_index in range(self.waveform_extractor.get_num_segments()): - i0, i1 = np.searchsorted(self.spikes["segment_index"], [segment_index, segment_index + 1], side="left") + i0 = np.searchsorted(self.spikes["segment_index"], segment_index, side="left") + i1 = np.searchsorted(self.spikes["segment_index"], segment_index, side="right") spikes = self.spikes[i0:i1] locations = self._extension_data["spike_locations"][i0:i1] diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 18bb4f5a99..9a4cd688c5 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -155,8 +155,9 @@ def interpolate_motion_on_traces( **spatial_interpolation_kwargs, ) - i0, i1 = np.searchsorted(bin_inds, [bin_ind, bin_ind + 1], side="left") - + i0 = np.searchsorted(bin_inds, bin_ind, side="left") + i1 = np.searchsorted(bin_inds, bin_ind, side="right") + # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing # in ChunkRecordingExecutor) From f2d702a7e20f7fb6459a18b17dd9a4881c1fe337 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Sep 2023 07:27:40 +0000 Subject: [PATCH 09/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/generate.py | 1 - src/spikeinterface/sortingcomponents/motion_interpolation.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 33f3dea923..9adda4cb2b 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1212,7 +1212,6 @@ def get_traces( start = np.searchsorted(self.spike_vector["sample_index"], start_frame - self.templates.shape[1], side="left") end = np.searchsorted(self.spike_vector["sample_index"], end_frame + self.templates.shape[1], side="right") - for i in range(start, end): spike = self.spike_vector[i] t = spike["sample_index"] diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 9a4cd688c5..b4a44105e4 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -157,7 +157,7 @@ def interpolate_motion_on_traces( i0 = np.searchsorted(bin_inds, bin_ind, side="left") i1 = np.searchsorted(bin_inds, bin_ind, side="right") - + # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing # in ChunkRecordingExecutor) From 9d07ec2fb467e4bc035f2e36566ea9a2aead772e Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 19 Sep 2023 09:31:02 +0200 Subject: [PATCH 10/10] One more --- src/spikeinterface/core/generate.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 9adda4cb2b..401c498f03 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1109,9 +1109,8 @@ def __init__( num_samples = [num_samples] for segment_index in range(sorting.get_num_segments()): - start, end = np.searchsorted( - self.spike_vector["segment_index"], [segment_index, segment_index + 1], side="left" - ) + start = np.searchsorted(self.spike_vector["segment_index"], segment_index, side="left") + end = np.searchsorted(self.spike_vector["segment_index"], segment_index, side="right") spikes = self.spike_vector[start:end] amplitude_vec = amplitude_vector[start:end] if amplitude_vector is not None else None upsample_vec = upsample_vector[start:end] if upsample_vector is not None else None