-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Spike location with true channel #1950
Changes from 2 commits
5370884
68df573
87c1ccb
5fef79c
5e58fa5
384809b
8e316ef
0c790f4
6887c97
f94c594
e238608
57ba043
59f0473
3ed9e5f
8a987b8
4e843cb
1af5722
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift | ||
|
||
from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension | ||
from spikeinterface.core.node_pipeline import SpikeRetriever | ||
|
||
|
||
class SpikeLocationsCalculator(BaseWaveformExtractorExtension): | ||
|
@@ -25,9 +26,14 @@ def __init__(self, waveform_extractor): | |
extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, outputs="index") | ||
self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) | ||
|
||
def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", method_kwargs={}): | ||
params = dict(ms_before=ms_before, ms_after=ms_after, method=method) | ||
def _set_params( | ||
self, ms_before=0.5, ms_after=0.5, channel_from_template=True, method="center_of_mass", method_kwargs={} | ||
): | ||
params = dict( | ||
ms_before=ms_before, ms_after=ms_after, channel_from_template=channel_from_template, method=method | ||
) | ||
params.update(**method_kwargs) | ||
print(params) | ||
return params | ||
|
||
def _select_extension_data(self, unit_ids): | ||
|
@@ -44,13 +50,28 @@ def _run(self, **job_kwargs): | |
uses the`sortingcomponents.peak_localization.localize_peaks()` function to triangulate | ||
spike locations. | ||
""" | ||
from spikeinterface.sortingcomponents.peak_localization import localize_peaks | ||
from spikeinterface.sortingcomponents.peak_localization import _run_localization_from_peak_source | ||
|
||
job_kwargs = fix_job_kwargs(job_kwargs) | ||
|
||
we = self.waveform_extractor | ||
|
||
spike_locations = localize_peaks(we.recording, self.spikes, **self._params, **job_kwargs) | ||
extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") | ||
|
||
params = self._params.copy() | ||
channel_from_template = params.pop("channel_from_template") | ||
|
||
# @alessio @pierre: where do we expose the parameters of radius for the retriever (this is not the same as the one for locatization it is smaller) ??? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would expose it in the compute_spike_location function, otherwise this has no sense. In addition, to ensure comparison with/witout this extra mecanism, we should make sure that if radius_um is set to 0, a classical PeakRetriever is used here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PeakRetriever will be not used here this is postprocessing. |
||
spike_retriever = SpikeRetriever( | ||
we.recording, | ||
we.sorting, | ||
channel_from_template=channel_from_template, | ||
extremum_channel_inds=extremum_channel_inds, | ||
radius_um=50, | ||
peak_sign=self._params.get("peaks_sign", "neg"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This neg should not be hardcoded, and peak_sign should be an argument of compute_spike_locations |
||
) | ||
spike_locations = _run_localization_from_peak_source(we.recording, spike_retriever, **params, **job_kwargs) | ||
|
||
self._extension_data["spike_locations"] = spike_locations | ||
|
||
def get_data(self, outputs="concatenated"): | ||
|
@@ -95,12 +116,16 @@ def get_extension_function(): | |
|
||
WaveformExtractor.register_extension(SpikeLocationsCalculator) | ||
|
||
# @alessio @pierre: channel_from_template=True is the old behavior but this is not accurate | ||
# what do we put by default ? | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would go for the new behavior as a default, but we need to think on the impact with the metrics |
||
|
||
def compute_spike_locations( | ||
waveform_extractor, | ||
load_if_exists=False, | ||
ms_before=0.5, | ||
ms_after=0.5, | ||
channel_from_template=True, | ||
method="center_of_mass", | ||
method_kwargs={}, | ||
outputs="concatenated", | ||
alejoe91 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
@@ -119,6 +144,10 @@ def compute_spike_locations( | |
The left window, before a peak, in milliseconds. | ||
ms_after : float | ||
The right window, after a peak, in milliseconds. | ||
channel_from_template: bool, default True | ||
For each spike is the maximum channel computed from template or re estimated at every spikes. | ||
channel_from_template = True is old behavior but less acurate | ||
channel_from_template = False is slower but more accurate | ||
method : str | ||
'center_of_mass' / 'monopolar_triangulation' / 'grid_convolution' | ||
method_kwargs : dict | ||
|
@@ -138,7 +167,13 @@ def compute_spike_locations( | |
slc = waveform_extractor.load_extension(SpikeLocationsCalculator.extension_name) | ||
else: | ||
slc = SpikeLocationsCalculator(waveform_extractor) | ||
slc.set_params(ms_before=ms_before, ms_after=ms_after, method=method, method_kwargs=method_kwargs) | ||
slc.set_params( | ||
ms_before=ms_before, | ||
ms_after=ms_after, | ||
channel_from_template=channel_from_template, | ||
method=method, | ||
method_kwargs=method_kwargs, | ||
) | ||
slc.run(**job_kwargs) | ||
|
||
locs = slc.get_data(outputs=outputs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the print here