Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Nov 5, 2024
1 parent 42754ae commit de4faf8
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 15 deletions.
26 changes: 13 additions & 13 deletions src/spikeinterface/benchmark/benchmark_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def plot_potential_merges(self, case_key, min_snr=None, backend="ipywidgets"):

plot_potential_merges(analyzer, mylist, backend=backend)

def plot_performed_merges(self, case_key, recursive=False, backend="ipywidgets"):
def plot_performed_merges(self, case_key, backend="ipywidgets"):
analyzer = self.get_sorting_analyzer(case_key)

if analyzer.get_extension("spike_amplitudes") is None:
Expand All @@ -167,18 +167,18 @@ def plot_performed_merges(self, case_key, recursive=False, backend="ipywidgets")
analyzer.compute(["correlograms"])

all_merges = self.benchmarks[case_key].result["merges"]
if recursive:
final_merges = {}
for merges in all_merges:
for merge in merges:
for m in merge:
new_list = m
for k in m:
if k in final_merges:
new_list.remove(k)
new_list += final_merges[k]
final_merges[m[0]] = new_list
all_merges = list(final_merges.values())
# if recursive:
# final_merges = {}
# for merges in all_merges:
# for merge in merges:
# for m in merge:
# new_list = m
# for k in m:
# if k in final_merges:
# new_list.remove(k)
# new_list += final_merges[k]
# final_merges[m[0]] = new_list
# all_merges = list(final_merges.values())

from spikeinterface.widgets import plot_potential_merges

Expand Down
25 changes: 23 additions & 2 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,16 @@ def auto_merge_units_internal(
merged_units = len(merge_unit_groups) > 0

if extra_outputs:
merge_unit_groups = all_merging_groups
merge_unit_groups = {}
for merges in all_merging_groups:
for m in merges:
new_list = m
for k in m:
if k in merge_unit_groups:
new_list.remove(k)
new_list += merge_unit_groups[k]
merge_unit_groups[m[0]] = new_list
merge_unit_groups = list(merge_unit_groups.values())
outs = all_outs

if extra_outputs:
Expand Down Expand Up @@ -751,7 +760,19 @@ def auto_merge_units(
if len(to_be_launched) == 1:
all_merging_groups = all_merging_groups[0]
all_outs = all_outs[0]
return sorting_analyzer, all_merging_groups, all_outs

merge_unit_groups = {}
for merges in all_merging_groups:
for m in merges:
new_list = m
for k in m:
if k in merge_unit_groups:
new_list.remove(k)
new_list += merge_unit_groups[k]
merge_unit_groups[m[0]] = new_list
merge_unit_groups = list(merge_unit_groups.values())

return sorting_analyzer, merge_unit_groups, all_outs
else:
return sorting_analyzer

Expand Down

0 comments on commit de4faf8

Please sign in to comment.