diff --git a/src/spikeinterface/core/unitsaggregationsorting.py b/src/spikeinterface/core/unitsaggregationsorting.py index df5a5589e8..647854406a 100644 --- a/src/spikeinterface/core/unitsaggregationsorting.py +++ b/src/spikeinterface/core/unitsaggregationsorting.py @@ -34,7 +34,8 @@ def __init__(self, sorting_list, renamed_unit_ids=None): ) unit_ids = list(renamed_unit_ids) else: - all_ids_are_same_type = np.unique([sort.get_unit_ids().dtype for sort in sorting_list]).size == 1 + unit_ids_dtypes = [sort.get_unit_ids().dtype for sort in sorting_list] + all_ids_are_same_type = np.unique(unit_ids_dtypes).size == 1 all_units_ids_are_unique = False if all_ids_are_same_type: combined_ids = np.concatenate([sort.get_unit_ids() for sort in sorting_list]) @@ -44,7 +45,10 @@ def __init__(self, sorting_list, renamed_unit_ids=None): unit_ids = combined_ids else: default_unit_ids = [str(i) for i in range(num_all_units)] - unit_ids = default_unit_ids + if all_ids_are_same_type and np.issubdtype(unit_ids_dtypes[0], np.integer): + unit_ids = np.arange(num_all_units, dtype=dtype) + else: + unit_ids = default_unit_ids # unit map maps unit ids that are used to get spike trains u_id = 0