Skip to content

Commit

Permalink
Merge branch 'merging_units' of github.com:yger/spikeinterface into m…
Browse files Browse the repository at this point in the history
…erging_units
  • Loading branch information
yger committed Jul 9, 2024
2 parents f313a60 + feb0ff8 commit 5e29e24
Show file tree
Hide file tree
Showing 18 changed files with 666 additions and 401 deletions.
4 changes: 2 additions & 2 deletions doc/modules/curation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ This format has two part:
* **manual output** curation with the folowing keys:

* "manual_labels"
* "merged_unit_groups"
* "merge_unit_groups"
* "removed_units"

Here is the description of the format with a simple example:
Expand Down Expand Up @@ -128,7 +128,7 @@ Here is the description of the format with a simple example:
]
}
],
"merged_unit_groups": [
"merge_unit_groups": [
[
"u3",
"u6"
Expand Down
413 changes: 261 additions & 152 deletions src/spikeinterface/core/analyzer_extension_core.py

Large diffs are not rendered by default.

42 changes: 21 additions & 21 deletions src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def random_spikes_selection(


def apply_merges_to_sorting(
sorting, units_to_merge, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy="append"
sorting, merge_unit_groups, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy="append"
):
"""
Apply a resolved representation of the merges to a sorting object.
Expand All @@ -242,16 +242,16 @@ def apply_merges_to_sorting(
----------
sorting : Sorting
The Sorting object to apply merges.
units_to_merge : list/tuple of lists/tuples
merge_unit_groups : list/tuple of lists/tuples
A list of lists for every merge group. Each element needs to have at least two elements (two units to merge),
but it can also have more (merge multiple units at once).
new_unit_ids : list | None, default: None
A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. If None,
A new unit_ids for merged units. If given, it needs to have the same length as `merge_unit_groups`. If None,
merged units will have the first unit_id of every lists of merges.
censor_ms: float | None, default: None
When applying the merges, should be discard consecutive spikes violating a given refractory per
return_kept : bool, default: False
If True, also return also a booolean mask of kept spikes.
If True, also return also a boolean mask of kept spikes.
new_id_strategy : "append" | "take_first", default: "append"
The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids.
Expand All @@ -271,15 +271,15 @@ def apply_merges_to_sorting(
keep_mask = np.ones(len(spikes), dtype=bool)

new_unit_ids = generate_unit_ids_for_merge_group(
sorting.unit_ids, units_to_merge, new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy
sorting.unit_ids, merge_unit_groups, new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy
)

rename_ids = {}
for i, merge_group in enumerate(units_to_merge):
for i, merge_group in enumerate(merge_unit_groups):
for unit_id in merge_group:
rename_ids[unit_id] = new_unit_ids[i]

all_unit_ids = _get_ids_after_merging(sorting.unit_ids, units_to_merge, new_unit_ids)
all_unit_ids = _get_ids_after_merging(sorting.unit_ids, merge_unit_groups, new_unit_ids)
all_unit_ids = list(all_unit_ids)

num_seg = sorting.get_num_segments()
Expand All @@ -303,7 +303,7 @@ def apply_merges_to_sorting(

if censor_ms is not None:
rpv = int(sorting.sampling_frequency * censor_ms / 1000.0)
for group_old_ids in units_to_merge:
for group_old_ids in merge_unit_groups:
for segment_index in range(num_seg):
group_indices = []
for unit_id in group_old_ids:
Expand All @@ -322,7 +322,7 @@ def apply_merges_to_sorting(
return sorting


def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids):
def _get_ids_after_merging(old_unit_ids, merge_unit_groups, new_unit_ids):
"""
Function to get the list of unique unit_ids after some merges, with given new_units_ids would
be provided.
Expand All @@ -333,11 +333,11 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids):
----------
old_unit_ids : np.array
The old unit_ids.
units_to_merge : list/tuple of lists/tuples
merge_unit_groups : list/tuple of lists/tuples
A list of lists for every merge group. Each element needs to have at least two elements (two units to merge),
but it can also have more (merge multiple units at once).
new_unit_ids : list | None
A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`.
A new unit_ids for merged units. If given, it needs to have the same length as `merge_unit_groups`.
Returns
-------
Expand All @@ -348,10 +348,10 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids):
"""
old_unit_ids = np.asarray(old_unit_ids)

assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge"
assert len(new_unit_ids) == len(merge_unit_groups), "new_unit_ids should have the same len as merge_unit_groups"

all_unit_ids = list(old_unit_ids.copy())
for new_unit_id, group_ids in zip(new_unit_ids, units_to_merge):
for new_unit_id, group_ids in zip(new_unit_ids, merge_unit_groups):
assert len(group_ids) > 1, "A merge should have at least two units"
for unit_id in group_ids:
assert unit_id in old_unit_ids, "Merged ids should be in the sorting"
Expand All @@ -364,21 +364,21 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids):
return np.array(all_unit_ids)


def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids=None, new_id_strategy="append"):
def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_ids=None, new_id_strategy="append"):
"""
Function to generate new units ids during a merging procedure. If new_units_ids
are provided, it will return these unit ids, checking that they have the the same
length as `units_to_merge`.
length as `merge_unit_groups`.
Parameters
----------
old_unit_ids : np.array
The old unit_ids.
units_to_merge : list/tuple of lists/tuples
merge_unit_groups : list/tuple of lists/tuples
A list of lists for every merge group. Each element needs to have at least two elements (two units to merge),
but it can also have more (merge multiple units at once).
new_unit_ids : list | None, default: None
Optional new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`.
Optional new unit_ids for merged units. If given, it needs to have the same length as `merge_unit_groups`.
If None, new ids will be generated.
new_id_strategy : "append" | "take_first", default: "append"
The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids.
Expand All @@ -395,17 +395,17 @@ def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids

if new_unit_ids is not None:
# then only doing a consistency check
assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge"
assert len(new_unit_ids) == len(merge_unit_groups), "new_unit_ids should have the same len as merge_unit_groups"
# new_unit_ids can also be part of old_unit_ids only inside the same group:
for i, new_unit_id in enumerate(new_unit_ids):
if new_unit_id in old_unit_ids:
assert new_unit_id in units_to_merge[i], "new_unit_ids already exists but outside the merged groups"
assert new_unit_id in merge_unit_groups[i], "new_unit_ids already exists but outside the merged groups"
else:
dtype = old_unit_ids.dtype
num_merge = len(units_to_merge)
num_merge = len(merge_unit_groups)
# select new_unit_ids greater that the max id, event greater than the numerical str ids
if new_id_strategy == "take_first":
new_unit_ids = [to_be_merged[0] for to_be_merged in units_to_merge]
new_unit_ids = [to_be_merged[0] for to_be_merged in merge_unit_groups]
elif new_id_strategy == "append":
if np.issubdtype(dtype, np.character):
# dtype str
Expand Down
Loading

0 comments on commit 5e29e24

Please sign in to comment.