diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 2b1a80fade..448ac3b361 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -23,7 +23,8 @@ # This is to separate names when the key are tuples when saving folders -_key_separator = "_##_" +# _key_separator = "_##_" +_key_separator = "_-°°-_" class GroundTruthStudy: diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 6bed022273..782758178e 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -1,11 +1,10 @@ from __future__ import annotations import shutil + from .si_based import ComponentsBasedSorter from spikeinterface.core import ( - load_extractor, - BaseRecording, get_noise_levels, extract_waveforms, NumpySorting, @@ -14,10 +13,12 @@ from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore +from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore, whiten from spikeinterface.core.basesorting import minimum_spike_dtype -from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel +from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel, cache_preprocessing + +# from spikeinterface.qualitymetrics import compute_snrs import numpy as np @@ -30,6 +31,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter): _default_params = { "apply_preprocessing": True, + "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, "waveforms": { "ms_before": 0.5, "ms_after": 1.5, @@ -38,22 +40,39 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "filtering": {"freq_min": 300.0, "freq_max": 12000.0}, "detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 1.5, "radius_um": 150.0}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, - "features": {}, "svd": {"n_components": 6}, "clustering": { "split_radius_um": 40.0, "merge_radius_um": 40.0, + "threshold_diff": 1.5, }, "templates": { - "ms_before": 1.5, - "ms_after": 2.5, + "ms_before": 2.0, + "ms_after": 3.0, + "max_spikes_per_unit": 400, # "peak_shift_ms": 0.2, }, - "matching": {"peak_shift_ms": 0.2, "radius_um": 100.0}, + # "matching": {"method": "tridesclous", "method_kwargs": {"peak_shift_ms": 0.2, "radius_um": 100.0}}, + "matching": {"method": "circus-omp-svd", "method_kwargs": {}}, "job_kwargs": {"n_jobs": -1}, "save_array": True, } + _params_description = { + "apply_preprocessing": "Apply internal preprocessing or not", + "cache_preprocessing": "A dict contaning how to cache the preprocessed recording. mode='memory' | 'folder | 'zarr' ", + "waveforms": "A dictonary containing waveforms params: ms_before, ms_after, radius_um", + "filtering": "A dictonary containing filtering params: freq_min, freq_max", + "detection": "A dictonary containing detection params: peak_sign, detect_threshold, exclude_sweep_ms, radius_um", + "selection": "A dictonary containing selection params: n_peaks_per_channel, min_n_peaks", + "svd": "A dictonary containing svd params: n_components", + "clustering": "A dictonary containing clustering params: split_radius_um, merge_radius_um", + "templates": "A dictonary containing waveforms params for peeler: ms_before, ms_after", + "matching": "A dictonary containing matching params for matching: peak_shift_ms, radius_um", + "job_kwargs": "A dictionary containing job kwargs", + "save_array": "Save or not intermediate arrays", + } + handle_multi_segment = True @classmethod @@ -97,6 +116,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # TODO what is the best about zscore>common_reference or the reverse recording = common_reference(recording) recording = zscore(recording, dtype="float32") + # recording = whiten(recording, dtype="float32") + + # used only if "folder" or "zarr" + cache_folder = sorter_output_folder / "cache_preprocessing" + recording = cache_preprocessing( + recording, folder=cache_folder, **job_kwargs, **params["cache_preprocessing"] + ) + noise_levels = np.ones(num_chans, dtype="float32") else: recording = recording_raw @@ -151,22 +178,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): features_folder = sorter_output_folder / "features" node0 = PeakRetriever(recording, peaks) - # node1 = ExtractDenseWaveforms(rec, parents=[node0], return_output=False, - # ms_before=0.5, - # ms_after=1.5, - # ) - - # node2 = LocalizeCenterOfMass(rec, parents=[node0, node1], return_output=True, - # local_radius_um=75.0, - # feature="ptp", ) - - # node2 = LocalizeGridConvolution(rec, parents=[node0, node1], return_output=True, - # local_radius_um=40., - # upsampling_um=5.0, - # ) - radius_um = params["waveforms"]["radius_um"] - node3 = ExtractSparseWaveforms( + node1 = ExtractSparseWaveforms( recording, parents=[node0], return_output=True, @@ -177,12 +190,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): model_folder_path = sorter_output_folder / "tsvd_model" - node4 = TemporalPCAProjection( - recording, parents=[node0, node3], return_output=True, model_folder_path=model_folder_path + node2 = TemporalPCAProjection( + recording, parents=[node0, node1], return_output=True, model_folder_path=model_folder_path ) - # pipeline_nodes = [node0, node1, node2, node3, node4] - pipeline_nodes = [node0, node3, node4] + pipeline_nodes = [node0, node1, node2] output = run_node_pipeline( recording, @@ -195,7 +207,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) # TODO make this generic in GatherNPY ??? - sparse_mask = node3.neighbours_mask + sparse_mask = node1.neighbours_mask np.save(features_folder / "sparse_mask.npy", sparse_mask) np.save(features_folder / "peaks.npy", peaks) @@ -231,6 +243,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) merge_radius_um = params["clustering"]["merge_radius_um"] + threshold_diff = params["clustering"]["threshold_diff"] post_merge_label, peak_shifts = merge_clusters( peaks, @@ -251,7 +264,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): method="normalized_template_diff", method_kwargs=dict( waveforms_sparse_mask=sparse_mask, - threshold_diff=0.2, + threshold_diff=threshold_diff, min_cluster_size=min_cluster_size + 1, num_shift=5, ), @@ -284,29 +297,45 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) sorting_temp = sorting_temp.save(folder=sorter_output_folder / "sorting_temp") - ms_before = params["templates"]["ms_before"] - ms_after = params["templates"]["ms_after"] - max_spikes_per_unit = 300 + we = extract_waveforms(recording, sorting_temp, sorter_output_folder / "waveforms_temp", **params["templates"]) - we = extract_waveforms( - recording, - sorting_temp, - sorter_output_folder / "waveforms_temp", - ms_before=ms_before, - ms_after=ms_after, - max_spikes_per_unit=max_spikes_per_unit, - **job_kwargs, - ) + # snrs = compute_snrs(we, peak_sign=params["detection"]["peak_sign"], peak_mode="extremum") + # print(snrs) + + # matching_params = params["matching"].copy() + # matching_params["waveform_extractor"] = we + # matching_params["noise_levels"] = noise_levels + # matching_params["peak_sign"] = params["detection"]["peak_sign"] + # matching_params["detect_threshold"] = params["detection"]["detect_threshold"] + # matching_params["radius_um"] = params["detection"]["radius_um"] + + # spikes = find_spikes_from_templates( + # recording, method="tridesclous", method_kwargs=matching_params, **job_kwargs + # ) + + matching_method = params["matching"]["method"] + matching_params = params["matching"]["method_kwargs"].copy() - matching_params = params["matching"].copy() matching_params["waveform_extractor"] = we matching_params["noise_levels"] = noise_levels - matching_params["peak_sign"] = params["detection"]["peak_sign"] - matching_params["detect_threshold"] = params["detection"]["detect_threshold"] - matching_params["radius_um"] = params["detection"]["radius_um"] + # matching_params["peak_sign"] = params["detection"]["peak_sign"] + # matching_params["detect_threshold"] = params["detection"]["detect_threshold"] + # matching_params["radius_um"] = params["detection"]["radius_um"] + + # spikes = find_spikes_from_templates( + # recording, method="tridesclous", method_kwargs=matching_params, **job_kwargs + # ) + # ) + + if matching_method == "circus-omp-svd": + job_kwargs = job_kwargs.copy() + for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: + if value in job_kwargs: + job_kwargs.pop(value) + job_kwargs["chunk_duration"] = "100ms" spikes = find_spikes_from_templates( - recording, method="tridesclous", method_kwargs=matching_params, **job_kwargs + recording, method=matching_method, method_kwargs=matching_params, **job_kwargs ) if params["save_array"]: diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index a407fcf01c..ba2792bfd5 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -617,7 +617,7 @@ def merge( peaks, features, waveforms_sparse_mask=None, - threshold_diff=0.05, + threshold_diff=1.5, min_cluster_size=50, num_shift=5, ): @@ -649,32 +649,67 @@ def merge( num_samples = template0.shape[0] # norm = np.mean(np.abs(template0)) + np.mean(np.abs(template1)) norm = np.mean(np.abs(template0) + np.abs(template1)) + + # norm_per_channel = np.max(np.abs(template0) + np.abs(template1), axis=0) / 2. + norm_per_channel = (np.max(np.abs(template0), axis=0) + np.max(np.abs(template1), axis=0)) * 0.5 + # norm_per_channel = np.max(np.abs(template0)) + np.max(np.abs(template1)) / 2. + # print(norm_per_channel) + all_shift_diff = [] + # all_shift_diff_by_channel = [] for shift in range(-num_shift, num_shift + 1): temp0 = template0[num_shift : num_samples - num_shift, :] temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :] - d = np.mean(np.abs(temp0 - temp1)) / (norm) - all_shift_diff.append(d) + # d = np.mean(np.abs(temp0 - temp1)) / (norm) + # d = np.max(np.abs(temp0 - temp1)) / (norm) + diff_per_channel = np.abs(temp0 - temp1) / norm + + diff_max = np.max(diff_per_channel, axis=0) + + # diff = np.max(diff_per_channel) + diff = np.average(diff_max, weights=norm_per_channel) + # diff = np.average(diff_max) + all_shift_diff.append(diff) + # diff_by_channel = np.mean(np.abs(temp0 - temp1), axis=0) / (norm) + # all_shift_diff_by_channel.append(diff_by_channel) + # d = np.mean(diff_by_channel) + # all_shift_diff.append(d) normed_diff = np.min(all_shift_diff) is_merge = normed_diff < threshold_diff + if is_merge: merge_value = normed_diff final_shift = np.argmin(all_shift_diff) - num_shift + + # diff_by_channel = all_shift_diff_by_channel[np.argmin(all_shift_diff)] else: final_shift = 0 merge_value = np.nan - if DEBUG and normed_diff < 0.2: + # print('merge_value', merge_value, 'final_shift', final_shift, 'is_merge', is_merge) + + DEBUG = False + # DEBUG = True + # if DEBUG and ( 0. < normed_diff < .4): + # if 0.5 < normed_diff < 4: + if DEBUG and is_merge: # if DEBUG: import matplotlib.pyplot as plt - fig, ax = plt.subplots() + fig, axs = plt.subplots(nrows=3) - m0 = template0.flatten() - m1 = template1.flatten() + temp0 = template0[num_shift : num_samples - num_shift, :] + temp1 = template1[num_shift + final_shift : num_samples - num_shift + final_shift, :] + + diff_per_channel = np.abs(temp0 - temp1) / norm + diff = np.max(diff_per_channel) + m0 = temp0.T.flatten() + m1 = temp1.T.flatten() + + ax = axs[0] ax.plot(m0, color="C0", label=f"{label0} {inds0.size}") ax.plot(m1, color="C1", label=f"{label1} {inds1.size}") @@ -682,6 +717,28 @@ def merge( f"union{union_chans.size} intersect{target_chans.size} \n {normed_diff:.3f} {final_shift} {is_merge}" ) ax.legend() + + ax = axs[1] + + # ~ temp0 = template0[num_shift : num_samples - num_shift, :] + # ~ temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :] + ax.plot(np.abs(m0 - m1)) + # ax.axhline(norm, ls='--', color='k') + ax = axs[2] + ax.plot(diff_per_channel.T.flatten()) + ax.axhline(threshold_diff, ls="--") + ax.axhline(normed_diff) + + # ax.axhline(normed_diff, ls='-', color='b') + # ax.plot(norm, ls='--') + # ax.plot(diff_by_channel) + + # ax.plot(np.abs(m0) + np.abs(m1)) + + # ax.plot(np.abs(m0 - m1) / (np.abs(m0) + np.abs(m1))) + + # ax.set_title(f"{norm=:.3f}") + plt.show() return is_merge, label0, label1, final_shift, merge_value