From 8e37d54f4a0da1610ef5e5e044488924981ed54c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Nov 2024 09:19:07 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/spyking_circus2.py | 39 ++++++++----------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index b6b98b8b37..493b750164 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -194,13 +194,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): max_n_peaks = selection_params["n_peaks_per_channel"] * num_channels n_peaks = max(selection_params["min_n_peaks"], max_n_peaks) - if params["matched_filtering"]: - prototype = get_prototype(recording_w, - n_peaks=5000, - ms_before=ms_before, - ms_after=ms_after, - **detection_params, - **job_kwargs) + if params["matched_filtering"]: + prototype = get_prototype( + recording_w, n_peaks=5000, ms_before=ms_before, ms_after=ms_after, **detection_params, **job_kwargs + ) detection_params["prototype"] = prototype detection_params["ms_before"] = ms_before if skip_peaks: @@ -431,26 +428,24 @@ def final_cleaning_circus(recording, sorting, templates, **merging_kwargs): def get_prototype(recording, n_peaks, ms_before, ms_after, **all_kwargs): from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.core.node_pipeline import ExtractSparseWaveforms - + detection_kwargs, job_kwargs = split_job_kwargs(all_kwargs) - + node = ExtractSparseWaveforms( - recording, - parents=None, - return_output=True, - ms_before=ms_before, - ms_after=ms_after, - radius_um=0, - ) + recording, + parents=None, + return_output=True, + ms_before=ms_before, + ms_after=ms_after, + radius_um=0, + ) nbefore = int(ms_before * recording.sampling_frequency / 1000.0) pipeline_nodes = [node] - res = detect_peaks(recording, - pipeline_nodes=pipeline_nodes, - skip_after_n_peaks=n_peaks, - **detection_kwargs, - **job_kwargs) + res = detect_peaks( + recording, pipeline_nodes=pipeline_nodes, skip_after_n_peaks=n_peaks, **detection_kwargs, **job_kwargs + ) waveforms = res[1] with np.errstate(divide="ignore", invalid="ignore"): prototype = np.nanmedian(waveforms[:, :, 0] / (np.abs(waveforms[:, nbefore, 0][:, np.newaxis])), axis=0) - return prototype \ No newline at end of file + return prototype