From 5370884331aa3b4ebdf25c10bf4d103fd502f28a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 1 Sep 2023 12:05:41 +0200 Subject: [PATCH 01/11] Refactor compite_spike_location using SpikeRetriver. --- .../core/tests/test_node_pipeline.py | 3 - .../postprocessing/spike_locations.py | 37 ++++++-- .../tests/test_spike_locations.py | 3 +- .../sortingcomponents/peak_localization.py | 84 +++++++++++-------- src/spikeinterface/sortingcomponents/tools.py | 3 + 5 files changed, 84 insertions(+), 46 deletions(-) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 339167f673..db5305c313 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -97,9 +97,6 @@ def test_run_node_pipeline(): # test with 2 diffrents first node for peak_source in (peak_retriever, spike_retriever_T, spike_retriever_S): - - - # one step only : squeeze output nodes = [ peak_source, diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index c6f498f7e8..32443d44d0 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -5,6 +5,7 @@ from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension +from spikeinterface.core.node_pipeline import SpikeRetriever class SpikeLocationsCalculator(BaseWaveformExtractorExtension): @@ -25,9 +26,12 @@ def __init__(self, waveform_extractor): extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, outputs="index") self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) - def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", method_kwargs={}): - params = dict(ms_before=ms_before, ms_after=ms_after, method=method) + + + def _set_params(self, ms_before=0.5, ms_after=0.5, channel_from_template=True, method="center_of_mass", method_kwargs={}): + params = dict(ms_before=ms_before, ms_after=ms_after, channel_from_template=channel_from_template, method=method) params.update(**method_kwargs) + print(params) return params def _select_extension_data(self, unit_ids): @@ -44,13 +48,28 @@ def _run(self, **job_kwargs): uses the`sortingcomponents.peak_localization.localize_peaks()` function to triangulate spike locations. """ - from spikeinterface.sortingcomponents.peak_localization import localize_peaks + from spikeinterface.sortingcomponents.peak_localization import _run_localization_from_peak_source job_kwargs = fix_job_kwargs(job_kwargs) we = self.waveform_extractor - spike_locations = localize_peaks(we.recording, self.spikes, **self._params, **job_kwargs) + extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") + + params = self._params.copy() + channel_from_template = params.pop("channel_from_template") + + # @alessio @pierre: where do we expose the parameters of radius for the retriever (this is not the same as the one for locatization it is smaller) ??? + spike_retriever = SpikeRetriever( + we.recording, + we.sorting, + channel_from_template=channel_from_template, + extremum_channel_inds=extremum_channel_inds, + radius_um=50, + peak_sign=self._params.get("peaks_sign", "neg") + ) + spike_locations = _run_localization_from_peak_source(we.recording, spike_retriever, **params, **job_kwargs) + self._extension_data["spike_locations"] = spike_locations def get_data(self, outputs="concatenated"): @@ -95,12 +114,15 @@ def get_extension_function(): WaveformExtractor.register_extension(SpikeLocationsCalculator) +# @alessio @pierre: channel_from_template=True is the old behavior but this is not accurate +# what do we put by default ? def compute_spike_locations( waveform_extractor, load_if_exists=False, ms_before=0.5, ms_after=0.5, + channel_from_template=True, method="center_of_mass", method_kwargs={}, outputs="concatenated", @@ -119,6 +141,10 @@ def compute_spike_locations( The left window, before a peak, in milliseconds. ms_after : float The right window, after a peak, in milliseconds. + channel_from_template: bool, default True + For each spike is the maximum channel computed from template or re estimated at every spikes. + channel_from_template = True is old behavior but less acurate + channel_from_template = False is slower but more accurate method : str 'center_of_mass' / 'monopolar_triangulation' / 'grid_convolution' method_kwargs : dict @@ -138,7 +164,8 @@ def compute_spike_locations( slc = waveform_extractor.load_extension(SpikeLocationsCalculator.extension_name) else: slc = SpikeLocationsCalculator(waveform_extractor) - slc.set_params(ms_before=ms_before, ms_after=ms_after, method=method, method_kwargs=method_kwargs) + slc.set_params(ms_before=ms_before, ms_after=ms_after, channel_from_template=channel_from_template, + method=method, method_kwargs=method_kwargs) slc.run(**job_kwargs) locs = slc.get_data(outputs=outputs) diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index 521b49e6cd..ab2345b1f5 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -10,7 +10,8 @@ class SpikeLocationsExtensionTest(WaveformExtensionCommonTestSuite, unittest.Tes extension_class = SpikeLocationsCalculator extension_data_names = ["spike_locations"] extension_function_kwargs_list = [ - dict(method="center_of_mass", chunk_size=10000, n_jobs=1), + dict(method="center_of_mass", chunk_size=10000, n_jobs=1, channel_from_template=True), + dict(method="center_of_mass", chunk_size=10000, n_jobs=1, channel_from_template=False), dict(method="center_of_mass", chunk_size=10000, n_jobs=1, outputs="by_unit"), dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"), dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"), diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index fa6101f896..b638e8ed3a 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -7,6 +7,7 @@ run_node_pipeline, find_parent_of_type, PeakRetriever, + SpikeRetriever, PipelineNode, WaveformsNode, ExtractDenseWaveforms, @@ -27,72 +28,49 @@ from .tools import get_prototype_spike -def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs): - """Localize peak (spike) in 2D or 3D depending the method. - - When a probe is 2D then: - * X is axis 0 of the probe - * Y is axis 1 of the probe - * Z is orthogonal to the plane of the probe - - Parameters - ---------- - recording: RecordingExtractor - The recording extractor object. - peaks: array - Peaks array, as returned by detect_peaks() in "compact_numpy" way. - - {method_doc} - - {job_doc} - - Returns - ------- - peak_locations: ndarray - Array with estimated location for each spike. - The dtype depends on the method. ('x', 'y') or ('x', 'y', 'z', 'alpha'). - """ +def _run_localization_from_peak_source(recording, peak_source, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs): + # use by localize_peaks() and compute_spike_locations() assert ( method in possible_localization_methods ), f"Method {method} is not supported. Choose from {possible_localization_methods}" method_kwargs, job_kwargs = split_job_kwargs(kwargs) - peak_retriever = PeakRetriever(recording, peaks) if method == "center_of_mass": extract_dense_waveforms = ExtractDenseWaveforms( - recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False + recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False ) pipeline_nodes = [ - peak_retriever, + peak_source, extract_dense_waveforms, - LocalizeCenterOfMass(recording, parents=[peak_retriever, extract_dense_waveforms], **method_kwargs), + LocalizeCenterOfMass(recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs), ] elif method == "monopolar_triangulation": extract_dense_waveforms = ExtractDenseWaveforms( - recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False + recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False ) pipeline_nodes = [ - peak_retriever, + peak_source, extract_dense_waveforms, LocalizeMonopolarTriangulation( - recording, parents=[peak_retriever, extract_dense_waveforms], **method_kwargs + recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs ), ] elif method == "peak_channel": - pipeline_nodes = [peak_retriever, LocalizePeakChannel(recording, parents=[peak_retriever], **method_kwargs)] + pipeline_nodes = [peak_source, LocalizePeakChannel(recording, parents=[peak_source], **method_kwargs)] elif method == "grid_convolution": if "prototype" not in method_kwargs: + assert isinstance(peak_source, (PeakRetriever, SpikeRetriever)) method_kwargs["prototype"] = get_prototype_spike( - recording, peaks, ms_before=ms_before, ms_after=ms_after, job_kwargs=job_kwargs + recording, peak_source.peaks, ms_before=ms_before, ms_after=ms_after, job_kwargs=job_kwargs ) extract_dense_waveforms = ExtractDenseWaveforms( - recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False + recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False ) pipeline_nodes = [ - peak_retriever, + peak_source, extract_dense_waveforms, - LocalizeGridConvolution(recording, parents=[peak_retriever, extract_dense_waveforms], **method_kwargs), + LocalizeGridConvolution(recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs), ] job_name = f"localize peaks using {method}" @@ -101,6 +79,38 @@ def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_ return peak_locations + +def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs): + """Localize peak (spike) in 2D or 3D depending the method. + + When a probe is 2D then: + * X is axis 0 of the probe + * Y is axis 1 of the probe + * Z is orthogonal to the plane of the probe + + Parameters + ---------- + recording: RecordingExtractor + The recording extractor object. + peaks: array + Peaks array, as returned by detect_peaks() in "compact_numpy" way. + + {method_doc} + + {job_doc} + + Returns + ------- + peak_locations: ndarray + Array with estimated location for each spike. + The dtype depends on the method. ('x', 'y') or ('x', 'y', 'z', 'alpha'). + """ + peak_retriever = PeakRetriever(recording, peaks) + peak_locations = _run_localization_from_peak_source(recording, peak_retriever, method=method, ms_before=ms_before, ms_after=ms_after, **kwargs) + return peak_locations + + + class LocalizeBase(PipelineNode): def __init__(self, recording, return_output=True, parents=None, radius_um=75.0): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 45b9079ea9..576732baa2 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -19,6 +19,9 @@ def make_multi_method_doc(methods, ident=" "): def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0.5, ms_after=0.5): + # TODO for Pierre: this function is really unefficient because it runa full pipeline only for a few + # spikes, which leans that traces are entirally computed!!!!! + # Please find a better way nb_peaks = min(len(peaks), nb_peaks) idx = np.sort(np.random.choice(len(peaks), nb_peaks, replace=False)) peak_retriever = PeakRetriever(recording, peaks[idx]) From 68df57384e5ca424c6b07aacf3e48933c1b5fa55 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Sep 2023 10:06:45 +0000 Subject: [PATCH 02/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/node_pipeline.py | 19 +++++++--------- .../core/tests/test_node_pipeline.py | 22 ++++++++++--------- .../postprocessing/spike_locations.py | 22 +++++++++++++------ .../sortingcomponents/peak_localization.py | 14 ++++++------ 4 files changed, 42 insertions(+), 35 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index ff747fe2a0..610ae42398 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -141,7 +141,7 @@ class SpikeRetriever(PeakSource): """ This class is usefull to inject a sorting object in the node pipepline mechanisim. It allows to compute some post processing with the same machinery used for sorting components. - This is a first step to totaly refactor: + This is a first step to totaly refactor: * compute_spike_locations() * compute_amplitude_scalings() * compute_spike_amplitudes() @@ -164,16 +164,14 @@ class SpikeRetriever(PeakSource): Peak sign to find the max channel. Used only when channel_from_template=False """ - def __init__(self, recording, sorting, - channel_from_template=True, - extremum_channel_inds=None, - radius_um=50, - peak_sign="neg" - ): + + def __init__( + self, recording, sorting, channel_from_template=True, extremum_channel_inds=None, radius_um=50, peak_sign="neg" + ): PipelineNode.__init__(self, recording, return_output=False) self.channel_from_template = channel_from_template - + assert extremum_channel_inds is not None, "SpikeRetriever need the dict extremum_channel_inds" self.peaks = sorting_to_peak(sorting, extremum_channel_inds) @@ -181,8 +179,7 @@ def __init__(self, recording, sorting, if not channel_from_template: channel_distance = get_channel_distances(recording) self.neighbours_mask = channel_distance < radius_um - self.peak_sign = peak_sign - + self.peak_sign = peak_sign # precompute segment slice self.segment_slices = [] @@ -219,7 +216,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): elif self.peak_sign == "pos": local_peaks[i]["channel_index"] = chans[np.argmax(sparse_wfs)] elif self.peak_sign == "both": - local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))] + local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))] # TODO: "amplitude" ??? diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index db5305c313..f271e81869 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -81,22 +81,24 @@ def test_run_node_pipeline(): we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs) extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") peaks = sorting_to_peak(sorting, extremum_channel_inds) - + peak_retriever = PeakRetriever(recording, peaks) # channel index is from template - spike_retriever_T = SpikeRetriever(recording, sorting, - channel_from_template=True, - extremum_channel_inds=extremum_channel_inds) + spike_retriever_T = SpikeRetriever( + recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channel_inds + ) # channel index is per spike - spike_retriever_S = SpikeRetriever(recording, sorting, - channel_from_template=False, - extremum_channel_inds=extremum_channel_inds, - radius_um=50, - peak_sign="neg") + spike_retriever_S = SpikeRetriever( + recording, + sorting, + channel_from_template=False, + extremum_channel_inds=extremum_channel_inds, + radius_um=50, + peak_sign="neg", + ) # test with 2 diffrents first node for peak_source in (peak_retriever, spike_retriever_T, spike_retriever_S): - # one step only : squeeze output nodes = [ peak_source, diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 32443d44d0..1da2858142 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -26,10 +26,12 @@ def __init__(self, waveform_extractor): extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, outputs="index") self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) - - - def _set_params(self, ms_before=0.5, ms_after=0.5, channel_from_template=True, method="center_of_mass", method_kwargs={}): - params = dict(ms_before=ms_before, ms_after=ms_after, channel_from_template=channel_from_template, method=method) + def _set_params( + self, ms_before=0.5, ms_after=0.5, channel_from_template=True, method="center_of_mass", method_kwargs={} + ): + params = dict( + ms_before=ms_before, ms_after=ms_after, channel_from_template=channel_from_template, method=method + ) params.update(**method_kwargs) print(params) return params @@ -66,7 +68,7 @@ def _run(self, **job_kwargs): channel_from_template=channel_from_template, extremum_channel_inds=extremum_channel_inds, radius_um=50, - peak_sign=self._params.get("peaks_sign", "neg") + peak_sign=self._params.get("peaks_sign", "neg"), ) spike_locations = _run_localization_from_peak_source(we.recording, spike_retriever, **params, **job_kwargs) @@ -117,6 +119,7 @@ def get_extension_function(): # @alessio @pierre: channel_from_template=True is the old behavior but this is not accurate # what do we put by default ? + def compute_spike_locations( waveform_extractor, load_if_exists=False, @@ -164,8 +167,13 @@ def compute_spike_locations( slc = waveform_extractor.load_extension(SpikeLocationsCalculator.extension_name) else: slc = SpikeLocationsCalculator(waveform_extractor) - slc.set_params(ms_before=ms_before, ms_after=ms_after, channel_from_template=channel_from_template, - method=method, method_kwargs=method_kwargs) + slc.set_params( + ms_before=ms_before, + ms_after=ms_after, + channel_from_template=channel_from_template, + method=method, + method_kwargs=method_kwargs, + ) slc.run(**job_kwargs) locs = slc.get_data(outputs=outputs) diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index b638e8ed3a..6495503b43 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -28,7 +28,9 @@ from .tools import get_prototype_spike -def _run_localization_from_peak_source(recording, peak_source, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs): +def _run_localization_from_peak_source( + recording, peak_source, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs +): # use by localize_peaks() and compute_spike_locations() assert ( method in possible_localization_methods @@ -52,9 +54,7 @@ def _run_localization_from_peak_source(recording, peak_source, method="center_of pipeline_nodes = [ peak_source, extract_dense_waveforms, - LocalizeMonopolarTriangulation( - recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs - ), + LocalizeMonopolarTriangulation(recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs), ] elif method == "peak_channel": pipeline_nodes = [peak_source, LocalizePeakChannel(recording, parents=[peak_source], **method_kwargs)] @@ -79,7 +79,6 @@ def _run_localization_from_peak_source(recording, peak_source, method="center_of return peak_locations - def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs): """Localize peak (spike) in 2D or 3D depending the method. @@ -106,11 +105,12 @@ def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_ The dtype depends on the method. ('x', 'y') or ('x', 'y', 'z', 'alpha'). """ peak_retriever = PeakRetriever(recording, peaks) - peak_locations = _run_localization_from_peak_source(recording, peak_retriever, method=method, ms_before=ms_before, ms_after=ms_after, **kwargs) + peak_locations = _run_localization_from_peak_source( + recording, peak_retriever, method=method, ms_before=ms_before, ms_after=ms_after, **kwargs + ) return peak_locations - class LocalizeBase(PipelineNode): def __init__(self, recording, return_output=True, parents=None, radius_um=75.0): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) From 8e316efd8bf280374598c33c490e8b3c6c90dc3a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 23 Oct 2023 21:29:16 +0200 Subject: [PATCH 03/11] wip compute_spike_location with true channel --- src/spikeinterface/core/node_pipeline.py | 4 +- .../postprocessing/spike_locations.py | 46 +++++++++++-------- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index a0ded216d1..e17c5d5caa 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -213,7 +213,9 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): elif self.peak_sign == "both": local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))] - # TODO: "amplitude" ??? + # handle amplitude + for i, peak in enumerate(local_peaks): + local_peaks["amplitude"][i] = traces[local_peaks["sample_index"], local_peaks[i]["channel_index"]] return (local_peaks,) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index e4b60d401e..2807fb992c 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -27,13 +27,18 @@ def __init__(self, waveform_extractor): self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) def _set_params( - self, ms_before=0.5, ms_after=0.5, channel_from_template=True, method="center_of_mass", method_kwargs={} + self, ms_before=0.5, ms_after=0.5, + spike_retriver_kwargs=dict( + channel_from_template=False, + radius_um=50, + peaks_sign="neg", + ), + method="center_of_mass", method_kwargs={} ): params = dict( - ms_before=ms_before, ms_after=ms_after, channel_from_template=channel_from_template, method=method + ms_before=ms_before, ms_after=ms_after, spike_retriver_kwargs=spike_retriver_kwargs, method=method ) params.update(**method_kwargs) - print(params) return params def _select_extension_data(self, unit_ids): @@ -59,16 +64,13 @@ def _run(self, **job_kwargs): extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") params = self._params.copy() - channel_from_template = params.pop("channel_from_template") + spike_retriver_kwargs = params.pop("spike_retriver_kwargs") - # @alessio @pierre: where do we expose the parameters of radius for the retriever (this is not the same as the one for locatization it is smaller) ??? spike_retriever = SpikeRetriever( we.recording, we.sorting, - channel_from_template=channel_from_template, extremum_channel_inds=extremum_channel_inds, - radius_um=50, - peak_sign=self._params.get("peaks_sign", "neg"), + **spike_retriver_kwargs ) spike_locations = _run_localization_from_peak_source(we.recording, spike_retriever, **params, **job_kwargs) @@ -116,16 +118,17 @@ def get_extension_function(): WaveformExtractor.register_extension(SpikeLocationsCalculator) -# @alessio @pierre: channel_from_template=True is the old behavior but this is not accurate -# what do we put by default ? - - def compute_spike_locations( waveform_extractor, load_if_exists=False, ms_before=0.5, ms_after=0.5, - channel_from_template=True, + spike_retriver_kwargs=dict( + channel_from_template=False, + radius_um=50, + peaks_sign="neg", + ), + method="center_of_mass", method_kwargs={}, outputs="concatenated", @@ -144,10 +147,17 @@ def compute_spike_locations( The left window, before a peak, in milliseconds. ms_after : float The right window, after a peak, in milliseconds. - channel_from_template: bool, default True - For each spike is the maximum channel computed from template or re estimated at every spikes. - channel_from_template = True is old behavior but less acurate - channel_from_template = False is slower but more accurate + spike_retriver_kwargs: dict + A dict that contains the behavior for getting the maximum channel for each spike. + This contain dict contains: + * channel_from_template: bool, default True + For each spike is the maximum channel computed from template or re estimated at every spikes. + channel_from_template = True is old behavior but less acurate + channel_from_template = False is slower but more accurate + * radius_um: float, default 50 + In case channel_from_template=False, this is the radius to get the true peak. + * peaks_sign="neg" + In case channel_from_template=False, this is the peak sign. method : str 'center_of_mass' / 'monopolar_triangulation' / 'grid_convolution' method_kwargs : dict @@ -170,7 +180,7 @@ def compute_spike_locations( slc.set_params( ms_before=ms_before, ms_after=ms_after, - channel_from_template=channel_from_template, + spike_retriver_kwargs=spike_retriver_kwargs, method=method, method_kwargs=method_kwargs, ) From 0c790f4687251803ab1fbad96712126ef1f49a2a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Oct 2023 19:29:39 +0000 Subject: [PATCH 04/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/node_pipeline.py | 2 +- .../postprocessing/spike_locations.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index e17c5d5caa..9b61ec0dab 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -215,7 +215,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # handle amplitude for i, peak in enumerate(local_peaks): - local_peaks["amplitude"][i] = traces[local_peaks["sample_index"], local_peaks[i]["channel_index"]] + local_peaks["amplitude"][i] = traces[local_peaks["sample_index"], local_peaks[i]["channel_index"]] return (local_peaks,) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 2807fb992c..6f8a8aabcb 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -27,13 +27,16 @@ def __init__(self, waveform_extractor): self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) def _set_params( - self, ms_before=0.5, ms_after=0.5, + self, + ms_before=0.5, + ms_after=0.5, spike_retriver_kwargs=dict( channel_from_template=False, radius_um=50, peaks_sign="neg", ), - method="center_of_mass", method_kwargs={} + method="center_of_mass", + method_kwargs={}, ): params = dict( ms_before=ms_before, ms_after=ms_after, spike_retriver_kwargs=spike_retriver_kwargs, method=method @@ -67,10 +70,7 @@ def _run(self, **job_kwargs): spike_retriver_kwargs = params.pop("spike_retriver_kwargs") spike_retriever = SpikeRetriever( - we.recording, - we.sorting, - extremum_channel_inds=extremum_channel_inds, - **spike_retriver_kwargs + we.recording, we.sorting, extremum_channel_inds=extremum_channel_inds, **spike_retriver_kwargs ) spike_locations = _run_localization_from_peak_source(we.recording, spike_retriever, **params, **job_kwargs) @@ -118,6 +118,7 @@ def get_extension_function(): WaveformExtractor.register_extension(SpikeLocationsCalculator) + def compute_spike_locations( waveform_extractor, load_if_exists=False, @@ -128,7 +129,6 @@ def compute_spike_locations( radius_um=50, peaks_sign="neg", ), - method="center_of_mass", method_kwargs={}, outputs="concatenated", From 6887c97bf81e502f8daee5eb60a049d2fb387c73 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 25 Oct 2023 09:50:01 +0200 Subject: [PATCH 05/11] oups --- src/spikeinterface/core/node_pipeline.py | 2 +- src/spikeinterface/postprocessing/spike_locations.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 9b61ec0dab..a6dabf77b5 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -215,7 +215,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # handle amplitude for i, peak in enumerate(local_peaks): - local_peaks["amplitude"][i] = traces[local_peaks["sample_index"], local_peaks[i]["channel_index"]] + local_peaks["amplitude"][i] = traces[peak["sample_index"], peak["channel_index"]] return (local_peaks,) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 6f8a8aabcb..0e471444d8 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -31,7 +31,7 @@ def _set_params( ms_before=0.5, ms_after=0.5, spike_retriver_kwargs=dict( - channel_from_template=False, + channel_from_template=True, radius_um=50, peaks_sign="neg", ), @@ -125,7 +125,7 @@ def compute_spike_locations( ms_before=0.5, ms_after=0.5, spike_retriver_kwargs=dict( - channel_from_template=False, + channel_from_template=True, radius_um=50, peaks_sign="neg", ), From e238608afde4b3f06e6dac8276fb8c398b1beeab Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 25 Oct 2023 10:39:39 +0200 Subject: [PATCH 06/11] less strict on amplitude for spikeretreiver tests --- src/spikeinterface/core/tests/test_node_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 4b86c538a9..9b65eba726 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -105,7 +105,8 @@ def test_run_node_pipeline(): AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6), ] step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True) - assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"]) + if loop ==0: + assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"]) # 3 nodes two have outputs ms_before = 0.5 @@ -133,7 +134,6 @@ def test_run_node_pipeline(): # gather memory mode output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory") amplitudes, waveforms_rms, denoised_waveforms_rms = output - assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"]) num_peaks = peaks.shape[0] num_channels = recording.get_num_channels() From 59f0473e71fb7b1ea029006021c55adf34b7b69f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Oct 2023 08:42:49 +0000 Subject: [PATCH 07/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/tests/test_node_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 9b65eba726..e5a6dd055c 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -105,7 +105,7 @@ def test_run_node_pipeline(): AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6), ] step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True) - if loop ==0: + if loop == 0: assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"]) # 3 nodes two have outputs From 3ed9e5f0a56cac3b8005c2d7fe5e78ead4078635 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 25 Oct 2023 21:11:55 +0200 Subject: [PATCH 08/11] oups --- src/spikeinterface/postprocessing/spike_locations.py | 6 +++--- .../postprocessing/tests/test_spike_locations.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 0e471444d8..ccf321ba80 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -33,7 +33,7 @@ def _set_params( spike_retriver_kwargs=dict( channel_from_template=True, radius_um=50, - peaks_sign="neg", + peak_sign="neg", ), method="center_of_mass", method_kwargs={}, @@ -127,7 +127,7 @@ def compute_spike_locations( spike_retriver_kwargs=dict( channel_from_template=True, radius_um=50, - peaks_sign="neg", + peak_sign="neg", ), method="center_of_mass", method_kwargs={}, @@ -156,7 +156,7 @@ def compute_spike_locations( channel_from_template = False is slower but more accurate * radius_um: float, default 50 In case channel_from_template=False, this is the radius to get the true peak. - * peaks_sign="neg" + * peak_sign="neg" In case channel_from_template=False, this is the peak sign. method : str 'center_of_mass' / 'monopolar_triangulation' / 'grid_convolution' diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index ab2345b1f5..89b015f1da 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -10,8 +10,8 @@ class SpikeLocationsExtensionTest(WaveformExtensionCommonTestSuite, unittest.Tes extension_class = SpikeLocationsCalculator extension_data_names = ["spike_locations"] extension_function_kwargs_list = [ - dict(method="center_of_mass", chunk_size=10000, n_jobs=1, channel_from_template=True), - dict(method="center_of_mass", chunk_size=10000, n_jobs=1, channel_from_template=False), + dict(method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=True)), + dict(method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=False)), dict(method="center_of_mass", chunk_size=10000, n_jobs=1, outputs="by_unit"), dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"), dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"), From 8a987b87d9b5e6fad4d9e1e03036898b598c8939 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Oct 2023 19:12:17 +0000 Subject: [PATCH 09/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../postprocessing/tests/test_spike_locations.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index 89b015f1da..d047a2f67e 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -10,8 +10,12 @@ class SpikeLocationsExtensionTest(WaveformExtensionCommonTestSuite, unittest.Tes extension_class = SpikeLocationsCalculator extension_data_names = ["spike_locations"] extension_function_kwargs_list = [ - dict(method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=True)), - dict(method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=False)), + dict( + method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=True) + ), + dict( + method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=False) + ), dict(method="center_of_mass", chunk_size=10000, n_jobs=1, outputs="by_unit"), dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"), dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"), From 4e843cb09ac9a9217af0f2a9476514609400d789 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 26 Oct 2023 12:01:19 +0200 Subject: [PATCH 10/11] Update src/spikeinterface/sortingcomponents/tools.py --- src/spikeinterface/sortingcomponents/tools.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 576732baa2..cd9226d5e8 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -19,9 +19,8 @@ def make_multi_method_doc(methods, ident=" "): def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0.5, ms_after=0.5): - # TODO for Pierre: this function is really unefficient because it runa full pipeline only for a few - # spikes, which leans that traces are entirally computed!!!!! - # Please find a better way + # TODO for Pierre: this function is really inefficient because it runs a full pipeline only for a few + # spikes, which means that all traces need to be accesses! Please find a better way nb_peaks = min(len(peaks), nb_peaks) idx = np.sort(np.random.choice(len(peaks), nb_peaks, replace=False)) peak_retriever = PeakRetriever(recording, peaks[idx]) From 1af5722a356c733d9afabe69bab354362273a4f8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 26 Oct 2023 12:03:12 +0200 Subject: [PATCH 11/11] Update src/spikeinterface/postprocessing/spike_locations.py --- src/spikeinterface/postprocessing/spike_locations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index ccf321ba80..28eed131cd 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -148,8 +148,8 @@ def compute_spike_locations( ms_after : float The right window, after a peak, in milliseconds. spike_retriver_kwargs: dict - A dict that contains the behavior for getting the maximum channel for each spike. - This contain dict contains: + A dictionary to control the behavior for getting the maximum channel for each spike. + This dictionary contains: * channel_from_template: bool, default True For each spike is the maximum channel computed from template or re estimated at every spikes. channel_from_template = True is old behavior but less acurate