diff --git a/doc/api.rst b/doc/api.rst index c5c9ebe4dd..c73cd812da 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -60,6 +60,10 @@ spikeinterface.core .. autofunction:: select_segment_sorting .. autofunction:: read_binary .. autofunction:: read_zarr + .. autofunction:: apply_merges_to_sorting + .. autofunction:: spike_vector_to_spike_trains + .. autofunction:: random_spikes_selection + Low-level ~~~~~~~~~ @@ -67,7 +71,6 @@ Low-level .. automodule:: spikeinterface.core :noindex: - .. autoclass:: BaseWaveformExtractorExtension .. autoclass:: ChunkRecordingExecutor spikeinterface.extractors diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index a5e1f44842..674f1ac463 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -101,7 +101,7 @@ get_chunk_with_margin, order_channels_by_depth, ) -from .sorting_tools import spike_vector_to_spike_trains, random_spikes_selection +from .sorting_tools import spike_vector_to_spike_trains, random_spikes_selection, apply_merges_to_sorting from .waveform_tools import extract_waveforms_to_buffers, estimate_templates, estimate_templates_with_accumulator from .snippets_tools import snippets_from_sorting diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 02f4529a98..918d95bf52 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -1,7 +1,10 @@ from __future__ import annotations -from .basesorting import BaseSorting + import numpy as np +from .basesorting import BaseSorting +from .numpyextractors import NumpySorting + def spike_vector_to_spike_trains(spike_vector: list[np.array], unit_ids: np.array) -> dict[dict[str, np.array]]: """ @@ -220,3 +223,202 @@ def random_spikes_selection( raise ValueError(f"random_spikes_selection(): method must be 'all' or 'uniform'") return random_spikes_indices + + +def apply_merges_to_sorting( + sorting, units_to_merge, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy="append" +): + """ + Apply a resolved representation of the merges to a sorting object. + + This function is not lazy and creates a new NumpySorting with a compact spike_vector as fast as possible. + + If `censor_ms` is not None, duplicated spikes violating the `censor_ms` refractory period are removed. + + Optionally, the boolean mask of kept spikes is returned. + + Parameters + ---------- + sorting : Sorting + The Sorting object to apply merges. + units_to_merge : list/tuple of lists/tuples + A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), + but it can also have more (merge multiple units at once). + new_unit_ids : list | None, default: None + A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. If None, + merged units will have the first unit_id of every lists of merges. + censor_ms: float | None, default: None + When applying the merges, should be discard consecutive spikes violating a given refractory per + return_kept : bool, default: False + If True, also return also a booolean mask of kept spikes. + new_id_strategy : "append" | "take_first", default: "append" + The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. + + * "append" : new_units_ids will be added at the end of max(sorging.unit_ids) + * "take_first" : new_unit_ids will be the first unit_id of every list of merges + + Returns + ------- + sorting : The new Sorting object + The newly create sorting with the merged units + keep_mask : numpy.array + A boolean mask, if censor_ms is not None, telling which spike from the original spike vector + has been kept, given the refractory period violations (None if censor_ms is None) + """ + + spikes = sorting.to_spike_vector().copy() + keep_mask = np.ones(len(spikes), dtype=bool) + + new_unit_ids = generate_unit_ids_for_merge_group( + sorting.unit_ids, units_to_merge, new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy + ) + + rename_ids = {} + for i, merge_group in enumerate(units_to_merge): + for unit_id in merge_group: + rename_ids[unit_id] = new_unit_ids[i] + + all_unit_ids = _get_ids_after_merging(sorting.unit_ids, units_to_merge, new_unit_ids) + all_unit_ids = list(all_unit_ids) + + num_seg = sorting.get_num_segments() + seg_lims = np.searchsorted(spikes["segment_index"], np.arange(0, num_seg + 2)) + segment_slices = [(seg_lims[i], seg_lims[i + 1]) for i in range(num_seg)] + + # using this function vaoid to use the mask approach and simplify a lot the algo + spike_vector_list = [spikes[s0:s1] for s0, s1 in segment_slices] + spike_indices = spike_vector_to_indices(spike_vector_list, sorting.unit_ids, absolute_index=True) + + for old_unit_id in sorting.unit_ids: + if old_unit_id in rename_ids.keys(): + new_unit_id = rename_ids[old_unit_id] + else: + new_unit_id = old_unit_id + + new_unit_index = all_unit_ids.index(new_unit_id) + for segment_index in range(num_seg): + spike_inds = spike_indices[segment_index][old_unit_id] + spikes["unit_index"][spike_inds] = new_unit_index + + if censor_ms is not None: + rpv = int(sorting.sampling_frequency * censor_ms / 1000.0) + for group_old_ids in units_to_merge: + for segment_index in range(num_seg): + group_indices = [] + for unit_id in group_old_ids: + group_indices.append(spike_indices[segment_index][unit_id]) + group_indices = np.concatenate(group_indices) + group_indices = np.sort(group_indices) + inds = np.flatnonzero(np.diff(spikes["sample_index"][group_indices]) < rpv) + keep_mask[group_indices[inds + 1]] = False + + spikes = spikes[keep_mask] + sorting = NumpySorting(spikes, sorting.sampling_frequency, all_unit_ids) + + if return_kept: + return sorting, keep_mask + else: + return sorting + + +def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): + """ + Function to get the list of unique unit_ids after some merges, with given new_units_ids would + be provided. + + Every new unit_id will be added at the end if not already present. + + Parameters + ---------- + old_unit_ids : np.array + The old unit_ids. + units_to_merge : list/tuple of lists/tuples + A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), + but it can also have more (merge multiple units at once). + new_unit_ids : list | None + A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. + + Returns + ------- + + all_unit_ids : The unit ids in the merged sorting + The units_ids that will be present after merges + + """ + old_unit_ids = np.asarray(old_unit_ids) + + assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge" + + all_unit_ids = list(old_unit_ids.copy()) + for new_unit_id, group_ids in zip(new_unit_ids, units_to_merge): + assert len(group_ids) > 1, "A merge should have at least two units" + for unit_id in group_ids: + assert unit_id in old_unit_ids, "Merged ids should be in the sorting" + for unit_id in group_ids: + if unit_id != new_unit_id: + # new_unit_id can be inside group_ids + all_unit_ids.remove(unit_id) + if new_unit_id not in all_unit_ids: + all_unit_ids.append(new_unit_id) + return np.array(all_unit_ids) + + +def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids=None, new_id_strategy="append"): + """ + Function to generate new units ids during a merging procedure. If new_units_ids + are provided, it will return these unit ids, checking that they have the the same + length as `units_to_merge`. + + Parameters + ---------- + old_unit_ids : np.array + The old unit_ids. + units_to_merge : list/tuple of lists/tuples + A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), + but it can also have more (merge multiple units at once). + new_unit_ids : list | None, default: None + Optional new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. + If None, new ids will be generated. + new_id_strategy : "append" | "take_first", default: "append" + The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. + + * "append" : new_units_ids will be added at the end of max(sorging.unit_ids) + * "take_first" : new_unit_ids will be the first unit_id of every list of merges + + Returns + ------- + new_unit_ids : The new unit ids + The new units_ids associated with the merges. + """ + old_unit_ids = np.asarray(old_unit_ids) + + if new_unit_ids is not None: + # then only doing a consistency check + assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge" + # new_unit_ids can also be part of old_unit_ids only inside the same group: + for i, new_unit_id in enumerate(new_unit_ids): + if new_unit_id in old_unit_ids: + assert new_unit_id in units_to_merge[i], "new_unit_ids already exists but outside the merged groups" + else: + dtype = old_unit_ids.dtype + num_merge = len(units_to_merge) + # select new_unit_ids greater that the max id, event greater than the numerical str ids + if new_id_strategy == "take_first": + new_unit_ids = [to_be_merged[0] for to_be_merged in units_to_merge] + elif new_id_strategy == "append": + if np.issubdtype(dtype, np.character): + # dtype str + if all(p.isdigit() for p in old_unit_ids): + # All str are digit : we can generate a max + m = max(int(p) for p in old_unit_ids) + 1 + new_unit_ids = [str(m + i) for i in range(num_merge)] + else: + # we cannot automatically find new names + new_unit_ids = [f"merge{i}" for i in range(num_merge)] + else: + # dtype int + new_unit_ids = list(max(old_unit_ids) + 1 + np.arange(num_merge, dtype=dtype)) + else: + raise ValueError("wrong new_id_strategy") + + return new_unit_ids diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 1aefeeb062..38baf62c35 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -9,6 +9,9 @@ spike_vector_to_spike_trains, random_spikes_selection, spike_vector_to_indices, + apply_merges_to_sorting, + _get_ids_after_merging, + generate_unit_ids_for_merge_group, ) @@ -75,7 +78,87 @@ def test_random_spikes_selection(): assert random_spikes_indices.size == spikes.size +def test_apply_merges_to_sorting(): + + times = np.array([0, 0, 10, 20, 300]) + labels = np.array(["a", "b", "c", "a", "b"]) + + # unit_ids str + sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=["a", "b", "c"]) + spikes1 = sorting1.to_spike_vector() + + sorting2 = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=None) + spikes2 = sorting2.to_spike_vector() + assert sorting2.unit_ids.size == 2 + assert sorting1.to_spike_vector().size == sorting1.to_spike_vector().size + assert np.array_equal(["c", "merge0"], sorting2.unit_ids) + assert np.array_equal( + spikes1[spikes1["unit_index"] == 2]["sample_index"], spikes2[spikes2["unit_index"] == 0]["sample_index"] + ) + + sorting3, keep_mask = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=1.5, return_kept=True) + spikes3 = sorting3.to_spike_vector() + assert spikes3.size < spikes1.size + assert not keep_mask[1] + st = sorting3.get_unit_spike_train(segment_index=0, unit_id="merge0") + assert st.size == 3 # one spike is removed by censor period + + # unit_ids int + sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=[10, 20, 30]) + spikes1 = sorting1.to_spike_vector() + sorting2 = apply_merges_to_sorting(sorting1, [[10, 20]], censor_ms=None) + assert np.array_equal(sorting2.unit_ids, [30, 31]) + + sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=["a", "b", "c"]) + sorting2 = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=None, new_id_strategy="take_first") + assert np.array_equal(sorting2.unit_ids, ["a", "c"]) + + +def test_get_ids_after_merging(): + + all_unit_ids = _get_ids_after_merging(["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], ["x", "d"]) + assert np.array_equal(all_unit_ids, ["c", "d", "x"]) + # print(all_unit_ids) + + all_unit_ids = _get_ids_after_merging([0, 5, 12, 9, 15], [[0, 5], [9, 15]], [28, 9]) + assert np.array_equal(all_unit_ids, [12, 9, 28]) + # print(all_unit_ids) + + +def test_generate_unit_ids_for_merge_group(): + + new_unit_ids = generate_unit_ids_for_merge_group( + ["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], new_id_strategy="append" + ) + assert np.array_equal(new_unit_ids, ["merge0", "merge1"]) + + new_unit_ids = generate_unit_ids_for_merge_group( + ["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], new_id_strategy="take_first" + ) + assert np.array_equal(new_unit_ids, ["a", "d"]) + + new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy="append") + assert np.array_equal(new_unit_ids, [16, 17]) + + new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy="take_first") + assert np.array_equal(new_unit_ids, [0, 9]) + + new_unit_ids = generate_unit_ids_for_merge_group( + ["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy="append" + ) + assert np.array_equal(new_unit_ids, ["16", "17"]) + + new_unit_ids = generate_unit_ids_for_merge_group( + ["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy="take_first" + ) + assert np.array_equal(new_unit_ids, ["0", "9"]) + + if __name__ == "__main__": # test_spike_vector_to_spike_trains() # test_spike_vector_to_indices() - test_random_spikes_selection() + # test_random_spikes_selection() + + test_apply_merges_to_sorting() + test_get_ids_after_merging() + test_generate_unit_ids_for_merge_group() diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index bbdb70b2f6..11f26ea778 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -4,6 +4,7 @@ from spikeinterface.core.basesorting import BaseSorting, BaseSortingSegment from spikeinterface.core.core_tools import define_function_from_class from copy import deepcopy +from spikeinterface.core.sorting_tools import generate_unit_ids_for_merge_group class MergeUnitsSorting(BaseSorting): @@ -44,35 +45,15 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy parents_unit_ids = sorting.unit_ids sampling_frequency = sorting.get_sampling_frequency() + new_unit_ids = generate_unit_ids_for_merge_group( + sorting.unit_ids, units_to_merge, new_unit_ids=new_unit_ids, new_id_strategy="append" + ) + all_removed_ids = [] for ids in units_to_merge: all_removed_ids.extend(ids) keep_unit_ids = [u for u in parents_unit_ids if u not in all_removed_ids] - if new_unit_ids is None: - dtype = parents_unit_ids.dtype - # select new_unit_ids greater that the max id, event greater than the numerical str ids - if np.issubdtype(dtype, np.character): - # dtype str - if all(p.isdigit() for p in parents_unit_ids): - # All str are digit : we can generate a max - m = max(int(p) for p in parents_unit_ids) + 1 - new_unit_ids = [str(m + i) for i in range(num_merge)] - else: - # we cannot automatically find new names - new_unit_ids = [f"merge{i}" for i in range(num_merge)] - if np.any(np.isin(new_unit_ids, keep_unit_ids)): - raise ValueError( - "Unable to find 'new_unit_ids' because it is a string and parents " - "already contain merges. Pass a list of 'new_unit_ids' as an argument." - ) - else: - # dtype int - new_unit_ids = list(max(parents_unit_ids) + 1 + np.arange(num_merge, dtype=dtype)) - else: - if np.any(np.isin(new_unit_ids, keep_unit_ids)): - raise ValueError("'new_unit_ids' already exist in the sorting.unit_ids. Provide new ones") - assert len(new_unit_ids) == num_merge, "new_unit_ids must have the same size as units_to_merge" # some checks @@ -81,7 +62,7 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy assert properties_policy in ("keep", "remove"), "properties_policy must be " "keep" " or " "remove" "" # new units are put at the end - unit_ids = keep_unit_ids + new_unit_ids + unit_ids = keep_unit_ids + list(new_unit_ids) BaseSorting.__init__(self, sampling_frequency, unit_ids) # assert all(np.isin(keep_unit_ids, self.unit_ids)), 'new_unit_id should have a compatible format with the parent ids'