Skip to content

Commit

Permalink
One pass to get the prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Nov 15, 2024
1 parent bd525b2 commit fd9ee0b
Showing 1 changed file with 37 additions and 8 deletions.
45 changes: 37 additions & 8 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

from spikeinterface.core import NumpySorting
from spikeinterface.core.job_tools import fix_job_kwargs
from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs
from spikeinterface.core.recording_tools import get_noise_levels
from spikeinterface.core.template import Templates
from spikeinterface.core.waveform_tools import estimate_templates
Expand Down Expand Up @@ -179,8 +179,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
## Then, we are detecting peaks with a locally_exclusive method
detection_params = params["detection"].copy()
selection_params = params["selection"].copy()

detection_params["radius_um"] = detection_params.get("radius_um", 50)
detection_params["radius_um"] = radius_um
detection_params["exclude_sweep_ms"] = exclude_sweep_ms
detection_params["noise_levels"] = noise_levels

Expand All @@ -195,11 +194,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
max_n_peaks = selection_params["n_peaks_per_channel"] * num_channels
n_peaks = max(selection_params["min_n_peaks"], max_n_peaks)

if params["matched_filtering"]:
peaks = detect_peaks(
recording_w, "locally_exclusive", **detection_params, skip_after_n_peaks=5000, **job_kwargs
)
prototype = get_prototype_spike(recording_w, peaks, ms_before, ms_after, **job_kwargs)
if params["matched_filtering"]:
prototype = get_prototype(recording_w,
n_peaks=5000,
ms_before=ms_before,
ms_after=ms_after,
**detection_params,
**job_kwargs)
detection_params["prototype"] = prototype
detection_params["ms_before"] = ms_before
if skip_peaks:
Expand Down Expand Up @@ -425,3 +426,31 @@ def final_cleaning_circus(recording, sorting, templates, **merging_kwargs):
sorting = apply_merges_to_sorting(sa.sorting, merges)

return sorting


def get_prototype(recording, n_peaks, ms_before, ms_after, **all_kwargs):
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.core.node_pipeline import ExtractSparseWaveforms

detection_kwargs, job_kwargs = split_job_kwargs(all_kwargs)

node = ExtractSparseWaveforms(
recording,
parents=None,
return_output=True,
ms_before=ms_before,
ms_after=ms_after,
radius_um=0,
)

nbefore = int(ms_before * recording.sampling_frequency / 1000.0)
pipeline_nodes = [node]
res = detect_peaks(recording,
pipeline_nodes=pipeline_nodes,
skip_after_n_peaks=n_peaks,
**detection_kwargs,
**job_kwargs)
waveforms = res[1]
with np.errstate(divide="ignore", invalid="ignore"):
prototype = np.nanmedian(waveforms[:, :, 0] / (np.abs(waveforms[:, nbefore, 0][:, np.newaxis])), axis=0)
return prototype

0 comments on commit fd9ee0b

Please sign in to comment.