From 7b49ef9cb9faa891d61ff4434807d705afe31668 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 22 Jan 2024 15:42:47 +0100 Subject: [PATCH 1/3] Fix sv curation and merg units with array properties --- src/spikeinterface/curation/mergeunitssorting.py | 12 ++++++++++-- src/spikeinterface/curation/sortingview_curation.py | 1 + src/spikeinterface/widgets/sorting_summary.py | 3 --- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index ae033d5531..af321915da 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -105,12 +105,20 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties if properties_policy == "keep": # propagate keep values - new_values = np.empty(shape=len(units_ids), dtype=parent_values.dtype) + if isinstance(parent_values, np.ndarray): + shape = (len(units_ids),) + parent_values.shape[1:] + else: + shape = len(units_ids) + new_values = np.empty(shape=shape, dtype=parent_values.dtype) new_values[keep_inds] = parent_values[keep_parent_inds] for new_id, ids in zip(new_unit_ids, units_to_merge): removed_inds = parent_sorting.ids_to_indices(ids) merge_values = parent_values[removed_inds] - if all(merge_values == merge_values[0]): + if isinstance(parent_values, np.ndarray): + same_property_values = np.all([np.array_equal(m, merge_values[0]) for m in merge_values[1:]]) + else: + same_property_values = np.all([m == merge_values[0] for m in merge_values[1:]]) + if same_property_values: # and new values only if they are all similar ind = self.id_to_index(new_id) new_values[ind] = merge_values[0] diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index e6427b32a2..c99a32bcad 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -70,6 +70,7 @@ def apply_sortingview_curation( if verbose: print(f"Merging {merge_group}") if unit_ids_dtype.kind in ("U", "S"): + merge_group = [str(unit) for unit in merge_group] # if unit dtype is str, set new id as "{unit1}-{unit2}" new_unit_id = "-".join(merge_group) curation_sorting.merge(merge_group, new_unit_id=new_unit_id) diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index e281e89257..af3275d0d0 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -35,9 +35,6 @@ class SortingSummaryWidget(BaseWidget): curation : bool, default: False If True, manual curation is enabled (sortingview backend) - unit_table_properties : list or None, default: None - List of properties to be added to the unit table - (sortingview backend) label_choices : list or None, default: None List of labels to be added to the curation table (sortingview backend) From 6f102801681ab7b3dff2dd74e9e118915dfe7298 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 25 Jan 2024 11:18:28 +0100 Subject: [PATCH 2/3] Simplify logic for property propagation in mergeunits and add tests --- .../curation/mergeunitssorting.py | 36 ++++++++++--------- .../curation/tests/test_curationsorting.py | 36 ++++++++++++++----- 2 files changed, 47 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index af321915da..ca451032ba 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -49,7 +49,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties if new_unit_ids is None: dtype = parents_unit_ids.dtype - # select new_units_ids greater that the max id, event greater than the numerical str ids + # select new_unit_ids greater that the max id, event greater than the numerical str ids if np.issubdtype(dtype, np.character): # dtype str if all(p.isdigit() for p in parents_unit_ids): @@ -79,8 +79,8 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties assert properties_policy in ("keep", "remove"), "properties_policy must be " "keep" " or " "remove" "" # new units are put at the end - units_ids = keep_unit_ids + new_unit_ids - BaseSorting.__init__(self, sampling_frequency, units_ids) + unit_ids = keep_unit_ids + new_unit_ids + BaseSorting.__init__(self, sampling_frequency, unit_ids) # assert all(np.isin(keep_unit_ids, self.unit_ids)), 'new_unit_id should have a compatible format with the parent ids' if delta_time_ms is None: @@ -99,33 +99,37 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties # ~ all_removed_inds = parent_sorting.ids_to_indices(all_removed_ids) keep_inds = self.ids_to_indices(keep_unit_ids) # ~ merge_inds = self.ids_to_indices(new_unit_ids) - prop_keys = parent_sorting._properties.keys() - for k in prop_keys: - parent_values = parent_sorting._properties[k] + prop_keys = parent_sorting.get_property_keys() + for key in prop_keys: + parent_values = parent_sorting.get_property(key) if properties_policy == "keep": # propagate keep values - if isinstance(parent_values, np.ndarray): - shape = (len(units_ids),) + parent_values.shape[1:] - else: - shape = len(units_ids) + shape = (len(unit_ids),) + parent_values.shape[1:] new_values = np.empty(shape=shape, dtype=parent_values.dtype) new_values[keep_inds] = parent_values[keep_parent_inds] for new_id, ids in zip(new_unit_ids, units_to_merge): removed_inds = parent_sorting.ids_to_indices(ids) merge_values = parent_values[removed_inds] - if isinstance(parent_values, np.ndarray): - same_property_values = np.all([np.array_equal(m, merge_values[0]) for m in merge_values[1:]]) - else: - same_property_values = np.all([m == merge_values[0] for m in merge_values[1:]]) + + same_property_values = np.all([np.array_equal(m, merge_values[0]) for m in merge_values[1:]]) + if same_property_values: # and new values only if they are all similar ind = self.id_to_index(new_id) new_values[ind] = merge_values[0] - self.set_property(k, new_values) + else: + if parent_values.dtype.kind == "f": + new_values[ind] = np.nan + elif parent_values.dtype.kind in ("U", "S"): + new_values[ind] = "" + else: + new_values = new_values.astype(object) + new_values[ind] = None + self.set_property(key, new_values) elif properties_policy == "remove": - self.set_property(k, parent_values[keep_parent_inds], keep_unit_ids) + self.set_property(key, parent_values[keep_parent_inds], keep_unit_ids) if parent_sorting.has_recording(): self.register_recording(parent_sorting._recording) diff --git a/src/spikeinterface/curation/tests/test_curationsorting.py b/src/spikeinterface/curation/tests/test_curationsorting.py index 91bc21a49f..0609fdc53b 100644 --- a/src/spikeinterface/curation/tests/test_curationsorting.py +++ b/src/spikeinterface/curation/tests/test_curationsorting.py @@ -1,7 +1,6 @@ -import pytest +import numpy as np from spikeinterface.core import NumpySorting -import numpy as np from spikeinterface.curation import CurationSorting, MergeUnitsSorting, SplitUnitSorting @@ -19,20 +18,40 @@ def test_split_merge(): parent_sort = NumpySorting.from_unit_dict(spikestimes, sampling_frequency=1000) # to have 1 sample=1ms parent_sort.set_property("someprop", [float(k) for k in spikestimes[0].keys()]) # float - # %% split_index = [v[4] % 2 for v in spikestimes] # spit class 4 in even and odds - splited = SplitUnitSorting( + splitted = SplitUnitSorting( parent_sort, split_unit_id=4, indices_list=split_index, new_unit_ids=[8, 10], properties_policy="keep" ) - merged = MergeUnitsSorting(splited, units_to_merge=[[8, 10]], new_unit_ids=[4], properties_policy="keep") + # add array property (with same values for units to be merged) -> keep + some_array_prop_same_values = np.ones((len(splitted.unit_ids), 2)) + splitted.set_property("some_array_prop_to_keep", some_array_prop_same_values) + # add float array property (with different values for units to be merged) -> nan + some_array_prop_to_remove = np.random.randn(len(splitted.unit_ids), 2) + splitted.set_property("some_array_prop_to_remove", some_array_prop_to_remove) + # add str property (with different values for units to be merged) -> "" + some_str_prop = ["merge"] * len(splitted.unit_ids) + some_str_prop[-1] = "different" + splitted.set_property("some_str_prop", some_str_prop) + # add int array property (with different values for units to be merged) -> None + some_array_prop_to_set_none = np.ones((len(splitted.unit_ids), 2), dtype=int) + some_array_prop_to_set_none[-1] = [1, 2] + splitted.set_property("some_array_prop_to_set_none", some_array_prop_to_set_none) + + merged = MergeUnitsSorting(splitted, units_to_merge=[[8, 10]], new_unit_ids=[4], properties_policy="keep") for i in range(len(spikestimes)): assert ( all(parent_sort.get_unit_spike_train(4, segment_index=i) == merged.get_unit_spike_train(4, segment_index=i)) == True ), "splir or merge error" - assert parent_sort.get_unit_property(4, "someprop") == merged.get_unit_property(4, "someprop"), ( - "property wasn" "t kept" - ) + assert parent_sort.get_unit_property(4, "someprop") == merged.get_unit_property( + 4, "someprop" + ), "property wasn't kept" + assert np.array_equal(merged.get_unit_property(4, "some_array_prop_to_keep"), [1, 1]), "error with array property" + assert np.all(np.isnan(merged.get_unit_property(4, "some_array_prop_to_remove"))), "error with array property" + assert np.array_equal( + merged.get_unit_property(4, "some_array_prop_to_set_none"), [None, None] + ), "error with array property" + assert merged.get_unit_property(4, "some_str_prop") == "", "error with array property" merged_with_dups = MergeUnitsSorting( parent_sort, new_unit_ids=[8], units_to_merge=[[0, 1]], properties_policy="remove", delta_time_ms=0.5 @@ -57,7 +76,6 @@ def test_curation(): parent_sort = NumpySorting.from_unit_dict(spikestimes, sampling_frequency=1000) # to have 1 sample=1ms parent_sort.set_property("some_names", ["unit_{}".format(k) for k in spikestimes[0].keys()]) # float cs = CurationSorting(parent_sort, properties_policy="remove") - # %% cs.merge(["a", "c"]) assert cs.sorting.get_num_units() == len(spikestimes[0]) - 1 split_index = [v["b"] < 6 for v in spikestimes] # split class 4 in even and odds From 7bac0c4207c0ed0be7f598ab3870da53f97e3ccc Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 25 Jan 2024 16:23:04 +0100 Subject: [PATCH 3/3] Add tests with 1d properties --- .../curation/tests/test_curationsorting.py | 36 ++++++++++++++----- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/curation/tests/test_curationsorting.py b/src/spikeinterface/curation/tests/test_curationsorting.py index 0609fdc53b..4f9c57cb8f 100644 --- a/src/spikeinterface/curation/tests/test_curationsorting.py +++ b/src/spikeinterface/curation/tests/test_curationsorting.py @@ -22,20 +22,33 @@ def test_split_merge(): splitted = SplitUnitSorting( parent_sort, split_unit_id=4, indices_list=split_index, new_unit_ids=[8, 10], properties_policy="keep" ) + + # Test 1D and multi-D properties of different types + # add 1D str property (with different values for units to be merged) -> "" + some_str_prop = ["merge"] * len(splitted.unit_ids) + some_str_prop[-1] = "different" + splitted.set_property("some_str_prop", some_str_prop) + + # add 1D float property (with same values for units to be merged) -> keep + some_prop_to_keep = np.ones(len(splitted.unit_ids)) + splitted.set_property("some_prop_to_keep", some_prop_to_keep) + # add 1D float property (with different values for units to be merged) -> nan + some_prop_to_remove = np.arange(len(splitted.unit_ids), dtype=float) + splitted.set_property("some_prop_to_remove", some_prop_to_remove) + # add 1D int property (with different values for units to be merged) -> None + some_prop_to_none = np.arange(len(splitted.unit_ids), dtype=int) + splitted.set_property("some_prop_to_none", some_prop_to_none) + # add array property (with same values for units to be merged) -> keep some_array_prop_same_values = np.ones((len(splitted.unit_ids), 2)) splitted.set_property("some_array_prop_to_keep", some_array_prop_same_values) # add float array property (with different values for units to be merged) -> nan some_array_prop_to_remove = np.random.randn(len(splitted.unit_ids), 2) splitted.set_property("some_array_prop_to_remove", some_array_prop_to_remove) - # add str property (with different values for units to be merged) -> "" - some_str_prop = ["merge"] * len(splitted.unit_ids) - some_str_prop[-1] = "different" - splitted.set_property("some_str_prop", some_str_prop) # add int array property (with different values for units to be merged) -> None - some_array_prop_to_set_none = np.ones((len(splitted.unit_ids), 2), dtype=int) - some_array_prop_to_set_none[-1] = [1, 2] - splitted.set_property("some_array_prop_to_set_none", some_array_prop_to_set_none) + some_array_prop_to_none = np.ones((len(splitted.unit_ids), 2), dtype=int) + some_array_prop_to_none[-1] = [1, 2] + splitted.set_property("some_array_prop_to_none", some_array_prop_to_none) merged = MergeUnitsSorting(splitted, units_to_merge=[[8, 10]], new_unit_ids=[4], properties_policy="keep") for i in range(len(spikestimes)): @@ -46,12 +59,17 @@ def test_split_merge(): assert parent_sort.get_unit_property(4, "someprop") == merged.get_unit_property( 4, "someprop" ), "property wasn't kept" + # 1d + assert merged.get_unit_property(4, "some_str_prop") == "", "error with array property" + assert merged.get_unit_property(4, "some_prop_to_keep") == 1, "error with array property" + assert np.isnan(merged.get_unit_property(4, "some_prop_to_remove")), "error with array property" + assert merged.get_unit_property(4, "some_prop_to_none") is None, "error with array property" + # 2d assert np.array_equal(merged.get_unit_property(4, "some_array_prop_to_keep"), [1, 1]), "error with array property" assert np.all(np.isnan(merged.get_unit_property(4, "some_array_prop_to_remove"))), "error with array property" assert np.array_equal( - merged.get_unit_property(4, "some_array_prop_to_set_none"), [None, None] + merged.get_unit_property(4, "some_array_prop_to_none"), [None, None] ), "error with array property" - assert merged.get_unit_property(4, "some_str_prop") == "", "error with array property" merged_with_dups = MergeUnitsSorting( parent_sort, new_unit_ids=[8], units_to_merge=[[0, 1]], properties_policy="remove", delta_time_ms=0.5