Skip to content

Commit

Permalink
Merge pull request #3213 from yger/patch_release
Browse files Browse the repository at this point in the history
Patch for SC2 after release AND bugs in auto merge
  • Loading branch information
samuelgarcia authored Jul 17, 2024
2 parents 7562b24 + 661b1e9 commit 4947385
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 15 deletions.
4 changes: 2 additions & 2 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,14 +243,14 @@ def get_potential_auto_merge(
assert step in all_steps, f"{step} is not a valid step"

# STEP : remove units with too few spikes
if step == "min_spikes":
if step == "num_spikes":
num_spikes = sorting.count_num_spikes_per_unit(outputs="array")
to_remove = num_spikes < min_spikes
pair_mask[to_remove, :] = False
pair_mask[:, to_remove] = False

# STEP : remove units with too small SNR
elif step == "min_snr":
elif step == "snr":
qm_ext = sorting_analyzer.get_extension("quality_metrics")
if qm_ext is None:
sorting_analyzer.compute("noise_levels")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _check_one(self, sorting_analyzer, extension_class, params):

some_merges = [sorting_analyzer.unit_ids[:2].tolist()]
num_units_after_merge = len(sorting_analyzer.unit_ids) - 1
merged = sorting_analyzer.merge_units(some_merges, format="memory", mode="soft", sparsity_overlap=0.0)
merged = sorting_analyzer.merge_units(some_merges, format="memory", merging_mode="soft", sparsity_overlap=0.0)
assert len(merged.unit_ids) == num_units_after_merge

def run_extension_tests(self, extension_class, params):
Expand Down
37 changes: 25 additions & 12 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
from spikeinterface.curation.auto_merge import get_potential_auto_merge
from spikeinterface.core.analyzer_extension_core import ComputeTemplates
from spikeinterface.core.sparsity import ChannelSparsity


class Spykingcircus2Sorter(ComponentsBasedSorter):
Expand All @@ -41,8 +42,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"correlograms_kwargs": {},
"auto_merge": {
"min_spikes": 10,
"corr_diff_thresh": 0.5,
"censor_correlograms_ms": 0.4,
"corr_diff_thresh": 0.25,
},
},
"clustering": {"legacy": True},
Expand Down Expand Up @@ -357,27 +357,40 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
return sorting


def final_cleaning_circus(recording, sorting, templates, **merging_kwargs):

from spikeinterface.core.sorting_tools import apply_merges_to_sorting
from spikeinterface.curation.curation_tools import resolve_merging_graph

def create_sorting_analyzer_with_templates(sorting, recording, templates, remove_empty=True):
sparsity = templates.sparsity
templates_array = templates.get_dense_templates().copy()

sa = create_sorting_analyzer(sorting, recording, format="memory", sparsity=sparsity)

if remove_empty:
non_empty_unit_ids = sorting.get_non_empty_unit_ids()
non_empty_sorting = sorting.remove_empty_units()
non_empty_unit_indices = sorting.ids_to_indices(non_empty_unit_ids)
templates_array = templates_array[non_empty_unit_indices]
sparsity_mask = sparsity.mask[non_empty_unit_indices, :]
sparsity = ChannelSparsity(sparsity_mask, non_empty_unit_ids, sparsity.channel_ids)
else:
non_empty_sorting = sorting

sa = create_sorting_analyzer(non_empty_sorting, recording, format="memory", sparsity=sparsity)
sa.extensions["templates"] = ComputeTemplates(sa)
sa.extensions["templates"].params = {"ms_before": templates.ms_before, "ms_after": templates.ms_after}
sa.extensions["templates"].data["average"] = templates_array
return sa


def final_cleaning_circus(recording, sorting, templates, **merging_kwargs):

from spikeinterface.core.sorting_tools import apply_merges_to_sorting

sa = create_sorting_analyzer_with_templates(sorting, recording, templates)

sa.compute("unit_locations", method="monopolar_triangulation")
similarity_kwargs = merging_kwargs.pop("similarity_kwargs", {})
sa.compute("template_similarity", **similarity_kwargs)
correlograms_kwargs = merging_kwargs.pop("correlograms_kwargs", {})
sa.compute("correlograms", **correlograms_kwargs)
auto_merge_kwargs = merging_kwargs.pop("auto_merge", {})
merges = get_potential_auto_merge(sa, **auto_merge_kwargs)
merges = resolve_merging_graph(sorting, merges)
sorting = apply_merges_to_sorting(sorting, merges)
merges = get_potential_auto_merge(sa, resolve_graph=True, **auto_merge_kwargs)
sorting = apply_merges_to_sorting(sa.sorting, merges)

return sorting

0 comments on commit 4947385

Please sign in to comment.