diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 056134a24e..e6d08d38f7 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -473,8 +473,7 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac if not concatenated: spikes_ = [] for segment_index in range(self.get_num_segments()): - s0 = np.searchsorted(spikes["segment_index"], segment_index, side="left") - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1, side="left") + s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1], side="left") spikes_.append(spikes[s0:s1]) spikes = spikes_ diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 1ba7c59c82..3b2b610cf5 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -111,8 +111,7 @@ def __init__(self, recording, peaks): # precompute segment slice self.segment_slices = [] for segment_index in range(recording.get_num_segments()): - i0 = np.searchsorted(peaks["segment_index"], segment_index) - i1 = np.searchsorted(peaks["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1]) self.segment_slices.append(slice(i0, i1)) def get_trace_margin(self): @@ -131,8 +130,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] - i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces @@ -189,8 +187,7 @@ def __init__( # precompute segment slice self.segment_slices = [] for segment_index in range(recording.get_num_segments()): - i0 = np.searchsorted(self.peaks["segment_index"], segment_index) - i1 = np.searchsorted(self.peaks["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(self.peaks["segment_index"], [segment_index, segment_index + 1]) self.segment_slices.append(slice(i0, i1)) def get_trace_margin(self): @@ -209,8 +206,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] - i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 97f22615df..d5663156c7 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -338,8 +338,7 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): if self.spikes_in_seg is None: # the slicing of segment is done only once the first time # this fasten the constructor a lot - s0 = np.searchsorted(self.spikes["segment_index"], self.segment_index, side="left") - s1 = np.searchsorted(self.spikes["segment_index"], self.segment_index + 1, side="left") + s0, s1 = np.searchsorted(self.spikes["segment_index"], [self.segment_index, self.segment_index + 1]) self.spikes_in_seg = self.spikes[s0:s1] unit_index = self.unit_ids.index(unit_id) diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index f70c45bfe5..85e36cf7a5 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -174,8 +174,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): # Return (0 * num_channels) array of correct dtype return self.parent_segments[0].get_traces(0, 0, channel_indices) - i0 = np.searchsorted(self.cumsum_length, start_frame, side="right") - 1 - i1 = np.searchsorted(self.cumsum_length, end_frame, side="right") - 1 + i0, i1 = np.searchsorted(self.cumsum_length, [start_frame, end_frame], side="right") - 1 # several case: # * come from one segment (i0 == i1) @@ -469,8 +468,7 @@ def get_unit_spike_train( if end_frame is None: end_frame = self.get_num_samples() - i0 = np.searchsorted(self.cumsum_length, start_frame, side="right") - 1 - i1 = np.searchsorted(self.cumsum_length, end_frame, side="right") - 1 + i0, i1 = np.searchsorted(self.cumsum_length, [start_frame, end_frame], side="right") - 1 # several case: # * come from one segment (i0 == i1) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index da8e3d64b6..a2f1296e31 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -344,15 +344,15 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx # take only spikes with the correct segment_index # this is a slice so no copy!! - s0 = np.searchsorted(spikes["segment_index"], segment_index) - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) in_seg_spikes = spikes[s0:s1] # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! # the border of segment are protected by nbefore on left an nafter on the right - i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) - i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + i0, i1 = np.searchsorted( + in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)] + ) # slice in absolut in spikes vector l0 = i0 + s0 @@ -562,8 +562,7 @@ def _init_worker_distribute_single_buffer( # prepare segment slices segment_slices = [] for segment_index in range(recording.get_num_segments()): - s0 = np.searchsorted(spikes["segment_index"], segment_index) - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append((s0, s1)) worker_ctx["segment_slices"] = segment_slices @@ -590,8 +589,9 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! # the border of segment are protected by nbefore on left an nafter on the right - i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) - i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + i0, i1 = np.searchsorted( + in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)] + ) # slice in absolut in spikes vector l0 = i0 + s0 @@ -685,8 +685,7 @@ def has_exceeding_spikes(recording, sorting): """ spike_vector = sorting.to_spike_vector() for segment_index in range(recording.get_num_segments()): - start_seg_ind = np.searchsorted(spike_vector["segment_index"], segment_index) - end_seg_ind = np.searchsorted(spike_vector["segment_index"], segment_index + 1) + start_seg_ind, end_seg_ind = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1]) spike_vector_seg = spike_vector[start_seg_ind:end_seg_ind] if len(spike_vector_seg) > 0: if spike_vector_seg["sample_index"][-1] > recording.get_num_samples(segment_index=segment_index) - 1: diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 5a3542cdf9..22b40a51c5 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -99,8 +99,7 @@ def _run(self, **job_kwargs): # precompute segment slice segment_slices = [] for segment_index in range(we.get_num_segments()): - i0 = np.searchsorted(self.spikes["segment_index"], segment_index) - i1 = np.searchsorted(self.spikes["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(self.spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append(slice(i0, i1)) # and run @@ -317,8 +316,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) spikes_in_segment = spikes[segment_slices[segment_index]] - i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame]) if i0 != i1: local_spikes = spikes_in_segment[i0:i1] @@ -335,8 +333,9 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) # set colliding spikes apart (if needed) if handle_collisions: # local spikes with margin! - i0_margin = np.searchsorted(spikes_in_segment["sample_index"], start_frame - left) - i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right) + i0_margin, i1_margin = np.searchsorted( + spikes_in_segment["sample_index"], [start_frame - left, end_frame + right] + ) local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin] collisions_local = find_collisions( local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices @@ -462,14 +461,11 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_ spike_index_w_margin = np.where(spikes_w_margin == spike)[0][0] # find the possible spikes per and post within delta_collision_samples - consecutive_window_pre = np.searchsorted( - spikes_w_margin["sample_index"], - spike["sample_index"] - delta_collision_samples, - ) - consecutive_window_post = np.searchsorted( + consecutive_window_pre, consecutive_window_post = np.searchsorted( spikes_w_margin["sample_index"], - spike["sample_index"] + delta_collision_samples, + [spike["sample_index"] - delta_collision_samples, spike["sample_index"] + delta_collision_samples], ) + # exclude the spike itself (it is included in the collision_spikes by construction) pre_possible_consecutive_spike_indices = np.arange(consecutive_window_pre, spike_index_w_margin) post_possible_consecutive_spike_indices = np.arange(spike_index_w_margin + 1, consecutive_window_post) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 233625e09e..ce1c3bd5a0 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -600,8 +600,7 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): seg_size = recording.get_num_samples(segment_index=segment_index) - i0 = np.searchsorted(spike_times, start_frame) - i1 = np.searchsorted(spike_times, end_frame) + i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame]) if i0 != i1: # protect from spikes on border : spike_time<0 or spike_time>seg_size diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index b6f25cda95..38cb714d59 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -218,9 +218,7 @@ def _spike_amplitudes_chunk(segment_index, start_frame, end_frame, worker_ctx): d = np.diff(spike_times) assert np.all(d >= 0) - i0 = np.searchsorted(spike_times, start_frame) - i1 = np.searchsorted(spike_times, end_frame) - + i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame]) n_spikes = i1 - i0 amplitudes = np.zeros(n_spikes, dtype=recording.get_dtype()) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index 0f4800c6e8..cc4e8601e2 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -18,7 +18,7 @@ def detect_bad_channels( nyquist_threshold=0.8, direction="y", chunk_duration_s=0.3, - num_random_chunks=10, + num_random_chunks=100, welch_window_ms=10.0, highpass_filter_cutoff=300, neighborhood_r2_threshold=0.9, @@ -81,9 +81,10 @@ def detect_bad_channels( highpass_filter_cutoff : float If the recording is not filtered, the cutoff frequency of the highpass filter, by default 300 chunk_duration_s : float - Duration of each chunk, by default 0.3 + Duration of each chunk, by default 0.5 num_random_chunks : int - Number of random chunks, by default 10 + Number of random chunks, by default 100 + Having many chunks is important for reproducibility. welch_window_ms : float Window size for the scipy.signal.welch that will be converted to nperseg, by default 10ms neighborhood_r2_threshold : float, default 0.95 @@ -174,20 +175,18 @@ def detect_bad_channels( channel_locations = recording.get_channel_locations() dim = ["x", "y", "z"].index(direction) assert dim < channel_locations.shape[1], f"Direction {direction} is wrong" - locs_depth = channel_locations[:, dim] - if np.array_equal(np.sort(locs_depth), locs_depth): + order_f, order_r = order_channels_by_depth(recording=recording, dimensions=("x", "y")) + if np.all(np.diff(order_f) == 1): + # already ordered order_f = None order_r = None - else: - # sort by x, y to avoid ambiguity - order_f, order_r = order_channels_by_depth(recording=recording, dimensions=("x", "y")) # Create empty channel labels and fill with bad-channel detection estimate for each chunk chunk_channel_labels = np.zeros((recording.get_num_channels(), len(random_data)), dtype=np.int8) for i, random_chunk in enumerate(random_data): - random_chunk_sorted = random_chunk[order_f] if order_f is not None else random_chunk - chunk_channel_labels[:, i] = detect_bad_channels_ibl( + random_chunk_sorted = random_chunk[:, order_f] if order_f is not None else random_chunk + chunk_labels = detect_bad_channels_ibl( raw=random_chunk_sorted, fs=recording.sampling_frequency, psd_hf_threshold=psd_hf_threshold, @@ -198,11 +197,10 @@ def detect_bad_channels( nyquist_threshold=nyquist_threshold, welch_window_ms=welch_window_ms, ) + chunk_channel_labels[:, i] = chunk_labels[order_r] if order_r is not None else chunk_labels # Take the mode of the chunk estimates as final result. Convert to binary good / bad channel output. mode_channel_labels, _ = scipy.stats.mode(chunk_channel_labels, axis=1, keepdims=False) - if order_r is not None: - mode_channel_labels = mode_channel_labels[order_r] (bad_inds,) = np.where(mode_channel_labels != 0) bad_channel_ids = recording.channel_ids[bad_inds] diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 4e871492f8..8dd5f857f6 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -848,16 +848,14 @@ def compute_drift_metrics( spike_vector = sorting.to_spike_vector() # retrieve spikes in segment - i0 = np.searchsorted(spike_vector["segment_index"], segment_index) - i1 = np.searchsorted(spike_vector["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1]) spikes_in_segment = spike_vector[i0:i1] spike_locations_in_segment = spike_locations[i0:i1] # compute median positions (if less than min_spikes_per_interval, median position is 0) median_positions = np.nan * np.zeros((len(unit_ids), num_bin_edges - 1)) for bin_index, (start_frame, end_frame) in enumerate(zip(bins[:-1], bins[1:])): - i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame]) spikes_in_bin = spikes_in_segment[i0:i1] spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction] diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py index 939475c17d..9715b7ea87 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py @@ -95,8 +95,7 @@ def plot(self): num_frames = int(duration / self.bin_duration_s) def animate_func(i): - i0 = np.searchsorted(peaks["sample_index"], bin_size * i) - i1 = np.searchsorted(peaks["sample_index"], bin_size * (i + 1)) + i0, i1 = np.searchsorted(peaks["sample_index"], [bin_size * i, bin_size * (i + 1)]) local_peaks = peaks[i0:i1] artists = self._plot_one_bin(rec, probe, local_peaks, self.bin_duration_s) return artists