Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spike location with true channel #1950

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class SpikeRetriever(PeakSource):
"""
This class is usefull to inject a sorting object in the node pipepline mechanisim.
It allows to compute some post processing with the same machinery used for sorting components.
This is a first step to totaly refactor:
This is a first step to totaly refactor:
* compute_spike_locations()
* compute_amplitude_scalings()
* compute_spike_amplitudes()
Expand All @@ -164,25 +164,22 @@ class SpikeRetriever(PeakSource):
Peak sign to find the max channel.
Used only when channel_from_template=False
"""
def __init__(self, recording, sorting,
channel_from_template=True,
extremum_channel_inds=None,
radius_um=50,
peak_sign="neg"
):

def __init__(
self, recording, sorting, channel_from_template=True, extremum_channel_inds=None, radius_um=50, peak_sign="neg"
):
PipelineNode.__init__(self, recording, return_output=False)

self.channel_from_template = channel_from_template

assert extremum_channel_inds is not None, "SpikeRetriever need the dict extremum_channel_inds"

self.peaks = sorting_to_peak(sorting, extremum_channel_inds)

if not channel_from_template:
channel_distance = get_channel_distances(recording)
self.neighbours_mask = channel_distance < radius_um
self.peak_sign = peak_sign

self.peak_sign = peak_sign

# precompute segment slice
self.segment_slices = []
Expand Down Expand Up @@ -219,7 +216,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
elif self.peak_sign == "pos":
local_peaks[i]["channel_index"] = chans[np.argmax(sparse_wfs)]
elif self.peak_sign == "both":
local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))]
local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))]

# TODO: "amplitude" ???

Expand Down
25 changes: 12 additions & 13 deletions src/spikeinterface/core/tests/test_node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,25 +81,24 @@ def test_run_node_pipeline():
we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs)
extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index")
peaks = sorting_to_peak(sorting, extremum_channel_inds)

peak_retriever = PeakRetriever(recording, peaks)
# channel index is from template
spike_retriever_T = SpikeRetriever(recording, sorting,
channel_from_template=True,
extremum_channel_inds=extremum_channel_inds)
spike_retriever_T = SpikeRetriever(
recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channel_inds
)
# channel index is per spike
spike_retriever_S = SpikeRetriever(recording, sorting,
channel_from_template=False,
extremum_channel_inds=extremum_channel_inds,
radius_um=50,
peak_sign="neg")
spike_retriever_S = SpikeRetriever(
recording,
sorting,
channel_from_template=False,
extremum_channel_inds=extremum_channel_inds,
radius_um=50,
peak_sign="neg",
)

# test with 2 diffrents first node
for peak_source in (peak_retriever, spike_retriever_T, spike_retriever_S):




# one step only : squeeze output
nodes = [
peak_source,
Expand Down
45 changes: 40 additions & 5 deletions src/spikeinterface/postprocessing/spike_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift

from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension
from spikeinterface.core.node_pipeline import SpikeRetriever


class SpikeLocationsCalculator(BaseWaveformExtractorExtension):
Expand All @@ -25,9 +26,14 @@ def __init__(self, waveform_extractor):
extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, outputs="index")
self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds)

def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", method_kwargs={}):
params = dict(ms_before=ms_before, ms_after=ms_after, method=method)
def _set_params(
self, ms_before=0.5, ms_after=0.5, channel_from_template=True, method="center_of_mass", method_kwargs={}
):
params = dict(
ms_before=ms_before, ms_after=ms_after, channel_from_template=channel_from_template, method=method
)
params.update(**method_kwargs)
print(params)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the print here

return params

def _select_extension_data(self, unit_ids):
Expand All @@ -44,13 +50,28 @@ def _run(self, **job_kwargs):
uses the`sortingcomponents.peak_localization.localize_peaks()` function to triangulate
spike locations.
"""
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
from spikeinterface.sortingcomponents.peak_localization import _run_localization_from_peak_source

job_kwargs = fix_job_kwargs(job_kwargs)

we = self.waveform_extractor

spike_locations = localize_peaks(we.recording, self.spikes, **self._params, **job_kwargs)
extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index")

params = self._params.copy()
channel_from_template = params.pop("channel_from_template")

# @alessio @pierre: where do we expose the parameters of radius for the retriever (this is not the same as the one for locatization it is smaller) ???
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expose it in the compute_spike_location function, otherwise this has no sense. In addition, to ensure comparison with/witout this extra mecanism, we should make sure that if radius_um is set to 0, a classical PeakRetriever is used here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PeakRetriever will be not used here this is postprocessing.

spike_retriever = SpikeRetriever(
we.recording,
we.sorting,
channel_from_template=channel_from_template,
extremum_channel_inds=extremum_channel_inds,
radius_um=50,
peak_sign=self._params.get("peaks_sign", "neg"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This neg should not be hardcoded, and peak_sign should be an argument of compute_spike_locations

)
spike_locations = _run_localization_from_peak_source(we.recording, spike_retriever, **params, **job_kwargs)

self._extension_data["spike_locations"] = spike_locations

def get_data(self, outputs="concatenated"):
Expand Down Expand Up @@ -95,12 +116,16 @@ def get_extension_function():

WaveformExtractor.register_extension(SpikeLocationsCalculator)

# @alessio @pierre: channel_from_template=True is the old behavior but this is not accurate
# what do we put by default ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would go for the new behavior as a default, but we need to think on the impact with the metrics


def compute_spike_locations(
waveform_extractor,
load_if_exists=False,
ms_before=0.5,
ms_after=0.5,
channel_from_template=True,
method="center_of_mass",
method_kwargs={},
outputs="concatenated",
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -119,6 +144,10 @@ def compute_spike_locations(
The left window, before a peak, in milliseconds.
ms_after : float
The right window, after a peak, in milliseconds.
channel_from_template: bool, default True
For each spike is the maximum channel computed from template or re estimated at every spikes.
channel_from_template = True is old behavior but less acurate
channel_from_template = False is slower but more accurate
method : str
'center_of_mass' / 'monopolar_triangulation' / 'grid_convolution'
method_kwargs : dict
Expand All @@ -138,7 +167,13 @@ def compute_spike_locations(
slc = waveform_extractor.load_extension(SpikeLocationsCalculator.extension_name)
else:
slc = SpikeLocationsCalculator(waveform_extractor)
slc.set_params(ms_before=ms_before, ms_after=ms_after, method=method, method_kwargs=method_kwargs)
slc.set_params(
ms_before=ms_before,
ms_after=ms_after,
channel_from_template=channel_from_template,
method=method,
method_kwargs=method_kwargs,
)
slc.run(**job_kwargs)

locs = slc.get_data(outputs=outputs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ class SpikeLocationsExtensionTest(WaveformExtensionCommonTestSuite, unittest.Tes
extension_class = SpikeLocationsCalculator
extension_data_names = ["spike_locations"]
extension_function_kwargs_list = [
dict(method="center_of_mass", chunk_size=10000, n_jobs=1),
dict(method="center_of_mass", chunk_size=10000, n_jobs=1, channel_from_template=True),
dict(method="center_of_mass", chunk_size=10000, n_jobs=1, channel_from_template=False),
dict(method="center_of_mass", chunk_size=10000, n_jobs=1, outputs="by_unit"),
dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"),
dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"),
Expand Down
88 changes: 49 additions & 39 deletions src/spikeinterface/sortingcomponents/peak_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
run_node_pipeline,
find_parent_of_type,
PeakRetriever,
SpikeRetriever,
PipelineNode,
WaveformsNode,
ExtractDenseWaveforms,
Expand All @@ -27,72 +28,49 @@
from .tools import get_prototype_spike


def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs):
"""Localize peak (spike) in 2D or 3D depending the method.

When a probe is 2D then:
* X is axis 0 of the probe
* Y is axis 1 of the probe
* Z is orthogonal to the plane of the probe

Parameters
----------
recording: RecordingExtractor
The recording extractor object.
peaks: array
Peaks array, as returned by detect_peaks() in "compact_numpy" way.

{method_doc}

{job_doc}

Returns
-------
peak_locations: ndarray
Array with estimated location for each spike.
The dtype depends on the method. ('x', 'y') or ('x', 'y', 'z', 'alpha').
"""
def _run_localization_from_peak_source(
recording, peak_source, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs
):
# use by localize_peaks() and compute_spike_locations()
assert (
method in possible_localization_methods
), f"Method {method} is not supported. Choose from {possible_localization_methods}"

method_kwargs, job_kwargs = split_job_kwargs(kwargs)

peak_retriever = PeakRetriever(recording, peaks)
if method == "center_of_mass":
extract_dense_waveforms = ExtractDenseWaveforms(
recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False
recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False
)
pipeline_nodes = [
peak_retriever,
peak_source,
extract_dense_waveforms,
LocalizeCenterOfMass(recording, parents=[peak_retriever, extract_dense_waveforms], **method_kwargs),
LocalizeCenterOfMass(recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs),
]
elif method == "monopolar_triangulation":
extract_dense_waveforms = ExtractDenseWaveforms(
recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False
recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False
)
pipeline_nodes = [
peak_retriever,
peak_source,
extract_dense_waveforms,
LocalizeMonopolarTriangulation(
recording, parents=[peak_retriever, extract_dense_waveforms], **method_kwargs
),
LocalizeMonopolarTriangulation(recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs),
]
elif method == "peak_channel":
pipeline_nodes = [peak_retriever, LocalizePeakChannel(recording, parents=[peak_retriever], **method_kwargs)]
pipeline_nodes = [peak_source, LocalizePeakChannel(recording, parents=[peak_source], **method_kwargs)]
elif method == "grid_convolution":
if "prototype" not in method_kwargs:
assert isinstance(peak_source, (PeakRetriever, SpikeRetriever))
method_kwargs["prototype"] = get_prototype_spike(
recording, peaks, ms_before=ms_before, ms_after=ms_after, job_kwargs=job_kwargs
recording, peak_source.peaks, ms_before=ms_before, ms_after=ms_after, job_kwargs=job_kwargs
)
extract_dense_waveforms = ExtractDenseWaveforms(
recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False
recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False
)
pipeline_nodes = [
peak_retriever,
peak_source,
extract_dense_waveforms,
LocalizeGridConvolution(recording, parents=[peak_retriever, extract_dense_waveforms], **method_kwargs),
LocalizeGridConvolution(recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs),
]

job_name = f"localize peaks using {method}"
Expand All @@ -101,6 +79,38 @@ def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_
return peak_locations


def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs):
"""Localize peak (spike) in 2D or 3D depending the method.

When a probe is 2D then:
* X is axis 0 of the probe
* Y is axis 1 of the probe
* Z is orthogonal to the plane of the probe

Parameters
----------
recording: RecordingExtractor
The recording extractor object.
peaks: array
Peaks array, as returned by detect_peaks() in "compact_numpy" way.

{method_doc}

{job_doc}

Returns
-------
peak_locations: ndarray
Array with estimated location for each spike.
The dtype depends on the method. ('x', 'y') or ('x', 'y', 'z', 'alpha').
"""
peak_retriever = PeakRetriever(recording, peaks)
peak_locations = _run_localization_from_peak_source(
recording, peak_retriever, method=method, ms_before=ms_before, ms_after=ms_after, **kwargs
)
return peak_locations


class LocalizeBase(PipelineNode):
def __init__(self, recording, return_output=True, parents=None, radius_um=75.0):
PipelineNode.__init__(self, recording, return_output=return_output, parents=parents)
Expand Down
3 changes: 3 additions & 0 deletions src/spikeinterface/sortingcomponents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def make_multi_method_doc(methods, ident=" "):


def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0.5, ms_after=0.5):
# TODO for Pierre: this function is really unefficient because it runa full pipeline only for a few
# spikes, which leans that traces are entirally computed!!!!!
# Please find a better way
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
nb_peaks = min(len(peaks), nb_peaks)
idx = np.sort(np.random.choice(len(peaks), nb_peaks, replace=False))
peak_retriever = PeakRetriever(recording, peaks[idx])
Expand Down