Skip to content

Commit

Permalink
Merge pull request #3154 from samuelgarcia/apply_merge_unit_to_sorting
Browse files Browse the repository at this point in the history
Implement apply_merges_to_sorting()
  • Loading branch information
alejoe91 authored Jul 6, 2024
2 parents 1e8b551 + e46b86a commit d983b19
Show file tree
Hide file tree
Showing 5 changed files with 298 additions and 29 deletions.
5 changes: 4 additions & 1 deletion doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,17 @@ spikeinterface.core
.. autofunction:: select_segment_sorting
.. autofunction:: read_binary
.. autofunction:: read_zarr
.. autofunction:: apply_merges_to_sorting
.. autofunction:: spike_vector_to_spike_trains
.. autofunction:: random_spikes_selection


Low-level
~~~~~~~~~

.. automodule:: spikeinterface.core
:noindex:

.. autoclass:: BaseWaveformExtractorExtension
.. autoclass:: ChunkRecordingExecutor

spikeinterface.extractors
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
get_chunk_with_margin,
order_channels_by_depth,
)
from .sorting_tools import spike_vector_to_spike_trains, random_spikes_selection
from .sorting_tools import spike_vector_to_spike_trains, random_spikes_selection, apply_merges_to_sorting

from .waveform_tools import extract_waveforms_to_buffers, estimate_templates, estimate_templates_with_accumulator
from .snippets_tools import snippets_from_sorting
Expand Down
204 changes: 203 additions & 1 deletion src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations
from .basesorting import BaseSorting

import numpy as np

from .basesorting import BaseSorting
from .numpyextractors import NumpySorting


def spike_vector_to_spike_trains(spike_vector: list[np.array], unit_ids: np.array) -> dict[dict[str, np.array]]:
"""
Expand Down Expand Up @@ -220,3 +223,202 @@ def random_spikes_selection(
raise ValueError(f"random_spikes_selection(): method must be 'all' or 'uniform'")

return random_spikes_indices


def apply_merges_to_sorting(
sorting, units_to_merge, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy="append"
):
"""
Apply a resolved representation of the merges to a sorting object.
This function is not lazy and creates a new NumpySorting with a compact spike_vector as fast as possible.
If `censor_ms` is not None, duplicated spikes violating the `censor_ms` refractory period are removed.
Optionally, the boolean mask of kept spikes is returned.
Parameters
----------
sorting : Sorting
The Sorting object to apply merges.
units_to_merge : list/tuple of lists/tuples
A list of lists for every merge group. Each element needs to have at least two elements (two units to merge),
but it can also have more (merge multiple units at once).
new_unit_ids : list | None, default: None
A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. If None,
merged units will have the first unit_id of every lists of merges.
censor_ms: float | None, default: None
When applying the merges, should be discard consecutive spikes violating a given refractory per
return_kept : bool, default: False
If True, also return also a booolean mask of kept spikes.
new_id_strategy : "append" | "take_first", default: "append"
The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids.
* "append" : new_units_ids will be added at the end of max(sorging.unit_ids)
* "take_first" : new_unit_ids will be the first unit_id of every list of merges
Returns
-------
sorting : The new Sorting object
The newly create sorting with the merged units
keep_mask : numpy.array
A boolean mask, if censor_ms is not None, telling which spike from the original spike vector
has been kept, given the refractory period violations (None if censor_ms is None)
"""

spikes = sorting.to_spike_vector().copy()
keep_mask = np.ones(len(spikes), dtype=bool)

new_unit_ids = generate_unit_ids_for_merge_group(
sorting.unit_ids, units_to_merge, new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy
)

rename_ids = {}
for i, merge_group in enumerate(units_to_merge):
for unit_id in merge_group:
rename_ids[unit_id] = new_unit_ids[i]

all_unit_ids = _get_ids_after_merging(sorting.unit_ids, units_to_merge, new_unit_ids)
all_unit_ids = list(all_unit_ids)

num_seg = sorting.get_num_segments()
seg_lims = np.searchsorted(spikes["segment_index"], np.arange(0, num_seg + 2))
segment_slices = [(seg_lims[i], seg_lims[i + 1]) for i in range(num_seg)]

# using this function vaoid to use the mask approach and simplify a lot the algo
spike_vector_list = [spikes[s0:s1] for s0, s1 in segment_slices]
spike_indices = spike_vector_to_indices(spike_vector_list, sorting.unit_ids, absolute_index=True)

for old_unit_id in sorting.unit_ids:
if old_unit_id in rename_ids.keys():
new_unit_id = rename_ids[old_unit_id]
else:
new_unit_id = old_unit_id

new_unit_index = all_unit_ids.index(new_unit_id)
for segment_index in range(num_seg):
spike_inds = spike_indices[segment_index][old_unit_id]
spikes["unit_index"][spike_inds] = new_unit_index

if censor_ms is not None:
rpv = int(sorting.sampling_frequency * censor_ms / 1000.0)
for group_old_ids in units_to_merge:
for segment_index in range(num_seg):
group_indices = []
for unit_id in group_old_ids:
group_indices.append(spike_indices[segment_index][unit_id])
group_indices = np.concatenate(group_indices)
group_indices = np.sort(group_indices)
inds = np.flatnonzero(np.diff(spikes["sample_index"][group_indices]) < rpv)
keep_mask[group_indices[inds + 1]] = False

spikes = spikes[keep_mask]
sorting = NumpySorting(spikes, sorting.sampling_frequency, all_unit_ids)

if return_kept:
return sorting, keep_mask
else:
return sorting


def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids):
"""
Function to get the list of unique unit_ids after some merges, with given new_units_ids would
be provided.
Every new unit_id will be added at the end if not already present.
Parameters
----------
old_unit_ids : np.array
The old unit_ids.
units_to_merge : list/tuple of lists/tuples
A list of lists for every merge group. Each element needs to have at least two elements (two units to merge),
but it can also have more (merge multiple units at once).
new_unit_ids : list | None
A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`.
Returns
-------
all_unit_ids : The unit ids in the merged sorting
The units_ids that will be present after merges
"""
old_unit_ids = np.asarray(old_unit_ids)

assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge"

all_unit_ids = list(old_unit_ids.copy())
for new_unit_id, group_ids in zip(new_unit_ids, units_to_merge):
assert len(group_ids) > 1, "A merge should have at least two units"
for unit_id in group_ids:
assert unit_id in old_unit_ids, "Merged ids should be in the sorting"
for unit_id in group_ids:
if unit_id != new_unit_id:
# new_unit_id can be inside group_ids
all_unit_ids.remove(unit_id)
if new_unit_id not in all_unit_ids:
all_unit_ids.append(new_unit_id)
return np.array(all_unit_ids)


def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids=None, new_id_strategy="append"):
"""
Function to generate new units ids during a merging procedure. If new_units_ids
are provided, it will return these unit ids, checking that they have the the same
length as `units_to_merge`.
Parameters
----------
old_unit_ids : np.array
The old unit_ids.
units_to_merge : list/tuple of lists/tuples
A list of lists for every merge group. Each element needs to have at least two elements (two units to merge),
but it can also have more (merge multiple units at once).
new_unit_ids : list | None, default: None
Optional new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`.
If None, new ids will be generated.
new_id_strategy : "append" | "take_first", default: "append"
The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids.
* "append" : new_units_ids will be added at the end of max(sorging.unit_ids)
* "take_first" : new_unit_ids will be the first unit_id of every list of merges
Returns
-------
new_unit_ids : The new unit ids
The new units_ids associated with the merges.
"""
old_unit_ids = np.asarray(old_unit_ids)

if new_unit_ids is not None:
# then only doing a consistency check
assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge"
# new_unit_ids can also be part of old_unit_ids only inside the same group:
for i, new_unit_id in enumerate(new_unit_ids):
if new_unit_id in old_unit_ids:
assert new_unit_id in units_to_merge[i], "new_unit_ids already exists but outside the merged groups"
else:
dtype = old_unit_ids.dtype
num_merge = len(units_to_merge)
# select new_unit_ids greater that the max id, event greater than the numerical str ids
if new_id_strategy == "take_first":
new_unit_ids = [to_be_merged[0] for to_be_merged in units_to_merge]
elif new_id_strategy == "append":
if np.issubdtype(dtype, np.character):
# dtype str
if all(p.isdigit() for p in old_unit_ids):
# All str are digit : we can generate a max
m = max(int(p) for p in old_unit_ids) + 1
new_unit_ids = [str(m + i) for i in range(num_merge)]
else:
# we cannot automatically find new names
new_unit_ids = [f"merge{i}" for i in range(num_merge)]
else:
# dtype int
new_unit_ids = list(max(old_unit_ids) + 1 + np.arange(num_merge, dtype=dtype))
else:
raise ValueError("wrong new_id_strategy")

return new_unit_ids
85 changes: 84 additions & 1 deletion src/spikeinterface/core/tests/test_sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
spike_vector_to_spike_trains,
random_spikes_selection,
spike_vector_to_indices,
apply_merges_to_sorting,
_get_ids_after_merging,
generate_unit_ids_for_merge_group,
)


Expand Down Expand Up @@ -75,7 +78,87 @@ def test_random_spikes_selection():
assert random_spikes_indices.size == spikes.size


def test_apply_merges_to_sorting():

times = np.array([0, 0, 10, 20, 300])
labels = np.array(["a", "b", "c", "a", "b"])

# unit_ids str
sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=["a", "b", "c"])
spikes1 = sorting1.to_spike_vector()

sorting2 = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=None)
spikes2 = sorting2.to_spike_vector()
assert sorting2.unit_ids.size == 2
assert sorting1.to_spike_vector().size == sorting1.to_spike_vector().size
assert np.array_equal(["c", "merge0"], sorting2.unit_ids)
assert np.array_equal(
spikes1[spikes1["unit_index"] == 2]["sample_index"], spikes2[spikes2["unit_index"] == 0]["sample_index"]
)

sorting3, keep_mask = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=1.5, return_kept=True)
spikes3 = sorting3.to_spike_vector()
assert spikes3.size < spikes1.size
assert not keep_mask[1]
st = sorting3.get_unit_spike_train(segment_index=0, unit_id="merge0")
assert st.size == 3 # one spike is removed by censor period

# unit_ids int
sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=[10, 20, 30])
spikes1 = sorting1.to_spike_vector()
sorting2 = apply_merges_to_sorting(sorting1, [[10, 20]], censor_ms=None)
assert np.array_equal(sorting2.unit_ids, [30, 31])

sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=["a", "b", "c"])
sorting2 = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=None, new_id_strategy="take_first")
assert np.array_equal(sorting2.unit_ids, ["a", "c"])


def test_get_ids_after_merging():

all_unit_ids = _get_ids_after_merging(["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], ["x", "d"])
assert np.array_equal(all_unit_ids, ["c", "d", "x"])
# print(all_unit_ids)

all_unit_ids = _get_ids_after_merging([0, 5, 12, 9, 15], [[0, 5], [9, 15]], [28, 9])
assert np.array_equal(all_unit_ids, [12, 9, 28])
# print(all_unit_ids)


def test_generate_unit_ids_for_merge_group():

new_unit_ids = generate_unit_ids_for_merge_group(
["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], new_id_strategy="append"
)
assert np.array_equal(new_unit_ids, ["merge0", "merge1"])

new_unit_ids = generate_unit_ids_for_merge_group(
["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], new_id_strategy="take_first"
)
assert np.array_equal(new_unit_ids, ["a", "d"])

new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy="append")
assert np.array_equal(new_unit_ids, [16, 17])

new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy="take_first")
assert np.array_equal(new_unit_ids, [0, 9])

new_unit_ids = generate_unit_ids_for_merge_group(
["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy="append"
)
assert np.array_equal(new_unit_ids, ["16", "17"])

new_unit_ids = generate_unit_ids_for_merge_group(
["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy="take_first"
)
assert np.array_equal(new_unit_ids, ["0", "9"])


if __name__ == "__main__":
# test_spike_vector_to_spike_trains()
# test_spike_vector_to_indices()
test_random_spikes_selection()
# test_random_spikes_selection()

test_apply_merges_to_sorting()
test_get_ids_after_merging()
test_generate_unit_ids_for_merge_group()
31 changes: 6 additions & 25 deletions src/spikeinterface/curation/mergeunitssorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from spikeinterface.core.basesorting import BaseSorting, BaseSortingSegment
from spikeinterface.core.core_tools import define_function_from_class
from copy import deepcopy
from spikeinterface.core.sorting_tools import generate_unit_ids_for_merge_group


class MergeUnitsSorting(BaseSorting):
Expand Down Expand Up @@ -44,35 +45,15 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy
parents_unit_ids = sorting.unit_ids
sampling_frequency = sorting.get_sampling_frequency()

new_unit_ids = generate_unit_ids_for_merge_group(
sorting.unit_ids, units_to_merge, new_unit_ids=new_unit_ids, new_id_strategy="append"
)

all_removed_ids = []
for ids in units_to_merge:
all_removed_ids.extend(ids)
keep_unit_ids = [u for u in parents_unit_ids if u not in all_removed_ids]

if new_unit_ids is None:
dtype = parents_unit_ids.dtype
# select new_unit_ids greater that the max id, event greater than the numerical str ids
if np.issubdtype(dtype, np.character):
# dtype str
if all(p.isdigit() for p in parents_unit_ids):
# All str are digit : we can generate a max
m = max(int(p) for p in parents_unit_ids) + 1
new_unit_ids = [str(m + i) for i in range(num_merge)]
else:
# we cannot automatically find new names
new_unit_ids = [f"merge{i}" for i in range(num_merge)]
if np.any(np.isin(new_unit_ids, keep_unit_ids)):
raise ValueError(
"Unable to find 'new_unit_ids' because it is a string and parents "
"already contain merges. Pass a list of 'new_unit_ids' as an argument."
)
else:
# dtype int
new_unit_ids = list(max(parents_unit_ids) + 1 + np.arange(num_merge, dtype=dtype))
else:
if np.any(np.isin(new_unit_ids, keep_unit_ids)):
raise ValueError("'new_unit_ids' already exist in the sorting.unit_ids. Provide new ones")

assert len(new_unit_ids) == num_merge, "new_unit_ids must have the same size as units_to_merge"

# some checks
Expand All @@ -81,7 +62,7 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy
assert properties_policy in ("keep", "remove"), "properties_policy must be " "keep" " or " "remove" ""

# new units are put at the end
unit_ids = keep_unit_ids + new_unit_ids
unit_ids = keep_unit_ids + list(new_unit_ids)
BaseSorting.__init__(self, sampling_frequency, unit_ids)
# assert all(np.isin(keep_unit_ids, self.unit_ids)), 'new_unit_id should have a compatible format with the parent ids'

Expand Down

0 comments on commit d983b19

Please sign in to comment.