Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up searchsorted calls #2000

Merged
merged 13 commits into from
Sep 19, 2023
3 changes: 1 addition & 2 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,7 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac
if not concatenated:
spikes_ = []
for segment_index in range(self.get_num_segments()):
s0 = np.searchsorted(spikes["segment_index"], segment_index, side="left")
s1 = np.searchsorted(spikes["segment_index"], segment_index + 1, side="left")
s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1], side="left")
spikes_.append(spikes[s0:s1])
spikes = spikes_

Expand Down
12 changes: 8 additions & 4 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,8 +1109,9 @@ def __init__(
num_samples = [num_samples]

for segment_index in range(sorting.get_num_segments()):
start = np.searchsorted(self.spike_vector["segment_index"], segment_index, side="left")
end = np.searchsorted(self.spike_vector["segment_index"], segment_index, side="right")
start, end = np.searchsorted(
self.spike_vector["segment_index"], [segment_index, segment_index + 1], side="left"
)
spikes = self.spike_vector[start:end]
amplitude_vec = amplitude_vector[start:end] if amplitude_vector is not None else None
upsample_vec = upsample_vector[start:end] if upsample_vector is not None else None
Expand Down Expand Up @@ -1208,8 +1209,11 @@ def get_traces(
else:
traces = np.zeros([end_frame - start_frame, n_channels], dtype=self.dtype)

start = np.searchsorted(self.spike_vector["sample_index"], start_frame - self.templates.shape[1], side="left")
end = np.searchsorted(self.spike_vector["sample_index"], end_frame + self.templates.shape[1], side="right")
start, end = np.searchsorted(
self.spike_vector["sample_index"],
[start_frame - self.templates.shape[1], end_frame + self.templates.shape[1] + 1],
side="left",
)

yger marked this conversation as resolved.
Show resolved Hide resolved
for i in range(start, end):
spike = self.spike_vector[i]
Expand Down
12 changes: 4 additions & 8 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ def __init__(self, recording, peaks):
# precompute segment slice
self.segment_slices = []
for segment_index in range(recording.get_num_segments()):
i0 = np.searchsorted(peaks["segment_index"], segment_index)
i1 = np.searchsorted(peaks["segment_index"], segment_index + 1)
i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1])
self.segment_slices.append(slice(i0, i1))

def get_trace_margin(self):
Expand All @@ -125,8 +124,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
# get local peaks
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame)
i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame)
i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame])
local_peaks = peaks_in_segment[i0:i1]

# make sample index local to traces
Expand Down Expand Up @@ -183,8 +181,7 @@ def __init__(
# precompute segment slice
self.segment_slices = []
for segment_index in range(recording.get_num_segments()):
i0 = np.searchsorted(self.peaks["segment_index"], segment_index)
i1 = np.searchsorted(self.peaks["segment_index"], segment_index + 1)
i0, i1 = np.searchsorted(self.peaks["segment_index"], [segment_index, segment_index + 1])
self.segment_slices.append(slice(i0, i1))

def get_trace_margin(self):
Expand All @@ -197,8 +194,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
# get local peaks
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame)
i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame)
i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame])
local_peaks = peaks_in_segment[i0:i1]

# make sample index local to traces
Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,7 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame):
if self.spikes_in_seg is None:
# the slicing of segment is done only once the first time
# this fasten the constructor a lot
s0 = np.searchsorted(self.spikes["segment_index"], self.segment_index, side="left")
s1 = np.searchsorted(self.spikes["segment_index"], self.segment_index + 1, side="left")
s0, s1 = np.searchsorted(self.spikes["segment_index"], [self.segment_index, self.segment_index + 1])
self.spikes_in_seg = self.spikes[s0:s1]

unit_index = self.unit_ids.index(unit_id)
Expand Down
6 changes: 2 additions & 4 deletions src/spikeinterface/core/segmentutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
# Return (0 * num_channels) array of correct dtype
return self.parent_segments[0].get_traces(0, 0, channel_indices)

i0 = np.searchsorted(self.cumsum_length, start_frame, side="right") - 1
i1 = np.searchsorted(self.cumsum_length, end_frame, side="right") - 1
i0, i1 = np.searchsorted(self.cumsum_length, [start_frame, end_frame], side="right") - 1

# several case:
# * come from one segment (i0 == i1)
Expand Down Expand Up @@ -469,8 +468,7 @@ def get_unit_spike_train(
if end_frame is None:
end_frame = self.get_num_samples()

i0 = np.searchsorted(self.cumsum_length, start_frame, side="right") - 1
i1 = np.searchsorted(self.cumsum_length, end_frame, side="right") - 1
i0, i1 = np.searchsorted(self.cumsum_length, [start_frame, end_frame], side="right") - 1

# several case:
# * come from one segment (i0 == i1)
Expand Down
19 changes: 9 additions & 10 deletions src/spikeinterface/core/waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,15 +344,15 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx

# take only spikes with the correct segment_index
# this is a slice so no copy!!
s0 = np.searchsorted(spikes["segment_index"], segment_index)
s1 = np.searchsorted(spikes["segment_index"], segment_index + 1)
s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1])
in_seg_spikes = spikes[s0:s1]

# take only spikes in range [start_frame, end_frame]
# this is a slice so no copy!!
# the border of segment are protected by nbefore on left an nafter on the right
i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore))
i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter))
i0, i1 = np.searchsorted(
in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)]
)

# slice in absolut in spikes vector
l0 = i0 + s0
Expand Down Expand Up @@ -562,8 +562,7 @@ def _init_worker_distribute_single_buffer(
# prepare segment slices
segment_slices = []
for segment_index in range(recording.get_num_segments()):
s0 = np.searchsorted(spikes["segment_index"], segment_index)
s1 = np.searchsorted(spikes["segment_index"], segment_index + 1)
s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1])
segment_slices.append((s0, s1))
worker_ctx["segment_slices"] = segment_slices

Expand All @@ -590,8 +589,9 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work
# take only spikes in range [start_frame, end_frame]
# this is a slice so no copy!!
# the border of segment are protected by nbefore on left an nafter on the right
i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore))
i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter))
i0, i1 = np.searchsorted(
in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)]
)

# slice in absolut in spikes vector
l0 = i0 + s0
Expand Down Expand Up @@ -685,8 +685,7 @@ def has_exceeding_spikes(recording, sorting):
"""
spike_vector = sorting.to_spike_vector()
for segment_index in range(recording.get_num_segments()):
start_seg_ind = np.searchsorted(spike_vector["segment_index"], segment_index)
end_seg_ind = np.searchsorted(spike_vector["segment_index"], segment_index + 1)
start_seg_ind, end_seg_ind = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1])
spike_vector_seg = spike_vector[start_seg_ind:end_seg_ind]
if len(spike_vector_seg) > 0:
if spike_vector_seg["sample_index"][-1] > recording.get_num_samples(segment_index=segment_index) - 1:
Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/curation/remove_duplicated_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ def get_unit_spike_train(
if end_frame == None:
end_frame = spike_train[-1] if len(spike_train) > 0 else 0

start = np.searchsorted(spike_train, start_frame, side="left")
end = np.searchsorted(spike_train, end_frame, side="right")
start, end = np.searchsorted(spike_train, [start_frame, end_frame + 1], side="left")
yger marked this conversation as resolved.
Show resolved Hide resolved

return spike_train[start:end]

Expand Down
20 changes: 8 additions & 12 deletions src/spikeinterface/postprocessing/amplitude_scalings.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ def _run(self, **job_kwargs):
# precompute segment slice
segment_slices = []
for segment_index in range(we.get_num_segments()):
i0 = np.searchsorted(self.spikes["segment_index"], segment_index)
i1 = np.searchsorted(self.spikes["segment_index"], segment_index + 1)
i0, i1 = np.searchsorted(self.spikes["segment_index"], [segment_index, segment_index + 1])
segment_slices.append(slice(i0, i1))

# and run
Expand Down Expand Up @@ -317,8 +316,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx)

spikes_in_segment = spikes[segment_slices[segment_index]]

i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame)
i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame)
i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame])

if i0 != i1:
local_spikes = spikes_in_segment[i0:i1]
Expand All @@ -335,8 +333,9 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx)
# set colliding spikes apart (if needed)
if handle_collisions:
# local spikes with margin!
i0_margin = np.searchsorted(spikes_in_segment["sample_index"], start_frame - left)
i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right)
i0_margin, i1_margin = np.searchsorted(
spikes_in_segment["sample_index"], [start_frame - left, end_frame + right]
)
local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin]
collisions_local = find_collisions(
local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices
Expand Down Expand Up @@ -462,14 +461,11 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_
spike_index_w_margin = np.where(spikes_w_margin == spike)[0][0]

# find the possible spikes per and post within delta_collision_samples
consecutive_window_pre = np.searchsorted(
spikes_w_margin["sample_index"],
spike["sample_index"] - delta_collision_samples,
)
consecutive_window_post = np.searchsorted(
consecutive_window_pre, consecutive_window_post = np.searchsorted(
spikes_w_margin["sample_index"],
spike["sample_index"] + delta_collision_samples,
[spike["sample_index"] - delta_collision_samples, spike["sample_index"] + delta_collision_samples],
)

# exclude the spike itself (it is included in the collision_spikes by construction)
pre_possible_consecutive_spike_indices = np.arange(consecutive_window_pre, spike_index_w_margin)
post_possible_consecutive_spike_indices = np.arange(spike_index_w_margin + 1, consecutive_window_post)
Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,7 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx):

seg_size = recording.get_num_samples(segment_index=segment_index)

i0 = np.searchsorted(spike_times, start_frame)
i1 = np.searchsorted(spike_times, end_frame)
i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame])

if i0 != i1:
# protect from spikes on border : spike_time<0 or spike_time>seg_size
Expand Down
4 changes: 1 addition & 3 deletions src/spikeinterface/postprocessing/spike_amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,7 @@ def _spike_amplitudes_chunk(segment_index, start_frame, end_frame, worker_ctx):
d = np.diff(spike_times)
assert np.all(d >= 0)

i0 = np.searchsorted(spike_times, start_frame)
i1 = np.searchsorted(spike_times, end_frame)

i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame])
n_spikes = i1 - i0
amplitudes = np.zeros(n_spikes, dtype=recording.get_dtype())

Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/postprocessing/spike_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def get_data(self, outputs="concatenated"):
elif outputs == "by_unit":
locations_by_unit = []
for segment_index in range(self.waveform_extractor.get_num_segments()):
i0 = np.searchsorted(self.spikes["segment_index"], segment_index, side="left")
i1 = np.searchsorted(self.spikes["segment_index"], segment_index, side="right")
i0, i1 = np.searchsorted(self.spikes["segment_index"], [segment_index, segment_index + 1], side="left")
yger marked this conversation as resolved.
Show resolved Hide resolved
spikes = self.spikes[i0:i1]
locations = self._extension_data["spike_locations"][i0:i1]

Expand Down
6 changes: 2 additions & 4 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,16 +848,14 @@ def compute_drift_metrics(
spike_vector = sorting.to_spike_vector()

# retrieve spikes in segment
i0 = np.searchsorted(spike_vector["segment_index"], segment_index)
i1 = np.searchsorted(spike_vector["segment_index"], segment_index + 1)
i0, i1 = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1])
spikes_in_segment = spike_vector[i0:i1]
spike_locations_in_segment = spike_locations[i0:i1]

# compute median positions (if less than min_spikes_per_interval, median position is 0)
median_positions = np.nan * np.zeros((len(unit_ids), num_bin_edges - 1))
for bin_index, (start_frame, end_frame) in enumerate(zip(bins[:-1], bins[1:])):
i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame)
i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame)
i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame])
spikes_in_bin = spikes_in_segment[i0:i1]
spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction]

Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/sortingcomponents/motion_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,7 @@ def interpolate_motion_on_traces(
**spatial_interpolation_kwargs,
)

i0 = np.searchsorted(bin_inds, bin_ind, side="left")
i1 = np.searchsorted(bin_inds, bin_ind, side="right")
yger marked this conversation as resolved.
Show resolved Hide resolved
i0, i1 = np.searchsorted(bin_inds, [bin_ind, bin_ind + 1], side="left")

# here we use a simple np.matmul even if dirft_kernel can be super sparse.
# because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing
Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ def plot(self):
num_frames = int(duration / self.bin_duration_s)

def animate_func(i):
i0 = np.searchsorted(peaks["sample_index"], bin_size * i)
i1 = np.searchsorted(peaks["sample_index"], bin_size * (i + 1))
i0, i1 = np.searchsorted(peaks["sample_index"], [bin_size * i, bin_size * (i + 1)])
local_peaks = peaks[i0:i1]
artists = self._plot_one_bin(rec, probe, local_peaks, self.bin_duration_s)
return artists
Expand Down