Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of a peak centerer to align peaks correctly #1879

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions src/spikeinterface/postprocessing/spike_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ class SpikeLocationsCalculator(BaseWaveformExtractorExtension):

extension_name = "spike_locations"

def __init__(self, waveform_extractor):
def __init__(self, waveform_extractor, peak_sign="neg"):
BaseWaveformExtractorExtension.__init__(self, waveform_extractor)

extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, outputs="index")
extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, peak_sign, outputs="index")
self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds)

def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", method_kwargs={}):
params = dict(ms_before=ms_before, ms_after=ms_after, method=method)
def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", method_kwargs={}, radius_um=None):
params = dict(ms_before=ms_before, ms_after=ms_after, method=method, radius_um=radius_um)
params.update(**method_kwargs)
return params

Expand All @@ -47,12 +46,8 @@ def _run(self, **job_kwargs):
from spikeinterface.sortingcomponents.peak_localization import localize_peaks

job_kwargs = fix_job_kwargs(job_kwargs)

we = self.waveform_extractor

extremum_channel_inds = get_template_extremum_channel(we, outputs="index")
self.spikes = we.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds)

spike_locations = localize_peaks(we.recording, self.spikes, **self._params, **job_kwargs)
self._extension_data["spike_locations"] = spike_locations

Expand Down Expand Up @@ -107,6 +102,8 @@ def compute_spike_locations(
method="center_of_mass",
method_kwargs={},
outputs="concatenated",
peak_sign="neg",
radius_um=None,
**job_kwargs,
):
"""
Expand All @@ -126,6 +123,11 @@ def compute_spike_locations(
'center_of_mass' / 'monopolar_triangulation' / 'grid_convolution'
method_kwargs : dict
Other kwargs depending on the method.
peak_sign: str
Sign of the template to compute best channels ('neg', 'pos', 'both')
radius_um: None
If not None, the radius used to perform peak centering, i.e. look for the
real channel, in the data, where the peaks occur
outputs : str
'concatenated' (default) / 'by_unit'
{}
Expand All @@ -140,8 +142,10 @@ def compute_spike_locations(
if load_if_exists and waveform_extractor.is_extension(SpikeLocationsCalculator.extension_name):
slc = waveform_extractor.load_extension(SpikeLocationsCalculator.extension_name)
else:
slc = SpikeLocationsCalculator(waveform_extractor)
slc.set_params(ms_before=ms_before, ms_after=ms_after, method=method, method_kwargs=method_kwargs)
slc = SpikeLocationsCalculator(waveform_extractor, peak_sign)
slc.set_params(
ms_before=ms_before, ms_after=ms_after, method=method, method_kwargs=method_kwargs, radius_um=radius_um
)
slc.run(**job_kwargs)

locs = slc.get_data(outputs=outputs)
Expand Down
12 changes: 10 additions & 2 deletions src/spikeinterface/sortingcomponents/peak_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
run_node_pipeline,
find_parent_of_type,
PeakRetriever,
PeakCenterer,
PipelineNode,
WaveformsNode,
ExtractDenseWaveforms,
Expand All @@ -26,7 +27,7 @@
from .tools import get_prototype_spike


def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs):
def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, radius_um=None, **kwargs):
"""Localize peak (spike) in 2D or 3D depending the method.

When a probe is 2D then:
Expand All @@ -40,6 +41,9 @@ def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_
The recording extractor object.
peaks: array
Peaks array, as returned by detect_peaks() in "compact_numpy" way.
radius_um: float (default None)
If not None, the radius used to perform peak centering, i.e. find the real channel where
the peaks actually occur

{method_doc}

Expand All @@ -57,7 +61,11 @@ def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_

method_kwargs, job_kwargs = split_job_kwargs(kwargs)

peak_retriever = PeakRetriever(recording, peaks)
if radius_um is not None:
peak_retriever = PeakCenterer(recording, peaks, radius_um=radius_um)
else:
peak_retriever = PeakRetriever(recording, peaks)

if method == "center_of_mass":
extract_dense_waveforms = ExtractDenseWaveforms(
recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False
Expand Down
46 changes: 44 additions & 2 deletions src/spikeinterface/sortingcomponents/peak_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,46 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
return (local_peaks,)


class PeakCenterer(PeakRetriever):
def __init__(self, recording, peaks, radius_um=50, peak_sign="neg"):
PeakRetriever.__init__(self, recording, peaks)
self.radius_um = radius_um
self.contact_locations = recording.get_channel_locations()
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance < radius_um
self.peak_sign = peak_sign

def get_trace_margin(self):
return 0

def get_dtype(self):
return base_peak_dtype

def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
# get local peaks
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame)
i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame)
local_peaks = peaks_in_segment[i0:i1]

# make sample index local to traces
local_peaks = local_peaks.copy()
local_peaks["sample_index"] -= start_frame - max_margin

for i, peak in enumerate(local_peaks):
(chans,) = np.nonzero(self.neighbours_mask[peak["channel_index"]])
sparse_wfs = traces[peak["sample_index"], chans]
if self.peak_sign == "neg":
local_peaks[i]["channel_index"] = chans[np.argmin(sparse_wfs)]
elif self.peak_sign == "pos":
local_peaks[i]["channel_index"] = chans[np.argmax(sparse_wfs)]
elif self.peak_sign == "both":
local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))]

return (local_peaks,)


class WaveformsNode(PipelineNode):
"""
Base class for waveforms in a node pipeline.
Expand Down Expand Up @@ -304,8 +344,10 @@ def check_graph(nodes):
"""

node0 = nodes[0]
if not (isinstance(node0, PeakDetector) or isinstance(node0, PeakRetriever)):
raise ValueError("Peak pipeline graph must contain PeakDetector or PeakRetriever as first element")
if not (isinstance(node0, PeakDetector) or isinstance(node0, PeakRetriever) or isinstance(node0, PeakCenterer)):
raise ValueError(
"Peak pipeline graph must contain PeakDetector or PeakRetriever or PeakCenterer as first element"
)

for i, node in enumerate(nodes):
assert isinstance(node, PipelineNode), f"Node {node} is not an instance of PipelineNode"
Expand Down