Skip to content

Commit

Permalink
Some more searchsorted
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Sep 15, 2023
1 parent 9c6e6c1 commit 646455a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 12 deletions.
15 changes: 5 additions & 10 deletions src/spikeinterface/postprocessing/amplitude_scalings.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx)

spikes_in_segment = spikes[segment_slices[segment_index]]

i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame)
i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame)
i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame])

if i0 != i1:
local_spikes = spikes_in_segment[i0:i1]
Expand All @@ -334,8 +333,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx)
# set colliding spikes apart (if needed)
if handle_collisions:
# local spikes with margin!
i0_margin = np.searchsorted(spikes_in_segment["sample_index"], start_frame - left)
i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right)
i0_margin, i1_margin = np.searchsorted(spikes_in_segment["sample_index"], [start_frame - left, end_frame + right])
local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin]
collisions_local = find_collisions(
local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices
Expand Down Expand Up @@ -461,14 +459,11 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_
spike_index_w_margin = np.where(spikes_w_margin == spike)[0][0]

# find the possible spikes per and post within delta_collision_samples
consecutive_window_pre = np.searchsorted(
consecutive_window_pre, consecutive_window_post = np.searchsorted(
spikes_w_margin["sample_index"],
spike["sample_index"] - delta_collision_samples,
)
consecutive_window_post = np.searchsorted(
spikes_w_margin["sample_index"],
spike["sample_index"] + delta_collision_samples,
[spike["sample_index"] - delta_collision_samples, spike["sample_index"] + delta_collision_samples]
)

# exclude the spike itself (it is included in the collision_spikes by construction)
pre_possible_consecutive_spike_indices = np.arange(consecutive_window_pre, spike_index_w_margin)
post_possible_consecutive_spike_indices = np.arange(spike_index_w_margin + 1, consecutive_window_post)
Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ def plot(self):
num_frames = int(duration / self.bin_duration_s)

def animate_func(i):
i0 = np.searchsorted(peaks["sample_index"], bin_size * i)
i1 = np.searchsorted(peaks["sample_index"], bin_size * (i + 1))
i0, i1 = np.searchsorted(peaks["sample_index"], [bin_size * i, bin_size * (i + 1)])
local_peaks = peaks[i0:i1]
artists = self._plot_one_bin(rec, probe, local_peaks, self.bin_duration_s)
return artists
Expand Down

0 comments on commit 646455a

Please sign in to comment.