Skip to content

Commit

Permalink
Merge pull request #2000 from yger/searchsorted
Browse files Browse the repository at this point in the history
Speed up searchsorted calls
  • Loading branch information
samuelgarcia authored Sep 19, 2023
2 parents 2f3bb29 + 3e860d4 commit 855a264
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 49 deletions.
3 changes: 1 addition & 2 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,7 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac
if not concatenated:
spikes_ = []
for segment_index in range(self.get_num_segments()):
s0 = np.searchsorted(spikes["segment_index"], segment_index, side="left")
s1 = np.searchsorted(spikes["segment_index"], segment_index + 1, side="left")
s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1], side="left")
spikes_.append(spikes[s0:s1])
spikes = spikes_

Expand Down
12 changes: 4 additions & 8 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ def __init__(self, recording, peaks):
# precompute segment slice
self.segment_slices = []
for segment_index in range(recording.get_num_segments()):
i0 = np.searchsorted(peaks["segment_index"], segment_index)
i1 = np.searchsorted(peaks["segment_index"], segment_index + 1)
i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1])
self.segment_slices.append(slice(i0, i1))

def get_trace_margin(self):
Expand All @@ -125,8 +124,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
# get local peaks
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame)
i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame)
i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame])
local_peaks = peaks_in_segment[i0:i1]

# make sample index local to traces
Expand Down Expand Up @@ -183,8 +181,7 @@ def __init__(
# precompute segment slice
self.segment_slices = []
for segment_index in range(recording.get_num_segments()):
i0 = np.searchsorted(self.peaks["segment_index"], segment_index)
i1 = np.searchsorted(self.peaks["segment_index"], segment_index + 1)
i0, i1 = np.searchsorted(self.peaks["segment_index"], [segment_index, segment_index + 1])
self.segment_slices.append(slice(i0, i1))

def get_trace_margin(self):
Expand All @@ -197,8 +194,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
# get local peaks
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame)
i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame)
i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame])
local_peaks = peaks_in_segment[i0:i1]

# make sample index local to traces
Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,7 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame):
if self.spikes_in_seg is None:
# the slicing of segment is done only once the first time
# this fasten the constructor a lot
s0 = np.searchsorted(self.spikes["segment_index"], self.segment_index, side="left")
s1 = np.searchsorted(self.spikes["segment_index"], self.segment_index + 1, side="left")
s0, s1 = np.searchsorted(self.spikes["segment_index"], [self.segment_index, self.segment_index + 1])
self.spikes_in_seg = self.spikes[s0:s1]

unit_index = self.unit_ids.index(unit_id)
Expand Down
6 changes: 2 additions & 4 deletions src/spikeinterface/core/segmentutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
# Return (0 * num_channels) array of correct dtype
return self.parent_segments[0].get_traces(0, 0, channel_indices)

i0 = np.searchsorted(self.cumsum_length, start_frame, side="right") - 1
i1 = np.searchsorted(self.cumsum_length, end_frame, side="right") - 1
i0, i1 = np.searchsorted(self.cumsum_length, [start_frame, end_frame], side="right") - 1

# several case:
# * come from one segment (i0 == i1)
Expand Down Expand Up @@ -469,8 +468,7 @@ def get_unit_spike_train(
if end_frame is None:
end_frame = self.get_num_samples()

i0 = np.searchsorted(self.cumsum_length, start_frame, side="right") - 1
i1 = np.searchsorted(self.cumsum_length, end_frame, side="right") - 1
i0, i1 = np.searchsorted(self.cumsum_length, [start_frame, end_frame], side="right") - 1

# several case:
# * come from one segment (i0 == i1)
Expand Down
19 changes: 9 additions & 10 deletions src/spikeinterface/core/waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,15 +344,15 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx

# take only spikes with the correct segment_index
# this is a slice so no copy!!
s0 = np.searchsorted(spikes["segment_index"], segment_index)
s1 = np.searchsorted(spikes["segment_index"], segment_index + 1)
s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1])
in_seg_spikes = spikes[s0:s1]

# take only spikes in range [start_frame, end_frame]
# this is a slice so no copy!!
# the border of segment are protected by nbefore on left an nafter on the right
i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore))
i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter))
i0, i1 = np.searchsorted(
in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)]
)

# slice in absolut in spikes vector
l0 = i0 + s0
Expand Down Expand Up @@ -562,8 +562,7 @@ def _init_worker_distribute_single_buffer(
# prepare segment slices
segment_slices = []
for segment_index in range(recording.get_num_segments()):
s0 = np.searchsorted(spikes["segment_index"], segment_index)
s1 = np.searchsorted(spikes["segment_index"], segment_index + 1)
s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1])
segment_slices.append((s0, s1))
worker_ctx["segment_slices"] = segment_slices

Expand All @@ -590,8 +589,9 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work
# take only spikes in range [start_frame, end_frame]
# this is a slice so no copy!!
# the border of segment are protected by nbefore on left an nafter on the right
i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore))
i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter))
i0, i1 = np.searchsorted(
in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)]
)

# slice in absolut in spikes vector
l0 = i0 + s0
Expand Down Expand Up @@ -685,8 +685,7 @@ def has_exceeding_spikes(recording, sorting):
"""
spike_vector = sorting.to_spike_vector()
for segment_index in range(recording.get_num_segments()):
start_seg_ind = np.searchsorted(spike_vector["segment_index"], segment_index)
end_seg_ind = np.searchsorted(spike_vector["segment_index"], segment_index + 1)
start_seg_ind, end_seg_ind = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1])
spike_vector_seg = spike_vector[start_seg_ind:end_seg_ind]
if len(spike_vector_seg) > 0:
if spike_vector_seg["sample_index"][-1] > recording.get_num_samples(segment_index=segment_index) - 1:
Expand Down
20 changes: 8 additions & 12 deletions src/spikeinterface/postprocessing/amplitude_scalings.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ def _run(self, **job_kwargs):
# precompute segment slice
segment_slices = []
for segment_index in range(we.get_num_segments()):
i0 = np.searchsorted(self.spikes["segment_index"], segment_index)
i1 = np.searchsorted(self.spikes["segment_index"], segment_index + 1)
i0, i1 = np.searchsorted(self.spikes["segment_index"], [segment_index, segment_index + 1])
segment_slices.append(slice(i0, i1))

# and run
Expand Down Expand Up @@ -317,8 +316,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx)

spikes_in_segment = spikes[segment_slices[segment_index]]

i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame)
i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame)
i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame])

if i0 != i1:
local_spikes = spikes_in_segment[i0:i1]
Expand All @@ -335,8 +333,9 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx)
# set colliding spikes apart (if needed)
if handle_collisions:
# local spikes with margin!
i0_margin = np.searchsorted(spikes_in_segment["sample_index"], start_frame - left)
i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right)
i0_margin, i1_margin = np.searchsorted(
spikes_in_segment["sample_index"], [start_frame - left, end_frame + right]
)
local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin]
collisions_local = find_collisions(
local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices
Expand Down Expand Up @@ -462,14 +461,11 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_
spike_index_w_margin = np.where(spikes_w_margin == spike)[0][0]

# find the possible spikes per and post within delta_collision_samples
consecutive_window_pre = np.searchsorted(
spikes_w_margin["sample_index"],
spike["sample_index"] - delta_collision_samples,
)
consecutive_window_post = np.searchsorted(
consecutive_window_pre, consecutive_window_post = np.searchsorted(
spikes_w_margin["sample_index"],
spike["sample_index"] + delta_collision_samples,
[spike["sample_index"] - delta_collision_samples, spike["sample_index"] + delta_collision_samples],
)

# exclude the spike itself (it is included in the collision_spikes by construction)
pre_possible_consecutive_spike_indices = np.arange(consecutive_window_pre, spike_index_w_margin)
post_possible_consecutive_spike_indices = np.arange(spike_index_w_margin + 1, consecutive_window_post)
Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,7 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx):

seg_size = recording.get_num_samples(segment_index=segment_index)

i0 = np.searchsorted(spike_times, start_frame)
i1 = np.searchsorted(spike_times, end_frame)
i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame])

if i0 != i1:
# protect from spikes on border : spike_time<0 or spike_time>seg_size
Expand Down
4 changes: 1 addition & 3 deletions src/spikeinterface/postprocessing/spike_amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,7 @@ def _spike_amplitudes_chunk(segment_index, start_frame, end_frame, worker_ctx):
d = np.diff(spike_times)
assert np.all(d >= 0)

i0 = np.searchsorted(spike_times, start_frame)
i1 = np.searchsorted(spike_times, end_frame)

i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame])
n_spikes = i1 - i0
amplitudes = np.zeros(n_spikes, dtype=recording.get_dtype())

Expand Down
6 changes: 2 additions & 4 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,16 +848,14 @@ def compute_drift_metrics(
spike_vector = sorting.to_spike_vector()

# retrieve spikes in segment
i0 = np.searchsorted(spike_vector["segment_index"], segment_index)
i1 = np.searchsorted(spike_vector["segment_index"], segment_index + 1)
i0, i1 = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1])
spikes_in_segment = spike_vector[i0:i1]
spike_locations_in_segment = spike_locations[i0:i1]

# compute median positions (if less than min_spikes_per_interval, median position is 0)
median_positions = np.nan * np.zeros((len(unit_ids), num_bin_edges - 1))
for bin_index, (start_frame, end_frame) in enumerate(zip(bins[:-1], bins[1:])):
i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame)
i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame)
i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame])
spikes_in_bin = spikes_in_segment[i0:i1]
spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction]

Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ def plot(self):
num_frames = int(duration / self.bin_duration_s)

def animate_func(i):
i0 = np.searchsorted(peaks["sample_index"], bin_size * i)
i1 = np.searchsorted(peaks["sample_index"], bin_size * (i + 1))
i0, i1 = np.searchsorted(peaks["sample_index"], [bin_size * i, bin_size * (i + 1)])
local_peaks = peaks[i0:i1]
artists = self._plot_one_bin(rec, probe, local_peaks, self.bin_duration_s)
return artists
Expand Down

0 comments on commit 855a264

Please sign in to comment.