diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index ff747fe2a0..610ae42398 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -141,7 +141,7 @@ class SpikeRetriever(PeakSource): """ This class is usefull to inject a sorting object in the node pipepline mechanisim. It allows to compute some post processing with the same machinery used for sorting components. - This is a first step to totaly refactor: + This is a first step to totaly refactor: * compute_spike_locations() * compute_amplitude_scalings() * compute_spike_amplitudes() @@ -164,16 +164,14 @@ class SpikeRetriever(PeakSource): Peak sign to find the max channel. Used only when channel_from_template=False """ - def __init__(self, recording, sorting, - channel_from_template=True, - extremum_channel_inds=None, - radius_um=50, - peak_sign="neg" - ): + + def __init__( + self, recording, sorting, channel_from_template=True, extremum_channel_inds=None, radius_um=50, peak_sign="neg" + ): PipelineNode.__init__(self, recording, return_output=False) self.channel_from_template = channel_from_template - + assert extremum_channel_inds is not None, "SpikeRetriever need the dict extremum_channel_inds" self.peaks = sorting_to_peak(sorting, extremum_channel_inds) @@ -181,8 +179,7 @@ def __init__(self, recording, sorting, if not channel_from_template: channel_distance = get_channel_distances(recording) self.neighbours_mask = channel_distance < radius_um - self.peak_sign = peak_sign - + self.peak_sign = peak_sign # precompute segment slice self.segment_slices = [] @@ -219,7 +216,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): elif self.peak_sign == "pos": local_peaks[i]["channel_index"] = chans[np.argmax(sparse_wfs)] elif self.peak_sign == "both": - local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))] + local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))] # TODO: "amplitude" ??? diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 8bea0bafb1..d0d49b865c 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -81,25 +81,24 @@ def test_run_node_pipeline(): we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs) extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") peaks = sorting_to_peak(sorting, extremum_channel_inds) - + peak_retriever = PeakRetriever(recording, peaks) # channel index is from template - spike_retriever_T = SpikeRetriever(recording, sorting, - channel_from_template=True, - extremum_channel_inds=extremum_channel_inds) + spike_retriever_T = SpikeRetriever( + recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channel_inds + ) # channel index is per spike - spike_retriever_S = SpikeRetriever(recording, sorting, - channel_from_template=False, - extremum_channel_inds=extremum_channel_inds, - radius_um=50, - peak_sign="neg") + spike_retriever_S = SpikeRetriever( + recording, + sorting, + channel_from_template=False, + extremum_channel_inds=extremum_channel_inds, + radius_um=50, + peak_sign="neg", + ) # test with 2 diffrents first node for loop, peak_source in enumerate((peak_retriever, spike_retriever_T, spike_retriever_S)): - - - - # one step only : squeeze output nodes = [ peak_source,