diff --git a/src/spikeinterface/benchmark/benchmark_merging.py b/src/spikeinterface/benchmark/benchmark_merging.py index a1085455fe..83d85a4243 100644 --- a/src/spikeinterface/benchmark/benchmark_merging.py +++ b/src/spikeinterface/benchmark/benchmark_merging.py @@ -158,7 +158,7 @@ def plot_potential_merges(self, case_key, min_snr=None, backend="ipywidgets"): plot_potential_merges(analyzer, mylist, backend=backend) - def plot_performed_merges(self, case_key, recursive=False, backend="ipywidgets"): + def plot_performed_merges(self, case_key, backend="ipywidgets"): analyzer = self.get_sorting_analyzer(case_key) if analyzer.get_extension("spike_amplitudes") is None: @@ -167,18 +167,18 @@ def plot_performed_merges(self, case_key, recursive=False, backend="ipywidgets") analyzer.compute(["correlograms"]) all_merges = self.benchmarks[case_key].result["merges"] - if recursive: - final_merges = {} - for merges in all_merges: - for merge in merges: - for m in merge: - new_list = m - for k in m: - if k in final_merges: - new_list.remove(k) - new_list += final_merges[k] - final_merges[m[0]] = new_list - all_merges = list(final_merges.values()) + # if recursive: + # final_merges = {} + # for merges in all_merges: + # for merge in merges: + # for m in merge: + # new_list = m + # for k in m: + # if k in final_merges: + # new_list.remove(k) + # new_list += final_merges[k] + # final_merges[m[0]] = new_list + # all_merges = list(final_merges.values()) from spikeinterface.widgets import plot_potential_merges diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index e686b7ae96..e2b219bdbe 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -446,7 +446,16 @@ def auto_merge_units_internal( merged_units = len(merge_unit_groups) > 0 if extra_outputs: - merge_unit_groups = all_merging_groups + merge_unit_groups = {} + for merges in all_merging_groups: + for m in merges: + new_list = m + for k in m: + if k in merge_unit_groups: + new_list.remove(k) + new_list += merge_unit_groups[k] + merge_unit_groups[m[0]] = new_list + merge_unit_groups = list(merge_unit_groups.values()) outs = all_outs if extra_outputs: @@ -751,7 +760,19 @@ def auto_merge_units( if len(to_be_launched) == 1: all_merging_groups = all_merging_groups[0] all_outs = all_outs[0] - return sorting_analyzer, all_merging_groups, all_outs + + merge_unit_groups = {} + for merges in all_merging_groups: + for m in merges: + new_list = m + for k in m: + if k in merge_unit_groups: + new_list.remove(k) + new_list += merge_unit_groups[k] + merge_unit_groups[m[0]] = new_list + merge_unit_groups = list(merge_unit_groups.values()) + + return sorting_analyzer, merge_unit_groups, all_outs else: return sorting_analyzer