Skip to content

Commit

Permalink
Merge pull request #2427 from alejoe91/fix-merging
Browse files Browse the repository at this point in the history
Fix sv curation and merge units with array properties
  • Loading branch information
samuelgarcia authored Jan 26, 2024
2 parents c2f18bd + 7bac0c4 commit 4af9205
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 22 deletions.
32 changes: 22 additions & 10 deletions src/spikeinterface/curation/mergeunitssorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties

if new_unit_ids is None:
dtype = parents_unit_ids.dtype
# select new_units_ids greater that the max id, event greater than the numerical str ids
# select new_unit_ids greater that the max id, event greater than the numerical str ids
if np.issubdtype(dtype, np.character):
# dtype str
if all(p.isdigit() for p in parents_unit_ids):
Expand Down Expand Up @@ -79,8 +79,8 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties
assert properties_policy in ("keep", "remove"), "properties_policy must be " "keep" " or " "remove" ""

# new units are put at the end
units_ids = keep_unit_ids + new_unit_ids
BaseSorting.__init__(self, sampling_frequency, units_ids)
unit_ids = keep_unit_ids + new_unit_ids
BaseSorting.__init__(self, sampling_frequency, unit_ids)
# assert all(np.isin(keep_unit_ids, self.unit_ids)), 'new_unit_id should have a compatible format with the parent ids'

if delta_time_ms is None:
Expand All @@ -99,25 +99,37 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties
# ~ all_removed_inds = parent_sorting.ids_to_indices(all_removed_ids)
keep_inds = self.ids_to_indices(keep_unit_ids)
# ~ merge_inds = self.ids_to_indices(new_unit_ids)
prop_keys = parent_sorting._properties.keys()
for k in prop_keys:
parent_values = parent_sorting._properties[k]
prop_keys = parent_sorting.get_property_keys()
for key in prop_keys:
parent_values = parent_sorting.get_property(key)

if properties_policy == "keep":
# propagate keep values
new_values = np.empty(shape=len(units_ids), dtype=parent_values.dtype)
shape = (len(unit_ids),) + parent_values.shape[1:]
new_values = np.empty(shape=shape, dtype=parent_values.dtype)
new_values[keep_inds] = parent_values[keep_parent_inds]
for new_id, ids in zip(new_unit_ids, units_to_merge):
removed_inds = parent_sorting.ids_to_indices(ids)
merge_values = parent_values[removed_inds]
if all(merge_values == merge_values[0]):

same_property_values = np.all([np.array_equal(m, merge_values[0]) for m in merge_values[1:]])

if same_property_values:
# and new values only if they are all similar
ind = self.id_to_index(new_id)
new_values[ind] = merge_values[0]
self.set_property(k, new_values)
else:
if parent_values.dtype.kind == "f":
new_values[ind] = np.nan
elif parent_values.dtype.kind in ("U", "S"):
new_values[ind] = ""
else:
new_values = new_values.astype(object)
new_values[ind] = None
self.set_property(key, new_values)

elif properties_policy == "remove":
self.set_property(k, parent_values[keep_parent_inds], keep_unit_ids)
self.set_property(key, parent_values[keep_parent_inds], keep_unit_ids)

if parent_sorting.has_recording():
self.register_recording(parent_sorting._recording)
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/curation/sortingview_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def apply_sortingview_curation(
if verbose:
print(f"Merging {merge_group}")
if unit_ids_dtype.kind in ("U", "S"):
merge_group = [str(unit) for unit in merge_group]
# if unit dtype is str, set new id as "{unit1}-{unit2}"
new_unit_id = "-".join(merge_group)
curation_sorting.merge(merge_group, new_unit_id=new_unit_id)
Expand Down
54 changes: 45 additions & 9 deletions src/spikeinterface/curation/tests/test_curationsorting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import numpy as np

from spikeinterface.core import NumpySorting
import numpy as np
from spikeinterface.curation import CurationSorting, MergeUnitsSorting, SplitUnitSorting


Expand All @@ -19,20 +18,58 @@ def test_split_merge():
parent_sort = NumpySorting.from_unit_dict(spikestimes, sampling_frequency=1000) # to have 1 sample=1ms
parent_sort.set_property("someprop", [float(k) for k in spikestimes[0].keys()]) # float

# %%
split_index = [v[4] % 2 for v in spikestimes] # spit class 4 in even and odds
splited = SplitUnitSorting(
splitted = SplitUnitSorting(
parent_sort, split_unit_id=4, indices_list=split_index, new_unit_ids=[8, 10], properties_policy="keep"
)
merged = MergeUnitsSorting(splited, units_to_merge=[[8, 10]], new_unit_ids=[4], properties_policy="keep")

# Test 1D and multi-D properties of different types
# add 1D str property (with different values for units to be merged) -> ""
some_str_prop = ["merge"] * len(splitted.unit_ids)
some_str_prop[-1] = "different"
splitted.set_property("some_str_prop", some_str_prop)

# add 1D float property (with same values for units to be merged) -> keep
some_prop_to_keep = np.ones(len(splitted.unit_ids))
splitted.set_property("some_prop_to_keep", some_prop_to_keep)
# add 1D float property (with different values for units to be merged) -> nan
some_prop_to_remove = np.arange(len(splitted.unit_ids), dtype=float)
splitted.set_property("some_prop_to_remove", some_prop_to_remove)
# add 1D int property (with different values for units to be merged) -> None
some_prop_to_none = np.arange(len(splitted.unit_ids), dtype=int)
splitted.set_property("some_prop_to_none", some_prop_to_none)

# add array property (with same values for units to be merged) -> keep
some_array_prop_same_values = np.ones((len(splitted.unit_ids), 2))
splitted.set_property("some_array_prop_to_keep", some_array_prop_same_values)
# add float array property (with different values for units to be merged) -> nan
some_array_prop_to_remove = np.random.randn(len(splitted.unit_ids), 2)
splitted.set_property("some_array_prop_to_remove", some_array_prop_to_remove)
# add int array property (with different values for units to be merged) -> None
some_array_prop_to_none = np.ones((len(splitted.unit_ids), 2), dtype=int)
some_array_prop_to_none[-1] = [1, 2]
splitted.set_property("some_array_prop_to_none", some_array_prop_to_none)

merged = MergeUnitsSorting(splitted, units_to_merge=[[8, 10]], new_unit_ids=[4], properties_policy="keep")
for i in range(len(spikestimes)):
assert (
all(parent_sort.get_unit_spike_train(4, segment_index=i) == merged.get_unit_spike_train(4, segment_index=i))
== True
), "splir or merge error"
assert parent_sort.get_unit_property(4, "someprop") == merged.get_unit_property(4, "someprop"), (
"property wasn" "t kept"
)
assert parent_sort.get_unit_property(4, "someprop") == merged.get_unit_property(
4, "someprop"
), "property wasn't kept"
# 1d
assert merged.get_unit_property(4, "some_str_prop") == "", "error with array property"
assert merged.get_unit_property(4, "some_prop_to_keep") == 1, "error with array property"
assert np.isnan(merged.get_unit_property(4, "some_prop_to_remove")), "error with array property"
assert merged.get_unit_property(4, "some_prop_to_none") is None, "error with array property"
# 2d
assert np.array_equal(merged.get_unit_property(4, "some_array_prop_to_keep"), [1, 1]), "error with array property"
assert np.all(np.isnan(merged.get_unit_property(4, "some_array_prop_to_remove"))), "error with array property"
assert np.array_equal(
merged.get_unit_property(4, "some_array_prop_to_none"), [None, None]
), "error with array property"

merged_with_dups = MergeUnitsSorting(
parent_sort, new_unit_ids=[8], units_to_merge=[[0, 1]], properties_policy="remove", delta_time_ms=0.5
Expand All @@ -57,7 +94,6 @@ def test_curation():
parent_sort = NumpySorting.from_unit_dict(spikestimes, sampling_frequency=1000) # to have 1 sample=1ms
parent_sort.set_property("some_names", ["unit_{}".format(k) for k in spikestimes[0].keys()]) # float
cs = CurationSorting(parent_sort, properties_policy="remove")
# %%
cs.merge(["a", "c"])
assert cs.sorting.get_num_units() == len(spikestimes[0]) - 1
split_index = [v["b"] < 6 for v in spikestimes] # split class 4 in even and odds
Expand Down
3 changes: 0 additions & 3 deletions src/spikeinterface/widgets/sorting_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ class SortingSummaryWidget(BaseWidget):
curation : bool, default: False
If True, manual curation is enabled
(sortingview backend)
unit_table_properties : list or None, default: None
List of properties to be added to the unit table
(sortingview backend)
label_choices : list or None, default: None
List of labels to be added to the curation table
(sortingview backend)
Expand Down

0 comments on commit 4af9205

Please sign in to comment.