Skip to content

Commit

Permalink
Resolving merging while iterative auto_merge
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Nov 8, 2024
1 parent 397e54d commit b584e67
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 48 deletions.
18 changes: 3 additions & 15 deletions src/spikeinterface/benchmark/benchmark_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def run(self, **job_kwargs):
)
# sorting_analyzer.compute(['random_spikes', 'templates'])
# sorting_analyzer.compute('template_similarity', max_lag_ms=0.1, method="l2", **job_kwargs)
merged_analyzer, self.result["merges"], self.result["outs"] = auto_merge_units(
merged_analyzer, self.result["merged_pairs"], self.result["merges"], self.result["outs"] = auto_merge_units(
sorting_analyzer, extra_outputs=True, **self.method_kwargs, **job_kwargs
)

Expand All @@ -40,7 +40,7 @@ def compute_result(self, **result_params):
comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True)
self.result["gt_comparison"] = comp

_run_key_saved = [("sorting", "sorting"), ("merges", "pickle"), ("outs", "pickle")]
_run_key_saved = [("sorting", "sorting"), ("merges", "pickle"), ("merged_pairs", "pickle"), ("outs", "pickle")]
_result_key_saved = [("gt_comparison", "pickle")]


Expand Down Expand Up @@ -166,19 +166,7 @@ def plot_performed_merges(self, case_key, backend="ipywidgets"):
if analyzer.get_extension("correlograms") is None:
analyzer.compute(["correlograms"])

all_merges = self.benchmarks[case_key].result["merges"]
# if recursive:
# final_merges = {}
# for merges in all_merges:
# for merge in merges:
# for m in merge:
# new_list = m
# for k in m:
# if k in final_merges:
# new_list.remove(k)
# new_list += final_merges[k]
# final_merges[m[0]] = new_list
# all_merges = list(final_merges.values())
all_merges = list(self.benchmarks[case_key].result["merged_pairs"].values())

from spikeinterface.widgets import plot_potential_merges

Expand Down
75 changes: 42 additions & 33 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,24 @@ def compute_merge_unit_groups(
return merge_unit_groups


def resolve_pairs(existing_merges, new_merges):
if existing_merges is None:
return new_merges.copy()
else:
resolved_merges = existing_merges.copy()
old_keys = list(existing_merges.keys())
for key, pair in new_merges.items():
nested_merge = np.flatnonzero([i in pair for i in old_keys])
if len(nested_merge) == 0:
resolved_merges.update({key : pair})
else:
for n in nested_merge:
previous_merges = resolved_merges.pop(old_keys[n])
pair.remove(old_keys[n])
pair += previous_merges
resolved_merges.update({key : pair})
return resolved_merges

def auto_merge_units_internal(
sorting_analyzer: SortingAnalyzer,
compute_merge_kwargs: dict = {},
Expand Down Expand Up @@ -423,43 +441,43 @@ def auto_merge_units_internal(

if extra_outputs:
merge_unit_groups, outs = merge_unit_groups

merged_analyzer = sorting_analyzer.merge_units(merge_unit_groups, **apply_merge_kwargs, **job_kwargs)


merged_analyzer, new_unit_ids = sorting_analyzer.merge_units(merge_unit_groups,
return_new_unit_ids=True,
**apply_merge_kwargs, **job_kwargs)
resolved_merges = {key : value for (key, value) in zip(new_unit_ids, merge_unit_groups)}
else:
merged_units = True
merged_analyzer = sorting_analyzer

if extra_outputs:
all_merging_groups = []
resolved_merges = {}
all_outs = []

while merged_units:
merge_unit_groups = compute_merge_unit_groups(
merged_analyzer, **compute_merge_kwargs, extra_outputs=extra_outputs, force_copy=False, **job_kwargs
)

if extra_outputs:
merge_unit_groups, outs = merge_unit_groups
all_merging_groups += [merge_unit_groups]
all_outs += [outs]

merged_analyzer = merged_analyzer.merge_units(merge_unit_groups, **apply_merge_kwargs, **job_kwargs)
merged_units = len(merge_unit_groups) > 0

if extra_outputs:
merge_unit_groups = {}
for merges in all_merging_groups:
for m in merges:
new_list = m
for k in m:
if k in merge_unit_groups:
new_list.remove(k)
new_list += merge_unit_groups[k]
merge_unit_groups[m[0]] = new_list
merge_unit_groups = list(merge_unit_groups.values())
outs = all_outs

if merged_units:
merged_analyzer, new_unit_ids = merged_analyzer.merge_units(merge_unit_groups,
return_new_unit_ids=True,
**apply_merge_kwargs, **job_kwargs)

if extra_outputs:
all_merging_groups += [merge_unit_groups]
new_merges = {key : value for (key, value) in zip(new_unit_ids, merge_unit_groups)}
resolved_merges = resolve_pairs(resolved_merges, new_merges)
all_outs += [outs]

if extra_outputs:
return merged_analyzer, merge_unit_groups, outs
return merged_analyzer, resolved_merges, merge_unit_groups, outs
else:
return merged_analyzer

Expand Down Expand Up @@ -728,6 +746,7 @@ def auto_merge_units(
if extra_outputs:
all_merging_groups = []
all_outs = []
resolved_merges = {}

if force_copy:
sorting_analyzer = sorting_analyzer.copy()
Expand All @@ -752,23 +771,13 @@ def auto_merge_units(
)

if extra_outputs:
sorting_analyzer, merge_unit_groups, outs = sorting_analyzer
sorting_analyzer, new_merges, merge_unit_groups, outs = sorting_analyzer
all_merging_groups += [merge_unit_groups]
resolved_merges = resolve_pairs(resolved_merges, new_merges)
all_outs += [outs]

if extra_outputs:
merge_unit_groups = {}
for merges in all_merging_groups[::-1]:
for m in merges:
new_list = m
for k in m:
if k in merge_unit_groups:
new_list.remove(k)
new_list += merge_unit_groups[k]
merge_unit_groups[m[0]] = new_list
merge_unit_groups = list(merge_unit_groups.values())

return sorting_analyzer, merge_unit_groups, all_outs
return sorting_analyzer, resolved_merges, merge_unit_groups, all_outs
else:
return sorting_analyzer

Expand Down

0 comments on commit b584e67

Please sign in to comment.