diff --git a/src/spikeinterface/core/tests/test_unitsaggregationsorting.py b/src/spikeinterface/core/tests/test_unitsaggregationsorting.py index b6cb479c7d..09b8affe55 100644 --- a/src/spikeinterface/core/tests/test_unitsaggregationsorting.py +++ b/src/spikeinterface/core/tests/test_unitsaggregationsorting.py @@ -5,6 +5,7 @@ from spikeinterface.core import NpzSortingExtractor from spikeinterface.core import create_sorting_npz +from spikeinterface.core import generate_sorting def test_unitsaggregationsorting(create_cache_folder): @@ -92,5 +93,42 @@ def test_unitsaggregationsorting(create_cache_folder): print(sorting_agg_prop.get_property("brain_area")) +def test_unit_aggregation_preserve_ids(): + + sorting1 = generate_sorting(num_units=3) + sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"]) + + sorting2 = generate_sorting(num_units=3) + sorting2 = sorting2.rename_units(new_unit_ids=["unit4", "unit5", "unit6"]) + + aggregated_sorting = aggregate_units([sorting1, sorting2]) + assert aggregated_sorting.get_num_units() == 6 + assert list(aggregated_sorting.get_unit_ids()) == ["unit1", "unit2", "unit3", "unit4", "unit5", "unit6"] + + +def test_unit_aggregation_does_not_preserve_ids_if_not_unique(): + sorting1 = generate_sorting(num_units=3) + sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"]) + + sorting2 = generate_sorting(num_units=3) + sorting2 = sorting2.rename_units(new_unit_ids=["unit1", "unit2", "unit3"]) + + aggregated_sorting = aggregate_units([sorting1, sorting2]) + assert aggregated_sorting.get_num_units() == 6 + assert list(aggregated_sorting.get_unit_ids()) == ["0", "1", "2", "3", "4", "5"] + + +def test_unit_aggregation_does_not_preserve_ids_not_the_same_type(): + sorting1 = generate_sorting(num_units=3) + sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"]) + + sorting2 = generate_sorting(num_units=2) + sorting2 = sorting2.rename_units(new_unit_ids=[1, 2]) + + aggregated_sorting = aggregate_units([sorting1, sorting2]) + assert aggregated_sorting.get_num_units() == 5 + assert list(aggregated_sorting.get_unit_ids()) == ["0", "1", "2", "3", "4"] + + if __name__ == "__main__": test_unitsaggregationsorting() diff --git a/src/spikeinterface/core/unitsaggregationsorting.py b/src/spikeinterface/core/unitsaggregationsorting.py index ea019268fb..9eb37e31ea 100644 --- a/src/spikeinterface/core/unitsaggregationsorting.py +++ b/src/spikeinterface/core/unitsaggregationsorting.py @@ -34,7 +34,21 @@ def __init__(self, sorting_list, renamed_unit_ids=None): ) unit_ids = list(renamed_unit_ids) else: - unit_ids = list(np.arange(num_all_units)) + unit_ids_dtypes = [sort.get_unit_ids().dtype for sort in sorting_list] + all_ids_are_same_type = np.unique(unit_ids_dtypes).size == 1 + all_units_ids_are_unique = False + if all_ids_are_same_type: + combined_ids = np.concatenate([sort.get_unit_ids() for sort in sorting_list]) + all_units_ids_are_unique = np.unique(combined_ids).size == num_all_units + + if all_ids_are_same_type and all_units_ids_are_unique: + unit_ids = combined_ids + else: + default_unit_ids = [str(i) for i in range(num_all_units)] + if all_ids_are_same_type and np.issubdtype(unit_ids_dtypes[0], np.integer): + unit_ids = np.arange(num_all_units, dtype=np.uint64) + else: + unit_ids = default_unit_ids # unit map maps unit ids that are used to get spike trains u_id = 0