diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index a0a4d0823c..0c3b9f95d1 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -6,7 +6,7 @@ from spikeinterface.core import NumpySorting, load_extractor, BaseRecording, get_noise_levels, extract_waveforms from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore +from spikeinterface.preprocessing import common_reference, zscore, whiten, highpass_filter try: import hdbscan @@ -22,7 +22,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, - "filtering": {"dtype": "float32"}, + "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 5}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, "localization": {}, @@ -60,11 +60,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## First, we are filtering the data filtering_params = params["filtering"].copy() if params["apply_preprocessing"]: - recording_f = bandpass_filter(recording, **filtering_params) + recording_f = highpass_filter(recording, **filtering_params) recording_f = common_reference(recording_f) else: recording_f = recording + # recording_f = whiten(recording_f, dtype="float32") recording_f = zscore(recording_f, dtype="float32") ## Then, we are detecting peaks with a locally_exclusive method @@ -98,10 +99,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We launch a clustering (using hdbscan) relying on positions and features extracted on ## the fly from the snippets clustering_params = params["clustering"].copy() - clustering_params["waveforms_kwargs"] = params["waveforms"] + clustering_params["waveforms"] = params["waveforms"].copy() for k in ["ms_before", "ms_after"]: - clustering_params["waveforms_kwargs"][k] = params["general"][k] + clustering_params["waveforms"][k] = params["general"][k] clustering_params.update(dict(shared_memory=params["shared_memory"])) clustering_params["job_kwargs"] = job_kwargs diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 891c355448..b4938717f8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -593,29 +593,34 @@ def remove_duplicates_via_matching( chunk_size = duration + 3 * margin - method_kwargs.update( + local_params = method_kwargs.copy() + + local_params.update( { "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, "amplitudes": [0.95, 1.05], - "omp_min_sps": 0.1, + "omp_min_sps": 0.05, } ) + spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()["unit_index"], return_counts=True) + indices = np.argsort(counts) + ignore_ids = [] similar_templates = [[], []] - for i in range(nb_templates): + for i in indices: t_start = padding + i * duration t_stop = padding + (i + 1) * duration sub_recording = recording.frame_slice(t_start - half_marging, t_stop + half_marging) - method_kwargs.update({"ignored_ids": ignore_ids + [i]}) + local_params.update({"ignored_ids": ignore_ids + [i]}) spikes, computed = find_spikes_from_templates( - sub_recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs + sub_recording, method=method, method_kwargs=local_params, extra_outputs=True, **job_kwargs ) if method == "circus-omp-svd": - method_kwargs.update( + local_params.update( { "overlaps": computed["overlaps"], "templates": computed["templates"], @@ -629,7 +634,7 @@ def remove_duplicates_via_matching( } ) elif method == "circus-omp": - method_kwargs.update( + local_params.update( { "overlaps": computed["overlaps"], "templates": computed["templates"], @@ -661,7 +666,7 @@ def remove_duplicates_via_matching( labels = np.unique(new_labels) labels = labels[labels >= 0] - del recording, sub_recording, method_kwargs + del recording, sub_recording, local_params, waveform_extractor os.remove(tmp_filename) return labels, new_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 1f97bf5201..a81458d7a8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -199,9 +199,8 @@ def sigmoid(x, L, x0, k, b): recording, sorting, waveform_folder, - ms_before=params["ms_before"], - ms_after=params["ms_after"], **params["job_kwargs"], + **params["waveforms"], return_scaled=False, mode=mode, )