diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index eae793d21b..aecbdc1b1d 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -3,7 +3,12 @@ from .remove_redundant import remove_redundant_units, find_redundant_units from .remove_duplicated_spikes import remove_duplicated_spikes from .remove_excess_spikes import remove_excess_spikes -from .auto_merge import compute_merge_unit_groups, auto_merge_units, get_potential_auto_merge, auto_merge_units_iterative +from .auto_merge import ( + compute_merge_unit_groups, + auto_merge_units, + get_potential_auto_merge, + auto_merge_units_iterative, +) # manual sorting, diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index 28f41ee5be..ae3e629927 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -82,56 +82,56 @@ # **job_kwargs, # ) - # DEBUG - # import matplotlib.pyplot as plt - # from spikeinterface.curation.auto_merge import normalize_correlogram - # templates_diff = outs['templates_diff'] - # correlogram_diff = outs['correlogram_diff'] - # bins = outs['bins'] - # correlograms_smoothed = outs['correlograms_smoothed'] - # correlograms = outs['correlograms'] - # win_sizes = outs['win_sizes'] - - # fig, ax = plt.subplots() - # ax.hist(correlogram_diff.flatten(), bins=np.arange(0, 1, 0.05)) - - # fig, ax = plt.subplots() - # ax.hist(templates_diff.flatten(), bins=np.arange(0, 1, 0.05)) - - # m = correlograms.shape[2] // 2 - - # for unit_id1, unit_id2 in merge_unit_groups[:5]: - # unit_ind1 = sorting_with_split.id_to_index(unit_id1) - # unit_ind2 = sorting_with_split.id_to_index(unit_id2) - - # bins2 = bins[:-1] + np.mean(np.diff(bins)) - # fig, axs = plt.subplots(ncols=3) - # ax = axs[0] - # ax.plot(bins2, correlograms[unit_ind1, unit_ind1, :], color='b') - # ax.plot(bins2, correlograms[unit_ind2, unit_ind2, :], color='r') - # ax.plot(bins2, correlograms_smoothed[unit_ind1, unit_ind1, :], color='b') - # ax.plot(bins2, correlograms_smoothed[unit_ind2, unit_ind2, :], color='r') - - # ax.set_title(f'{unit_id1} {unit_id2}') - # ax = axs[1] - # ax.plot(bins2, correlograms_smoothed[unit_ind1, unit_ind2, :], color='g') - - # auto_corr1 = normalize_correlogram(correlograms_smoothed[unit_ind1, unit_ind1, :]) - # auto_corr2 = normalize_correlogram(correlograms_smoothed[unit_ind2, unit_ind2, :]) - # cross_corr = normalize_correlogram(correlograms_smoothed[unit_ind1, unit_ind2, :]) - - # ax = axs[2] - # ax.plot(bins2, auto_corr1, color='b') - # ax.plot(bins2, auto_corr2, color='r') - # ax.plot(bins2, cross_corr, color='g') - - # ax.axvline(bins2[m - win_sizes[unit_ind1]], color='b') - # ax.axvline(bins2[m + win_sizes[unit_ind1]], color='b') - # ax.axvline(bins2[m - win_sizes[unit_ind2]], color='r') - # ax.axvline(bins2[m + win_sizes[unit_ind2]], color='r') - - # ax.set_title(f'corr diff {correlogram_diff[unit_ind1, unit_ind2]} - temp diff {templates_diff[unit_ind1, unit_ind2]}') - # plt.show() +# DEBUG +# import matplotlib.pyplot as plt +# from spikeinterface.curation.auto_merge import normalize_correlogram +# templates_diff = outs['templates_diff'] +# correlogram_diff = outs['correlogram_diff'] +# bins = outs['bins'] +# correlograms_smoothed = outs['correlograms_smoothed'] +# correlograms = outs['correlograms'] +# win_sizes = outs['win_sizes'] + +# fig, ax = plt.subplots() +# ax.hist(correlogram_diff.flatten(), bins=np.arange(0, 1, 0.05)) + +# fig, ax = plt.subplots() +# ax.hist(templates_diff.flatten(), bins=np.arange(0, 1, 0.05)) + +# m = correlograms.shape[2] // 2 + +# for unit_id1, unit_id2 in merge_unit_groups[:5]: +# unit_ind1 = sorting_with_split.id_to_index(unit_id1) +# unit_ind2 = sorting_with_split.id_to_index(unit_id2) + +# bins2 = bins[:-1] + np.mean(np.diff(bins)) +# fig, axs = plt.subplots(ncols=3) +# ax = axs[0] +# ax.plot(bins2, correlograms[unit_ind1, unit_ind1, :], color='b') +# ax.plot(bins2, correlograms[unit_ind2, unit_ind2, :], color='r') +# ax.plot(bins2, correlograms_smoothed[unit_ind1, unit_ind1, :], color='b') +# ax.plot(bins2, correlograms_smoothed[unit_ind2, unit_ind2, :], color='r') + +# ax.set_title(f'{unit_id1} {unit_id2}') +# ax = axs[1] +# ax.plot(bins2, correlograms_smoothed[unit_ind1, unit_ind2, :], color='g') + +# auto_corr1 = normalize_correlogram(correlograms_smoothed[unit_ind1, unit_ind1, :]) +# auto_corr2 = normalize_correlogram(correlograms_smoothed[unit_ind2, unit_ind2, :]) +# cross_corr = normalize_correlogram(correlograms_smoothed[unit_ind1, unit_ind2, :]) + +# ax = axs[2] +# ax.plot(bins2, auto_corr1, color='b') +# ax.plot(bins2, auto_corr2, color='r') +# ax.plot(bins2, cross_corr, color='g') + +# ax.axvline(bins2[m - win_sizes[unit_ind1]], color='b') +# ax.axvline(bins2[m + win_sizes[unit_ind1]], color='b') +# ax.axvline(bins2[m - win_sizes[unit_ind2]], color='r') +# ax.axvline(bins2[m + win_sizes[unit_ind2]], color='r') + +# ax.set_title(f'corr diff {correlogram_diff[unit_ind1, unit_ind2]} - temp diff {templates_diff[unit_ind1, unit_ind2]}') +# plt.show() def test_auto_merge_units(sorting_analyzer_for_curation): @@ -154,9 +154,9 @@ def test_auto_merge_units(sorting_analyzer_for_curation): **job_kwargs, ) - merged_sorting = auto_merge_units(sorting_analyzer, {"preset" : "x_contaminations"}) + merged_sorting = auto_merge_units(sorting_analyzer, {"preset": "x_contaminations"}) assert len(merged_sorting.unit_ids) < len(sorting_analyzer_for_curation.unit_ids) - + def test_auto_merge_units_iterative(sorting_analyzer_for_curation): @@ -178,7 +178,7 @@ def test_auto_merge_units_iterative(sorting_analyzer_for_curation): **job_kwargs, ) - merged_sorting = auto_merge_units_iterative(sorting_analyzer, [{"preset" : "x_contaminations"}]) + merged_sorting = auto_merge_units_iterative(sorting_analyzer, [{"preset": "x_contaminations"}]) assert len(merged_sorting.unit_ids) < len(sorting_analyzer_for_curation.unit_ids) @@ -186,6 +186,6 @@ def test_auto_merge_units_iterative(sorting_analyzer_for_curation): sorting_analyzer = make_sorting_analyzer(sparse=True) # preset = "x_contaminations" preset = None - #test_compute_merge_unit_groups(sorting_analyzer, preset=preset) + # test_compute_merge_unit_groups(sorting_analyzer, preset=preset) test_auto_merge_units(sorting_analyzer) - test_auto_merge_units_iterative(sorting_analyzer) \ No newline at end of file + test_auto_merge_units_iterative(sorting_analyzer) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 4d43177684..d4b2caee1a 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -347,7 +347,7 @@ def final_cleaning_circus( sorting, templates, similarity_kwargs={"method": "l2", "support": "union", "max_lag_ms": 0.1}, - apply_merge_kwargs={"merging_mode": "soft", "sparsity_overlap": 0.5, "censor_ms" : 3.}, + apply_merge_kwargs={"merging_mode": "soft", "sparsity_overlap": 0.5, "censor_ms": 3.0}, correlograms_kwargs={}, ):