Skip to content

Commit

Permalink
Merge branch 'main' into skip_no_peaks
Browse files Browse the repository at this point in the history
  • Loading branch information
yger authored Sep 19, 2023
2 parents 4bee4b1 + fc95465 commit f4accb7
Show file tree
Hide file tree
Showing 11 changed files with 40 additions and 61 deletions.
3 changes: 1 addition & 2 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,7 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac
if not concatenated:
spikes_ = []
for segment_index in range(self.get_num_segments()):
s0 = np.searchsorted(spikes["segment_index"], segment_index, side="left")
s1 = np.searchsorted(spikes["segment_index"], segment_index + 1, side="left")
s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1], side="left")
spikes_.append(spikes[s0:s1])
spikes = spikes_

Expand Down
12 changes: 4 additions & 8 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ def __init__(self, recording, peaks):
# precompute segment slice
self.segment_slices = []
for segment_index in range(recording.get_num_segments()):
i0 = np.searchsorted(peaks["segment_index"], segment_index)
i1 = np.searchsorted(peaks["segment_index"], segment_index + 1)
i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1])
self.segment_slices.append(slice(i0, i1))

def get_trace_margin(self):
Expand All @@ -131,8 +130,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
# get local peaks
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame)
i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame)
i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame])
local_peaks = peaks_in_segment[i0:i1]

# make sample index local to traces
Expand Down Expand Up @@ -189,8 +187,7 @@ def __init__(
# precompute segment slice
self.segment_slices = []
for segment_index in range(recording.get_num_segments()):
i0 = np.searchsorted(self.peaks["segment_index"], segment_index)
i1 = np.searchsorted(self.peaks["segment_index"], segment_index + 1)
i0, i1 = np.searchsorted(self.peaks["segment_index"], [segment_index, segment_index + 1])
self.segment_slices.append(slice(i0, i1))

def get_trace_margin(self):
Expand All @@ -209,8 +206,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
# get local peaks
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame)
i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame)
i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame])
local_peaks = peaks_in_segment[i0:i1]

# make sample index local to traces
Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,7 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame):
if self.spikes_in_seg is None:
# the slicing of segment is done only once the first time
# this fasten the constructor a lot
s0 = np.searchsorted(self.spikes["segment_index"], self.segment_index, side="left")
s1 = np.searchsorted(self.spikes["segment_index"], self.segment_index + 1, side="left")
s0, s1 = np.searchsorted(self.spikes["segment_index"], [self.segment_index, self.segment_index + 1])
self.spikes_in_seg = self.spikes[s0:s1]

unit_index = self.unit_ids.index(unit_id)
Expand Down
6 changes: 2 additions & 4 deletions src/spikeinterface/core/segmentutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
# Return (0 * num_channels) array of correct dtype
return self.parent_segments[0].get_traces(0, 0, channel_indices)

i0 = np.searchsorted(self.cumsum_length, start_frame, side="right") - 1
i1 = np.searchsorted(self.cumsum_length, end_frame, side="right") - 1
i0, i1 = np.searchsorted(self.cumsum_length, [start_frame, end_frame], side="right") - 1

# several case:
# * come from one segment (i0 == i1)
Expand Down Expand Up @@ -469,8 +468,7 @@ def get_unit_spike_train(
if end_frame is None:
end_frame = self.get_num_samples()

i0 = np.searchsorted(self.cumsum_length, start_frame, side="right") - 1
i1 = np.searchsorted(self.cumsum_length, end_frame, side="right") - 1
i0, i1 = np.searchsorted(self.cumsum_length, [start_frame, end_frame], side="right") - 1

# several case:
# * come from one segment (i0 == i1)
Expand Down
19 changes: 9 additions & 10 deletions src/spikeinterface/core/waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,15 +344,15 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx

# take only spikes with the correct segment_index
# this is a slice so no copy!!
s0 = np.searchsorted(spikes["segment_index"], segment_index)
s1 = np.searchsorted(spikes["segment_index"], segment_index + 1)
s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1])
in_seg_spikes = spikes[s0:s1]

# take only spikes in range [start_frame, end_frame]
# this is a slice so no copy!!
# the border of segment are protected by nbefore on left an nafter on the right
i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore))
i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter))
i0, i1 = np.searchsorted(
in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)]
)

# slice in absolut in spikes vector
l0 = i0 + s0
Expand Down Expand Up @@ -562,8 +562,7 @@ def _init_worker_distribute_single_buffer(
# prepare segment slices
segment_slices = []
for segment_index in range(recording.get_num_segments()):
s0 = np.searchsorted(spikes["segment_index"], segment_index)
s1 = np.searchsorted(spikes["segment_index"], segment_index + 1)
s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1])
segment_slices.append((s0, s1))
worker_ctx["segment_slices"] = segment_slices

Expand All @@ -590,8 +589,9 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work
# take only spikes in range [start_frame, end_frame]
# this is a slice so no copy!!
# the border of segment are protected by nbefore on left an nafter on the right
i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore))
i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter))
i0, i1 = np.searchsorted(
in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)]
)

# slice in absolut in spikes vector
l0 = i0 + s0
Expand Down Expand Up @@ -685,8 +685,7 @@ def has_exceeding_spikes(recording, sorting):
"""
spike_vector = sorting.to_spike_vector()
for segment_index in range(recording.get_num_segments()):
start_seg_ind = np.searchsorted(spike_vector["segment_index"], segment_index)
end_seg_ind = np.searchsorted(spike_vector["segment_index"], segment_index + 1)
start_seg_ind, end_seg_ind = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1])
spike_vector_seg = spike_vector[start_seg_ind:end_seg_ind]
if len(spike_vector_seg) > 0:
if spike_vector_seg["sample_index"][-1] > recording.get_num_samples(segment_index=segment_index) - 1:
Expand Down
20 changes: 8 additions & 12 deletions src/spikeinterface/postprocessing/amplitude_scalings.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ def _run(self, **job_kwargs):
# precompute segment slice
segment_slices = []
for segment_index in range(we.get_num_segments()):
i0 = np.searchsorted(self.spikes["segment_index"], segment_index)
i1 = np.searchsorted(self.spikes["segment_index"], segment_index + 1)
i0, i1 = np.searchsorted(self.spikes["segment_index"], [segment_index, segment_index + 1])
segment_slices.append(slice(i0, i1))

# and run
Expand Down Expand Up @@ -317,8 +316,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx)

spikes_in_segment = spikes[segment_slices[segment_index]]

i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame)
i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame)
i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame])

if i0 != i1:
local_spikes = spikes_in_segment[i0:i1]
Expand All @@ -335,8 +333,9 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx)
# set colliding spikes apart (if needed)
if handle_collisions:
# local spikes with margin!
i0_margin = np.searchsorted(spikes_in_segment["sample_index"], start_frame - left)
i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right)
i0_margin, i1_margin = np.searchsorted(
spikes_in_segment["sample_index"], [start_frame - left, end_frame + right]
)
local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin]
collisions_local = find_collisions(
local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices
Expand Down Expand Up @@ -462,14 +461,11 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_
spike_index_w_margin = np.where(spikes_w_margin == spike)[0][0]

# find the possible spikes per and post within delta_collision_samples
consecutive_window_pre = np.searchsorted(
spikes_w_margin["sample_index"],
spike["sample_index"] - delta_collision_samples,
)
consecutive_window_post = np.searchsorted(
consecutive_window_pre, consecutive_window_post = np.searchsorted(
spikes_w_margin["sample_index"],
spike["sample_index"] + delta_collision_samples,
[spike["sample_index"] - delta_collision_samples, spike["sample_index"] + delta_collision_samples],
)

# exclude the spike itself (it is included in the collision_spikes by construction)
pre_possible_consecutive_spike_indices = np.arange(consecutive_window_pre, spike_index_w_margin)
post_possible_consecutive_spike_indices = np.arange(spike_index_w_margin + 1, consecutive_window_post)
Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,7 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx):

seg_size = recording.get_num_samples(segment_index=segment_index)

i0 = np.searchsorted(spike_times, start_frame)
i1 = np.searchsorted(spike_times, end_frame)
i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame])

if i0 != i1:
# protect from spikes on border : spike_time<0 or spike_time>seg_size
Expand Down
4 changes: 1 addition & 3 deletions src/spikeinterface/postprocessing/spike_amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,7 @@ def _spike_amplitudes_chunk(segment_index, start_frame, end_frame, worker_ctx):
d = np.diff(spike_times)
assert np.all(d >= 0)

i0 = np.searchsorted(spike_times, start_frame)
i1 = np.searchsorted(spike_times, end_frame)

i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame])
n_spikes = i1 - i0
amplitudes = np.zeros(n_spikes, dtype=recording.get_dtype())

Expand Down
22 changes: 10 additions & 12 deletions src/spikeinterface/preprocessing/detect_bad_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def detect_bad_channels(
nyquist_threshold=0.8,
direction="y",
chunk_duration_s=0.3,
num_random_chunks=10,
num_random_chunks=100,
welch_window_ms=10.0,
highpass_filter_cutoff=300,
neighborhood_r2_threshold=0.9,
Expand Down Expand Up @@ -81,9 +81,10 @@ def detect_bad_channels(
highpass_filter_cutoff : float
If the recording is not filtered, the cutoff frequency of the highpass filter, by default 300
chunk_duration_s : float
Duration of each chunk, by default 0.3
Duration of each chunk, by default 0.5
num_random_chunks : int
Number of random chunks, by default 10
Number of random chunks, by default 100
Having many chunks is important for reproducibility.
welch_window_ms : float
Window size for the scipy.signal.welch that will be converted to nperseg, by default 10ms
neighborhood_r2_threshold : float, default 0.95
Expand Down Expand Up @@ -174,20 +175,18 @@ def detect_bad_channels(
channel_locations = recording.get_channel_locations()
dim = ["x", "y", "z"].index(direction)
assert dim < channel_locations.shape[1], f"Direction {direction} is wrong"
locs_depth = channel_locations[:, dim]
if np.array_equal(np.sort(locs_depth), locs_depth):
order_f, order_r = order_channels_by_depth(recording=recording, dimensions=("x", "y"))
if np.all(np.diff(order_f) == 1):
# already ordered
order_f = None
order_r = None
else:
# sort by x, y to avoid ambiguity
order_f, order_r = order_channels_by_depth(recording=recording, dimensions=("x", "y"))

# Create empty channel labels and fill with bad-channel detection estimate for each chunk
chunk_channel_labels = np.zeros((recording.get_num_channels(), len(random_data)), dtype=np.int8)

for i, random_chunk in enumerate(random_data):
random_chunk_sorted = random_chunk[order_f] if order_f is not None else random_chunk
chunk_channel_labels[:, i] = detect_bad_channels_ibl(
random_chunk_sorted = random_chunk[:, order_f] if order_f is not None else random_chunk
chunk_labels = detect_bad_channels_ibl(
raw=random_chunk_sorted,
fs=recording.sampling_frequency,
psd_hf_threshold=psd_hf_threshold,
Expand All @@ -198,11 +197,10 @@ def detect_bad_channels(
nyquist_threshold=nyquist_threshold,
welch_window_ms=welch_window_ms,
)
chunk_channel_labels[:, i] = chunk_labels[order_r] if order_r is not None else chunk_labels

# Take the mode of the chunk estimates as final result. Convert to binary good / bad channel output.
mode_channel_labels, _ = scipy.stats.mode(chunk_channel_labels, axis=1, keepdims=False)
if order_r is not None:
mode_channel_labels = mode_channel_labels[order_r]

(bad_inds,) = np.where(mode_channel_labels != 0)
bad_channel_ids = recording.channel_ids[bad_inds]
Expand Down
6 changes: 2 additions & 4 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,16 +848,14 @@ def compute_drift_metrics(
spike_vector = sorting.to_spike_vector()

# retrieve spikes in segment
i0 = np.searchsorted(spike_vector["segment_index"], segment_index)
i1 = np.searchsorted(spike_vector["segment_index"], segment_index + 1)
i0, i1 = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1])
spikes_in_segment = spike_vector[i0:i1]
spike_locations_in_segment = spike_locations[i0:i1]

# compute median positions (if less than min_spikes_per_interval, median position is 0)
median_positions = np.nan * np.zeros((len(unit_ids), num_bin_edges - 1))
for bin_index, (start_frame, end_frame) in enumerate(zip(bins[:-1], bins[1:])):
i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame)
i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame)
i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame])
spikes_in_bin = spikes_in_segment[i0:i1]
spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction]

Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ def plot(self):
num_frames = int(duration / self.bin_duration_s)

def animate_func(i):
i0 = np.searchsorted(peaks["sample_index"], bin_size * i)
i1 = np.searchsorted(peaks["sample_index"], bin_size * (i + 1))
i0, i1 = np.searchsorted(peaks["sample_index"], [bin_size * i, bin_size * (i + 1)])
local_peaks = peaks[i0:i1]
artists = self._plot_one_bin(rec, probe, local_peaks, self.bin_duration_s)
return artists
Expand Down

0 comments on commit f4accb7

Please sign in to comment.