Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 5, 2024
1 parent dd99121 commit 8783b7f
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 61 deletions.
46 changes: 23 additions & 23 deletions src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,17 +225,18 @@ def random_spikes_selection(
return random_spikes_indices



def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy='append'):
def apply_merges_to_sorting(
sorting, units_to_merge, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy="append"
):
"""
Function to apply a resolved representation of the merges to a sorting object.
This function is not lazy and create a new NumpySorting with a compact spike_vector as fast as possible.
If censor_ms is not None, duplicated spikes violating the censor_ms refractory period are removed.
Optionaly, the boolean of kept spikes is returned
Parameters
----------
sorting : Sorting
Expand All @@ -251,7 +252,7 @@ def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_m
return_kept : bool, default False
return also a booolean of kept spikes
new_id_strategy : "append" | "take_first", default "append"
The strategy that should be used, if new_unit_ids is None, to create new unit_ids.
The strategy that should be used, if new_unit_ids is None, to create new unit_ids.
"append" : new_units_ids will be added at the end of max(sorging.unit_ids)
"take_first" : new_unit_ids will be the first unit_id of every list of merges
Expand All @@ -267,38 +268,39 @@ def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_m
spikes = sorting.to_spike_vector().copy()
keep_mask = np.ones(len(spikes), dtype=bool)

new_unit_ids = generate_unit_ids_for_merge_group(sorting.unit_ids, units_to_merge,
new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy)
new_unit_ids = generate_unit_ids_for_merge_group(
sorting.unit_ids, units_to_merge, new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy
)

rename_ids = {}
for i, merge_group in enumerate(units_to_merge):
for unit_id in merge_group:
rename_ids[unit_id] = new_unit_ids[i]

all_unit_ids = _get_ids_after_merging(sorting.unit_ids, units_to_merge, new_unit_ids)
all_unit_ids = list(all_unit_ids)

num_seg = sorting.get_num_segments()
segment_limits = np.searchsorted(spikes["segment_index"], np.arange(0, num_seg + 2))
segment_slices = []
for i in range(num_seg):
segment_slices += [(segment_limits[i], segment_limits[i+1])]
segment_slices += [(segment_limits[i], segment_limits[i + 1])]

# using this function vaoid to use the mask approach and simplify a lot the algo
spike_vector_list = [spikes[s0:s1] for s0, s1 in segment_slices]
spike_indices = spike_vector_to_indices(spike_vector_list, sorting.unit_ids, absolute_index=True)

for old_unit_id in sorting.unit_ids:
if old_unit_id in rename_ids.keys():
new_unit_id = rename_ids[old_unit_id]
else:
new_unit_id = old_unit_id

new_unit_index = all_unit_ids.index(new_unit_id)
for segment_index in range(num_seg):
spike_inds = spike_indices[segment_index][old_unit_id]
spikes["unit_index"][spike_inds] = new_unit_index

if censor_ms is not None:
rpv = int(sorting.sampling_frequency * censor_ms / 1000.0)
for group_old_ids in units_to_merge:
Expand All @@ -308,7 +310,7 @@ def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_m
group_indices.append(spike_indices[segment_index][unit_id])
group_indices = np.concatenate(group_indices)
group_indices = np.sort(group_indices)
inds = np.flatnonzero(np.diff(spikes["sample_index"][group_indices]) < rpv )
inds = np.flatnonzero(np.diff(spikes["sample_index"][group_indices]) < rpv)
keep_mask[group_indices[inds + 1]] = False

spikes = spikes[keep_mask]
Expand All @@ -326,7 +328,7 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids):
be provided.
Every new unit_id will be added at the end if not already present.
Parameters
----------
old_unit_ids : np.array
Expand All @@ -341,7 +343,7 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids):
-------
all_unit_ids : The unit ids in the merged sorting
The units_ids that will be present after merges
The units_ids that will be present after merges
"""
old_unit_ids = np.asarray(old_unit_ids)
Expand All @@ -362,8 +364,7 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids):
return np.array(all_unit_ids)



def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids=None, new_id_strategy='append'):
def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids=None, new_id_strategy="append"):
"""
Function to generate new units ids during a merging procedure. If new_units_ids
are provided, it will return these unit ids, checking that they have the length as
Expand All @@ -380,20 +381,19 @@ def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids
A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. If None,
merged units will have the first unit_id of every lists of merges
new_id_strategy : "append" | "take_first", default "append"
The strategy that should be used, if new_unit_ids is None, to create new unit_ids.
The strategy that should be used, if new_unit_ids is None, to create new unit_ids.
"append" : new_units_ids will be added at the end of max(sorging.unit_ids)
"take_first" : new_unit_ids will be the first unit_id of every list of merges
Returns
-------
new_unit_ids : The new unit ids
The new units_ids associated with the merges
The new units_ids associated with the merges
"""
old_unit_ids = np.asarray(old_unit_ids)


if new_unit_ids is not None:
assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge"
else:
Expand All @@ -418,4 +418,4 @@ def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids
else:
raise ValueError("wrong new_id_strategy")

return new_unit_ids
return new_unit_ids
70 changes: 35 additions & 35 deletions src/spikeinterface/core/tests/test_sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
spike_vector_to_indices,
apply_merges_to_sorting,
_get_ids_after_merging,
generate_unit_ids_for_merge_group
generate_unit_ids_for_merge_group,
)


Expand Down Expand Up @@ -77,56 +77,47 @@ def test_random_spikes_selection():
random_spikes_indices = random_spikes_selection(sorting, num_samples, method="all")
assert random_spikes_indices.size == spikes.size


def test_apply_merges_to_sorting():

times = np.array([0, 0, 10, 20, 300])
labels = np.array(['a', 'b', 'c', 'a', 'b' ])
labels = np.array(["a", "b", "c", "a", "b"])

# unit_ids str
sorting1 = NumpySorting.from_times_labels(
[times, times], [labels, labels], 10_000., unit_ids=['a', 'b', 'c']
)
sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=["a", "b", "c"])
spikes1 = sorting1.to_spike_vector()

sorting2 = apply_merges_to_sorting(sorting1, [['a', 'b']], censor_ms=None)
sorting2 = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=None)
spikes2 = sorting2.to_spike_vector()
assert sorting2.unit_ids.size == 2
assert sorting1.to_spike_vector().size == sorting1.to_spike_vector().size
assert np.array_equal(['c', 'merge0'], sorting2.unit_ids)
assert np.array_equal(["c", "merge0"], sorting2.unit_ids)
assert np.array_equal(
spikes1[spikes1['unit_index'] == 2]['sample_index'],
spikes2[spikes2['unit_index'] == 0]['sample_index']
spikes1[spikes1["unit_index"] == 2]["sample_index"], spikes2[spikes2["unit_index"] == 0]["sample_index"]
)


sorting3, keep_mask = apply_merges_to_sorting(sorting1, [['a', 'b']], censor_ms=1.5, return_kept=True)
sorting3, keep_mask = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=1.5, return_kept=True)
spikes3 = sorting3.to_spike_vector()
assert spikes3.size < spikes1.size
assert not keep_mask[1]
st = sorting3.get_unit_spike_train(segment_index=0, unit_id='merge0')
assert st.size == 3 # one spike is removed by censor period

st = sorting3.get_unit_spike_train(segment_index=0, unit_id="merge0")
assert st.size == 3 # one spike is removed by censor period

# unit_ids int
sorting1 = NumpySorting.from_times_labels(
[times, times], [labels, labels], 10_000., unit_ids=[10, 20, 30]
)
sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=[10, 20, 30])
spikes1 = sorting1.to_spike_vector()
sorting2 = apply_merges_to_sorting(sorting1, [[10, 20]], censor_ms=None)
assert np.array_equal(sorting2.unit_ids, [30, 31])

sorting1 = NumpySorting.from_times_labels(
[times, times], [labels, labels], 10_000., unit_ids=['a', 'b', 'c']
)
sorting2 = apply_merges_to_sorting(sorting1, [['a', 'b']], censor_ms=None, new_id_strategy="take_first")
assert np.array_equal(sorting2.unit_ids, ['a', 'c'])

sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=["a", "b", "c"])
sorting2 = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=None, new_id_strategy="take_first")
assert np.array_equal(sorting2.unit_ids, ["a", "c"])


def test_get_ids_after_merging():

all_unit_ids = _get_ids_after_merging(['a', 'b', 'c', 'd', 'e'], [['a', 'b'], ['d', 'e']], ['x', 'd'])
assert np.array_equal(all_unit_ids, ['c', 'd', 'x'])
all_unit_ids = _get_ids_after_merging(["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], ["x", "d"])
assert np.array_equal(all_unit_ids, ["c", "d", "x"])
# print(all_unit_ids)

all_unit_ids = _get_ids_after_merging([0, 5, 12, 9, 15], [[0, 5], [9, 15]], [28, 9])
Expand All @@ -136,24 +127,33 @@ def test_get_ids_after_merging():

def test_generate_unit_ids_for_merge_group():

new_unit_ids = generate_unit_ids_for_merge_group(['a', 'b', 'c', 'd', 'e'], [['a', 'b'], ['d', 'e']], new_id_strategy='append')
assert np.array_equal(new_unit_ids, ['merge0', 'merge1'])
new_unit_ids = generate_unit_ids_for_merge_group(
["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], new_id_strategy="append"
)
assert np.array_equal(new_unit_ids, ["merge0", "merge1"])

new_unit_ids = generate_unit_ids_for_merge_group(['a', 'b', 'c', 'd', 'e'], [['a', 'b'], ['d', 'e']], new_id_strategy='take_first')
assert np.array_equal(new_unit_ids, ['a', 'd'])
new_unit_ids = generate_unit_ids_for_merge_group(
["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], new_id_strategy="take_first"
)
assert np.array_equal(new_unit_ids, ["a", "d"])

new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy='append')
new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy="append")
assert np.array_equal(new_unit_ids, [16, 17])
new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy='take_first')

new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy="take_first")
assert np.array_equal(new_unit_ids, [0, 9])

new_unit_ids = generate_unit_ids_for_merge_group(["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy='append')
new_unit_ids = generate_unit_ids_for_merge_group(
["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy="append"
)
assert np.array_equal(new_unit_ids, ["16", "17"])

new_unit_ids = generate_unit_ids_for_merge_group(["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy='take_first')

new_unit_ids = generate_unit_ids_for_merge_group(
["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy="take_first"
)
assert np.array_equal(new_unit_ids, ["0", "9"])


if __name__ == "__main__":
# test_spike_vector_to_spike_trains()
# test_spike_vector_to_indices()
Expand Down
8 changes: 5 additions & 3 deletions src/spikeinterface/curation/mergeunitssorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from copy import deepcopy
from spikeinterface.core.sorting_tools import generate_unit_ids_for_merge_group


class MergeUnitsSorting(BaseSorting):
"""
Class that handles several merges of units from a Sorting object based on a list of lists of unit_ids.
Expand Down Expand Up @@ -45,9 +46,10 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy
sampling_frequency = sorting.get_sampling_frequency()

from spikeinterface.core.sorting_tools import generate_unit_ids_for_merge_group
new_unit_ids = generate_unit_ids_for_merge_group(sorting.unit_ids, units_to_merge,
new_unit_ids=new_unit_ids,
new_id_strategy='append')

new_unit_ids = generate_unit_ids_for_merge_group(
sorting.unit_ids, units_to_merge, new_unit_ids=new_unit_ids, new_id_strategy="append"
)

all_removed_ids = []
for ids in units_to_merge:
Expand Down

0 comments on commit 8783b7f

Please sign in to comment.