From b584e676a0cfdb04a0b38cf23337f2cf36d132c8 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 8 Nov 2024 14:24:10 +0100 Subject: [PATCH] Resolving merging while iterative auto_merge --- .../benchmark/benchmark_merging.py | 18 +---- src/spikeinterface/curation/auto_merge.py | 75 +++++++++++-------- 2 files changed, 45 insertions(+), 48 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_merging.py b/src/spikeinterface/benchmark/benchmark_merging.py index 83d85a4243..5239a201cb 100644 --- a/src/spikeinterface/benchmark/benchmark_merging.py +++ b/src/spikeinterface/benchmark/benchmark_merging.py @@ -29,7 +29,7 @@ def run(self, **job_kwargs): ) # sorting_analyzer.compute(['random_spikes', 'templates']) # sorting_analyzer.compute('template_similarity', max_lag_ms=0.1, method="l2", **job_kwargs) - merged_analyzer, self.result["merges"], self.result["outs"] = auto_merge_units( + merged_analyzer, self.result["merged_pairs"], self.result["merges"], self.result["outs"] = auto_merge_units( sorting_analyzer, extra_outputs=True, **self.method_kwargs, **job_kwargs ) @@ -40,7 +40,7 @@ def compute_result(self, **result_params): comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True) self.result["gt_comparison"] = comp - _run_key_saved = [("sorting", "sorting"), ("merges", "pickle"), ("outs", "pickle")] + _run_key_saved = [("sorting", "sorting"), ("merges", "pickle"), ("merged_pairs", "pickle"), ("outs", "pickle")] _result_key_saved = [("gt_comparison", "pickle")] @@ -166,19 +166,7 @@ def plot_performed_merges(self, case_key, backend="ipywidgets"): if analyzer.get_extension("correlograms") is None: analyzer.compute(["correlograms"]) - all_merges = self.benchmarks[case_key].result["merges"] - # if recursive: - # final_merges = {} - # for merges in all_merges: - # for merge in merges: - # for m in merge: - # new_list = m - # for k in m: - # if k in final_merges: - # new_list.remove(k) - # new_list += final_merges[k] - # final_merges[m[0]] = new_list - # all_merges = list(final_merges.values()) + all_merges = list(self.benchmarks[case_key].result["merged_pairs"].values()) from spikeinterface.widgets import plot_potential_merges diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 570d6076c6..84d8e72def 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -371,6 +371,24 @@ def compute_merge_unit_groups( return merge_unit_groups +def resolve_pairs(existing_merges, new_merges): + if existing_merges is None: + return new_merges.copy() + else: + resolved_merges = existing_merges.copy() + old_keys = list(existing_merges.keys()) + for key, pair in new_merges.items(): + nested_merge = np.flatnonzero([i in pair for i in old_keys]) + if len(nested_merge) == 0: + resolved_merges.update({key : pair}) + else: + for n in nested_merge: + previous_merges = resolved_merges.pop(old_keys[n]) + pair.remove(old_keys[n]) + pair += previous_merges + resolved_merges.update({key : pair}) + return resolved_merges + def auto_merge_units_internal( sorting_analyzer: SortingAnalyzer, compute_merge_kwargs: dict = {}, @@ -423,15 +441,20 @@ def auto_merge_units_internal( if extra_outputs: merge_unit_groups, outs = merge_unit_groups - - merged_analyzer = sorting_analyzer.merge_units(merge_unit_groups, **apply_merge_kwargs, **job_kwargs) - + + merged_analyzer, new_unit_ids = sorting_analyzer.merge_units(merge_unit_groups, + return_new_unit_ids=True, + **apply_merge_kwargs, **job_kwargs) + resolved_merges = {key : value for (key, value) in zip(new_unit_ids, merge_unit_groups)} else: merged_units = True merged_analyzer = sorting_analyzer + if extra_outputs: all_merging_groups = [] + resolved_merges = {} all_outs = [] + while merged_units: merge_unit_groups = compute_merge_unit_groups( merged_analyzer, **compute_merge_kwargs, extra_outputs=extra_outputs, force_copy=False, **job_kwargs @@ -439,27 +462,22 @@ def auto_merge_units_internal( if extra_outputs: merge_unit_groups, outs = merge_unit_groups - all_merging_groups += [merge_unit_groups] - all_outs += [outs] - merged_analyzer = merged_analyzer.merge_units(merge_unit_groups, **apply_merge_kwargs, **job_kwargs) merged_units = len(merge_unit_groups) > 0 - - if extra_outputs: - merge_unit_groups = {} - for merges in all_merging_groups: - for m in merges: - new_list = m - for k in m: - if k in merge_unit_groups: - new_list.remove(k) - new_list += merge_unit_groups[k] - merge_unit_groups[m[0]] = new_list - merge_unit_groups = list(merge_unit_groups.values()) - outs = all_outs + + if merged_units: + merged_analyzer, new_unit_ids = merged_analyzer.merge_units(merge_unit_groups, + return_new_unit_ids=True, + **apply_merge_kwargs, **job_kwargs) + + if extra_outputs: + all_merging_groups += [merge_unit_groups] + new_merges = {key : value for (key, value) in zip(new_unit_ids, merge_unit_groups)} + resolved_merges = resolve_pairs(resolved_merges, new_merges) + all_outs += [outs] if extra_outputs: - return merged_analyzer, merge_unit_groups, outs + return merged_analyzer, resolved_merges, merge_unit_groups, outs else: return merged_analyzer @@ -728,6 +746,7 @@ def auto_merge_units( if extra_outputs: all_merging_groups = [] all_outs = [] + resolved_merges = {} if force_copy: sorting_analyzer = sorting_analyzer.copy() @@ -752,23 +771,13 @@ def auto_merge_units( ) if extra_outputs: - sorting_analyzer, merge_unit_groups, outs = sorting_analyzer + sorting_analyzer, new_merges, merge_unit_groups, outs = sorting_analyzer all_merging_groups += [merge_unit_groups] + resolved_merges = resolve_pairs(resolved_merges, new_merges) all_outs += [outs] if extra_outputs: - merge_unit_groups = {} - for merges in all_merging_groups[::-1]: - for m in merges: - new_list = m - for k in m: - if k in merge_unit_groups: - new_list.remove(k) - new_list += merge_unit_groups[k] - merge_unit_groups[m[0]] = new_list - merge_unit_groups = list(merge_unit_groups.values()) - - return sorting_analyzer, merge_unit_groups, all_outs + return sorting_analyzer, resolved_merges, merge_unit_groups, all_outs else: return sorting_analyzer