Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 15, 2024
1 parent fd9ee0b commit 8e37d54
Showing 1 changed file with 17 additions and 22 deletions.
39 changes: 17 additions & 22 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
max_n_peaks = selection_params["n_peaks_per_channel"] * num_channels
n_peaks = max(selection_params["min_n_peaks"], max_n_peaks)

if params["matched_filtering"]:
prototype = get_prototype(recording_w,
n_peaks=5000,
ms_before=ms_before,
ms_after=ms_after,
**detection_params,
**job_kwargs)
if params["matched_filtering"]:
prototype = get_prototype(
recording_w, n_peaks=5000, ms_before=ms_before, ms_after=ms_after, **detection_params, **job_kwargs
)
detection_params["prototype"] = prototype
detection_params["ms_before"] = ms_before
if skip_peaks:
Expand Down Expand Up @@ -431,26 +428,24 @@ def final_cleaning_circus(recording, sorting, templates, **merging_kwargs):
def get_prototype(recording, n_peaks, ms_before, ms_after, **all_kwargs):
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.core.node_pipeline import ExtractSparseWaveforms

detection_kwargs, job_kwargs = split_job_kwargs(all_kwargs)

node = ExtractSparseWaveforms(
recording,
parents=None,
return_output=True,
ms_before=ms_before,
ms_after=ms_after,
radius_um=0,
)
recording,
parents=None,
return_output=True,
ms_before=ms_before,
ms_after=ms_after,
radius_um=0,
)

nbefore = int(ms_before * recording.sampling_frequency / 1000.0)
pipeline_nodes = [node]
res = detect_peaks(recording,
pipeline_nodes=pipeline_nodes,
skip_after_n_peaks=n_peaks,
**detection_kwargs,
**job_kwargs)
res = detect_peaks(
recording, pipeline_nodes=pipeline_nodes, skip_after_n_peaks=n_peaks, **detection_kwargs, **job_kwargs
)
waveforms = res[1]
with np.errstate(divide="ignore", invalid="ignore"):
prototype = np.nanmedian(waveforms[:, :, 0] / (np.abs(waveforms[:, nbefore, 0][:, np.newaxis])), axis=0)
return prototype
return prototype

0 comments on commit 8e37d54

Please sign in to comment.