diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 9ea5ad59e7..b11f40a441 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -84,7 +84,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *ar raise NotImplementedError -# nodes graph must have either a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever) +# nodes graph must have a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever) # as first element they play the same role in pipeline : give some peaks (and eventually more) @@ -138,7 +138,99 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # this is not implemented yet this will be done in separted PR class SpikeRetriever(PeakSource): - pass + """ + This class is useful to inject a sorting object in the node pipepline mechanism. + It allows to compute some post-processing steps with the same machinery used for sorting components. + This is used by: + * compute_spike_locations() + * compute_amplitude_scalings() + * compute_spike_amplitudes() + * compute_principal_components() + + recording : BaseRecording + The recording object. + sorting: BaseSorting + The sorting object. + channel_from_template: bool, default: True + If True, then the channel_index is inferred from the template and `extremum_channel_inds` must be provided. + If False, the max channel is computed for each spike given a radius around the template max channel. + extremum_channel_inds: dict of int + The extremum channel index dict given from template. + radius_um: float (default 50.) + The radius to find the real max channel. + Used only when channel_from_template=False + peak_sign: str (default "neg") + Peak sign to find the max channel. + Used only when channel_from_template=False + """ + + def __init__( + self, recording, sorting, channel_from_template=True, extremum_channel_inds=None, radius_um=50, peak_sign="neg" + ): + PipelineNode.__init__(self, recording, return_output=False) + + self.channel_from_template = channel_from_template + + assert extremum_channel_inds is not None, "SpikeRetriever needs the extremum_channel_inds dictionary" + + self.peaks = sorting_to_peaks(sorting, extremum_channel_inds) + + if not channel_from_template: + channel_distance = get_channel_distances(recording) + self.neighbours_mask = channel_distance < radius_um + self.peak_sign = peak_sign + + # precompute segment slice + self.segment_slices = [] + for segment_index in range(recording.get_num_segments()): + i0 = np.searchsorted(self.peaks["segment_index"], segment_index) + i1 = np.searchsorted(self.peaks["segment_index"], segment_index + 1) + self.segment_slices.append(slice(i0, i1)) + + def get_trace_margin(self): + return 0 + + def get_dtype(self): + return base_peak_dtype + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + # get local peaks + sl = self.segment_slices[segment_index] + peaks_in_segment = self.peaks[sl] + i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame) + i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame) + local_peaks = peaks_in_segment[i0:i1] + + # make sample index local to traces + local_peaks = local_peaks.copy() + local_peaks["sample_index"] -= start_frame - max_margin + + if not self.channel_from_template: + # handle channel spike per spike + for i, peak in enumerate(local_peaks): + chans = np.flatnonzero(self.neighbours_mask[peak["channel_index"]]) + sparse_wfs = traces[peak["sample_index"], chans] + if self.peak_sign == "neg": + local_peaks[i]["channel_index"] = chans[np.argmin(sparse_wfs)] + elif self.peak_sign == "pos": + local_peaks[i]["channel_index"] = chans[np.argmax(sparse_wfs)] + elif self.peak_sign == "both": + local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))] + + # TODO: "amplitude" ??? + + return (local_peaks,) + + +def sorting_to_peaks(sorting, extremum_channel_inds): + spikes = sorting.to_spike_vector() + peaks = np.zeros(spikes.size, dtype=base_peak_dtype) + peaks["sample_index"] = spikes["sample_index"] + extremum_channel_inds_ = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids]) + peaks["channel_index"] = extremum_channel_inds_[spikes["unit_index"]] + peaks["amplitude"] = 0.0 + peaks["segment_index"] = spikes["segment_index"] + return peaks class WaveformsNode(PipelineNode): @@ -423,7 +515,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) # set sample index to local node_output[0]["sample_index"] += extra_margin - elif isinstance(node, PeakRetriever): + elif isinstance(node, PeakSource): node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin) else: # TODO later when in master: change the signature of all nodes (or maybe not!) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 85f41924c1..bcb15b6455 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -12,9 +12,10 @@ from spikeinterface.core.node_pipeline import ( run_node_pipeline, PeakRetriever, + SpikeRetriever, PipelineNode, ExtractDenseWaveforms, - base_peak_dtype, + sorting_to_peaks, ) @@ -78,99 +79,107 @@ def test_run_node_pipeline(): # create peaks from spikes we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs) extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") - # print(extremum_channel_inds) - ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids]) - # print(ext_channel_inds) - peaks = np.zeros(spikes.size, dtype=base_peak_dtype) - peaks["sample_index"] = spikes["sample_index"] - peaks["channel_index"] = ext_channel_inds[spikes["unit_index"]] - peaks["amplitude"] = 0.0 - peaks["segment_index"] = 0 - - # one step only : squeeze output - peak_retriever = PeakRetriever(recording, peaks) - nodes = [ - peak_retriever, - AmplitudeExtractionNode(recording, parents=[peak_retriever], param0=6.6), - ] - step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True) - assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"]) - - # 3 nodes two have outputs - ms_before = 0.5 - ms_after = 1.0 + peaks = sorting_to_peaks(sorting, extremum_channel_inds) + peak_retriever = PeakRetriever(recording, peaks) - dense_waveforms = ExtractDenseWaveforms( - recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False - ) - waveform_denoiser = WaveformDenoiser(recording, parents=[peak_retriever, dense_waveforms], return_output=False) - amplitue_extraction = AmplitudeExtractionNode(recording, parents=[peak_retriever], param0=6.6, return_output=True) - waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_retriever, dense_waveforms], return_output=True) - denoised_waveforms_rms = WaveformsRootMeanSquare( - recording, parents=[peak_retriever, waveform_denoiser], return_output=True + # channel index is from template + spike_retriever_T = SpikeRetriever( + recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channel_inds ) - - nodes = [ - peak_retriever, - dense_waveforms, - waveform_denoiser, - amplitue_extraction, - waveforms_rms, - denoised_waveforms_rms, - ] - - # gather memory mode - output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory") - amplitudes, waveforms_rms, denoised_waveforms_rms = output - assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"]) - - num_peaks = peaks.shape[0] - num_channels = recording.get_num_channels() - assert waveforms_rms.shape[0] == num_peaks - assert waveforms_rms.shape[1] == num_channels - - assert waveforms_rms.shape[0] == num_peaks - assert waveforms_rms.shape[1] == num_channels - - # gather npy mode - folder = cache_folder / "pipeline_folder" - if folder.is_dir(): - shutil.rmtree(folder) - - output = run_node_pipeline( + # channel index is per spike + spike_retriever_S = SpikeRetriever( recording, - nodes, - job_kwargs, - gather_mode="npy", - folder=folder, - names=["amplitudes", "waveforms_rms", "denoised_waveforms_rms"], + sorting, + channel_from_template=False, + extremum_channel_inds=extremum_channel_inds, + radius_um=50, + peak_sign="neg", ) - amplitudes2, waveforms_rms2, denoised_waveforms_rms2 = output - - amplitudes_file = folder / "amplitudes.npy" - assert amplitudes_file.is_file() - amplitudes3 = np.load(amplitudes_file) - assert np.array_equal(amplitudes, amplitudes2) - assert np.array_equal(amplitudes2, amplitudes3) - - waveforms_rms_file = folder / "waveforms_rms.npy" - assert waveforms_rms_file.is_file() - waveforms_rms3 = np.load(waveforms_rms_file) - assert np.array_equal(waveforms_rms, waveforms_rms2) - assert np.array_equal(waveforms_rms2, waveforms_rms3) - - denoised_waveforms_rms_file = folder / "denoised_waveforms_rms.npy" - assert denoised_waveforms_rms_file.is_file() - denoised_waveforms_rms3 = np.load(denoised_waveforms_rms_file) - assert np.array_equal(denoised_waveforms_rms, denoised_waveforms_rms2) - assert np.array_equal(denoised_waveforms_rms2, denoised_waveforms_rms3) - - # Test pickle mechanism - for node in nodes: - import pickle - - pickled_node = pickle.dumps(node) - unpickled_node = pickle.loads(pickled_node) + + # test with 3 differents first nodes + for loop, peak_source in enumerate((peak_retriever, spike_retriever_T, spike_retriever_S)): + # one step only : squeeze output + nodes = [ + peak_source, + AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6), + ] + step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True) + assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"]) + + # 3 nodes two have outputs + ms_before = 0.5 + ms_after = 1.0 + peak_retriever = PeakRetriever(recording, peaks) + dense_waveforms = ExtractDenseWaveforms( + recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False + ) + waveform_denoiser = WaveformDenoiser(recording, parents=[peak_source, dense_waveforms], return_output=False) + amplitue_extraction = AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6, return_output=True) + waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_source, dense_waveforms], return_output=True) + denoised_waveforms_rms = WaveformsRootMeanSquare( + recording, parents=[peak_source, waveform_denoiser], return_output=True + ) + + nodes = [ + peak_source, + dense_waveforms, + waveform_denoiser, + amplitue_extraction, + waveforms_rms, + denoised_waveforms_rms, + ] + + # gather memory mode + output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory") + amplitudes, waveforms_rms, denoised_waveforms_rms = output + assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"]) + + num_peaks = peaks.shape[0] + num_channels = recording.get_num_channels() + assert waveforms_rms.shape[0] == num_peaks + assert waveforms_rms.shape[1] == num_channels + + assert waveforms_rms.shape[0] == num_peaks + assert waveforms_rms.shape[1] == num_channels + + # gather npy mode + folder = cache_folder / f"pipeline_folder_{loop}" + if folder.is_dir(): + shutil.rmtree(folder) + output = run_node_pipeline( + recording, + nodes, + job_kwargs, + gather_mode="npy", + folder=folder, + names=["amplitudes", "waveforms_rms", "denoised_waveforms_rms"], + ) + amplitudes2, waveforms_rms2, denoised_waveforms_rms2 = output + + amplitudes_file = folder / "amplitudes.npy" + assert amplitudes_file.is_file() + amplitudes3 = np.load(amplitudes_file) + assert np.array_equal(amplitudes, amplitudes2) + assert np.array_equal(amplitudes2, amplitudes3) + + waveforms_rms_file = folder / "waveforms_rms.npy" + assert waveforms_rms_file.is_file() + waveforms_rms3 = np.load(waveforms_rms_file) + assert np.array_equal(waveforms_rms, waveforms_rms2) + assert np.array_equal(waveforms_rms2, waveforms_rms3) + + denoised_waveforms_rms_file = folder / "denoised_waveforms_rms.npy" + assert denoised_waveforms_rms_file.is_file() + denoised_waveforms_rms3 = np.load(denoised_waveforms_rms_file) + assert np.array_equal(denoised_waveforms_rms, denoised_waveforms_rms2) + assert np.array_equal(denoised_waveforms_rms2, denoised_waveforms_rms3) + + # Test pickle mechanism + for node in nodes: + import pickle + + pickled_node = pickle.dumps(node) + unpickled_node = pickle.loads(pickled_node) if __name__ == "__main__":