From a2da5e9854501e645609aaf85b8bb2ef6ca90c72 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 26 Apr 2024 16:44:58 +0200 Subject: [PATCH 01/18] Squashing improvements --- .../comparison/groundtruthstudy.py | 5 +- src/spikeinterface/core/generate.py | 1 - src/spikeinterface/curation/auto_merge.py | 83 ++++-- src/spikeinterface/preprocessing/motion.py | 7 +- .../sorters/internal/spyking_circus2.py | 130 ++++++--- .../benchmark/benchmark_clustering.py | 5 +- .../benchmark/benchmark_matching.py | 4 +- .../benchmark/benchmark_tools.py | 5 +- .../sortingcomponents/clustering/circus.py | 131 +++++++--- .../clustering/clustering_tools.py | 246 ++++++++++++++---- .../clustering/random_projections.py | 17 +- .../clustering/sliding_nn.py | 9 +- .../sortingcomponents/clustering/split.py | 52 ++-- src/spikeinterface/sortingcomponents/tools.py | 12 + 14 files changed, 510 insertions(+), 197 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 55acf76203..fa2f0944d1 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -141,7 +141,10 @@ def scan_folder(self): comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle") if comparison_file.exists(): with open(comparison_file, mode="rb") as f: - self.comparisons[key] = pickle.load(f) + try: + self.comparisons[key] = pickle.load(f) + except Exception: + pass def __repr__(self): t = f"{self.__class__.__name__} {self.folder.stem} \n" diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index ec76fcbaa9..fcf9137d61 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -26,7 +26,6 @@ def _ensure_seed(seed): seed = np.random.default_rng(seed=None).integers(0, 2**63) return seed - def generate_recording( num_channels: Optional[int] = 2, sampling_frequency: Optional[float] = 30000.0, diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 77d6e54b15..326d783846 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -3,13 +3,13 @@ import numpy as np from ..core import create_sorting_analyzer +from ..core.template import Templates from ..core.template_tools import get_template_extremum_channel from ..postprocessing import compute_correlograms from ..qualitymetrics import compute_refrac_period_violations, compute_firing_rates from .mergeunitssorting import MergeUnitsSorting - def get_potential_auto_merge( sorting_analyzer, minimum_spikes=1000, @@ -30,6 +30,7 @@ def get_potential_auto_merge( firing_contamination_balance=1.5, extra_outputs=False, steps=None, + template_metric='l1' ): """ Algorithm to find and check potential merges between units. @@ -63,7 +64,7 @@ def get_potential_auto_merge( Minimum number of spikes for each unit to consider a potential merge. Enough spikes are needed to estimate the correlogram maximum_distance_um: float, default: 150 - Minimum distance between units for considering a merge + Maximum distance between units for considering a merge peak_sign: "neg" | "pos" | "both", default: "neg" Peak sign used to estimate the maximum channel of a template bin_ms: float, default: 0.25 @@ -76,6 +77,8 @@ def get_potential_auto_merge( template_diff_thresh: float, default: 0.25 The threshold on the "template distance metric" for considering a merge. It needs to be between 0 and 1 + template_metric: 'l1' + The metric to be used when comparing templates. Default is l1 norm censored_period_ms: float, default: 0.3 Used to compute the refractory period violations aka "contamination" refractory_period_ms: float, default: 1 @@ -101,6 +104,8 @@ def get_potential_auto_merge( If None all steps are done. Pontential steps: "min_spikes", "remove_contaminated", "unit_positions", "correlogram", "template_similarity", "check_increase_score". Please check steps explanations above! + template_metric: 'l1', 'l2' or 'cosine' + The metric to consider when measuring the distances between templates Returns ------- @@ -114,6 +119,7 @@ def get_potential_auto_merge( import scipy sorting = sorting_analyzer.sorting + recording = sorting_analyzer.recording unit_ids = sorting.unit_ids # to get fast computation we will not analyse pairs when: @@ -140,7 +146,7 @@ def get_potential_auto_merge( to_remove = num_spikes < minimum_spikes pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False - + # STEP 2 : remove contaminated auto corr if "remove_contaminated" in steps: contaminations, nb_violations = compute_refrac_period_violations( @@ -154,15 +160,20 @@ def get_potential_auto_merge( # STEP 3 : unit positions are estimated roughly with channel if "unit_positions" in steps: - chan_loc = sorting_analyzer.get_channel_locations() - unit_max_chan = get_template_extremum_channel( - sorting_analyzer, peak_sign=peak_sign, mode="extremum", outputs="index" - ) - unit_max_chan = list(unit_max_chan.values()) - unit_locations = chan_loc[unit_max_chan, :] + positions_ext = sorting_analyzer.get_extension("unit_locations") + if positions_ext is not None: + unit_locations = positions_ext.get_data()[:, :2] + else: + chan_loc = sorting_analyzer.get_channel_locations() + unit_max_chan = get_template_extremum_channel( + sorting_analyzer, peak_sign=peak_sign, mode="extremum", outputs="index" + ) + unit_max_chan = list(unit_max_chan.values()) + unit_locations = chan_loc[unit_max_chan, :] + unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean") pair_mask = pair_mask & (unit_distances <= maximum_distance_um) - + # STEP 4 : potential auto merge by correlogram if "correlogram" in steps: correlograms, bins = compute_correlograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba") @@ -194,10 +205,14 @@ def get_potential_auto_merge( templates_ext is not None ), "auto_merge with template_similarity requires a SortingAnalyzer with extension templates" - templates = templates_ext.get_templates(operator="average") + templates = templates_ext.get_data(outputs='Templates') + templates = templates.to_sparse(sorting_analyzer.sparsity) + templates_diff = compute_templates_diff( - sorting, templates, num_channels=num_channels, num_shift=num_shift, pair_mask=pair_mask + sorting, templates, num_channels=num_channels, num_shift=num_shift, pair_mask=pair_mask, + template_metric=template_metric ) + pair_mask = pair_mask & (templates_diff < template_diff_thresh) # STEP 6 : validate the potential merges with CC increase the contamination quality metrics @@ -378,16 +393,16 @@ def get_unit_adaptive_window(auto_corr: np.ndarray, threshold: float): return win_size -def compute_templates_diff(sorting, templates, num_channels=5, num_shift=5, pair_mask=None): +def compute_templates_diff(sorting, templates, num_channels=5, num_shift=5, pair_mask=None, template_metric='l1'): """ - Computes normalilzed template differences. + Computes normalized template differences. Parameters ---------- sorting : BaseSorting The sorting object - templates : np.array - The templates array (num_units, num_samples, num_channels) + templates : np.array or Templates + The templates array (num_units, num_samples, num_channels) or a Templates objects num_channels: int, default: 5 Number of channel to use for template similarity computation num_shift: int, default: 5 @@ -407,29 +422,51 @@ def compute_templates_diff(sorting, templates, num_channels=5, num_shift=5, pair if pair_mask is None: pair_mask = np.ones((n, n), dtype="bool") + if isinstance(templates, Templates): + adaptative_masks = (num_channels == None) and (templates.sparsity is not None) + if templates.sparsity is not None: + sparsity = templates.sparsity.mask + templates_array = templates.get_dense_templates() + else: + templates_array = templates + templates_diff = np.full((n, n), np.nan, dtype="float64") for unit_ind1 in range(n): for unit_ind2 in range(unit_ind1 + 1, n): if not pair_mask[unit_ind1, unit_ind2]: continue - template1 = templates[unit_ind1] - template2 = templates[unit_ind2] + template1 = templates_array[unit_ind1] + template2 = templates_array[unit_ind2] # take best channels - chan_inds = np.argsort(np.max(np.abs(template1 + template2), axis=0))[::-1][:num_channels] + if not adaptative_masks: + chan_inds = np.argsort(np.max(np.abs(template1 + template2), axis=0))[::-1][:num_channels] + else: + chan_inds = np.intersect1d(np.where(sparsity[unit_ind1])[0], np.where(sparsity[unit_ind2])[0]) + template1 = template1[:, chan_inds] template2 = template2[:, chan_inds] num_samples = template1.shape[0] - norm = np.sum(np.abs(template1)) + np.sum(np.abs(template2)) + if template_metric == 'l1': + norm = np.sum(np.abs(template1)) + np.sum(np.abs(template2)) + elif template_metric == 'l2': + norm = np.sum(template1**2) + np.sum(template2**2) + elif template_metric == 'cosine': + norm = np.linalg.norm(template1) * np.linalg.norm(template2) all_shift_diff = [] for shift in range(-num_shift, num_shift + 1): temp1 = template1[num_shift : num_samples - num_shift, :] temp2 = template2[num_shift + shift : num_samples - num_shift + shift, :] - d = np.sum(np.abs(temp1 - temp2)) / (norm) + if template_metric == 'l1': + d = np.sum(np.abs(temp1 - temp2)) / norm + elif template_metric == 'l2': + d = np.linalg.norm(temp1 - temp2) / norm + elif template_metric == 'cosine': + d = min(1, 1 - np.sum(temp1 * temp2) / norm) all_shift_diff.append(d) templates_diff[unit_ind1, unit_ind2] = np.min(all_shift_diff) - + return templates_diff @@ -437,7 +474,7 @@ def check_improve_contaminations_score( sorting_analyzer, pair_mask, contaminations, firing_contamination_balance, refractory_period_ms, censored_period_ms ): """ - Check that the score is improve afeter a potential merge + Check that the score is improve after a potential merge The score is a balance between: * contamination decrease diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 1b182a6436..587666ad3e 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -132,9 +132,10 @@ ), "interpolate_motion_kwargs": dict( direction=1, - border_mode="remove_channels", - spatial_interpolation_method="idw", - num_closest=3, + border_mode="force_extrapolate", + spatial_interpolation_method="kriging", + sigma_um=np.sqrt(2)*20.0, + p=2 ), }, # This preset is a super fast rigid estimation with center of mass diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 436ba1c26d..4ed8d9d2a8 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -6,15 +6,19 @@ import shutil import numpy as np -from spikeinterface.core import NumpySorting, load_extractor, BaseRecording +from spikeinterface.core import NumpySorting from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.core.template import Templates +from spikeinterface.core.template_tools import get_template_extremum_amplitude from spikeinterface.core.waveform_tools import estimate_templates -from spikeinterface.preprocessing import common_reference, zscore, whiten, highpass_filter +from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion from spikeinterface.sortingcomponents.tools import cache_preprocessing from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sparsity import compute_sparsity from spikeinterface.sortingcomponents.tools import remove_empty_templates +from spikeinterface.sortingcomponents.tools import get_prototype_spike, check_probe_for_drift_correction +from spikeinterface.sortingcomponents.clustering.clustering_tools import final_cleaning_circus try: import hdbscan @@ -30,18 +34,22 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, "sparsity": {"method": "ptp", "threshold": 0.25}, - "filtering": {"freq_min": 150}, + "filtering": {"freq_min": 150, "freq_max": 7000, "ftype" : "bessel", "filter_order" : 2}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { - "method": "smart_sampling_amplitudes", + "method": "uniform", "n_peaks_per_channel": 5000, "min_n_peaks": 100000, "select_per_channel": False, "seed": 42, }, - "clustering": {"legacy": False}, + "drift_correction" : {"preset" : "nonrigid_fast_and_accurate"}, + "merging" : {"minimum_spikes" : 10, "corr_diff_thresh" : 0.5, "template_metric" : 'cosine', + "censor_correlograms_ms" : 0.4, "num_channels" : 5}, + "clustering": {"legacy": True}, "matching": {"method": "circus-omp-svd"}, "apply_preprocessing": True, + "matched_filtering": False, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, "multi_units_only": False, "job_kwargs": {"n_jobs": 0.8}, @@ -62,14 +70,15 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): True, one other clustering called circus will be used, similar to the one used in Spyking Circus 1", "matching": "A dictionary to specify the matching engine used to recover spikes. The method default is circus-omp-svd, but other engines\ can be used", + "merging" : "A dictionary to specify the final merging param to group cells after template matching (get_potential_auto_merge)", + "motion_correction": "A dictionary to be provided if motion correction has to be performed (dense probe only)", "apply_preprocessing": "Boolean to specify whether circus 2 should preprocess the recording or not. If yes, then high_pass filtering + common\ median reference + zscore", - "shared_memory": "Boolean to specify if the code should, as much as possible, use an internal data structure in memory (faster)", "cache_preprocessing": "How to cache the preprocessed recording. Mode can be memory, file, zarr, with extra arguments. In case of memory (default), \ memory_limit will control how much RAM can be used. In case of folder or zarr, delete_cache controls if cache is cleaned after sorting", "multi_units_only": "Boolean to get only multi units activity (i.e. one template per electrode)", "job_kwargs": "A dictionary to specify how many jobs and which parameters they should used", - "debug": "Boolean to specify if the internal data structure should be kept for debugging", + "debug": "Boolean to specify if internal data structures made during the sorting should be kept for debugging", } sorter_description = """Spyking Circus 2 is a rewriting of Spyking Circus, within the SpikeInterface framework @@ -99,44 +108,73 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sampling_frequency = recording.get_sampling_frequency() num_channels = recording.get_num_channels() + ms_before = params["general"].get("ms_before", 2) + ms_after = params["general"].get("ms_after", 2) + radius_um = params["general"].get("radius_um", 100) ## First, we are filtering the data filtering_params = params["filtering"].copy() if params["apply_preprocessing"]: - recording_f = highpass_filter(recording, **filtering_params, dtype="float32") + recording_f = bandpass_filter(recording, **filtering_params, dtype="float32") if num_channels > 1: recording_f = common_reference(recording_f) else: recording_f = recording recording_f.annotate(is_filtered=True) + + valid_geometry = check_probe_for_drift_correction(recording_f) + if params["drift_correction"] is not None: + if not valid_geometry: + print("Geometry of the probe does not allow 1D drift correction") + else: + print("Motion correction activated (probe geometry compatible)") + motion_folder = sorter_output_folder / "motion" + params['drift_correction'].update({'folder' : motion_folder}) + recording_f = correct_motion(recording_f, **params['drift_correction']) + + ## We need to whiten before the template matching step, to boost the results + recording_w = whiten(recording_f, mode='local', radius_um=radius_um, dtype="float32", regularize=True) - recording_f = zscore(recording_f, dtype="float32") - noise_levels = np.ones(recording_f.get_num_channels(), dtype=np.float32) + noise_levels = get_noise_levels(recording_w, return_scaled=False) - if recording_f.check_serializability("json"): - recording_f.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None) - elif recording_f.check_serializability("pickle"): - recording_f.dump(sorter_output_folder / "preprocessed_recording.pickle", relative_to=None) + if recording_w.check_serializability("json"): + recording_w.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None) + elif recording_w.check_serializability("pickle"): + recording_w.dump(sorter_output_folder / "preprocessed_recording.pickle", relative_to=None) - recording_f = cache_preprocessing(recording_f, **job_kwargs, **params["cache_preprocessing"]) + recording_w = cache_preprocessing(recording_w, **job_kwargs, **params["cache_preprocessing"]) ## Then, we are detecting peaks with a locally_exclusive method detection_params = params["detection"].copy() detection_params.update(job_kwargs) - radius_um = params["general"].get("radius_um", 100) - if "radius_um" not in detection_params: - detection_params["radius_um"] = radius_um - if "exclude_sweep_ms" not in detection_params: - detection_params["exclude_sweep_ms"] = max(params["general"]["ms_before"], params["general"]["ms_after"]) + + detection_params["radius_um"] = detection_params.get('radius_um', 50) + detection_params["exclude_sweep_ms"] = detection_params.get('exclude_sweep_ms', 0.5) detection_params["noise_levels"] = noise_levels - peaks = detect_peaks(recording_f, method="locally_exclusive", **detection_params) + fs = recording_w.get_sampling_frequency() + nbefore = int(ms_before * fs / 1000.0) + nafter = int(ms_after * fs / 1000.0) + + peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params) + + if params["matched_filtering"]: + prototype = get_prototype_spike(recording_w, peaks, ms_before, ms_after, **job_kwargs) + detection_params["prototype"] = prototype + + for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: + if value in detection_params: + detection_params.pop(value) + + detection_params["chunk_duration"] = "100ms" + + peaks = detect_peaks(recording_w, "matched_filtering", **detection_params) if verbose: print("We found %d peaks in total" % len(peaks)) if params["multi_units_only"]: - sorting = NumpySorting.from_peaks(peaks, sampling_frequency, unit_ids=recording_f.unit_ids) + sorting = NumpySorting.from_peaks(peaks, sampling_frequency, unit_ids=recording_w.unit_ids) else: ## We subselect a subset of all the peaks, by making the distributions os SNRs over all ## channels as flat as possible @@ -156,25 +194,21 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params["waveforms"] = {} clustering_params["sparsity"] = params["sparsity"] clustering_params["radius_um"] = radius_um - - for k in ["ms_before", "ms_after"]: - clustering_params["waveforms"][k] = params["general"][k] - + clustering_params["waveforms"]["ms_before"] = ms_before + clustering_params["waveforms"]["ms_after"] = ms_after clustering_params["job_kwargs"] = job_kwargs clustering_params["noise_levels"] = noise_levels clustering_params["tmp_folder"] = sorter_output_folder / "clustering" - legacy = clustering_params.get("legacy", False) + legacy = clustering_params.get("legacy", True) if legacy: - if verbose: - print("We are using the legacy mode for the clustering") clustering_method = "circus" else: clustering_method = "random_projections" labels, peak_labels = find_cluster_from_peaks( - recording_f, selected_peaks, method=clustering_method, method_kwargs=clustering_params + recording_w, selected_peaks, method=clustering_method, method_kwargs=clustering_params ) ## We get the labels for our peaks @@ -198,11 +232,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(clustering_folder / "labels", labels) np.save(clustering_folder / "peaks", selected_peaks) - nbefore = int(params["general"]["ms_before"] * sampling_frequency / 1000.0) - nafter = int(params["general"]["ms_after"] * sampling_frequency / 1000.0) - - recording_w = whiten(recording_f, mode="local", radius_um=100.0) - templates_array = estimate_templates( recording_w, labeled_peaks, unit_ids, nbefore, nafter, return_scaled=False, job_name=None, **job_kwargs ) @@ -260,15 +289,42 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting_folder = sorter_output_folder / "sorting" if sorting_folder.exists(): shutil.rmtree(sorting_folder) + + merging_params = params["merging"].copy() + + if len(merging_params) > 0: + if params['drift_correction']: + from spikeinterface.preprocessing.motion import load_motion_info + motion_info = load_motion_info(motion_folder) + merging_params['maximum_distance_um'] = max(50, 2*np.abs(motion_info['motion']).max()) + + # peak_sign = params['detection'].get('peak_sign', 'neg') + # best_amplitudes = get_template_extremum_amplitude(templates, peak_sign=peak_sign) + # guessed_amplitudes = spikes['amplitude'].copy() + # for ind in unit_ids: + # mask = spikes['cluster_index'] == ind + # guessed_amplitudes[mask] *= best_amplitudes[ind] + + if params["debug"]: + curation_folder = sorter_output_folder / "curation" + if curation_folder.exists(): + shutil.rmtree(curation_folder) + sorting.save(folder=curation_folder) + #np.save(fitting_folder / "amplitudes", guessed_amplitudes) + + sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params) + + if verbose: + print(f"Final merging, keeping {len(sorting.unit_ids)} units") folder_to_delete = None cache_mode = params["cache_preprocessing"].get("mode", "memory") delete_cache = params["cache_preprocessing"].get("delete_cache", True) if cache_mode in ["folder", "zarr"] and delete_cache: - folder_to_delete = recording_f._kwargs["folder_path"] + folder_to_delete = recording_w._kwargs["folder_path"] - del recording_f + del recording_w if folder_to_delete is not None: shutil.rmtree(folder_to_delete) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index 9d7d202098..03d2a86345 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -49,7 +49,6 @@ def run(self, **job_kwargs): def compute_result(self, **result_params): self.noise = self.result["peak_labels"] < 0 - spikes = self.gt_sorting.to_spike_vector() self.result["sliced_gt_sorting"] = NumpySorting( spikes[self.indices], self.recording.sampling_frequency, self.gt_sorting.unit_ids @@ -301,8 +300,8 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs result = self.get_result(key) scores = result["gt_comparison"].agreement_scores - # positions = result["gt_comparison"].sorting1.get_property('gt_unit_locations') - positions = self.datasets[key[1]][1].get_property("gt_unit_locations") + positions = result["sliced_gt_sorting"].get_property('gt_unit_locations') + #positions = self.datasets[key[1]][1].get_property("gt_unit_locations") depth = positions[:, 1] analyzer = self.get_sorting_analyzer(key) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 5dd0778f76..3346be662d 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -42,13 +42,13 @@ def compute_result(self, **result_params): sorting = self.result["sorting"] comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True) self.result["gt_comparison"] = comp - self.result["gt_collision"] = CollisionGTComparison(self.gt_sorting, sorting, exhaustive_gt=True) + #self.result["gt_collision"] = CollisionGTComparison(self.gt_sorting, sorting, exhaustive_gt=True) _run_key_saved = [ ("sorting", "sorting"), ("templates", "zarr_templates"), ] - _result_key_saved = [("gt_collision", "pickle"), ("gt_comparison", "pickle")] + _result_key_saved = [("gt_comparison", "pickle")] class MatchingStudy(BenchmarkStudy): diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index 811673e525..5ce25685ab 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -229,11 +229,12 @@ def create_sorting_analyzer_gt(self, case_keys=None, return_scaled=True, random_ folder = base_folder / self.key_to_str(dataset_key) recording, gt_sorting = self.datasets[dataset_key] sorting_analyzer = create_sorting_analyzer( - gt_sorting, recording, format="binary_folder", folder=folder, return_scaled=return_scaled - ) + gt_sorting, recording, format="binary_folder", folder=folder) sorting_analyzer.compute("random_spikes", **random_params) + sorting_analyzer.compute("waveforms") sorting_analyzer.compute("templates", **job_kwargs) sorting_analyzer.compute("noise_levels") + def get_sorting_analyzer(self, case_key=None, dataset_key=None): if case_key is not None: diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 02d23c2a84..5b7675f0e6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -17,11 +17,11 @@ from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.waveform_tools import estimate_templates from .clustering_tools import remove_duplicates_via_matching -from spikeinterface.core.recording_tools import get_noise_levels +from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection -from sklearn.decomposition import TruncatedSVD +from sklearn.decomposition import TruncatedSVD, PCA from spikeinterface.core.template import Templates from spikeinterface.core.sparsity import compute_sparsity from spikeinterface.sortingcomponents.tools import remove_empty_templates @@ -31,6 +31,7 @@ ExtractSparseWaveforms, PeakRetriever, ) + from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel @@ -41,19 +42,23 @@ class CircusClustering: _default_params = { "hdbscan_kwargs": { - "min_cluster_size": 20, - "min_samples": 1, + "min_cluster_size": 25, "allow_single_cluster": True, "core_dist_n_jobs": -1, "cluster_selection_method": "eom", + #"cluster_selection_epsilon" : 5 ## To be optimized }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, "sparsity": {"method": "ptp", "threshold": 0.25}, + "recursive_kwargs" : {"recursive" : True, + "recursive_depth" : 3, + "returns_split_count" : True, + }, "radius_um": 100, - "n_svd": [5, 10], - "ms_before": 0.5, - "ms_after": 0.5, + "n_svd": [5, 2], + "ms_before": 2, + "ms_after": 2, "noise_levels": None, "tmp_folder": None, "job_kwargs": {}, @@ -66,14 +71,13 @@ def main_function(cls, recording, peaks, params): job_kwargs = fix_job_kwargs(params["job_kwargs"]) d = params - verbose = job_kwargs.get("verbose", False) + verbose = job_kwargs.get("verbose", True) fs = recording.get_sampling_frequency() ms_before = params["ms_before"] ms_after = params["ms_after"] nbefore = int(ms_before * fs / 1000.0) nafter = int(ms_after * fs / 1000.0) - if params["tmp_folder"] is None: name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) tmp_folder = get_global_tmp_folder() / name @@ -126,42 +130,89 @@ def main_function(cls, recording, peaks, params): pipeline_nodes = [node0, node1, node2] - all_pc_data = run_node_pipeline( - recording, - pipeline_nodes, - params["job_kwargs"], - job_name="extracting features", - ) - - peak_labels = -1 * np.ones(len(peaks), dtype=int) - nb_clusters = 0 - for c in np.unique(peaks["channel_index"]): - mask = peaks["channel_index"] == c - if all_pc_data.shape[1] > params["n_svd"][1]: - tsvd = TruncatedSVD(params["n_svd"][1]) - else: - tsvd = TruncatedSVD(all_pc_data.shape[1]) - sub_data = all_pc_data[mask] - hdbscan_data = tsvd.fit_transform(sub_data.reshape(len(sub_data), -1)) - try: - clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) - local_labels = clustering[0] - except Exception: - local_labels = np.zeros(len(hdbscan_data)) - valid_clusters = local_labels > -1 - if np.sum(valid_clusters) > 0: - local_labels[valid_clusters] += nb_clusters - peak_labels[mask] = local_labels - nb_clusters += len(np.unique(local_labels[valid_clusters])) - - labels = np.unique(peak_labels) - labels = labels[labels >= 0] + if len(params["recursive_kwargs"]) == 0: + + all_pc_data = run_node_pipeline( + recording, + pipeline_nodes, + params["job_kwargs"], + job_name="extracting features", + ) + + peak_labels = -1 * np.ones(len(peaks), dtype=int) + nb_clusters = 0 + for c in np.unique(peaks["channel_index"]): + mask = peaks["channel_index"] == c + sub_data = all_pc_data[mask] + sub_data = sub_data.reshape(len(sub_data), -1) + + if all_pc_data.shape[1] > params["n_svd"][1]: + tsvd = TruncatedSVD(params["n_svd"][1]) + else: + tsvd = TruncatedSVD(all_pc_data.shape[1]) + + hdbscan_data = tsvd.fit_transform(sub_data) + try: + clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) + local_labels = clustering[0] + except Exception: + local_labels = np.zeros(len(hdbscan_data)) + valid_clusters = local_labels > -1 + if np.sum(valid_clusters) > 0: + local_labels[valid_clusters] += nb_clusters + peak_labels[mask] = local_labels + nb_clusters += len(np.unique(local_labels[valid_clusters])) + else: + + features_folder = tmp_folder / "tsvd_features" + features_folder.mkdir(exist_ok=True) + + _ = run_node_pipeline( + recording, + pipeline_nodes, + params["job_kwargs"], + job_name="extracting features", + gather_mode="npy", + gather_kwargs=dict(exist_ok=True), + folder=features_folder, + names=["sparse_tsvd"], + ) + + sparse_mask = node1.neighbours_mask + neighbours_mask = get_channel_distances(recording) < radius_um + + #np.save(features_folder / "sparse_mask.npy", sparse_mask) + np.save(features_folder / "peaks.npy", peaks) + + original_labels = peaks["channel_index"] + from spikeinterface.sortingcomponents.clustering.split import split_clusters + peak_labels, _ = split_clusters( + original_labels, + recording, + features_folder, + method="local_feature_clustering", + method_kwargs=dict( + clusterer="hdbscan", + feature_name="sparse_tsvd", + neighbours_mask=neighbours_mask, + waveforms_sparse_mask=sparse_mask, + min_size_split=50, + clusterer_kwargs=d["hdbscan_kwargs"], + n_pca_features=params["n_svd"][1], + scale_n_pca_by_depth=True + ), + **params["recursive_kwargs"], + **job_kwargs, + ) + + labels, inverse = np.unique(peak_labels[peak_labels > -1], return_inverse=True) + labels = np.arange(len(labels)) spikes = np.zeros(np.sum(peak_labels > -1), dtype=minimum_spike_dtype) mask = peak_labels > -1 spikes["sample_index"] = peaks[mask]["sample_index"] spikes["segment_index"] = peaks[mask]["segment_index"] - spikes["unit_index"] = peak_labels[mask] + spikes["unit_index"] = inverse unit_ids = np.arange(len(np.unique(spikes["unit_index"]))) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index f5cd3f9eea..a4af25e59a 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -8,6 +8,7 @@ import numpy as np from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap +from spikeinterface.curation.mergeunitssorting import merge_units_sorting def _split_waveforms( @@ -535,13 +536,13 @@ def remove_duplicates( return labels, new_labels -def remove_duplicates_via_matching(templates, peak_labels, method_kwargs={}, job_kwargs={}, tmp_folder=None): + +def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, + tmp_folder=None, rank=5, multiple_passes=False): + from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface.core import BinaryRecordingExtractor, NumpyRecording, SharedMemoryRecording - from spikeinterface.core import NumpySorting - from spikeinterface.core import get_global_tmp_folder import os - from pathlib import Path job_kwargs = fix_job_kwargs(job_kwargs) @@ -553,10 +554,24 @@ def remove_duplicates_via_matching(templates, peak_labels, method_kwargs={}, job fs = templates.sampling_frequency num_chans = len(templates.channel_ids) - padding = 2 * duration + if rank is not None: + templates_array = templates.get_dense_templates().copy() + templates_array -= templates_array.mean(axis=(1, 2))[:, None, None] + + # Then we keep only the strongest components + temporal, singular, spatial = np.linalg.svd(templates_array, full_matrices=False) + temporal = temporal[:, :, :rank] + singular = singular[:, :rank] + spatial = spatial[:, :rank, :] + + templates_array = np.matmul(temporal * singular[:, np.newaxis, :], spatial) + + norms = np.linalg.norm(templates_array, axis=(1, 2)) + margin = max(templates.nbefore, templates.nafter) tmp_filename = None - zdata = templates_array.reshape(nb_templates * duration, num_chans) - blank = np.zeros((2 * duration, num_chans), dtype=zdata.dtype) + zdata = np.hstack(((templates_array, np.zeros((nb_templates, margin, num_chans))))) + zdata = zdata.reshape(nb_templates * (duration + margin), num_chans) + blank = np.zeros((margin, num_chans), dtype=zdata.dtype) zdata = np.vstack((blank, zdata, blank)) if tmp_folder is not None: @@ -575,49 +590,101 @@ def remove_duplicates_via_matching(templates, peak_labels, method_kwargs={}, job recording = recording.set_probe(templates.probe) recording.annotate(is_filtered=True) - margin = 2 * max(templates.nbefore, templates.nafter) - half_marging = margin // 2 - local_params = method_kwargs.copy() + amplitudes = [0.95, 1.05] - local_params.update({"templates": templates, "amplitudes": [0.95, 1.05]}) + local_params.update({"templates": templates, + "amplitudes": amplitudes, + "stop_criteria": "omp_min_sps", + "omp_min_sps" : 0.5}) ignore_ids = [] similar_templates = [[], []] - for i in range(nb_templates): - t_start = padding + i * duration - t_stop = padding + (i + 1) * duration + keep_searching = True + + DEBUG = False + while keep_searching: + + keep_searching = False + + for i in list(set(range(nb_templates)).difference(ignore_ids)): + + ## Could be speed up by only computing the values for templates that are + ## nearby + + t_start = i*(duration + margin) + t_stop = margin + (i + 1) * (duration + margin) + + sub_recording = recording.frame_slice(t_start, t_stop) + local_params.update({"ignored_ids": ignore_ids + [i]}) + spikes, computed = find_spikes_from_templates( + sub_recording, method="circus-omp-svd", method_kwargs=local_params, extra_outputs=True, **job_kwargs + ) + local_params.update( + { + "overlaps": computed["overlaps"], + "normed_templates": computed["normed_templates"], + "norms": computed["norms"], + "temporal": computed["temporal"], + "spatial": computed["spatial"], + "singular": computed["singular"], + "units_overlaps": computed["units_overlaps"], + "unit_overlaps_indices": computed["unit_overlaps_indices"], + } + ) + valid = (spikes["sample_index"] >= 0) * (spikes["sample_index"] < duration + 2*margin) + + if np.sum(valid) > 0: + ref_norm = norms[i] - sub_recording = recording.frame_slice(t_start - half_marging, t_stop + half_marging) - local_params.update({"ignored_ids": ignore_ids + [i]}) - spikes, computed = find_spikes_from_templates( - sub_recording, method="circus-omp-svd", method_kwargs=local_params, extra_outputs=True, **job_kwargs - ) - local_params.update( - { - "overlaps": computed["overlaps"], - "normed_templates": computed["normed_templates"], - "norms": computed["norms"], - "temporal": computed["temporal"], - "spatial": computed["spatial"], - "singular": computed["singular"], - "units_overlaps": computed["units_overlaps"], - "unit_overlaps_indices": computed["unit_overlaps_indices"], - } - ) - valid = (spikes["sample_index"] >= half_marging) * (spikes["sample_index"] < duration + half_marging) - if np.sum(valid) > 0: - if np.sum(valid) == 1: j = spikes[valid]["cluster_index"][0] - ignore_ids += [i] - similar_templates[1] += [i] - similar_templates[0] += [j] - elif np.sum(valid) > 1: - similar_templates[0] += [-1] - ignore_ids += [i] - similar_templates[1] += [i] + sum = templates_array[j] + for k in range(1, np.sum(valid)): + j = spikes[valid]["cluster_index"][k] + a = spikes[valid]["amplitude"][k] + sum += a*templates_array[j] + + tgt_norm = np.linalg.norm(sum) + ratio = tgt_norm / ref_norm + + if (amplitudes[0] < ratio) and (ratio < amplitudes[1]): + if multiple_passes: + keep_searching = True + if np.sum(valid) == 1: + ignore_ids += [i] + similar_templates[1] += [i] + similar_templates[0] += [j] + elif np.sum(valid) > 1: + similar_templates[0] += [-1] + ignore_ids += [i] + similar_templates[1] += [i] + + if DEBUG: + import pylab as plt + fig, axes = plt.subplots(1, 2) + from spikeinterface.widgets import plot_traces + plot_traces(sub_recording, ax=axes[0]) + axes[1].plot(templates_array[i].flatten(), label=f'{ref_norm}') + axes[1].plot(sum.flatten(), label=f'{tgt_norm}') + axes[1].legend() + plt.show() + print(i, spikes[valid]["cluster_index"], spikes[valid]["amplitude"]) + + del recording, sub_recording, local_params, templates + if tmp_filename is not None: + os.remove(tmp_filename) + + return similar_templates + +def remove_duplicates_via_matching(templates, peak_labels, method_kwargs={}, job_kwargs={}, + tmp_folder=None, rank=5, multiple_passes=False): + + + similar_templates = detect_mixtures(templates, method_kwargs, job_kwargs, + tmp_folder=tmp_folder, rank=rank, multiple_passes=multiple_passes) + new_labels = peak_labels.copy() labels = np.unique(new_labels) @@ -630,13 +697,100 @@ def remove_duplicates_via_matching(templates, peak_labels, method_kwargs={}, job labels = np.unique(new_labels) labels = labels[labels >= 0] - del recording, sub_recording, local_params, templates - if tmp_filename is not None: - os.remove(tmp_filename) - return labels, new_labels +def resolve_merging_graph(sorting, potential_merges): + """ + Function to provide, given a list of potential_merges, a resolved merging + graph based on the connected components. + """ + from scipy.sparse.csgraph import connected_components + from scipy.sparse import lil_matrix + n = len(sorting.unit_ids) + graph = lil_matrix((n, n)) + for i, j in potential_merges: + graph[sorting.id_to_index(i), sorting.id_to_index(j)] = 1 + + n_components, labels = connected_components(graph, directed=True, connection='weak', return_labels=True) + final_merges = [] + for i in range(n_components): + merges = labels == i + if merges.sum() > 1: + src = np.where(merges)[0][0] + tgts = np.where(merges)[0][1:] + final_merges += [(sorting.unit_ids[src], sorting.unit_ids[tgts])] + + return final_merges + +def apply_merges_to_sorting(sorting, merges, censor_ms=0.4): + """ + Function to apply a resolved representation of the merges to a sorting object. If censor_ms is not None, + duplicated spikes violating the censor_ms refractory period are removed + """ + spikes = sorting.to_spike_vector().copy() + to_keep = np.ones(len(spikes), dtype=bool) + + segment_slices = {} + for segment_index in range(sorting.get_num_segments()): + s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1], side="left") + segment_slices[segment_index] = (s0, s1) + + if censor_ms is not None: + rpv = int(sorting.sampling_frequency * censor_ms / 1000) + + for src, targets in merges: + mask = np.in1d(spikes['unit_index'], sorting.ids_to_indices([src] + list(targets))) + spikes['unit_index'][mask] = sorting.id_to_index(src) + + if censor_ms is not None: + for segment_index in range(sorting.get_num_segments()): + s0, s1 = segment_slices[segment_index] + (indices,) = s0 + np.nonzero(mask[s0:s1]) + to_keep[indices[1:]] = np.logical_or( + to_keep[indices[1:]], np.diff(spikes[indices]["sample_index"]) > rpv + ) + + from spikeinterface.core.numpyextractors import NumpySorting + times_list = [] + labels_list = [] + for segment_index in range(sorting.get_num_segments()): + s0, s1 = segment_slices[segment_index] + if censor_ms is not None: + times_list += [spikes['sample_index'][s0:s1][to_keep[s0:s1]]] + labels_list += [spikes['unit_index'][s0:s1][to_keep[s0:s1]]] + else: + times_list += [spikes['sample_index'][s0:s1]] + labels_list += [spikes['unit_index'][s0:s1]] + + sorting = NumpySorting.from_times_labels(times_list, labels_list, sorting.sampling_frequency) + return sorting + + + +def final_cleaning_circus(recording, sorting, templates, + **merging_kwargs): + + from spikeinterface.core.sortinganalyzer import create_sorting_analyzer + from spikeinterface.curation.auto_merge import get_potential_auto_merge + + sparsity = templates.sparsity + templates_array = templates.get_dense_templates().copy() + + sa = create_sorting_analyzer(sorting, recording, format="memory", sparsity=sparsity) + from spikeinterface.core.analyzer_extension_core import ComputeTemplates + sa.extensions['templates'] = ComputeTemplates(sa) + sa.extensions['templates'].params = {'nbefore' : templates.nbefore} + sa.extensions['templates'].data['average'] = templates_array + sa.compute('unit_locations', method='monopolar_triangulation') + merges = get_potential_auto_merge(sa, **merging_kwargs) + merges = resolve_merging_graph(sorting, merges) + sorting = apply_merges_to_sorting(sorting, merges) + #sorting = merge_units_sorting(sorting, merges) + + return sorting + + def remove_duplicates_via_dip(wfs_arrays, peak_labels, dip_threshold=1, cosine_threshold=None): import sklearn @@ -731,4 +885,4 @@ def remove_duplicates_via_dip(wfs_arrays, peak_labels, dip_threshold=1, cosine_t labels = np.unique(new_labels) labels = labels[labels >= 0] - return labels, new_labels + return labels, new_labels \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index dc483a5b96..f906cd2945 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -16,7 +16,7 @@ from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.waveform_tools import estimate_templates from .clustering_tools import remove_duplicates_via_matching -from spikeinterface.core.recording_tools import get_noise_levels +from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature @@ -37,16 +37,16 @@ class RandomProjectionClustering: _default_params = { "hdbscan_kwargs": { - "min_cluster_size": 20, + "min_cluster_size": 10, "allow_single_cluster": True, "core_dist_n_jobs": -1, "cluster_selection_method": "leaf", - "cluster_selection_epsilon": 2, + "cluster_selection_epsilon": 1, }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, "sparsity": {"method": "ptp", "threshold": 0.25}, - "radius_um": 100, + "radius_um": 30, "nb_projections": 10, "feature": "energy", "ms_before": 0.5, @@ -65,7 +65,7 @@ def main_function(cls, recording, peaks, params): job_kwargs = fix_job_kwargs(params["job_kwargs"]) d = params - verbose = job_kwargs.get("verbose", False) + verbose = job_kwargs.get("verbose", True) fs = recording.get_sampling_frequency() radius_um = params["radius_um"] @@ -87,10 +87,9 @@ def main_function(cls, recording, peaks, params): node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params["smoothing_kwargs"]) num_projections = min(num_chans, d["nb_projections"]) - projections = rng.randn(num_chans, num_projections) - if num_chans > 1: - projections -= projections.mean() - projections /= projections.std() + projections = rng.normal( + loc=0.0, scale=1.0 / np.sqrt(num_chans), size=(num_chans, num_projections) + ) nbefore = int(params["ms_before"] * fs / 1000) nafter = int(params["ms_after"] * fs / 1000) diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py index b8f29fcd9a..d9c325d4e6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py @@ -26,7 +26,7 @@ HAVE_HDBSCAN = True except: HAVE_HDBSCAN = False -import copy + from scipy.sparse import coo_matrix try: @@ -57,8 +57,8 @@ class SlidingNNClustering: "time_window_s": 5, "hdbscan_kwargs": {"min_cluster_size": 20, "allow_single_cluster": True}, "margin_ms": 100, - "ms_before": 1, - "ms_after": 1, + "ms_before": 0.5, + "ms_after": 0.5, "n_channel_neighbors": 8, "n_neighbors": 5, "embedding_dim": None, @@ -141,13 +141,10 @@ def main_function(cls, recording, peaks, params): # prepare neighborhood parameters fs = recording.get_sampling_frequency() n_frames = recording.get_num_frames() - duration = n_frames / fs time_window_frames = fs * d["time_window_s"] margin_frames = int(d["margin_ms"] / 1000 * fs) spike_pre_frames = int(d["ms_before"] / 1000 * fs) spike_post_frames = int(d["ms_after"] / 1000 * fs) - n_channels = recording.get_num_channels() - n_samples = spike_pre_frames + spike_post_frames if d["embedding_dim"] is None: d["embedding_dim"] = recording.get_num_channels() diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 3861e7fe83..cddc57103a 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -4,7 +4,7 @@ from threadpoolctl import threadpool_limits from tqdm.auto import tqdm -from sklearn.decomposition import TruncatedSVD +from sklearn.decomposition import TruncatedSVD, PCA import numpy as np @@ -83,7 +83,7 @@ def split_clusters( for label in labels_set: peak_indices = np.flatnonzero(peak_labels == label) if peak_indices.size > 0: - jobs.append(pool.submit(split_function_wrapper, peak_indices)) + jobs.append(pool.submit(split_function_wrapper, peak_indices, recursion_level=1)) if progress_bar: iterator = tqdm(jobs, desc=f"split_clusters with {method}", total=len(labels_set)) @@ -104,9 +104,10 @@ def split_clusters( current_max_label += np.max(local_labels[mask]) + 1 if recursive: + recursion_level = np.max(split_count[peak_indices]) if recursive_depth is not None: # stop reccursivity when recursive_depth is reach - extra_ball = np.max(split_count[peak_indices]) < recursive_depth + extra_ball = recursion_level < recursive_depth else: # reccurssive always extra_ball = True @@ -116,7 +117,7 @@ def split_clusters( for label in new_labels_set: peak_indices = np.flatnonzero(peak_labels == label) if peak_indices.size > 0: - jobs.append(pool.submit(split_function_wrapper, peak_indices)) + jobs.append(pool.submit(split_function_wrapper, peak_indices, recursion_level)) if progress_bar: iterator.total += 1 @@ -146,11 +147,11 @@ def split_worker_init( _ctx["peaks"] = _ctx["features"]["peaks"] -def split_function_wrapper(peak_indices): +def split_function_wrapper(peak_indices, recursion_level): global _ctx with threadpool_limits(limits=_ctx["max_threads_per_process"]): is_split, local_labels = _ctx["method_class"].split( - peak_indices, _ctx["peaks"], _ctx["features"], **_ctx["method_kwargs"] + peak_indices, _ctx["peaks"], _ctx["features"], recursion_level, **_ctx["method_kwargs"] ) return is_split, local_labels, peak_indices @@ -163,7 +164,7 @@ class LocalFeatureClustering: The idea simple : * agregate features (svd or even waveforms) with sparse channel. - * run a local feature reduction (pca or svd) + * run a local feature reduction (pca or svd) * try a new split (hdscan or isocut5) """ @@ -174,14 +175,15 @@ def split( peak_indices, peaks, features, + recursion_level=1, clusterer="hdbscan", feature_name="sparse_tsvd", neighbours_mask=None, waveforms_sparse_mask=None, + clusterer_kwargs={'min_cluster_size' : 25}, min_size_split=25, - min_cluster_size=25, - min_samples=25, n_pca_features=2, + scale_n_pca_by_depth=False, minimum_common_channels=2, ): local_labels = np.zeros(peak_indices.size, dtype=np.int64) @@ -190,7 +192,7 @@ def split( sparse_features = features[feature_name] assert waveforms_sparse_mask is not None - + # target channel subset is done intersect local channels + neighbours local_chans = np.unique(peaks["channel_index"][peak_indices]) @@ -213,24 +215,27 @@ def split( aligned_wfs = aligned_wfs[kept, :, :] flatten_features = aligned_wfs.reshape(aligned_wfs.shape[0], -1) - - # final_features = PCA(n_pca_features, whiten=True).fit_transform(flatten_features) - # final_features = PCA(n_pca_features, whiten=False).fit_transform(flatten_features) - final_features = TruncatedSVD(n_pca_features).fit_transform(flatten_features) + + if flatten_features.shape[1] > n_pca_features: + if scale_n_pca_by_depth: + #tsvd = TruncatedSVD(n_pca_features * recursion_level) + tsvd = PCA(n_pca_features * recursion_level, whiten=True) + else: + #tsvd = TruncatedSVD(n_pca_features) + tsvd = PCA(n_pca_features, whiten=True) + final_features = tsvd.fit_transform(flatten_features) + else: + final_features = flatten_features + if clusterer == "hdbscan": from hdbscan import HDBSCAN - - clust = HDBSCAN( - min_cluster_size=min_cluster_size, - min_samples=min_samples, - allow_single_cluster=True, - cluster_selection_method="leaf", - ) + clust = HDBSCAN(**clusterer_kwargs) clust.fit(final_features) possible_labels = clust.labels_ is_split = np.setdiff1d(possible_labels, [-1]).size > 1 elif clusterer == "isocut5": + min_cluster_size = clusterer_kwargs['min_cluster_size'] dipscore, cutpoint = isocut5(final_features[:, 0]) possible_labels = np.zeros(final_features.shape[0]) if dipscore > 1.5: @@ -243,7 +248,7 @@ def split( else: raise ValueError(f"wrong clusterer {clusterer}") - # DEBUG = True + #DEBUG = True DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -265,8 +270,7 @@ def split( ax = axs[1] ax.plot(flatten_wfs[mask][sl].T, color=colors[k], alpha=0.5) - axs[0].set_title(f"{clusterer} {is_split}") - + axs[0].set_title(f"{clusterer} {is_split} {peak_indices[0]} {np.unique(possible_labels)}") plt.show() if not is_split: diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 3d7e40da14..e5c7789da6 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -84,6 +84,18 @@ def get_prototype_spike(recording, peaks, ms_before=0.5, ms_after=0.5, nb_peaks= return prototype +def check_probe_for_drift_correction(recording, dist_x_max=60): + num_channels = recording.get_num_channels() + if num_channels < 32: + return False + else: + locations = recording.get_channel_locations() + x_min = locations[:, 0].min() + x_max = locations[:, 0].max() + if np.abs(x_max - x_min) > dist_x_max: + return False + return True + def cache_preprocessing(recording, mode="memory", memory_limit=0.5, delete_cache=True, **extra_kwargs): save_kwargs, job_kwargs = split_job_kwargs(extra_kwargs) From 57b36366e98ff9b48eeb714aa2ff35128c3a14fa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 26 Apr 2024 14:46:42 +0000 Subject: [PATCH 02/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/generate.py | 1 + src/spikeinterface/curation/auto_merge.py | 39 ++++---- src/spikeinterface/preprocessing/motion.py | 4 +- .../sorters/internal/spyking_circus2.py | 38 ++++---- .../benchmark/benchmark_clustering.py | 4 +- .../benchmark/benchmark_matching.py | 2 +- .../benchmark/benchmark_tools.py | 4 +- .../sortingcomponents/clustering/circus.py | 22 +++-- .../clustering/clustering_tools.py | 92 ++++++++++--------- .../clustering/random_projections.py | 4 +- .../sortingcomponents/clustering/split.py | 16 ++-- src/spikeinterface/sortingcomponents/tools.py | 3 +- 12 files changed, 121 insertions(+), 108 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index fcf9137d61..ec76fcbaa9 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -26,6 +26,7 @@ def _ensure_seed(seed): seed = np.random.default_rng(seed=None).integers(0, 2**63) return seed + def generate_recording( num_channels: Optional[int] = 2, sampling_frequency: Optional[float] = 30000.0, diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 326d783846..f28ec16fb3 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -10,6 +10,7 @@ from .mergeunitssorting import MergeUnitsSorting + def get_potential_auto_merge( sorting_analyzer, minimum_spikes=1000, @@ -30,7 +31,7 @@ def get_potential_auto_merge( firing_contamination_balance=1.5, extra_outputs=False, steps=None, - template_metric='l1' + template_metric="l1", ): """ Algorithm to find and check potential merges between units. @@ -146,7 +147,7 @@ def get_potential_auto_merge( to_remove = num_spikes < minimum_spikes pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False - + # STEP 2 : remove contaminated auto corr if "remove_contaminated" in steps: contaminations, nb_violations = compute_refrac_period_violations( @@ -170,10 +171,10 @@ def get_potential_auto_merge( ) unit_max_chan = list(unit_max_chan.values()) unit_locations = chan_loc[unit_max_chan, :] - + unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean") pair_mask = pair_mask & (unit_distances <= maximum_distance_um) - + # STEP 4 : potential auto merge by correlogram if "correlogram" in steps: correlograms, bins = compute_correlograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba") @@ -205,14 +206,18 @@ def get_potential_auto_merge( templates_ext is not None ), "auto_merge with template_similarity requires a SortingAnalyzer with extension templates" - templates = templates_ext.get_data(outputs='Templates') + templates = templates_ext.get_data(outputs="Templates") templates = templates.to_sparse(sorting_analyzer.sparsity) - + templates_diff = compute_templates_diff( - sorting, templates, num_channels=num_channels, num_shift=num_shift, pair_mask=pair_mask, - template_metric=template_metric + sorting, + templates, + num_channels=num_channels, + num_shift=num_shift, + pair_mask=pair_mask, + template_metric=template_metric, ) - + pair_mask = pair_mask & (templates_diff < template_diff_thresh) # STEP 6 : validate the potential merges with CC increase the contamination quality metrics @@ -393,7 +398,7 @@ def get_unit_adaptive_window(auto_corr: np.ndarray, threshold: float): return win_size -def compute_templates_diff(sorting, templates, num_channels=5, num_shift=5, pair_mask=None, template_metric='l1'): +def compute_templates_diff(sorting, templates, num_channels=5, num_shift=5, pair_mask=None, template_metric="l1"): """ Computes normalized template differences. @@ -448,25 +453,25 @@ def compute_templates_diff(sorting, templates, num_channels=5, num_shift=5, pair template2 = template2[:, chan_inds] num_samples = template1.shape[0] - if template_metric == 'l1': + if template_metric == "l1": norm = np.sum(np.abs(template1)) + np.sum(np.abs(template2)) - elif template_metric == 'l2': + elif template_metric == "l2": norm = np.sum(template1**2) + np.sum(template2**2) - elif template_metric == 'cosine': + elif template_metric == "cosine": norm = np.linalg.norm(template1) * np.linalg.norm(template2) all_shift_diff = [] for shift in range(-num_shift, num_shift + 1): temp1 = template1[num_shift : num_samples - num_shift, :] temp2 = template2[num_shift + shift : num_samples - num_shift + shift, :] - if template_metric == 'l1': + if template_metric == "l1": d = np.sum(np.abs(temp1 - temp2)) / norm - elif template_metric == 'l2': + elif template_metric == "l2": d = np.linalg.norm(temp1 - temp2) / norm - elif template_metric == 'cosine': + elif template_metric == "cosine": d = min(1, 1 - np.sum(temp1 * temp2) / norm) all_shift_diff.append(d) templates_diff[unit_ind1, unit_ind2] = np.min(all_shift_diff) - + return templates_diff diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 587666ad3e..9434b15360 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -134,8 +134,8 @@ direction=1, border_mode="force_extrapolate", spatial_interpolation_method="kriging", - sigma_um=np.sqrt(2)*20.0, - p=2 + sigma_um=np.sqrt(2) * 20.0, + p=2, ), }, # This preset is a super fast rigid estimation with center of mass diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 4ed8d9d2a8..3cedc430d3 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -34,7 +34,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, "sparsity": {"method": "ptp", "threshold": 0.25}, - "filtering": {"freq_min": 150, "freq_max": 7000, "ftype" : "bessel", "filter_order" : 2}, + "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { "method": "uniform", @@ -43,9 +43,14 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "select_per_channel": False, "seed": 42, }, - "drift_correction" : {"preset" : "nonrigid_fast_and_accurate"}, - "merging" : {"minimum_spikes" : 10, "corr_diff_thresh" : 0.5, "template_metric" : 'cosine', - "censor_correlograms_ms" : 0.4, "num_channels" : 5}, + "drift_correction": {"preset": "nonrigid_fast_and_accurate"}, + "merging": { + "minimum_spikes": 10, + "corr_diff_thresh": 0.5, + "template_metric": "cosine", + "censor_correlograms_ms": 0.4, + "num_channels": 5, + }, "clustering": {"legacy": True}, "matching": {"method": "circus-omp-svd"}, "apply_preprocessing": True, @@ -70,7 +75,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): True, one other clustering called circus will be used, similar to the one used in Spyking Circus 1", "matching": "A dictionary to specify the matching engine used to recover spikes. The method default is circus-omp-svd, but other engines\ can be used", - "merging" : "A dictionary to specify the final merging param to group cells after template matching (get_potential_auto_merge)", + "merging": "A dictionary to specify the final merging param to group cells after template matching (get_potential_auto_merge)", "motion_correction": "A dictionary to be provided if motion correction has to be performed (dense probe only)", "apply_preprocessing": "Boolean to specify whether circus 2 should preprocess the recording or not. If yes, then high_pass filtering + common\ median reference + zscore", @@ -121,7 +126,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: recording_f = recording recording_f.annotate(is_filtered=True) - + valid_geometry = check_probe_for_drift_correction(recording_f) if params["drift_correction"] is not None: if not valid_geometry: @@ -129,11 +134,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: print("Motion correction activated (probe geometry compatible)") motion_folder = sorter_output_folder / "motion" - params['drift_correction'].update({'folder' : motion_folder}) - recording_f = correct_motion(recording_f, **params['drift_correction']) + params["drift_correction"].update({"folder": motion_folder}) + recording_f = correct_motion(recording_f, **params["drift_correction"]) ## We need to whiten before the template matching step, to boost the results - recording_w = whiten(recording_f, mode='local', radius_um=radius_um, dtype="float32", regularize=True) + recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32", regularize=True) noise_levels = get_noise_levels(recording_w, return_scaled=False) @@ -147,9 +152,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## Then, we are detecting peaks with a locally_exclusive method detection_params = params["detection"].copy() detection_params.update(job_kwargs) - - detection_params["radius_um"] = detection_params.get('radius_um', 50) - detection_params["exclude_sweep_ms"] = detection_params.get('exclude_sweep_ms', 0.5) + + detection_params["radius_um"] = detection_params.get("radius_um", 50) + detection_params["exclude_sweep_ms"] = detection_params.get("exclude_sweep_ms", 0.5) detection_params["noise_levels"] = noise_levels fs = recording_w.get_sampling_frequency() @@ -289,14 +294,15 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting_folder = sorter_output_folder / "sorting" if sorting_folder.exists(): shutil.rmtree(sorting_folder) - + merging_params = params["merging"].copy() if len(merging_params) > 0: - if params['drift_correction']: + if params["drift_correction"]: from spikeinterface.preprocessing.motion import load_motion_info + motion_info = load_motion_info(motion_folder) - merging_params['maximum_distance_um'] = max(50, 2*np.abs(motion_info['motion']).max()) + merging_params["maximum_distance_um"] = max(50, 2 * np.abs(motion_info["motion"]).max()) # peak_sign = params['detection'].get('peak_sign', 'neg') # best_amplitudes = get_template_extremum_amplitude(templates, peak_sign=peak_sign) @@ -310,7 +316,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if curation_folder.exists(): shutil.rmtree(curation_folder) sorting.save(folder=curation_folder) - #np.save(fitting_folder / "amplitudes", guessed_amplitudes) + # np.save(fitting_folder / "amplitudes", guessed_amplitudes) sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index 03d2a86345..7c66fecb44 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -300,8 +300,8 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs result = self.get_result(key) scores = result["gt_comparison"].agreement_scores - positions = result["sliced_gt_sorting"].get_property('gt_unit_locations') - #positions = self.datasets[key[1]][1].get_property("gt_unit_locations") + positions = result["sliced_gt_sorting"].get_property("gt_unit_locations") + # positions = self.datasets[key[1]][1].get_property("gt_unit_locations") depth = positions[:, 1] analyzer = self.get_sorting_analyzer(key) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 3346be662d..717cc7d7c5 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -42,7 +42,7 @@ def compute_result(self, **result_params): sorting = self.result["sorting"] comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True) self.result["gt_comparison"] = comp - #self.result["gt_collision"] = CollisionGTComparison(self.gt_sorting, sorting, exhaustive_gt=True) + # self.result["gt_collision"] = CollisionGTComparison(self.gt_sorting, sorting, exhaustive_gt=True) _run_key_saved = [ ("sorting", "sorting"), diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index 5ce25685ab..358a30b5aa 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -228,13 +228,11 @@ def create_sorting_analyzer_gt(self, case_keys=None, return_scaled=True, random_ # the waveforms depend on the dataset key folder = base_folder / self.key_to_str(dataset_key) recording, gt_sorting = self.datasets[dataset_key] - sorting_analyzer = create_sorting_analyzer( - gt_sorting, recording, format="binary_folder", folder=folder) + sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="binary_folder", folder=folder) sorting_analyzer.compute("random_spikes", **random_params) sorting_analyzer.compute("waveforms") sorting_analyzer.compute("templates", **job_kwargs) sorting_analyzer.compute("noise_levels") - def get_sorting_analyzer(self, case_key=None, dataset_key=None): if case_key is not None: diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 5b7675f0e6..38e6169ee8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -46,15 +46,16 @@ class CircusClustering: "allow_single_cluster": True, "core_dist_n_jobs": -1, "cluster_selection_method": "eom", - #"cluster_selection_epsilon" : 5 ## To be optimized + # "cluster_selection_epsilon" : 5 ## To be optimized }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, "sparsity": {"method": "ptp", "threshold": 0.25}, - "recursive_kwargs" : {"recursive" : True, - "recursive_depth" : 3, - "returns_split_count" : True, - }, + "recursive_kwargs": { + "recursive": True, + "recursive_depth": 3, + "returns_split_count": True, + }, "radius_um": 100, "n_svd": [5, 2], "ms_before": 2, @@ -143,14 +144,14 @@ def main_function(cls, recording, peaks, params): nb_clusters = 0 for c in np.unique(peaks["channel_index"]): mask = peaks["channel_index"] == c - sub_data = all_pc_data[mask] + sub_data = all_pc_data[mask] sub_data = sub_data.reshape(len(sub_data), -1) if all_pc_data.shape[1] > params["n_svd"][1]: tsvd = TruncatedSVD(params["n_svd"][1]) else: tsvd = TruncatedSVD(all_pc_data.shape[1]) - + hdbscan_data = tsvd.fit_transform(sub_data) try: clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) @@ -163,7 +164,7 @@ def main_function(cls, recording, peaks, params): peak_labels[mask] = local_labels nb_clusters += len(np.unique(local_labels[valid_clusters])) else: - + features_folder = tmp_folder / "tsvd_features" features_folder.mkdir(exist_ok=True) @@ -181,11 +182,12 @@ def main_function(cls, recording, peaks, params): sparse_mask = node1.neighbours_mask neighbours_mask = get_channel_distances(recording) < radius_um - #np.save(features_folder / "sparse_mask.npy", sparse_mask) + # np.save(features_folder / "sparse_mask.npy", sparse_mask) np.save(features_folder / "peaks.npy", peaks) original_labels = peaks["channel_index"] from spikeinterface.sortingcomponents.clustering.split import split_clusters + peak_labels, _ = split_clusters( original_labels, recording, @@ -199,7 +201,7 @@ def main_function(cls, recording, peaks, params): min_size_split=50, clusterer_kwargs=d["hdbscan_kwargs"], n_pca_features=params["n_svd"][1], - scale_n_pca_by_depth=True + scale_n_pca_by_depth=True, ), **params["recursive_kwargs"], **job_kwargs, diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index a4af25e59a..8fa20e48fe 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -536,9 +536,7 @@ def remove_duplicates( return labels, new_labels - -def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, - tmp_folder=None, rank=5, multiple_passes=False): +def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None, rank=5, multiple_passes=False): from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface.core import BinaryRecordingExtractor, NumpyRecording, SharedMemoryRecording @@ -565,7 +563,7 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, spatial = spatial[:, :rank, :] templates_array = np.matmul(temporal * singular[:, np.newaxis, :], spatial) - + norms = np.linalg.norm(templates_array, axis=(1, 2)) margin = max(templates.nbefore, templates.nafter) tmp_filename = None @@ -593,10 +591,9 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, local_params = method_kwargs.copy() amplitudes = [0.95, 1.05] - local_params.update({"templates": templates, - "amplitudes": amplitudes, - "stop_criteria": "omp_min_sps", - "omp_min_sps" : 0.5}) + local_params.update( + {"templates": templates, "amplitudes": amplitudes, "stop_criteria": "omp_min_sps", "omp_min_sps": 0.5} + ) ignore_ids = [] similar_templates = [[], []] @@ -609,11 +606,11 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, keep_searching = False for i in list(set(range(nb_templates)).difference(ignore_ids)): - + ## Could be speed up by only computing the values for templates that are ## nearby - t_start = i*(duration + margin) + t_start = i * (duration + margin) t_stop = margin + (i + 1) * (duration + margin) sub_recording = recording.frame_slice(t_start, t_stop) @@ -633,7 +630,7 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, "unit_overlaps_indices": computed["unit_overlaps_indices"], } ) - valid = (spikes["sample_index"] >= 0) * (spikes["sample_index"] < duration + 2*margin) + valid = (spikes["sample_index"] >= 0) * (spikes["sample_index"] < duration + 2 * margin) if np.sum(valid) > 0: ref_norm = norms[i] @@ -643,10 +640,10 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, for k in range(1, np.sum(valid)): j = spikes[valid]["cluster_index"][k] a = spikes[valid]["amplitude"][k] - sum += a*templates_array[j] - + sum += a * templates_array[j] + tgt_norm = np.linalg.norm(sum) - ratio = tgt_norm / ref_norm + ratio = tgt_norm / ref_norm if (amplitudes[0] < ratio) and (ratio < amplitudes[1]): if multiple_passes: @@ -659,32 +656,35 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, similar_templates[0] += [-1] ignore_ids += [i] similar_templates[1] += [i] - + if DEBUG: import pylab as plt + fig, axes = plt.subplots(1, 2) from spikeinterface.widgets import plot_traces + plot_traces(sub_recording, ax=axes[0]) - axes[1].plot(templates_array[i].flatten(), label=f'{ref_norm}') - axes[1].plot(sum.flatten(), label=f'{tgt_norm}') + axes[1].plot(templates_array[i].flatten(), label=f"{ref_norm}") + axes[1].plot(sum.flatten(), label=f"{tgt_norm}") axes[1].legend() plt.show() print(i, spikes[valid]["cluster_index"], spikes[valid]["amplitude"]) - + del recording, sub_recording, local_params, templates if tmp_filename is not None: os.remove(tmp_filename) - + return similar_templates -def remove_duplicates_via_matching(templates, peak_labels, method_kwargs={}, job_kwargs={}, - tmp_folder=None, rank=5, multiple_passes=False): - +def remove_duplicates_via_matching( + templates, peak_labels, method_kwargs={}, job_kwargs={}, tmp_folder=None, rank=5, multiple_passes=False +): + + similar_templates = detect_mixtures( + templates, method_kwargs, job_kwargs, tmp_folder=tmp_folder, rank=rank, multiple_passes=multiple_passes + ) - similar_templates = detect_mixtures(templates, method_kwargs, job_kwargs, - tmp_folder=tmp_folder, rank=rank, multiple_passes=multiple_passes) - new_labels = peak_labels.copy() labels = np.unique(new_labels) @@ -707,12 +707,13 @@ def resolve_merging_graph(sorting, potential_merges): """ from scipy.sparse.csgraph import connected_components from scipy.sparse import lil_matrix + n = len(sorting.unit_ids) graph = lil_matrix((n, n)) for i, j in potential_merges: graph[sorting.id_to_index(i), sorting.id_to_index(j)] = 1 - - n_components, labels = connected_components(graph, directed=True, connection='weak', return_labels=True) + + n_components, labels = connected_components(graph, directed=True, connection="weak", return_labels=True) final_merges = [] for i in range(n_components): merges = labels == i @@ -723,9 +724,10 @@ def resolve_merging_graph(sorting, potential_merges): return final_merges + def apply_merges_to_sorting(sorting, merges, censor_ms=0.4): """ - Function to apply a resolved representation of the merges to a sorting object. If censor_ms is not None, + Function to apply a resolved representation of the merges to a sorting object. If censor_ms is not None, duplicated spikes violating the censor_ms refractory period are removed """ spikes = sorting.to_spike_vector().copy() @@ -738,11 +740,11 @@ def apply_merges_to_sorting(sorting, merges, censor_ms=0.4): if censor_ms is not None: rpv = int(sorting.sampling_frequency * censor_ms / 1000) - + for src, targets in merges: - mask = np.in1d(spikes['unit_index'], sorting.ids_to_indices([src] + list(targets))) - spikes['unit_index'][mask] = sorting.id_to_index(src) - + mask = np.in1d(spikes["unit_index"], sorting.ids_to_indices([src] + list(targets))) + spikes["unit_index"][mask] = sorting.id_to_index(src) + if censor_ms is not None: for segment_index in range(sorting.get_num_segments()): s0, s1 = segment_slices[segment_index] @@ -752,24 +754,23 @@ def apply_merges_to_sorting(sorting, merges, censor_ms=0.4): ) from spikeinterface.core.numpyextractors import NumpySorting + times_list = [] labels_list = [] for segment_index in range(sorting.get_num_segments()): s0, s1 = segment_slices[segment_index] if censor_ms is not None: - times_list += [spikes['sample_index'][s0:s1][to_keep[s0:s1]]] - labels_list += [spikes['unit_index'][s0:s1][to_keep[s0:s1]]] + times_list += [spikes["sample_index"][s0:s1][to_keep[s0:s1]]] + labels_list += [spikes["unit_index"][s0:s1][to_keep[s0:s1]]] else: - times_list += [spikes['sample_index'][s0:s1]] - labels_list += [spikes['unit_index'][s0:s1]] + times_list += [spikes["sample_index"][s0:s1]] + labels_list += [spikes["unit_index"][s0:s1]] sorting = NumpySorting.from_times_labels(times_list, labels_list, sorting.sampling_frequency) return sorting - -def final_cleaning_circus(recording, sorting, templates, - **merging_kwargs): +def final_cleaning_circus(recording, sorting, templates, **merging_kwargs): from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.curation.auto_merge import get_potential_auto_merge @@ -779,14 +780,15 @@ def final_cleaning_circus(recording, sorting, templates, sa = create_sorting_analyzer(sorting, recording, format="memory", sparsity=sparsity) from spikeinterface.core.analyzer_extension_core import ComputeTemplates - sa.extensions['templates'] = ComputeTemplates(sa) - sa.extensions['templates'].params = {'nbefore' : templates.nbefore} - sa.extensions['templates'].data['average'] = templates_array - sa.compute('unit_locations', method='monopolar_triangulation') + + sa.extensions["templates"] = ComputeTemplates(sa) + sa.extensions["templates"].params = {"nbefore": templates.nbefore} + sa.extensions["templates"].data["average"] = templates_array + sa.compute("unit_locations", method="monopolar_triangulation") merges = get_potential_auto_merge(sa, **merging_kwargs) merges = resolve_merging_graph(sorting, merges) sorting = apply_merges_to_sorting(sorting, merges) - #sorting = merge_units_sorting(sorting, merges) + # sorting = merge_units_sorting(sorting, merges) return sorting @@ -885,4 +887,4 @@ def remove_duplicates_via_dip(wfs_arrays, peak_labels, dip_threshold=1, cosine_t labels = np.unique(new_labels) labels = labels[labels >= 0] - return labels, new_labels \ No newline at end of file + return labels, new_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index f906cd2945..efd63be55f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -87,9 +87,7 @@ def main_function(cls, recording, peaks, params): node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params["smoothing_kwargs"]) num_projections = min(num_chans, d["nb_projections"]) - projections = rng.normal( - loc=0.0, scale=1.0 / np.sqrt(num_chans), size=(num_chans, num_projections) - ) + projections = rng.normal(loc=0.0, scale=1.0 / np.sqrt(num_chans), size=(num_chans, num_projections)) nbefore = int(params["ms_before"] * fs / 1000) nafter = int(params["ms_after"] * fs / 1000) diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index cddc57103a..30378c80c3 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -180,7 +180,7 @@ def split( feature_name="sparse_tsvd", neighbours_mask=None, waveforms_sparse_mask=None, - clusterer_kwargs={'min_cluster_size' : 25}, + clusterer_kwargs={"min_cluster_size": 25}, min_size_split=25, n_pca_features=2, scale_n_pca_by_depth=False, @@ -192,7 +192,7 @@ def split( sparse_features = features[feature_name] assert waveforms_sparse_mask is not None - + # target channel subset is done intersect local channels + neighbours local_chans = np.unique(peaks["channel_index"][peak_indices]) @@ -215,27 +215,27 @@ def split( aligned_wfs = aligned_wfs[kept, :, :] flatten_features = aligned_wfs.reshape(aligned_wfs.shape[0], -1) - + if flatten_features.shape[1] > n_pca_features: if scale_n_pca_by_depth: - #tsvd = TruncatedSVD(n_pca_features * recursion_level) + # tsvd = TruncatedSVD(n_pca_features * recursion_level) tsvd = PCA(n_pca_features * recursion_level, whiten=True) else: - #tsvd = TruncatedSVD(n_pca_features) + # tsvd = TruncatedSVD(n_pca_features) tsvd = PCA(n_pca_features, whiten=True) final_features = tsvd.fit_transform(flatten_features) else: final_features = flatten_features - if clusterer == "hdbscan": from hdbscan import HDBSCAN + clust = HDBSCAN(**clusterer_kwargs) clust.fit(final_features) possible_labels = clust.labels_ is_split = np.setdiff1d(possible_labels, [-1]).size > 1 elif clusterer == "isocut5": - min_cluster_size = clusterer_kwargs['min_cluster_size'] + min_cluster_size = clusterer_kwargs["min_cluster_size"] dipscore, cutpoint = isocut5(final_features[:, 0]) possible_labels = np.zeros(final_features.shape[0]) if dipscore > 1.5: @@ -248,7 +248,7 @@ def split( else: raise ValueError(f"wrong clusterer {clusterer}") - #DEBUG = True + # DEBUG = True DEBUG = False if DEBUG: import matplotlib.pyplot as plt diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index e5c7789da6..66e5e87119 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -86,7 +86,7 @@ def get_prototype_spike(recording, peaks, ms_before=0.5, ms_after=0.5, nb_peaks= def check_probe_for_drift_correction(recording, dist_x_max=60): num_channels = recording.get_num_channels() - if num_channels < 32: + if num_channels < 32: return False else: locations = recording.get_channel_locations() @@ -96,6 +96,7 @@ def check_probe_for_drift_correction(recording, dist_x_max=60): return False return True + def cache_preprocessing(recording, mode="memory", memory_limit=0.5, delete_cache=True, **extra_kwargs): save_kwargs, job_kwargs = split_job_kwargs(extra_kwargs) From 9ac08c39801737a895a6e9e05c406c3bf1393d97 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 26 Apr 2024 17:10:31 +0200 Subject: [PATCH 03/18] Removing useless changes --- src/spikeinterface/preprocessing/motion.py | 7 ++-- .../sorters/internal/spyking_circus2.py | 27 ++++++++++++- .../benchmark/benchmark_matching.py | 4 +- .../benchmark/benchmark_tools.py | 3 +- .../clustering/clustering_tools.py | 39 ++++--------------- 5 files changed, 39 insertions(+), 41 deletions(-) diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 587666ad3e..1b182a6436 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -132,10 +132,9 @@ ), "interpolate_motion_kwargs": dict( direction=1, - border_mode="force_extrapolate", - spatial_interpolation_method="kriging", - sigma_um=np.sqrt(2)*20.0, - p=2 + border_mode="remove_channels", + spatial_interpolation_method="idw", + num_closest=3, ), }, # This preset is a super fast rigid estimation with center of mass diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 4ed8d9d2a8..e0395c5eb3 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -18,7 +18,10 @@ from spikeinterface.core.sparsity import compute_sparsity from spikeinterface.sortingcomponents.tools import remove_empty_templates from spikeinterface.sortingcomponents.tools import get_prototype_spike, check_probe_for_drift_correction -from spikeinterface.sortingcomponents.clustering.clustering_tools import final_cleaning_circus +from spikeinterface.core.sortinganalyzer import create_sorting_analyzer +from spikeinterface.curation.auto_merge import get_potential_auto_merge +from spikeinterface.sortingcomponents.clustering.clustering_tools import resolve_merging_graph, apply_merges_to_sorting +from spikeinterface.core.analyzer_extension_core import ComputeTemplates try: import hdbscan @@ -331,3 +334,25 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting = sorting.save(folder=sorting_folder) return sorting + + +def final_cleaning_circus(recording, sorting, templates, + **merging_kwargs): + + + + sparsity = templates.sparsity + templates_array = templates.get_dense_templates().copy() + + sa = create_sorting_analyzer(sorting, recording, format="memory", sparsity=sparsity) + + sa.extensions['templates'] = ComputeTemplates(sa) + sa.extensions['templates'].params = {'nbefore' : templates.nbefore} + sa.extensions['templates'].data['average'] = templates_array + sa.compute('unit_locations', method='monopolar_triangulation') + merges = get_potential_auto_merge(sa, **merging_kwargs) + merges = resolve_merging_graph(sorting, merges) + sorting = apply_merges_to_sorting(sorting, merges) + #sorting = merge_units_sorting(sorting, merges) + + return sorting \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 3346be662d..5dd0778f76 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -42,13 +42,13 @@ def compute_result(self, **result_params): sorting = self.result["sorting"] comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True) self.result["gt_comparison"] = comp - #self.result["gt_collision"] = CollisionGTComparison(self.gt_sorting, sorting, exhaustive_gt=True) + self.result["gt_collision"] = CollisionGTComparison(self.gt_sorting, sorting, exhaustive_gt=True) _run_key_saved = [ ("sorting", "sorting"), ("templates", "zarr_templates"), ] - _result_key_saved = [("gt_comparison", "pickle")] + _result_key_saved = [("gt_collision", "pickle"), ("gt_comparison", "pickle")] class MatchingStudy(BenchmarkStudy): diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index 5ce25685ab..52322ec0b8 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -229,9 +229,8 @@ def create_sorting_analyzer_gt(self, case_keys=None, return_scaled=True, random_ folder = base_folder / self.key_to_str(dataset_key) recording, gt_sorting = self.datasets[dataset_key] sorting_analyzer = create_sorting_analyzer( - gt_sorting, recording, format="binary_folder", folder=folder) + gt_sorting, recording, format="binary_folder", folder=folder, return_scaled=return_scaled) sorting_analyzer.compute("random_spikes", **random_params) - sorting_analyzer.compute("waveforms") sorting_analyzer.compute("templates", **job_kwargs) sorting_analyzer.compute("noise_levels") diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index a4af25e59a..7747e5dd66 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -8,7 +8,7 @@ import numpy as np from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap -from spikeinterface.curation.mergeunitssorting import merge_units_sorting +from spikeinterface.core import NumpySorting def _split_waveforms( @@ -717,12 +717,11 @@ def resolve_merging_graph(sorting, potential_merges): for i in range(n_components): merges = labels == i if merges.sum() > 1: - src = np.where(merges)[0][0] - tgts = np.where(merges)[0][1:] - final_merges += [(sorting.unit_ids[src], sorting.unit_ids[tgts])] + final_merges += [list(sorting.unit_ids[np.flatnonzero(merges)]))] return final_merges + def apply_merges_to_sorting(sorting, merges, censor_ms=0.4): """ Function to apply a resolved representation of the merges to a sorting object. If censor_ms is not None, @@ -739,9 +738,9 @@ def apply_merges_to_sorting(sorting, merges, censor_ms=0.4): if censor_ms is not None: rpv = int(sorting.sampling_frequency * censor_ms / 1000) - for src, targets in merges: - mask = np.in1d(spikes['unit_index'], sorting.ids_to_indices([src] + list(targets))) - spikes['unit_index'][mask] = sorting.id_to_index(src) + for connected in merges: + mask = np.in1d(spikes['unit_index'], sorting.ids_to_indices(connected)) + spikes['unit_index'][mask] = sorting.id_to_index(connected[0]) if censor_ms is not None: for segment_index in range(sorting.get_num_segments()): @@ -751,7 +750,7 @@ def apply_merges_to_sorting(sorting, merges, censor_ms=0.4): to_keep[indices[1:]], np.diff(spikes[indices]["sample_index"]) > rpv ) - from spikeinterface.core.numpyextractors import NumpySorting + times_list = [] labels_list = [] for segment_index in range(sorting.get_num_segments()): @@ -767,30 +766,6 @@ def apply_merges_to_sorting(sorting, merges, censor_ms=0.4): return sorting - -def final_cleaning_circus(recording, sorting, templates, - **merging_kwargs): - - from spikeinterface.core.sortinganalyzer import create_sorting_analyzer - from spikeinterface.curation.auto_merge import get_potential_auto_merge - - sparsity = templates.sparsity - templates_array = templates.get_dense_templates().copy() - - sa = create_sorting_analyzer(sorting, recording, format="memory", sparsity=sparsity) - from spikeinterface.core.analyzer_extension_core import ComputeTemplates - sa.extensions['templates'] = ComputeTemplates(sa) - sa.extensions['templates'].params = {'nbefore' : templates.nbefore} - sa.extensions['templates'].data['average'] = templates_array - sa.compute('unit_locations', method='monopolar_triangulation') - merges = get_potential_auto_merge(sa, **merging_kwargs) - merges = resolve_merging_graph(sorting, merges) - sorting = apply_merges_to_sorting(sorting, merges) - #sorting = merge_units_sorting(sorting, merges) - - return sorting - - def remove_duplicates_via_dip(wfs_arrays, peak_labels, dip_threshold=1, cosine_threshold=None): import sklearn From 87df257696c92e87faee3416045c7d395f2108c6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 26 Apr 2024 15:15:07 +0000 Subject: [PATCH 04/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/spyking_circus2.py | 19 ++++++++----------- .../benchmark/benchmark_tools.py | 3 ++- .../clustering/clustering_tools.py | 6 +++--- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index fdb85f0243..503a6dcc50 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -342,23 +342,20 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): return sorting -def final_cleaning_circus(recording, sorting, templates, - **merging_kwargs): - - +def final_cleaning_circus(recording, sorting, templates, **merging_kwargs): sparsity = templates.sparsity templates_array = templates.get_dense_templates().copy() sa = create_sorting_analyzer(sorting, recording, format="memory", sparsity=sparsity) - - sa.extensions['templates'] = ComputeTemplates(sa) - sa.extensions['templates'].params = {'nbefore' : templates.nbefore} - sa.extensions['templates'].data['average'] = templates_array - sa.compute('unit_locations', method='monopolar_triangulation') + + sa.extensions["templates"] = ComputeTemplates(sa) + sa.extensions["templates"].params = {"nbefore": templates.nbefore} + sa.extensions["templates"].data["average"] = templates_array + sa.compute("unit_locations", method="monopolar_triangulation") merges = get_potential_auto_merge(sa, **merging_kwargs) merges = resolve_merging_graph(sorting, merges) sorting = apply_merges_to_sorting(sorting, merges) - #sorting = merge_units_sorting(sorting, merges) + # sorting = merge_units_sorting(sorting, merges) - return sorting \ No newline at end of file + return sorting diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index c35d961708..811673e525 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -229,7 +229,8 @@ def create_sorting_analyzer_gt(self, case_keys=None, return_scaled=True, random_ folder = base_folder / self.key_to_str(dataset_key) recording, gt_sorting = self.datasets[dataset_key] sorting_analyzer = create_sorting_analyzer( - gt_sorting, recording, format="binary_folder", folder=folder, return_scaled=return_scaled) + gt_sorting, recording, format="binary_folder", folder=folder, return_scaled=return_scaled + ) sorting_analyzer.compute("random_spikes", **random_params) sorting_analyzer.compute("templates", **job_kwargs) sorting_analyzer.compute("noise_levels") diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 855167c394..30d99e1d37 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -738,11 +738,11 @@ def apply_merges_to_sorting(sorting, merges, censor_ms=0.4): if censor_ms is not None: rpv = int(sorting.sampling_frequency * censor_ms / 1000) - + for connected in merges: mask = np.in1d(spikes['unit_index'], sorting.ids_to_indices(connected)) spikes['unit_index'][mask] = sorting.id_to_index(connected[0]) - + if censor_ms is not None: for segment_index in range(sorting.get_num_segments()): s0, s1 = segment_slices[segment_index] @@ -751,7 +751,7 @@ def apply_merges_to_sorting(sorting, merges, censor_ms=0.4): to_keep[indices[1:]], np.diff(spikes[indices]["sample_index"]) > rpv ) - + times_list = [] labels_list = [] for segment_index in range(sorting.get_num_segments()): From cea56cf8a94ef368b6e0dd3e1a1152fc27bd358f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 30 Apr 2024 18:30:40 +0200 Subject: [PATCH 05/18] repair tridesclous2 after cicurs2 changes --- src/spikeinterface/sorters/internal/tridesclous2.py | 3 +-- .../sortingcomponents/clustering/clustering_tools.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index fe618d42f3..e7bb1027e3 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -234,8 +234,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): neighbours_mask=neighbours_mask, waveforms_sparse_mask=sparse_mask, min_size_split=min_cluster_size, - min_cluster_size=min_cluster_size, - min_samples=50, + clusterer_kwargs={"min_cluster_size": min_cluster_size}, n_pca_features=3, ), recursive=True, diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 30d99e1d37..9b6d2e504a 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -718,7 +718,7 @@ def resolve_merging_graph(sorting, potential_merges): for i in range(n_components): merges = labels == i if merges.sum() > 1: - final_merges += [list(sorting.unit_ids[np.flatnonzero(merges)]))] + final_merges += [list(sorting.unit_ids[np.flatnonzero(merges)])] return final_merges From ac6e6d0cb9685491996d9fc0c7b8f160c17a2b67 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Apr 2024 16:31:20 +0000 Subject: [PATCH 06/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/clustering_tools.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 9b6d2e504a..d3a00c4e6e 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -740,8 +740,8 @@ def apply_merges_to_sorting(sorting, merges, censor_ms=0.4): rpv = int(sorting.sampling_frequency * censor_ms / 1000) for connected in merges: - mask = np.in1d(spikes['unit_index'], sorting.ids_to_indices(connected)) - spikes['unit_index'][mask] = sorting.id_to_index(connected[0]) + mask = np.in1d(spikes["unit_index"], sorting.ids_to_indices(connected)) + spikes["unit_index"][mask] = sorting.id_to_index(connected[0]) if censor_ms is not None: for segment_index in range(sorting.get_num_segments()): @@ -751,7 +751,6 @@ def apply_merges_to_sorting(sorting, merges, censor_ms=0.4): to_keep[indices[1:]], np.diff(spikes[indices]["sample_index"]) > rpv ) - times_list = [] labels_list = [] for segment_index in range(sorting.get_num_segments()): From bcf2d61e4951082854944964bd09ee80d4d8566c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 30 Apr 2024 18:38:01 +0200 Subject: [PATCH 07/18] fix circus2 no motion --- src/spikeinterface/sorters/internal/spyking_circus2.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 4a5930b044..98974fcc16 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -136,14 +136,18 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if params["drift_correction"] is not None: if not valid_geometry: print("Geometry of the probe does not allow 1D drift correction") + motion_folder = None else: print("Motion correction activated (probe geometry compatible)") motion_folder = sorter_output_folder / "motion" params["drift_correction"].update({"folder": motion_folder}) recording_f = correct_motion(recording_f, **params["drift_correction"]) + else: + motion_folder = None ## We need to whiten before the template matching step, to boost the results - recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32", regularize=True) + # TODO add , regularize=True chen ready + recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32") noise_levels = get_noise_levels(recording_w, return_scaled=False) @@ -303,7 +307,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): merging_params = params["merging"].copy() if len(merging_params) > 0: - if params["drift_correction"]: + if params["drift_correction"] and motion_folder is not None: from spikeinterface.preprocessing.motion import load_motion_info motion_info = load_motion_info(motion_folder) From b226fe68bc040b71abc875e0b8b76f1f19d016e4 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 30 Apr 2024 18:44:37 +0200 Subject: [PATCH 08/18] local sklearn import --- .../sortingcomponents/clustering/sliding_hdbscan.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index 871d486b9c..19e9640383 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -6,7 +6,7 @@ import random import string -import sklearn.decomposition + import numpy as np @@ -164,6 +164,9 @@ def _initialize_folder(cls, recording, peaks, params): @classmethod def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): + + import sklearn.decomposition + num_chans = recording.get_num_channels() fs = recording.get_sampling_frequency() nbefore = int(d["ms_before"] * fs / 1000.0) From 5b7ce16156b6fba3dd4f98595518a3a1a55af55e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Apr 2024 16:45:01 +0000 Subject: [PATCH 09/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/sliding_hdbscan.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index 19e9640383..7e7a8de1d7 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -7,7 +7,6 @@ import string - import numpy as np try: @@ -166,7 +165,7 @@ def _initialize_folder(cls, recording, peaks, params): def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): import sklearn.decomposition - + num_chans = recording.get_num_channels() fs = recording.get_sampling_frequency() nbefore = int(d["ms_before"] * fs / 1000.0) From 2facb8e3e7a0b12030201307c990b209c825bc28 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 30 Apr 2024 19:12:43 +0200 Subject: [PATCH 10/18] some fixes for tests dur to circus2 --- src/spikeinterface/sortingcomponents/clustering/sliding_nn.py | 4 +++- src/spikeinterface/sortingcomponents/clustering/split.py | 2 +- src/spikeinterface/sortingcomponents/tests/test_clustering.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py index d9c325d4e6..f0fc1d45da 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py @@ -8,7 +8,7 @@ HAVE_NUMBA = False import numpy as np -from sklearn.utils import check_random_state + try: from pynndescent import NNDescent @@ -469,6 +469,8 @@ def get_spike_nearest_neighbors( https://github.com/facebookresearch/pysparnn """ + from sklearn.utils import check_random_state + # helper functions for nearest-neighbors search tree def get_n_trees_iters(X): n_trees = min(64, 5 + int(round((X.shape[0]) ** 0.5 / 20.0))) diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 30378c80c3..ceeaeb6633 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -83,7 +83,7 @@ def split_clusters( for label in labels_set: peak_indices = np.flatnonzero(peak_labels == label) if peak_indices.size > 0: - jobs.append(pool.submit(split_function_wrapper, peak_indices, recursion_level=1)) + jobs.append(pool.submit(split_function_wrapper, peak_indices, 1)) if progress_bar: iterator = tqdm(jobs, desc=f"split_clusters with {method}", total=len(labels_set)) diff --git a/src/spikeinterface/sortingcomponents/tests/test_clustering.py b/src/spikeinterface/sortingcomponents/tests/test_clustering.py index be481aac4c..3092becc94 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_clustering.py +++ b/src/spikeinterface/sortingcomponents/tests/test_clustering.py @@ -76,6 +76,7 @@ def test_find_cluster_from_peaks(clustering_method, recording, peaks, peak_locat recording, sorting = make_dataset() peaks = run_peaks(recording, job_kwargs) peak_locations = run_peak_locations(recording, peaks, job_kwargs) - method = "position_and_pca" + # method = "position_and_pca" + method = "circus" test_find_cluster_from_peaks(method, recording, peaks, peak_locations) From bf497fd7d097ace42729865932eb32d53f891c96 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Apr 2024 17:13:33 +0000 Subject: [PATCH 11/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/clustering/sliding_nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py index f0fc1d45da..2466f8bba9 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py @@ -470,7 +470,7 @@ def get_spike_nearest_neighbors( """ from sklearn.utils import check_random_state - + # helper functions for nearest-neighbors search tree def get_n_trees_iters(X): n_trees = min(64, 5 + int(round((X.shape[0]) ** 0.5 / 20.0))) From 0fd4ab6382d797f41a5a5f738a3d093f3b1b5589 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Apr 2024 20:18:03 +0000 Subject: [PATCH 12/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/clustering/sliding_nn.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py index 58a60b959d..a6ffa5fdc2 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py @@ -28,7 +28,6 @@ HAVE_HDBSCAN = False - try: import pymde @@ -573,7 +572,7 @@ def merge_nn_dicts(peaks, n_neighbors, peaks_in_chunk_idx_list, knn_indices_list def construct_symmetric_graph_from_idx_vals(graph_idx, graph_vals): from scipy.sparse import coo_matrix - + rows = graph_idx.flatten() cols = np.repeat(np.arange(len(graph_idx)), graph_idx.shape[1]) rows_ = np.concatenate([rows, cols]) From 06d67ae5d804cd920705e2b7459741291246f7ca Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 30 Apr 2024 22:26:04 +0200 Subject: [PATCH 13/18] fix import --- src/spikeinterface/sortingcomponents/clustering/sliding_nn.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py index 2466f8bba9..58a60b959d 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py @@ -27,7 +27,7 @@ except: HAVE_HDBSCAN = False -from scipy.sparse import coo_matrix + try: import pymde @@ -572,6 +572,8 @@ def merge_nn_dicts(peaks, n_neighbors, peaks_in_chunk_idx_list, knn_indices_list def construct_symmetric_graph_from_idx_vals(graph_idx, graph_vals): + from scipy.sparse import coo_matrix + rows = graph_idx.flatten() cols = np.repeat(np.arange(len(graph_idx)), graph_idx.shape[1]) rows_ = np.concatenate([rows, cols]) From 0d15ab62425810fae5fbced271e210d15e29e810 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Apr 2024 20:28:48 +0000 Subject: [PATCH 14/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/spyking_circus2.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index b71d72aaae..6c00d357f8 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -105,7 +105,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks from spikeinterface.sortingcomponents.matching import find_spikes_from_templates - from spikeinterface.sortingcomponents.clustering.clustering_tools import resolve_merging_graph, apply_merges_to_sorting + from spikeinterface.sortingcomponents.clustering.clustering_tools import ( + resolve_merging_graph, + apply_merges_to_sorting, + ) from spikeinterface.sortingcomponents.tools import remove_empty_templates from spikeinterface.sortingcomponents.tools import get_prototype_spike, check_probe_for_drift_correction from spikeinterface.sortingcomponents.tools import get_prototype_spike From ccf297313f57e978501a910ece785066377d14ab Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 30 Apr 2024 22:33:39 +0200 Subject: [PATCH 15/18] remove sortingcomponents form test_imports actions --- .github/import_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/import_test.py b/.github/import_test.py index c1bebbd4e4..eb578c6102 100644 --- a/.github/import_test.py +++ b/.github/import_test.py @@ -9,7 +9,8 @@ "import spikeinterface.preprocessing", "import spikeinterface.comparison", "import spikeinterface.postprocessing", - "import spikeinterface.sortingcomponents", + # sorting components has too non core import everywhere + #"import spikeinterface.sortingcomponents", "import spikeinterface.curation", "import spikeinterface.exporters", "import spikeinterface.widgets", From 87d1bdf482169ffe2c88fa4c6c091e78b7ed748f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 30 Apr 2024 22:37:04 +0200 Subject: [PATCH 16/18] fix import in circus2 --- src/spikeinterface/sorters/internal/spyking_circus2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 98974fcc16..b71d72aaae 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -16,14 +16,10 @@ from spikeinterface.sortingcomponents.tools import cache_preprocessing from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sparsity import compute_sparsity -from spikeinterface.sortingcomponents.tools import remove_empty_templates -from spikeinterface.sortingcomponents.tools import get_prototype_spike, check_probe_for_drift_correction from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.curation.auto_merge import get_potential_auto_merge -from spikeinterface.sortingcomponents.clustering.clustering_tools import resolve_merging_graph, apply_merges_to_sorting from spikeinterface.core.analyzer_extension_core import ComputeTemplates -from spikeinterface.sortingcomponents.tools import get_prototype_spike try: import hdbscan @@ -109,6 +105,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks from spikeinterface.sortingcomponents.matching import find_spikes_from_templates + from spikeinterface.sortingcomponents.clustering.clustering_tools import resolve_merging_graph, apply_merges_to_sorting + from spikeinterface.sortingcomponents.tools import remove_empty_templates + from spikeinterface.sortingcomponents.tools import get_prototype_spike, check_probe_for_drift_correction + from spikeinterface.sortingcomponents.tools import get_prototype_spike job_kwargs = params["job_kwargs"] job_kwargs = fix_job_kwargs(job_kwargs) From 937d811f6fda6b26a66d8ef8b4294a7d496cbb93 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 30 Apr 2024 22:47:15 +0200 Subject: [PATCH 17/18] put back sortingcomponents form test_imports actions --- .github/import_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/import_test.py b/.github/import_test.py index eb578c6102..c1bebbd4e4 100644 --- a/.github/import_test.py +++ b/.github/import_test.py @@ -9,8 +9,7 @@ "import spikeinterface.preprocessing", "import spikeinterface.comparison", "import spikeinterface.postprocessing", - # sorting components has too non core import everywhere - #"import spikeinterface.sortingcomponents", + "import spikeinterface.sortingcomponents", "import spikeinterface.curation", "import spikeinterface.exporters", "import spikeinterface.widgets", From faad05d21c55c39bf95ae923033e51d71a4b8f3d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 30 Apr 2024 22:49:52 +0200 Subject: [PATCH 18/18] oups --- src/spikeinterface/sorters/internal/spyking_circus2.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 6c00d357f8..ba6870eef2 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -105,10 +105,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks from spikeinterface.sortingcomponents.matching import find_spikes_from_templates - from spikeinterface.sortingcomponents.clustering.clustering_tools import ( - resolve_merging_graph, - apply_merges_to_sorting, - ) from spikeinterface.sortingcomponents.tools import remove_empty_templates from spikeinterface.sortingcomponents.tools import get_prototype_spike, check_probe_for_drift_correction from spikeinterface.sortingcomponents.tools import get_prototype_spike @@ -353,6 +349,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): def final_cleaning_circus(recording, sorting, templates, **merging_kwargs): + from spikeinterface.sortingcomponents.clustering.clustering_tools import ( + resolve_merging_graph, + apply_merges_to_sorting, + ) + sparsity = templates.sparsity templates_array = templates.get_dense_templates().copy()