Skip to content

Commit

Permalink
Merge pull request #1950 from samuelgarcia/spike_location_with_true_c…
Browse files Browse the repository at this point in the history
…hannel

Spike location with true channel
  • Loading branch information
alejoe91 authored Oct 26, 2023
2 parents 503d7c8 + 1af5722 commit ef095c2
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 49 deletions.
4 changes: 3 additions & 1 deletion src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
elif self.peak_sign == "both":
local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))]

# TODO: "amplitude" ???
# handle amplitude
for i, peak in enumerate(local_peaks):
local_peaks["amplitude"][i] = traces[peak["sample_index"], peak["channel_index"]]

return (local_peaks,)

Expand Down
7 changes: 4 additions & 3 deletions src/spikeinterface/core/tests/test_node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def compute(self, traces, peaks, waveforms):
def test_run_node_pipeline():
recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0])

job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False)
# job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False)
job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False)

spikes = sorting.to_spike_vector()

Expand Down Expand Up @@ -104,7 +105,8 @@ def test_run_node_pipeline():
AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6),
]
step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True)
assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"])
if loop == 0:
assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"])

# 3 nodes two have outputs
ms_before = 0.5
Expand Down Expand Up @@ -132,7 +134,6 @@ def test_run_node_pipeline():
# gather memory mode
output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory")
amplitudes, waveforms_rms, denoised_waveforms_rms = output
assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"])

num_peaks = peaks.shape[0]
num_channels = recording.get_num_channels()
Expand Down
55 changes: 50 additions & 5 deletions src/spikeinterface/postprocessing/spike_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift

from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension
from spikeinterface.core.node_pipeline import SpikeRetriever


class SpikeLocationsCalculator(BaseWaveformExtractorExtension):
Expand All @@ -25,8 +26,21 @@ def __init__(self, waveform_extractor):
extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, outputs="index")
self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds)

def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", method_kwargs={}):
params = dict(ms_before=ms_before, ms_after=ms_after, method=method)
def _set_params(
self,
ms_before=0.5,
ms_after=0.5,
spike_retriver_kwargs=dict(
channel_from_template=True,
radius_um=50,
peak_sign="neg",
),
method="center_of_mass",
method_kwargs={},
):
params = dict(
ms_before=ms_before, ms_after=ms_after, spike_retriver_kwargs=spike_retriver_kwargs, method=method
)
params.update(**method_kwargs)
return params

Expand All @@ -44,13 +58,22 @@ def _run(self, **job_kwargs):
uses the`sortingcomponents.peak_localization.localize_peaks()` function to triangulate
spike locations.
"""
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
from spikeinterface.sortingcomponents.peak_localization import _run_localization_from_peak_source

job_kwargs = fix_job_kwargs(job_kwargs)

we = self.waveform_extractor

spike_locations = localize_peaks(we.recording, self.spikes, **self._params, **job_kwargs)
extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index")

params = self._params.copy()
spike_retriver_kwargs = params.pop("spike_retriver_kwargs")

spike_retriever = SpikeRetriever(
we.recording, we.sorting, extremum_channel_inds=extremum_channel_inds, **spike_retriver_kwargs
)
spike_locations = _run_localization_from_peak_source(we.recording, spike_retriever, **params, **job_kwargs)

self._extension_data["spike_locations"] = spike_locations

def get_data(self, outputs="concatenated"):
Expand Down Expand Up @@ -101,6 +124,11 @@ def compute_spike_locations(
load_if_exists=False,
ms_before=0.5,
ms_after=0.5,
spike_retriver_kwargs=dict(
channel_from_template=True,
radius_um=50,
peak_sign="neg",
),
method="center_of_mass",
method_kwargs={},
outputs="concatenated",
Expand All @@ -119,6 +147,17 @@ def compute_spike_locations(
The left window, before a peak, in milliseconds.
ms_after : float
The right window, after a peak, in milliseconds.
spike_retriver_kwargs: dict
A dictionary to control the behavior for getting the maximum channel for each spike.
This dictionary contains:
* channel_from_template: bool, default True
For each spike is the maximum channel computed from template or re estimated at every spikes.
channel_from_template = True is old behavior but less acurate
channel_from_template = False is slower but more accurate
* radius_um: float, default 50
In case channel_from_template=False, this is the radius to get the true peak.
* peak_sign="neg"
In case channel_from_template=False, this is the peak sign.
method : str
'center_of_mass' / 'monopolar_triangulation' / 'grid_convolution'
method_kwargs : dict
Expand All @@ -138,7 +177,13 @@ def compute_spike_locations(
slc = waveform_extractor.load_extension(SpikeLocationsCalculator.extension_name)
else:
slc = SpikeLocationsCalculator(waveform_extractor)
slc.set_params(ms_before=ms_before, ms_after=ms_after, method=method, method_kwargs=method_kwargs)
slc.set_params(
ms_before=ms_before,
ms_after=ms_after,
spike_retriver_kwargs=spike_retriver_kwargs,
method=method,
method_kwargs=method_kwargs,
)
slc.run(**job_kwargs)

locs = slc.get_data(outputs=outputs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ class SpikeLocationsExtensionTest(WaveformExtensionCommonTestSuite, unittest.Tes
extension_class = SpikeLocationsCalculator
extension_data_names = ["spike_locations"]
extension_function_kwargs_list = [
dict(method="center_of_mass", chunk_size=10000, n_jobs=1),
dict(
method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=True)
),
dict(
method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=False)
),
dict(method="center_of_mass", chunk_size=10000, n_jobs=1, outputs="by_unit"),
dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"),
dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"),
Expand Down
88 changes: 49 additions & 39 deletions src/spikeinterface/sortingcomponents/peak_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
run_node_pipeline,
find_parent_of_type,
PeakRetriever,
SpikeRetriever,
PipelineNode,
WaveformsNode,
ExtractDenseWaveforms,
Expand All @@ -27,72 +28,49 @@
from .tools import get_prototype_spike


def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs):
"""Localize peak (spike) in 2D or 3D depending the method.
When a probe is 2D then:
* X is axis 0 of the probe
* Y is axis 1 of the probe
* Z is orthogonal to the plane of the probe
Parameters
----------
recording: RecordingExtractor
The recording extractor object.
peaks: array
Peaks array, as returned by detect_peaks() in "compact_numpy" way.
{method_doc}
{job_doc}
Returns
-------
peak_locations: ndarray
Array with estimated location for each spike.
The dtype depends on the method. ('x', 'y') or ('x', 'y', 'z', 'alpha').
"""
def _run_localization_from_peak_source(
recording, peak_source, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs
):
# use by localize_peaks() and compute_spike_locations()
assert (
method in possible_localization_methods
), f"Method {method} is not supported. Choose from {possible_localization_methods}"

method_kwargs, job_kwargs = split_job_kwargs(kwargs)

peak_retriever = PeakRetriever(recording, peaks)
if method == "center_of_mass":
extract_dense_waveforms = ExtractDenseWaveforms(
recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False
recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False
)
pipeline_nodes = [
peak_retriever,
peak_source,
extract_dense_waveforms,
LocalizeCenterOfMass(recording, parents=[peak_retriever, extract_dense_waveforms], **method_kwargs),
LocalizeCenterOfMass(recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs),
]
elif method == "monopolar_triangulation":
extract_dense_waveforms = ExtractDenseWaveforms(
recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False
recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False
)
pipeline_nodes = [
peak_retriever,
peak_source,
extract_dense_waveforms,
LocalizeMonopolarTriangulation(
recording, parents=[peak_retriever, extract_dense_waveforms], **method_kwargs
),
LocalizeMonopolarTriangulation(recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs),
]
elif method == "peak_channel":
pipeline_nodes = [peak_retriever, LocalizePeakChannel(recording, parents=[peak_retriever], **method_kwargs)]
pipeline_nodes = [peak_source, LocalizePeakChannel(recording, parents=[peak_source], **method_kwargs)]
elif method == "grid_convolution":
if "prototype" not in method_kwargs:
assert isinstance(peak_source, (PeakRetriever, SpikeRetriever))
method_kwargs["prototype"] = get_prototype_spike(
recording, peaks, ms_before=ms_before, ms_after=ms_after, job_kwargs=job_kwargs
recording, peak_source.peaks, ms_before=ms_before, ms_after=ms_after, job_kwargs=job_kwargs
)
extract_dense_waveforms = ExtractDenseWaveforms(
recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False
recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False
)
pipeline_nodes = [
peak_retriever,
peak_source,
extract_dense_waveforms,
LocalizeGridConvolution(recording, parents=[peak_retriever, extract_dense_waveforms], **method_kwargs),
LocalizeGridConvolution(recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs),
]

job_name = f"localize peaks using {method}"
Expand All @@ -101,6 +79,38 @@ def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_
return peak_locations


def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs):
"""Localize peak (spike) in 2D or 3D depending the method.
When a probe is 2D then:
* X is axis 0 of the probe
* Y is axis 1 of the probe
* Z is orthogonal to the plane of the probe
Parameters
----------
recording: RecordingExtractor
The recording extractor object.
peaks: array
Peaks array, as returned by detect_peaks() in "compact_numpy" way.
{method_doc}
{job_doc}
Returns
-------
peak_locations: ndarray
Array with estimated location for each spike.
The dtype depends on the method. ('x', 'y') or ('x', 'y', 'z', 'alpha').
"""
peak_retriever = PeakRetriever(recording, peaks)
peak_locations = _run_localization_from_peak_source(
recording, peak_retriever, method=method, ms_before=ms_before, ms_after=ms_after, **kwargs
)
return peak_locations


class LocalizeBase(PipelineNode):
def __init__(self, recording, return_output=True, parents=None, radius_um=75.0):
PipelineNode.__init__(self, recording, return_output=return_output, parents=parents)
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/sortingcomponents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def make_multi_method_doc(methods, ident=" "):


def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0.5, ms_after=0.5):
# TODO for Pierre: this function is really inefficient because it runs a full pipeline only for a few
# spikes, which means that all traces need to be accesses! Please find a better way
nb_peaks = min(len(peaks), nb_peaks)
idx = np.sort(np.random.choice(len(peaks), nb_peaks, replace=False))
peak_retriever = PeakRetriever(recording, peaks[idx])
Expand Down

0 comments on commit ef095c2

Please sign in to comment.