From 1939b936e94d30c8437633f89c49fd006ca71a80 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 4 Oct 2023 10:19:11 +0200 Subject: [PATCH 1/8] Diff for SC2 --- src/spikeinterface/sorters/internal/spyking_circus2.py | 7 ++++--- .../sortingcomponents/clustering/clustering_tools.py | 7 +++++-- .../sortingcomponents/clustering/random_projections.py | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index a0a4d0823c..db06287f6c 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -6,7 +6,7 @@ from spikeinterface.core import NumpySorting, load_extractor, BaseRecording, get_noise_levels, extract_waveforms from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore +from spikeinterface.preprocessing import common_reference, zscore, whiten, highpass_filter try: import hdbscan @@ -22,7 +22,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, - "filtering": {"dtype": "float32"}, + "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 5}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, "localization": {}, @@ -60,11 +60,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## First, we are filtering the data filtering_params = params["filtering"].copy() if params["apply_preprocessing"]: - recording_f = bandpass_filter(recording, **filtering_params) + recording_f = highpass_filter(recording, **filtering_params) recording_f = common_reference(recording_f) else: recording_f = recording + #recording_f = whiten(recording_f, dtype="float32") recording_f = zscore(recording_f, dtype="float32") ## Then, we are detecting peaks with a locally_exclusive method diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 891c355448..6dba4b7f0f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -598,14 +598,17 @@ def remove_duplicates_via_matching( "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, "amplitudes": [0.95, 1.05], - "omp_min_sps": 0.1, + "omp_min_sps": 0.05, } ) + spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()['unit_index'], return_counts=True) + indices = np.argsort(counts) + ignore_ids = [] similar_templates = [[], []] - for i in range(nb_templates): + for i in np.arange(nb_templates)[indices]: t_start = padding + i * duration t_stop = padding + (i + 1) * duration diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 1f97bf5201..d7ceef2561 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -33,7 +33,7 @@ class RandomProjectionClustering: "min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": os.cpu_count(), - "cluster_selection_method": "leaf", + "cluster_selection_method": "leaf" }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, From 4cd3747786728e2942bef43b5c9d5ecba8d102fb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Oct 2023 06:25:31 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- .../sortingcomponents/clustering/clustering_tools.py | 2 +- .../sortingcomponents/clustering/random_projections.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index db06287f6c..6cf925e852 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -65,7 +65,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: recording_f = recording - #recording_f = whiten(recording_f, dtype="float32") + # recording_f = whiten(recording_f, dtype="float32") recording_f = zscore(recording_f, dtype="float32") ## Then, we are detecting peaks with a locally_exclusive method diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 6dba4b7f0f..72cfd71791 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -602,7 +602,7 @@ def remove_duplicates_via_matching( } ) - spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()['unit_index'], return_counts=True) + spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()["unit_index"], return_counts=True) indices = np.argsort(counts) ignore_ids = [] diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index d7ceef2561..1f97bf5201 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -33,7 +33,7 @@ class RandomProjectionClustering: "min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": os.cpu_count(), - "cluster_selection_method": "leaf" + "cluster_selection_method": "leaf", }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, From 22c0eb426507be87790cbcd68427e3d3764721ee Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 5 Oct 2023 08:29:18 +0200 Subject: [PATCH 3/8] Fix bug while reloading --- .../sortingcomponents/clustering/clustering_tools.py | 2 +- .../sortingcomponents/clustering/random_projections.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 6dba4b7f0f..d94345f56b 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -664,7 +664,7 @@ def remove_duplicates_via_matching( labels = np.unique(new_labels) labels = labels[labels >= 0] - del recording, sub_recording, method_kwargs + del recording, sub_recording, method_kwargs, waveform_extractor os.remove(tmp_filename) return labels, new_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index d7ceef2561..4d1dd1f9d5 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -223,6 +223,8 @@ def sigmoid(x, L, x0, k, b): ) del we, sorting + import gc + gc.collect() if params["tmp_folder"] is None: shutil.rmtree(tmp_folder) From f69d7e3dbd013c52564b79c1f6ce5c87a3f67af0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Oct 2023 06:30:11 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/random_projections.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 7cb882409d..620346a875 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -224,6 +224,7 @@ def sigmoid(x, L, x0, k, b): del we, sorting import gc + gc.collect() if params["tmp_folder"] is None: From 403890ce83b065a76bcc1542a562d1a73e6e04be Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 5 Oct 2023 09:01:02 +0200 Subject: [PATCH 5/8] Found it! --- .../clustering/clustering_tools.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index ce29c47113..734ceff1a3 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -593,7 +593,9 @@ def remove_duplicates_via_matching( chunk_size = duration + 3 * margin - method_kwargs.update( + local_params = method_kwargs.copy() + + local_params.update( { "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, @@ -613,12 +615,12 @@ def remove_duplicates_via_matching( t_stop = padding + (i + 1) * duration sub_recording = recording.frame_slice(t_start - half_marging, t_stop + half_marging) - method_kwargs.update({"ignored_ids": ignore_ids + [i]}) + local_params.update({"ignored_ids": ignore_ids + [i]}) spikes, computed = find_spikes_from_templates( - sub_recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs + sub_recording, method=method, method_kwargs=local_params, extra_outputs=True, **job_kwargs ) if method == "circus-omp-svd": - method_kwargs.update( + local_params.update( { "overlaps": computed["overlaps"], "templates": computed["templates"], @@ -632,7 +634,7 @@ def remove_duplicates_via_matching( } ) elif method == "circus-omp": - method_kwargs.update( + local_params.update( { "overlaps": computed["overlaps"], "templates": computed["templates"], @@ -664,7 +666,7 @@ def remove_duplicates_via_matching( labels = np.unique(new_labels) labels = labels[labels >= 0] - del recording, sub_recording, method_kwargs, waveform_extractor + del recording, sub_recording, local_params, waveform_extractor os.remove(tmp_filename) return labels, new_labels From 6951e856c0794e78108be180d6f16e0fde6af6e2 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 5 Oct 2023 09:54:27 +0200 Subject: [PATCH 6/8] WIP --- .../sortingcomponents/clustering/clustering_tools.py | 2 +- .../sortingcomponents/clustering/random_projections.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 734ceff1a3..b4938717f8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -610,7 +610,7 @@ def remove_duplicates_via_matching( ignore_ids = [] similar_templates = [[], []] - for i in np.arange(nb_templates)[indices]: + for i in indices: t_start = padding + i * duration t_stop = padding + (i + 1) * duration diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 620346a875..1f97bf5201 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -223,9 +223,6 @@ def sigmoid(x, L, x0, k, b): ) del we, sorting - import gc - - gc.collect() if params["tmp_folder"] is None: shutil.rmtree(tmp_folder) From fdebd12b09654796a177f4ab91b8e614409f5ac7 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 5 Oct 2023 10:43:20 +0200 Subject: [PATCH 7/8] Sparse waveforms were not handled --- src/spikeinterface/sorters/internal/spyking_circus2.py | 4 ++-- .../sortingcomponents/clustering/random_projections.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 6cf925e852..0c3b9f95d1 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -99,10 +99,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We launch a clustering (using hdbscan) relying on positions and features extracted on ## the fly from the snippets clustering_params = params["clustering"].copy() - clustering_params["waveforms_kwargs"] = params["waveforms"] + clustering_params["waveforms"] = params["waveforms"].copy() for k in ["ms_before", "ms_after"]: - clustering_params["waveforms_kwargs"][k] = params["general"][k] + clustering_params["waveforms"][k] = params["general"][k] clustering_params.update(dict(shared_memory=params["shared_memory"])) clustering_params["job_kwargs"] = job_kwargs diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 1f97bf5201..ffb868f682 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -199,9 +199,8 @@ def sigmoid(x, L, x0, k, b): recording, sorting, waveform_folder, - ms_before=params["ms_before"], - ms_after=params["ms_after"], **params["job_kwargs"], + **params['waveforms'], return_scaled=False, mode=mode, ) From b6f9235a7cf9c2ad106ec0e4cb6be365a243d2af Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Oct 2023 08:44:20 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/random_projections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index ffb868f682..a81458d7a8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -200,7 +200,7 @@ def sigmoid(x, L, x0, k, b): sorting, waveform_folder, **params["job_kwargs"], - **params['waveforms'], + **params["waveforms"], return_scaled=False, mode=mode, )