Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix split in more than 2 units and extend curation docs and tests #2775

Merged
merged 7 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 114 additions & 20 deletions src/spikeinterface/curation/curationsorting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

from collections import namedtuple
from collections.abc import Iterable
from hmac import new
zm711 marked this conversation as resolved.
Show resolved Hide resolved
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved

import numpy as np

from .mergeunitssorting import MergeUnitsSorting
Expand Down Expand Up @@ -59,9 +62,29 @@ def _get_unused_id(self, n=1):
ids = [str(i) for i in ids]
return ids

def split(self, split_unit_id, indices_list):
def split(self, split_unit_id, indices_list, new_unit_ids=None):
"""
Split a unit into multiple units.

Parameters
----------
split_unit_id: int
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
The unit to split
indices_list: list or np.array
A list of index arrays selecting the spikes to split in each segment.
Each array can contain more than 2 indices (e.g. for splitting in 3 or more units) and it should
be the same length as the spike train (for each segment).
If the sorting has only one segment, indices_list can be a single array
new_unit_ids: list ot None
List of new unit ids. If None, a new unit id is automatically selected
"""
current_sorting = self._sorting_stages[self._sorting_stages_i]
new_unit_ids = self._get_unused_id(2)
if not isinstance(indices_list, list):
indices_list = [indices_list]
if not isinstance(indices_list[0], Iterable):
raise ValueError("indices_list must be a list of iterable arrays")
if new_unit_ids is None:
new_unit_ids = self._get_unused_id(np.max([len(np.unique(v)) for v in indices_list]))
new_sorting = SplitUnitSorting(
current_sorting,
split_unit_id=split_unit_id,
Expand All @@ -81,6 +104,18 @@ def split(self, split_unit_id, indices_list):
self._add_new_stage(new_sorting, edges)

def merge(self, units_to_merge, new_unit_id=None, delta_time_ms=0.4):
"""
Merge a list of units into a new unit.

Parameters
----------
units_to_merge: list
List of unit ids to merge
new_unit_id: int
zm711 marked this conversation as resolved.
Show resolved Hide resolved
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
The new unit id. If None, a new unit id is automatically selected
delta_time_ms: float
Number of ms to consider for duplicated spikes. None won't check for duplications
"""
current_sorting = self._sorting_stages[self._sorting_stages_i]
if new_unit_id is None:
new_unit_id = self._get_unused_id()[0]
Expand All @@ -104,6 +139,14 @@ def merge(self, units_to_merge, new_unit_id=None, delta_time_ms=0.4):
self._add_new_stage(new_sorting, edges)

def remove_units(self, unit_ids):
"""
Remove a list of units.

Parameters
----------
unit_ids: list
List of unit ids to remove
"""
current_sorting = self._sorting_stages[self._sorting_stages_i]
unit2keep = [u for u in current_sorting.get_unit_ids() if u not in unit_ids]
if self._make_graph:
Expand All @@ -114,9 +157,27 @@ def remove_units(self, unit_ids):
self._add_new_stage(current_sorting.select_units(unit2keep), edges)

def remove_unit(self, unit_id):
"""
Remove a unit.

Parameters
----------
unit_id : int ot str
The unit id to remove
"""
self.remove_units([unit_id])

def select_units(self, unit_ids, renamed_unit_ids=None):
"""
Select a list of units.

Parameters
----------
unit_ids : list
List of unit ids to select
renamed_unit_ids : list or None, default: None
List of new unit ids to rename the selected units
"""
new_sorting = self._sorting_stages[self._sorting_stages_i].select_units(unit_ids, renamed_unit_ids)
if self._make_graph:
i = self._sorting_stages_i
Expand All @@ -129,20 +190,20 @@ def select_units(self, unit_ids, renamed_unit_ids=None):
self._add_new_stage(new_sorting, edges)

def rename(self, renamed_unit_ids):
self.select_units(self.current_sorting.unit_ids, renamed_unit_ids=renamed_unit_ids)
"""
Rename a list of units.

def _add_new_stage(self, new_sorting, edges):
# adds the stage to the stage list and creates the associated new graph
self._sorting_stages = self._sorting_stages[0 : self._sorting_stages_i + 1]
self._sorting_stages.append(new_sorting)
if self._make_graph:
self._graphs = self._graphs[0 : self._sorting_stages_i + 1]
new_graph = self._graphs[self._sorting_stages_i].copy()
new_graph.add_edges_from(edges)
self._graphs.append(new_graph)
self._sorting_stages_i += 1
Parameters
----------
renamed_unit_ids : list
List of unit ids to rename exisiting units
"""
self.select_units(self.current_sorting.unit_ids, renamed_unit_ids=renamed_unit_ids)

def remove_empty_units(self):
"""
Remove empty units.
"""
i = self._sorting_stages_i
new_sorting = self._sorting_stages[i].remove_empty_units()
if self._make_graph:
Expand All @@ -153,22 +214,52 @@ def remove_empty_units(self):
self._add_new_stage(new_sorting, edges)

def redo_available(self):
"""
Check if redo is available.

Returns
-------
bool
True if redo is available
"""
# useful function for a gui
return self._sorting_stages_i < len(self._sorting_stages)

def undo_available(self):
"""
Check if undo is available.

Returns
-------
bool
True if undo is available
"""
# useful function for a gui
return self._sorting_stages_i > 0

def undo(self):
"""
Undo the last operation.
"""
if self.undo_available():
self._sorting_stages_i -= 1

def redo(self):
"""
Redo the last operation.
"""
if self.redo_available():
self._sorting_stages_i += 1

def draw_graph(self, **kwargs):
"""
Draw the curation graph.

Parameters
----------
**kwargs: dict
Keyword arguments for Networkx draw function
"""
assert self._make_graph, "to make a graph use make_graph=True"
graph = self.graph
ids = [c.unit_id for c in graph.nodes]
Expand All @@ -189,13 +280,16 @@ def sorting(self):
def current_sorting(self):
return self._sorting_stages[self._sorting_stages_i]

# def __getattr__(self,name):
# #any method not define for this class will try to use the current
# # sorting stage. In that whay this class will behave as a sortingextractor
# current_sorting = self._sorting_stages[self._sorting_stages_i]

# attr = object.__getattribute__(current_sorting, name)
# return attr
def _add_new_stage(self, new_sorting, edges):
# adds the stage to the stage list and creates the associated new graph
self._sorting_stages = self._sorting_stages[0 : self._sorting_stages_i + 1]
self._sorting_stages.append(new_sorting)
if self._make_graph:
self._graphs = self._graphs[0 : self._sorting_stages_i + 1]
new_graph = self._graphs[self._sorting_stages_i].copy()
new_graph.add_edges_from(edges)
self._graphs.append(new_graph)
self._sorting_stages_i += 1


curation_sorting = define_function_from_class(source_class=CurationSorting, name="curation_sorting")
10 changes: 5 additions & 5 deletions src/spikeinterface/curation/splitunitsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@ class SplitUnitSorting(BaseSorting):
The recording object
parent_unit_id: int
Unit id of the unit to split
indices_list: list
indices_list: list or np.array
A list of index arrays selecting the spikes to split in each segment.
Each array can contain more than 2 indices (e.g. for splitting in 3 or more units) and it should
be the same length as the spike train (for each segment)
be the same length as the spike train (for each segment).
If the sorting has only one segment, indices_list can be a single array
new_unit_ids: int
Unit ids of the new units to be created.
Unit ids of the new units to be created
properties_policy: "keep" | "remove", default: "keep"
Policy used to propagate properties. If "keep" the properties will be passed to the new units
(if the units_to_merge have the same value). If "remove" the new units will have an empty
value for all the properties of the new unit.
value for all the properties of the new unit
Returns
-------
sorting: Sorting
Expand All @@ -48,7 +49,6 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non
new_unit_ids = max(parents_unit_ids) + 1
new_unit_ids = np.array([u + new_unit_ids for u in range(tot_splits)], dtype=parents_unit_ids.dtype)
else:
new_unit_ids = np.array(new_unit_ids, dtype=parents_unit_ids.dtype)
assert len(np.unique(new_unit_ids)) == len(new_unit_ids), "Each element in new_unit_ids must be unique"
assert len(new_unit_ids) <= tot_splits, "indices_list has more id indices than the length of new_unit_ids"

Expand Down
24 changes: 21 additions & 3 deletions src/spikeinterface/curation/tests/test_curationsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,30 @@ def test_curation():
parent_sort = NumpySorting.from_unit_dict(spikestimes, sampling_frequency=1000) # to have 1 sample=1ms
parent_sort.set_property("some_names", ["unit_{}".format(k) for k in spikestimes[0].keys()]) # float
cs = CurationSorting(parent_sort, properties_policy="remove")
cs.merge(["a", "c"])

# merge a-c
cs.merge(["a", "c"], new_unit_id="a-c")
assert cs.sorting.get_num_units() == len(spikestimes[0]) - 1
cs.undo()

# split b in 2
split_index = [v["b"] < 6 for v in spikestimes] # split class 4 in even and odds
cs.split("b", split_index)
cs.split("b", split_index, new_unit_ids=["b1", "b2"])
after_split = cs.sorting
assert cs.sorting.get_num_units() == len(spikestimes[0]) + 1
cs.undo()

# split one unit in 3
split_index3 = [v["b"] % 3 + 1 for v in spikestimes] # split class in 3
cs.split("b", split_index3, new_unit_ids=["b1", "b2", "b3"])
after_split = cs.sorting
assert after_split.get_num_units() == len(spikestimes[0]) + 2
cs.undo()

# split with renaming
cs.split("b", split_index3)
after_split = cs.sorting
assert cs.sorting.get_num_units() == len(spikestimes[0])
assert after_split.get_num_units() == len(spikestimes[0]) + 2

all_units = cs.sorting.get_unit_ids()
cs.merge(all_units, new_unit_id=all_units[0])
Expand Down
Loading