Skip to content

Commit

Permalink
Merge pull request #3356 from samuelgarcia/node_pipeline_skip_no_peaks
Browse files Browse the repository at this point in the history
nodepipeline : skip chunks when no peaks inside and skip_after_n_peaks
  • Loading branch information
yger authored Oct 7, 2024
2 parents c7ac344 + 5d84f6c commit 69bf6e4
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 84 deletions.
263 changes: 183 additions & 80 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,6 @@
"""
Pipeline on spikes/peaks/detected peaks
Functions that can be chained:
* after peak detection
* already detected peaks
* spikes (labeled peaks)
to compute some additional features on-the-fly:
* peak localization
* peak-to-peak
* pca
* amplitude
* amplitude scaling
* ...
There are two ways for using theses "plugin nodes":
* during `peak_detect()`
* when peaks are already detected and reduced with `select_peaks()`
* on a sorting object
"""

from __future__ import annotations
Expand Down Expand Up @@ -103,6 +87,15 @@ def get_trace_margin(self):
def get_dtype(self):
return base_peak_dtype

def get_peak_slice(
self,
segment_index,
start_frame,
end_frame,
):
# not needed for PeakDetector
raise NotImplementedError


# this is used in sorting components
class PeakDetector(PeakSource):
Expand All @@ -127,11 +120,18 @@ def get_trace_margin(self):
def get_dtype(self):
return base_peak_dtype

def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
# get local peaks
def get_peak_slice(self, segment_index, start_frame, end_frame, max_margin):
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame])
return i0, i1

def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peak_slice):
# get local peaks
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
# i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame])
i0, i1 = peak_slice
local_peaks = peaks_in_segment[i0:i1]

# make sample index local to traces
Expand Down Expand Up @@ -212,8 +212,7 @@ def get_trace_margin(self):
def get_dtype(self):
return self._dtype

def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
# get local peaks
def get_peak_slice(self, segment_index, start_frame, end_frame, max_margin):
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
if self.include_spikes_in_margin:
Expand All @@ -222,6 +221,20 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
)
else:
i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame])
return i0, i1

def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peak_slice):
# get local peaks
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
# if self.include_spikes_in_margin:
# i0, i1 = np.searchsorted(
# peaks_in_segment["sample_index"], [start_frame - max_margin, end_frame + max_margin]
# )
# else:
# i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame])
i0, i1 = peak_slice

local_peaks = peaks_in_segment[i0:i1]

# make sample index local to traces
Expand Down Expand Up @@ -467,31 +480,91 @@ def run_node_pipeline(
nodes,
job_kwargs,
job_name="pipeline",
mp_context=None,
# mp_context=None,
gather_mode="memory",
gather_kwargs={},
squeeze_output=True,
folder=None,
names=None,
verbose=False,
skip_after_n_peaks=None,
):
"""
Common function to run pipeline with peak detector or already detected peak.
Machinery to compute in parallel operations on peaks and traces.
This useful in several use cases:
* in sortingcomponents : detect peaks and make some computation on then (localize, pca, ...)
* in sortingcomponents : replay some peaks and make some computation on then (localize, pca, ...)
* postprocessing : replay some spikes and make some computation on then (localize, pca, ...)
Here a "peak" is a spike without any labels just a "detected".
Here a "spike" is a spike with any a label so already sorted.
The main idea is to have a graph of nodes.
Every node is doing a computaion of some peaks and related traces.
The first node is PeakSource so either a peak detector PeakDetector or peak/spike replay (PeakRetriever/SpikeRetriever)
Every node can have one or several output that can be directed to other nodes (aka nodes have parents).
Every node can optionally have a global output that will be gathered by the main process.
This is controlled by return_output = True.
The gather consists of concatenating features related to peaks (localization, pca, scaling, ...) into a single big vector.
These vectors can be in "memory" or in files ("npy")
Parameters
----------
recording: Recording
nodes: a list of PipelineNode
job_kwargs: dict
The classical job_kwargs
job_name : str
The name of the pipeline used for the progress_bar
gather_mode : "memory" | "npz"
gather_kwargs : dict
OPtions to control the "gather engine". See GatherToMemory or GatherToNpy.
squeeze_output : bool, default True
If only one output node then squeeze the tuple
folder : str | Path | None
Used for gather_mode="npz"
names : list of str
Names of outputs.
verbose : bool, default False
Verbosity.
skip_after_n_peaks : None | int
Skip the computation after n_peaks.
This is not an exact because internally this skip is done per worker in average.
Returns
-------
outputs: tuple of np.array | np.array
a tuple of vector for the output of nodes having return_output=True.
If squeeze_output=True and only one output then directly np.array.
"""

check_graph(nodes)

job_kwargs = fix_job_kwargs(job_kwargs)
assert all(isinstance(node, PipelineNode) for node in nodes)

if skip_after_n_peaks is not None:
skip_after_n_peaks_per_worker = skip_after_n_peaks / job_kwargs["n_jobs"]
else:
skip_after_n_peaks_per_worker = None

if gather_mode == "memory":
gather_func = GatherToMemory()
elif gather_mode == "npy":
gather_func = GatherToNpy(folder, names, **gather_kwargs)
else:
raise ValueError(f"wrong gather_mode : {gather_mode}")

init_args = (recording, nodes)
init_args = (recording, nodes, skip_after_n_peaks_per_worker)

processor = ChunkRecordingExecutor(
recording,
Expand All @@ -510,79 +583,103 @@ def run_node_pipeline(
return outs


def _init_peak_pipeline(recording, nodes):
def _init_peak_pipeline(recording, nodes, skip_after_n_peaks_per_worker):
# create a local dict per worker
worker_ctx = {}
worker_ctx["recording"] = recording
worker_ctx["nodes"] = nodes
worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes)
worker_ctx["skip_after_n_peaks_per_worker"] = skip_after_n_peaks_per_worker
worker_ctx["num_peaks"] = 0
return worker_ctx


def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_ctx):
recording = worker_ctx["recording"]
max_margin = worker_ctx["max_margin"]
nodes = worker_ctx["nodes"]
skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"]

recording_segment = recording._recording_segments[segment_index]
traces_chunk, left_margin, right_margin = get_chunk_with_margin(
recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True
)
node0 = nodes[0]

# compute the graph
pipeline_outputs = {}
for node in nodes:
node_parents = node.parents if node.parents else list()
node_input_args = tuple()
for parent in node_parents:
parent_output = pipeline_outputs[parent]
parent_outputs_tuple = parent_output if isinstance(parent_output, tuple) else (parent_output,)
node_input_args += parent_outputs_tuple
if isinstance(node, PeakDetector):
# to handle compatibility peak detector is a special case
# with specific margin
# TODO later when in master: change this later
extra_margin = max_margin - node.get_trace_margin()
if extra_margin:
trace_detection = traces_chunk[extra_margin:-extra_margin]
if isinstance(node0, (SpikeRetriever, PeakRetriever)):
# in this case PeakSource could have no peaks and so no need to load traces just skip
peak_slice = i0, i1 = node0.get_peak_slice(segment_index, start_frame, end_frame, max_margin)
load_trace_and_compute = i0 < i1
else:
# PeakDetector always need traces
load_trace_and_compute = True

if skip_after_n_peaks_per_worker is not None:
if worker_ctx["num_peaks"] > skip_after_n_peaks_per_worker:
load_trace_and_compute = False

if load_trace_and_compute:
traces_chunk, left_margin, right_margin = get_chunk_with_margin(
recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True
)
# compute the graph
pipeline_outputs = {}
for node in nodes:
node_parents = node.parents if node.parents else list()
node_input_args = tuple()
for parent in node_parents:
parent_output = pipeline_outputs[parent]
parent_outputs_tuple = parent_output if isinstance(parent_output, tuple) else (parent_output,)
node_input_args += parent_outputs_tuple
if isinstance(node, PeakDetector):
# to handle compatibility peak detector is a special case
# with specific margin
# TODO later when in master: change this later
extra_margin = max_margin - node.get_trace_margin()
if extra_margin:
trace_detection = traces_chunk[extra_margin:-extra_margin]
else:
trace_detection = traces_chunk
node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin)
# set sample index to local
node_output[0]["sample_index"] += extra_margin
elif isinstance(node, PeakSource):
node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin, peak_slice)
else:
trace_detection = traces_chunk
node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin)
# set sample index to local
node_output[0]["sample_index"] += extra_margin
elif isinstance(node, PeakSource):
node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin)
else:
# TODO later when in master: change the signature of all nodes (or maybe not!)
node_output = node.compute(traces_chunk, *node_input_args)
pipeline_outputs[node] = node_output

# propagate the output
pipeline_outputs_tuple = tuple()
for node in nodes:
# handle which buffer are given to the output
# this is controlled by node.return_output being a bool or tuple of bool
out = pipeline_outputs[node]
if isinstance(out, tuple):
if isinstance(node.return_output, bool) and node.return_output:
pipeline_outputs_tuple += out
elif isinstance(node.return_output, tuple):
for flag, e in zip(node.return_output, out):
if flag:
pipeline_outputs_tuple += (e,)
else:
if isinstance(node.return_output, bool) and node.return_output:
pipeline_outputs_tuple += (out,)
elif isinstance(node.return_output, tuple):
# this should not apppend : maybe a checker somewhere before ?
pass
# TODO later when in master: change the signature of all nodes (or maybe not!)
node_output = node.compute(traces_chunk, *node_input_args)
pipeline_outputs[node] = node_output

if skip_after_n_peaks_per_worker is not None and isinstance(node, PeakSource):
worker_ctx["num_peaks"] += node_output[0].size

# propagate the output
pipeline_outputs_tuple = tuple()
for node in nodes:
# handle which buffer are given to the output
# this is controlled by node.return_output being a bool or tuple of bool
out = pipeline_outputs[node]
if isinstance(out, tuple):
if isinstance(node.return_output, bool) and node.return_output:
pipeline_outputs_tuple += out
elif isinstance(node.return_output, tuple):
for flag, e in zip(node.return_output, out):
if flag:
pipeline_outputs_tuple += (e,)
else:
if isinstance(node.return_output, bool) and node.return_output:
pipeline_outputs_tuple += (out,)
elif isinstance(node.return_output, tuple):
# this should not apppend : maybe a checker somewhere before ?
pass

if isinstance(nodes[0], PeakDetector):
# the first out element is the peak vector
# we need to go back to absolut sample index
pipeline_outputs_tuple[0]["sample_index"] += start_frame - left_margin

if isinstance(nodes[0], PeakDetector):
# the first out element is the peak vector
# we need to go back to absolut sample index
pipeline_outputs_tuple[0]["sample_index"] += start_frame - left_margin
return pipeline_outputs_tuple

return pipeline_outputs_tuple
else:
# the gather will skip this output and not concatenate it
return


class GatherToMemory:
Expand All @@ -595,6 +692,9 @@ def __init__(self):
self.tuple_mode = None

def __call__(self, res):
if res is None:
return

if self.tuple_mode is None:
# first loop only
self.tuple_mode = isinstance(res, tuple)
Expand Down Expand Up @@ -655,6 +755,9 @@ def __init__(self, folder, names, npy_header_size=1024, exist_ok=False):
self.final_shapes.append(None)

def __call__(self, res):
if res is None:
return

if self.tuple_mode is None:
# first loop only
self.tuple_mode = isinstance(res, tuple)
Expand Down
Loading

0 comments on commit 69bf6e4

Please sign in to comment.