Skip to content

Commit

Permalink
WIPé
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Oct 29, 2024
2 parents 0a5f32f + 94d869a commit e301e1a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 23 deletions.
5 changes: 2 additions & 3 deletions src/spikeinterface/benchmark/benchmark_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def run(self, **job_kwargs):
sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs)
sorting_analyzer.compute("unit_locations", method="monopolar_triangulation")
sorting_analyzer.compute("template_similarity", **{"method": "l2", "support": "union", "max_lag_ms": 0.1})
#sorting_analyzer.compute("correlograms", **correlograms_kwargs)
# sorting_analyzer.compute("correlograms", **correlograms_kwargs)

merged_analyzer, self.result["merges"], self.result["outs"] = auto_merge_units(
sorting_analyzer,
extra_outputs=True,
Expand Down Expand Up @@ -160,4 +160,3 @@ def plot_potential_merges(self, case_key, min_snr=None, backend="ipywidgets"):
from spikeinterface.widgets import plot_potential_merges

plot_potential_merges(analyzer, mylist, backend=backend)

24 changes: 10 additions & 14 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,6 @@ def compute_merge_unit_groups(
outs["pairs_decreased_score"] = pairs_decreased_score

ind1, ind2 = np.nonzero(pair_mask)
print(step, len(ind1))

# FINAL STEP : create the final list from pair_mask boolean matrix
ind1, ind2 = np.nonzero(pair_mask)
Expand Down Expand Up @@ -390,7 +389,7 @@ def auto_merge_units_internal(
-------
sorting_analyzer:
The new sorting analyzer where all the merges from all the presets have been applied
merges, outs:
Returned only when extra_outputs=True
A list with the merges performed, and dictionaries that contains data for debugging and plotting.
Expand Down Expand Up @@ -626,7 +625,7 @@ def auto_merge_units(
steps_params: dict = None,
steps: list[str] | None = None,
apply_merge_kwargs: dict = {},
recursive : bool = False,
recursive : bool = True,
extra_outputs: bool = False,
**job_kwargs,
) -> SortingAnalyzer:
Expand Down Expand Up @@ -660,7 +659,7 @@ def auto_merge_units(
presets = [presets]

if (steps is not None) and (presets is not None):
raise Exception('presets and steps are mutually exclusive')
raise Exception("presets and steps are mutually exclusive")

if presets is not None:
to_be_launched = presets
Expand All @@ -672,26 +671,23 @@ def auto_merge_units(
if steps_params is not None:
assert len(steps_params) == len(to_be_launched), f"steps params should have the same size as {launch_mode}"
else:
steps_params = [None]*len(to_be_launched)
steps_params = [None] * len(to_be_launched)

if extra_outputs:
all_merging_groups = []
all_outs = []

for to_launch, params in zip(to_be_launched, steps_params):

if launch_mode == "presets":
compute_merge_kwargs = {"preset" : to_launch}
compute_merge_kwargs = {"preset": to_launch}
elif launch_mode == "steps":
compute_merge_kwargs = {"steps" : to_launch}
compute_merge_kwargs = {"steps": to_launch}

compute_merge_kwargs.update({"steps_params": params})
sorting_analyzer = auto_merge_units_internal(sorting_analyzer,
compute_merge_kwargs,
apply_merge_kwargs,
recursive,
extra_outputs,
**job_kwargs)
sorting_analyzer = auto_merge_units_internal(
sorting_analyzer, compute_merge_kwargs, apply_merge_kwargs, recursive, extra_outputs, **job_kwargs
)

if extra_outputs:
sorting_analyzer, merge_unit_groups, outs = sorting_analyzer
Expand Down
10 changes: 4 additions & 6 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,8 @@ def final_cleaning_circus(

template_diff_thresh = np.arange(0.05, 0.25, 0.05)
presets = ["x_contaminations"] * len(template_diff_thresh)
steps_params = [{"template_similarity": {"template_diff_thresh": i}} for i in template_diff_thresh
]
final_sa = auto_merge_units(analyzer,
presets=presets,
steps_params=steps_params,
apply_merge_kwargs=apply_merge_kwargs)
steps_params = [{"template_similarity": {"template_diff_thresh": i}} for i in template_diff_thresh]
final_sa = auto_merge_units(
analyzer, presets=presets, steps_params=steps_params, apply_merge_kwargs=apply_merge_kwargs
)
return final_sa.sorting

0 comments on commit e301e1a

Please sign in to comment.