diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 9ee8ecb528..8d038aa45b 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -225,17 +225,18 @@ def random_spikes_selection( return random_spikes_indices - -def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy='append'): +def apply_merges_to_sorting( + sorting, units_to_merge, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy="append" +): """ Function to apply a resolved representation of the merges to a sorting object. This function is not lazy and create a new NumpySorting with a compact spike_vector as fast as possible. - + If censor_ms is not None, duplicated spikes violating the censor_ms refractory period are removed. Optionaly, the boolean of kept spikes is returned - + Parameters ---------- sorting : Sorting @@ -251,7 +252,7 @@ def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_m return_kept : bool, default False return also a booolean of kept spikes new_id_strategy : "append" | "take_first", default "append" - The strategy that should be used, if new_unit_ids is None, to create new unit_ids. + The strategy that should be used, if new_unit_ids is None, to create new unit_ids. "append" : new_units_ids will be added at the end of max(sorging.unit_ids) "take_first" : new_unit_ids will be the first unit_id of every list of merges @@ -267,14 +268,15 @@ def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_m spikes = sorting.to_spike_vector().copy() keep_mask = np.ones(len(spikes), dtype=bool) - new_unit_ids = generate_unit_ids_for_merge_group(sorting.unit_ids, units_to_merge, - new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy) + new_unit_ids = generate_unit_ids_for_merge_group( + sorting.unit_ids, units_to_merge, new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy + ) rename_ids = {} for i, merge_group in enumerate(units_to_merge): for unit_id in merge_group: rename_ids[unit_id] = new_unit_ids[i] - + all_unit_ids = _get_ids_after_merging(sorting.unit_ids, units_to_merge, new_unit_ids) all_unit_ids = list(all_unit_ids) @@ -282,23 +284,23 @@ def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_m segment_limits = np.searchsorted(spikes["segment_index"], np.arange(0, num_seg + 2)) segment_slices = [] for i in range(num_seg): - segment_slices += [(segment_limits[i], segment_limits[i+1])] + segment_slices += [(segment_limits[i], segment_limits[i + 1])] # using this function vaoid to use the mask approach and simplify a lot the algo spike_vector_list = [spikes[s0:s1] for s0, s1 in segment_slices] spike_indices = spike_vector_to_indices(spike_vector_list, sorting.unit_ids, absolute_index=True) - + for old_unit_id in sorting.unit_ids: if old_unit_id in rename_ids.keys(): new_unit_id = rename_ids[old_unit_id] else: new_unit_id = old_unit_id - + new_unit_index = all_unit_ids.index(new_unit_id) for segment_index in range(num_seg): spike_inds = spike_indices[segment_index][old_unit_id] spikes["unit_index"][spike_inds] = new_unit_index - + if censor_ms is not None: rpv = int(sorting.sampling_frequency * censor_ms / 1000.0) for group_old_ids in units_to_merge: @@ -308,7 +310,7 @@ def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_m group_indices.append(spike_indices[segment_index][unit_id]) group_indices = np.concatenate(group_indices) group_indices = np.sort(group_indices) - inds = np.flatnonzero(np.diff(spikes["sample_index"][group_indices]) < rpv ) + inds = np.flatnonzero(np.diff(spikes["sample_index"][group_indices]) < rpv) keep_mask[group_indices[inds + 1]] = False spikes = spikes[keep_mask] @@ -326,7 +328,7 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): be provided. Every new unit_id will be added at the end if not already present. - + Parameters ---------- old_unit_ids : np.array @@ -341,7 +343,7 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): ------- all_unit_ids : The unit ids in the merged sorting - The units_ids that will be present after merges + The units_ids that will be present after merges """ old_unit_ids = np.asarray(old_unit_ids) @@ -362,8 +364,7 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): return np.array(all_unit_ids) - -def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids=None, new_id_strategy='append'): +def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids=None, new_id_strategy="append"): """ Function to generate new units ids during a merging procedure. If new_units_ids are provided, it will return these unit ids, checking that they have the length as @@ -380,20 +381,19 @@ def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. If None, merged units will have the first unit_id of every lists of merges new_id_strategy : "append" | "take_first", default "append" - The strategy that should be used, if new_unit_ids is None, to create new unit_ids. + The strategy that should be used, if new_unit_ids is None, to create new unit_ids. "append" : new_units_ids will be added at the end of max(sorging.unit_ids) "take_first" : new_unit_ids will be the first unit_id of every list of merges - + Returns ------- new_unit_ids : The new unit ids - The new units_ids associated with the merges + The new units_ids associated with the merges + - """ old_unit_ids = np.asarray(old_unit_ids) - if new_unit_ids is not None: assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge" else: @@ -418,4 +418,4 @@ def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids else: raise ValueError("wrong new_id_strategy") - return new_unit_ids \ No newline at end of file + return new_unit_ids diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 24739fb374..38baf62c35 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -11,7 +11,7 @@ spike_vector_to_indices, apply_merges_to_sorting, _get_ids_after_merging, - generate_unit_ids_for_merge_group + generate_unit_ids_for_merge_group, ) @@ -77,56 +77,47 @@ def test_random_spikes_selection(): random_spikes_indices = random_spikes_selection(sorting, num_samples, method="all") assert random_spikes_indices.size == spikes.size + def test_apply_merges_to_sorting(): times = np.array([0, 0, 10, 20, 300]) - labels = np.array(['a', 'b', 'c', 'a', 'b' ]) + labels = np.array(["a", "b", "c", "a", "b"]) # unit_ids str - sorting1 = NumpySorting.from_times_labels( - [times, times], [labels, labels], 10_000., unit_ids=['a', 'b', 'c'] - ) + sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=["a", "b", "c"]) spikes1 = sorting1.to_spike_vector() - sorting2 = apply_merges_to_sorting(sorting1, [['a', 'b']], censor_ms=None) + sorting2 = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=None) spikes2 = sorting2.to_spike_vector() assert sorting2.unit_ids.size == 2 assert sorting1.to_spike_vector().size == sorting1.to_spike_vector().size - assert np.array_equal(['c', 'merge0'], sorting2.unit_ids) + assert np.array_equal(["c", "merge0"], sorting2.unit_ids) assert np.array_equal( - spikes1[spikes1['unit_index'] == 2]['sample_index'], - spikes2[spikes2['unit_index'] == 0]['sample_index'] + spikes1[spikes1["unit_index"] == 2]["sample_index"], spikes2[spikes2["unit_index"] == 0]["sample_index"] ) - - sorting3, keep_mask = apply_merges_to_sorting(sorting1, [['a', 'b']], censor_ms=1.5, return_kept=True) + sorting3, keep_mask = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=1.5, return_kept=True) spikes3 = sorting3.to_spike_vector() assert spikes3.size < spikes1.size assert not keep_mask[1] - st = sorting3.get_unit_spike_train(segment_index=0, unit_id='merge0') - assert st.size == 3 # one spike is removed by censor period - + st = sorting3.get_unit_spike_train(segment_index=0, unit_id="merge0") + assert st.size == 3 # one spike is removed by censor period # unit_ids int - sorting1 = NumpySorting.from_times_labels( - [times, times], [labels, labels], 10_000., unit_ids=[10, 20, 30] - ) + sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=[10, 20, 30]) spikes1 = sorting1.to_spike_vector() sorting2 = apply_merges_to_sorting(sorting1, [[10, 20]], censor_ms=None) assert np.array_equal(sorting2.unit_ids, [30, 31]) - sorting1 = NumpySorting.from_times_labels( - [times, times], [labels, labels], 10_000., unit_ids=['a', 'b', 'c'] - ) - sorting2 = apply_merges_to_sorting(sorting1, [['a', 'b']], censor_ms=None, new_id_strategy="take_first") - assert np.array_equal(sorting2.unit_ids, ['a', 'c']) - + sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=["a", "b", "c"]) + sorting2 = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=None, new_id_strategy="take_first") + assert np.array_equal(sorting2.unit_ids, ["a", "c"]) def test_get_ids_after_merging(): - all_unit_ids = _get_ids_after_merging(['a', 'b', 'c', 'd', 'e'], [['a', 'b'], ['d', 'e']], ['x', 'd']) - assert np.array_equal(all_unit_ids, ['c', 'd', 'x']) + all_unit_ids = _get_ids_after_merging(["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], ["x", "d"]) + assert np.array_equal(all_unit_ids, ["c", "d", "x"]) # print(all_unit_ids) all_unit_ids = _get_ids_after_merging([0, 5, 12, 9, 15], [[0, 5], [9, 15]], [28, 9]) @@ -136,24 +127,33 @@ def test_get_ids_after_merging(): def test_generate_unit_ids_for_merge_group(): - new_unit_ids = generate_unit_ids_for_merge_group(['a', 'b', 'c', 'd', 'e'], [['a', 'b'], ['d', 'e']], new_id_strategy='append') - assert np.array_equal(new_unit_ids, ['merge0', 'merge1']) + new_unit_ids = generate_unit_ids_for_merge_group( + ["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], new_id_strategy="append" + ) + assert np.array_equal(new_unit_ids, ["merge0", "merge1"]) - new_unit_ids = generate_unit_ids_for_merge_group(['a', 'b', 'c', 'd', 'e'], [['a', 'b'], ['d', 'e']], new_id_strategy='take_first') - assert np.array_equal(new_unit_ids, ['a', 'd']) + new_unit_ids = generate_unit_ids_for_merge_group( + ["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], new_id_strategy="take_first" + ) + assert np.array_equal(new_unit_ids, ["a", "d"]) - new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy='append') + new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy="append") assert np.array_equal(new_unit_ids, [16, 17]) - - new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy='take_first') + + new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy="take_first") assert np.array_equal(new_unit_ids, [0, 9]) - new_unit_ids = generate_unit_ids_for_merge_group(["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy='append') + new_unit_ids = generate_unit_ids_for_merge_group( + ["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy="append" + ) assert np.array_equal(new_unit_ids, ["16", "17"]) - - new_unit_ids = generate_unit_ids_for_merge_group(["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy='take_first') + + new_unit_ids = generate_unit_ids_for_merge_group( + ["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy="take_first" + ) assert np.array_equal(new_unit_ids, ["0", "9"]) + if __name__ == "__main__": # test_spike_vector_to_spike_trains() # test_spike_vector_to_indices() diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index c182d4130a..3771b1c63c 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -6,6 +6,7 @@ from copy import deepcopy from spikeinterface.core.sorting_tools import generate_unit_ids_for_merge_group + class MergeUnitsSorting(BaseSorting): """ Class that handles several merges of units from a Sorting object based on a list of lists of unit_ids. @@ -45,9 +46,10 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy sampling_frequency = sorting.get_sampling_frequency() from spikeinterface.core.sorting_tools import generate_unit_ids_for_merge_group - new_unit_ids = generate_unit_ids_for_merge_group(sorting.unit_ids, units_to_merge, - new_unit_ids=new_unit_ids, - new_id_strategy='append') + + new_unit_ids = generate_unit_ids_for_merge_group( + sorting.unit_ids, units_to_merge, new_unit_ids=new_unit_ids, new_id_strategy="append" + ) all_removed_ids = [] for ids in units_to_merge: