diff --git a/src/spikeinterface/curation/curationsorting.py b/src/spikeinterface/curation/curationsorting.py index b031ab9146..702bb587f7 100644 --- a/src/spikeinterface/curation/curationsorting.py +++ b/src/spikeinterface/curation/curationsorting.py @@ -18,8 +18,8 @@ class CurationSorting: Parameters ---------- - parent_sorting : Recording - The recording object + sorting: BaseSorting + The sorting object properties_policy : "keep" | "remove", default: "keep" Policy used to propagate properties after split and merge operation. If "keep" the properties will be passed to the new units (if the original units have the same value). If "remove" the new units will have @@ -32,12 +32,13 @@ class CurationSorting: Sorting object with the selected units merged """ - def __init__(self, parent_sorting, make_graph=False, properties_policy="keep"): + def __init__(self, sorting, make_graph=False, properties_policy="keep"): + # to allow undo and redo a list of sortingextractors is keep - self._sorting_stages = [parent_sorting] + self._sorting_stages = [sorting] self._sorting_stages_i = 0 self._properties_policy = properties_policy - parent_units = parent_sorting.get_unit_ids() + parent_units = sorting.get_unit_ids() self._make_graph = make_graph if make_graph: # to easily allow undo and redo a list of graphs with the history of the curation is keep @@ -52,7 +53,7 @@ def __init__(self, parent_sorting, make_graph=False, properties_policy="keep"): else: self.max_used_id = max(parent_units) if len(parent_units) > 0 else 0 - self._kwargs = dict(parent_sorting=parent_sorting, make_graph=make_graph, properties_policy=properties_policy) + self._kwargs = dict(sorting=sorting, make_graph=make_graph, properties_policy=properties_policy) def _get_unused_id(self, n=1): # check units in the graph to the next unused unit id @@ -121,7 +122,7 @@ def merge(self, units_to_merge, new_unit_id=None, delta_time_ms=0.4): elif new_unit_id not in units_to_merge: assert new_unit_id not in current_sorting.unit_ids, f"new_unit_id already exists!" new_sorting = MergeUnitsSorting( - parent_sorting=current_sorting, + sorting=current_sorting, units_to_merge=units_to_merge, new_unit_ids=[new_unit_id], delta_time_ms=delta_time_ms, diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index 3bf1d1d43a..bbdb70b2f6 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -12,7 +12,7 @@ class MergeUnitsSorting(BaseSorting): Parameters ---------- - parent_sorting : Recording + sorting: BaseSorting The sorting object units_to_merge : list/tuple of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), @@ -32,8 +32,8 @@ class MergeUnitsSorting(BaseSorting): Sorting object with the selected units merged """ - def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties_policy="keep", delta_time_ms=0.4): - self._parent_sorting = parent_sorting + def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy="keep", delta_time_ms=0.4): + self._parent_sorting = sorting if not isinstance(units_to_merge[0], (list, tuple)): # keep backward compatibility : the previous behavior was only one merge @@ -41,8 +41,8 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties num_merge = len(units_to_merge) - parents_unit_ids = parent_sorting.unit_ids - sampling_frequency = parent_sorting.get_sampling_frequency() + parents_unit_ids = sorting.unit_ids + sampling_frequency = sorting.get_sampling_frequency() all_removed_ids = [] for ids in units_to_merge: @@ -93,17 +93,17 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties sub_segment = MergeUnitsSortingSegment(parent_segment, units_to_merge, new_unit_ids, rm_dup_delta) self.add_sorting_segment(sub_segment) - ann_keys = parent_sorting._annotations.keys() - self._annotations = deepcopy({k: parent_sorting._annotations[k] for k in ann_keys}) + ann_keys = sorting._annotations.keys() + self._annotations = deepcopy({k: sorting._annotations[k] for k in ann_keys}) # copy properties for unchanged units, and check if units propierties are the same - keep_parent_inds = parent_sorting.ids_to_indices(keep_unit_ids) + keep_parent_inds = sorting.ids_to_indices(keep_unit_ids) # ~ all_removed_inds = parent_sorting.ids_to_indices(all_removed_ids) keep_inds = self.ids_to_indices(keep_unit_ids) # ~ merge_inds = self.ids_to_indices(new_unit_ids) - prop_keys = parent_sorting.get_property_keys() + prop_keys = sorting.get_property_keys() for key in prop_keys: - parent_values = parent_sorting.get_property(key) + parent_values = sorting.get_property(key) if properties_policy == "keep": # propagate keep values @@ -111,7 +111,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties new_values = np.empty(shape=shape, dtype=parent_values.dtype) new_values[keep_inds] = parent_values[keep_parent_inds] for new_id, ids in zip(new_unit_ids, units_to_merge): - removed_inds = parent_sorting.ids_to_indices(ids) + removed_inds = sorting.ids_to_indices(ids) merge_values = parent_values[removed_inds] same_property_values = np.all([np.array_equal(m, merge_values[0]) for m in merge_values[1:]]) @@ -133,13 +133,13 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties elif properties_policy == "remove": self.set_property(key, parent_values[keep_parent_inds], keep_unit_ids) - if parent_sorting.has_recording(): - self.register_recording(parent_sorting._recording) + if sorting.has_recording(): + self.register_recording(sorting._recording) # make it jsonable units_to_merge = [list(e) for e in units_to_merge] self._kwargs = dict( - parent_sorting=parent_sorting, + sorting=sorting, units_to_merge=units_to_merge, new_unit_ids=new_unit_ids, properties_policy=properties_policy, diff --git a/src/spikeinterface/curation/splitunitsorting.py b/src/spikeinterface/curation/splitunitsorting.py index 8fc6afcde8..33c14dfe5a 100644 --- a/src/spikeinterface/curation/splitunitsorting.py +++ b/src/spikeinterface/curation/splitunitsorting.py @@ -13,8 +13,8 @@ class SplitUnitSorting(BaseSorting): Parameters ---------- - parent_sorting : Recording - The recording object + sorting: BaseSorting + The sorting object parent_unit_id : int Unit id of the unit to split indices_list : list or np.array @@ -34,11 +34,11 @@ class SplitUnitSorting(BaseSorting): Sorting object with the selected units split """ - def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=None, properties_policy="keep"): + def __init__(self, sorting, split_unit_id, indices_list, new_unit_ids=None, properties_policy="keep"): if type(indices_list) is not list: indices_list = [indices_list] - parents_unit_ids = parent_sorting.unit_ids - assert parent_sorting.get_num_segments() == len( + parents_unit_ids = sorting.unit_ids + assert sorting.get_num_segments() == len( indices_list ), "The length of indices_list must be the same as parent_sorting.get_num_segments" split_unit_indices = np.unique([np.unique(v) for v in indices_list]) @@ -70,10 +70,10 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non np.isin(new_unit_ids, unchanged_units) ), "new_unit_ids should be new unit ids or no more than one unit id can be found in split_unit_id" - sampling_frequency = parent_sorting.get_sampling_frequency() + sampling_frequency = sorting.get_sampling_frequency() units_ids = np.concatenate([unchanged_units, new_unit_ids]) - self._parent_sorting = parent_sorting + self._parent_sorting = sorting BaseSorting.__init__(self, sampling_frequency, units_ids) assert all( @@ -85,18 +85,18 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non self.add_sorting_segment(sub_segment) # copy properties - ann_keys = parent_sorting._annotations.keys() - self._annotations = deepcopy({k: parent_sorting._annotations[k] for k in ann_keys}) + ann_keys = sorting._annotations.keys() + self._annotations = deepcopy({k: sorting._annotations[k] for k in ann_keys}) # copy properties for unchanged units, and check if units propierties - keep_parent_inds = parent_sorting.ids_to_indices(unchanged_units) - split_unit_id_ind = parent_sorting.id_to_index(split_unit_id) + keep_parent_inds = sorting.ids_to_indices(unchanged_units) + split_unit_id_ind = sorting.id_to_index(split_unit_id) keep_units_inds = self.ids_to_indices(unchanged_units) split_unit_ind = self.ids_to_indices(new_unit_ids) # copy properties from original units to split ones - prop_keys = parent_sorting._properties.keys() + prop_keys = sorting._properties.keys() for k in prop_keys: - values = parent_sorting._properties[k] + values = sorting._properties[k] if properties_policy == "keep": new_values = np.empty_like(values, shape=len(units_ids)) new_values[keep_units_inds] = values[keep_parent_inds] @@ -105,11 +105,11 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non continue self.set_property(k, values[keep_parent_inds], unchanged_units) - if parent_sorting.has_recording(): - self.register_recording(parent_sorting._recording) + if sorting.has_recording(): + self.register_recording(sorting._recording) self._kwargs = dict( - parent_sorting=parent_sorting, + sorting=sorting, split_unit_id=split_unit_id, indices_list=indices_list, new_unit_ids=new_unit_ids,