diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index aac96be7b6..6302fe37ae 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -19,14 +19,13 @@ class SpikeLocationsCalculator(BaseWaveformExtractorExtension): extension_name = "spike_locations" - def __init__(self, waveform_extractor): + def __init__(self, waveform_extractor, peak_sign="neg"): BaseWaveformExtractorExtension.__init__(self, waveform_extractor) - - extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, outputs="index") + extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, peak_sign, outputs="index") self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) - def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", method_kwargs={}): - params = dict(ms_before=ms_before, ms_after=ms_after, method=method) + def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", method_kwargs={}, radius_um=None): + params = dict(ms_before=ms_before, ms_after=ms_after, method=method, radius_um=radius_um) params.update(**method_kwargs) return params @@ -47,12 +46,8 @@ def _run(self, **job_kwargs): from spikeinterface.sortingcomponents.peak_localization import localize_peaks job_kwargs = fix_job_kwargs(job_kwargs) - we = self.waveform_extractor - extremum_channel_inds = get_template_extremum_channel(we, outputs="index") - self.spikes = we.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) - spike_locations = localize_peaks(we.recording, self.spikes, **self._params, **job_kwargs) self._extension_data["spike_locations"] = spike_locations @@ -107,6 +102,8 @@ def compute_spike_locations( method="center_of_mass", method_kwargs={}, outputs="concatenated", + peak_sign="neg", + radius_um=None, **job_kwargs, ): """ @@ -126,6 +123,11 @@ def compute_spike_locations( 'center_of_mass' / 'monopolar_triangulation' / 'grid_convolution' method_kwargs : dict Other kwargs depending on the method. + peak_sign: str + Sign of the template to compute best channels ('neg', 'pos', 'both') + radius_um: None + If not None, the radius used to perform peak centering, i.e. look for the + real channel, in the data, where the peaks occur outputs : str 'concatenated' (default) / 'by_unit' {} @@ -140,8 +142,10 @@ def compute_spike_locations( if load_if_exists and waveform_extractor.is_extension(SpikeLocationsCalculator.extension_name): slc = waveform_extractor.load_extension(SpikeLocationsCalculator.extension_name) else: - slc = SpikeLocationsCalculator(waveform_extractor) - slc.set_params(ms_before=ms_before, ms_after=ms_after, method=method, method_kwargs=method_kwargs) + slc = SpikeLocationsCalculator(waveform_extractor, peak_sign) + slc.set_params( + ms_before=ms_before, ms_after=ms_after, method=method, method_kwargs=method_kwargs, radius_um=radius_um + ) slc.run(**job_kwargs) locs = slc.get_data(outputs=outputs) diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index bd793b3f53..a9c0063a8f 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -6,6 +6,7 @@ run_node_pipeline, find_parent_of_type, PeakRetriever, + PeakCenterer, PipelineNode, WaveformsNode, ExtractDenseWaveforms, @@ -26,7 +27,7 @@ from .tools import get_prototype_spike -def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs): +def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, radius_um=None, **kwargs): """Localize peak (spike) in 2D or 3D depending the method. When a probe is 2D then: @@ -40,6 +41,9 @@ def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_ The recording extractor object. peaks: array Peaks array, as returned by detect_peaks() in "compact_numpy" way. + radius_um: float (default None) + If not None, the radius used to perform peak centering, i.e. find the real channel where + the peaks actually occur {method_doc} @@ -57,7 +61,11 @@ def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_ method_kwargs, job_kwargs = split_job_kwargs(kwargs) - peak_retriever = PeakRetriever(recording, peaks) + if radius_um is not None: + peak_retriever = PeakCenterer(recording, peaks, radius_um=radius_um) + else: + peak_retriever = PeakRetriever(recording, peaks) + if method == "center_of_mass": extract_dense_waveforms = ExtractDenseWaveforms( recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False diff --git a/src/spikeinterface/sortingcomponents/peak_pipeline.py b/src/spikeinterface/sortingcomponents/peak_pipeline.py index 6f0f26201f..732bcd3554 100644 --- a/src/spikeinterface/sortingcomponents/peak_pipeline.py +++ b/src/spikeinterface/sortingcomponents/peak_pipeline.py @@ -126,6 +126,46 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): return (local_peaks,) +class PeakCenterer(PeakRetriever): + def __init__(self, recording, peaks, radius_um=50, peak_sign="neg"): + PeakRetriever.__init__(self, recording, peaks) + self.radius_um = radius_um + self.contact_locations = recording.get_channel_locations() + self.channel_distance = get_channel_distances(recording) + self.neighbours_mask = self.channel_distance < radius_um + self.peak_sign = peak_sign + + def get_trace_margin(self): + return 0 + + def get_dtype(self): + return base_peak_dtype + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + # get local peaks + sl = self.segment_slices[segment_index] + peaks_in_segment = self.peaks[sl] + i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame) + i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame) + local_peaks = peaks_in_segment[i0:i1] + + # make sample index local to traces + local_peaks = local_peaks.copy() + local_peaks["sample_index"] -= start_frame - max_margin + + for i, peak in enumerate(local_peaks): + (chans,) = np.nonzero(self.neighbours_mask[peak["channel_index"]]) + sparse_wfs = traces[peak["sample_index"], chans] + if self.peak_sign == "neg": + local_peaks[i]["channel_index"] = chans[np.argmin(sparse_wfs)] + elif self.peak_sign == "pos": + local_peaks[i]["channel_index"] = chans[np.argmax(sparse_wfs)] + elif self.peak_sign == "both": + local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))] + + return (local_peaks,) + + class WaveformsNode(PipelineNode): """ Base class for waveforms in a node pipeline. @@ -304,8 +344,10 @@ def check_graph(nodes): """ node0 = nodes[0] - if not (isinstance(node0, PeakDetector) or isinstance(node0, PeakRetriever)): - raise ValueError("Peak pipeline graph must contain PeakDetector or PeakRetriever as first element") + if not (isinstance(node0, PeakDetector) or isinstance(node0, PeakRetriever) or isinstance(node0, PeakCenterer)): + raise ValueError( + "Peak pipeline graph must contain PeakDetector or PeakRetriever or PeakCenterer as first element" + ) for i, node in enumerate(nodes): assert isinstance(node, PipelineNode), f"Node {node} is not an instance of PipelineNode"