Skip to content

Commit

Permalink
Merge branch 'main' into sc2_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
yger authored Oct 2, 2023
2 parents bbc8167 + 8c35a3a commit 071f81f
Show file tree
Hide file tree
Showing 22 changed files with 1,032 additions and 464 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,4 @@ test_folder/

# Mac OS
.DS_Store
test_data.json
28 changes: 28 additions & 0 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def generate_sorting(
firing_rates=3.0,
empty_units=None,
refractory_period_ms=3.0, # in ms
add_spikes_on_borders=False,
num_spikes_per_border=3,
border_size_samples=20,
seed=None,
):
"""
Expand All @@ -142,6 +145,12 @@ def generate_sorting(
List of units that will have no spikes. (used for testing mainly).
refractory_period_ms : float, default: 3.0
The refractory period in ms
add_spikes_on_borders : bool, default: False
If True, spikes will be added close to the borders of the segments.
num_spikes_per_border : int, default: 3
The number of spikes to add close to the borders of the segments.
border_size_samples : int, default: 20
The size of the border in samples to add border spikes.
seed : int, default: None
The random seed
Expand All @@ -151,11 +160,13 @@ def generate_sorting(
The sorting object
"""
seed = _ensure_seed(seed)
rng = np.random.default_rng(seed)
num_segments = len(durations)
unit_ids = np.arange(num_units)

spikes = []
for segment_index in range(num_segments):
num_samples = int(sampling_frequency * durations[segment_index])
times, labels = synthesize_random_firings(
num_units=num_units,
sampling_frequency=sampling_frequency,
Expand All @@ -175,7 +186,23 @@ def generate_sorting(
spikes_in_seg["unit_index"] = labels
spikes_in_seg["segment_index"] = segment_index
spikes.append(spikes_in_seg)

if add_spikes_on_borders:
spikes_on_borders = np.zeros(2 * num_spikes_per_border, dtype=minimum_spike_dtype)
spikes_on_borders["segment_index"] = segment_index
spikes_on_borders["unit_index"] = rng.choice(num_units, size=2 * num_spikes_per_border, replace=True)
# at start
spikes_on_borders["sample_index"][:num_spikes_per_border] = rng.integers(
0, border_size_samples, num_spikes_per_border
)
# at end
spikes_on_borders["sample_index"][num_spikes_per_border:] = rng.integers(
num_samples - border_size_samples, num_samples, num_spikes_per_border
)
spikes.append(spikes_on_borders)

spikes = np.concatenate(spikes)
spikes = spikes[np.lexsort((spikes["sample_index"], spikes["segment_index"]))]

sorting = NumpySorting(spikes, sampling_frequency, unit_ids)

Expand Down Expand Up @@ -596,6 +623,7 @@ def __init__(
dtype = np.dtype(dtype).name # Cast to string for serialization
if dtype not in ("float32", "float64"):
raise ValueError(f"'dtype' must be 'float32' or 'float64' but is {dtype}")
assert strategy in ("tile_pregenerated", "on_the_fly"), "'strategy' must be 'tile_pregenerated' or 'on_the_fly'"

BaseRecording.__init__(self, sampling_frequency=sampling_frequency, channel_ids=channel_ids, dtype=dtype)

Expand Down
36 changes: 33 additions & 3 deletions src/spikeinterface/core/tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,44 @@


def test_generate_recording():
# TODO even this is extenssivly tested in all other function
# TODO even this is extensively tested in all other functions
pass


def test_generate_sorting():
# TODO even this is extenssivly tested in all other function
# TODO even this is extensively tested in all other functions
pass


def test_generate_sorting_with_spikes_on_borders():
num_spikes_on_borders = 10
border_size_samples = 10
segment_duration = 10
for nseg in [1, 2, 3]:
sorting = generate_sorting(
durations=[segment_duration] * nseg,
sampling_frequency=30000,
num_units=10,
add_spikes_on_borders=True,
num_spikes_per_border=num_spikes_on_borders,
border_size_samples=border_size_samples,
)
# check that segments are correctly sorted
all_spikes = sorting.to_spike_vector()
np.testing.assert_array_equal(all_spikes["segment_index"], np.sort(all_spikes["segment_index"]))

spikes = sorting.to_spike_vector(concatenated=False)
# at least num_border spikes at borders for all segments
for spikes_in_segment in spikes:
# check that sample indices are correctly sorted within segments
np.testing.assert_array_equal(spikes_in_segment["sample_index"], np.sort(spikes_in_segment["sample_index"]))
num_samples = int(segment_duration * 30000)
assert np.sum(spikes_in_segment["sample_index"] < border_size_samples) >= num_spikes_on_borders
assert (
np.sum(spikes_in_segment["sample_index"] >= num_samples - border_size_samples) >= num_spikes_on_borders
)


def measure_memory_allocation(measure_in_process: bool = True) -> float:
"""
A local utility to measure memory allocation at a specific point in time.
Expand Down Expand Up @@ -399,7 +428,7 @@ def test_generate_ground_truth_recording():
if __name__ == "__main__":
strategy = "tile_pregenerated"
# strategy = "on_the_fly"
test_noise_generator_memory()
# test_noise_generator_memory()
# test_noise_generator_under_giga()
# test_noise_generator_correct_shape(strategy)
# test_noise_generator_consistency_across_calls(strategy, 0, 5)
Expand All @@ -410,3 +439,4 @@ def test_generate_ground_truth_recording():
# test_generate_templates()
# test_inject_templates()
# test_generate_ground_truth_recording()
test_generate_sorting_with_spikes_on_borders()
49 changes: 29 additions & 20 deletions src/spikeinterface/curation/sortingview_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,37 +57,47 @@ def apply_sortingview_curation(
unit_ids_dtype = sorting.unit_ids.dtype

# STEP 1: merge groups
labels_dict = sortingview_curation_dict["labelsByUnit"]
if "mergeGroups" in sortingview_curation_dict and not skip_merge:
merge_groups = sortingview_curation_dict["mergeGroups"]
for mg in merge_groups:
for merge_group in merge_groups:
# Store labels of units that are about to be merged
labels_to_inherit = []
for unit in merge_group:
labels_to_inherit.extend(labels_dict.get(str(unit), []))
labels_to_inherit = list(set(labels_to_inherit)) # Remove duplicates

if verbose:
print(f"Merging {mg}")
print(f"Merging {merge_group}")
if unit_ids_dtype.kind in ("U", "S"):
# if unit dtype is str, set new id as "{unit1}-{unit2}"
new_unit_id = "-".join(mg)
new_unit_id = "-".join(merge_group)
curation_sorting.merge(merge_group, new_unit_id=new_unit_id)
else:
# in this case, the CurationSorting takes care of finding a new unused int
new_unit_id = None
curation_sorting.merge(mg, new_unit_id=new_unit_id)
curation_sorting.merge(merge_group, new_unit_id=None)
new_unit_id = curation_sorting.max_used_id # merged unit id
labels_dict[str(new_unit_id)] = labels_to_inherit

# STEP 2: gather and apply sortingview curation labels

# In sortingview, a unit is not required to have all labels.
# For example, the first 3 units could be labeled as "accept".
# In this case, the first 3 values of the property "accept" will be True, the rest False
labels_dict = sortingview_curation_dict["labelsByUnit"]
properties = {}
for _, labels in labels_dict.items():
for label in labels:
if label not in properties:
properties[label] = np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool)
for u_i, unit_id in enumerate(curation_sorting.current_sorting.unit_ids):
labels_unit = []
for unit_label, labels in labels_dict.items():
if unit_label in str(unit_id):
labels_unit.extend(labels)
for label in labels_unit:
properties[label][u_i] = True

# Initialize the properties dictionary
properties = {
label: np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool)
for labels in labels_dict.values()
for label in labels
}

# Populate the properties dictionary
for unit_index, unit_id in enumerate(curation_sorting.current_sorting.unit_ids):
unit_id_str = str(unit_id)
if unit_id_str in labels_dict:
for label in labels_dict[unit_id_str]:
properties[label][unit_index] = True

for prop_name, prop_values in properties.items():
curation_sorting.current_sorting.set_property(prop_name, prop_values)

Expand All @@ -103,5 +113,4 @@ def apply_sortingview_curation(
units_to_remove.extend(unit_ids[curation_sorting.current_sorting.get_property(exclude_label) == True])
units_to_remove = np.unique(units_to_remove)
curation_sorting.remove_units(units_to_remove)

return curation_sorting.current_sorting
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"labelsByUnit": {
"1": [
"accept"
],
"2": [
"artifact"
],
"12": [
"artifact"
]
},
"mergeGroups": [
[
2,
12
]
]
}
39 changes: 39 additions & 0 deletions src/spikeinterface/curation/tests/sv-sorting-curation-int.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"labelsByUnit": {
"1": [
"mua"
],
"2": [
"mua"
],
"3": [
"reject"
],
"4": [
"noise"
],
"5": [
"accept"
],
"6": [
"accept"
],
"7": [
"accept"
]
},
"mergeGroups": [
[
1,
2
],
[
3,
4
],
[
5,
6
]
]
}
39 changes: 39 additions & 0 deletions src/spikeinterface/curation/tests/sv-sorting-curation-str.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"labelsByUnit": {
"a": [
"mua"
],
"b": [
"mua"
],
"c": [
"reject"
],
"d": [
"noise"
],
"e": [
"accept"
],
"f": [
"accept"
],
"g": [
"accept"
]
},
"mergeGroups": [
[
"a",
"b"
],
[
"c",
"d"
],
[
"e",
"f"
]
]
}
Loading

0 comments on commit 071f81f

Please sign in to comment.