Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve merging and iterative merging #3487

Open
wants to merge 91 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
3221527
WIP
yger Oct 16, 2024
557a77a
Iterative merges
yger Oct 16, 2024
ee23401
WIP
yger Oct 16, 2024
813c888
Merge branch 'auto_merge_refactoring' into meta_merging
yger Oct 16, 2024
af0cfa8
Merge branch 'auto_merge_refactoring' into meta_merging
yger Oct 16, 2024
a399b28
WIP
yger Oct 16, 2024
1aeaecf
WIP
yger Oct 16, 2024
9642e15
WIP
yger Oct 17, 2024
fb63348
Display
yger Oct 17, 2024
c88f43b
Harmonization of splitting functions and removing dependencies
yger Oct 17, 2024
49efdb1
WIP
yger Oct 17, 2024
297b6b0
WIP
yger Oct 18, 2024
b5a7800
WIP
yger Oct 18, 2024
8e8d91e
Merge branch 'SpikeInterface:main' into meta_merging
yger Oct 18, 2024
08a599b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 18, 2024
b10ca04
Merge branch 'SpikeInterface:main' into meta_merging
yger Oct 28, 2024
e7fe01a
Adding tests
yger Oct 28, 2024
7f53178
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2024
96ceb9b
Naming
yger Oct 28, 2024
ae42eb9
Merge branch 'meta_merging' of github.com:yger/spikeinterface into me…
yger Oct 28, 2024
467c751
Merge branch 'auto_merge_refactoring' into meta_merging
yger Oct 28, 2024
ae8de50
Adding an auto_merge_iterative function
yger Oct 28, 2024
0a64a67
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2024
109abf0
Adding tests
yger Oct 28, 2024
ef3f4bc
Tests
yger Oct 28, 2024
07fdb5b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2024
b023161
Tests
yger Oct 28, 2024
9808e37
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2024
5f7ac5b
Testsg
yger Oct 28, 2024
633d628
Merge branch 'meta_merging' of github.com:yger/spikeinterface into me…
yger Oct 28, 2024
6c0f90d
Cleaning benchmark
yger Oct 28, 2024
6163fbe
Cleaning benchmark
yger Oct 28, 2024
8b033ec
Merge branch 'meta_merging' of github.com:yger/spikeinterface into me…
yger Oct 29, 2024
448d3cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2024
507df4c
WIP
yger Oct 29, 2024
fd25ca1
Merge branch 'auto_merge_refactoring' into meta_merging
yger Oct 29, 2024
831bb76
Refactoring
yger Oct 29, 2024
e01e0cc
Merge branch 'meta_merging' of github.com:yger/spikeinterface into me…
yger Oct 29, 2024
622115f
Naming circus 2
yger Oct 29, 2024
94d869a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2024
42a3464
WIP
yger Oct 29, 2024
eb69353
Merge branch 'auto_merge_refactoring' into meta_merging
yger Oct 29, 2024
0a5f32f
WIP
yger Oct 29, 2024
e301e1a
WIPé
yger Oct 29, 2024
c6e71b8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2024
57c5b41
Tests
yger Oct 29, 2024
08da941
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2024
0ec16e1
WIP
yger Oct 29, 2024
a5c6be3
Merge branch 'meta_merging' of github.com:yger/spikeinterface into me…
yger Oct 29, 2024
528db42
Merge branch 'meta_merging' of github.com:yger/spikeinterface into me…
yger Oct 29, 2024
2f06bb6
Merge branch 'meta_merging' of github.com:yger/spikeinterface into me…
yger Oct 29, 2024
2bf61e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2024
16dcfb8
Docs
yger Oct 29, 2024
f6563cd
Docs
yger Oct 29, 2024
7c3e7ba
Fix the bug in benchmarks
yger Oct 30, 2024
9d6c1f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 30, 2024
48f65bb
Fix
yger Oct 30, 2024
ba6d835
Merge branch 'meta_merging' of github.com:yger/spikeinterface into me…
yger Oct 30, 2024
43b383b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 30, 2024
f003edc
Docs
yger Oct 30, 2024
9223bb2
Merge branch 'meta_merging' of github.com:yger/spikeinterface into me…
yger Oct 30, 2024
73b3a83
Cleaning
yger Oct 30, 2024
b296c81
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 30, 2024
86a00af
Merge branch 'auto_merge_refactoring' into meta_merging
yger Oct 30, 2024
285f785
Benchmarking
yger Oct 30, 2024
b34ad92
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 30, 2024
9a9cf7d
Benchmark
yger Oct 30, 2024
de3a4a4
Benchmark
yger Oct 30, 2024
23c455e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 30, 2024
60e98b4
WIP
yger Oct 30, 2024
27707fd
Merge branch 'meta_merging' of github.com:yger/spikeinterface into me…
yger Oct 30, 2024
e913a7b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 30, 2024
a396275
Merge branch 'SpikeInterface:main' into meta_merging
yger Oct 31, 2024
392cdcc
Sync with auto_merge
yger Oct 31, 2024
33ae9eb
Merge branch 'SpikeInterface:main' into meta_merging
yger Nov 3, 2024
2f80b8d
Merge branch 'SpikeInterface:main' into meta_merging
yger Nov 4, 2024
42754ae
Merge branch 'SpikeInterface:main' into meta_merging
yger Nov 5, 2024
de4faf8
WIP
yger Nov 5, 2024
85cd163
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2024
3b76b5a
Merge branch 'SpikeInterface:main' into meta_merging
yger Nov 8, 2024
397e54d
To properly resolve the pairs
yger Nov 8, 2024
b584e67
Resolving merging while iterative auto_merge
yger Nov 8, 2024
7f4d2d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 8, 2024
d12f4ee
Merge branch 'extension_dependencies' into meta_merging
yger Nov 8, 2024
27fed08
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Nov 11, 2024
1129b22
Fixing tests
yger Nov 12, 2024
cdac4b9
Merge branch 'SpikeInterface:main' into meta_merging
yger Nov 29, 2024
d310021
Calibration with first benchmarks
yger Nov 29, 2024
097d46c
Removing correlograms
yger Nov 29, 2024
e36c2b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 29, 2024
a07e94a
Merge branch 'SpikeInterface:main' into meta_merging
yger Dec 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions src/spikeinterface/benchmark/benchmark_merging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from __future__ import annotations

from spikeinterface.sortingcomponents.merging import merge_spikes
from spikeinterface.comparison import compare_sorter_to_ground_truth
from spikeinterface.widgets import (
plot_agreement_matrix,
plot_unit_templates,
plot_amplitudes,
plot_crosscorrelograms,
)

import numpy as np
from .benchmark_base import Benchmark, BenchmarkStudy


class MergingBenchmark(Benchmark):
yger marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, recording, splitted_sorting, params, gt_sorting, splitted_cells=None):
self.recording = recording
self.splitted_sorting = splitted_sorting
self.method = params["method"]
self.gt_sorting = gt_sorting
self.splitted_cells = splitted_cells
self.method_kwargs = params["method_kwargs"]
self.result = {}

def run(self, **job_kwargs):
self.result["sorting"], self.result["merges"], self.result["outs"] = merge_spikes(
yger marked this conversation as resolved.
Show resolved Hide resolved
self.recording,
self.splitted_sorting,
method=self.method,
verbose=True,
extra_outputs=True,
method_kwargs=self.method_kwargs,
)

def compute_result(self, **result_params):
sorting = self.result["sorting"]
comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True)
self.result["gt_comparison"] = comp

_run_key_saved = [("sorting", "sorting"), ("merges", "pickle"), ("outs", "pickle")]
_result_key_saved = [("gt_comparison", "pickle")]


class MergingStudy(BenchmarkStudy):

benchmark_class = MergingBenchmark

def create_benchmark(self, key):
dataset_key = self.cases[key]["dataset"]
recording, gt_sorting = self.datasets[dataset_key]
params = self.cases[key]["params"]
init_kwargs = self.cases[key]["init_kwargs"]
benchmark = MergingBenchmark(recording, gt_sorting, params, **init_kwargs)
return benchmark

def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None):
import pandas as pd

if case_keys is None:
case_keys = list(self.cases.keys())

if isinstance(case_keys[0], str):
index = pd.Index(case_keys, name=self.levels)
else:
index = pd.MultiIndex.from_tuples(case_keys, names=self.levels)

columns = ["num_gt", "num_sorter", "num_well_detected"]
comp = self.get_result(case_keys[0])["gt_comparison"]
if comp.exhaustive_gt:
columns.extend(["num_false_positive", "num_redundant", "num_overmerged", "num_bad"])
count_units = pd.DataFrame(index=index, columns=columns, dtype=int)

for key in case_keys:
comp = self.get_result(key)["gt_comparison"]
assert comp is not None, "You need to do study.run_comparisons() first"

gt_sorting = comp.sorting1
sorting = comp.sorting2

count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids())
count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids())
count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score)

if comp.exhaustive_gt:
count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score)
count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score)
count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score)
count_units.loc[key, "num_bad"] = comp.count_bad_units()

return count_units

def plot_agreement_matrix(self, **kwargs):
from .benchmark_plot_tools import plot_agreement_matrix

return plot_agreement_matrix(self, **kwargs)

def plot_unit_counts(self, case_keys=None, figsize=None):
from spikeinterface.widgets.widget_list import plot_study_unit_counts

plot_study_unit_counts(self, case_keys, figsize=figsize)

def get_splitted_pairs(self, case_key):
return self.benchmarks[case_key].splitted_cells

def get_splitted_pairs_index(self, case_key, pair):
for count, i in enumerate(self.benchmarks[case_key].splitted_cells):
if i == pair:
return count

def plot_splitted_amplitudes(self, case_key, pair_index=0, backend="ipywidgets"):
analyzer = self.get_sorting_analyzer(case_key)
if analyzer.get_extension("spike_amplitudes") is None:
analyzer.compute(["spike_amplitudes"])
plot_amplitudes(analyzer, unit_ids=self.get_splitted_pairs(case_key)[pair_index], backend=backend)

def plot_splitted_correlograms(self, case_key, pair_index=0, backend="ipywidgets"):
analyzer = self.get_sorting_analyzer(case_key)
if analyzer.get_extension("correlograms") is None:
analyzer.compute(["correlograms"])
if analyzer.get_extension("template_similarity") is None:
analyzer.compute(["template_similarity"])
plot_crosscorrelograms(analyzer, unit_ids=self.get_splitted_pairs(case_key)[pair_index])

def plot_splitted_templates(self, case_key, pair_index=0, backend="ipywidgets"):
analyzer = self.get_sorting_analyzer(case_key)
if analyzer.get_extension("spike_amplitudes") is None:
analyzer.compute(["spike_amplitudes"])
plot_unit_templates(analyzer, unit_ids=self.get_splitted_pairs(case_key)[pair_index], backend=backend)

def plot_potential_merges(self, case_key, min_snr=None, backend="ipywidgets"):
analyzer = self.get_sorting_analyzer(case_key)
mylist = self.get_splitted_pairs(case_key)

if analyzer.get_extension("spike_amplitudes") is None:
analyzer.compute(["spike_amplitudes"])
if analyzer.get_extension("correlograms") is None:
analyzer.compute(["correlograms"])

if min_snr is not None:
select_from = analyzer.sorting.unit_ids
if analyzer.get_extension("noise_levels") is None:
analyzer.compute("noise_levels")
if analyzer.get_extension("quality_metrics") is None:
analyzer.compute("quality_metrics", metric_names=["snr"])

snr = analyzer.get_extension("quality_metrics").get_data()["snr"].values
select_from = select_from[snr > min_snr]
mylist_selection = []
for i in mylist:
if (i[0] in select_from) or (i[1] in select_from):
mylist_selection += [i]
mylist = mylist_selection

from spikeinterface.widgets import plot_potential_merges

plot_potential_merges(analyzer, mylist, backend=backend)
74 changes: 74 additions & 0 deletions src/spikeinterface/benchmark/tests/test_benchmark_merging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pytest
from pathlib import Path
import numpy as np

import shutil

from spikeinterface.benchmark.benchmark_merging import MergingStudy
from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset

from spikeinterface.generation.splitting_tools import split_sorting_by_amplitudes, split_sorting_by_times


@pytest.mark.skip()
def test_benchmark_merging(create_cache_folder):
cache_folder = create_cache_folder
job_kwargs = dict(n_jobs=0.8, chunk_duration="1s")

recording, gt_sorting, gt_analyzer = make_dataset()

# create study
study_folder = cache_folder / "study_clustering"
# datasets = {"toy": (recording, gt_sorting)}
datasets = {"toy": gt_analyzer}

gt_analyzer.compute(["random_spikes", "templates", "spike_amplitudes"])

splitted_sorting = {}
splitted_sorting["times"] = split_sorting_by_times(gt_analyzer)
splitted_sorting["amplitudes"] = split_sorting_by_amplitudes(gt_analyzer)

cases = {}
for splits in ["times", "amplitudes"]:
for method in ["circus", "lussac"]:
cases[(method, splits)] = {
"label": f"{method}",
"dataset": "toy",
"init_kwargs": {"gt_sorting": gt_sorting, "splitted_cells": splitted_sorting[splits][1]},
"params": {"method": method, "splitted_sorting": splitted_sorting[splits][0], "method_kwargs": {}},
}

if study_folder.exists():
shutil.rmtree(study_folder)
study = MergingStudy.create(study_folder, datasets=datasets, cases=cases)
print(study)

# this study needs analyzer
# study.create_sorting_analyzer_gt(**job_kwargs)
study.compute_metrics()

study = MergingStudy(study_folder)

# run and result
study.run(**job_kwargs)
study.compute_results()

# load study to check persistency
study = MergingStudy(study_folder)
print(study)

# plots
# study.plot_performances_vs_snr()
study.plot_agreements()
study.plot_unit_counts()
# study.plot_error_metrics()
# study.plot_metrics_vs_snr()
# study.plot_run_times()
# study.plot_metrics_vs_snr("cosine")
# study.homogeneity_score(ignore_noise=False)
# import matplotlib.pyplot as plt
# plt.show()


if __name__ == "__main__":
test_benchmark_merging()
127 changes: 125 additions & 2 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ def auto_merges(
sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs)
res_ext = sorting_analyzer.get_extension(step)
if res_ext is None:
print(f"Extension {ext} is computed with default params. Precompute it with custom params if needed")
print(
f"Extension {ext} is computed with default params. Precompute it with custom params if needed"
)
sorting_analyzer.compute(ext, **job_kwargs)
elif not compute_needed_extensions and not sorting_analyzer.has_extension(ext):
raise ValueError(f"{step} requires {ext} extension")
Expand All @@ -233,7 +235,7 @@ def auto_merges(
params.update(steps_params[step])

# STEP : remove units with too few spikes
if step == "num_spikes":
if step == "min_spikes":

num_spikes = sorting.count_num_spikes_per_unit(outputs="array")
to_remove = num_spikes < params["min_spikes"]
Expand Down Expand Up @@ -545,6 +547,127 @@ def get_potential_auto_merge(
)


def iterative_merges(
sorting_analyzer,
presets,
presets_params=None,
merging_kwargs={"merging_mode": "soft", "sparsity_overlap": 0.5, "censor_ms": 3},
compute_needed_extensions=True,
verbose=False,
greedy_merges=False,
extra_outputs=False,
**job_kwargs,
):
"""
Wrapper to conveniently be able to launch several presets for auto_merges in a row, as a list. Merges
are applied sequentially or until no more merges are done, one preset at a time, and extensions are
not recomputed thanks to the merging units.

Parameters
----------
sorting_analyzer : SortingAnalyzer
The SortingAnalyzer
presets : list of presets for the auto_merges() functions. Presets can be in
"similarity_correlograms" | "x_contaminations" | "temporal_splits" | "feature_neighbors"
(see auto_merge for more details)
params : list of params that should be given to all presets. Should have the same length as presets
merging_kwargs : dict, the paramaters that should be used while merging units after each preset
compute_needed_extensions : bool, default True
During the preset, boolean to specify is extensions needed by the steps should be recomputed,
or used as they are if already present in the sorting_analyzer
greedy_merges : bool, default: False
If True, then each presets of the list is applied until no further merges can be done, before trying
the next one
extra_outputs : bool, default: False
If True, additional list of merges applied at every preset, and dictionary (`outs`) with processed data are returned.
Returns
-------
sorting_analyzer:
The new sorting analyzer where all the merges from all the presets have been applied
merges, outs:
Returned only when extra_outputs=True
A list with all the merges performed at every steps, and dictionaries that contains data for debugging and plotting.
"""

if presets_params is None:
presets_params = [dict()] * len(presets)

assert len(presets) == len(presets_params)
n_units = max(sorting_analyzer.unit_ids) + 1

if compute_needed_extensions:
sorting_analyzer = sorting_analyzer.copy()

if extra_outputs:
all_merges = []
all_outs = []

preset_number = 0

while True:

if preset_number == len(presets_params):
break

should_compute_extensions = bool(compute_needed_extensions * (preset_number == 0))

if extra_outputs:
merges, outs = auto_merges(
sorting_analyzer,
preset=presets[preset_number],
resolve_graph=True,
compute_needed_extensions=should_compute_extensions,
extra_outputs=extra_outputs,
force_copy=False,
steps_params=presets_params[preset_number],
**job_kwargs,
)

all_merges += [merges]
all_outs += [outs]
else:
merges = auto_merges(
sorting_analyzer,
preset=presets[preset_number],
resolve_graph=True,
compute_needed_extensions=should_compute_extensions,
extra_outputs=extra_outputs,
force_copy=False,
steps_params=presets_params[preset_number],
**job_kwargs,
)

if verbose:
n_merges = int(np.sum([len(i) for i in merges]))
print(f"{n_merges} merges have been made during pass", presets[preset_number])

if greedy_merges:
if n_merges == 0:
preset_number += 1
else:
preset_number += 1

sorting_analyzer = sorting_analyzer.merge_units(merges, **merging_kwargs, **job_kwargs)

if extra_outputs:

final_merges = {}
for merge in all_merges:
for count, m in enumerate(merge):
new_list = m
for k in m:
if k in final_merges:
new_list.remove(k)
new_list += final_merges[k]
final_merges[count + n_units] = new_list
if len(final_merges.keys()) > 0:
n_units = max(final_merges.keys()) + 1

return sorting_analyzer.sorting, list(final_merges.values()), all_outs
else:
return sorting_analyzer.sorting


def get_pairs_via_nntree(sorting_analyzer, k_nn=5, pair_mask=None, **knn_kwargs):

sorting = sorting_analyzer.sorting
Expand Down
Loading
Loading