Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 28, 2024
1 parent 109abf0 commit 07fdb5b
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 57 deletions.
7 changes: 6 additions & 1 deletion src/spikeinterface/curation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from .remove_redundant import remove_redundant_units, find_redundant_units
from .remove_duplicated_spikes import remove_duplicated_spikes
from .remove_excess_spikes import remove_excess_spikes
from .auto_merge import compute_merge_unit_groups, auto_merge_units, get_potential_auto_merge, auto_merge_units_iterative
from .auto_merge import (
compute_merge_unit_groups,
auto_merge_units,
get_potential_auto_merge,
auto_merge_units_iterative,
)


# manual sorting,
Expand Down
110 changes: 55 additions & 55 deletions src/spikeinterface/curation/tests/test_auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,56 +82,56 @@
# **job_kwargs,
# )

# DEBUG
# import matplotlib.pyplot as plt
# from spikeinterface.curation.auto_merge import normalize_correlogram
# templates_diff = outs['templates_diff']
# correlogram_diff = outs['correlogram_diff']
# bins = outs['bins']
# correlograms_smoothed = outs['correlograms_smoothed']
# correlograms = outs['correlograms']
# win_sizes = outs['win_sizes']

# fig, ax = plt.subplots()
# ax.hist(correlogram_diff.flatten(), bins=np.arange(0, 1, 0.05))

# fig, ax = plt.subplots()
# ax.hist(templates_diff.flatten(), bins=np.arange(0, 1, 0.05))

# m = correlograms.shape[2] // 2

# for unit_id1, unit_id2 in merge_unit_groups[:5]:
# unit_ind1 = sorting_with_split.id_to_index(unit_id1)
# unit_ind2 = sorting_with_split.id_to_index(unit_id2)

# bins2 = bins[:-1] + np.mean(np.diff(bins))
# fig, axs = plt.subplots(ncols=3)
# ax = axs[0]
# ax.plot(bins2, correlograms[unit_ind1, unit_ind1, :], color='b')
# ax.plot(bins2, correlograms[unit_ind2, unit_ind2, :], color='r')
# ax.plot(bins2, correlograms_smoothed[unit_ind1, unit_ind1, :], color='b')
# ax.plot(bins2, correlograms_smoothed[unit_ind2, unit_ind2, :], color='r')

# ax.set_title(f'{unit_id1} {unit_id2}')
# ax = axs[1]
# ax.plot(bins2, correlograms_smoothed[unit_ind1, unit_ind2, :], color='g')

# auto_corr1 = normalize_correlogram(correlograms_smoothed[unit_ind1, unit_ind1, :])
# auto_corr2 = normalize_correlogram(correlograms_smoothed[unit_ind2, unit_ind2, :])
# cross_corr = normalize_correlogram(correlograms_smoothed[unit_ind1, unit_ind2, :])

# ax = axs[2]
# ax.plot(bins2, auto_corr1, color='b')
# ax.plot(bins2, auto_corr2, color='r')
# ax.plot(bins2, cross_corr, color='g')

# ax.axvline(bins2[m - win_sizes[unit_ind1]], color='b')
# ax.axvline(bins2[m + win_sizes[unit_ind1]], color='b')
# ax.axvline(bins2[m - win_sizes[unit_ind2]], color='r')
# ax.axvline(bins2[m + win_sizes[unit_ind2]], color='r')

# ax.set_title(f'corr diff {correlogram_diff[unit_ind1, unit_ind2]} - temp diff {templates_diff[unit_ind1, unit_ind2]}')
# plt.show()
# DEBUG
# import matplotlib.pyplot as plt
# from spikeinterface.curation.auto_merge import normalize_correlogram
# templates_diff = outs['templates_diff']
# correlogram_diff = outs['correlogram_diff']
# bins = outs['bins']
# correlograms_smoothed = outs['correlograms_smoothed']
# correlograms = outs['correlograms']
# win_sizes = outs['win_sizes']

# fig, ax = plt.subplots()
# ax.hist(correlogram_diff.flatten(), bins=np.arange(0, 1, 0.05))

# fig, ax = plt.subplots()
# ax.hist(templates_diff.flatten(), bins=np.arange(0, 1, 0.05))

# m = correlograms.shape[2] // 2

# for unit_id1, unit_id2 in merge_unit_groups[:5]:
# unit_ind1 = sorting_with_split.id_to_index(unit_id1)
# unit_ind2 = sorting_with_split.id_to_index(unit_id2)

# bins2 = bins[:-1] + np.mean(np.diff(bins))
# fig, axs = plt.subplots(ncols=3)
# ax = axs[0]
# ax.plot(bins2, correlograms[unit_ind1, unit_ind1, :], color='b')
# ax.plot(bins2, correlograms[unit_ind2, unit_ind2, :], color='r')
# ax.plot(bins2, correlograms_smoothed[unit_ind1, unit_ind1, :], color='b')
# ax.plot(bins2, correlograms_smoothed[unit_ind2, unit_ind2, :], color='r')

# ax.set_title(f'{unit_id1} {unit_id2}')
# ax = axs[1]
# ax.plot(bins2, correlograms_smoothed[unit_ind1, unit_ind2, :], color='g')

# auto_corr1 = normalize_correlogram(correlograms_smoothed[unit_ind1, unit_ind1, :])
# auto_corr2 = normalize_correlogram(correlograms_smoothed[unit_ind2, unit_ind2, :])
# cross_corr = normalize_correlogram(correlograms_smoothed[unit_ind1, unit_ind2, :])

# ax = axs[2]
# ax.plot(bins2, auto_corr1, color='b')
# ax.plot(bins2, auto_corr2, color='r')
# ax.plot(bins2, cross_corr, color='g')

# ax.axvline(bins2[m - win_sizes[unit_ind1]], color='b')
# ax.axvline(bins2[m + win_sizes[unit_ind1]], color='b')
# ax.axvline(bins2[m - win_sizes[unit_ind2]], color='r')
# ax.axvline(bins2[m + win_sizes[unit_ind2]], color='r')

# ax.set_title(f'corr diff {correlogram_diff[unit_ind1, unit_ind2]} - temp diff {templates_diff[unit_ind1, unit_ind2]}')
# plt.show()


def test_auto_merge_units(sorting_analyzer_for_curation):
Expand All @@ -154,9 +154,9 @@ def test_auto_merge_units(sorting_analyzer_for_curation):
**job_kwargs,
)

merged_sorting = auto_merge_units(sorting_analyzer, {"preset" : "x_contaminations"})
merged_sorting = auto_merge_units(sorting_analyzer, {"preset": "x_contaminations"})
assert len(merged_sorting.unit_ids) < len(sorting_analyzer_for_curation.unit_ids)


def test_auto_merge_units_iterative(sorting_analyzer_for_curation):

Expand All @@ -178,14 +178,14 @@ def test_auto_merge_units_iterative(sorting_analyzer_for_curation):
**job_kwargs,
)

merged_sorting = auto_merge_units_iterative(sorting_analyzer, [{"preset" : "x_contaminations"}])
merged_sorting = auto_merge_units_iterative(sorting_analyzer, [{"preset": "x_contaminations"}])
assert len(merged_sorting.unit_ids) < len(sorting_analyzer_for_curation.unit_ids)


if __name__ == "__main__":
sorting_analyzer = make_sorting_analyzer(sparse=True)
# preset = "x_contaminations"
preset = None
#test_compute_merge_unit_groups(sorting_analyzer, preset=preset)
# test_compute_merge_unit_groups(sorting_analyzer, preset=preset)
test_auto_merge_units(sorting_analyzer)
test_auto_merge_units_iterative(sorting_analyzer)
test_auto_merge_units_iterative(sorting_analyzer)
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def final_cleaning_circus(
sorting,
templates,
similarity_kwargs={"method": "l2", "support": "union", "max_lag_ms": 0.1},
apply_merge_kwargs={"merging_mode": "soft", "sparsity_overlap": 0.5, "censor_ms" : 3.},
apply_merge_kwargs={"merging_mode": "soft", "sparsity_overlap": 0.5, "censor_ms": 3.0},
correlograms_kwargs={},
):

Expand Down

0 comments on commit 07fdb5b

Please sign in to comment.