From f3fbd5d4017f9ea0888c0a1846ed513c4a1fbe9f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 29 Aug 2024 15:59:44 +0200 Subject: [PATCH 1/5] nodepipeline : skip chunks when no peaks inside --- src/spikeinterface/core/node_pipeline.py | 162 +++++++++++------- .../core/tests/test_node_pipeline.py | 15 +- 2 files changed, 114 insertions(+), 63 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index ceff8577d3..e72f87f794 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -103,6 +103,9 @@ def get_trace_margin(self): def get_dtype(self): return base_peak_dtype + def get_peak_slice(self, segment_index, start_frame, end_frame, ): + # not needed for PeakDetector + raise NotImplementedError # this is used in sorting components class PeakDetector(PeakSource): @@ -127,11 +130,18 @@ def get_trace_margin(self): def get_dtype(self): return base_peak_dtype - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - # get local peaks + def get_peak_slice(self, segment_index, start_frame, end_frame, max_margin): sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) + return i0, i1 + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peak_slice): + # get local peaks + sl = self.segment_slices[segment_index] + peaks_in_segment = self.peaks[sl] + # i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) + i0, i1 = peak_slice local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces @@ -212,8 +222,7 @@ def get_trace_margin(self): def get_dtype(self): return self._dtype - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - # get local peaks + def get_peak_slice(self, segment_index, start_frame, end_frame, max_margin): sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] if self.include_spikes_in_margin: @@ -222,6 +231,20 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): ) else: i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) + return i0, i1 + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peak_slice): + # get local peaks + sl = self.segment_slices[segment_index] + peaks_in_segment = self.peaks[sl] + # if self.include_spikes_in_margin: + # i0, i1 = np.searchsorted( + # peaks_in_segment["sample_index"], [start_frame - max_margin, end_frame + max_margin] + # ) + # else: + # i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) + i0, i1 = peak_slice + local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces @@ -525,64 +548,79 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c nodes = worker_ctx["nodes"] recording_segment = recording._recording_segments[segment_index] - traces_chunk, left_margin, right_margin = get_chunk_with_margin( - recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True - ) - - # compute the graph - pipeline_outputs = {} - for node in nodes: - node_parents = node.parents if node.parents else list() - node_input_args = tuple() - for parent in node_parents: - parent_output = pipeline_outputs[parent] - parent_outputs_tuple = parent_output if isinstance(parent_output, tuple) else (parent_output,) - node_input_args += parent_outputs_tuple - if isinstance(node, PeakDetector): - # to handle compatibility peak detector is a special case - # with specific margin - # TODO later when in master: change this later - extra_margin = max_margin - node.get_trace_margin() - if extra_margin: - trace_detection = traces_chunk[extra_margin:-extra_margin] + node0 = nodes[0] + + if isinstance(node0, (SpikeRetriever, PeakRetriever)): + # in this case PeakSource could have no peaks and so no need to load traces just skip + peak_slice = i0, i1 = node0.get_peak_slice(segment_index, start_frame, end_frame, max_margin) + load_trace_and_compute = i0 < i1 + else: + # PeakDetector always need traces + load_trace_and_compute = True + + if load_trace_and_compute: + traces_chunk, left_margin, right_margin = get_chunk_with_margin( + recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True + ) + # compute the graph + pipeline_outputs = {} + for node in nodes: + node_parents = node.parents if node.parents else list() + node_input_args = tuple() + for parent in node_parents: + parent_output = pipeline_outputs[parent] + parent_outputs_tuple = parent_output if isinstance(parent_output, tuple) else (parent_output,) + node_input_args += parent_outputs_tuple + if isinstance(node, PeakDetector): + # to handle compatibility peak detector is a special case + # with specific margin + # TODO later when in master: change this later + extra_margin = max_margin - node.get_trace_margin() + if extra_margin: + trace_detection = traces_chunk[extra_margin:-extra_margin] + else: + trace_detection = traces_chunk + node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) + # set sample index to local + node_output[0]["sample_index"] += extra_margin + elif isinstance(node, PeakSource): + node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin, peak_slice) else: - trace_detection = traces_chunk - node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) - # set sample index to local - node_output[0]["sample_index"] += extra_margin - elif isinstance(node, PeakSource): - node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin) - else: - # TODO later when in master: change the signature of all nodes (or maybe not!) - node_output = node.compute(traces_chunk, *node_input_args) - pipeline_outputs[node] = node_output - - # propagate the output - pipeline_outputs_tuple = tuple() - for node in nodes: - # handle which buffer are given to the output - # this is controlled by node.return_output being a bool or tuple of bool - out = pipeline_outputs[node] - if isinstance(out, tuple): - if isinstance(node.return_output, bool) and node.return_output: - pipeline_outputs_tuple += out - elif isinstance(node.return_output, tuple): - for flag, e in zip(node.return_output, out): - if flag: - pipeline_outputs_tuple += (e,) - else: - if isinstance(node.return_output, bool) and node.return_output: - pipeline_outputs_tuple += (out,) - elif isinstance(node.return_output, tuple): - # this should not apppend : maybe a checker somewhere before ? - pass + # TODO later when in master: change the signature of all nodes (or maybe not!) + node_output = node.compute(traces_chunk, *node_input_args) + pipeline_outputs[node] = node_output + + # propagate the output + pipeline_outputs_tuple = tuple() + for node in nodes: + # handle which buffer are given to the output + # this is controlled by node.return_output being a bool or tuple of bool + out = pipeline_outputs[node] + if isinstance(out, tuple): + if isinstance(node.return_output, bool) and node.return_output: + pipeline_outputs_tuple += out + elif isinstance(node.return_output, tuple): + for flag, e in zip(node.return_output, out): + if flag: + pipeline_outputs_tuple += (e,) + else: + if isinstance(node.return_output, bool) and node.return_output: + pipeline_outputs_tuple += (out,) + elif isinstance(node.return_output, tuple): + # this should not apppend : maybe a checker somewhere before ? + pass + + if isinstance(nodes[0], PeakDetector): + # the first out element is the peak vector + # we need to go back to absolut sample index + pipeline_outputs_tuple[0]["sample_index"] += start_frame - left_margin + + return pipeline_outputs_tuple - if isinstance(nodes[0], PeakDetector): - # the first out element is the peak vector - # we need to go back to absolut sample index - pipeline_outputs_tuple[0]["sample_index"] += start_frame - left_margin + else: + # the gather will skip this output and not concatenate it + return - return pipeline_outputs_tuple class GatherToMemory: @@ -595,6 +633,9 @@ def __init__(self): self.tuple_mode = None def __call__(self, res): + if res is None: + return + if self.tuple_mode is None: # first loop only self.tuple_mode = isinstance(res, tuple) @@ -655,6 +696,9 @@ def __init__(self, folder, names, npy_header_size=1024, exist_ok=False): self.final_shapes.append(None) def __call__(self, res): + if res is None: + return + if self.tuple_mode is None: # first loop only self.tuple_mode = isinstance(res, tuple) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 8d788acbad..a2919f5673 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -83,8 +83,12 @@ def test_run_node_pipeline(cache_folder_creation): extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) + print(peaks.size) peak_retriever = PeakRetriever(recording, peaks) + # this test when no spikes in last chunks + peak_retriever_few = PeakRetriever(recording, peaks[:peaks.size//2]) + # channel index is from template spike_retriever_T = SpikeRetriever( sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channel_inds @@ -100,7 +104,7 @@ def test_run_node_pipeline(cache_folder_creation): ) # test with 3 differents first nodes - for loop, peak_source in enumerate((peak_retriever, spike_retriever_T, spike_retriever_S)): + for loop, peak_source in enumerate((peak_retriever, peak_retriever_few, spike_retriever_T, spike_retriever_S)): # one step only : squeeze output nodes = [ peak_source, @@ -139,10 +143,12 @@ def test_run_node_pipeline(cache_folder_creation): num_peaks = peaks.shape[0] num_channels = recording.get_num_channels() - assert waveforms_rms.shape[0] == num_peaks + if peak_source != peak_retriever_few: + assert waveforms_rms.shape[0] == num_peaks assert waveforms_rms.shape[1] == num_channels - assert waveforms_rms.shape[0] == num_peaks + if peak_source != peak_retriever_few: + assert waveforms_rms.shape[0] == num_peaks assert waveforms_rms.shape[1] == num_channels # gather npy mode @@ -186,4 +192,5 @@ def test_run_node_pipeline(cache_folder_creation): if __name__ == "__main__": - test_run_node_pipeline() + folder = Path("./cache_folder/core") + test_run_node_pipeline(folder) From 6590e0f1845a62300d617d3896969ec303bebc46 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Sep 2024 14:24:16 +0200 Subject: [PATCH 2/5] nodepipeline add skip_after_n_peaks option --- src/spikeinterface/core/node_pipeline.py | 22 +++++++++-- .../core/tests/test_node_pipeline.py | 37 +++++++++++++++++-- 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index e72f87f794..d04ad59f46 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -497,6 +497,7 @@ def run_node_pipeline( folder=None, names=None, verbose=False, + skip_after_n_peaks=None, ): """ Common function to run pipeline with peak detector or already detected peak. @@ -507,6 +508,11 @@ def run_node_pipeline( job_kwargs = fix_job_kwargs(job_kwargs) assert all(isinstance(node, PipelineNode) for node in nodes) + if skip_after_n_peaks is not None: + skip_after_n_peaks_per_worker = skip_after_n_peaks / job_kwargs["n_jobs"] + else: + skip_after_n_peaks_per_worker = None + if gather_mode == "memory": gather_func = GatherToMemory() elif gather_mode == "npy": @@ -514,7 +520,7 @@ def run_node_pipeline( else: raise ValueError(f"wrong gather_mode : {gather_mode}") - init_args = (recording, nodes) + init_args = (recording, nodes, skip_after_n_peaks_per_worker) processor = ChunkRecordingExecutor( recording, @@ -533,12 +539,14 @@ def run_node_pipeline( return outs -def _init_peak_pipeline(recording, nodes): +def _init_peak_pipeline(recording, nodes, skip_after_n_peaks_per_worker): # create a local dict per worker worker_ctx = {} worker_ctx["recording"] = recording worker_ctx["nodes"] = nodes worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes) + worker_ctx["skip_after_n_peaks_per_worker"] = skip_after_n_peaks_per_worker + worker_ctx["num_peaks"] = 0 return worker_ctx @@ -546,6 +554,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c recording = worker_ctx["recording"] max_margin = worker_ctx["max_margin"] nodes = worker_ctx["nodes"] + skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"] recording_segment = recording._recording_segments[segment_index] node0 = nodes[0] @@ -557,7 +566,11 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c else: # PeakDetector always need traces load_trace_and_compute = True - + + if skip_after_n_peaks_per_worker is not None: + if worker_ctx["num_peaks"] > skip_after_n_peaks_per_worker: + load_trace_and_compute = False + if load_trace_and_compute: traces_chunk, left_margin, right_margin = get_chunk_with_margin( recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True @@ -590,6 +603,9 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c node_output = node.compute(traces_chunk, *node_input_args) pipeline_outputs[node] = node_output + if skip_after_n_peaks_per_worker is not None and isinstance(node, PeakSource): + worker_ctx["num_peaks"] += node_output[0].size + # propagate the output pipeline_outputs_tuple = tuple() for node in nodes: diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index a2919f5673..f31757d6bc 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -83,7 +83,7 @@ def test_run_node_pipeline(cache_folder_creation): extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) - print(peaks.size) + # print(peaks.size) peak_retriever = PeakRetriever(recording, peaks) # this test when no spikes in last chunks @@ -191,6 +191,37 @@ def test_run_node_pipeline(cache_folder_creation): unpickled_node = pickle.loads(pickled_node) +def test_skip_after_n_peaks(): + recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0]) + + # job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) + job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False) + + spikes = sorting.to_spike_vector() + + # create peaks from spikes + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") + sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) + extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") + + peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) + # print(peaks.size) + + node0 = PeakRetriever(recording, peaks) + node1 = AmplitudeExtractionNode(recording, parents=[node0], param0=6.6, return_output=True) + nodes = [node0, node1] + + skip_after_n_peaks = 30 + some_amplitudes = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory", skip_after_n_peaks=skip_after_n_peaks) + + assert some_amplitudes.size >= skip_after_n_peaks + assert some_amplitudes.size < spikes.size + + + + if __name__ == "__main__": - folder = Path("./cache_folder/core") - test_run_node_pipeline(folder) + # folder = Path("./cache_folder/core") + # test_run_node_pipeline(folder) + + test_skip_after_n_peaks() From 9111c13f1994afc6a970353424015ae8be465d2a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Sep 2024 14:48:20 +0200 Subject: [PATCH 3/5] make Zach happy --- src/spikeinterface/core/node_pipeline.py | 78 ++++++++++++++++++------ 1 file changed, 58 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index d04ad59f46..9b53f08520 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -1,22 +1,6 @@ """ -Pipeline on spikes/peaks/detected peaks - -Functions that can be chained: - * after peak detection - * already detected peaks - * spikes (labeled peaks) -to compute some additional features on-the-fly: - * peak localization - * peak-to-peak - * pca - * amplitude - * amplitude scaling - * ... - -There are two ways for using theses "plugin nodes": - * during `peak_detect()` - * when peaks are already detected and reduced with `select_peaks()` - * on a sorting object + + """ from __future__ import annotations @@ -490,7 +474,7 @@ def run_node_pipeline( nodes, job_kwargs, job_name="pipeline", - mp_context=None, + #mp_context=None, gather_mode="memory", gather_kwargs={}, squeeze_output=True, @@ -500,7 +484,61 @@ def run_node_pipeline( skip_after_n_peaks=None, ): """ - Common function to run pipeline with peak detector or already detected peak. + Machinery to compute in paralell operations on peaks and traces. + + This usefull in several use cases: + * in sortingcomponents : detect peaks and make some computation on then (localize, pca, ...) + * in sortingcomponents : replay some peaks and make some computation on then (localize, pca, ...) + * postprocessing : replay some spikes and make some computation on then (localize, pca, ...) + + Here a "peak" is a spike without any labels just a "detected". + Here a "spike" is a spike with any a label so already sorted. + + The main idea is to have a graph of nodes. + Every node is doing a computaion of some peaks and related traces. + The first node is PeakSource so either a peak detector PeakDetector or peak/spike replay (PeakRetriever/SpikeRetriever) + + Every can have one or several output that can be directed to other nodes (aka nodes have parents). + + Every node can optionaly have an global output that will be globaly gather by the main process. + This is controlled by return_output = True. + + The gather consists of concatenating features related to peaks (localization, pca, scaling, ...) into a single big vector. + Theses vector can be in "memory" or in file ("npy") + + + Parameters + ---------- + + recording: Recording + + nodes: a list of PipelineNode + + job_kwargs: dict + The classical job_kwargs + job_name : str + The name of the pipeline used for the progress_bar + gather_mode : "memory" | "npz" + + gather_kwargs : dict + OPtions to control the "gather engine". See GatherToMemory or GatherToNpy. + squeeze_output : bool, default True + If only one output node, the, squeeze the tuple + folder : str | Path | None + Used for gather_mode="npz" + names : list of str + Names of outputs. + verbose : bool, default False + Verbosity. + skip_after_n_peaks : None | int + Skip the computaion after n_peaks. + This is not an exact because internally this skip is done per worker in average. + + Returns + ------- + outputs: tuple of np.array | np.array + a tuple of vector for the output of nodes having return_output=True. + If squeeze_output=True and only one output then directly np.array. """ check_graph(nodes) From 1f527153dc909f93e25bb85c178579d56c624913 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 9 Sep 2024 10:25:16 +0200 Subject: [PATCH 4/5] Merci Zach Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/node_pipeline.py | 14 +++++++------- .../core/tests/test_node_pipeline.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 9b53f08520..a617272753 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -484,9 +484,9 @@ def run_node_pipeline( skip_after_n_peaks=None, ): """ - Machinery to compute in paralell operations on peaks and traces. + Machinery to compute in parallel operations on peaks and traces. - This usefull in several use cases: + This useful in several use cases: * in sortingcomponents : detect peaks and make some computation on then (localize, pca, ...) * in sortingcomponents : replay some peaks and make some computation on then (localize, pca, ...) * postprocessing : replay some spikes and make some computation on then (localize, pca, ...) @@ -498,13 +498,13 @@ def run_node_pipeline( Every node is doing a computaion of some peaks and related traces. The first node is PeakSource so either a peak detector PeakDetector or peak/spike replay (PeakRetriever/SpikeRetriever) - Every can have one or several output that can be directed to other nodes (aka nodes have parents). + Every node can have one or several output that can be directed to other nodes (aka nodes have parents). - Every node can optionaly have an global output that will be globaly gather by the main process. + Every node can optionally have a global output that will be gathered by the main process. This is controlled by return_output = True. The gather consists of concatenating features related to peaks (localization, pca, scaling, ...) into a single big vector. - Theses vector can be in "memory" or in file ("npy") + These vectors can be in "memory" or in files ("npy") Parameters @@ -523,7 +523,7 @@ def run_node_pipeline( gather_kwargs : dict OPtions to control the "gather engine". See GatherToMemory or GatherToNpy. squeeze_output : bool, default True - If only one output node, the, squeeze the tuple + If only one output node then squeeze the tuple folder : str | Path | None Used for gather_mode="npz" names : list of str @@ -531,7 +531,7 @@ def run_node_pipeline( verbose : bool, default False Verbosity. skip_after_n_peaks : None | int - Skip the computaion after n_peaks. + Skip the computation after n_peaks. This is not an exact because internally this skip is done per worker in average. Returns diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index f31757d6bc..3d3a642371 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -219,7 +219,7 @@ def test_skip_after_n_peaks(): - +# the following is for testing locally with python or ipython. It is not used in ci or with pytest. if __name__ == "__main__": # folder = Path("./cache_folder/core") # test_run_node_pipeline(folder) From c4eb8a540fab74b5712d744a30886867fa1f68f3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Oct 2024 14:59:33 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/node_pipeline.py | 17 +++++++++++------ .../core/tests/test_node_pipeline.py | 7 ++++--- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index a617272753..a72808e176 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -87,10 +87,16 @@ def get_trace_margin(self): def get_dtype(self): return base_peak_dtype - def get_peak_slice(self, segment_index, start_frame, end_frame, ): + def get_peak_slice( + self, + segment_index, + start_frame, + end_frame, + ): # not needed for PeakDetector raise NotImplementedError + # this is used in sorting components class PeakDetector(PeakSource): pass @@ -474,7 +480,7 @@ def run_node_pipeline( nodes, job_kwargs, job_name="pipeline", - #mp_context=None, + # mp_context=None, gather_mode="memory", gather_kwargs={}, squeeze_output=True, @@ -506,7 +512,7 @@ def run_node_pipeline( The gather consists of concatenating features related to peaks (localization, pca, scaling, ...) into a single big vector. These vectors can be in "memory" or in files ("npy") - + Parameters ---------- @@ -533,7 +539,7 @@ def run_node_pipeline( skip_after_n_peaks : None | int Skip the computation after n_peaks. This is not an exact because internally this skip is done per worker in average. - + Returns ------- outputs: tuple of np.array | np.array @@ -596,7 +602,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c recording_segment = recording._recording_segments[segment_index] node0 = nodes[0] - + if isinstance(node0, (SpikeRetriever, PeakRetriever)): # in this case PeakSource could have no peaks and so no need to load traces just skip peak_slice = i0, i1 = node0.get_peak_slice(segment_index, start_frame, end_frame, max_margin) @@ -676,7 +682,6 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c return - class GatherToMemory: """ Gather output of nodes into list and then demultiplex and np.concatenate diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 3d3a642371..deef2291c6 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -87,7 +87,7 @@ def test_run_node_pipeline(cache_folder_creation): peak_retriever = PeakRetriever(recording, peaks) # this test when no spikes in last chunks - peak_retriever_few = PeakRetriever(recording, peaks[:peaks.size//2]) + peak_retriever_few = PeakRetriever(recording, peaks[: peaks.size // 2]) # channel index is from template spike_retriever_T = SpikeRetriever( @@ -212,13 +212,14 @@ def test_skip_after_n_peaks(): nodes = [node0, node1] skip_after_n_peaks = 30 - some_amplitudes = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory", skip_after_n_peaks=skip_after_n_peaks) + some_amplitudes = run_node_pipeline( + recording, nodes, job_kwargs, gather_mode="memory", skip_after_n_peaks=skip_after_n_peaks + ) assert some_amplitudes.size >= skip_after_n_peaks assert some_amplitudes.size < spikes.size - # the following is for testing locally with python or ipython. It is not used in ci or with pytest. if __name__ == "__main__": # folder = Path("./cache_folder/core")