From 7f1294f6286eaf5c8c8616545c08f91235af9b9d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 25 Oct 2023 09:14:38 +0200 Subject: [PATCH 01/10] oups --- src/spikeinterface/sortingcomponents/clustering/merge.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 4c79383542..076ed8438f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -674,8 +674,8 @@ def merge( final_shift = 0 merge_value = np.nan - # DEBUG = False - DEBUG = True + DEBUG = False + # DEBUG = True if DEBUG and normed_diff < 0.2: # if DEBUG: @@ -683,8 +683,8 @@ def merge( fig, ax = plt.subplots() - m0 = template0.flatten() - m1 = template1.flatten() + m0 = template0.T.flatten() + m1 = template1.T.flatten() ax.plot(m0, color="C0", label=f"{label0} {inds0.size}") ax.plot(m1, color="C1", label=f"{label1} {inds1.size}") From 4421cb890ee867cd15d6d685aaa8efa5290f59d4 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 2 Nov 2023 15:18:22 +0100 Subject: [PATCH 02/10] try omp on tdc2 --- .../sorters/internal/tridesclous2.py | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index e256915fa6..b6a1a9da0a 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -12,7 +12,7 @@ from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore +from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore, whiten from spikeinterface.core.basesorting import minimum_spike_dtype import numpy as np @@ -91,6 +91,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # TODO what is the best about zscore>common_reference or the reverse recording = common_reference(recording) recording = zscore(recording, dtype="float32") + # recording = whiten(recording, dtype="float32") noise_levels = np.ones(num_chans, dtype="float32") else: recording = recording_raw @@ -289,17 +290,29 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): **job_kwargs, ) + # matching_params = params["matching"].copy() + # matching_params["waveform_extractor"] = we + # matching_params["noise_levels"] = noise_levels + # matching_params["peak_sign"] = params["detection"]["peak_sign"] + # matching_params["detect_threshold"] = params["detection"]["detect_threshold"] + # matching_params["radius_um"] = params["detection"]["radius_um"] + + # spikes = find_spikes_from_templates( + # recording, method="tridesclous", method_kwargs=matching_params, **job_kwargs + # ) + matching_params = params["matching"].copy() matching_params["waveform_extractor"] = we matching_params["noise_levels"] = noise_levels - matching_params["peak_sign"] = params["detection"]["peak_sign"] - matching_params["detect_threshold"] = params["detection"]["detect_threshold"] - matching_params["radius_um"] = params["detection"]["radius_um"] + # matching_params["peak_sign"] = params["detection"]["peak_sign"] + # matching_params["detect_threshold"] = params["detection"]["detect_threshold"] + # matching_params["radius_um"] = params["detection"]["radius_um"] spikes = find_spikes_from_templates( - recording, method="tridesclous", method_kwargs=matching_params, **job_kwargs + recording, method="circus-omp-svd", method_kwargs=matching_params, **job_kwargs ) + if params["save_array"]: np.save(sorter_output_folder / "noise_levels.npy", noise_levels) np.save(sorter_output_folder / "all_peaks.npy", all_peaks) From 08829752d88026a83c28c12cf5d69f027d6106b3 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 29 Nov 2023 10:57:44 +0100 Subject: [PATCH 03/10] GTstudy change key separator (it break SC) --- src/spikeinterface/comparison/groundtruthstudy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 7269960dc1..16b71b7c6e 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -21,7 +21,8 @@ # This is to separate names when the key are tuples when saving folders -_key_separator = "_##_" +# _key_separator = "_##_" +_key_separator = "_-°°-_" class GroundTruthStudy: From 3081ec7009825aa599dd6b215bd71703c423e16c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 8 Dec 2023 11:59:47 +0100 Subject: [PATCH 04/10] Put _params_description for Tridesclous2Sorter --- .../sorters/internal/tridesclous2.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 7914b322a0..588e79a92e 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -36,7 +36,6 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "filtering": {"freq_min": 300.0, "freq_max": 12000.0}, "detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 1.5, "radius_um": 150.0}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, - "features": {}, "svd": {"n_components": 6}, "clustering": { "split_radius_um": 40.0, @@ -52,6 +51,20 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "save_array": True, } + _params_description = { + "apply_preprocessing": "Apply internal preprocessing or not", + "waveforms": "A dictonary containing waveforms params: ms_before, ms_after, radius_um", + "filtering": "A dictonary containing filtering params: freq_min, freq_max", + "detection": "A dictonary containing detection params: peak_sign, detect_threshold, exclude_sweep_ms, radius_um", + "selection": "A dictonary containing selection params: n_peaks_per_channel, min_n_peaks", + "svd": "A dictonary containing svd params: n_components", + "clustering": "A dictonary containing clustering params: split_radius_um, merge_radius_um", + "templates": "A dictonary containing waveforms params for peeler: ms_before, ms_after", + "matching": "A dictonary containing matching params for matching: peak_shift_ms, radius_um", + "job_kwargs": "A dictionnary containing job kwargs", + "save_array": "Save or not intermediate arrays", + } + handle_multi_segment = True @classmethod From 1fb42824b4e4e8a53ca3f2baf09b6ac3e3d31e6c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Sat, 20 Jan 2024 18:30:51 +0100 Subject: [PATCH 05/10] tdc2 : cache_preprocessing --- .../sorters/internal/tridesclous2.py | 32 ++++++++++++++++--- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 588e79a92e..47e31da2cb 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -15,7 +15,7 @@ from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore, whiten from spikeinterface.core.basesorting import minimum_spike_dtype -from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel +from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel, cache_preprocessing import numpy as np @@ -28,6 +28,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter): _default_params = { "apply_preprocessing": True, + "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, "waveforms": { "ms_before": 0.5, "ms_after": 1.5, @@ -46,13 +47,16 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "ms_after": 2.5, # "peak_shift_ms": 0.2, }, - "matching": {"peak_shift_ms": 0.2, "radius_um": 100.0}, + # "matching": {"method": "tridesclous", "method_kwargs": {"peak_shift_ms": 0.2, "radius_um": 100.0}}, + "matching": {"method": "circus-omp-svd", "method_kwargs": {}}, + "job_kwargs": {"n_jobs": -1}, "save_array": True, } _params_description = { "apply_preprocessing": "Apply internal preprocessing or not", + "cache_preprocessing": "A dict contaning how to cache the preprocessed recording. mode='memory' | 'folder | 'zarr' ", "waveforms": "A dictonary containing waveforms params: ms_before, ms_after, radius_um", "filtering": "A dictonary containing filtering params: freq_min, freq_max", "detection": "A dictonary containing detection params: peak_sign, detect_threshold, exclude_sweep_ms, radius_um", @@ -109,6 +113,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording = common_reference(recording) recording = zscore(recording, dtype="float32") # recording = whiten(recording, dtype="float32") + + # used only if "folder" or "zarr" + cache_folder = sorter_output_folder / "cache_preprocessing" + recording = cache_preprocessing(recording, folder=cache_folder, **job_kwargs, **params["cache_preprocessing"]) + noise_levels = np.ones(num_chans, dtype="float32") else: recording = recording_raw @@ -321,18 +330,31 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # recording, method="tridesclous", method_kwargs=matching_params, **job_kwargs # ) - matching_params = params["matching"].copy() + matching_method = params["matching"]["method"] + matching_params = params["matching"]["method_kwargs"].copy() + matching_params["waveform_extractor"] = we matching_params["noise_levels"] = noise_levels # matching_params["peak_sign"] = params["detection"]["peak_sign"] # matching_params["detect_threshold"] = params["detection"]["detect_threshold"] # matching_params["radius_um"] = params["detection"]["radius_um"] + # spikes = find_spikes_from_templates( + # recording, method="tridesclous", method_kwargs=matching_params, **job_kwargs + # ) + # ) + + if matching_method == "circus-omp-svd": + job_kwargs = job_kwargs.copy() + for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: + if value in job_kwargs: + job_kwargs.pop(value) + job_kwargs["chunk_duration"] = "100ms" + spikes = find_spikes_from_templates( - recording, method="circus-omp-svd", method_kwargs=matching_params, **job_kwargs + recording, method=matching_method, method_kwargs=matching_params, **job_kwargs ) - if params["save_array"]: np.save(sorter_output_folder / "noise_levels.npy", noise_levels) np.save(sorter_output_folder / "all_peaks.npy", all_peaks) From 01f82e23fbd99b7ac8914de84ee6317486c34588 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 23 Jan 2024 08:59:37 +0100 Subject: [PATCH 06/10] test on auto merge --- .../sortingcomponents/clustering/merge.py | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 5f580d4d99..353990d9d3 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -647,6 +647,8 @@ def merge( num_samples = template0.shape[0] # norm = np.mean(np.abs(template0)) + np.mean(np.abs(template1)) norm = np.mean(np.abs(template0) + np.abs(template1)) + # norm = np.median(np.abs(template0) + np.abs(template1)) + all_shift_diff = [] for shift in range(-num_shift, num_shift + 1): temp0 = template0[num_shift : num_samples - num_shift, :] @@ -663,18 +665,19 @@ def merge( final_shift = 0 merge_value = np.nan - DEBUG = False - # DEBUG = True - if DEBUG and normed_diff < 0.2: - # if DEBUG: + # DEBUG = False + DEBUG = True + # if DEBUG and normed_diff < 0.2: + if DEBUG: import matplotlib.pyplot as plt - fig, ax = plt.subplots() + fig, axs = plt.subplots(nrows=2) m0 = template0.T.flatten() m1 = template1.T.flatten() - + + ax = axs[0] ax.plot(m0, color="C0", label=f"{label0} {inds0.size}") ax.plot(m1, color="C1", label=f"{label1} {inds1.size}") @@ -682,6 +685,21 @@ def merge( f"union{union_chans.size} intersect{target_chans.size} \n {normed_diff:.3f} {final_shift} {is_merge}" ) ax.legend() + + ax = axs[1] + + #~ temp0 = template0[num_shift : num_samples - num_shift, :] + #~ temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :] + ax.plot(np.abs(m0 - m1)) + ax.axhline(norm) + # ax.plot(np.abs(m0) + np.abs(m1)) + + # ax.plot(np.abs(m0 - m1) / (np.abs(m0) + np.abs(m1))) + + ax.set_title(f"{norm}") + + + plt.show() return is_merge, label0, label1, final_shift, merge_value From c32f3fc8465e259712a5c938c04db211759d48c4 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Fri, 2 Feb 2024 12:09:53 +0100 Subject: [PATCH 07/10] Update src/spikeinterface/sorters/internal/tridesclous2.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/sorters/internal/tridesclous2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 47e31da2cb..c661485237 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -65,7 +65,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "clustering": "A dictonary containing clustering params: split_radius_um, merge_radius_um", "templates": "A dictonary containing waveforms params for peeler: ms_before, ms_after", "matching": "A dictonary containing matching params for matching: peak_shift_ms, radius_um", - "job_kwargs": "A dictionnary containing job kwargs", + "job_kwargs": "A dictionary containing job kwargs", "save_array": "Save or not intermediate arrays", } From b24e40362d035176bab1ca6f5915a4cffb609fe8 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 5 Feb 2024 17:27:00 +0100 Subject: [PATCH 08/10] wip merge --- .../sorters/internal/tridesclous2.py | 6 +-- .../sortingcomponents/clustering/merge.py | 40 +++++++++++++++---- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index c661485237..b65074b589 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -1,9 +1,6 @@ -import shutil from .si_based import ComponentsBasedSorter from spikeinterface.core import ( - load_extractor, - BaseRecording, get_noise_levels, extract_waveforms, NumpySorting, @@ -272,7 +269,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): method="normalized_template_diff", method_kwargs=dict( waveforms_sparse_mask=sparse_mask, - threshold_diff=0.2, + # threshold_diff=0.2, + threshold_diff=3, min_cluster_size=min_cluster_size + 1, num_shift=5, ), diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 3661d8dbe2..a450a7d0eb 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -647,32 +647,47 @@ def merge( num_samples = template0.shape[0] # norm = np.mean(np.abs(template0)) + np.mean(np.abs(template1)) norm = np.mean(np.abs(template0) + np.abs(template1)) + # norm = np.mean(np.abs(template0) + np.abs(template1), axis=0) # norm = np.median(np.abs(template0) + np.abs(template1)) all_shift_diff = [] + # all_shift_diff_by_channel = [] for shift in range(-num_shift, num_shift + 1): temp0 = template0[num_shift : num_samples - num_shift, :] temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :] - d = np.mean(np.abs(temp0 - temp1)) / (norm) + #d = np.mean(np.abs(temp0 - temp1)) / (norm) + d = np.max(np.abs(temp0 - temp1)) / (norm) all_shift_diff.append(d) + # diff_by_channel = np.mean(np.abs(temp0 - temp1), axis=0) / (norm) + # all_shift_diff_by_channel.append(diff_by_channel) + # d = np.mean(diff_by_channel) + # all_shift_diff.append(d) normed_diff = np.min(all_shift_diff) + is_merge = normed_diff < threshold_diff + if is_merge: merge_value = normed_diff final_shift = np.argmin(all_shift_diff) - num_shift + + # diff_by_channel = all_shift_diff_by_channel[np.argmin(all_shift_diff)] else: final_shift = 0 merge_value = np.nan + - # DEBUG = False - DEBUG = True - # if DEBUG and normed_diff < 0.2: - if DEBUG: + # print('merge_value', merge_value, 'final_shift', final_shift, 'is_merge', is_merge) + + DEBUG = False + # DEBUG = True + if DEBUG and ( 0. < normed_diff < 5): + # if 0.5 < normed_diff < 1: + # if DEBUG and is_merge: import matplotlib.pyplot as plt - fig, axs = plt.subplots(nrows=2) + fig, axs = plt.subplots(nrows=3) m0 = template0.T.flatten() m1 = template1.T.flatten() @@ -691,12 +706,21 @@ def merge( #~ temp0 = template0[num_shift : num_samples - num_shift, :] #~ temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :] ax.plot(np.abs(m0 - m1)) - ax.axhline(norm) + ax.axhline(norm, ls='--', color='k') + ax = axs[2] + ax.plot(np.abs(m0 - m1) / norm) + ax.axhline(normed_diff) + + + # ax.axhline(normed_diff, ls='-', color='b') + # ax.plot(norm, ls='--') + # ax.plot(diff_by_channel) + # ax.plot(np.abs(m0) + np.abs(m1)) # ax.plot(np.abs(m0 - m1) / (np.abs(m0) + np.abs(m1))) - ax.set_title(f"{norm}") + ax.set_title(f"{norm=:.3f}") From 9a2c6ef223c43eff5cd8c9944df4dec2c1444593 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 5 Feb 2024 21:37:05 +0100 Subject: [PATCH 09/10] A quick and drty clean for tridesclous2 --- .../sorters/internal/tridesclous2.py | 53 +++++++------------ .../sortingcomponents/clustering/merge.py | 48 ++++++++++++----- 2 files changed, 54 insertions(+), 47 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index b65074b589..70f9cb029c 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -14,6 +14,8 @@ from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel, cache_preprocessing +# from spikeinterface.qualitymetrics import compute_snrs + import numpy as np import pickle @@ -38,10 +40,12 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "clustering": { "split_radius_um": 40.0, "merge_radius_um": 40.0, + "threshold_diff": 1.5, }, "templates": { - "ms_before": 1.5, - "ms_after": 2.5, + "ms_before": 2., + "ms_after": 3., + "max_spikes_per_unit" : 400, # "peak_shift_ms": 0.2, }, # "matching": {"method": "tridesclous", "method_kwargs": {"peak_shift_ms": 0.2, "radius_um": 100.0}}, @@ -169,22 +173,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): features_folder = sorter_output_folder / "features" node0 = PeakRetriever(recording, peaks) - # node1 = ExtractDenseWaveforms(rec, parents=[node0], return_output=False, - # ms_before=0.5, - # ms_after=1.5, - # ) - - # node2 = LocalizeCenterOfMass(rec, parents=[node0, node1], return_output=True, - # local_radius_um=75.0, - # feature="ptp", ) - - # node2 = LocalizeGridConvolution(rec, parents=[node0, node1], return_output=True, - # local_radius_um=40., - # upsampling_um=5.0, - # ) - radius_um = params["waveforms"]["radius_um"] - node3 = ExtractSparseWaveforms( + node1 = ExtractSparseWaveforms( recording, parents=[node0], return_output=True, @@ -195,12 +185,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): model_folder_path = sorter_output_folder / "tsvd_model" - node4 = TemporalPCAProjection( - recording, parents=[node0, node3], return_output=True, model_folder_path=model_folder_path + node2 = TemporalPCAProjection( + recording, parents=[node0, node1], return_output=True, model_folder_path=model_folder_path ) - # pipeline_nodes = [node0, node1, node2, node3, node4] - pipeline_nodes = [node0, node3, node4] + pipeline_nodes = [node0, node1, node2] output = run_node_pipeline( recording, @@ -213,7 +202,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) # TODO make this generic in GatherNPY ??? - sparse_mask = node3.neighbours_mask + sparse_mask = node1.neighbours_mask np.save(features_folder / "sparse_mask.npy", sparse_mask) np.save(features_folder / "peaks.npy", peaks) @@ -249,6 +238,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) merge_radius_um = params["clustering"]["merge_radius_um"] + threshold_diff = params["clustering"]["threshold_diff"] + post_merge_label, peak_shifts = merge_clusters( peaks, @@ -269,8 +260,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): method="normalized_template_diff", method_kwargs=dict( waveforms_sparse_mask=sparse_mask, - # threshold_diff=0.2, - threshold_diff=3, + threshold_diff=threshold_diff, min_cluster_size=min_cluster_size + 1, num_shift=5, ), @@ -303,19 +293,16 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) sorting_temp = sorting_temp.save(folder=sorter_output_folder / "sorting_temp") - ms_before = params["templates"]["ms_before"] - ms_after = params["templates"]["ms_after"] - max_spikes_per_unit = 300 - we = extract_waveforms( recording, sorting_temp, sorter_output_folder / "waveforms_temp", - ms_before=ms_before, - ms_after=ms_after, - max_spikes_per_unit=max_spikes_per_unit, - **job_kwargs, - ) + **params["templates"]) + + # snrs = compute_snrs(we, peak_sign=params["detection"]["peak_sign"], peak_mode="extremum") + # print(snrs) + + # matching_params = params["matching"].copy() # matching_params["waveform_extractor"] = we diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index a450a7d0eb..af2748e47a 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -615,7 +615,7 @@ def merge( peaks, features, waveforms_sparse_mask=None, - threshold_diff=0.05, + threshold_diff=1.5, min_cluster_size=50, num_shift=5, ): @@ -647,8 +647,11 @@ def merge( num_samples = template0.shape[0] # norm = np.mean(np.abs(template0)) + np.mean(np.abs(template1)) norm = np.mean(np.abs(template0) + np.abs(template1)) - # norm = np.mean(np.abs(template0) + np.abs(template1), axis=0) - # norm = np.median(np.abs(template0) + np.abs(template1)) + + # norm_per_channel = np.max(np.abs(template0) + np.abs(template1), axis=0) / 2. + norm_per_channel = (np.max(np.abs(template0), axis=0) + np.max(np.abs(template1), axis=0)) * 0.5 + # norm_per_channel = np.max(np.abs(template0)) + np.max(np.abs(template1)) / 2. + # print(norm_per_channel) all_shift_diff = [] # all_shift_diff_by_channel = [] @@ -656,8 +659,15 @@ def merge( temp0 = template0[num_shift : num_samples - num_shift, :] temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :] #d = np.mean(np.abs(temp0 - temp1)) / (norm) - d = np.max(np.abs(temp0 - temp1)) / (norm) - all_shift_diff.append(d) + # d = np.max(np.abs(temp0 - temp1)) / (norm) + diff_per_channel = np.abs(temp0 - temp1) / norm + + diff_max = np.max(diff_per_channel, axis=0) + + # diff = np.max(diff_per_channel) + diff = np.average(diff_max, weights=norm_per_channel) + # diff = np.average(diff_max) + all_shift_diff.append(diff) # diff_by_channel = np.mean(np.abs(temp0 - temp1), axis=0) / (norm) # all_shift_diff_by_channel.append(diff_by_channel) # d = np.mean(diff_by_channel) @@ -681,17 +691,26 @@ def merge( DEBUG = False # DEBUG = True - if DEBUG and ( 0. < normed_diff < 5): - # if 0.5 < normed_diff < 1: - # if DEBUG and is_merge: + # if DEBUG and ( 0. < normed_diff < .4): + # if 0.5 < normed_diff < 4: + if DEBUG and is_merge: + # if DEBUG: import matplotlib.pyplot as plt fig, axs = plt.subplots(nrows=3) - m0 = template0.T.flatten() - m1 = template1.T.flatten() - + temp0 = template0[num_shift : num_samples - num_shift, :] + temp1 = template1[num_shift + final_shift : num_samples - num_shift + final_shift, :] + + diff_per_channel = np.abs(temp0 - temp1) / norm + diff = np.max(diff_per_channel) + + m0 = temp0.T.flatten() + m1 = temp1.T.flatten() + + + ax = axs[0] ax.plot(m0, color="C0", label=f"{label0} {inds0.size}") ax.plot(m1, color="C1", label=f"{label1} {inds1.size}") @@ -706,9 +725,10 @@ def merge( #~ temp0 = template0[num_shift : num_samples - num_shift, :] #~ temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :] ax.plot(np.abs(m0 - m1)) - ax.axhline(norm, ls='--', color='k') + # ax.axhline(norm, ls='--', color='k') ax = axs[2] - ax.plot(np.abs(m0 - m1) / norm) + ax.plot(diff_per_channel.T.flatten()) + ax.axhline(threshold_diff, ls='--') ax.axhline(normed_diff) @@ -720,7 +740,7 @@ def merge( # ax.plot(np.abs(m0 - m1) / (np.abs(m0) + np.abs(m1))) - ax.set_title(f"{norm=:.3f}") + # ax.set_title(f"{norm=:.3f}") From ecb7e3e3c35d3c7a673857545b58ef528100be61 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 20:38:40 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../comparison/groundtruthstudy.py | 2 +- .../sorters/internal/tridesclous2.py | 32 ++++++++----------- .../sortingcomponents/clustering/merge.py | 21 ++++-------- 3 files changed, 21 insertions(+), 34 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index a07b06349c..448ac3b361 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -23,7 +23,7 @@ # This is to separate names when the key are tuples when saving folders -# _key_separator = "_##_" +# _key_separator = "_##_" _key_separator = "_-°°-_" diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 75a8de8f9a..782758178e 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -47,14 +47,13 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "threshold_diff": 1.5, }, "templates": { - "ms_before": 2., - "ms_after": 3., - "max_spikes_per_unit" : 400, + "ms_before": 2.0, + "ms_after": 3.0, + "max_spikes_per_unit": 400, # "peak_shift_ms": 0.2, }, # "matching": {"method": "tridesclous", "method_kwargs": {"peak_shift_ms": 0.2, "radius_um": 100.0}}, "matching": {"method": "circus-omp-svd", "method_kwargs": {}}, - "job_kwargs": {"n_jobs": -1}, "save_array": True, } @@ -63,16 +62,16 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "apply_preprocessing": "Apply internal preprocessing or not", "cache_preprocessing": "A dict contaning how to cache the preprocessed recording. mode='memory' | 'folder | 'zarr' ", "waveforms": "A dictonary containing waveforms params: ms_before, ms_after, radius_um", - "filtering": "A dictonary containing filtering params: freq_min, freq_max", - "detection": "A dictonary containing detection params: peak_sign, detect_threshold, exclude_sweep_ms, radius_um", - "selection": "A dictonary containing selection params: n_peaks_per_channel, min_n_peaks", - "svd": "A dictonary containing svd params: n_components", + "filtering": "A dictonary containing filtering params: freq_min, freq_max", + "detection": "A dictonary containing detection params: peak_sign, detect_threshold, exclude_sweep_ms, radius_um", + "selection": "A dictonary containing selection params: n_peaks_per_channel, min_n_peaks", + "svd": "A dictonary containing svd params: n_components", "clustering": "A dictonary containing clustering params: split_radius_um, merge_radius_um", "templates": "A dictonary containing waveforms params for peeler: ms_before, ms_after", "matching": "A dictonary containing matching params for matching: peak_shift_ms, radius_um", "job_kwargs": "A dictionary containing job kwargs", "save_array": "Save or not intermediate arrays", - } + } handle_multi_segment = True @@ -118,10 +117,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording = common_reference(recording) recording = zscore(recording, dtype="float32") # recording = whiten(recording, dtype="float32") - + # used only if "folder" or "zarr" cache_folder = sorter_output_folder / "cache_preprocessing" - recording = cache_preprocessing(recording, folder=cache_folder, **job_kwargs, **params["cache_preprocessing"]) + recording = cache_preprocessing( + recording, folder=cache_folder, **job_kwargs, **params["cache_preprocessing"] + ) noise_levels = np.ones(num_chans, dtype="float32") else: @@ -243,7 +244,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): merge_radius_um = params["clustering"]["merge_radius_um"] threshold_diff = params["clustering"]["threshold_diff"] - post_merge_label, peak_shifts = merge_clusters( peaks, @@ -297,17 +297,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) sorting_temp = sorting_temp.save(folder=sorter_output_folder / "sorting_temp") - we = extract_waveforms( - recording, - sorting_temp, - sorter_output_folder / "waveforms_temp", - **params["templates"]) + we = extract_waveforms(recording, sorting_temp, sorter_output_folder / "waveforms_temp", **params["templates"]) # snrs = compute_snrs(we, peak_sign=params["detection"]["peak_sign"], peak_mode="extremum") # print(snrs) - - # matching_params = params["matching"].copy() # matching_params["waveform_extractor"] = we # matching_params["noise_levels"] = noise_levels diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 63b78250ed..ba2792bfd5 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -660,10 +660,10 @@ def merge( for shift in range(-num_shift, num_shift + 1): temp0 = template0[num_shift : num_samples - num_shift, :] temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :] - #d = np.mean(np.abs(temp0 - temp1)) / (norm) + # d = np.mean(np.abs(temp0 - temp1)) / (norm) # d = np.max(np.abs(temp0 - temp1)) / (norm) diff_per_channel = np.abs(temp0 - temp1) / norm - + diff_max = np.max(diff_per_channel, axis=0) # diff = np.max(diff_per_channel) @@ -675,7 +675,6 @@ def merge( # d = np.mean(diff_by_channel) # all_shift_diff.append(d) normed_diff = np.min(all_shift_diff) - is_merge = normed_diff < threshold_diff @@ -687,7 +686,6 @@ def merge( else: final_shift = 0 merge_value = np.nan - # print('merge_value', merge_value, 'final_shift', final_shift, 'is_merge', is_merge) @@ -696,7 +694,7 @@ def merge( # if DEBUG and ( 0. < normed_diff < .4): # if 0.5 < normed_diff < 4: if DEBUG and is_merge: - # if DEBUG: + # if DEBUG: import matplotlib.pyplot as plt @@ -711,8 +709,6 @@ def merge( m0 = temp0.T.flatten() m1 = temp1.T.flatten() - - ax = axs[0] ax.plot(m0, color="C0", label=f"{label0} {inds0.size}") ax.plot(m1, color="C1", label=f"{label1} {inds1.size}") @@ -724,16 +720,15 @@ def merge( ax = axs[1] - #~ temp0 = template0[num_shift : num_samples - num_shift, :] - #~ temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :] + # ~ temp0 = template0[num_shift : num_samples - num_shift, :] + # ~ temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :] ax.plot(np.abs(m0 - m1)) # ax.axhline(norm, ls='--', color='k') ax = axs[2] ax.plot(diff_per_channel.T.flatten()) - ax.axhline(threshold_diff, ls='--') + ax.axhline(threshold_diff, ls="--") ax.axhline(normed_diff) - - + # ax.axhline(normed_diff, ls='-', color='b') # ax.plot(norm, ls='--') # ax.plot(diff_by_channel) @@ -744,8 +739,6 @@ def merge( # ax.set_title(f"{norm=:.3f}") - - plt.show() return is_merge, label0, label1, final_shift, merge_value