diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index a0ded216d1..a6dabf77b5 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -213,7 +213,9 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): elif self.peak_sign == "both": local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))] - # TODO: "amplitude" ??? + # handle amplitude + for i, peak in enumerate(local_peaks): + local_peaks["amplitude"][i] = traces[peak["sample_index"], peak["channel_index"]] return (local_peaks,) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index bcb15b6455..e5a6dd055c 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -72,7 +72,8 @@ def compute(self, traces, peaks, waveforms): def test_run_node_pipeline(): recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0]) - job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) + # job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) + job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False) spikes = sorting.to_spike_vector() @@ -104,7 +105,8 @@ def test_run_node_pipeline(): AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6), ] step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True) - assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"]) + if loop == 0: + assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"]) # 3 nodes two have outputs ms_before = 0.5 @@ -132,7 +134,6 @@ def test_run_node_pipeline(): # gather memory mode output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory") amplitudes, waveforms_rms, denoised_waveforms_rms = output - assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"]) num_peaks = peaks.shape[0] num_channels = recording.get_num_channels() diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 4cbe4d665e..28eed131cd 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -5,6 +5,7 @@ from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension +from spikeinterface.core.node_pipeline import SpikeRetriever class SpikeLocationsCalculator(BaseWaveformExtractorExtension): @@ -25,8 +26,21 @@ def __init__(self, waveform_extractor): extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, outputs="index") self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) - def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", method_kwargs={}): - params = dict(ms_before=ms_before, ms_after=ms_after, method=method) + def _set_params( + self, + ms_before=0.5, + ms_after=0.5, + spike_retriver_kwargs=dict( + channel_from_template=True, + radius_um=50, + peak_sign="neg", + ), + method="center_of_mass", + method_kwargs={}, + ): + params = dict( + ms_before=ms_before, ms_after=ms_after, spike_retriver_kwargs=spike_retriver_kwargs, method=method + ) params.update(**method_kwargs) return params @@ -44,13 +58,22 @@ def _run(self, **job_kwargs): uses the`sortingcomponents.peak_localization.localize_peaks()` function to triangulate spike locations. """ - from spikeinterface.sortingcomponents.peak_localization import localize_peaks + from spikeinterface.sortingcomponents.peak_localization import _run_localization_from_peak_source job_kwargs = fix_job_kwargs(job_kwargs) we = self.waveform_extractor - spike_locations = localize_peaks(we.recording, self.spikes, **self._params, **job_kwargs) + extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") + + params = self._params.copy() + spike_retriver_kwargs = params.pop("spike_retriver_kwargs") + + spike_retriever = SpikeRetriever( + we.recording, we.sorting, extremum_channel_inds=extremum_channel_inds, **spike_retriver_kwargs + ) + spike_locations = _run_localization_from_peak_source(we.recording, spike_retriever, **params, **job_kwargs) + self._extension_data["spike_locations"] = spike_locations def get_data(self, outputs="concatenated"): @@ -101,6 +124,11 @@ def compute_spike_locations( load_if_exists=False, ms_before=0.5, ms_after=0.5, + spike_retriver_kwargs=dict( + channel_from_template=True, + radius_um=50, + peak_sign="neg", + ), method="center_of_mass", method_kwargs={}, outputs="concatenated", @@ -119,6 +147,17 @@ def compute_spike_locations( The left window, before a peak, in milliseconds. ms_after : float The right window, after a peak, in milliseconds. + spike_retriver_kwargs: dict + A dictionary to control the behavior for getting the maximum channel for each spike. + This dictionary contains: + * channel_from_template: bool, default True + For each spike is the maximum channel computed from template or re estimated at every spikes. + channel_from_template = True is old behavior but less acurate + channel_from_template = False is slower but more accurate + * radius_um: float, default 50 + In case channel_from_template=False, this is the radius to get the true peak. + * peak_sign="neg" + In case channel_from_template=False, this is the peak sign. method : str 'center_of_mass' / 'monopolar_triangulation' / 'grid_convolution' method_kwargs : dict @@ -138,7 +177,13 @@ def compute_spike_locations( slc = waveform_extractor.load_extension(SpikeLocationsCalculator.extension_name) else: slc = SpikeLocationsCalculator(waveform_extractor) - slc.set_params(ms_before=ms_before, ms_after=ms_after, method=method, method_kwargs=method_kwargs) + slc.set_params( + ms_before=ms_before, + ms_after=ms_after, + spike_retriver_kwargs=spike_retriver_kwargs, + method=method, + method_kwargs=method_kwargs, + ) slc.run(**job_kwargs) locs = slc.get_data(outputs=outputs) diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index 521b49e6cd..d047a2f67e 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -10,7 +10,12 @@ class SpikeLocationsExtensionTest(WaveformExtensionCommonTestSuite, unittest.Tes extension_class = SpikeLocationsCalculator extension_data_names = ["spike_locations"] extension_function_kwargs_list = [ - dict(method="center_of_mass", chunk_size=10000, n_jobs=1), + dict( + method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=True) + ), + dict( + method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=False) + ), dict(method="center_of_mass", chunk_size=10000, n_jobs=1, outputs="by_unit"), dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"), dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"), diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index fa6101f896..6495503b43 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -7,6 +7,7 @@ run_node_pipeline, find_parent_of_type, PeakRetriever, + SpikeRetriever, PipelineNode, WaveformsNode, ExtractDenseWaveforms, @@ -27,72 +28,49 @@ from .tools import get_prototype_spike -def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs): - """Localize peak (spike) in 2D or 3D depending the method. - - When a probe is 2D then: - * X is axis 0 of the probe - * Y is axis 1 of the probe - * Z is orthogonal to the plane of the probe - - Parameters - ---------- - recording: RecordingExtractor - The recording extractor object. - peaks: array - Peaks array, as returned by detect_peaks() in "compact_numpy" way. - - {method_doc} - - {job_doc} - - Returns - ------- - peak_locations: ndarray - Array with estimated location for each spike. - The dtype depends on the method. ('x', 'y') or ('x', 'y', 'z', 'alpha'). - """ +def _run_localization_from_peak_source( + recording, peak_source, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs +): + # use by localize_peaks() and compute_spike_locations() assert ( method in possible_localization_methods ), f"Method {method} is not supported. Choose from {possible_localization_methods}" method_kwargs, job_kwargs = split_job_kwargs(kwargs) - peak_retriever = PeakRetriever(recording, peaks) if method == "center_of_mass": extract_dense_waveforms = ExtractDenseWaveforms( - recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False + recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False ) pipeline_nodes = [ - peak_retriever, + peak_source, extract_dense_waveforms, - LocalizeCenterOfMass(recording, parents=[peak_retriever, extract_dense_waveforms], **method_kwargs), + LocalizeCenterOfMass(recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs), ] elif method == "monopolar_triangulation": extract_dense_waveforms = ExtractDenseWaveforms( - recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False + recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False ) pipeline_nodes = [ - peak_retriever, + peak_source, extract_dense_waveforms, - LocalizeMonopolarTriangulation( - recording, parents=[peak_retriever, extract_dense_waveforms], **method_kwargs - ), + LocalizeMonopolarTriangulation(recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs), ] elif method == "peak_channel": - pipeline_nodes = [peak_retriever, LocalizePeakChannel(recording, parents=[peak_retriever], **method_kwargs)] + pipeline_nodes = [peak_source, LocalizePeakChannel(recording, parents=[peak_source], **method_kwargs)] elif method == "grid_convolution": if "prototype" not in method_kwargs: + assert isinstance(peak_source, (PeakRetriever, SpikeRetriever)) method_kwargs["prototype"] = get_prototype_spike( - recording, peaks, ms_before=ms_before, ms_after=ms_after, job_kwargs=job_kwargs + recording, peak_source.peaks, ms_before=ms_before, ms_after=ms_after, job_kwargs=job_kwargs ) extract_dense_waveforms = ExtractDenseWaveforms( - recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False + recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False ) pipeline_nodes = [ - peak_retriever, + peak_source, extract_dense_waveforms, - LocalizeGridConvolution(recording, parents=[peak_retriever, extract_dense_waveforms], **method_kwargs), + LocalizeGridConvolution(recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs), ] job_name = f"localize peaks using {method}" @@ -101,6 +79,38 @@ def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_ return peak_locations +def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs): + """Localize peak (spike) in 2D or 3D depending the method. + + When a probe is 2D then: + * X is axis 0 of the probe + * Y is axis 1 of the probe + * Z is orthogonal to the plane of the probe + + Parameters + ---------- + recording: RecordingExtractor + The recording extractor object. + peaks: array + Peaks array, as returned by detect_peaks() in "compact_numpy" way. + + {method_doc} + + {job_doc} + + Returns + ------- + peak_locations: ndarray + Array with estimated location for each spike. + The dtype depends on the method. ('x', 'y') or ('x', 'y', 'z', 'alpha'). + """ + peak_retriever = PeakRetriever(recording, peaks) + peak_locations = _run_localization_from_peak_source( + recording, peak_retriever, method=method, ms_before=ms_before, ms_after=ms_after, **kwargs + ) + return peak_locations + + class LocalizeBase(PipelineNode): def __init__(self, recording, return_output=True, parents=None, radius_um=75.0): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 45b9079ea9..cd9226d5e8 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -19,6 +19,8 @@ def make_multi_method_doc(methods, ident=" "): def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0.5, ms_after=0.5): + # TODO for Pierre: this function is really inefficient because it runs a full pipeline only for a few + # spikes, which means that all traces need to be accesses! Please find a better way nb_peaks = min(len(peaks), nb_peaks) idx = np.sort(np.random.choice(len(peaks), nb_peaks, replace=False)) peak_retriever = PeakRetriever(recording, peaks[idx])