diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index b6b98b8b37..493b750164 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -194,13 +194,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): max_n_peaks = selection_params["n_peaks_per_channel"] * num_channels n_peaks = max(selection_params["min_n_peaks"], max_n_peaks) - if params["matched_filtering"]: - prototype = get_prototype(recording_w, - n_peaks=5000, - ms_before=ms_before, - ms_after=ms_after, - **detection_params, - **job_kwargs) + if params["matched_filtering"]: + prototype = get_prototype( + recording_w, n_peaks=5000, ms_before=ms_before, ms_after=ms_after, **detection_params, **job_kwargs + ) detection_params["prototype"] = prototype detection_params["ms_before"] = ms_before if skip_peaks: @@ -431,26 +428,24 @@ def final_cleaning_circus(recording, sorting, templates, **merging_kwargs): def get_prototype(recording, n_peaks, ms_before, ms_after, **all_kwargs): from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.core.node_pipeline import ExtractSparseWaveforms - + detection_kwargs, job_kwargs = split_job_kwargs(all_kwargs) - + node = ExtractSparseWaveforms( - recording, - parents=None, - return_output=True, - ms_before=ms_before, - ms_after=ms_after, - radius_um=0, - ) + recording, + parents=None, + return_output=True, + ms_before=ms_before, + ms_after=ms_after, + radius_um=0, + ) nbefore = int(ms_before * recording.sampling_frequency / 1000.0) pipeline_nodes = [node] - res = detect_peaks(recording, - pipeline_nodes=pipeline_nodes, - skip_after_n_peaks=n_peaks, - **detection_kwargs, - **job_kwargs) + res = detect_peaks( + recording, pipeline_nodes=pipeline_nodes, skip_after_n_peaks=n_peaks, **detection_kwargs, **job_kwargs + ) waveforms = res[1] with np.errstate(divide="ignore", invalid="ignore"): prototype = np.nanmedian(waveforms[:, :, 0] / (np.abs(waveforms[:, nbefore, 0][:, np.newaxis])), axis=0) - return prototype \ No newline at end of file + return prototype