Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spike location with true channel #1950

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
elif self.peak_sign == "both":
local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))]

# TODO: "amplitude" ???
# handle amplitude
for i, peak in enumerate(local_peaks):
local_peaks["amplitude"][i] = traces[peak["sample_index"], peak["channel_index"]]

return (local_peaks,)

Expand Down
7 changes: 4 additions & 3 deletions src/spikeinterface/core/tests/test_node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def compute(self, traces, peaks, waveforms):
def test_run_node_pipeline():
recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0])

job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False)
# job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False)
job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False)

spikes = sorting.to_spike_vector()

Expand Down Expand Up @@ -104,7 +105,8 @@ def test_run_node_pipeline():
AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6),
]
step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True)
assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"])
if loop == 0:
assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"])

# 3 nodes two have outputs
ms_before = 0.5
Expand Down Expand Up @@ -132,7 +134,6 @@ def test_run_node_pipeline():
# gather memory mode
output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory")
amplitudes, waveforms_rms, denoised_waveforms_rms = output
assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"])

num_peaks = peaks.shape[0]
num_channels = recording.get_num_channels()
Expand Down
55 changes: 50 additions & 5 deletions src/spikeinterface/postprocessing/spike_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift

from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension
from spikeinterface.core.node_pipeline import SpikeRetriever


class SpikeLocationsCalculator(BaseWaveformExtractorExtension):
Expand All @@ -25,8 +26,21 @@ def __init__(self, waveform_extractor):
extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, outputs="index")
self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds)

def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", method_kwargs={}):
params = dict(ms_before=ms_before, ms_after=ms_after, method=method)
def _set_params(
self,
ms_before=0.5,
ms_after=0.5,
spike_retriver_kwargs=dict(
channel_from_template=True,
radius_um=50,
peak_sign="neg",
),
method="center_of_mass",
method_kwargs={},
):
params = dict(
ms_before=ms_before, ms_after=ms_after, spike_retriver_kwargs=spike_retriver_kwargs, method=method
)
params.update(**method_kwargs)
return params

Expand All @@ -44,13 +58,22 @@ def _run(self, **job_kwargs):
uses the`sortingcomponents.peak_localization.localize_peaks()` function to triangulate
spike locations.
"""
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
from spikeinterface.sortingcomponents.peak_localization import _run_localization_from_peak_source

job_kwargs = fix_job_kwargs(job_kwargs)

we = self.waveform_extractor

spike_locations = localize_peaks(we.recording, self.spikes, **self._params, **job_kwargs)
extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index")

params = self._params.copy()
spike_retriver_kwargs = params.pop("spike_retriver_kwargs")

spike_retriever = SpikeRetriever(
we.recording, we.sorting, extremum_channel_inds=extremum_channel_inds, **spike_retriver_kwargs
)
spike_locations = _run_localization_from_peak_source(we.recording, spike_retriever, **params, **job_kwargs)

self._extension_data["spike_locations"] = spike_locations

def get_data(self, outputs="concatenated"):
Expand Down Expand Up @@ -101,6 +124,11 @@ def compute_spike_locations(
load_if_exists=False,
ms_before=0.5,
ms_after=0.5,
spike_retriver_kwargs=dict(
channel_from_template=True,
radius_um=50,
peak_sign="neg",
),
method="center_of_mass",
method_kwargs={},
outputs="concatenated",
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -119,6 +147,17 @@ def compute_spike_locations(
The left window, before a peak, in milliseconds.
ms_after : float
The right window, after a peak, in milliseconds.
spike_retriver_kwargs: dict
A dictionary to control the behavior for getting the maximum channel for each spike.
This dictionary contains:
* channel_from_template: bool, default True
For each spike is the maximum channel computed from template or re estimated at every spikes.
channel_from_template = True is old behavior but less acurate
channel_from_template = False is slower but more accurate
* radius_um: float, default 50
In case channel_from_template=False, this is the radius to get the true peak.
* peak_sign="neg"
In case channel_from_template=False, this is the peak sign.
method : str
'center_of_mass' / 'monopolar_triangulation' / 'grid_convolution'
method_kwargs : dict
Expand All @@ -138,7 +177,13 @@ def compute_spike_locations(
slc = waveform_extractor.load_extension(SpikeLocationsCalculator.extension_name)
else:
slc = SpikeLocationsCalculator(waveform_extractor)
slc.set_params(ms_before=ms_before, ms_after=ms_after, method=method, method_kwargs=method_kwargs)
slc.set_params(
ms_before=ms_before,
ms_after=ms_after,
spike_retriver_kwargs=spike_retriver_kwargs,
method=method,
method_kwargs=method_kwargs,
)
slc.run(**job_kwargs)

locs = slc.get_data(outputs=outputs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ class SpikeLocationsExtensionTest(WaveformExtensionCommonTestSuite, unittest.Tes
extension_class = SpikeLocationsCalculator
extension_data_names = ["spike_locations"]
extension_function_kwargs_list = [
dict(method="center_of_mass", chunk_size=10000, n_jobs=1),
dict(
method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=True)
),
dict(
method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=False)
),
dict(method="center_of_mass", chunk_size=10000, n_jobs=1, outputs="by_unit"),
dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"),
dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"),
Expand Down
88 changes: 49 additions & 39 deletions src/spikeinterface/sortingcomponents/peak_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
run_node_pipeline,
find_parent_of_type,
PeakRetriever,
SpikeRetriever,
PipelineNode,
WaveformsNode,
ExtractDenseWaveforms,
Expand All @@ -27,72 +28,49 @@
from .tools import get_prototype_spike


def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs):
"""Localize peak (spike) in 2D or 3D depending the method.

When a probe is 2D then:
* X is axis 0 of the probe
* Y is axis 1 of the probe
* Z is orthogonal to the plane of the probe

Parameters
----------
recording: RecordingExtractor
The recording extractor object.
peaks: array
Peaks array, as returned by detect_peaks() in "compact_numpy" way.

{method_doc}

{job_doc}

Returns
-------
peak_locations: ndarray
Array with estimated location for each spike.
The dtype depends on the method. ('x', 'y') or ('x', 'y', 'z', 'alpha').
"""
def _run_localization_from_peak_source(
recording, peak_source, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs
):
# use by localize_peaks() and compute_spike_locations()
assert (
method in possible_localization_methods
), f"Method {method} is not supported. Choose from {possible_localization_methods}"

method_kwargs, job_kwargs = split_job_kwargs(kwargs)

peak_retriever = PeakRetriever(recording, peaks)
if method == "center_of_mass":
extract_dense_waveforms = ExtractDenseWaveforms(
recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False
recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False
)
pipeline_nodes = [
peak_retriever,
peak_source,
extract_dense_waveforms,
LocalizeCenterOfMass(recording, parents=[peak_retriever, extract_dense_waveforms], **method_kwargs),
LocalizeCenterOfMass(recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs),
]
elif method == "monopolar_triangulation":
extract_dense_waveforms = ExtractDenseWaveforms(
recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False
recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False
)
pipeline_nodes = [
peak_retriever,
peak_source,
extract_dense_waveforms,
LocalizeMonopolarTriangulation(
recording, parents=[peak_retriever, extract_dense_waveforms], **method_kwargs
),
LocalizeMonopolarTriangulation(recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs),
]
elif method == "peak_channel":
pipeline_nodes = [peak_retriever, LocalizePeakChannel(recording, parents=[peak_retriever], **method_kwargs)]
pipeline_nodes = [peak_source, LocalizePeakChannel(recording, parents=[peak_source], **method_kwargs)]
elif method == "grid_convolution":
if "prototype" not in method_kwargs:
assert isinstance(peak_source, (PeakRetriever, SpikeRetriever))
method_kwargs["prototype"] = get_prototype_spike(
recording, peaks, ms_before=ms_before, ms_after=ms_after, job_kwargs=job_kwargs
recording, peak_source.peaks, ms_before=ms_before, ms_after=ms_after, job_kwargs=job_kwargs
)
extract_dense_waveforms = ExtractDenseWaveforms(
recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False
recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False
)
pipeline_nodes = [
peak_retriever,
peak_source,
extract_dense_waveforms,
LocalizeGridConvolution(recording, parents=[peak_retriever, extract_dense_waveforms], **method_kwargs),
LocalizeGridConvolution(recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs),
]

job_name = f"localize peaks using {method}"
Expand All @@ -101,6 +79,38 @@ def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_
return peak_locations


def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs):
"""Localize peak (spike) in 2D or 3D depending the method.

When a probe is 2D then:
* X is axis 0 of the probe
* Y is axis 1 of the probe
* Z is orthogonal to the plane of the probe

Parameters
----------
recording: RecordingExtractor
The recording extractor object.
peaks: array
Peaks array, as returned by detect_peaks() in "compact_numpy" way.

{method_doc}

{job_doc}

Returns
-------
peak_locations: ndarray
Array with estimated location for each spike.
The dtype depends on the method. ('x', 'y') or ('x', 'y', 'z', 'alpha').
"""
peak_retriever = PeakRetriever(recording, peaks)
peak_locations = _run_localization_from_peak_source(
recording, peak_retriever, method=method, ms_before=ms_before, ms_after=ms_after, **kwargs
)
return peak_locations


class LocalizeBase(PipelineNode):
def __init__(self, recording, return_output=True, parents=None, radius_um=75.0):
PipelineNode.__init__(self, recording, return_output=return_output, parents=parents)
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/sortingcomponents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def make_multi_method_doc(methods, ident=" "):


def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0.5, ms_after=0.5):
# TODO for Pierre: this function is really inefficient because it runs a full pipeline only for a few
# spikes, which means that all traces need to be accesses! Please find a better way
nb_peaks = min(len(peaks), nb_peaks)
idx = np.sort(np.random.choice(len(peaks), nb_peaks, replace=False))
peak_retriever = PeakRetriever(recording, peaks[idx])
Expand Down