diff --git a/src/spikeinterface/benchmark/benchmark_matching.py b/src/spikeinterface/benchmark/benchmark_matching.py index 1934b65ef4..8b27d8d026 100644 --- a/src/spikeinterface/benchmark/benchmark_matching.py +++ b/src/spikeinterface/benchmark/benchmark_matching.py @@ -133,10 +133,10 @@ def get_count_units(self, case_keys=None, well_detected_score=None, redundant_sc return count_units - def plot_unit_counts(self, case_keys=None, figsize=None): - from spikeinterface.widgets.widget_list import plot_study_unit_counts + def plot_unit_counts(self, case_keys=None, **kwargs): + from .benchmark_plot_tools import plot_unit_counts - plot_study_unit_counts(self, case_keys, figsize=figsize) + return plot_unit_counts(self, case_keys, **kwargs) def plot_unit_losses(self, before, after, metric=["precision"], figsize=None): import matplotlib.pyplot as plt diff --git a/src/spikeinterface/benchmark/benchmark_merging.py b/src/spikeinterface/benchmark/benchmark_merging.py new file mode 100644 index 0000000000..5239a201cb --- /dev/null +++ b/src/spikeinterface/benchmark/benchmark_merging.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +from spikeinterface.curation.auto_merge import auto_merge_units +from spikeinterface.comparison import compare_sorter_to_ground_truth +from spikeinterface.core.sortinganalyzer import create_sorting_analyzer +from spikeinterface.widgets import ( + plot_unit_templates, + plot_amplitudes, + plot_crosscorrelograms, +) + +import numpy as np +from .benchmark_base import Benchmark, BenchmarkStudy + + +class MergingBenchmark(Benchmark): + + def __init__(self, recording, splitted_sorting, params, gt_sorting, splitted_cells=None): + self.recording = recording + self.splitted_sorting = splitted_sorting + self.gt_sorting = gt_sorting + self.splitted_cells = splitted_cells + self.method_kwargs = params["method_kwargs"] + self.result = {} + + def run(self, **job_kwargs): + sorting_analyzer = create_sorting_analyzer( + self.splitted_sorting, self.recording, format="memory", sparse=True, **job_kwargs + ) + # sorting_analyzer.compute(['random_spikes', 'templates']) + # sorting_analyzer.compute('template_similarity', max_lag_ms=0.1, method="l2", **job_kwargs) + merged_analyzer, self.result["merged_pairs"], self.result["merges"], self.result["outs"] = auto_merge_units( + sorting_analyzer, extra_outputs=True, **self.method_kwargs, **job_kwargs + ) + + self.result["sorting"] = merged_analyzer.sorting + + def compute_result(self, **result_params): + sorting = self.result["sorting"] + comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True) + self.result["gt_comparison"] = comp + + _run_key_saved = [("sorting", "sorting"), ("merges", "pickle"), ("merged_pairs", "pickle"), ("outs", "pickle")] + _result_key_saved = [("gt_comparison", "pickle")] + + +class MergingStudy(BenchmarkStudy): + + benchmark_class = MergingBenchmark + + def create_benchmark(self, key): + dataset_key = self.cases[key]["dataset"] + recording, gt_sorting = self.datasets[dataset_key] + params = self.cases[key]["params"] + init_kwargs = self.cases[key]["init_kwargs"] + benchmark = MergingBenchmark(recording, gt_sorting, params, **init_kwargs) + return benchmark + + def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): + import pandas as pd + + if case_keys is None: + case_keys = list(self.cases.keys()) + + if isinstance(case_keys[0], str): + index = pd.Index(case_keys, name=self.levels) + else: + index = pd.MultiIndex.from_tuples(case_keys, names=self.levels) + + columns = ["num_gt", "num_sorter", "num_well_detected"] + comp = self.get_result(case_keys[0])["gt_comparison"] + if comp.exhaustive_gt: + columns.extend(["num_false_positive", "num_redundant", "num_overmerged", "num_bad"]) + count_units = pd.DataFrame(index=index, columns=columns, dtype=int) + + for key in case_keys: + comp = self.get_result(key)["gt_comparison"] + assert comp is not None, "You need to do study.run_comparisons() first" + + gt_sorting = comp.sorting1 + sorting = comp.sorting2 + + count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids()) + count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids()) + count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score) + + if comp.exhaustive_gt: + count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score) + count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score) + count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score) + count_units.loc[key, "num_bad"] = comp.count_bad_units() + + return count_units + + def plot_agreement_matrix(self, **kwargs): + from .benchmark_plot_tools import plot_agreement_matrix + + return plot_agreement_matrix(self, **kwargs) + + def plot_unit_counts(self, case_keys=None, **kwargs): + from .benchmark_plot_tools import plot_unit_counts + + return plot_unit_counts(self, case_keys, **kwargs) + + def get_splitted_pairs(self, case_key): + return self.benchmarks[case_key].splitted_cells + + def get_splitted_pairs_index(self, case_key, pair): + for count, i in enumerate(self.benchmarks[case_key].splitted_cells): + if i == pair: + return count + + def plot_splitted_amplitudes(self, case_key, pair_index=0, backend="ipywidgets"): + analyzer = self.get_sorting_analyzer(case_key) + if analyzer.get_extension("spike_amplitudes") is None: + analyzer.compute(["spike_amplitudes"]) + plot_amplitudes(analyzer, unit_ids=self.get_splitted_pairs(case_key)[pair_index], backend=backend) + + def plot_splitted_correlograms(self, case_key, pair_index=0, backend="ipywidgets"): + analyzer = self.get_sorting_analyzer(case_key) + if analyzer.get_extension("correlograms") is None: + analyzer.compute(["correlograms"]) + if analyzer.get_extension("template_similarity") is None: + analyzer.compute(["template_similarity"]) + plot_crosscorrelograms(analyzer, unit_ids=self.get_splitted_pairs(case_key)[pair_index]) + + def plot_splitted_templates(self, case_key, pair_index=0, backend="ipywidgets"): + analyzer = self.get_sorting_analyzer(case_key) + if analyzer.get_extension("spike_amplitudes") is None: + analyzer.compute(["spike_amplitudes"]) + plot_unit_templates(analyzer, unit_ids=self.get_splitted_pairs(case_key)[pair_index], backend=backend) + + def plot_potential_merges(self, case_key, min_snr=None, backend="ipywidgets"): + analyzer = self.get_sorting_analyzer(case_key) + mylist = self.get_splitted_pairs(case_key) + + if analyzer.get_extension("spike_amplitudes") is None: + analyzer.compute(["spike_amplitudes"]) + if analyzer.get_extension("correlograms") is None: + analyzer.compute(["correlograms"]) + + if min_snr is not None: + select_from = analyzer.sorting.unit_ids + if analyzer.get_extension("noise_levels") is None: + analyzer.compute("noise_levels") + if analyzer.get_extension("quality_metrics") is None: + analyzer.compute("quality_metrics", metric_names=["snr"]) + + snr = analyzer.get_extension("quality_metrics").get_data()["snr"].values + select_from = select_from[snr > min_snr] + mylist_selection = [] + for i in mylist: + if (i[0] in select_from) or (i[1] in select_from): + mylist_selection += [i] + mylist = mylist_selection + + from spikeinterface.widgets import plot_potential_merges + + plot_potential_merges(analyzer, mylist, backend=backend) + + def plot_performed_merges(self, case_key, backend="ipywidgets"): + analyzer = self.get_sorting_analyzer(case_key) + + if analyzer.get_extension("spike_amplitudes") is None: + analyzer.compute(["spike_amplitudes"]) + if analyzer.get_extension("correlograms") is None: + analyzer.compute(["correlograms"]) + + all_merges = list(self.benchmarks[case_key].result["merged_pairs"].values()) + + from spikeinterface.widgets import plot_potential_merges + + plot_potential_merges(analyzer, all_merges, backend=backend) diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_merging.py b/src/spikeinterface/benchmark/tests/test_benchmark_merging.py new file mode 100644 index 0000000000..c868eb76e5 --- /dev/null +++ b/src/spikeinterface/benchmark/tests/test_benchmark_merging.py @@ -0,0 +1,74 @@ +import pytest +from pathlib import Path +import numpy as np + +import shutil + +from spikeinterface.benchmark.benchmark_merging import MergingStudy +from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset + +from spikeinterface.generation.splitting_tools import split_sorting_by_amplitudes, split_sorting_by_times + + +@pytest.mark.skip() +def test_benchmark_merging(create_cache_folder): + cache_folder = create_cache_folder + job_kwargs = dict(n_jobs=0.8, chunk_duration="1s") + + recording, gt_sorting, gt_analyzer = make_dataset() + + # create study + study_folder = cache_folder / "study_clustering" + # datasets = {"toy": (recording, gt_sorting)} + datasets = {"toy": gt_analyzer} + + gt_analyzer.compute(["random_spikes", "templates", "spike_amplitudes"]) + + splitted_sorting = {} + splitted_sorting["times"] = split_sorting_by_times(gt_analyzer) + splitted_sorting["amplitudes"] = split_sorting_by_amplitudes(gt_analyzer) + + cases = {} + for splits in ["times", "amplitudes"]: + for method in ["circus", "lussac"]: + cases[(method, splits)] = { + "label": f"{method}", + "dataset": "toy", + "init_kwargs": {"gt_sorting": gt_sorting, "splitted_cells": splitted_sorting[splits][1]}, + "params": {"method": method, "splitted_sorting": splitted_sorting[splits][0], "method_kwargs": {}}, + } + + if study_folder.exists(): + shutil.rmtree(study_folder) + study = MergingStudy.create(study_folder, datasets=datasets, cases=cases) + print(study) + + # this study needs analyzer + # study.create_sorting_analyzer_gt(**job_kwargs) + study.compute_metrics() + + study = MergingStudy(study_folder) + + # run and result + study.run(**job_kwargs) + study.compute_results() + + # load study to check persistency + study = MergingStudy(study_folder) + print(study) + + # plots + # study.plot_performances_vs_snr() + study.plot_agreements() + study.plot_unit_counts() + # study.plot_error_metrics() + # study.plot_metrics_vs_snr() + # study.plot_run_times() + # study.plot_metrics_vs_snr("cosine") + # study.homogeneity_score(ignore_noise=False) + # import matplotlib.pyplot as plt + # plt.show() + + +if __name__ == "__main__": + test_benchmark_merging() diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index 0302ffe5b7..57d63e340c 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -3,7 +3,11 @@ from .remove_redundant import remove_redundant_units, find_redundant_units from .remove_duplicated_spikes import remove_duplicated_spikes from .remove_excess_spikes import remove_excess_spikes -from .auto_merge import compute_merge_unit_groups, auto_merge_units, get_potential_auto_merge +from .auto_merge import ( + compute_merge_unit_groups, + auto_merge_units, + get_potential_auto_merge, +) # manual sorting, diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 4f4cff144e..c3301b86b4 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -13,7 +13,7 @@ except ImportError: HAVE_NUMBA = False -from ..core import SortingAnalyzer, Templates +from ..core import SortingAnalyzer from ..qualitymetrics import compute_refrac_period_violations, compute_firing_rates from .mergeunitssorting import MergeUnitsSorting @@ -234,7 +234,6 @@ def compute_merge_unit_groups( # STEP : remove units with too few spikes if step == "num_spikes": - num_spikes = sorting.count_num_spikes_per_unit(outputs="array") to_remove = num_spikes < params["min_spikes"] pair_mask[to_remove, :] = False @@ -356,6 +355,9 @@ def compute_merge_unit_groups( ) outs["pairs_decreased_score"] = pairs_decreased_score + # ind1, ind2 = np.nonzero(pair_mask) + # print(step, len(ind1)) + # FINAL STEP : create the final list from pair_mask boolean matrix ind1, ind2 = np.nonzero(pair_mask) merge_unit_groups = list(zip(unit_ids[ind1], unit_ids[ind2])) @@ -369,19 +371,122 @@ def compute_merge_unit_groups( return merge_unit_groups -def auto_merge_units( - sorting_analyzer: SortingAnalyzer, compute_merge_kwargs: dict = {}, apply_merge_kwargs: dict = {}, **job_kwargs +def resolve_pairs(existing_merges, new_merges): + if existing_merges is None: + return new_merges.copy() + else: + resolved_merges = existing_merges.copy() + old_keys = list(existing_merges.keys()) + for key, pair in new_merges.items(): + nested_merge = np.flatnonzero([i in pair for i in old_keys]) + if len(nested_merge) == 0: + resolved_merges.update({key: pair}) + else: + for n in nested_merge: + previous_merges = resolved_merges.pop(old_keys[n]) + pair.remove(old_keys[n]) + pair += previous_merges + resolved_merges.update({key: pair}) + return resolved_merges + + +def auto_merge_units_internal( + sorting_analyzer: SortingAnalyzer, + compute_merge_kwargs: dict = {}, + apply_merge_kwargs: dict = {}, + recursive: bool = False, + extra_outputs: bool = False, + force_copy: bool = True, + **job_kwargs, ) -> SortingAnalyzer: """ Compute merge unit groups and apply it on a SortingAnalyzer. Internally uses `compute_merge_unit_groups()` + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer + compute_merge_kwargs : dict + The params that should be given to auto_merge_units + apply_merge_kwargs : dict + The paramaters that should be used while merging units after each preset + recursive : bool, default: False + If True, then merges are performed recursively until no more merges can be performed, given the + compute_merge_kwargs + extra_outputs : bool, default: False + If True, additional list of merges applied, and dictionary (`outs`) with processed data are returned. + force_copy : boolean, default: True + When new extensions are computed, the default is to make a copy of the analyzer, to avoid overwriting + already computed extensions. False if you want to overwrite + + Returns + ------- + sorting_analyzer: + The new sorting analyzer where all the merges from all the presets have been applied + + merges, outs: + Returned only when extra_outputs=True + A list with the merges performed, and dictionaries that contains data for debugging and plotting. + Note that if recursive, then you are receiving list of lists (for all merges and outs at every step) """ - merge_unit_groups = compute_merge_unit_groups( - sorting_analyzer, extra_outputs=False, **compute_merge_kwargs, **job_kwargs - ) - merged_analyzer = sorting_analyzer.merge_units(merge_unit_groups, **apply_merge_kwargs, **job_kwargs) - return merged_analyzer + if force_copy: + # To avoid erasing the extensions of the user + sorting_analyzer = sorting_analyzer.copy() + + if not recursive: + merge_unit_groups = compute_merge_unit_groups( + sorting_analyzer, **compute_merge_kwargs, extra_outputs=extra_outputs, force_copy=False, **job_kwargs + ) + + if extra_outputs: + merge_unit_groups, outs = merge_unit_groups + + merged_units = len(merge_unit_groups) > 0 + if merged_units: + merged_analyzer, new_unit_ids = sorting_analyzer.merge_units( + merge_unit_groups, return_new_unit_ids=True, **apply_merge_kwargs, **job_kwargs + ) + else: + merged_analyzer = sorting_analyzer + new_unit_ids = [] + + resolved_merges = {key: value for (key, value) in zip(new_unit_ids, merge_unit_groups)} + else: + merged_units = True + merged_analyzer = sorting_analyzer + + if extra_outputs: + all_merging_groups = [] + resolved_merges = {} + all_outs = [] + + while merged_units: + merge_unit_groups = compute_merge_unit_groups( + merged_analyzer, **compute_merge_kwargs, extra_outputs=extra_outputs, force_copy=False, **job_kwargs + ) + + if extra_outputs: + merge_unit_groups, outs = merge_unit_groups + + merged_units = len(merge_unit_groups) > 0 + + if merged_units: + merged_analyzer, new_unit_ids = merged_analyzer.merge_units( + merge_unit_groups, return_new_unit_ids=True, **apply_merge_kwargs, **job_kwargs + ) + + if extra_outputs: + all_merging_groups += [merge_unit_groups] + new_merges = {key: value for (key, value) in zip(new_unit_ids, merge_unit_groups)} + resolved_merges = resolve_pairs(resolved_merges, new_merges) + all_outs += [outs] + + if extra_outputs: + return merged_analyzer, resolved_merges, merge_unit_groups, outs + else: + return merged_analyzer def get_potential_auto_merge( @@ -572,6 +677,118 @@ def get_potential_auto_merge( ) +def auto_merge_units( + sorting_analyzer: SortingAnalyzer, + presets: list | None = ["similarity_correlograms"], + steps_params: dict = None, + steps: list[str] | None = None, + apply_merge_kwargs: dict = {}, + recursive: bool = False, + extra_outputs: bool = False, + force_copy: bool = True, + **job_kwargs, +) -> SortingAnalyzer: + """ + Wrapper to conveniently be able to launch several presets for auto_merge_units in a row, as a list. + Merges are applied sequentially or until no more merges are done, one preset at a time, and extensions + are not recomputed thanks to the merging units. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer + presets : str or list, default = "similarity_correlograms" + A single preset or a list of presets, that should be applied iteratively to the data + steps_params : dict or list of dict, default None + The params that should be used for the steps or presets. Should be a single dict if only one steps, + or a list of dict is multiples steps (same size as presets) + steps : list or list of list, default None + The list of steps that should be applied. If list of list is provided, then these lists will be applied + iteratively. Mutually exclusive with presets + apply_merge_kwargs : dict + The paramaters that should be used while merging units after each preset + recursive : bool, default: False + If True, then each presets of the list is applied until no further merges can be done, before trying + the next one + extra_outputs : bool, default: False + If True, additional list of merges applied at every preset, and dictionary (`outs`) with processed data are returned. + force_copy : boolean, default: True + When new extensions are computed, the default is to make a copy of the analyzer, to avoid overwriting + already computed extensions. False if you want to overwrite + + IMPORTANT: internally, all computations are relying on extensions of the analyzer, that are computed + with default parameters if not present (i.e. correlograms, template_similarity, ...) If you want to + have a finer control on these values, please precompute the extensions before applying the auto_merge + + If you have errors on sparsity_threshold, this is because you are trying to perform soft_merges for units + that are barely overlapping. While in theory this should + + Returns + ------- + sorting_analyzer: + The new sorting analyzer where all the merges from all the presets have been applied + merges, outs: + Returned only when extra_outputs=True + A list with all the merges performed at every steps, and dictionaries that contains data for debugging and plotting. + """ + + if isinstance(presets, str): + presets = [presets] + + if (steps is not None) and (presets is not None): + raise Exception("presets and steps are mutually exclusive") + + if presets is not None: + to_be_launched = presets + launch_mode = "presets" + elif steps is not None: + to_be_launched = steps + launch_mode = "steps" + + if steps_params is not None: + assert len(steps_params) == len(to_be_launched), f"steps params should have the same size as {launch_mode}" + else: + steps_params = [None] * len(to_be_launched) + + if extra_outputs: + all_merging_groups = [] + all_outs = [] + resolved_merges = {} + + if force_copy: + sorting_analyzer = sorting_analyzer.copy() + + for to_launch, params in zip(to_be_launched, steps_params): + + if launch_mode == "presets": + compute_merge_kwargs = {"preset": to_launch} + elif launch_mode == "steps": + compute_merge_kwargs = {"steps": to_launch} + + compute_merge_kwargs.update({"steps_params": params}) + # print(compute_merge_kwargs) + sorting_analyzer = auto_merge_units_internal( + sorting_analyzer, + compute_merge_kwargs, + apply_merge_kwargs=apply_merge_kwargs, + recursive=recursive, + extra_outputs=extra_outputs, + force_copy=False, + **job_kwargs, + ) + + if extra_outputs: + sorting_analyzer, new_merges, merge_unit_groups, outs = sorting_analyzer + all_merging_groups += [merge_unit_groups] + resolved_merges = resolve_pairs(resolved_merges, new_merges) + all_outs += [outs] + + if extra_outputs: + return sorting_analyzer, resolved_merges, merge_unit_groups, all_outs + else: + return sorting_analyzer + + def get_pairs_via_nntree(sorting_analyzer, k_nn=5, pair_mask=None, **knn_kwargs): sorting = sorting_analyzer.sorting @@ -1218,7 +1435,6 @@ def estimate_cross_contamination( n_violations = compute_nb_coincidence(spike_train1, spike_train2, t_r) - compute_nb_coincidence( spike_train1, spike_train2, t_c ) - estimation = 1 - ((n_violations * T) / (2 * N1 * N2 * t_r) - 1.0) / (C1 - 1.0) if C1 != 1.0 else -np.inf if limit is None: return estimation diff --git a/src/spikeinterface/curation/tests/common.py b/src/spikeinterface/curation/tests/common.py index 9cd20f4bfc..8805171c1f 100644 --- a/src/spikeinterface/curation/tests/common.py +++ b/src/spikeinterface/curation/tests/common.py @@ -3,17 +3,16 @@ import pytest from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer -from spikeinterface.qualitymetrics import compute_quality_metrics job_kwargs = dict(n_jobs=-1) -def make_sorting_analyzer(sparse=True): +def make_sorting_analyzer(sparse=True, num_units=5): recording, sorting = generate_ground_truth_recording( durations=[300.0], sampling_frequency=30000.0, num_channels=4, - num_units=5, + num_units=num_units, generate_sorting_kwargs=dict(firing_rates=20.0, refractory_period_ms=4.0), noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"), seed=2205, diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index 4c05f41a4c..2f0e51c295 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -3,7 +3,8 @@ from spikeinterface.core import create_sorting_analyzer from spikeinterface.core.generate import inject_some_split_units -from spikeinterface.curation import compute_merge_unit_groups, auto_merge +from spikeinterface.curation import compute_merge_unit_groups, auto_merge_units +from spikeinterface.generation import split_sorting_by_times from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation @@ -81,60 +82,82 @@ def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset): **job_kwargs, ) - # DEBUG - # import matplotlib.pyplot as plt - # from spikeinterface.curation.auto_merge import normalize_correlogram - # templates_diff = outs['templates_diff'] - # correlogram_diff = outs['correlogram_diff'] - # bins = outs['bins'] - # correlograms_smoothed = outs['correlograms_smoothed'] - # correlograms = outs['correlograms'] - # win_sizes = outs['win_sizes'] - # fig, ax = plt.subplots() - # ax.hist(correlogram_diff.flatten(), bins=np.arange(0, 1, 0.05)) +# DEBUG +# import matplotlib.pyplot as plt +# from spikeinterface.curation.auto_merge import normalize_correlogram +# templates_diff = outs['templates_diff'] +# correlogram_diff = outs['correlogram_diff'] +# bins = outs['bins'] +# correlograms_smoothed = outs['correlograms_smoothed'] +# correlograms = outs['correlograms'] +# win_sizes = outs['win_sizes'] - # fig, ax = plt.subplots() - # ax.hist(templates_diff.flatten(), bins=np.arange(0, 1, 0.05)) +# fig, ax = plt.subplots() +# ax.hist(correlogram_diff.flatten(), bins=np.arange(0, 1, 0.05)) - # m = correlograms.shape[2] // 2 +# fig, ax = plt.subplots() +# ax.hist(templates_diff.flatten(), bins=np.arange(0, 1, 0.05)) - # for unit_id1, unit_id2 in merge_unit_groups[:5]: - # unit_ind1 = sorting_with_split.id_to_index(unit_id1) - # unit_ind2 = sorting_with_split.id_to_index(unit_id2) +# m = correlograms.shape[2] // 2 - # bins2 = bins[:-1] + np.mean(np.diff(bins)) - # fig, axs = plt.subplots(ncols=3) - # ax = axs[0] - # ax.plot(bins2, correlograms[unit_ind1, unit_ind1, :], color='b') - # ax.plot(bins2, correlograms[unit_ind2, unit_ind2, :], color='r') - # ax.plot(bins2, correlograms_smoothed[unit_ind1, unit_ind1, :], color='b') - # ax.plot(bins2, correlograms_smoothed[unit_ind2, unit_ind2, :], color='r') +# for unit_id1, unit_id2 in merge_unit_groups[:5]: +# unit_ind1 = sorting_with_split.id_to_index(unit_id1) +# unit_ind2 = sorting_with_split.id_to_index(unit_id2) - # ax.set_title(f'{unit_id1} {unit_id2}') - # ax = axs[1] - # ax.plot(bins2, correlograms_smoothed[unit_ind1, unit_ind2, :], color='g') +# bins2 = bins[:-1] + np.mean(np.diff(bins)) +# fig, axs = plt.subplots(ncols=3) +# ax = axs[0] +# ax.plot(bins2, correlograms[unit_ind1, unit_ind1, :], color='b') +# ax.plot(bins2, correlograms[unit_ind2, unit_ind2, :], color='r') +# ax.plot(bins2, correlograms_smoothed[unit_ind1, unit_ind1, :], color='b') +# ax.plot(bins2, correlograms_smoothed[unit_ind2, unit_ind2, :], color='r') - # auto_corr1 = normalize_correlogram(correlograms_smoothed[unit_ind1, unit_ind1, :]) - # auto_corr2 = normalize_correlogram(correlograms_smoothed[unit_ind2, unit_ind2, :]) - # cross_corr = normalize_correlogram(correlograms_smoothed[unit_ind1, unit_ind2, :]) +# ax.set_title(f'{unit_id1} {unit_id2}') +# ax = axs[1] +# ax.plot(bins2, correlograms_smoothed[unit_ind1, unit_ind2, :], color='g') - # ax = axs[2] - # ax.plot(bins2, auto_corr1, color='b') - # ax.plot(bins2, auto_corr2, color='r') - # ax.plot(bins2, cross_corr, color='g') +# auto_corr1 = normalize_correlogram(correlograms_smoothed[unit_ind1, unit_ind1, :]) +# auto_corr2 = normalize_correlogram(correlograms_smoothed[unit_ind2, unit_ind2, :]) +# cross_corr = normalize_correlogram(correlograms_smoothed[unit_ind1, unit_ind2, :]) - # ax.axvline(bins2[m - win_sizes[unit_ind1]], color='b') - # ax.axvline(bins2[m + win_sizes[unit_ind1]], color='b') - # ax.axvline(bins2[m - win_sizes[unit_ind2]], color='r') - # ax.axvline(bins2[m + win_sizes[unit_ind2]], color='r') +# ax = axs[2] +# ax.plot(bins2, auto_corr1, color='b') +# ax.plot(bins2, auto_corr2, color='r') +# ax.plot(bins2, cross_corr, color='g') - # ax.set_title(f'corr diff {correlogram_diff[unit_ind1, unit_ind2]} - temp diff {templates_diff[unit_ind1, unit_ind2]}') - # plt.show() +# ax.axvline(bins2[m - win_sizes[unit_ind1]], color='b') +# ax.axvline(bins2[m + win_sizes[unit_ind1]], color='b') +# ax.axvline(bins2[m - win_sizes[unit_ind2]], color='r') +# ax.axvline(bins2[m + win_sizes[unit_ind2]], color='r') + +# ax.set_title(f'corr diff {correlogram_diff[unit_ind1, unit_ind2]} - temp diff {templates_diff[unit_ind1, unit_ind2]}') +# plt.show() + + +def test_auto_merge_units(sorting_analyzer_for_curation): + recording = sorting_analyzer_for_curation.recording + job_kwargs = dict(n_jobs=-1) + new_sorting, _ = split_sorting_by_times(sorting_analyzer_for_curation) + new_sorting_analyzer = create_sorting_analyzer(new_sorting, recording, format="memory") + merged_analyzer = auto_merge_units(new_sorting_analyzer, presets="x_contaminations", **job_kwargs) + assert len(merged_analyzer.unit_ids) < len(new_sorting_analyzer.unit_ids) + + +def test_auto_merge_units_iterative(sorting_analyzer_for_curation): + recording = sorting_analyzer_for_curation.recording + job_kwargs = dict(n_jobs=-1) + new_sorting, _ = split_sorting_by_times(sorting_analyzer_for_curation) + new_sorting_analyzer = create_sorting_analyzer(new_sorting, recording, format="memory") + merged_analyzer = auto_merge_units( + new_sorting_analyzer, presets=["x_contaminations", "x_contaminations"], **job_kwargs + ) + assert len(merged_analyzer.unit_ids) < len(new_sorting_analyzer.unit_ids) if __name__ == "__main__": sorting_analyzer = make_sorting_analyzer(sparse=True) - # preset = "x_contaminations" preset = None test_compute_merge_unit_groups(sorting_analyzer, preset=preset) + test_auto_merge_units(sorting_analyzer) + test_auto_merge_units_iterative(sorting_analyzer) diff --git a/src/spikeinterface/generation/__init__.py b/src/spikeinterface/generation/__init__.py index 5bf42ecf0f..5d18ce5676 100644 --- a/src/spikeinterface/generation/__init__.py +++ b/src/spikeinterface/generation/__init__.py @@ -15,6 +15,8 @@ ) from .noise_tools import generate_noise +from .splitting_tools import split_sorting_by_amplitudes, split_sorting_by_times + from .drifting_generator import ( make_one_displacement_vector, generate_displacement_vector, diff --git a/src/spikeinterface/generation/drift_tools.py b/src/spikeinterface/generation/drift_tools.py index 9d28340352..e920725e2a 100644 --- a/src/spikeinterface/generation/drift_tools.py +++ b/src/spikeinterface/generation/drift_tools.py @@ -5,7 +5,6 @@ import numpy as np from numpy.typing import ArrayLike -from probeinterface import Probe from spikeinterface.core import BaseRecording, BaseRecordingSegment, BaseSorting, Templates diff --git a/src/spikeinterface/generation/splitting_tools.py b/src/spikeinterface/generation/splitting_tools.py new file mode 100644 index 0000000000..2dbbf99593 --- /dev/null +++ b/src/spikeinterface/generation/splitting_tools.py @@ -0,0 +1,147 @@ +import numpy as np +from spikeinterface.core.numpyextractors import NumpySorting +from spikeinterface.core.sorting_tools import spike_vector_to_indices + + +def split_sorting_by_times( + sorting_analyzer, splitting_probability=0.5, partial_split_prob=0.95, unit_ids=None, min_snr=None, seed=None +): + """ + Fonction used to split a sorting based on the times of the units. This + might be used for benchmarking meta merging step (see components) + + Parameters + ---------- + sorting_analyzer : A sortingAnalyzer object + The sortingAnalyzer object whose sorting should be splitted + splitting_probability : float, default 0.5 + probability of being splitted, for any cell in the provided sorting + partial_split_prob : float, default 0.95 + The percentage of spikes that will belong to pre/post splits + unit_ids : list of unit_ids, default None + The list of unit_ids to be splitted, if prespecified + min_snr : float, default=None + If specified, only cells with a snr higher than min_snr might be splitted + seed : int | None, default: None + The seed for random generator. + + Returns + ------- + new_sorting, splitted_pairs : The new splitted sorting, and the pairs that have been splitted + """ + + sorting = sorting_analyzer.sorting + rng = np.random.RandomState(seed) + fs = sorting_analyzer.sampling_frequency + + nb_splits = int(splitting_probability * len(sorting.unit_ids)) + if unit_ids is None: + select_from = sorting.unit_ids + if min_snr is not None: + if sorting_analyzer.get_extension("noise_levels") is None: + sorting_analyzer.compute("noise_levels") + if sorting_analyzer.get_extension("quality_metrics") is None: + sorting_analyzer.compute("quality_metrics", metric_names=["snr"]) + + snr = sorting_analyzer.get_extension("quality_metrics").get_data()["snr"].values + select_from = select_from[snr > min_snr] + + to_split_ids = rng.choice(select_from, nb_splits, replace=False) + else: + to_split_ids = unit_ids + + spikes = sorting_analyzer.sorting.to_spike_vector(concatenated=False) + new_spikes = spikes[0].copy() + max_index = np.max(new_spikes["unit_index"]) + new_unit_ids = list(sorting_analyzer.sorting.unit_ids.copy()) + spike_indices = spike_vector_to_indices(spikes, sorting_analyzer.unit_ids, absolute_index=True) + splitted_pairs = [] + for unit_id in to_split_ids: + ind_mask = spike_indices[0][unit_id] + m = np.median(spikes[0][ind_mask]["sample_index"]) + time_mask = spikes[0][ind_mask]["sample_index"] > m + mask = time_mask & (rng.rand(len(ind_mask)) <= partial_split_prob).astype(bool) + new_index = unit_id * np.ones(len(mask)) + new_index[mask] = max_index + 1 + new_spikes["unit_index"][ind_mask] = new_index + new_unit_ids += [max_index + 1] + splitted_pairs += [(unit_id, new_unit_ids[-1])] + max_index += 1 + + new_sorting = NumpySorting(new_spikes, sampling_frequency=fs, unit_ids=new_unit_ids) + return new_sorting, splitted_pairs + + +def split_sorting_by_amplitudes( + sorting_analyzer, splitting_probability=0.5, partial_split_prob=0.95, unit_ids=None, min_snr=None, seed=None +): + """ + Fonction used to split a sorting based on the amplitudes of the units. This + might be used for benchmarking meta merging step (see components) + + Parameters + ---------- + sorting_analyzer : A sortingAnalyzer object + The sortingAnalyzer object whose sorting should be splitted + splitting_probability : float, default 0.5 + probability of being splitted, for any cell in the provided sorting + partial_split_prob : float, default 0.95 + The percentage of spikes that will belong to pre/post splits + unit_ids : list of unit_ids, default None + The list of unit_ids to be splitted, if prespecified + min_snr : float, default=None + If specified, only cells with a snr higher than min_snr might be splitted + seed : int | None, default: None + The seed for random generator. + + Returns + ------- + new_sorting, splitted_pairs : The new splitted sorting, and the pairs that have been splitted + """ + + if sorting_analyzer.get_extension("spike_amplitudes") is None: + sorting_analyzer.compute("spike_amplitudes") + + rng = np.random.RandomState(seed) + fs = sorting_analyzer.sampling_frequency + from spikeinterface.core.template_tools import get_template_extremum_channel + + extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") + spikes = sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds, concatenated=False) + new_spikes = spikes[0].copy() + amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() + nb_splits = int(splitting_probability * len(sorting_analyzer.sorting.unit_ids)) + + if unit_ids is None: + select_from = sorting_analyzer.sorting.unit_ids + if min_snr is not None: + if sorting_analyzer.get_extension("noise_levels") is None: + sorting_analyzer.compute("noise_levels") + if sorting_analyzer.get_extension("quality_metrics") is None: + sorting_analyzer.compute("quality_metrics", metric_names=["snr"]) + + snr = sorting_analyzer.get_extension("quality_metrics").get_data()["snr"].values + select_from = select_from[snr > min_snr] + to_split_ids = rng.choice(select_from, nb_splits, replace=False) + else: + to_split_ids = unit_ids + + max_index = np.max(new_spikes["unit_index"]) + new_unit_ids = list(sorting_analyzer.sorting.unit_ids.copy()) + splitted_pairs = [] + spike_indices = spike_vector_to_indices(spikes, sorting_analyzer.unit_ids, absolute_index=True) + + for unit_id in to_split_ids: + ind_mask = spike_indices[0][unit_id] + thresh = np.median(amplitudes[ind_mask]) + amplitude_mask = amplitudes[ind_mask] > thresh + mask = amplitude_mask & (rng.rand(len(ind_mask)) <= partial_split_prob).astype(bool) + new_index = unit_id * np.ones(len(mask)) + new_index[mask] = max_index + 1 + new_spikes["unit_index"][ind_mask] = new_index + new_unit_ids += [max_index + 1] + splitted_pairs += [(unit_id, new_unit_ids[-1])] + max_index += 1 + + new_sorting = NumpySorting(new_spikes, sampling_frequency=fs, unit_ids=new_unit_ids) + return new_sorting, splitted_pairs diff --git a/src/spikeinterface/generation/tests/test_splitting_tools.py b/src/spikeinterface/generation/tests/test_splitting_tools.py new file mode 100644 index 0000000000..b95fb46bb5 --- /dev/null +++ b/src/spikeinterface/generation/tests/test_splitting_tools.py @@ -0,0 +1,30 @@ +import probeinterface + +from spikeinterface.generation import split_sorting_by_amplitudes, split_sorting_by_times + +from spikeinterface.core.sortinganalyzer import create_sorting_analyzer +from spikeinterface.core.generate import generate_ground_truth_recording + + +def test_split_by_times(): + rec, sorting = generate_ground_truth_recording() + sa = create_sorting_analyzer(sorting, rec) + new_sorting, splitted_pairs = split_sorting_by_times(sa) + assert len(new_sorting.unit_ids) == len(sorting.unit_ids) + len(splitted_pairs) + for pair in splitted_pairs: + p1 = new_sorting.get_unit_spike_train(pair[0]).mean() + p2 = new_sorting.get_unit_spike_train(pair[1]).mean() + assert p1 < p2 + + +def test_split_by_amplitudes(): + rec, sorting = generate_ground_truth_recording() + sa = create_sorting_analyzer(sorting, rec) + sa.compute(["random_spikes", "templates", "spike_amplitudes"]) + new_sorting, splitted_pairs = split_sorting_by_amplitudes(sa) + assert len(new_sorting.unit_ids) == len(sorting.unit_ids) + len(splitted_pairs) + + +if __name__ == "__main__": + test_split_by_times() + test_split_by_amplitudes() diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index eed693b343..8ccd1d3fd1 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -14,10 +14,6 @@ from spikeinterface.sortingcomponents.tools import cache_preprocessing from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sparsity import compute_sparsity -from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.curation.auto_merge import get_potential_auto_merge -from spikeinterface.core.analyzer_extension_core import ComputeTemplates -from spikeinterface.core.sparsity import ChannelSparsity class Spykingcircus2Sorter(ComponentsBasedSorter): @@ -38,14 +34,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): }, "apply_motion_correction": True, "motion_correction": {"preset": "dredge_fast"}, - "merging": { - "similarity_kwargs": {"method": "cosine", "support": "union", "max_lag_ms": 0.2}, - "correlograms_kwargs": {}, - "auto_merge": { - "min_spikes": 10, - "corr_diff_thresh": 0.25, - }, - }, + "merging": {"max_distance_um": 50}, "clustering": {"legacy": True}, "matching": {"method": "circus-omp-svd"}, "apply_preprocessing": True, @@ -71,7 +60,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): True, one other clustering called circus will be used, similar to the one used in Spyking Circus 1", "matching": "A dictionary to specify the matching engine used to recover spikes. The method default is circus-omp-svd, but other engines\ can be used", - "merging": "A dictionary to specify the final merging param to group cells after template matching (get_potential_auto_merge)", + "merging": "A dictionary to specify the final merging param to group cells after template matching (auto_merge_units)", "motion_correction": "A dictionary to be provided if motion correction has to be performed (dense probe only)", "apply_preprocessing": "Boolean to specify whether circus 2 should preprocess the recording or not. If yes, then high_pass filtering + common\ median reference + zscore", @@ -100,14 +89,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): except: HAVE_HDBSCAN = False + assert HAVE_HDBSCAN, "spykingcircus2 needs hdbscan to be installed" + try: import torch except ImportError: HAVE_TORCH = False print("spykingcircus2 could benefit from using torch. Consider installing it") - assert HAVE_HDBSCAN, "spykingcircus2 needs hdbscan to be installed" - # this is importanted only on demand because numba import are too heavy from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks @@ -321,14 +310,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): max_motion = max( np.max(np.abs(motion.displacement[seg_index])) for seg_index in range(len(motion.displacement)) ) - merging_params["max_distance_um"] = max(50, 2 * max_motion) - - # peak_sign = params['detection'].get('peak_sign', 'neg') - # best_amplitudes = get_template_extremum_amplitude(templates, peak_sign=peak_sign) - # guessed_amplitudes = spikes['amplitude'].copy() - # for ind in unit_ids: - # mask = spikes['cluster_index'] == ind - # guessed_amplitudes[mask] *= best_amplitudes[ind] + max_distance_um = merging_params.get("max_distance_um", 50) + merging_params["max_distance_um"] = max(max_distance_um, 2 * max_motion) if params["debug"]: curation_folder = sorter_output_folder / "curation" @@ -337,7 +320,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting.save(folder=curation_folder) # np.save(fitting_folder / "amplitudes", guessed_amplitudes) - sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params) + sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params, **job_kwargs) if verbose: print(f"Final merging, keeping {len(sorting.unit_ids)} units") @@ -358,40 +341,38 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): return sorting -def create_sorting_analyzer_with_templates(sorting, recording, templates, remove_empty=True): - sparsity = templates.sparsity - templates_array = templates.get_dense_templates().copy() - - if remove_empty: - non_empty_unit_ids = sorting.get_non_empty_unit_ids() - non_empty_sorting = sorting.remove_empty_units() - non_empty_unit_indices = sorting.ids_to_indices(non_empty_unit_ids) - templates_array = templates_array[non_empty_unit_indices] - sparsity_mask = sparsity.mask[non_empty_unit_indices, :] - sparsity = ChannelSparsity(sparsity_mask, non_empty_unit_ids, sparsity.channel_ids) - else: - non_empty_sorting = sorting - - sa = create_sorting_analyzer(non_empty_sorting, recording, format="memory", sparsity=sparsity) - sa.extensions["templates"] = ComputeTemplates(sa) - sa.extensions["templates"].params = {"ms_before": templates.ms_before, "ms_after": templates.ms_after} - sa.extensions["templates"].data["average"] = templates_array - return sa - - -def final_cleaning_circus(recording, sorting, templates, **merging_kwargs): - - from spikeinterface.core.sorting_tools import apply_merges_to_sorting - - sa = create_sorting_analyzer_with_templates(sorting, recording, templates) - - sa.compute("unit_locations", method="monopolar_triangulation") - similarity_kwargs = merging_kwargs.pop("similarity_kwargs", {}) - sa.compute("template_similarity", **similarity_kwargs) - correlograms_kwargs = merging_kwargs.pop("correlograms_kwargs", {}) - sa.compute("correlograms", **correlograms_kwargs) - auto_merge_kwargs = merging_kwargs.pop("auto_merge", {}) - merges = get_potential_auto_merge(sa, resolve_graph=True, **auto_merge_kwargs) - sorting = apply_merges_to_sorting(sa.sorting, merges) - - return sorting +def final_cleaning_circus( + recording, + sorting, + templates, + similarity_kwargs={"method": "l2", "support": "union", "max_lag_ms": 0.1}, + apply_merge_kwargs={"sparsity_overlap": 0.1, "censor_ms": 3.0}, + # correlograms_kwargs={}, + max_distance_um=50, + template_diff_thresh=np.arange(0.05, 0.5, 0.05), + **job_kwargs, +): + + from spikeinterface.sortingcomponents.tools import create_sorting_analyzer_with_existing_templates + from spikeinterface.curation.auto_merge import auto_merge_units + + # First we compute the needed extensions + analyzer = create_sorting_analyzer_with_existing_templates(sorting, recording, templates) + analyzer.compute("unit_locations", method="monopolar_triangulation") + analyzer.compute("template_similarity", **similarity_kwargs) + # analyzer.compute("correlograms", **correlograms_kwargs) + + presets = ["x_contaminations"] * len(template_diff_thresh) + steps_params = [ + {"template_similarity": {"template_diff_thresh": i}, "unit_locations": {"max_distance_um": max_distance_um}} + for i in template_diff_thresh + ] + final_sa = auto_merge_units( + analyzer, + presets=presets, + steps_params=steps_params, + apply_merge_kwargs=apply_merge_kwargs, + recursive=True, + **job_kwargs, + ) + return final_sa.sorting diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index d03744f8f9..2240357d27 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -118,7 +118,7 @@ def detect_peaks( squeeze_output = True else: squeeze_output = False - job_name += f" + {len(pipeline_nodes)} nodes" + job_name += f" + {len(pipeline_nodes)} nodes" # because node are modified inplace (insert parent) they need to copy incase # the same pipeline is run several times diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 1501582336..439aee6db8 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -11,10 +11,11 @@ from spikeinterface.core.sparsity import ChannelSparsity from spikeinterface.core.template import Templates - -from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractSparseWaveforms, PeakRetriever from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer from spikeinterface.core.job_tools import split_job_kwargs +from spikeinterface.core.sortinganalyzer import create_sorting_analyzer +from spikeinterface.core.sparsity import ChannelSparsity +from spikeinterface.core.analyzer_extension_core import ComputeTemplates def make_multi_method_doc(methods, ident=" "): @@ -151,3 +152,26 @@ def fit_sigmoid(xdata, ydata, p0=None): popt, pcov = curve_fit(sigmoid, xdata, ydata, p0) return popt + + +def create_sorting_analyzer_with_existing_templates(sorting, recording, templates, remove_empty=True): + sparsity = templates.sparsity + templates_array = templates.get_dense_templates().copy() + + if remove_empty: + non_empty_unit_ids = sorting.get_non_empty_unit_ids() + non_empty_sorting = sorting.remove_empty_units() + non_empty_unit_indices = sorting.ids_to_indices(non_empty_unit_ids) + templates_array = templates_array[non_empty_unit_indices] + sparsity_mask = sparsity.mask[non_empty_unit_indices, :] + sparsity = ChannelSparsity(sparsity_mask, non_empty_unit_ids, sparsity.channel_ids) + else: + non_empty_sorting = sorting + + sa = create_sorting_analyzer(non_empty_sorting, recording, format="memory", sparsity=sparsity) + sa.compute("random_spikes") + sa.extensions["templates"] = ComputeTemplates(sa) + sa.extensions["templates"].params = {"ms_before": templates.ms_before, "ms_after": templates.ms_after} + sa.extensions["templates"].data["average"] = templates_array + sa.extensions["templates"].data["std"] = np.zeros(templates_array.shape, dtype=np.float32) + return sa