From dd99121652f8875d130f008ee10c6d1eb290201e Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 5 Jul 2024 14:24:14 +0200 Subject: [PATCH 1/6] Implement apply_merges_to_sorting() --- src/spikeinterface/core/sorting_tools.py | 201 +++++++++++++++++- .../core/tests/test_sorting_tools.py | 85 +++++++- .../curation/mergeunitssorting.py | 33 +-- 3 files changed, 291 insertions(+), 28 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 02f4529a98..9ee8ecb528 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -1,7 +1,10 @@ from __future__ import annotations -from .basesorting import BaseSorting + import numpy as np +from .basesorting import BaseSorting +from .numpyextractors import NumpySorting + def spike_vector_to_spike_trains(spike_vector: list[np.array], unit_ids: np.array) -> dict[dict[str, np.array]]: """ @@ -220,3 +223,199 @@ def random_spikes_selection( raise ValueError(f"random_spikes_selection(): method must be 'all' or 'uniform'") return random_spikes_indices + + + +def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy='append'): + """ + Function to apply a resolved representation of the merges to a sorting object. + + This function is not lazy and create a new NumpySorting with a compact spike_vector as fast as possible. + + If censor_ms is not None, duplicated spikes violating the censor_ms refractory period are removed. + + Optionaly, the boolean of kept spikes is returned + + Parameters + ---------- + sorting : Sorting + The Sorting object to apply merges + units_to_merge : list/tuple of lists/tuples + A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), + but it can also have more (merge multiple units at once). + new_unit_ids : None or list + A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. If None, + merged units will have the first unit_id of every lists of merges + censor_ms: None or float + When applying the merges, should be discard consecutive spikes violating a given refractory per + return_kept : bool, default False + return also a booolean of kept spikes + new_id_strategy : "append" | "take_first", default "append" + The strategy that should be used, if new_unit_ids is None, to create new unit_ids. + "append" : new_units_ids will be added at the end of max(sorging.unit_ids) + "take_first" : new_unit_ids will be the first unit_id of every list of merges + + Returns + ------- + sorting : The new Sorting object + The newly create sorting with the merged units + keep_mask : numpy.array + A boolean mask, if censor_ms is not None, telling which spike from the original spike vector + has been kept, given the refractory period violations (None if censor_ms is None) + """ + + spikes = sorting.to_spike_vector().copy() + keep_mask = np.ones(len(spikes), dtype=bool) + + new_unit_ids = generate_unit_ids_for_merge_group(sorting.unit_ids, units_to_merge, + new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy) + + rename_ids = {} + for i, merge_group in enumerate(units_to_merge): + for unit_id in merge_group: + rename_ids[unit_id] = new_unit_ids[i] + + all_unit_ids = _get_ids_after_merging(sorting.unit_ids, units_to_merge, new_unit_ids) + all_unit_ids = list(all_unit_ids) + + num_seg = sorting.get_num_segments() + segment_limits = np.searchsorted(spikes["segment_index"], np.arange(0, num_seg + 2)) + segment_slices = [] + for i in range(num_seg): + segment_slices += [(segment_limits[i], segment_limits[i+1])] + + # using this function vaoid to use the mask approach and simplify a lot the algo + spike_vector_list = [spikes[s0:s1] for s0, s1 in segment_slices] + spike_indices = spike_vector_to_indices(spike_vector_list, sorting.unit_ids, absolute_index=True) + + for old_unit_id in sorting.unit_ids: + if old_unit_id in rename_ids.keys(): + new_unit_id = rename_ids[old_unit_id] + else: + new_unit_id = old_unit_id + + new_unit_index = all_unit_ids.index(new_unit_id) + for segment_index in range(num_seg): + spike_inds = spike_indices[segment_index][old_unit_id] + spikes["unit_index"][spike_inds] = new_unit_index + + if censor_ms is not None: + rpv = int(sorting.sampling_frequency * censor_ms / 1000.0) + for group_old_ids in units_to_merge: + for segment_index in range(num_seg): + group_indices = [] + for unit_id in group_old_ids: + group_indices.append(spike_indices[segment_index][unit_id]) + group_indices = np.concatenate(group_indices) + group_indices = np.sort(group_indices) + inds = np.flatnonzero(np.diff(spikes["sample_index"][group_indices]) < rpv ) + keep_mask[group_indices[inds + 1]] = False + + spikes = spikes[keep_mask] + sorting = NumpySorting(spikes, sorting.sampling_frequency, all_unit_ids) + + if return_kept: + return sorting, keep_mask + else: + return sorting + + +def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): + """ + Function to get the list of unique unit_ids after some merges, with given new_units_ids would + be provided. + + Every new unit_id will be added at the end if not already present. + + Parameters + ---------- + old_unit_ids : np.array + The old unit_ids. + units_to_merge : list/tuple of lists/tuples + A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), + but it can also have more (merge multiple units at once). + new_unit_ids : None or list + A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. + + Returns + ------- + + all_unit_ids : The unit ids in the merged sorting + The units_ids that will be present after merges + + """ + old_unit_ids = np.asarray(old_unit_ids) + + assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge" + + all_unit_ids = list(old_unit_ids.copy()) + for new_unit_id, group_ids in zip(new_unit_ids, units_to_merge): + assert len(group_ids) > 1, "A merge should have at least two units" + for unit_id in group_ids: + assert unit_id in old_unit_ids, "Merged ids should be in the sorting" + for unit_id in group_ids: + if unit_id != new_unit_id: + # new_unit_id can be inside group_ids + all_unit_ids.remove(unit_id) + if new_unit_id not in all_unit_ids: + all_unit_ids.append(new_unit_id) + return np.array(all_unit_ids) + + + +def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids=None, new_id_strategy='append'): + """ + Function to generate new units ids during a merging procedure. If new_units_ids + are provided, it will return these unit ids, checking that they have the length as + to_be_merged. + + Parameters + ---------- + old_unit_ids : np.array + The old unit_ids + units_to_merge : list/tuple of lists/tuples + A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), + but it can also have more (merge multiple units at once). + new_unit_ids : None or list + A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. If None, + merged units will have the first unit_id of every lists of merges + new_id_strategy : "append" | "take_first", default "append" + The strategy that should be used, if new_unit_ids is None, to create new unit_ids. + "append" : new_units_ids will be added at the end of max(sorging.unit_ids) + "take_first" : new_unit_ids will be the first unit_id of every list of merges + + Returns + ------- + new_unit_ids : The new unit ids + The new units_ids associated with the merges + + + """ + old_unit_ids = np.asarray(old_unit_ids) + + + if new_unit_ids is not None: + assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge" + else: + dtype = old_unit_ids.dtype + num_merge = len(units_to_merge) + # select new_unit_ids greater that the max id, event greater than the numerical str ids + if new_id_strategy == "take_first": + new_unit_ids = [to_be_merged[0] for to_be_merged in units_to_merge] + elif new_id_strategy == "append": + if np.issubdtype(dtype, np.character): + # dtype str + if all(p.isdigit() for p in old_unit_ids): + # All str are digit : we can generate a max + m = max(int(p) for p in old_unit_ids) + 1 + new_unit_ids = [str(m + i) for i in range(num_merge)] + else: + # we cannot automatically find new names + new_unit_ids = [f"merge{i}" for i in range(num_merge)] + else: + # dtype int + new_unit_ids = list(max(old_unit_ids) + 1 + np.arange(num_merge, dtype=dtype)) + else: + raise ValueError("wrong new_id_strategy") + + return new_unit_ids \ No newline at end of file diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 1aefeeb062..24739fb374 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -9,6 +9,9 @@ spike_vector_to_spike_trains, random_spikes_selection, spike_vector_to_indices, + apply_merges_to_sorting, + _get_ids_after_merging, + generate_unit_ids_for_merge_group ) @@ -74,8 +77,88 @@ def test_random_spikes_selection(): random_spikes_indices = random_spikes_selection(sorting, num_samples, method="all") assert random_spikes_indices.size == spikes.size +def test_apply_merges_to_sorting(): + + times = np.array([0, 0, 10, 20, 300]) + labels = np.array(['a', 'b', 'c', 'a', 'b' ]) + + # unit_ids str + sorting1 = NumpySorting.from_times_labels( + [times, times], [labels, labels], 10_000., unit_ids=['a', 'b', 'c'] + ) + spikes1 = sorting1.to_spike_vector() + + sorting2 = apply_merges_to_sorting(sorting1, [['a', 'b']], censor_ms=None) + spikes2 = sorting2.to_spike_vector() + assert sorting2.unit_ids.size == 2 + assert sorting1.to_spike_vector().size == sorting1.to_spike_vector().size + assert np.array_equal(['c', 'merge0'], sorting2.unit_ids) + assert np.array_equal( + spikes1[spikes1['unit_index'] == 2]['sample_index'], + spikes2[spikes2['unit_index'] == 0]['sample_index'] + ) + + + sorting3, keep_mask = apply_merges_to_sorting(sorting1, [['a', 'b']], censor_ms=1.5, return_kept=True) + spikes3 = sorting3.to_spike_vector() + assert spikes3.size < spikes1.size + assert not keep_mask[1] + st = sorting3.get_unit_spike_train(segment_index=0, unit_id='merge0') + assert st.size == 3 # one spike is removed by censor period + + + # unit_ids int + sorting1 = NumpySorting.from_times_labels( + [times, times], [labels, labels], 10_000., unit_ids=[10, 20, 30] + ) + spikes1 = sorting1.to_spike_vector() + sorting2 = apply_merges_to_sorting(sorting1, [[10, 20]], censor_ms=None) + assert np.array_equal(sorting2.unit_ids, [30, 31]) + + sorting1 = NumpySorting.from_times_labels( + [times, times], [labels, labels], 10_000., unit_ids=['a', 'b', 'c'] + ) + sorting2 = apply_merges_to_sorting(sorting1, [['a', 'b']], censor_ms=None, new_id_strategy="take_first") + assert np.array_equal(sorting2.unit_ids, ['a', 'c']) + + + +def test_get_ids_after_merging(): + + all_unit_ids = _get_ids_after_merging(['a', 'b', 'c', 'd', 'e'], [['a', 'b'], ['d', 'e']], ['x', 'd']) + assert np.array_equal(all_unit_ids, ['c', 'd', 'x']) + # print(all_unit_ids) + + all_unit_ids = _get_ids_after_merging([0, 5, 12, 9, 15], [[0, 5], [9, 15]], [28, 9]) + assert np.array_equal(all_unit_ids, [12, 9, 28]) + # print(all_unit_ids) + + +def test_generate_unit_ids_for_merge_group(): + + new_unit_ids = generate_unit_ids_for_merge_group(['a', 'b', 'c', 'd', 'e'], [['a', 'b'], ['d', 'e']], new_id_strategy='append') + assert np.array_equal(new_unit_ids, ['merge0', 'merge1']) + + new_unit_ids = generate_unit_ids_for_merge_group(['a', 'b', 'c', 'd', 'e'], [['a', 'b'], ['d', 'e']], new_id_strategy='take_first') + assert np.array_equal(new_unit_ids, ['a', 'd']) + + new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy='append') + assert np.array_equal(new_unit_ids, [16, 17]) + + new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy='take_first') + assert np.array_equal(new_unit_ids, [0, 9]) + + new_unit_ids = generate_unit_ids_for_merge_group(["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy='append') + assert np.array_equal(new_unit_ids, ["16", "17"]) + + new_unit_ids = generate_unit_ids_for_merge_group(["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy='take_first') + assert np.array_equal(new_unit_ids, ["0", "9"]) if __name__ == "__main__": # test_spike_vector_to_spike_trains() # test_spike_vector_to_indices() - test_random_spikes_selection() + # test_random_spikes_selection() + + test_apply_merges_to_sorting() + test_get_ids_after_merging() + test_generate_unit_ids_for_merge_group() diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index bbdb70b2f6..c182d4130a 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -4,7 +4,7 @@ from spikeinterface.core.basesorting import BaseSorting, BaseSortingSegment from spikeinterface.core.core_tools import define_function_from_class from copy import deepcopy - +from spikeinterface.core.sorting_tools import generate_unit_ids_for_merge_group class MergeUnitsSorting(BaseSorting): """ @@ -44,35 +44,16 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy parents_unit_ids = sorting.unit_ids sampling_frequency = sorting.get_sampling_frequency() + from spikeinterface.core.sorting_tools import generate_unit_ids_for_merge_group + new_unit_ids = generate_unit_ids_for_merge_group(sorting.unit_ids, units_to_merge, + new_unit_ids=new_unit_ids, + new_id_strategy='append') + all_removed_ids = [] for ids in units_to_merge: all_removed_ids.extend(ids) keep_unit_ids = [u for u in parents_unit_ids if u not in all_removed_ids] - if new_unit_ids is None: - dtype = parents_unit_ids.dtype - # select new_unit_ids greater that the max id, event greater than the numerical str ids - if np.issubdtype(dtype, np.character): - # dtype str - if all(p.isdigit() for p in parents_unit_ids): - # All str are digit : we can generate a max - m = max(int(p) for p in parents_unit_ids) + 1 - new_unit_ids = [str(m + i) for i in range(num_merge)] - else: - # we cannot automatically find new names - new_unit_ids = [f"merge{i}" for i in range(num_merge)] - if np.any(np.isin(new_unit_ids, keep_unit_ids)): - raise ValueError( - "Unable to find 'new_unit_ids' because it is a string and parents " - "already contain merges. Pass a list of 'new_unit_ids' as an argument." - ) - else: - # dtype int - new_unit_ids = list(max(parents_unit_ids) + 1 + np.arange(num_merge, dtype=dtype)) - else: - if np.any(np.isin(new_unit_ids, keep_unit_ids)): - raise ValueError("'new_unit_ids' already exist in the sorting.unit_ids. Provide new ones") - assert len(new_unit_ids) == num_merge, "new_unit_ids must have the same size as units_to_merge" # some checks @@ -81,7 +62,7 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy assert properties_policy in ("keep", "remove"), "properties_policy must be " "keep" " or " "remove" "" # new units are put at the end - unit_ids = keep_unit_ids + new_unit_ids + unit_ids = keep_unit_ids + list(new_unit_ids) BaseSorting.__init__(self, sampling_frequency, unit_ids) # assert all(np.isin(keep_unit_ids, self.unit_ids)), 'new_unit_id should have a compatible format with the parent ids' From 8783b7fbba6d6029fbac09357a85ccbf27ea0ca9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 Jul 2024 12:28:37 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sorting_tools.py | 46 ++++++------ .../core/tests/test_sorting_tools.py | 70 +++++++++---------- .../curation/mergeunitssorting.py | 8 ++- 3 files changed, 63 insertions(+), 61 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 9ee8ecb528..8d038aa45b 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -225,17 +225,18 @@ def random_spikes_selection( return random_spikes_indices - -def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy='append'): +def apply_merges_to_sorting( + sorting, units_to_merge, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy="append" +): """ Function to apply a resolved representation of the merges to a sorting object. This function is not lazy and create a new NumpySorting with a compact spike_vector as fast as possible. - + If censor_ms is not None, duplicated spikes violating the censor_ms refractory period are removed. Optionaly, the boolean of kept spikes is returned - + Parameters ---------- sorting : Sorting @@ -251,7 +252,7 @@ def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_m return_kept : bool, default False return also a booolean of kept spikes new_id_strategy : "append" | "take_first", default "append" - The strategy that should be used, if new_unit_ids is None, to create new unit_ids. + The strategy that should be used, if new_unit_ids is None, to create new unit_ids. "append" : new_units_ids will be added at the end of max(sorging.unit_ids) "take_first" : new_unit_ids will be the first unit_id of every list of merges @@ -267,14 +268,15 @@ def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_m spikes = sorting.to_spike_vector().copy() keep_mask = np.ones(len(spikes), dtype=bool) - new_unit_ids = generate_unit_ids_for_merge_group(sorting.unit_ids, units_to_merge, - new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy) + new_unit_ids = generate_unit_ids_for_merge_group( + sorting.unit_ids, units_to_merge, new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy + ) rename_ids = {} for i, merge_group in enumerate(units_to_merge): for unit_id in merge_group: rename_ids[unit_id] = new_unit_ids[i] - + all_unit_ids = _get_ids_after_merging(sorting.unit_ids, units_to_merge, new_unit_ids) all_unit_ids = list(all_unit_ids) @@ -282,23 +284,23 @@ def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_m segment_limits = np.searchsorted(spikes["segment_index"], np.arange(0, num_seg + 2)) segment_slices = [] for i in range(num_seg): - segment_slices += [(segment_limits[i], segment_limits[i+1])] + segment_slices += [(segment_limits[i], segment_limits[i + 1])] # using this function vaoid to use the mask approach and simplify a lot the algo spike_vector_list = [spikes[s0:s1] for s0, s1 in segment_slices] spike_indices = spike_vector_to_indices(spike_vector_list, sorting.unit_ids, absolute_index=True) - + for old_unit_id in sorting.unit_ids: if old_unit_id in rename_ids.keys(): new_unit_id = rename_ids[old_unit_id] else: new_unit_id = old_unit_id - + new_unit_index = all_unit_ids.index(new_unit_id) for segment_index in range(num_seg): spike_inds = spike_indices[segment_index][old_unit_id] spikes["unit_index"][spike_inds] = new_unit_index - + if censor_ms is not None: rpv = int(sorting.sampling_frequency * censor_ms / 1000.0) for group_old_ids in units_to_merge: @@ -308,7 +310,7 @@ def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_m group_indices.append(spike_indices[segment_index][unit_id]) group_indices = np.concatenate(group_indices) group_indices = np.sort(group_indices) - inds = np.flatnonzero(np.diff(spikes["sample_index"][group_indices]) < rpv ) + inds = np.flatnonzero(np.diff(spikes["sample_index"][group_indices]) < rpv) keep_mask[group_indices[inds + 1]] = False spikes = spikes[keep_mask] @@ -326,7 +328,7 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): be provided. Every new unit_id will be added at the end if not already present. - + Parameters ---------- old_unit_ids : np.array @@ -341,7 +343,7 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): ------- all_unit_ids : The unit ids in the merged sorting - The units_ids that will be present after merges + The units_ids that will be present after merges """ old_unit_ids = np.asarray(old_unit_ids) @@ -362,8 +364,7 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): return np.array(all_unit_ids) - -def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids=None, new_id_strategy='append'): +def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids=None, new_id_strategy="append"): """ Function to generate new units ids during a merging procedure. If new_units_ids are provided, it will return these unit ids, checking that they have the length as @@ -380,20 +381,19 @@ def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. If None, merged units will have the first unit_id of every lists of merges new_id_strategy : "append" | "take_first", default "append" - The strategy that should be used, if new_unit_ids is None, to create new unit_ids. + The strategy that should be used, if new_unit_ids is None, to create new unit_ids. "append" : new_units_ids will be added at the end of max(sorging.unit_ids) "take_first" : new_unit_ids will be the first unit_id of every list of merges - + Returns ------- new_unit_ids : The new unit ids - The new units_ids associated with the merges + The new units_ids associated with the merges + - """ old_unit_ids = np.asarray(old_unit_ids) - if new_unit_ids is not None: assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge" else: @@ -418,4 +418,4 @@ def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids else: raise ValueError("wrong new_id_strategy") - return new_unit_ids \ No newline at end of file + return new_unit_ids diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 24739fb374..38baf62c35 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -11,7 +11,7 @@ spike_vector_to_indices, apply_merges_to_sorting, _get_ids_after_merging, - generate_unit_ids_for_merge_group + generate_unit_ids_for_merge_group, ) @@ -77,56 +77,47 @@ def test_random_spikes_selection(): random_spikes_indices = random_spikes_selection(sorting, num_samples, method="all") assert random_spikes_indices.size == spikes.size + def test_apply_merges_to_sorting(): times = np.array([0, 0, 10, 20, 300]) - labels = np.array(['a', 'b', 'c', 'a', 'b' ]) + labels = np.array(["a", "b", "c", "a", "b"]) # unit_ids str - sorting1 = NumpySorting.from_times_labels( - [times, times], [labels, labels], 10_000., unit_ids=['a', 'b', 'c'] - ) + sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=["a", "b", "c"]) spikes1 = sorting1.to_spike_vector() - sorting2 = apply_merges_to_sorting(sorting1, [['a', 'b']], censor_ms=None) + sorting2 = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=None) spikes2 = sorting2.to_spike_vector() assert sorting2.unit_ids.size == 2 assert sorting1.to_spike_vector().size == sorting1.to_spike_vector().size - assert np.array_equal(['c', 'merge0'], sorting2.unit_ids) + assert np.array_equal(["c", "merge0"], sorting2.unit_ids) assert np.array_equal( - spikes1[spikes1['unit_index'] == 2]['sample_index'], - spikes2[spikes2['unit_index'] == 0]['sample_index'] + spikes1[spikes1["unit_index"] == 2]["sample_index"], spikes2[spikes2["unit_index"] == 0]["sample_index"] ) - - sorting3, keep_mask = apply_merges_to_sorting(sorting1, [['a', 'b']], censor_ms=1.5, return_kept=True) + sorting3, keep_mask = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=1.5, return_kept=True) spikes3 = sorting3.to_spike_vector() assert spikes3.size < spikes1.size assert not keep_mask[1] - st = sorting3.get_unit_spike_train(segment_index=0, unit_id='merge0') - assert st.size == 3 # one spike is removed by censor period - + st = sorting3.get_unit_spike_train(segment_index=0, unit_id="merge0") + assert st.size == 3 # one spike is removed by censor period # unit_ids int - sorting1 = NumpySorting.from_times_labels( - [times, times], [labels, labels], 10_000., unit_ids=[10, 20, 30] - ) + sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=[10, 20, 30]) spikes1 = sorting1.to_spike_vector() sorting2 = apply_merges_to_sorting(sorting1, [[10, 20]], censor_ms=None) assert np.array_equal(sorting2.unit_ids, [30, 31]) - sorting1 = NumpySorting.from_times_labels( - [times, times], [labels, labels], 10_000., unit_ids=['a', 'b', 'c'] - ) - sorting2 = apply_merges_to_sorting(sorting1, [['a', 'b']], censor_ms=None, new_id_strategy="take_first") - assert np.array_equal(sorting2.unit_ids, ['a', 'c']) - + sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=["a", "b", "c"]) + sorting2 = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=None, new_id_strategy="take_first") + assert np.array_equal(sorting2.unit_ids, ["a", "c"]) def test_get_ids_after_merging(): - all_unit_ids = _get_ids_after_merging(['a', 'b', 'c', 'd', 'e'], [['a', 'b'], ['d', 'e']], ['x', 'd']) - assert np.array_equal(all_unit_ids, ['c', 'd', 'x']) + all_unit_ids = _get_ids_after_merging(["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], ["x", "d"]) + assert np.array_equal(all_unit_ids, ["c", "d", "x"]) # print(all_unit_ids) all_unit_ids = _get_ids_after_merging([0, 5, 12, 9, 15], [[0, 5], [9, 15]], [28, 9]) @@ -136,24 +127,33 @@ def test_get_ids_after_merging(): def test_generate_unit_ids_for_merge_group(): - new_unit_ids = generate_unit_ids_for_merge_group(['a', 'b', 'c', 'd', 'e'], [['a', 'b'], ['d', 'e']], new_id_strategy='append') - assert np.array_equal(new_unit_ids, ['merge0', 'merge1']) + new_unit_ids = generate_unit_ids_for_merge_group( + ["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], new_id_strategy="append" + ) + assert np.array_equal(new_unit_ids, ["merge0", "merge1"]) - new_unit_ids = generate_unit_ids_for_merge_group(['a', 'b', 'c', 'd', 'e'], [['a', 'b'], ['d', 'e']], new_id_strategy='take_first') - assert np.array_equal(new_unit_ids, ['a', 'd']) + new_unit_ids = generate_unit_ids_for_merge_group( + ["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], new_id_strategy="take_first" + ) + assert np.array_equal(new_unit_ids, ["a", "d"]) - new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy='append') + new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy="append") assert np.array_equal(new_unit_ids, [16, 17]) - - new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy='take_first') + + new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy="take_first") assert np.array_equal(new_unit_ids, [0, 9]) - new_unit_ids = generate_unit_ids_for_merge_group(["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy='append') + new_unit_ids = generate_unit_ids_for_merge_group( + ["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy="append" + ) assert np.array_equal(new_unit_ids, ["16", "17"]) - - new_unit_ids = generate_unit_ids_for_merge_group(["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy='take_first') + + new_unit_ids = generate_unit_ids_for_merge_group( + ["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy="take_first" + ) assert np.array_equal(new_unit_ids, ["0", "9"]) + if __name__ == "__main__": # test_spike_vector_to_spike_trains() # test_spike_vector_to_indices() diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index c182d4130a..3771b1c63c 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -6,6 +6,7 @@ from copy import deepcopy from spikeinterface.core.sorting_tools import generate_unit_ids_for_merge_group + class MergeUnitsSorting(BaseSorting): """ Class that handles several merges of units from a Sorting object based on a list of lists of unit_ids. @@ -45,9 +46,10 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy sampling_frequency = sorting.get_sampling_frequency() from spikeinterface.core.sorting_tools import generate_unit_ids_for_merge_group - new_unit_ids = generate_unit_ids_for_merge_group(sorting.unit_ids, units_to_merge, - new_unit_ids=new_unit_ids, - new_id_strategy='append') + + new_unit_ids = generate_unit_ids_for_merge_group( + sorting.unit_ids, units_to_merge, new_unit_ids=new_unit_ids, new_id_strategy="append" + ) all_removed_ids = [] for ids in units_to_merge: From 85d504dd9e0bc685069c4e426fcd32a99df48f5d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 5 Jul 2024 15:44:10 +0200 Subject: [PATCH 3/6] docstrings --- src/spikeinterface/core/sorting_tools.py | 54 ++++++++++++------------ 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 8d038aa45b..ca9697b222 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -229,32 +229,33 @@ def apply_merges_to_sorting( sorting, units_to_merge, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy="append" ): """ - Function to apply a resolved representation of the merges to a sorting object. + Apply a resolved representation of the merges to a sorting object. - This function is not lazy and create a new NumpySorting with a compact spike_vector as fast as possible. + This function is not lazy and creates a new NumpySorting with a compact spike_vector as fast as possible. - If censor_ms is not None, duplicated spikes violating the censor_ms refractory period are removed. + If `censor_ms` is not None, duplicated spikes violating the `censor_ms` refractory period are removed. - Optionaly, the boolean of kept spikes is returned + Optionally, the boolean mask of kept spikes is returned. Parameters ---------- sorting : Sorting - The Sorting object to apply merges + The Sorting object to apply merges. units_to_merge : list/tuple of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), but it can also have more (merge multiple units at once). - new_unit_ids : None or list + new_unit_ids : list | None, default: None A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. If None, - merged units will have the first unit_id of every lists of merges - censor_ms: None or float + merged units will have the first unit_id of every lists of merges. + censor_ms: float | None, default: None When applying the merges, should be discard consecutive spikes violating a given refractory per - return_kept : bool, default False - return also a booolean of kept spikes - new_id_strategy : "append" | "take_first", default "append" - The strategy that should be used, if new_unit_ids is None, to create new unit_ids. - "append" : new_units_ids will be added at the end of max(sorging.unit_ids) - "take_first" : new_unit_ids will be the first unit_id of every list of merges + return_kept : bool, default: False + If True, also return also a booolean mask of kept spikes. + new_id_strategy : "append" | "take_first", default: "append" + The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. + + * "append" : new_units_ids will be added at the end of max(sorging.unit_ids) + * "take_first" : new_unit_ids will be the first unit_id of every list of merges Returns ------- @@ -336,7 +337,7 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): units_to_merge : list/tuple of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), but it can also have more (merge multiple units at once). - new_unit_ids : None or list + new_unit_ids : list | None A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. Returns @@ -367,28 +368,29 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids=None, new_id_strategy="append"): """ Function to generate new units ids during a merging procedure. If new_units_ids - are provided, it will return these unit ids, checking that they have the length as - to_be_merged. + are provided, it will return these unit ids, checking that they have the the same + length as `units_to:merge`. Parameters ---------- old_unit_ids : np.array - The old unit_ids + The old unit_ids. units_to_merge : list/tuple of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), but it can also have more (merge multiple units at once). - new_unit_ids : None or list - A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. If None, - merged units will have the first unit_id of every lists of merges - new_id_strategy : "append" | "take_first", default "append" - The strategy that should be used, if new_unit_ids is None, to create new unit_ids. - "append" : new_units_ids will be added at the end of max(sorging.unit_ids) - "take_first" : new_unit_ids will be the first unit_id of every list of merges + new_unit_ids : list | None, default: None + Optional new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. + If None, new ids will be generated. + new_id_strategy : "append" | "take_first", default: "append" + The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. + + * "append" : new_units_ids will be added at the end of max(sorging.unit_ids) + * "take_first" : new_unit_ids will be the first unit_id of every list of merges Returns ------- new_unit_ids : The new unit ids - The new units_ids associated with the merges + The new units_ids associated with the merges. """ From 752603ee4dd9678c2f66322079e3c84487552c0b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 5 Jul 2024 17:51:10 +0200 Subject: [PATCH 4/6] Add apply_merges_to_sorting in api.rst --- doc/api.rst | 5 ++++- src/spikeinterface/core/__init__.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index c5c9ebe4dd..c73cd812da 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -60,6 +60,10 @@ spikeinterface.core .. autofunction:: select_segment_sorting .. autofunction:: read_binary .. autofunction:: read_zarr + .. autofunction:: apply_merges_to_sorting + .. autofunction:: spike_vector_to_spike_trains + .. autofunction:: random_spikes_selection + Low-level ~~~~~~~~~ @@ -67,7 +71,6 @@ Low-level .. automodule:: spikeinterface.core :noindex: - .. autoclass:: BaseWaveformExtractorExtension .. autoclass:: ChunkRecordingExecutor spikeinterface.extractors diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index a5e1f44842..674f1ac463 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -101,7 +101,7 @@ get_chunk_with_margin, order_channels_by_depth, ) -from .sorting_tools import spike_vector_to_spike_trains, random_spikes_selection +from .sorting_tools import spike_vector_to_spike_trains, random_spikes_selection, apply_merges_to_sorting from .waveform_tools import extract_waveforms_to_buffers, estimate_templates, estimate_templates_with_accumulator from .snippets_tools import snippets_from_sorting From a4ad437d4ad27be283faa18d30ba9fdac3028015 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 5 Jul 2024 19:07:33 +0200 Subject: [PATCH 5/6] more checks --- src/spikeinterface/core/sorting_tools.py | 15 ++++++++------- src/spikeinterface/curation/mergeunitssorting.py | 2 -- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index ca9697b222..d2104eec73 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -282,10 +282,8 @@ def apply_merges_to_sorting( all_unit_ids = list(all_unit_ids) num_seg = sorting.get_num_segments() - segment_limits = np.searchsorted(spikes["segment_index"], np.arange(0, num_seg + 2)) - segment_slices = [] - for i in range(num_seg): - segment_slices += [(segment_limits[i], segment_limits[i + 1])] + seg_lims = np.searchsorted(spikes["segment_index"], np.arange(0, num_seg + 2)) + segment_slices = [(seg_lims[i], seg_lims[i + 1]) for i in range(num_seg)] # using this function vaoid to use the mask approach and simplify a lot the algo spike_vector_list = [spikes[s0:s1] for s0, s1 in segment_slices] @@ -369,7 +367,7 @@ def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids """ Function to generate new units ids during a merging procedure. If new_units_ids are provided, it will return these unit ids, checking that they have the the same - length as `units_to:merge`. + length as `units_to_merge`. Parameters ---------- @@ -391,13 +389,16 @@ def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids ------- new_unit_ids : The new unit ids The new units_ids associated with the merges. - - """ old_unit_ids = np.asarray(old_unit_ids) if new_unit_ids is not None: + # then only doing a consistency check assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge" + # new_unit_ids can also be part of old_unit_ids only inside the same group: + for i, new_unit_id in enumerate(new_unit_ids): + if new_unit_id in old_unit_ids: + assert new_unit_id in units_to_merge[i], "new_unit_ids already exists but outside the merged groups" else: dtype = old_unit_ids.dtype num_merge = len(units_to_merge) diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index 3771b1c63c..11f26ea778 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -45,8 +45,6 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy parents_unit_ids = sorting.unit_ids sampling_frequency = sorting.get_sampling_frequency() - from spikeinterface.core.sorting_tools import generate_unit_ids_for_merge_group - new_unit_ids = generate_unit_ids_for_merge_group( sorting.unit_ids, units_to_merge, new_unit_ids=new_unit_ids, new_id_strategy="append" ) From e46b86a02c6fde2dc9539857a0ff5b16728c3419 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 Jul 2024 17:53:49 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sorting_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index d2104eec73..918d95bf52 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -395,7 +395,7 @@ def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids if new_unit_ids is not None: # then only doing a consistency check assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge" - # new_unit_ids can also be part of old_unit_ids only inside the same group: + # new_unit_ids can also be part of old_unit_ids only inside the same group: for i, new_unit_id in enumerate(new_unit_ids): if new_unit_id in old_unit_ids: assert new_unit_id in units_to_merge[i], "new_unit_ids already exists but outside the merged groups"