diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index a07b06349c..448ac3b361 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -23,7 +23,7 @@ # This is to separate names when the key are tuples when saving folders -# _key_separator = "_##_" +# _key_separator = "_##_" _key_separator = "_-°°-_" diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 75a8de8f9a..782758178e 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -47,14 +47,13 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "threshold_diff": 1.5, }, "templates": { - "ms_before": 2., - "ms_after": 3., - "max_spikes_per_unit" : 400, + "ms_before": 2.0, + "ms_after": 3.0, + "max_spikes_per_unit": 400, # "peak_shift_ms": 0.2, }, # "matching": {"method": "tridesclous", "method_kwargs": {"peak_shift_ms": 0.2, "radius_um": 100.0}}, "matching": {"method": "circus-omp-svd", "method_kwargs": {}}, - "job_kwargs": {"n_jobs": -1}, "save_array": True, } @@ -63,16 +62,16 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "apply_preprocessing": "Apply internal preprocessing or not", "cache_preprocessing": "A dict contaning how to cache the preprocessed recording. mode='memory' | 'folder | 'zarr' ", "waveforms": "A dictonary containing waveforms params: ms_before, ms_after, radius_um", - "filtering": "A dictonary containing filtering params: freq_min, freq_max", - "detection": "A dictonary containing detection params: peak_sign, detect_threshold, exclude_sweep_ms, radius_um", - "selection": "A dictonary containing selection params: n_peaks_per_channel, min_n_peaks", - "svd": "A dictonary containing svd params: n_components", + "filtering": "A dictonary containing filtering params: freq_min, freq_max", + "detection": "A dictonary containing detection params: peak_sign, detect_threshold, exclude_sweep_ms, radius_um", + "selection": "A dictonary containing selection params: n_peaks_per_channel, min_n_peaks", + "svd": "A dictonary containing svd params: n_components", "clustering": "A dictonary containing clustering params: split_radius_um, merge_radius_um", "templates": "A dictonary containing waveforms params for peeler: ms_before, ms_after", "matching": "A dictonary containing matching params for matching: peak_shift_ms, radius_um", "job_kwargs": "A dictionary containing job kwargs", "save_array": "Save or not intermediate arrays", - } + } handle_multi_segment = True @@ -118,10 +117,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording = common_reference(recording) recording = zscore(recording, dtype="float32") # recording = whiten(recording, dtype="float32") - + # used only if "folder" or "zarr" cache_folder = sorter_output_folder / "cache_preprocessing" - recording = cache_preprocessing(recording, folder=cache_folder, **job_kwargs, **params["cache_preprocessing"]) + recording = cache_preprocessing( + recording, folder=cache_folder, **job_kwargs, **params["cache_preprocessing"] + ) noise_levels = np.ones(num_chans, dtype="float32") else: @@ -243,7 +244,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): merge_radius_um = params["clustering"]["merge_radius_um"] threshold_diff = params["clustering"]["threshold_diff"] - post_merge_label, peak_shifts = merge_clusters( peaks, @@ -297,17 +297,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) sorting_temp = sorting_temp.save(folder=sorter_output_folder / "sorting_temp") - we = extract_waveforms( - recording, - sorting_temp, - sorter_output_folder / "waveforms_temp", - **params["templates"]) + we = extract_waveforms(recording, sorting_temp, sorter_output_folder / "waveforms_temp", **params["templates"]) # snrs = compute_snrs(we, peak_sign=params["detection"]["peak_sign"], peak_mode="extremum") # print(snrs) - - # matching_params = params["matching"].copy() # matching_params["waveform_extractor"] = we # matching_params["noise_levels"] = noise_levels diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 63b78250ed..ba2792bfd5 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -660,10 +660,10 @@ def merge( for shift in range(-num_shift, num_shift + 1): temp0 = template0[num_shift : num_samples - num_shift, :] temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :] - #d = np.mean(np.abs(temp0 - temp1)) / (norm) + # d = np.mean(np.abs(temp0 - temp1)) / (norm) # d = np.max(np.abs(temp0 - temp1)) / (norm) diff_per_channel = np.abs(temp0 - temp1) / norm - + diff_max = np.max(diff_per_channel, axis=0) # diff = np.max(diff_per_channel) @@ -675,7 +675,6 @@ def merge( # d = np.mean(diff_by_channel) # all_shift_diff.append(d) normed_diff = np.min(all_shift_diff) - is_merge = normed_diff < threshold_diff @@ -687,7 +686,6 @@ def merge( else: final_shift = 0 merge_value = np.nan - # print('merge_value', merge_value, 'final_shift', final_shift, 'is_merge', is_merge) @@ -696,7 +694,7 @@ def merge( # if DEBUG and ( 0. < normed_diff < .4): # if 0.5 < normed_diff < 4: if DEBUG and is_merge: - # if DEBUG: + # if DEBUG: import matplotlib.pyplot as plt @@ -711,8 +709,6 @@ def merge( m0 = temp0.T.flatten() m1 = temp1.T.flatten() - - ax = axs[0] ax.plot(m0, color="C0", label=f"{label0} {inds0.size}") ax.plot(m1, color="C1", label=f"{label1} {inds1.size}") @@ -724,16 +720,15 @@ def merge( ax = axs[1] - #~ temp0 = template0[num_shift : num_samples - num_shift, :] - #~ temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :] + # ~ temp0 = template0[num_shift : num_samples - num_shift, :] + # ~ temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :] ax.plot(np.abs(m0 - m1)) # ax.axhline(norm, ls='--', color='k') ax = axs[2] ax.plot(diff_per_channel.T.flatten()) - ax.axhline(threshold_diff, ls='--') + ax.axhline(threshold_diff, ls="--") ax.axhline(normed_diff) - - + # ax.axhline(normed_diff, ls='-', color='b') # ax.plot(norm, ls='--') # ax.plot(diff_by_channel) @@ -744,8 +739,6 @@ def merge( # ax.set_title(f"{norm=:.3f}") - - plt.show() return is_merge, label0, label1, final_shift, merge_value