From b7c430976d68bd86e516ecef6ad30f92614be266 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 10 Jul 2024 16:22:19 -0600 Subject: [PATCH] units aggergation should preserve ids --- .../tests/test_unitsaggregationsorting.py | 46 +++++++++++++++++-- .../core/unitsaggregationsorting.py | 12 ++++- 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/tests/test_unitsaggregationsorting.py b/src/spikeinterface/core/tests/test_unitsaggregationsorting.py index b6cb479c7d..e68bdc6939 100644 --- a/src/spikeinterface/core/tests/test_unitsaggregationsorting.py +++ b/src/spikeinterface/core/tests/test_unitsaggregationsorting.py @@ -5,6 +5,7 @@ from spikeinterface.core import NpzSortingExtractor from spikeinterface.core import create_sorting_npz +from spikeinterface.core import generate_sorting def test_unitsaggregationsorting(create_cache_folder): @@ -33,10 +34,12 @@ def test_unitsaggregationsorting(create_cache_folder): spiketrain1_1 = sorting1.get_unit_spike_train(unit_ids[1], segment_index=seg) spiketrains2_0 = sorting2.get_unit_spike_train(unit_ids[0], segment_index=seg) spiketrains3_2 = sorting3.get_unit_spike_train(unit_ids[2], segment_index=seg) - assert np.allclose(spiketrain1_1, sorting_agg.get_unit_spike_train(unit_ids[1], segment_index=seg)) - assert np.allclose(spiketrains2_0, sorting_agg.get_unit_spike_train(num_units + unit_ids[0], segment_index=seg)) + assert np.allclose(spiketrain1_1, sorting_agg.get_unit_spike_train(str(unit_ids[1]), segment_index=seg)) assert np.allclose( - spiketrains3_2, sorting_agg.get_unit_spike_train(2 * num_units + unit_ids[2], segment_index=seg) + spiketrains2_0, sorting_agg.get_unit_spike_train(str(num_units + unit_ids[0]), segment_index=seg) + ) + assert np.allclose( + spiketrains3_2, sorting_agg.get_unit_spike_train(str(2 * num_units + unit_ids[2]), segment_index=seg) ) # test rename units @@ -92,5 +95,42 @@ def test_unitsaggregationsorting(create_cache_folder): print(sorting_agg_prop.get_property("brain_area")) +def test_unit_aggregation_preserve_ids(): + + sorting1 = generate_sorting(num_units=3) + sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"]) + + sorting2 = generate_sorting(num_units=3) + sorting2 = sorting2.rename_units(new_unit_ids=["unit4", "unit5", "unit6"]) + + aggregated_sorting = aggregate_units([sorting1, sorting2]) + assert aggregated_sorting.get_num_units() == 6 + assert list(aggregated_sorting.get_unit_ids()) == ["unit1", "unit2", "unit3", "unit4", "unit5", "unit6"] + + +def test_unit_aggregation_does_not_preserve_ids_if_not_unique(): + sorting1 = generate_sorting(num_units=3) + sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"]) + + sorting2 = generate_sorting(num_units=3) + sorting2 = sorting2.rename_units(new_unit_ids=["unit1", "unit2", "unit3"]) + + aggregated_sorting = aggregate_units([sorting1, sorting2]) + assert aggregated_sorting.get_num_units() == 6 + assert list(aggregated_sorting.get_unit_ids()) == ["0", "1", "2", "3", "4", "5"] + + +def test_unit_aggregation_does_not_preserve_ids_not_the_same_type(): + sorting1 = generate_sorting(num_units=3) + sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"]) + + sorting2 = generate_sorting(num_units=2) + sorting2 = sorting2.rename_units(new_unit_ids=[1, 2]) + + aggregated_sorting = aggregate_units([sorting1, sorting2]) + assert aggregated_sorting.get_num_units() == 5 + assert list(aggregated_sorting.get_unit_ids()) == ["0", "1", "2", "3", "4"] + + if __name__ == "__main__": test_unitsaggregationsorting() diff --git a/src/spikeinterface/core/unitsaggregationsorting.py b/src/spikeinterface/core/unitsaggregationsorting.py index ea019268fb..df5a5589e8 100644 --- a/src/spikeinterface/core/unitsaggregationsorting.py +++ b/src/spikeinterface/core/unitsaggregationsorting.py @@ -34,7 +34,17 @@ def __init__(self, sorting_list, renamed_unit_ids=None): ) unit_ids = list(renamed_unit_ids) else: - unit_ids = list(np.arange(num_all_units)) + all_ids_are_same_type = np.unique([sort.get_unit_ids().dtype for sort in sorting_list]).size == 1 + all_units_ids_are_unique = False + if all_ids_are_same_type: + combined_ids = np.concatenate([sort.get_unit_ids() for sort in sorting_list]) + all_units_ids_are_unique = np.unique(combined_ids).size == num_all_units + + if all_ids_are_same_type and all_units_ids_are_unique: + unit_ids = combined_ids + else: + default_unit_ids = [str(i) for i in range(num_all_units)] + unit_ids = default_unit_ids # unit map maps unit ids that are used to get spike trains u_id = 0