From 8d5e408387923484325d13ad8fb3b7f4f0dacff1 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 28 Aug 2023 18:29:28 +0200 Subject: [PATCH 01/10] move peak_pipeline into core and rename it as node_pipeline. Change tests accordingly --- src/spikeinterface/core/node_pipeline.py | 602 ++++++++++++++++++ .../core/tests/test_node_pipeline.py | 186 ++++++ src/spikeinterface/preprocessing/motion.py | 2 +- .../sortingcomponents/features_from_peaks.py | 2 +- .../sortingcomponents/peak_detection.py | 3 +- .../sortingcomponents/peak_localization.py | 3 +- .../sortingcomponents/peak_pipeline.py | 582 +---------------- .../tests/test_motion_estimation.py | 3 +- .../tests/test_peak_detection.py | 7 +- .../tests/test_peak_pipeline.py | 3 +- .../test_neural_network_denoiser.py | 2 +- .../test_waveforms/test_savgol_denoiser.py | 3 +- .../tests/test_waveforms/test_temporal_pca.py | 2 +- .../test_waveform_thresholder.py | 8 +- src/spikeinterface/sortingcomponents/tools.py | 3 +- .../waveforms/neural_network_denoiser.py | 2 +- .../waveforms/savgol_denoiser.py | 2 +- .../waveforms/temporal_pca.py | 2 +- .../waveforms/waveform_thresholder.py | 2 +- 19 files changed, 812 insertions(+), 607 deletions(-) create mode 100644 src/spikeinterface/core/node_pipeline.py create mode 100644 src/spikeinterface/core/tests/test_node_pipeline.py diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py new file mode 100644 index 0000000000..4157365ffd --- /dev/null +++ b/src/spikeinterface/core/node_pipeline.py @@ -0,0 +1,602 @@ +""" +Pipeline on spikes/peaks/detected peaks + +Functions that can be chained: + * after peak detection + * already detected peaks + * spikes (labeled peaks) +to compute some additional features on-the-fly: + * peak localization + * peak-to-peak + * pca + * amplitude + * amplitude scaling + * ... + +There are two ways for using theses "plugin nodes": + * during `peak_detect()` + * when peaks are already detected and reduced with `select_peaks()` + * on a sorting object +""" + +from typing import Optional, List, Type + +import struct + +from pathlib import Path + + +import numpy as np + +from spikeinterface.core import BaseRecording, get_chunk_with_margin +from spikeinterface.core.job_tools import ChunkRecordingExecutor, fix_job_kwargs, _shared_job_kwargs_doc +from spikeinterface.core import get_channel_distances + + +base_peak_dtype = [ + ("sample_index", "int64"), + ("channel_index", "int64"), + ("amplitude", "float64"), + ("segment_index", "int64"), +] + +class PipelineNode: + def __init__( + self, recording: BaseRecording, return_output: bool = True, parents: Optional[List[Type["PipelineNode"]]] = None + ): + """ + This is a generic object that will make some computation on peaks given a buffer of traces. + Typically used for exctrating features (amplitudes, localization, ...) + + A Node can optionally connect to other nodes with the parents and receive inputs from them. + + Parameters + ---------- + recording : BaseRecording + The recording object. + parents : Optional[List[PipelineNode]], optional + Pass parents nodes to perform a previous computation, by default None + return_output : bool or tuple of bool + Whether or not the output of the node is returned by the pipeline, by default False + When a Node have several toutputs then this can be a tuple of bool. + + + """ + + self.recording = recording + self.return_output = return_output + if isinstance(parents, str): + # only one parents is allowed + parents = [parents] + self.parents = parents + + self._kwargs = dict() + + def get_trace_margin(self): + # can optionaly be overwritten + return 0 + + def get_dtype(self): + raise NotImplementedError + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *args): + raise NotImplementedError + + +# nodes graph must have either a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever) +# as first element they play the same role in pipeline : give some peaks (and eventually more) + +class PeakSource(PipelineNode): + # base class for peak detector + def get_trace_margin(self): + raise NotImplementedError + + def get_dtype(self): + return base_peak_dtype + + +# this is used in sorting components +class PeakDetector(PeakSource): + pass + + +class PeakRetriever(PeakSource): + def __init__(self, recording, peaks): + PipelineNode.__init__(self, recording, return_output=False) + + self.peaks = peaks + + # precompute segment slice + self.segment_slices = [] + for segment_index in range(recording.get_num_segments()): + i0 = np.searchsorted(peaks["segment_index"], segment_index) + i1 = np.searchsorted(peaks["segment_index"], segment_index + 1) + self.segment_slices.append(slice(i0, i1)) + + def get_trace_margin(self): + return 0 + + def get_dtype(self): + return base_peak_dtype + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + # get local peaks + sl = self.segment_slices[segment_index] + peaks_in_segment = self.peaks[sl] + i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame) + i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame) + local_peaks = peaks_in_segment[i0:i1] + + # make sample index local to traces + local_peaks = local_peaks.copy() + local_peaks["sample_index"] -= start_frame - max_margin + + return (local_peaks,) + +# this is not implemented yet this will be done in separted PR +class SpikeRetriever(PeakSource): + pass + + +class WaveformsNode(PipelineNode): + """ + Base class for waveforms in a node pipeline. + + Nodes that output waveforms either extracting them from the traces + (e.g., ExtractDenseWaveforms/ExtractSparseWaveforms)or modifying existing + waveforms (e.g., Denoisers) need to inherit from this base class. + """ + + def __init__( + self, + recording: BaseRecording, + ms_before: float, + ms_after: float, + parents: Optional[List[PipelineNode]] = None, + return_output: bool = False, + ): + """ + Base class for waveform extractor. Contains logic to handle the temporal interval in which to extract the + waveforms. + + Parameters + ---------- + recording : BaseRecording + The recording object. + parents : Optional[List[PipelineNode]], optional + Pass parents nodes to perform a previous computation, by default None + return_output : bool, optional + Whether or not the output of the node is returned by the pipeline, by default False + ms_before : float, optional + The number of milliseconds to include before the peak of the spike, by default 1. + ms_after : float, optional + The number of milliseconds to include after the peak of the spike, by default 1. + """ + + PipelineNode.__init__(self, recording=recording, parents=parents, return_output=return_output) + self.ms_before = ms_before + self.ms_after = ms_after + self.nbefore = int(ms_before * recording.get_sampling_frequency() / 1000.0) + self.nafter = int(ms_after * recording.get_sampling_frequency() / 1000.0) + + +class ExtractDenseWaveforms(WaveformsNode): + def __init__( + self, + recording: BaseRecording, + ms_before: float, + ms_after: float, + parents: Optional[List[PipelineNode]] = None, + return_output: bool = False, + ): + """ + Extract dense waveforms from a recording. This is the default waveform extractor which extracts the waveforms + for further cmoputation on them. + + + Parameters + ---------- + recording : BaseRecording + The recording object. + parents : Optional[List[PipelineNode]], optional + Pass parents nodes to perform a previous computation, by default None + return_output : bool, optional + Whether or not the output of the node is returned by the pipeline, by default False + ms_before : float, optional + The number of milliseconds to include before the peak of the spike, by default 1. + ms_after : float, optional + The number of milliseconds to include after the peak of the spike, by default 1. + """ + + WaveformsNode.__init__( + self, + recording=recording, + parents=parents, + ms_before=ms_before, + ms_after=ms_after, + return_output=return_output, + ) + # this is a bad hack to differentiate in the child if the parents is dense or not. + self.neighbours_mask = None + + def get_trace_margin(self): + return max(self.nbefore, self.nafter) + + def compute(self, traces, peaks): + waveforms = traces[peaks["sample_index"][:, None] + np.arange(-self.nbefore, self.nafter)] + return waveforms + + +class ExtractSparseWaveforms(WaveformsNode): + def __init__( + self, + recording: BaseRecording, + ms_before: float, + ms_after: float, + parents: Optional[List[PipelineNode]] = None, + return_output: bool = False, + radius_um: float = 100.0, + ): + """ + Extract sparse waveforms from a recording. The strategy in this specific node is to reshape the waveforms + to eliminate their inactive channels. This is achieved by changing thei shape from + (num_waveforms, num_time_samples, num_channels) to (num_waveforms, num_time_samples, max_num_active_channels). + + Where max_num_active_channels is the max number of active channels in the waveforms. This is done by selecting + the max number of non-zeros entries in the sparsity neighbourhood mask. + + Note that not all waveforms will have the same number of active channels. Even in the reduced form some of + the channels will be inactive and are filled with zeros. + + Parameters + ---------- + recording : BaseRecording + The recording object. + parents : Optional[List[PipelineNode]], optional + Pass parents nodes to perform a previous computation, by default None + return_output : bool, optional + Whether or not the output of the node is returned by the pipeline, by default False + ms_before : float, optional + The number of milliseconds to include before the peak of the spike, by default 1. + ms_after : float, optional + The number of milliseconds to include after the peak of the spike, by default 1. + + + """ + WaveformsNode.__init__( + self, + recording=recording, + parents=parents, + ms_before=ms_before, + ms_after=ms_after, + return_output=return_output, + ) + + self.radius_um = radius_um + self.contact_locations = recording.get_channel_locations() + self.channel_distance = get_channel_distances(recording) + self.neighbours_mask = self.channel_distance < radius_um + self.max_num_chans = np.max(np.sum(self.neighbours_mask, axis=1)) + + def get_trace_margin(self): + return max(self.nbefore, self.nafter) + + def compute(self, traces, peaks): + sparse_wfs = np.zeros((peaks.shape[0], self.nbefore + self.nafter, self.max_num_chans), dtype=traces.dtype) + + for i, peak in enumerate(peaks): + (chans,) = np.nonzero(self.neighbours_mask[peak["channel_index"]]) + sparse_wfs[i, :, : len(chans)] = traces[ + peak["sample_index"] - self.nbefore : peak["sample_index"] + self.nafter, : + ][:, chans] + + return sparse_wfs + + + +def find_parent_of_type(list_of_parents, parent_type, unique=True): + if list_of_parents is None: + return None + + parents = [] + for parent in list_of_parents: + if isinstance(parent, parent_type): + parents.append(parent) + + if unique and len(parents) == 1: + return parents[0] + elif not unique and len(parents) > 1: + return parents[0] + else: + return None + + +def check_graph(nodes): + """ + Check that node list is orderd in a good (parents are before children) + """ + + node0 = nodes[0] + if not isinstance(node0, PeakSource): + raise ValueError("Peak pipeline graph must have as first element a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever") + + for i, node in enumerate(nodes): + assert isinstance(node, PipelineNode), f"Node {node} is not an instance of PipelineNode" + # check that parents exists and are before in chain + node_parents = node.parents if node.parents else [] + for parent in node_parents: + assert parent in nodes, f"Node {node} has parent {parent} that was not passed in nodes" + assert ( + nodes.index(parent) < i + ), f"Node are ordered incorrectly: {node} before {parent} in the pipeline definition." + + return nodes + + +def run_node_pipeline( + recording, + nodes, + job_kwargs, + job_name="pipeline", + mp_context=None, + gather_mode="memory", + squeeze_output=True, + folder=None, + names=None, +): + """ + Common function to run pipeline with peak detector or already detected peak. + """ + + check_graph(nodes) + + job_kwargs = fix_job_kwargs(job_kwargs) + assert all(isinstance(node, PipelineNode) for node in nodes) + + if gather_mode == "memory": + gather_func = GatherToMemory() + elif gather_mode == "npy": + gather_func = GatherToNpy(folder, names) + else: + raise ValueError(f"wrong gather_mode : {gather_mode}") + + init_args = (recording, nodes) + + processor = ChunkRecordingExecutor( + recording, + _compute_peak_pipeline_chunk, + _init_peak_pipeline, + init_args, + gather_func=gather_func, + job_name=job_name, + **job_kwargs, + ) + + processor.run() + + outs = gather_func.finalize_buffers(squeeze_output=squeeze_output) + return outs + + +def _init_peak_pipeline(recording, nodes): + # create a local dict per worker + worker_ctx = {} + worker_ctx["recording"] = recording + worker_ctx["nodes"] = nodes + worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes) + + return worker_ctx + + +def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_ctx): + recording = worker_ctx["recording"] + max_margin = worker_ctx["max_margin"] + nodes = worker_ctx["nodes"] + + recording_segment = recording._recording_segments[segment_index] + traces_chunk, left_margin, right_margin = get_chunk_with_margin( + recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True + ) + + # compute the graph + pipeline_outputs = {} + for node in nodes: + node_parents = node.parents if node.parents else list() + node_input_args = tuple() + for parent in node_parents: + parent_output = pipeline_outputs[parent] + parent_outputs_tuple = parent_output if isinstance(parent_output, tuple) else (parent_output,) + node_input_args += parent_outputs_tuple + if isinstance(node, PeakDetector): + # to handle compatibility peak detector is a special case + # with specific margin + # TODO later when in master: change this later + extra_margin = max_margin - node.get_trace_margin() + if extra_margin: + trace_detection = traces_chunk[extra_margin:-extra_margin] + else: + trace_detection = traces_chunk + node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) + # set sample index to local + node_output[0]["sample_index"] += extra_margin + elif isinstance(node, PeakRetriever): + node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin) + else: + # TODO later when in master: change the signature of all nodes (or maybe not!) + node_output = node.compute(traces_chunk, *node_input_args) + pipeline_outputs[node] = node_output + + # propagate the output + pipeline_outputs_tuple = tuple() + for node in nodes: + # handle which buffer are given to the output + # this is controlled by node.return_output being a bool or tuple of bool + out = pipeline_outputs[node] + if isinstance(out, tuple): + if isinstance(node.return_output, bool) and node.return_output: + pipeline_outputs_tuple += out + elif isinstance(node.return_output, tuple): + for flag, e in zip(node.return_output, out): + if flag: + pipeline_outputs_tuple += (e,) + else: + if isinstance(node.return_output, bool) and node.return_output: + pipeline_outputs_tuple += (out,) + elif isinstance(node.return_output, tuple): + # this should not apppend : maybe a checker somewhere before ? + pass + + if isinstance(nodes[0], PeakDetector): + # the first out element is the peak vector + # we need to go back to absolut sample index + pipeline_outputs_tuple[0]["sample_index"] += start_frame - left_margin + + return pipeline_outputs_tuple + + + +class GatherToMemory: + """ + Gather output of nodes into list and then demultiplex and np.concatenate + """ + + def __init__(self): + self.outputs = [] + self.tuple_mode = None + + def __call__(self, res): + if self.tuple_mode is None: + # first loop only + self.tuple_mode = isinstance(res, tuple) + + # res is a tuple + self.outputs.append(res) + + def finalize_buffers(self, squeeze_output=False): + # concatenate + if self.tuple_mode: + # list of tuple of numpy array + outs_concat = () + for output_step in zip(*self.outputs): + outs_concat += (np.concatenate(output_step, axis=0),) + + if len(outs_concat) == 1 and squeeze_output: + # when tuple size ==1 then remove the tuple + return outs_concat[0] + else: + # always a tuple even of size 1 + return outs_concat + else: + # list of numpy array + return np.concatenate(self.outputs) + + +class GatherToNpy: + """ + Gather output of nodes into npy file and then open then as memmap. + + + The trick is: + * speculate on a header length (1024) + * accumulate in C order the buffer + * create the npy v1.0 header at the end with the correct shape and dtype + """ + + def __init__(self, folder, names, npy_header_size=1024): + self.folder = Path(folder) + self.folder.mkdir(parents=True, exist_ok=False) + assert names is not None + self.names = names + self.npy_header_size = npy_header_size + + self.tuple_mode = None + + self.files = [] + self.dtypes = [] + self.shapes0 = [] + self.final_shapes = [] + for name in names: + filename = folder / (name + ".npy") + f = open(filename, "wb+") + f.seek(npy_header_size) + self.files.append(f) + self.dtypes.append(None) + self.shapes0.append(0) + self.final_shapes.append(None) + + def __call__(self, res): + if self.tuple_mode is None: + # first loop only + self.tuple_mode = isinstance(res, tuple) + if self.tuple_mode: + assert len(self.names) == len(res) + else: + assert len(self.names) == 1 + + # distribute binary buffer to npy files + for i in range(len(self.names)): + f = self.files[i] + buf = res[i] + buf = np.require(buf, requirements="C") + if self.dtypes[i] is None: + # first loop only + self.dtypes[i] = buf.dtype + if buf.ndim > 1: + self.final_shapes[i] = buf.shape[1:] + f.write(buf.tobytes()) + self.shapes0[i] += buf.shape[0] + + def finalize_buffers(self, squeeze_output=False): + # close and post write header to files + for f in self.files: + f.close() + + for i, name in enumerate(self.names): + filename = self.folder / (name + ".npy") + + shape = (self.shapes0[i],) + if self.final_shapes[i] is not None: + shape += self.final_shapes[i] + + # create header npy v1.0 in bytes + # see https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format + # magic + header = b"\x93NUMPY" + # version npy 1.0 + header += b"\x01\x00" + # size except 10 first bytes + header += struct.pack(" 1: - return parents[0] - else: - return None - - -def check_graph(nodes): - """ - Check that node list is orderd in a good (parents are before children) - """ - - node0 = nodes[0] - if not (isinstance(node0, PeakDetector) or isinstance(node0, PeakRetriever)): - raise ValueError("Peak pipeline graph must contain PeakDetector or PeakRetriever as first element") - - for i, node in enumerate(nodes): - assert isinstance(node, PipelineNode), f"Node {node} is not an instance of PipelineNode" - # check that parents exists and are before in chain - node_parents = node.parents if node.parents else [] - for parent in node_parents: - assert parent in nodes, f"Node {node} has parent {parent} that was not passed in nodes" - assert ( - nodes.index(parent) < i - ), f"Node are ordered incorrectly: {node} before {parent} in the pipeline definition." - - return nodes - - -def run_node_pipeline( - recording, - nodes, - job_kwargs, - job_name="peak_pipeline", - mp_context=None, - gather_mode="memory", - squeeze_output=True, - folder=None, - names=None, -): - """ - Common function to run pipeline with peak detector or already detected peak. - """ - - check_graph(nodes) - - job_kwargs = fix_job_kwargs(job_kwargs) - assert all(isinstance(node, PipelineNode) for node in nodes) - - if gather_mode == "memory": - gather_func = GatherToMemory() - elif gather_mode == "npy": - gather_func = GatherToNpy(folder, names) - else: - raise ValueError(f"wrong gather_mode : {gather_mode}") - - init_args = (recording, nodes) - - processor = ChunkRecordingExecutor( - recording, - _compute_peak_pipeline_chunk, - _init_peak_pipeline, - init_args, - gather_func=gather_func, - job_name=job_name, - **job_kwargs, - ) - - processor.run() - - outs = gather_func.finalize_buffers(squeeze_output=squeeze_output) - return outs - - -def _init_peak_pipeline(recording, nodes): - # create a local dict per worker - worker_ctx = {} - worker_ctx["recording"] = recording - worker_ctx["nodes"] = nodes - worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes) - - return worker_ctx - - -def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_ctx): - recording = worker_ctx["recording"] - max_margin = worker_ctx["max_margin"] - nodes = worker_ctx["nodes"] - - recording_segment = recording._recording_segments[segment_index] - traces_chunk, left_margin, right_margin = get_chunk_with_margin( - recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True - ) - - # compute the graph - pipeline_outputs = {} - for node in nodes: - node_parents = node.parents if node.parents else list() - node_input_args = tuple() - for parent in node_parents: - parent_output = pipeline_outputs[parent] - parent_outputs_tuple = parent_output if isinstance(parent_output, tuple) else (parent_output,) - node_input_args += parent_outputs_tuple - if isinstance(node, PeakDetector): - # to handle compatibility peak detector is a special case - # with specific margin - # TODO later when in master: change this later - extra_margin = max_margin - node.get_trace_margin() - if extra_margin: - trace_detection = traces_chunk[extra_margin:-extra_margin] - else: - trace_detection = traces_chunk - node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) - # set sample index to local - node_output[0]["sample_index"] += extra_margin - elif isinstance(node, PeakRetriever): - node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin) - else: - # TODO later when in master: change the signature of all nodes (or maybe not!) - node_output = node.compute(traces_chunk, *node_input_args) - pipeline_outputs[node] = node_output - - # propagate the output - pipeline_outputs_tuple = tuple() - for node in nodes: - # handle which buffer are given to the output - # this is controlled by node.return_output being a bool or tuple of bool - out = pipeline_outputs[node] - if isinstance(out, tuple): - if isinstance(node.return_output, bool) and node.return_output: - pipeline_outputs_tuple += out - elif isinstance(node.return_output, tuple): - for flag, e in zip(node.return_output, out): - if flag: - pipeline_outputs_tuple += (e,) - else: - if isinstance(node.return_output, bool) and node.return_output: - pipeline_outputs_tuple += (out,) - elif isinstance(node.return_output, tuple): - # this should not apppend : maybe a checker somewhere before ? - pass - - if isinstance(nodes[0], PeakDetector): - # the first out element is the peak vector - # we need to go back to absolut sample index - pipeline_outputs_tuple[0]["sample_index"] += start_frame - left_margin - - return pipeline_outputs_tuple def run_peak_pipeline( @@ -480,149 +46,3 @@ def run_peak_pipeline( ) return outs - -class GatherToMemory: - """ - Gather output of nodes into list and then demultiplex and np.concatenate - """ - - def __init__(self): - self.outputs = [] - self.tuple_mode = None - - def __call__(self, res): - if self.tuple_mode is None: - # first loop only - self.tuple_mode = isinstance(res, tuple) - - # res is a tuple - self.outputs.append(res) - - def finalize_buffers(self, squeeze_output=False): - # concatenate - if self.tuple_mode: - # list of tuple of numpy array - outs_concat = () - for output_step in zip(*self.outputs): - outs_concat += (np.concatenate(output_step, axis=0),) - - if len(outs_concat) == 1 and squeeze_output: - # when tuple size ==1 then remove the tuple - return outs_concat[0] - else: - # always a tuple even of size 1 - return outs_concat - else: - # list of numpy array - return np.concatenate(self.outputs) - - -class GatherToNpy: - """ - Gather output of nodes into npy file and then open then as memmap. - - - The trick is: - * speculate on a header length (1024) - * accumulate in C order the buffer - * create the npy v1.0 header at the end with the correct shape and dtype - """ - - def __init__(self, folder, names, npy_header_size=1024): - self.folder = Path(folder) - self.folder.mkdir(parents=True, exist_ok=False) - assert names is not None - self.names = names - self.npy_header_size = npy_header_size - - self.tuple_mode = None - - self.files = [] - self.dtypes = [] - self.shapes0 = [] - self.final_shapes = [] - for name in names: - filename = folder / (name + ".npy") - f = open(filename, "wb+") - f.seek(npy_header_size) - self.files.append(f) - self.dtypes.append(None) - self.shapes0.append(0) - self.final_shapes.append(None) - - def __call__(self, res): - if self.tuple_mode is None: - # first loop only - self.tuple_mode = isinstance(res, tuple) - if self.tuple_mode: - assert len(self.names) == len(res) - else: - assert len(self.names) == 1 - - # distribute binary buffer to npy files - for i in range(len(self.names)): - f = self.files[i] - buf = res[i] - buf = np.require(buf, requirements="C") - if self.dtypes[i] is None: - # first loop only - self.dtypes[i] = buf.dtype - if buf.ndim > 1: - self.final_shapes[i] = buf.shape[1:] - f.write(buf.tobytes()) - self.shapes0[i] += buf.shape[0] - - def finalize_buffers(self, squeeze_output=False): - # close and post write header to files - for f in self.files: - f.close() - - for i, name in enumerate(self.names): - filename = self.folder / (name + ".npy") - - shape = (self.shapes0[i],) - if self.final_shapes[i] is not None: - shape += self.final_shapes[i] - - # create header npy v1.0 in bytes - # see https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format - # magic - header = b"\x93NUMPY" - # version npy 1.0 - header += b"\x01\x00" - # size except 10 first bytes - header += struct.pack(" Date: Mon, 28 Aug 2023 16:30:08 +0000 Subject: [PATCH 02/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/node_pipeline.py | 13 ++++++++----- src/spikeinterface/core/tests/test_node_pipeline.py | 6 +++--- .../sortingcomponents/peak_detection.py | 8 +++++++- .../sortingcomponents/peak_pipeline.py | 5 ----- .../sortingcomponents/tests/test_peak_detection.py | 1 - src/spikeinterface/sortingcomponents/tools.py | 1 - 6 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 4157365ffd..9ea5ad59e7 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -16,7 +16,7 @@ There are two ways for using theses "plugin nodes": * during `peak_detect()` * when peaks are already detected and reduced with `select_peaks()` - * on a sorting object + * on a sorting object """ from typing import Optional, List, Type @@ -40,6 +40,7 @@ ("segment_index", "int64"), ] + class PipelineNode: def __init__( self, recording: BaseRecording, return_output: bool = True, parents: Optional[List[Type["PipelineNode"]]] = None @@ -86,6 +87,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *ar # nodes graph must have either a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever) # as first element they play the same role in pipeline : give some peaks (and eventually more) + class PeakSource(PipelineNode): # base class for peak detector def get_trace_margin(self): @@ -132,7 +134,8 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): local_peaks["sample_index"] -= start_frame - max_margin return (local_peaks,) - + + # this is not implemented yet this will be done in separted PR class SpikeRetriever(PeakSource): pass @@ -293,7 +296,6 @@ def compute(self, traces, peaks): return sparse_wfs - def find_parent_of_type(list_of_parents, parent_type, unique=True): if list_of_parents is None: return None @@ -318,7 +320,9 @@ def check_graph(nodes): node0 = nodes[0] if not isinstance(node0, PeakSource): - raise ValueError("Peak pipeline graph must have as first element a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever") + raise ValueError( + "Peak pipeline graph must have as first element a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever" + ) for i, node in enumerate(nodes): assert isinstance(node, PipelineNode), f"Node {node} is not an instance of PipelineNode" @@ -454,7 +458,6 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c return pipeline_outputs_tuple - class GatherToMemory: """ Gather output of nodes into list and then demultiplex and np.concatenate diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index e40a820c85..e9dfb43a66 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -6,6 +6,7 @@ import scipy.signal from spikeinterface import download_dataset, BaseSorting, extract_waveforms, get_template_extremum_channel + # from spikeinterface.extractors import MEArecRecordingExtractor from spikeinterface.extractors import read_mearec @@ -15,7 +16,7 @@ PeakRetriever, PipelineNode, ExtractDenseWaveforms, - base_peak_dtype + base_peak_dtype, ) @@ -93,9 +94,8 @@ def test_run_node_pipeline(): peaks = np.zeros(spikes.size, dtype=base_peak_dtype) peaks["sample_index"] = spikes["sample_index"] peaks["channel_index"] = ext_channel_inds[spikes["unit_index"]] - peaks["amplitude"] = 0. + peaks["amplitude"] = 0.0 peaks["segment_index"] = 0 - # one step only : squeeze output peak_retriever = PeakRetriever(recording, peaks) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index bc8889e274..f3719b934b 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -13,7 +13,13 @@ from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances from spikeinterface.core.baserecording import BaseRecording -from spikeinterface.core.node_pipeline import PeakDetector, WaveformsNode, ExtractSparseWaveforms, run_node_pipeline, base_peak_dtype +from spikeinterface.core.node_pipeline import ( + PeakDetector, + WaveformsNode, + ExtractSparseWaveforms, + run_node_pipeline, + base_peak_dtype, +) from ..core import get_chunk_with_margin diff --git a/src/spikeinterface/sortingcomponents/peak_pipeline.py b/src/spikeinterface/sortingcomponents/peak_pipeline.py index c235e18558..f72e827a09 100644 --- a/src/spikeinterface/sortingcomponents/peak_pipeline.py +++ b/src/spikeinterface/sortingcomponents/peak_pipeline.py @@ -3,10 +3,6 @@ from spikeinterface.core.node_pipeline import PeakRetriever, run_node_pipeline - - - - def run_peak_pipeline( recording, peaks, @@ -45,4 +41,3 @@ def run_peak_pipeline( names=names, ) return outs - diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py index 7a37e4da02..9f9377ee53 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py @@ -26,7 +26,6 @@ from spikeinterface.core.node_pipeline import run_node_pipeline - if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "sortingcomponents" else: diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 69768a7fca..45b9079ea9 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -19,7 +19,6 @@ def make_multi_method_doc(methods, ident=" "): def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0.5, ms_after=0.5): - nb_peaks = min(len(peaks), nb_peaks) idx = np.sort(np.random.choice(len(peaks), nb_peaks, replace=False)) peak_retriever = PeakRetriever(recording, peaks[idx]) From a516c634d6e5f8902bbf2fb59a4d3bd665249de6 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 28 Aug 2023 18:42:56 +0200 Subject: [PATCH 03/10] oups --- .../tests/test_waveforms/test_neural_network_denoiser.py | 1 - .../tests/test_waveforms/test_temporal_pca.py | 2 +- .../tests/test_waveforms/test_waveform_thresholder.py | 3 ++- .../sortingcomponents/waveforms/neural_network_denoiser.py | 2 +- .../sortingcomponents/waveforms/savgol_denoiser.py | 2 +- .../sortingcomponents/waveforms/waveform_thresholder.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_neural_network_denoiser.py b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_neural_network_denoiser.py index 8a3c8235f5..f40a54cb81 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_neural_network_denoiser.py +++ b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_neural_network_denoiser.py @@ -4,7 +4,6 @@ from spikeinterface.extractors import MEArecRecordingExtractor from spikeinterface import download_dataset - from spikeinterface.core.node_pipeline import run_node_pipeline, PeakRetriever, ExtractDenseWaveforms from spikeinterface.sortingcomponents.waveforms.neural_network_denoiser import SingleChannelToyDenoiser diff --git a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_temporal_pca.py b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_temporal_pca.py index ea045a2f0d..2be1692f7b 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_temporal_pca.py @@ -2,7 +2,7 @@ from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection, TemporalPCADenoising -from spikeinterface.core.node_pipeline import import ( +from spikeinterface.core.node_pipeline import ( PeakRetriever, ExtractDenseWaveforms, ExtractSparseWaveforms, diff --git a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py index 84adc4686d..3737988ee9 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py +++ b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py @@ -4,7 +4,8 @@ from spikeinterface.sortingcomponents.waveforms.waveform_thresholder import WaveformThresholder -from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_peak_pipeline +from spikeinterface.core.node_pipeline import ExtractDenseWaveforms +from spikeinterface.sortingcomponents.peak_pipeline import run_peak_pipeline @pytest.fixture(scope="module") diff --git a/src/spikeinterface/sortingcomponents/waveforms/neural_network_denoiser.py b/src/spikeinterface/sortingcomponents/waveforms/neural_network_denoiser.py index 50a36651a6..d094bae3e0 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/neural_network_denoiser.py +++ b/src/spikeinterface/sortingcomponents/waveforms/neural_network_denoiser.py @@ -17,7 +17,7 @@ HAVE_HUGGINFACE = False from spikeinterface.core import BaseRecording -from spikeinterface.core.node_pipeline import import PipelineNode, WaveformsNode, find_parent_of_type +from spikeinterface.core.node_pipeline import PipelineNode, WaveformsNode, find_parent_of_type from .waveform_utils import to_temporal_representation, from_temporal_representation diff --git a/src/spikeinterface/sortingcomponents/waveforms/savgol_denoiser.py b/src/spikeinterface/sortingcomponents/waveforms/savgol_denoiser.py index 7a1cc100fd..df6dd81a97 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/savgol_denoiser.py +++ b/src/spikeinterface/sortingcomponents/waveforms/savgol_denoiser.py @@ -4,7 +4,7 @@ import scipy.signal from spikeinterface.core import BaseRecording -from spikeinterface.core.node_pipeline import import PipelineNode, WaveformsNode, find_parent_of_type +from spikeinterface.core.node_pipeline import PipelineNode, WaveformsNode, find_parent_of_type class SavGolDenoiser(WaveformsNode): diff --git a/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py b/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py index b700efc94b..36875148d4 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py +++ b/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py @@ -7,7 +7,7 @@ from typing import Literal from spikeinterface.core import BaseRecording, get_noise_levels -from spikeinterface.core.node_pipeline import import PipelineNode, WaveformsNode, find_parent_of_type +from spikeinterface.core.node_pipeline import PipelineNode, WaveformsNode, find_parent_of_type class WaveformThresholder(WaveformsNode): From e7a4c86bf4b2d72de6d141b307d4ae6e7b5c2d88 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Aug 2023 16:44:16 +0000 Subject: [PATCH 04/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/tests/test_waveforms/test_temporal_pca.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_temporal_pca.py b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_temporal_pca.py index 2be1692f7b..fcd7ddae18 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_temporal_pca.py @@ -2,7 +2,7 @@ from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection, TemporalPCADenoising -from spikeinterface.core.node_pipeline import ( +from spikeinterface.core.node_pipeline import ( PeakRetriever, ExtractDenseWaveforms, ExtractSparseWaveforms, From da7a68bd7019a3e3ecd4b10ba6457013c81eb1ed Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 29 Aug 2023 11:29:21 +0200 Subject: [PATCH 05/10] remove scipy from core test --- src/spikeinterface/core/tests/test_node_pipeline.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index e9dfb43a66..395259610a 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -3,8 +3,6 @@ from pathlib import Path import shutil -import scipy.signal - from spikeinterface import download_dataset, BaseSorting, extract_waveforms, get_template_extremum_channel # from spikeinterface.extractors import MEArecRecordingExtractor @@ -53,8 +51,8 @@ def get_dtype(self): return np.dtype("float32") def compute(self, traces, peaks, waveforms): - kernel = np.array([0.1, 0.8, 0.1])[np.newaxis, :, np.newaxis] - denoised_waveforms = scipy.signal.fftconvolve(waveforms, kernel, axes=1, mode="same") + kernel = np.array([0.1, 0.8, 0.1]) + denoised_waveforms = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode='same'), axis=1, arr=waveforms) return denoised_waveforms From e8bae07f176c08f5088ad61bad80762ef929dab3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 29 Aug 2023 09:30:12 +0000 Subject: [PATCH 06/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/tests/test_node_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 395259610a..bd5c8b3c5f 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -52,7 +52,7 @@ def get_dtype(self): def compute(self, traces, peaks, waveforms): kernel = np.array([0.1, 0.8, 0.1]) - denoised_waveforms = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode='same'), axis=1, arr=waveforms) + denoised_waveforms = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=1, arr=waveforms) return denoised_waveforms From b50bc902964b09f774b879bffb88c7292baca967 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 1 Sep 2023 09:00:40 +0200 Subject: [PATCH 07/10] Remove download from test_node_pipeline.py when in core. --- .../core/tests/test_node_pipeline.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index bd5c8b3c5f..7de62a64cb 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -3,7 +3,7 @@ from pathlib import Path import shutil -from spikeinterface import download_dataset, BaseSorting, extract_waveforms, get_template_extremum_channel +from spikeinterface import extract_waveforms, get_template_extremum_channel, generate_ground_truth_recording # from spikeinterface.extractors import MEArecRecordingExtractor from spikeinterface.extractors import read_mearec @@ -69,26 +69,18 @@ def compute(self, traces, peaks, waveforms): def test_run_node_pipeline(): - repo = "https://gin.g-node.org/NeuralEnsemble/ephy_testing_data" - remote_path = "mearec/mearec_test_10s.h5" - local_path = download_dataset(repo=repo, remote_path=remote_path, local_folder=None) - # recording = MEArecRecordingExtractor(local_path) - recording, sorting = read_mearec(local_path) + recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.]) job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) spikes = sorting.to_spike_vector() - # peaks = detect_peaks( - # recording, method="locally_exclusive", peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1, **job_kwargs - # ) - # create peaks from spikes we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs) extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") - print(extremum_channel_inds) + # print(extremum_channel_inds) ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids]) - print(ext_channel_inds) + # print(ext_channel_inds) peaks = np.zeros(spikes.size, dtype=base_peak_dtype) peaks["sample_index"] = spikes["sample_index"] peaks["channel_index"] = ext_channel_inds[spikes["unit_index"]] From d07da4fcb1bdaaccd376e37bfe258b7404c311eb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Sep 2023 07:01:04 +0000 Subject: [PATCH 08/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/__init__.py | 7 +- src/spikeinterface/core/generate.py | 322 ++++++++++-------- .../core/tests/test_core_tools.py | 21 +- .../core/tests/test_generate.py | 138 +++++--- .../core/tests/test_node_pipeline.py | 2 +- .../curation/tests/test_auto_merge.py | 3 +- .../curation/tests/test_remove_redundant.py | 3 +- src/spikeinterface/extractors/toy_example.py | 61 ++-- .../tests/test_metrics_functions.py | 18 +- .../tests/test_quality_metric_calculator.py | 2 - 10 files changed, 331 insertions(+), 246 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 36d011aef7..5b4a66244e 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -35,11 +35,12 @@ inject_some_split_units, synthetize_spike_train_bad_isi, generate_templates, - NoiseGeneratorRecording, noise_generator_recording, + NoiseGeneratorRecording, + noise_generator_recording, generate_recording_by_size, - InjectTemplatesRecording, inject_templates, + InjectTemplatesRecording, + inject_templates, generate_ground_truth_recording, - ) # utils to append and concatenate segment (equivalent to OLD MultiRecordingTimeExtractor) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 73cdd59ca7..e2e31ad9b7 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -9,23 +9,18 @@ from probeinterface import Probe, generate_linear_probe -from spikeinterface.core import ( - BaseRecording, - BaseRecordingSegment, - BaseSorting -) +from spikeinterface.core import BaseRecording, BaseRecordingSegment, BaseSorting from .snippets_tools import snippets_from_sorting from .core_tools import define_function_from_class - def _ensure_seed(seed): # when seed is None: # we want to set one to push it in the Recordind._kwargs to reconstruct the same signal # this is a better approach than having seed=42 or seed=my_dog_birth because we ensure to have # a new signal for all call with seed=None but the dump/load will still work if seed is None: - seed = np.random.default_rng(seed=None).integers(0, 2 ** 63) + seed = np.random.default_rng(seed=None).integers(0, 2**63) return seed @@ -72,19 +67,19 @@ def generate_recording( recording = _generate_recording_legacy(num_channels, sampling_frequency, durations, seed) elif mode == "lazy": recording = NoiseGeneratorRecording( - num_channels=num_channels, - sampling_frequency=sampling_frequency, - durations=durations, - dtype="float32", - seed=seed, - strategy="tile_pregenerated", - # block size is fixed to one second - noise_block_size=int(sampling_frequency) + num_channels=num_channels, + sampling_frequency=sampling_frequency, + durations=durations, + dtype="float32", + seed=seed, + strategy="tile_pregenerated", + # block size is fixed to one second + noise_block_size=int(sampling_frequency), ) else: raise ValueError("generate_recording() : wrong mode") - + recording.annotate(is_filtered=True) if set_probe: @@ -96,7 +91,6 @@ def generate_recording( probe = generate_linear_probe(num_elec=num_channels) return recording - def _generate_recording_legacy(num_channels, sampling_frequency, durations, seed): @@ -121,9 +115,9 @@ def generate_sorting( num_units=5, sampling_frequency=30000.0, # in Hz durations=[10.325, 3.5], #  in s for 2 segments - firing_rates=3., + firing_rates=3.0, empty_units=None, - refractory_period_ms=3., # in ms + refractory_period_ms=3.0, # in ms seed=None, ): seed = _ensure_seed(seed) @@ -145,7 +139,7 @@ def generate_sorting( keep = ~np.in1d(labels, empty_units) times = times[keep] labels = labels[keep] - + spikes_in_seg = np.zeros(times.size, dtype=minimum_spike_dtype) spikes_in_seg["sample_index"] = times spikes_in_seg["unit_index"] = labels @@ -213,9 +207,15 @@ def generate_snippets( ## spiketrain zone ## + def synthesize_random_firings( - num_units=20, sampling_frequency=30000.0, duration=60, refractory_period_ms=4.0, firing_rates=3.0, add_shift_shuffle=False, - seed=None + num_units=20, + sampling_frequency=30000.0, + duration=60, + refractory_period_ms=4.0, + firing_rates=3.0, + add_shift_shuffle=False, + seed=None, ): """ " Generate some spiketrain with random firing for one segment. @@ -276,7 +276,7 @@ def synthesize_random_firings( if add_shift_shuffle: ## make an interesting autocorrelogram shape # this replace the previous rand_distr2() - some = rng.choice(spike_times.size, spike_times.size//2, replace=False) + some = rng.choice(spike_times.size, spike_times.size // 2, replace=False) x = rng.random(some.size) a = refractory_sample b = refractory_sample * 20 @@ -284,7 +284,7 @@ def synthesize_random_firings( spike_times[some] += shift times0 = times0[(0 <= times0) & (times0 < N)] - violations, = np.nonzero(np.diff(spike_times) < refractory_sample) + (violations,) = np.nonzero(np.diff(spike_times) < refractory_sample) spike_times = np.delete(spike_times, violations) if len(spike_times) > n_spikes: spike_times = rng.choice(spike_times, n_spikes, replace=False) @@ -463,6 +463,7 @@ def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, viol ## Noise generator zone ## + class NoiseGeneratorRecording(BaseRecording): """ A lazy recording that generates random samples if and only if `get_traces` is called. @@ -501,41 +502,47 @@ class NoiseGeneratorRecording(BaseRecording): ---- If modifying this function, ensure that only one call to malloc is made per call get_traces to maintain the optimized memory profile. - """ + """ + def __init__( self, num_channels: int, sampling_frequency: float, durations: List[float], - noise_level: float = 5., + noise_level: float = 5.0, dtype: Optional[Union[np.dtype, str]] = "float32", seed: Optional[int] = None, strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated", noise_block_size: int = 30000, ): - channel_ids = np.arange(num_channels) dtype = np.dtype(dtype).name # Cast to string for serialization if dtype not in ("float32", "float64"): raise ValueError(f"'dtype' must be 'float32' or 'float64' but is {dtype}") - BaseRecording.__init__(self, sampling_frequency=sampling_frequency, channel_ids=channel_ids, dtype=dtype) num_segments = len(durations) # very important here when multiprocessing and dump/load seed = _ensure_seed(seed) - + # we need one seed per segment rng = np.random.default_rng(seed) - segments_seeds = [rng.integers(0, 2 ** 63) for i in range(num_segments)] + segments_seeds = [rng.integers(0, 2**63) for i in range(num_segments)] for i in range(num_segments): num_samples = int(durations[i] * sampling_frequency) - rec_segment = NoiseGeneratorRecordingSegment(num_samples, num_channels, sampling_frequency, - noise_block_size, noise_level, dtype, - segments_seeds[i], strategy) + rec_segment = NoiseGeneratorRecordingSegment( + num_samples, + num_channels, + sampling_frequency, + noise_block_size, + noise_level, + dtype, + segments_seeds[i], + strategy, + ) self.add_recording_segment(rec_segment) self._kwargs = { @@ -550,10 +557,11 @@ def __init__( class NoiseGeneratorRecordingSegment(BaseRecordingSegment): - def __init__(self, num_samples, num_channels, sampling_frequency, noise_block_size, noise_level, dtype, seed, strategy): + def __init__( + self, num_samples, num_channels, sampling_frequency, noise_block_size, noise_level, dtype, seed, strategy + ): assert seed is not None - - + BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency) self.num_samples = num_samples @@ -566,12 +574,14 @@ def __init__(self, num_samples, num_channels, sampling_frequency, noise_block_si if self.strategy == "tile_pregenerated": rng = np.random.default_rng(seed=self.seed) - self.noise_block = rng.standard_normal(size=(self.noise_block_size, self.num_channels)).astype(self.dtype) * noise_level + self.noise_block = ( + rng.standard_normal(size=(self.noise_block_size, self.num_channels)).astype(self.dtype) * noise_level + ) elif self.strategy == "on_the_fly": pass - + def get_num_samples(self): - return self.num_samples + return self.num_samples def get_traces( self, @@ -579,7 +589,6 @@ def get_traces( end_frame: Union[int, None] = None, channel_indices: Union[List, None] = None, ) -> np.ndarray: - start_frame = 0 if start_frame is None else max(start_frame, 0) end_frame = self.num_samples if end_frame is None else min(end_frame, self.num_samples) @@ -608,12 +617,12 @@ def get_traces( pos += end_first_block else: # special case when unique block - traces[:] = noise_block[start_frame_mod:start_frame_mod + traces.shape[0]] + traces[:] = noise_block[start_frame_mod : start_frame_mod + traces.shape[0]] elif block_index == end_block_index: if end_frame_mod > 0: traces[pos:] = noise_block[:end_frame_mod] else: - traces[pos:pos + self.noise_block_size] = noise_block + traces[pos : pos + self.noise_block_size] = noise_block pos += self.noise_block_size # slice channels @@ -622,12 +631,14 @@ def get_traces( return traces -noise_generator_recording = define_function_from_class(source_class=NoiseGeneratorRecording, name="noise_generator_recording") +noise_generator_recording = define_function_from_class( + source_class=NoiseGeneratorRecording, name="noise_generator_recording" +) def generate_recording_by_size( full_traces_size_GiB: float, - num_channels:int = 1024, + num_channels: int = 1024, seed: Optional[int] = None, strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated", ) -> NoiseGeneratorRecording: @@ -675,65 +686,71 @@ def generate_recording_by_size( return recording + ## Waveforms zone ## + def exp_growth(start_amp, end_amp, duration_ms, tau_ms, sampling_frequency, flip=False): if flip: start_amp, end_amp = end_amp, start_amp - size = int(duration_ms * sampling_frequency / 1000.) - times_ms = np.arange(size + 1) / sampling_frequency * 1000. + size = int(duration_ms * sampling_frequency / 1000.0) + times_ms = np.arange(size + 1) / sampling_frequency * 1000.0 y = np.exp(times_ms / tau_ms) y = y / (y[-1] - y[0]) * (end_amp - start_amp) y = y - y[0] + start_amp if flip: - y =y[::-1] + y = y[::-1] return y[:-1] def generate_single_fake_waveform( - sampling_frequency=None, - ms_before=1.0, - ms_after=3.0, - negative_amplitude=-1, - positive_amplitude=.15, - depolarization_ms=.1, - repolarization_ms=0.6, - hyperpolarization_ms=1.1, - smooth_ms=0.05, - dtype="float32", - ): + sampling_frequency=None, + ms_before=1.0, + ms_after=3.0, + negative_amplitude=-1, + positive_amplitude=0.15, + depolarization_ms=0.1, + repolarization_ms=0.6, + hyperpolarization_ms=1.1, + smooth_ms=0.05, + dtype="float32", +): """ Very naive spike waveforms generator with 3 exponentials. """ assert ms_after > depolarization_ms + repolarization_ms assert ms_before > depolarization_ms - - nbefore = int(sampling_frequency * ms_before / 1000.) - nafter = int(sampling_frequency * ms_after/ 1000.) + nbefore = int(sampling_frequency * ms_before / 1000.0) + nafter = int(sampling_frequency * ms_after / 1000.0) width = nbefore + nafter wf = np.zeros(width, dtype=dtype) # depolarization - ndepo = int(depolarization_ms * sampling_frequency / 1000.) + ndepo = int(depolarization_ms * sampling_frequency / 1000.0) assert ndepo < nafter, "ms_before is too short" - tau_ms = depolarization_ms * .2 - wf[nbefore - ndepo:nbefore] = exp_growth(0, negative_amplitude, depolarization_ms, tau_ms, sampling_frequency, flip=False) + tau_ms = depolarization_ms * 0.2 + wf[nbefore - ndepo : nbefore] = exp_growth( + 0, negative_amplitude, depolarization_ms, tau_ms, sampling_frequency, flip=False + ) # repolarization - nrepol = int(repolarization_ms * sampling_frequency / 1000.) - tau_ms = repolarization_ms * .5 - wf[nbefore:nbefore + nrepol] = exp_growth(negative_amplitude, positive_amplitude, repolarization_ms, tau_ms, sampling_frequency, flip=True) + nrepol = int(repolarization_ms * sampling_frequency / 1000.0) + tau_ms = repolarization_ms * 0.5 + wf[nbefore : nbefore + nrepol] = exp_growth( + negative_amplitude, positive_amplitude, repolarization_ms, tau_ms, sampling_frequency, flip=True + ) # hyperpolarization - nrefac = int(hyperpolarization_ms * sampling_frequency / 1000.) + nrefac = int(hyperpolarization_ms * sampling_frequency / 1000.0) assert nrefac + nrepol < nafter, "ms_after is too short" tau_ms = hyperpolarization_ms * 0.5 - wf[nbefore + nrepol:nbefore + nrepol + nrefac] = exp_growth(positive_amplitude, 0., hyperpolarization_ms, tau_ms, sampling_frequency, flip=True) - + wf[nbefore + nrepol : nbefore + nrepol + nrefac] = exp_growth( + positive_amplitude, 0.0, hyperpolarization_ms, tau_ms, sampling_frequency, flip=True + ) # gaussian smooth - smooth_size = smooth_ms / (1 / sampling_frequency * 1000.) + smooth_size = smooth_ms / (1 / sampling_frequency * 1000.0) n = int(smooth_size * 4) bins = np.arange(-n, n + 1) smooth_kernel = np.exp(-(bins**2) / (2 * smooth_size**2)) @@ -754,26 +771,27 @@ def generate_single_fake_waveform( default_unit_params_range = dict( - alpha=(5_000., 15_000.), - depolarization_ms=(.09, .14), + alpha=(5_000.0, 15_000.0), + depolarization_ms=(0.09, 0.14), repolarization_ms=(0.5, 0.8), - hyperpolarization_ms=(1., 1.5), + hyperpolarization_ms=(1.0, 1.5), positive_amplitude=(0.05, 0.15), smooth_ms=(0.03, 0.07), ) + def generate_templates( - channel_locations, - units_locations, - sampling_frequency, - ms_before, - ms_after, - seed=None, - dtype="float32", - upsample_factor=None, - unit_params=dict(), - unit_params_range=dict(), - ): + channel_locations, + units_locations, + sampling_frequency, + ms_before, + ms_after, + seed=None, + dtype="float32", + upsample_factor=None, + unit_params=dict(), + unit_params_range=dict(), +): """ Generate some template from given channel position and neuron position. @@ -817,11 +835,10 @@ def generate_templates( The template array with shape * (num_units, num_samples, num_channels): standard case * (num_units, num_samples, num_channels, upsample_factor) if upsample_factor is not None - + """ rng = np.random.default_rng(seed=seed) - # neuron location must be 3D assert units_locations.shape[1] == 3 @@ -833,8 +850,8 @@ def generate_templates( num_units = units_locations.shape[0] num_channels = channel_locations.shape[0] - nbefore = int(sampling_frequency * ms_before / 1000.) - nafter = int(sampling_frequency * ms_after/ 1000.) + nbefore = int(sampling_frequency * ms_before / 1000.0) + nafter = int(sampling_frequency * ms_after / 1000.0) width = nbefore + nafter if upsample_factor is not None: @@ -862,22 +879,21 @@ def generate_templates( for u in range(num_units): wf = generate_single_fake_waveform( - sampling_frequency=fs, - ms_before=ms_before, - ms_after=ms_after, - negative_amplitude=-1, - positive_amplitude=params["positive_amplitude"][u], - depolarization_ms=params["depolarization_ms"][u], - repolarization_ms=params["repolarization_ms"][u], - hyperpolarization_ms=params["hyperpolarization_ms"][u], - smooth_ms=params["smooth_ms"][u], - dtype=dtype, - ) - - + sampling_frequency=fs, + ms_before=ms_before, + ms_after=ms_after, + negative_amplitude=-1, + positive_amplitude=params["positive_amplitude"][u], + depolarization_ms=params["depolarization_ms"][u], + repolarization_ms=params["repolarization_ms"][u], + hyperpolarization_ms=params["hyperpolarization_ms"][u], + smooth_ms=params["smooth_ms"][u], + dtype=dtype, + ) + alpha = params["alpha"][u] # the espilon avoid enormous factors - eps = 1. + eps = 1.0 pow = 1.5 # naive formula for spatial decay channel_factors = alpha / (distances[u, :] + eps) ** pow @@ -890,11 +906,9 @@ def generate_templates( return templates - - - ## template convolution zone ## + class InjectTemplatesRecording(BaseRecording): """ Class for creating a recording based on spike timings and templates. @@ -942,9 +956,8 @@ def __init__( parent_recording: Union[BaseRecording, None] = None, num_samples: Optional[List[int]] = None, upsample_vector: Union[List[int], None] = None, - check_borbers: bool =True, + check_borbers: bool = True, ) -> None: - templates = np.asarray(templates) if check_borbers: self._check_templates(templates) @@ -1090,7 +1103,6 @@ def get_traces( end_frame: Union[int, None] = None, channel_indices: Union[List, None] = None, ) -> np.ndarray: - start_frame = 0 if start_frame is None else start_frame end_frame = self.num_samples if end_frame is None else end_frame @@ -1166,13 +1178,16 @@ def generate_channel_locations(num_channels, num_columns, contact_spacing_um): j = 0 for i in range(num_columns): channel_locations[j : j + num_contact_per_column, 0] = i * contact_spacing_um - channel_locations[j : j + num_contact_per_column, 1] = np.arange(num_contact_per_column) * contact_spacing_um + channel_locations[j : j + num_contact_per_column, 1] = ( + np.arange(num_contact_per_column) * contact_spacing_um + ) j += num_contact_per_column return channel_locations -def generate_unit_locations(num_units, channel_locations, margin_um=20., minimum_z=5., maximum_z=40., seed=None): + +def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimum_z=5.0, maximum_z=40.0, seed=None): rng = np.random.default_rng(seed=seed) - units_locations = np.zeros((num_units, 3), dtype='float32') + units_locations = np.zeros((num_units, 3), dtype="float32") for dim in (0, 1): lim0 = np.min(channel_locations[:, dim]) - margin_um lim1 = np.max(channel_locations[:, dim]) + margin_um @@ -1183,24 +1198,24 @@ def generate_unit_locations(num_units, channel_locations, margin_um=20., minimum def generate_ground_truth_recording( - durations=[10.], - sampling_frequency=25000.0, - num_channels=4, - num_units=10, - sorting=None, - probe=None, - templates=None, - ms_before=1., - ms_after=3., - upsample_factor=None, - upsample_vector=None, - generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=1.5), - noise_kwargs=dict(noise_level=5., strategy="on_the_fly"), - generate_unit_locations_kwargs=dict(margin_um=10., minimum_z=5., maximum_z=50.), - generate_templates_kwargs=dict(), - dtype="float32", - seed=None, - ): + durations=[10.0], + sampling_frequency=25000.0, + num_channels=4, + num_units=10, + sorting=None, + probe=None, + templates=None, + ms_before=1.0, + ms_after=3.0, + upsample_factor=None, + upsample_vector=None, + generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=1.5), + noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"), + generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0), + generate_templates_kwargs=dict(), + dtype="float32", + seed=None, +): """ Generate a recording with spike given a probe+sorting+templates. @@ -1220,7 +1235,7 @@ def generate_ground_truth_recording( An external Probe object. If not provided of linear probe is generated. templates: np.array or None The templates of units. - If None they are generated. + If None they are generated. Shape can be: * (num_units, num_samples, num_channels): standard case * (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce jitter. @@ -1269,7 +1284,7 @@ def generate_ground_truth_recording( generate_sorting_kwargs["seed"] = seed sorting = generate_sorting(**generate_sorting_kwargs) else: - num_units = sorting.get_num_units() + num_units = sorting.get_num_units() assert sorting.sampling_frequency == sampling_frequency num_spikes = sorting.to_spike_vector().size @@ -1281,9 +1296,20 @@ def generate_ground_truth_recording( if templates is None: channel_locations = probe.contact_positions - unit_locations = generate_unit_locations(num_units, channel_locations, seed=seed, **generate_unit_locations_kwargs) - templates = generate_templates(channel_locations, unit_locations, sampling_frequency, ms_before, ms_after, - upsample_factor=upsample_factor, seed=seed, dtype=dtype, **generate_templates_kwargs) + unit_locations = generate_unit_locations( + num_units, channel_locations, seed=seed, **generate_unit_locations_kwargs + ) + templates = generate_templates( + channel_locations, + unit_locations, + sampling_frequency, + ms_before, + ms_after, + upsample_factor=upsample_factor, + seed=seed, + dtype=dtype, + **generate_templates_kwargs, + ) else: assert templates.shape[0] == num_units @@ -1294,27 +1320,29 @@ def generate_ground_truth_recording( upsample_factor = templates.shape[3] upsample_vector = rng.integers(0, upsample_factor, size=num_spikes) - nbefore = int(ms_before * sampling_frequency / 1000.) - nafter = int(ms_after * sampling_frequency / 1000.) + nbefore = int(ms_before * sampling_frequency / 1000.0) + nafter = int(ms_after * sampling_frequency / 1000.0) assert (nbefore + nafter) == templates.shape[1] # construct recording noise_rec = NoiseGeneratorRecording( - num_channels=num_channels, - sampling_frequency=sampling_frequency, - durations=durations, - dtype=dtype, - seed=seed, - noise_block_size=int(sampling_frequency), - **noise_kwargs + num_channels=num_channels, + sampling_frequency=sampling_frequency, + durations=durations, + dtype=dtype, + seed=seed, + noise_block_size=int(sampling_frequency), + **noise_kwargs, ) recording = InjectTemplatesRecording( - sorting, templates, nbefore=nbefore, parent_recording=noise_rec, upsample_vector=upsample_vector, + sorting, + templates, + nbefore=nbefore, + parent_recording=noise_rec, + upsample_vector=upsample_vector, ) recording.annotate(is_filtered=True) recording.set_probe(probe, in_place=True) - return recording, sorting - diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index 6dc7ee864c..a3cd0caa92 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -25,7 +25,10 @@ def test_write_binary_recording(tmp_path): durations = [10.0] recording = NoiseGeneratorRecording( - durations=durations, num_channels=num_channels, sampling_frequency=sampling_frequency, strategy="tile_pregenerated" + durations=durations, + num_channels=num_channels, + sampling_frequency=sampling_frequency, + strategy="tile_pregenerated", ) file_paths = [tmp_path / "binary01.raw"] @@ -49,7 +52,10 @@ def test_write_binary_recording_offset(tmp_path): durations = [10.0] recording = NoiseGeneratorRecording( - durations=durations, num_channels=num_channels, sampling_frequency=sampling_frequency, strategy="tile_pregenerated" + durations=durations, + num_channels=num_channels, + sampling_frequency=sampling_frequency, + strategy="tile_pregenerated", ) file_paths = [tmp_path / "binary01.raw"] @@ -82,7 +88,7 @@ def test_write_binary_recording_parallel(tmp_path): num_channels=num_channels, sampling_frequency=sampling_frequency, dtype=dtype, - strategy="tile_pregenerated" + strategy="tile_pregenerated", ) file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"] @@ -109,7 +115,10 @@ def test_write_binary_recording_multiple_segment(tmp_path): durations = [10.30, 3.5] recording = NoiseGeneratorRecording( - durations=durations, num_channels=num_channels, sampling_frequency=sampling_frequency, strategy="tile_pregenerated" + durations=durations, + num_channels=num_channels, + sampling_frequency=sampling_frequency, + strategy="tile_pregenerated", ) file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"] @@ -130,7 +139,9 @@ def test_write_binary_recording_multiple_segment(tmp_path): def test_write_memory_recording(): # 2 segments - recording = NoiseGeneratorRecording(num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated") + recording = NoiseGeneratorRecording( + num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated" + ) # make dumpable recording = recording.save() diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 550546d4f8..9ba5de42d6 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -4,10 +4,18 @@ import numpy as np from spikeinterface.core import load_extractor, extract_waveforms -from spikeinterface.core.generate import (generate_recording, generate_sorting, NoiseGeneratorRecording, generate_recording_by_size, - InjectTemplatesRecording, generate_single_fake_waveform, generate_templates, - generate_channel_locations, generate_unit_locations, generate_ground_truth_recording, - ) +from spikeinterface.core.generate import ( + generate_recording, + generate_sorting, + NoiseGeneratorRecording, + generate_recording_by_size, + InjectTemplatesRecording, + generate_single_fake_waveform, + generate_templates, + generate_channel_locations, + generate_unit_locations, + generate_ground_truth_recording, +) from spikeinterface.core.core_tools import convert_bytes_to_str @@ -21,10 +29,12 @@ def test_generate_recording(): # TODO even this is extenssivly tested in all other function pass + def test_generate_sorting(): # TODO even this is extenssivly tested in all other function pass + def measure_memory_allocation(measure_in_process: bool = True) -> float: """ A local utility to measure memory allocation at a specific point in time. @@ -49,7 +59,6 @@ def measure_memory_allocation(measure_in_process: bool = True) -> float: return memory - def test_noise_generator_memory(): # Test that get_traces does not consume more memory than allocated. @@ -69,7 +78,7 @@ def test_noise_generator_memory(): rec1 = NoiseGeneratorRecording( num_channels=num_channels, sampling_frequency=sampling_frequency, - durations=durations, + durations=durations, dtype=dtype, seed=seed, strategy="tile_pregenerated", @@ -79,14 +88,16 @@ def test_noise_generator_memory(): memory_usage_MiB = after_instanciation_MiB - before_instanciation_MiB expected_allocation_MiB = dtype.itemsize * num_channels * noise_block_size / bytes_to_MiB_factor ratio = expected_allocation_MiB / expected_allocation_MiB - assert ratio <= 1.0 + relative_tolerance, f"NoiseGeneratorRecording with 'tile_pregenerated' wrong memory {memory_usage_MiB} instead of {expected_allocation_MiB}" + assert ( + ratio <= 1.0 + relative_tolerance + ), f"NoiseGeneratorRecording with 'tile_pregenerated' wrong memory {memory_usage_MiB} instead of {expected_allocation_MiB}" # case 2: no preallocation very few memory (under 2 MiB) before_instanciation_MiB = measure_memory_allocation() / bytes_to_MiB_factor rec2 = NoiseGeneratorRecording( num_channels=num_channels, sampling_frequency=sampling_frequency, - durations=durations, + durations=durations, dtype=dtype, seed=seed, strategy="on_the_fly", @@ -126,7 +137,7 @@ def test_noise_generator_correct_shape(strategy): num_channels=num_channels, sampling_frequency=sampling_frequency, durations=durations, - dtype=dtype, + dtype=dtype, seed=seed, strategy=strategy, ) @@ -161,7 +172,7 @@ def test_noise_generator_consistency_across_calls(strategy, start_frame, end_fra lazy_recording = NoiseGeneratorRecording( num_channels=num_channels, sampling_frequency=sampling_frequency, - durations=durations, + durations=durations, dtype=dtype, seed=seed, strategy=strategy, @@ -215,21 +226,20 @@ def test_noise_generator_consistency_after_dump(strategy, seed): # test same noise after dump even with seed=None rec0 = NoiseGeneratorRecording( num_channels=2, - sampling_frequency=30000., + sampling_frequency=30000.0, durations=[2.0], dtype="float32", seed=seed, strategy=strategy, ) traces0 = rec0.get_traces() - + rec1 = load_extractor(rec0.to_dict()) traces1 = rec1.get_traces() assert np.allclose(traces0, traces1) - def test_generate_recording(): # check the high level function rec = generate_recording(mode="lazy") @@ -237,9 +247,9 @@ def test_generate_recording(): def test_generate_single_fake_waveform(): - sampling_frequency = 30000. - ms_before = 1. - ms_after = 3. + sampling_frequency = 30000.0 + ms_before = 1.0 + ms_after = 3.0 wf = generate_single_fake_waveform(ms_before=ms_before, ms_after=ms_after, sampling_frequency=sampling_frequency) # import matplotlib.pyplot as plt @@ -249,52 +259,66 @@ def test_generate_single_fake_waveform(): # ax.axvline(0) # plt.show() + def test_generate_templates(): - seed= 0 + seed = 0 num_chans = 12 num_columns = 1 num_units = 10 - margin_um= 15. - channel_locations = generate_channel_locations(num_chans, num_columns, 20.) + margin_um = 15.0 + channel_locations = generate_channel_locations(num_chans, num_columns, 20.0) unit_locations = generate_unit_locations(num_units, channel_locations, margin_um, seed) - - sampling_frequency = 30000. - ms_before = 1. - ms_after = 3. + sampling_frequency = 30000.0 + ms_before = 1.0 + ms_after = 3.0 # standard case - templates = generate_templates(channel_locations, unit_locations, sampling_frequency, ms_before, ms_after, - upsample_factor=None, - seed=42, - dtype="float32", - ) + templates = generate_templates( + channel_locations, + unit_locations, + sampling_frequency, + ms_before, + ms_after, + upsample_factor=None, + seed=42, + dtype="float32", + ) assert templates.ndim == 3 assert templates.shape[2] == num_chans assert templates.shape[0] == num_units # play with params - templates = generate_templates(channel_locations, unit_locations, sampling_frequency, ms_before, ms_after, - upsample_factor=None, - seed=42, - dtype="float32", - unit_params=dict(alpha=np.ones(num_units) * 8000.), - unit_params_range=dict(smooth_ms=(0.04, 0.05)), - ) + templates = generate_templates( + channel_locations, + unit_locations, + sampling_frequency, + ms_before, + ms_after, + upsample_factor=None, + seed=42, + dtype="float32", + unit_params=dict(alpha=np.ones(num_units) * 8000.0), + unit_params_range=dict(smooth_ms=(0.04, 0.05)), + ) # upsampling case - templates = generate_templates(channel_locations, unit_locations, sampling_frequency, ms_before, ms_after, - upsample_factor=3, - seed=42, - dtype="float32", - ) + templates = generate_templates( + channel_locations, + unit_locations, + sampling_frequency, + ms_before, + ms_after, + upsample_factor=3, + seed=42, + dtype="float32", + ) assert templates.ndim == 4 assert templates.shape[2] == num_chans assert templates.shape[0] == num_units assert templates.shape[3] == 3 - # import matplotlib.pyplot as plt # fig, ax = plt.subplots() # for u in range(num_units): @@ -315,12 +339,26 @@ def test_inject_templates(): upsample_factor = 3 # generate some sutff - rec_noise = generate_recording(num_channels=num_channels, durations=durations, sampling_frequency=sampling_frequency, mode="lazy", seed=42) + rec_noise = generate_recording( + num_channels=num_channels, durations=durations, sampling_frequency=sampling_frequency, mode="lazy", seed=42 + ) channel_locations = rec_noise.get_channel_locations() - sorting = generate_sorting(num_units=num_units, durations=durations, sampling_frequency=sampling_frequency, firing_rates=1., seed=42) - units_locations = generate_unit_locations(num_units, channel_locations, margin_um=10., seed=42) - templates_3d = generate_templates(channel_locations, units_locations, sampling_frequency, ms_before, ms_after, seed=42, upsample_factor=None) - templates_4d = generate_templates(channel_locations, units_locations, sampling_frequency, ms_before, ms_after, seed=42, upsample_factor=upsample_factor) + sorting = generate_sorting( + num_units=num_units, durations=durations, sampling_frequency=sampling_frequency, firing_rates=1.0, seed=42 + ) + units_locations = generate_unit_locations(num_units, channel_locations, margin_um=10.0, seed=42) + templates_3d = generate_templates( + channel_locations, units_locations, sampling_frequency, ms_before, ms_after, seed=42, upsample_factor=None + ) + templates_4d = generate_templates( + channel_locations, + units_locations, + sampling_frequency, + ms_before, + ms_after, + seed=42, + upsample_factor=upsample_factor, + ) # Case 1: parent_recording = None rec1 = InjectTemplatesRecording( @@ -336,8 +374,9 @@ def test_inject_templates(): # Case 3: with parent_recording + upsample_factor rng = np.random.default_rng(seed=42) upsample_vector = rng.integers(0, upsample_factor, size=sorting.to_spike_vector().size) - rec3 = InjectTemplatesRecording(sorting, templates_4d, nbefore=nbefore, parent_recording=rec_noise, upsample_vector=upsample_vector) - + rec3 = InjectTemplatesRecording( + sorting, templates_4d, nbefore=nbefore, parent_recording=rec_noise, upsample_vector=upsample_vector + ) for rec in (rec1, rec2, rec3): assert rec.get_traces(end_frame=600, segment_index=0).shape == (600, 4) @@ -357,8 +396,6 @@ def test_generate_ground_truth_recording(): assert rec.templates.ndim == 4 - - if __name__ == "__main__": strategy = "tile_pregenerated" # strategy = "on_the_fly" @@ -373,4 +410,3 @@ def test_generate_ground_truth_recording(): # test_generate_templates() # test_inject_templates() # test_generate_ground_truth_recording() - diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 7de62a64cb..c1f2fbd4b9 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -69,7 +69,7 @@ def compute(self, traces, peaks, waveforms): def test_run_node_pipeline(): - recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.]) + recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0]) job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index cba53d53e8..068d3e824b 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -41,7 +41,7 @@ def test_get_auto_merge_list(): # shutil.rmtree(wf_folder) # we = extract_waveforms(rec, sorting_with_split, mode="folder", folder=wf_folder, n_jobs=1) - we = extract_waveforms(rec, sorting_with_split, mode='memory', folder=None, n_jobs=1) + we = extract_waveforms(rec, sorting_with_split, mode="memory", folder=None, n_jobs=1) # print(we) potential_merges, outs = get_potential_auto_merge( @@ -71,7 +71,6 @@ def test_get_auto_merge_list(): true_pair = tuple(true_pair) assert true_pair in potential_merges - # import matplotlib.pyplot as plt # templates_diff = outs['templates_diff'] # correlogram_diff = outs['correlogram_diff'] diff --git a/src/spikeinterface/curation/tests/test_remove_redundant.py b/src/spikeinterface/curation/tests/test_remove_redundant.py index e89115d9dc..9e27374de1 100644 --- a/src/spikeinterface/curation/tests/test_remove_redundant.py +++ b/src/spikeinterface/curation/tests/test_remove_redundant.py @@ -36,9 +36,8 @@ def test_remove_redundant_units(): # shutil.rmtree(wf_folder) # we = extract_waveforms(rec, sorting_with_dup, folder=wf_folder) - we = extract_waveforms(rec, sorting_with_dup, mode='memory', folder=None, n_jobs=1) + we = extract_waveforms(rec, sorting_with_dup, mode="memory", folder=None, n_jobs=1) - # print(we) for remove_strategy in ("max_spikes", "minimum_shift", "highest_amplitude"): diff --git a/src/spikeinterface/extractors/toy_example.py b/src/spikeinterface/extractors/toy_example.py index 6fc7e3fa20..0b50d735ed 100644 --- a/src/spikeinterface/extractors/toy_example.py +++ b/src/spikeinterface/extractors/toy_example.py @@ -2,8 +2,13 @@ from probeinterface import Probe from spikeinterface.core import NumpySorting -from spikeinterface.core.generate import (generate_sorting, generate_channel_locations, - generate_unit_locations, generate_templates, generate_ground_truth_recording) +from spikeinterface.core.generate import ( + generate_sorting, + generate_channel_locations, + generate_unit_locations, + generate_templates, + generate_ground_truth_recording, +) def toy_example( @@ -14,7 +19,7 @@ def toy_example( num_segments=2, average_peak_amplitude=-100, upsample_factor=None, - contact_spacing_um=40., + contact_spacing_um=40.0, num_columns=1, spike_times=None, spike_labels=None, @@ -66,7 +71,9 @@ def toy_example( """ if upsample_factor is not None: - raise NotImplementedError("InjectTemplatesRecording do not support yet upsample_factor but this will be done soon") + raise NotImplementedError( + "InjectTemplatesRecording do not support yet upsample_factor but this will be done soon" + ) assert num_channels > 0 assert num_units > 0 @@ -88,24 +95,32 @@ def toy_example( channel_locations = generate_channel_locations(num_channels, num_columns, contact_spacing_um) probe = Probe(ndim=2) probe.set_contacts(positions=channel_locations, shapes="circle", shape_params={"radius": 5}) - probe.create_auto_shape(probe_type="rect", margin=20.) + probe.create_auto_shape(probe_type="rect", margin=20.0) probe.set_device_channel_indices(np.arange(num_channels, dtype="int64")) # generate templates # this is hard coded now but it use to be like this ms_before = 1.5 - ms_after = 3. + ms_after = 3.0 unit_locations = generate_unit_locations( - num_units, channel_locations, margin_um=15., minimum_z=5., maximum_z=50., seed=seed + num_units, channel_locations, margin_um=15.0, minimum_z=5.0, maximum_z=50.0, seed=seed + ) + templates = generate_templates( + channel_locations, + unit_locations, + sampling_frequency, + ms_before, + ms_after, + upsample_factor=upsample_factor, + seed=seed, + dtype="float32", ) - templates = generate_templates(channel_locations, unit_locations, sampling_frequency, ms_before, ms_after, - upsample_factor=upsample_factor, seed=seed, dtype="float32") if average_peak_amplitude is not None: # ajustement au mean amplitude amps = np.min(templates, axis=(1, 2)) - templates *= (average_peak_amplitude / np.mean(amps)) - + templates *= average_peak_amplitude / np.mean(amps) + # construct sorting if spike_times is not None: assert isinstance(spike_times, list) @@ -121,20 +136,20 @@ def toy_example( firing_rates=firing_rate, empty_units=None, refractory_period_ms=4.0, - seed=seed + seed=seed, ) recording, sorting = generate_ground_truth_recording( - durations=durations, - sampling_frequency=sampling_frequency, - sorting=sorting, - probe=probe, - templates=templates, - ms_before=ms_before, - ms_after=ms_after, - dtype="float32", - seed=seed, - noise_kwargs=dict(noise_level=10., strategy="on_the_fly"), - ) + durations=durations, + sampling_frequency=sampling_frequency, + sorting=sorting, + probe=probe, + templates=templates, + ms_before=ms_before, + ms_after=ms_after, + dtype="float32", + seed=seed, + noise_kwargs=dict(noise_level=10.0, strategy="on_the_fly"), + ) return recording, sorting diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index c62770b7e8..99ca10ba8f 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -165,7 +165,7 @@ def simulated_data(): def setup_dataset(spike_data, score_detection=1): -# def setup_dataset(spike_data): + # def setup_dataset(spike_data): recording, sorting = toy_example( duration=[spike_data["duration"]], spike_times=[spike_data["times"]], @@ -195,7 +195,7 @@ def test_calculate_firing_rate_num_spikes(simulated_data): firing_rates = compute_firing_rates(we) num_spikes = compute_num_spikes(we) - # testing method accuracy with magic number is not a good pratcice, I remove this. + # testing method accuracy with magic number is not a good pratcice, I remove this. # firing_rates_gt = {0: 10.01, 1: 5.03, 2: 5.09} # num_spikes_gt = {0: 1001, 1: 503, 2: 509} # assert np.allclose(list(firing_rates_gt.values()), list(firing_rates.values()), rtol=0.05) @@ -208,7 +208,7 @@ def test_calculate_amplitude_cutoff(simulated_data): amp_cuts = compute_amplitude_cutoffs(we, num_histogram_bins=10) print(amp_cuts) - # testing method accuracy with magic number is not a good pratcice, I remove this. + # testing method accuracy with magic number is not a good pratcice, I remove this. # amp_cuts_gt = {0: 0.33067210050787543, 1: 0.43482247296942045, 2: 0.43482247296942045} # assert np.allclose(list(amp_cuts_gt.values()), list(amp_cuts.values()), rtol=0.05) @@ -219,7 +219,7 @@ def test_calculate_amplitude_median(simulated_data): amp_medians = compute_amplitude_medians(we) print(amp_medians) - # testing method accuracy with magic number is not a good pratcice, I remove this. + # testing method accuracy with magic number is not a good pratcice, I remove this. # amp_medians_gt = {0: 130.77323354628675, 1: 130.7461997791725, 2: 130.7461997791725} # assert np.allclose(list(amp_medians_gt.values()), list(amp_medians.values()), rtol=0.05) @@ -229,7 +229,7 @@ def test_calculate_snrs(simulated_data): snrs = compute_snrs(we) print(snrs) - # testing method accuracy with magic number is not a good pratcice, I remove this. + # testing method accuracy with magic number is not a good pratcice, I remove this. # snrs_gt = {0: 12.92, 1: 12.99, 2: 12.99} # assert np.allclose(list(snrs_gt.values()), list(snrs.values()), rtol=0.05) @@ -239,7 +239,7 @@ def test_calculate_presence_ratio(simulated_data): ratios = compute_presence_ratios(we, bin_duration_s=10) print(ratios) - # testing method accuracy with magic number is not a good pratcice, I remove this. + # testing method accuracy with magic number is not a good pratcice, I remove this. # ratios_gt = {0: 1.0, 1: 1.0, 2: 1.0} # np.testing.assert_array_equal(list(ratios_gt.values()), list(ratios.values())) @@ -249,7 +249,7 @@ def test_calculate_isi_violations(simulated_data): isi_viol, counts = compute_isi_violations(we, isi_threshold_ms=1, min_isi_ms=0.0) print(isi_viol) - # testing method accuracy with magic number is not a good pratcice, I remove this. + # testing method accuracy with magic number is not a good pratcice, I remove this. # isi_viol_gt = {0: 0.0998002996004994, 1: 0.7904857139469347, 2: 1.929898371551754} # counts_gt = {0: 2, 1: 4, 2: 10} # assert np.allclose(list(isi_viol_gt.values()), list(isi_viol.values()), rtol=0.05) @@ -261,13 +261,12 @@ def test_calculate_sliding_rp_violations(simulated_data): contaminations = compute_sliding_rp_violations(we, bin_size_ms=0.25, window_size_s=1) print(contaminations) - # testing method accuracy with magic number is not a good pratcice, I remove this. + # testing method accuracy with magic number is not a good pratcice, I remove this. # contaminations_gt = {0: 0.03, 1: 0.185, 2: 0.325} # assert np.allclose(list(contaminations_gt.values()), list(contaminations.values()), rtol=0.05) def test_calculate_rp_violations(simulated_data): - counts_gt = {0: 2, 1: 4, 2: 10} we = setup_dataset(simulated_data) rp_contamination, counts = compute_refrac_period_violations(we, refractory_period_ms=1, censored_period_ms=0.0) @@ -289,7 +288,6 @@ def test_calculate_rp_violations(simulated_data): @pytest.mark.sortingcomponents def test_calculate_drift_metrics(simulated_data): - we = setup_dataset(simulated_data) spike_locs = compute_spike_locations(we) drifts_ptps, drifts_stds, drift_mads = compute_drift_metrics(we, interval_s=10, min_spikes_per_interval=10) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 52807ebf4e..4fa65993d1 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -261,7 +261,6 @@ def test_nn_metrics(self): we_sparse, metric_names=metric_names, sparsity=None, seed=0, n_jobs=2 ) for metric_name in metrics.columns: - assert np.allclose(metrics[metric_name], metrics_par[metric_name]) def test_recordingless(self): @@ -279,7 +278,6 @@ def test_recordingless(self): print(qm_rec) print(qm_no_rec) - # check metrics are the same for metric_name in qm_rec.columns: # rtol is addedd for sliding_rp_violation, for a reason I do not have to explore now. Sam. From 748751d72dcdab74fd4252f6be3792b52a60541c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 1 Sep 2023 14:09:19 +0200 Subject: [PATCH 09/10] remove test_peak_pipepeline.py from components (this is now in core) --- .../core/tests/test_node_pipeline.py | 1 + .../tests/test_peak_pipeline.py | 168 ------------------ 2 files changed, 1 insertion(+), 168 deletions(-) delete mode 100644 src/spikeinterface/sortingcomponents/tests/test_peak_pipeline.py diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index c1f2fbd4b9..84ffeb846c 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -136,6 +136,7 @@ def test_run_node_pipeline(): folder = cache_folder / "pipeline_folder" if folder.is_dir(): shutil.rmtree(folder) + output = run_node_pipeline( recording, nodes, diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_pipeline.py b/src/spikeinterface/sortingcomponents/tests/test_peak_pipeline.py deleted file mode 100644 index 269848a753..0000000000 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_pipeline.py +++ /dev/null @@ -1,168 +0,0 @@ -import pytest -import numpy as np -from pathlib import Path -import shutil - -import scipy.signal - -from spikeinterface import download_dataset, BaseSorting -from spikeinterface.extractors import MEArecRecordingExtractor - -from spikeinterface.sortingcomponents.peak_detection import detect_peaks -from spikeinterface.core.node_pipeline import ( - run_node_pipeline, - PeakRetriever, - PipelineNode, - ExtractDenseWaveforms, -) - - -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "sortingcomponents" -else: - cache_folder = Path("cache_folder") / "sortingcomponents" - - -class AmplitudeExtractionNode(PipelineNode): - def __init__(self, recording, parents=None, return_output=True, param0=5.5): - PipelineNode.__init__(self, recording, parents=parents, return_output=return_output) - self.param0 = param0 - self._dtype = np.dtype([("abs_amplitude", recording.get_dtype())]) - - def get_dtype(self): - return self._dtype - - def compute(self, traces, peaks): - amps = np.zeros(peaks.size, dtype=self._dtype) - amps["abs_amplitude"] = np.abs(peaks["amplitude"]) - return amps - - def get_trace_margin(self): - return 5 - - -class WaveformDenoiser(PipelineNode): - # waveform smoother - def __init__(self, recording, return_output=True, parents=None): - PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) - - def get_dtype(self): - return np.dtype("float32") - - def compute(self, traces, peaks, waveforms): - kernel = np.array([0.1, 0.8, 0.1])[np.newaxis, :, np.newaxis] - denoised_waveforms = scipy.signal.fftconvolve(waveforms, kernel, axes=1, mode="same") - return denoised_waveforms - - -class WaveformsRootMeanSquare(PipelineNode): - def __init__(self, recording, return_output=True, parents=None): - PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) - - def get_dtype(self): - return np.dtype("float32") - - def compute(self, traces, peaks, waveforms): - rms_by_channels = np.sum(waveforms**2, axis=1) - return rms_by_channels - - -def test_run_node_pipeline(): - repo = "https://gin.g-node.org/NeuralEnsemble/ephy_testing_data" - remote_path = "mearec/mearec_test_10s.h5" - local_path = download_dataset(repo=repo, remote_path=remote_path, local_folder=None) - recording = MEArecRecordingExtractor(local_path) - - job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) - - peaks = detect_peaks( - recording, method="locally_exclusive", peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1, **job_kwargs - ) - - # one step only : squeeze output - peak_retriever = PeakRetriever(recording, peaks) - nodes = [ - peak_retriever, - AmplitudeExtractionNode(recording, parents=[peak_retriever], param0=6.6), - ] - step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True) - assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"]) - - # 3 nodes two have outputs - ms_before = 0.5 - ms_after = 1.0 - peak_retriever = PeakRetriever(recording, peaks) - extract_waveforms = ExtractDenseWaveforms( - recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False - ) - waveform_denoiser = WaveformDenoiser(recording, parents=[peak_retriever, extract_waveforms], return_output=False) - amplitue_extraction = AmplitudeExtractionNode(recording, parents=[peak_retriever], param0=6.6, return_output=True) - waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_retriever, extract_waveforms], return_output=True) - denoised_waveforms_rms = WaveformsRootMeanSquare( - recording, parents=[peak_retriever, waveform_denoiser], return_output=True - ) - - nodes = [ - peak_retriever, - extract_waveforms, - waveform_denoiser, - amplitue_extraction, - waveforms_rms, - denoised_waveforms_rms, - ] - - # gather memory mode - output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory") - amplitudes, waveforms_rms, denoised_waveforms_rms = output - assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"]) - - num_peaks = peaks.shape[0] - num_channels = recording.get_num_channels() - assert waveforms_rms.shape[0] == num_peaks - assert waveforms_rms.shape[1] == num_channels - - assert waveforms_rms.shape[0] == num_peaks - assert waveforms_rms.shape[1] == num_channels - - # gather npy mode - folder = cache_folder / "pipeline_folder" - if folder.is_dir(): - shutil.rmtree(folder) - output = run_node_pipeline( - recording, - nodes, - job_kwargs, - gather_mode="npy", - folder=folder, - names=["amplitudes", "waveforms_rms", "denoised_waveforms_rms"], - ) - amplitudes2, waveforms_rms2, denoised_waveforms_rms2 = output - - amplitudes_file = folder / "amplitudes.npy" - assert amplitudes_file.is_file() - amplitudes3 = np.load(amplitudes_file) - assert np.array_equal(amplitudes, amplitudes2) - assert np.array_equal(amplitudes2, amplitudes3) - - waveforms_rms_file = folder / "waveforms_rms.npy" - assert waveforms_rms_file.is_file() - waveforms_rms3 = np.load(waveforms_rms_file) - assert np.array_equal(waveforms_rms, waveforms_rms2) - assert np.array_equal(waveforms_rms2, waveforms_rms3) - - denoised_waveforms_rms_file = folder / "denoised_waveforms_rms.npy" - assert denoised_waveforms_rms_file.is_file() - denoised_waveforms_rms3 = np.load(denoised_waveforms_rms_file) - assert np.array_equal(denoised_waveforms_rms, denoised_waveforms_rms2) - assert np.array_equal(denoised_waveforms_rms2, denoised_waveforms_rms3) - - # Test pickle mechanism - for node in nodes: - import pickle - - pickled_node = pickle.dumps(node) - unpickled_node = pickle.loads(pickled_node) - - -if __name__ == "__main__": - test_run_node_pipeline() From 0ee1d1165d2d8adbf54f971dc8bca9b262346f97 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Sep 2023 12:10:25 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/tests/test_node_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 84ffeb846c..85f41924c1 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -136,7 +136,7 @@ def test_run_node_pipeline(): folder = cache_folder / "pipeline_folder" if folder.is_dir(): shutil.rmtree(folder) - + output = run_node_pipeline( recording, nodes,