From 2a0e042bdaa186836377d02d01ca83c7d06b71d3 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 29 Aug 2023 09:24:43 +0200 Subject: [PATCH 1/5] Implement SpikeRetriever. --- src/spikeinterface/core/node_pipeline.py | 102 ++++++++- .../core/tests/test_node_pipeline.py | 199 ++++++++++-------- 2 files changed, 205 insertions(+), 96 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 9ea5ad59e7..ff747fe2a0 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -84,7 +84,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *ar raise NotImplementedError -# nodes graph must have either a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever) +# nodes graph must have a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever) # as first element they play the same role in pipeline : give some peaks (and eventually more) @@ -138,7 +138,103 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # this is not implemented yet this will be done in separted PR class SpikeRetriever(PeakSource): - pass + """ + This class is usefull to inject a sorting object in the node pipepline mechanisim. + It allows to compute some post processing with the same machinery used for sorting components. + This is a first step to totaly refactor: + * compute_spike_locations() + * compute_amplitude_scalings() + * compute_spike_amplitudes() + * compute_principal_components() + + + recording: + + sorting: + + channel_from_template: bool (default True) + If True then the channel_index is infered from template and extremum_channel_inds must be provided. + If False every spikes compute its own channel index given a radius around the template max channel. + extremum_channel_inds: dict of int + The extremum channel index dict given from template. + radius_um: float (default 50.) + The radius to find the real max channel. + Used only when channel_from_template=False + peak_sign: str (default "neg") + Peak sign to find the max channel. + Used only when channel_from_template=False + """ + def __init__(self, recording, sorting, + channel_from_template=True, + extremum_channel_inds=None, + radius_um=50, + peak_sign="neg" + ): + PipelineNode.__init__(self, recording, return_output=False) + + self.channel_from_template = channel_from_template + + assert extremum_channel_inds is not None, "SpikeRetriever need the dict extremum_channel_inds" + + self.peaks = sorting_to_peak(sorting, extremum_channel_inds) + + if not channel_from_template: + channel_distance = get_channel_distances(recording) + self.neighbours_mask = channel_distance < radius_um + self.peak_sign = peak_sign + + + # precompute segment slice + self.segment_slices = [] + for segment_index in range(recording.get_num_segments()): + i0 = np.searchsorted(self.peaks["segment_index"], segment_index) + i1 = np.searchsorted(self.peaks["segment_index"], segment_index + 1) + self.segment_slices.append(slice(i0, i1)) + + def get_trace_margin(self): + return 0 + + def get_dtype(self): + return base_peak_dtype + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + # get local peaks + sl = self.segment_slices[segment_index] + peaks_in_segment = self.peaks[sl] + i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame) + i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame) + local_peaks = peaks_in_segment[i0:i1] + + # make sample index local to traces + local_peaks = local_peaks.copy() + local_peaks["sample_index"] -= start_frame - max_margin + + if not self.channel_from_template: + # handle channel spike per spike + for i, peak in enumerate(local_peaks): + chans = np.flatnonzero(self.neighbours_mask[peak["channel_index"]]) + sparse_wfs = traces[peak["sample_index"], chans] + if self.peak_sign == "neg": + local_peaks[i]["channel_index"] = chans[np.argmin(sparse_wfs)] + elif self.peak_sign == "pos": + local_peaks[i]["channel_index"] = chans[np.argmax(sparse_wfs)] + elif self.peak_sign == "both": + local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))] + + # TODO: "amplitude" ??? + + return (local_peaks,) + + +def sorting_to_peak(sorting, extremum_channel_inds): + spikes = sorting.to_spike_vector() + peaks = np.zeros(spikes.size, dtype=base_peak_dtype) + peaks["sample_index"] = spikes["sample_index"] + extremum_channel_inds_ = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids]) + peaks["channel_index"] = extremum_channel_inds_[spikes["unit_index"]] + peaks["amplitude"] = 0.0 + peaks["segment_index"] = spikes["segment_index"] + return peaks class WaveformsNode(PipelineNode): @@ -423,7 +519,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) # set sample index to local node_output[0]["sample_index"] += extra_margin - elif isinstance(node, PeakRetriever): + elif isinstance(node, PeakSource): node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin) else: # TODO later when in master: change the signature of all nodes (or maybe not!) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index e9dfb43a66..35388a33a5 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -14,9 +14,10 @@ from spikeinterface.core.node_pipeline import ( run_node_pipeline, PeakRetriever, + SpikeRetriever, PipelineNode, ExtractDenseWaveforms, - base_peak_dtype, + sorting_to_peak, ) @@ -77,7 +78,8 @@ def test_run_node_pipeline(): # recording = MEArecRecordingExtractor(local_path) recording, sorting = read_mearec(local_path) - job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) + # job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) + job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False) spikes = sorting.to_spike_vector() @@ -88,98 +90,109 @@ def test_run_node_pipeline(): # create peaks from spikes we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs) extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") - print(extremum_channel_inds) - ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids]) - print(ext_channel_inds) - peaks = np.zeros(spikes.size, dtype=base_peak_dtype) - peaks["sample_index"] = spikes["sample_index"] - peaks["channel_index"] = ext_channel_inds[spikes["unit_index"]] - peaks["amplitude"] = 0.0 - peaks["segment_index"] = 0 - - # one step only : squeeze output - peak_retriever = PeakRetriever(recording, peaks) - nodes = [ - peak_retriever, - AmplitudeExtractionNode(recording, parents=[peak_retriever], param0=6.6), - ] - step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True) - assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"]) - - # 3 nodes two have outputs - ms_before = 0.5 - ms_after = 1.0 + peaks = sorting_to_peak(sorting, extremum_channel_inds) + + peak_retriever = PeakRetriever(recording, peaks) - dense_waveforms = ExtractDenseWaveforms( - recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False - ) - waveform_denoiser = WaveformDenoiser(recording, parents=[peak_retriever, dense_waveforms], return_output=False) - amplitue_extraction = AmplitudeExtractionNode(recording, parents=[peak_retriever], param0=6.6, return_output=True) - waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_retriever, dense_waveforms], return_output=True) - denoised_waveforms_rms = WaveformsRootMeanSquare( - recording, parents=[peak_retriever, waveform_denoiser], return_output=True - ) - - nodes = [ - peak_retriever, - dense_waveforms, - waveform_denoiser, - amplitue_extraction, - waveforms_rms, - denoised_waveforms_rms, - ] - - # gather memory mode - output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory") - amplitudes, waveforms_rms, denoised_waveforms_rms = output - assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"]) - - num_peaks = peaks.shape[0] - num_channels = recording.get_num_channels() - assert waveforms_rms.shape[0] == num_peaks - assert waveforms_rms.shape[1] == num_channels - - assert waveforms_rms.shape[0] == num_peaks - assert waveforms_rms.shape[1] == num_channels - - # gather npy mode - folder = cache_folder / "pipeline_folder" - if folder.is_dir(): - shutil.rmtree(folder) - output = run_node_pipeline( - recording, - nodes, - job_kwargs, - gather_mode="npy", - folder=folder, - names=["amplitudes", "waveforms_rms", "denoised_waveforms_rms"], - ) - amplitudes2, waveforms_rms2, denoised_waveforms_rms2 = output - - amplitudes_file = folder / "amplitudes.npy" - assert amplitudes_file.is_file() - amplitudes3 = np.load(amplitudes_file) - assert np.array_equal(amplitudes, amplitudes2) - assert np.array_equal(amplitudes2, amplitudes3) - - waveforms_rms_file = folder / "waveforms_rms.npy" - assert waveforms_rms_file.is_file() - waveforms_rms3 = np.load(waveforms_rms_file) - assert np.array_equal(waveforms_rms, waveforms_rms2) - assert np.array_equal(waveforms_rms2, waveforms_rms3) - - denoised_waveforms_rms_file = folder / "denoised_waveforms_rms.npy" - assert denoised_waveforms_rms_file.is_file() - denoised_waveforms_rms3 = np.load(denoised_waveforms_rms_file) - assert np.array_equal(denoised_waveforms_rms, denoised_waveforms_rms2) - assert np.array_equal(denoised_waveforms_rms2, denoised_waveforms_rms3) - - # Test pickle mechanism - for node in nodes: - import pickle - - pickled_node = pickle.dumps(node) - unpickled_node = pickle.loads(pickled_node) + # channel index is from template + spike_retriever_T = SpikeRetriever(recording, sorting, + channel_from_template=True, + extremum_channel_inds=extremum_channel_inds) + # channel index is per spike + spike_retriever_S = SpikeRetriever(recording, sorting, + channel_from_template=False, + extremum_channel_inds=extremum_channel_inds, + radius_um=50, + peak_sign="neg") + + # test with 2 diffrents first node + for peak_source in (peak_retriever, spike_retriever_T, spike_retriever_S): + + + + + # one step only : squeeze output + nodes = [ + peak_source, + AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6), + ] + step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True) + assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"]) + + # 3 nodes two have outputs + ms_before = 0.5 + ms_after = 1.0 + peak_retriever = PeakRetriever(recording, peaks) + dense_waveforms = ExtractDenseWaveforms( + recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False + ) + waveform_denoiser = WaveformDenoiser(recording, parents=[peak_source, dense_waveforms], return_output=False) + amplitue_extraction = AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6, return_output=True) + waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_source, dense_waveforms], return_output=True) + denoised_waveforms_rms = WaveformsRootMeanSquare( + recording, parents=[peak_source, waveform_denoiser], return_output=True + ) + + nodes = [ + peak_source, + dense_waveforms, + waveform_denoiser, + amplitue_extraction, + waveforms_rms, + denoised_waveforms_rms, + ] + + # gather memory mode + output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory") + amplitudes, waveforms_rms, denoised_waveforms_rms = output + assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"]) + + num_peaks = peaks.shape[0] + num_channels = recording.get_num_channels() + assert waveforms_rms.shape[0] == num_peaks + assert waveforms_rms.shape[1] == num_channels + + assert waveforms_rms.shape[0] == num_peaks + assert waveforms_rms.shape[1] == num_channels + + # gather npy mode + folder = cache_folder / "pipeline_folder" + if folder.is_dir(): + shutil.rmtree(folder) + output = run_node_pipeline( + recording, + nodes, + job_kwargs, + gather_mode="npy", + folder=folder, + names=["amplitudes", "waveforms_rms", "denoised_waveforms_rms"], + ) + amplitudes2, waveforms_rms2, denoised_waveforms_rms2 = output + + amplitudes_file = folder / "amplitudes.npy" + assert amplitudes_file.is_file() + amplitudes3 = np.load(amplitudes_file) + assert np.array_equal(amplitudes, amplitudes2) + assert np.array_equal(amplitudes2, amplitudes3) + + waveforms_rms_file = folder / "waveforms_rms.npy" + assert waveforms_rms_file.is_file() + waveforms_rms3 = np.load(waveforms_rms_file) + assert np.array_equal(waveforms_rms, waveforms_rms2) + assert np.array_equal(waveforms_rms2, waveforms_rms3) + + denoised_waveforms_rms_file = folder / "denoised_waveforms_rms.npy" + assert denoised_waveforms_rms_file.is_file() + denoised_waveforms_rms3 = np.load(denoised_waveforms_rms_file) + assert np.array_equal(denoised_waveforms_rms, denoised_waveforms_rms2) + assert np.array_equal(denoised_waveforms_rms2, denoised_waveforms_rms3) + + # Test pickle mechanism + for node in nodes: + import pickle + + pickled_node = pickle.dumps(node) + unpickled_node = pickle.loads(pickled_node) if __name__ == "__main__": From b2e737ec734d8a488bf8997ca134ad7190b82d2c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Sep 2023 13:46:48 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/node_pipeline.py | 19 ++++++-------- .../core/tests/test_node_pipeline.py | 25 +++++++++---------- 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index ff747fe2a0..610ae42398 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -141,7 +141,7 @@ class SpikeRetriever(PeakSource): """ This class is usefull to inject a sorting object in the node pipepline mechanisim. It allows to compute some post processing with the same machinery used for sorting components. - This is a first step to totaly refactor: + This is a first step to totaly refactor: * compute_spike_locations() * compute_amplitude_scalings() * compute_spike_amplitudes() @@ -164,16 +164,14 @@ class SpikeRetriever(PeakSource): Peak sign to find the max channel. Used only when channel_from_template=False """ - def __init__(self, recording, sorting, - channel_from_template=True, - extremum_channel_inds=None, - radius_um=50, - peak_sign="neg" - ): + + def __init__( + self, recording, sorting, channel_from_template=True, extremum_channel_inds=None, radius_um=50, peak_sign="neg" + ): PipelineNode.__init__(self, recording, return_output=False) self.channel_from_template = channel_from_template - + assert extremum_channel_inds is not None, "SpikeRetriever need the dict extremum_channel_inds" self.peaks = sorting_to_peak(sorting, extremum_channel_inds) @@ -181,8 +179,7 @@ def __init__(self, recording, sorting, if not channel_from_template: channel_distance = get_channel_distances(recording) self.neighbours_mask = channel_distance < radius_um - self.peak_sign = peak_sign - + self.peak_sign = peak_sign # precompute segment slice self.segment_slices = [] @@ -219,7 +216,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): elif self.peak_sign == "pos": local_peaks[i]["channel_index"] = chans[np.argmax(sparse_wfs)] elif self.peak_sign == "both": - local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))] + local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))] # TODO: "amplitude" ??? diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 8bea0bafb1..d0d49b865c 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -81,25 +81,24 @@ def test_run_node_pipeline(): we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs) extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") peaks = sorting_to_peak(sorting, extremum_channel_inds) - + peak_retriever = PeakRetriever(recording, peaks) # channel index is from template - spike_retriever_T = SpikeRetriever(recording, sorting, - channel_from_template=True, - extremum_channel_inds=extremum_channel_inds) + spike_retriever_T = SpikeRetriever( + recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channel_inds + ) # channel index is per spike - spike_retriever_S = SpikeRetriever(recording, sorting, - channel_from_template=False, - extremum_channel_inds=extremum_channel_inds, - radius_um=50, - peak_sign="neg") + spike_retriever_S = SpikeRetriever( + recording, + sorting, + channel_from_template=False, + extremum_channel_inds=extremum_channel_inds, + radius_um=50, + peak_sign="neg", + ) # test with 2 diffrents first node for loop, peak_source in enumerate((peak_retriever, spike_retriever_T, spike_retriever_S)): - - - - # one step only : squeeze output nodes = [ peak_source, From 45f2b15b286e5b071cf92ec5f18257e3a641e332 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 12 Sep 2023 11:41:34 +0200 Subject: [PATCH 3/5] Feeback from Alessio --- src/spikeinterface/core/node_pipeline.py | 21 +++++++++---------- .../core/tests/test_node_pipeline.py | 9 ++++---- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 610ae42398..64949357c4 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -139,21 +139,20 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # this is not implemented yet this will be done in separted PR class SpikeRetriever(PeakSource): """ - This class is usefull to inject a sorting object in the node pipepline mechanisim. + This class is useful to inject a sorting object in the node pipepline mechanism. It allows to compute some post processing with the same machinery used for sorting components. - This is a first step to totaly refactor: + This is used by: * compute_spike_locations() * compute_amplitude_scalings() * compute_spike_amplitudes() * compute_principal_components() - - recording: - - sorting: - - channel_from_template: bool (default True) - If True then the channel_index is infered from template and extremum_channel_inds must be provided. + recording : BaseRecording + The recording object. + sorting: BaseSorting + The sorting object. + channel_from_template: bool, default: True + If True, then the channel_index is inferred from the template and `extremum_channel_inds` must be provided. If False every spikes compute its own channel index given a radius around the template max channel. extremum_channel_inds: dict of int The extremum channel index dict given from template. @@ -174,7 +173,7 @@ def __init__( assert extremum_channel_inds is not None, "SpikeRetriever need the dict extremum_channel_inds" - self.peaks = sorting_to_peak(sorting, extremum_channel_inds) + self.peaks = sorting_to_peaks(sorting, extremum_channel_inds) if not channel_from_template: channel_distance = get_channel_distances(recording) @@ -223,7 +222,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): return (local_peaks,) -def sorting_to_peak(sorting, extremum_channel_inds): +def sorting_to_peaks(sorting, extremum_channel_inds): spikes = sorting.to_spike_vector() peaks = np.zeros(spikes.size, dtype=base_peak_dtype) peaks["sample_index"] = spikes["sample_index"] diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index d0d49b865c..bcb15b6455 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -15,7 +15,7 @@ SpikeRetriever, PipelineNode, ExtractDenseWaveforms, - sorting_to_peak, + sorting_to_peaks, ) @@ -72,15 +72,14 @@ def compute(self, traces, peaks, waveforms): def test_run_node_pipeline(): recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0]) - # job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) - job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False) + job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) spikes = sorting.to_spike_vector() # create peaks from spikes we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs) extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") - peaks = sorting_to_peak(sorting, extremum_channel_inds) + peaks = sorting_to_peaks(sorting, extremum_channel_inds) peak_retriever = PeakRetriever(recording, peaks) # channel index is from template @@ -97,7 +96,7 @@ def test_run_node_pipeline(): peak_sign="neg", ) - # test with 2 diffrents first node + # test with 3 differents first nodes for loop, peak_source in enumerate((peak_retriever, spike_retriever_T, spike_retriever_S)): # one step only : squeeze output nodes = [ From e2a0472d2c1c53e5d5fd58775d7e8677cf8912d7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 13 Sep 2023 13:35:07 +0200 Subject: [PATCH 4/5] oups --- src/spikeinterface/core/node_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 64949357c4..14964ac7c3 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -140,7 +140,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): class SpikeRetriever(PeakSource): """ This class is useful to inject a sorting object in the node pipepline mechanism. - It allows to compute some post processing with the same machinery used for sorting components. + It allows to compute some post-processing steps with the same machinery used for sorting components. This is used by: * compute_spike_locations() * compute_amplitude_scalings() @@ -153,7 +153,7 @@ class SpikeRetriever(PeakSource): The sorting object. channel_from_template: bool, default: True If True, then the channel_index is inferred from the template and `extremum_channel_inds` must be provided. - If False every spikes compute its own channel index given a radius around the template max channel. + If False, the max channel is computed for each spike given a radius around the template max channel. extremum_channel_inds: dict of int The extremum channel index dict given from template. radius_um: float (default 50.) From ad0f05e555d1e910ec80c8759d963ca27d71bf58 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 13 Sep 2023 16:39:59 +0200 Subject: [PATCH 5/5] Update src/spikeinterface/core/node_pipeline.py --- src/spikeinterface/core/node_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 14964ac7c3..b11f40a441 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -171,7 +171,7 @@ def __init__( self.channel_from_template = channel_from_template - assert extremum_channel_inds is not None, "SpikeRetriever need the dict extremum_channel_inds" + assert extremum_channel_inds is not None, "SpikeRetriever needs the extremum_channel_inds dictionary" self.peaks = sorting_to_peaks(sorting, extremum_channel_inds)