Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 18, 2024
1 parent 8e8d91e commit 08a599b
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 24 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface/benchmark/benchmark_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,4 @@ def plot_potential_merges(self, case_key, min_snr=None, backend="ipywidgets"):

from spikeinterface.widgets import plot_potential_merges

plot_potential_merges(analyzer, mylist, backend=backend)
plot_potential_merges(analyzer, mylist, backend=backend)
7 changes: 3 additions & 4 deletions src/spikeinterface/benchmark/tests/test_benchmark_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import shutil

from spikeinterface.benchmark.benchmark_merging import MergingStudy
from spikeinterface.benchmark.tests.common_benchmark_testing import (
make_dataset
)
from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset

from spikeinterface.generation.splitting_tools import split_sorting_by_amplitudes, split_sorting_by_times


@pytest.mark.skip()
def test_benchmark_merging(create_cache_folder):
cache_folder = create_cache_folder
Expand Down Expand Up @@ -72,4 +71,4 @@ def test_benchmark_merging(create_cache_folder):


if __name__ == "__main__":
test_benchmark_merging()
test_benchmark_merging()
8 changes: 5 additions & 3 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ def auto_merges(
sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs)
res_ext = sorting_analyzer.get_extension(step)
if res_ext is None:
print(f"Extension {ext} is computed with default params. Precompute it with custom params if needed")
print(
f"Extension {ext} is computed with default params. Precompute it with custom params if needed"
)
sorting_analyzer.compute(ext, **job_kwargs)
elif not compute_needed_extensions and not sorting_analyzer.has_extension(ext):
raise ValueError(f"{step} requires {ext} extension")
Expand Down Expand Up @@ -558,9 +560,9 @@ def iterative_merges(
):
"""
Wrapper to conveniently be able to launch several presets for auto_merges in a row, as a list. Merges
are applied sequentially or until no more merges are done, one preset at a time, and extensions are
are applied sequentially or until no more merges are done, one preset at a time, and extensions are
not recomputed thanks to the merging units.
Parameters
----------
sorting_analyzer : SortingAnalyzer
Expand Down
8 changes: 5 additions & 3 deletions src/spikeinterface/generation/splitting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,14 @@ def split_sorting_by_times(
new_unit_ids += [max_index + 1]
splitted_pairs += [(unit_id, new_unit_ids[-1])]
max_index += 1

new_sorting = NumpySorting(new_spikes, sampling_frequency=sa.sampling_frequency, unit_ids=new_unit_ids)
return new_sorting, splitted_pairs


def split_sorting_by_amplitudes(sorting_analyzer, splitting_probability=0.5, partial_split_prob=0.95, unit_ids=None, min_snr=None, seed=None):
def split_sorting_by_amplitudes(
sorting_analyzer, splitting_probability=0.5, partial_split_prob=0.95, unit_ids=None, min_snr=None, seed=None
):
"""
Fonction used to split a sorting based on the amplitudes of the units. This
might be used for benchmarking meta merging step (see components)
Expand Down Expand Up @@ -134,4 +136,4 @@ def split_sorting_by_amplitudes(sorting_analyzer, splitting_probability=0.5, par
max_index += 1

new_sorting = NumpySorting(new_spikes, sampling_frequency=sa.sampling_frequency, unit_ids=new_unit_ids)
return new_sorting, splitted_pairs
return new_sorting, splitted_pairs
10 changes: 6 additions & 4 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"apply_motion_correction": True,
"motion_correction": {"preset": "dredge_fast"},
"merging": {
"similarity_kwargs": {"method": "l2", "support": "union", "max_lag_ms": 0.1},
"similarity_kwargs": {"method": "l2", "support": "union", "max_lag_ms": 0.1},
},
"clustering": {"legacy": True},
"matching": {"method": "circus-omp-svd"},
Expand Down Expand Up @@ -90,6 +90,7 @@ def get_sorter_version(cls):
def _run_from_folder(cls, sorter_output_folder, params, verbose):
try:
import hdbscan

HAVE_HDBSCAN = True
except:
HAVE_HDBSCAN = False
Expand Down Expand Up @@ -374,10 +375,11 @@ def final_cleaning_circus(recording, sorting, templates, **merging_kwargs):
sa.compute("template_similarity", **similarity_kwargs)
correlograms_kwargs = merging_kwargs.pop("correlograms_kwargs", {})
sa.compute("correlograms", **correlograms_kwargs)

from spikeinterface.curation.auto_merge import iterative_merges

template_diff_thresh = np.arange(0.05, 0.25, 0.05)
presets_params = [{'template_similarity' : {'template_diff_thresh' : i}} for i in template_diff_thresh]
presets = ['x_contaminations'] * len(template_diff_thresh)
presets_params = [{"template_similarity": {"template_diff_thresh": i}} for i in template_diff_thresh]
presets = ["x_contaminations"] * len(template_diff_thresh)
final_sa = iterative_merges(sa, presets=presets, presets_params=presets_params)
return final_sa.sorting
2 changes: 1 addition & 1 deletion src/spikeinterface/sortingcomponents/merging/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .main import merge_spikes
from .main import merge_spikes
26 changes: 18 additions & 8 deletions src/spikeinterface/sortingcomponents/merging/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

merging_methods = ["circus", "auto_merges"]


def create_sorting_analyzer_with_templates(sorting, recording, templates, remove_empty=True):
sparsity = templates.sparsity
templates_array = templates.get_dense_templates().copy()
Expand All @@ -29,19 +30,27 @@ def create_sorting_analyzer_with_templates(sorting, recording, templates, remove
return sa


def merging_circus(sorting_analyzer, similarity_kwargs={"method": "l2", "support": "union", "max_lag_ms": 0.1}, extra_outputs=False, **job_kwargs):
def merging_circus(
sorting_analyzer,
similarity_kwargs={"method": "l2", "support": "union", "max_lag_ms": 0.1},
extra_outputs=False,
**job_kwargs,
):

if sorting_analyzer.get_extension('templates') is None:
if sorting_analyzer.get_extension("templates") is None:
sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs)
sorting_analyzer.compute("unit_locations", method="monopolar_triangulation")
sorting_analyzer.compute("template_similarity", **similarity_kwargs)
sorting_analyzer.compute("correlograms")

from spikeinterface.curation.auto_merge import iterative_merges

template_diff_thresh = np.arange(0.05, 0.25, 0.05)
presets_params = [{'template_similarity' : {'template_diff_thresh' : i}} for i in template_diff_thresh]
presets = ['x_contaminations'] * len(template_diff_thresh)
return iterative_merges(sorting_analyzer, presets=presets, presets_params=presets_params, extra_outputs=extra_outputs, **job_kwargs)
presets_params = [{"template_similarity": {"template_diff_thresh": i}} for i in template_diff_thresh]
presets = ["x_contaminations"] * len(template_diff_thresh)
return iterative_merges(
sorting_analyzer, presets=presets, presets_params=presets_params, extra_outputs=extra_outputs, **job_kwargs
)


def merge_spikes(
Expand Down Expand Up @@ -89,11 +98,12 @@ def merge_spikes(
return merging_circus(sorting_analyzer, extra_outputs=extra_outputs, **method_kwargs)
elif method == "auto_merges":
from spikeinterface.curation.auto_merge import get_potential_auto_merge

merges = get_potential_auto_merge(sorting_analyzer, **method_kwargs, resolve_graph=True)
new_sa = sorting_analyzer.copy()
new_sa = new_sa.merge_units(merges, merging_mode="soft", sparsity_overlap=0.5, censor_ms=3, **job_kwargs)
sorting = new_sa.sorting
if extra_outputs:
if extra_outputs:
return sorting, merges, []
else:
return sorting
return sorting

0 comments on commit 08a599b

Please sign in to comment.