diff --git a/src/spikeinterface/sortingcomponents/peak_selection.py b/src/spikeinterface/sortingcomponents/peak_selection.py index 1ccfbc4d22..fed026b6a7 100644 --- a/src/spikeinterface/sortingcomponents/peak_selection.py +++ b/src/spikeinterface/sortingcomponents/peak_selection.py @@ -76,19 +76,18 @@ def select_peaks( selected_indices = select_peak_indices(peaks, method=method, seed=seed, **method_kwargs) selected_peaks = peaks[selected_indices] + num_segments = len(np.unique(selected_peaks["segment_index"])) if margin is not None: to_keep = np.zeros(len(selected_peaks), dtype=bool) - offset = 0 - for segment_index in range(recording.get_num_segments()): - duration = recording.get_num_frames(segment_index) + for segment_index in range(num_segments): + num_samples_in_segment = recording.get_num_samples(segment_index) i0, i1 = np.searchsorted(selected_peaks["segment_index"], [segment_index, segment_index + 1]) - while selected_peaks["sample_index"][i0] <= margin[0] + offset: + while selected_peaks["sample_index"][i0] <= margin[0]: i0 += 1 - while selected_peaks["sample_index"][i1 - 1] >= (duration - margin[1]) + offset: + while selected_peaks["sample_index"][i1 - 1] >= (num_samples_in_segment - margin[1]): i1 -= 1 to_keep[i0:i1] = True - offset += duration selected_indices = selected_indices[to_keep] selected_peaks = peaks[selected_indices] @@ -284,7 +283,9 @@ def select_peak_indices(peaks, method, seed, **method_kwargs): ) selected_indices = np.concatenate(selected_indices) - selected_indices = selected_indices[np.argsort(peaks[selected_indices]["sample_index"])] + selected_indices = selected_indices[ + np.lexsort((peaks[selected_indices]["sample_index"], peaks[selected_indices]["segment_index"])) + ] return selected_indices