diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 610ae42398..64949357c4 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -139,21 +139,20 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # this is not implemented yet this will be done in separted PR class SpikeRetriever(PeakSource): """ - This class is usefull to inject a sorting object in the node pipepline mechanisim. + This class is useful to inject a sorting object in the node pipepline mechanism. It allows to compute some post processing with the same machinery used for sorting components. - This is a first step to totaly refactor: + This is used by: * compute_spike_locations() * compute_amplitude_scalings() * compute_spike_amplitudes() * compute_principal_components() - - recording: - - sorting: - - channel_from_template: bool (default True) - If True then the channel_index is infered from template and extremum_channel_inds must be provided. + recording : BaseRecording + The recording object. + sorting: BaseSorting + The sorting object. + channel_from_template: bool, default: True + If True, then the channel_index is inferred from the template and `extremum_channel_inds` must be provided. If False every spikes compute its own channel index given a radius around the template max channel. extremum_channel_inds: dict of int The extremum channel index dict given from template. @@ -174,7 +173,7 @@ def __init__( assert extremum_channel_inds is not None, "SpikeRetriever need the dict extremum_channel_inds" - self.peaks = sorting_to_peak(sorting, extremum_channel_inds) + self.peaks = sorting_to_peaks(sorting, extremum_channel_inds) if not channel_from_template: channel_distance = get_channel_distances(recording) @@ -223,7 +222,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): return (local_peaks,) -def sorting_to_peak(sorting, extremum_channel_inds): +def sorting_to_peaks(sorting, extremum_channel_inds): spikes = sorting.to_spike_vector() peaks = np.zeros(spikes.size, dtype=base_peak_dtype) peaks["sample_index"] = spikes["sample_index"] diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index d0d49b865c..bcb15b6455 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -15,7 +15,7 @@ SpikeRetriever, PipelineNode, ExtractDenseWaveforms, - sorting_to_peak, + sorting_to_peaks, ) @@ -72,15 +72,14 @@ def compute(self, traces, peaks, waveforms): def test_run_node_pipeline(): recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0]) - # job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) - job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False) + job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) spikes = sorting.to_spike_vector() # create peaks from spikes we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs) extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") - peaks = sorting_to_peak(sorting, extremum_channel_inds) + peaks = sorting_to_peaks(sorting, extremum_channel_inds) peak_retriever = PeakRetriever(recording, peaks) # channel index is from template @@ -97,7 +96,7 @@ def test_run_node_pipeline(): peak_sign="neg", ) - # test with 2 diffrents first node + # test with 3 differents first nodes for loop, peak_source in enumerate((peak_retriever, spike_retriever_T, spike_retriever_S)): # one step only : squeeze output nodes = [