Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Correct unit ID matching in sortingview curation #2037

Merged
merged 17 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,4 @@ test_folder/

# Mac OS
.DS_Store
test_data.json
49 changes: 29 additions & 20 deletions src/spikeinterface/curation/sortingview_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,37 +57,47 @@ def apply_sortingview_curation(
unit_ids_dtype = sorting.unit_ids.dtype

# STEP 1: merge groups
labels_dict = sortingview_curation_dict["labelsByUnit"]
if "mergeGroups" in sortingview_curation_dict and not skip_merge:
merge_groups = sortingview_curation_dict["mergeGroups"]
for mg in merge_groups:
for merge_group in merge_groups:
# Store labels of units that are about to be merged
labels_to_inherit = []
for unit in merge_group:
labels_to_inherit.extend(labels_dict.get(str(unit), []))
labels_to_inherit = list(set(labels_to_inherit)) # Remove duplicates

if verbose:
print(f"Merging {mg}")
print(f"Merging {merge_group}")
if unit_ids_dtype.kind in ("U", "S"):
# if unit dtype is str, set new id as "{unit1}-{unit2}"
new_unit_id = "-".join(mg)
new_unit_id = "-".join(merge_group)
curation_sorting.merge(merge_group, new_unit_id=new_unit_id)
else:
# in this case, the CurationSorting takes care of finding a new unused int
new_unit_id = None
curation_sorting.merge(mg, new_unit_id=new_unit_id)
curation_sorting.merge(merge_group, new_unit_id=None)
new_unit_id = curation_sorting.max_used_id # merged unit id
labels_dict[str(new_unit_id)] = labels_to_inherit

# STEP 2: gather and apply sortingview curation labels

# In sortingview, a unit is not required to have all labels.
# For example, the first 3 units could be labeled as "accept".
# In this case, the first 3 values of the property "accept" will be True, the rest False
labels_dict = sortingview_curation_dict["labelsByUnit"]
properties = {}
for _, labels in labels_dict.items():
for label in labels:
if label not in properties:
properties[label] = np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool)
for u_i, unit_id in enumerate(curation_sorting.current_sorting.unit_ids):
labels_unit = []
for unit_label, labels in labels_dict.items():
if unit_label in str(unit_id):
labels_unit.extend(labels)
for label in labels_unit:
properties[label][u_i] = True

# Initialize the properties dictionary
properties = {
label: np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool)
for labels in labels_dict.values()
for label in labels
}

# Populate the properties dictionary
for unit_index, unit_id in enumerate(curation_sorting.current_sorting.unit_ids):
unit_id_str = str(unit_id)
if unit_id_str in labels_dict:
for label in labels_dict[unit_id_str]:
properties[label][unit_index] = True

for prop_name, prop_values in properties.items():
curation_sorting.current_sorting.set_property(prop_name, prop_values)

Expand All @@ -103,5 +113,4 @@ def apply_sortingview_curation(
units_to_remove.extend(unit_ids[curation_sorting.current_sorting.get_property(exclude_label) == True])
units_to_remove = np.unique(units_to_remove)
curation_sorting.remove_units(units_to_remove)

return curation_sorting.current_sorting
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"labelsByUnit": {
"1": [
"accept"
],
"2": [
"artifact"
],
"12": [
"artifact"
]
},
"mergeGroups": [
[
2,
12
]
]
}
39 changes: 39 additions & 0 deletions src/spikeinterface/curation/tests/sv-sorting-curation-int.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"labelsByUnit": {
"1": [
"mua"
],
"2": [
"mua"
],
"3": [
"reject"
],
"4": [
"noise"
],
"5": [
"accept"
],
"6": [
"accept"
],
"7": [
"accept"
]
},
"mergeGroups": [
[
1,
2
],
[
3,
4
],
[
5,
6
]
]
}
39 changes: 39 additions & 0 deletions src/spikeinterface/curation/tests/sv-sorting-curation-str.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"labelsByUnit": {
"a": [
"mua"
],
"b": [
"mua"
],
"c": [
"reject"
],
"d": [
"noise"
],
"e": [
"accept"
],
"f": [
"accept"
],
"g": [
"accept"
]
},
"mergeGroups": [
[
"a",
"b"
],
[
"c",
"d"
],
[
"e",
"f"
]
]
}
147 changes: 140 additions & 7 deletions src/spikeinterface/curation/tests/test_sortingview_curation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pytest
from pathlib import Path
import os
import json
import numpy as np

import spikeinterface as si
import spikeinterface.extractors as se
from spikeinterface.extractors import read_mearec
from spikeinterface import set_global_tmp_folder
from spikeinterface.postprocessing import (
Expand All @@ -19,7 +22,6 @@
cache_folder = Path("cache_folder") / "curation"

parent_folder = Path(__file__).parent

ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS"))
KACHERY_CLOUD_SET = bool(os.getenv("KACHERY_CLOUD_CLIENT_ID")) and bool(os.getenv("KACHERY_CLOUD_PRIVATE_KEY"))

Expand Down Expand Up @@ -50,15 +52,15 @@ def generate_sortingview_curation_dataset():

@pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available")
def test_gh_curation():
"""
Test curation using GitHub URI.
"""
local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5")
_, sorting = read_mearec(local_path)

# from GH
# curated link:
# https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary&s={%22sortingCuration%22:%22gh://alejoe91/spikeinterface/fix-codecov/spikeinterface/curation/tests/sv-sorting-curation.json%22}
gh_uri = "gh://SpikeInterface/spikeinterface/main/src/spikeinterface/curation/tests/sv-sorting-curation.json"
sorting_curated_gh = apply_sortingview_curation(sorting, uri_or_json=gh_uri, verbose=True)
print(f"From GH: {sorting_curated_gh}")

assert len(sorting_curated_gh.unit_ids) == 9
assert "#8-#9" in sorting_curated_gh.unit_ids
Expand All @@ -78,6 +80,9 @@ def test_gh_curation():

@pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available")
def test_sha1_curation():
"""
Test curation using SHA1 URI.
"""
local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5")
_, sorting = read_mearec(local_path)

Expand All @@ -86,14 +91,14 @@ def test_sha1_curation():
# https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary&s={%22sortingCuration%22:%22sha1://1182ba19671fcc7d3f8e0501b0f8c07fb9736c22%22}
sha1_uri = "sha1://1182ba19671fcc7d3f8e0501b0f8c07fb9736c22"
sorting_curated_sha1 = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, verbose=True)
print(f"From SHA: {sorting_curated_sha1}")
# print(f"From SHA: {sorting_curated_sha1}")

assert len(sorting_curated_sha1.unit_ids) == 9
assert "#8-#9" in sorting_curated_sha1.unit_ids
assert "accept" in sorting_curated_sha1.get_property_keys()
assert "mua" in sorting_curated_sha1.get_property_keys()
assert "artifact" in sorting_curated_sha1.get_property_keys()

unit_ids = sorting_curated_sha1.unit_ids
sorting_curated_sha1_accepted = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, include_labels=["accept"])
sorting_curated_sha1_mua = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, exclude_labels=["mua"])
sorting_curated_sha1_art_mua = apply_sortingview_curation(
Expand All @@ -105,13 +110,16 @@ def test_sha1_curation():


def test_json_curation():
"""
Test curation using a JSON file.
"""
local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5")
_, sorting = read_mearec(local_path)

# from curation.json
json_file = parent_folder / "sv-sorting-curation.json"
# print(f"Sorting: {sorting.get_unit_ids()}")
sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True)
print(f"From JSON: {sorting_curated_json}")

assert len(sorting_curated_json.unit_ids) == 9
assert "#8-#9" in sorting_curated_json.unit_ids
Expand All @@ -131,8 +139,133 @@ def test_json_curation():
assert len(sorting_curated_json_mua1.unit_ids) == 5


def test_false_positive_curation():
"""
Test curation for false positives.
"""
# https://spikeinterface.readthedocs.io/en/latest/modules_gallery/core/plot_2_sorting_extractor.html
sampling_frequency = 30000.0
duration = 20.0
num_timepoints = int(sampling_frequency * duration)
num_units = 20
num_spikes = 1000
times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes)))
labels = np.random.randint(1, num_units + 1, size=num_spikes)

sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency)
# print("Sorting: {}".format(sorting.get_unit_ids()))

json_file = parent_folder / "sv-sorting-curation-false-positive.json"
sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True)
# print("Curated:", sorting_curated_json.get_unit_ids())

# Assertions
assert sorting_curated_json.get_unit_property(unit_id=1, key="accept")
assert not sorting_curated_json.get_unit_property(unit_id=10, key="accept")
assert 21 in sorting_curated_json.unit_ids
rkim48 marked this conversation as resolved.
Show resolved Hide resolved


def test_label_inheritance_int():
"""
Test curation for label inheritance for integer unit IDs.
"""
# Setup
sampling_frequency = 30000.0
duration = 20.0
num_timepoints = int(sampling_frequency * duration)
num_spikes = 1000
num_units = 7
times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes)))
labels = np.random.randint(1, 1 + num_units, size=num_spikes) # 7 units: 1 to 7

sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency)

json_file = parent_folder / "sv-sorting-curation-int.json"
sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file)

# Assertions for merged units
# print(f"Merge only: {sorting_merge.get_unit_ids()}")
assert sorting_merge.get_unit_property(unit_id=8, key="mua") # 8 = merged unit of 1 and 2
assert not sorting_merge.get_unit_property(unit_id=8, key="reject")
assert not sorting_merge.get_unit_property(unit_id=8, key="noise")
assert not sorting_merge.get_unit_property(unit_id=8, key="accept")

assert not sorting_merge.get_unit_property(unit_id=9, key="mua") # 9 = merged unit of 3 and 4
assert sorting_merge.get_unit_property(unit_id=9, key="reject")
assert sorting_merge.get_unit_property(unit_id=9, key="noise")
assert not sorting_merge.get_unit_property(unit_id=9, key="accept")

assert not sorting_merge.get_unit_property(unit_id=10, key="mua") # 10 = merged unit of 5 and 6
assert not sorting_merge.get_unit_property(unit_id=10, key="reject")
assert not sorting_merge.get_unit_property(unit_id=10, key="noise")
assert sorting_merge.get_unit_property(unit_id=10, key="accept")

# Assertions for exclude_labels
sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_file, exclude_labels=["noise"])
# print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}")
assert 9 not in sorting_exclude_noise.get_unit_ids()

# Assertions for include_labels
sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_file, include_labels=["accept"])
# print(f"Include accept: {sorting_include_accept.get_unit_ids()}")
assert 8 not in sorting_include_accept.get_unit_ids()
assert 9 not in sorting_include_accept.get_unit_ids()
assert 10 in sorting_include_accept.get_unit_ids()


def test_label_inheritance_str():
"""
Test curation for label inheritance for string unit IDs.
"""
sampling_frequency = 30000.0
duration = 20.0
num_timepoints = int(sampling_frequency * duration)
num_spikes = 1000
times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes)))
labels = np.random.choice(["a", "b", "c", "d", "e", "f", "g"], size=num_spikes)

sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency)
# print(f"Sorting: {sorting.get_unit_ids()}")

# Apply curation
json_file = parent_folder / "sv-sorting-curation-str.json"
sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True)

# Assertions for merged units
# print(f"Merge only: {sorting_merge.get_unit_ids()}")
assert sorting_merge.get_unit_property(unit_id="a-b", key="mua")
assert not sorting_merge.get_unit_property(unit_id="a-b", key="reject")
assert not sorting_merge.get_unit_property(unit_id="a-b", key="noise")
assert not sorting_merge.get_unit_property(unit_id="a-b", key="accept")

assert not sorting_merge.get_unit_property(unit_id="c-d", key="mua")
assert sorting_merge.get_unit_property(unit_id="c-d", key="reject")
assert sorting_merge.get_unit_property(unit_id="c-d", key="noise")
assert not sorting_merge.get_unit_property(unit_id="c-d", key="accept")

assert not sorting_merge.get_unit_property(unit_id="e-f", key="mua")
assert not sorting_merge.get_unit_property(unit_id="e-f", key="reject")
assert not sorting_merge.get_unit_property(unit_id="e-f", key="noise")
assert sorting_merge.get_unit_property(unit_id="e-f", key="accept")

# Assertions for exclude_labels
sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_file, exclude_labels=["noise"])
# print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}")
assert "c-d" not in sorting_exclude_noise.get_unit_ids()

# Assertions for include_labels
sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_file, include_labels=["accept"])
# print(f"Include accept: {sorting_include_accept.get_unit_ids()}")
assert "a-b" not in sorting_include_accept.get_unit_ids()
assert "c-d" not in sorting_include_accept.get_unit_ids()
assert "e-f" in sorting_include_accept.get_unit_ids()


if __name__ == "__main__":
# generate_sortingview_curation_dataset()
test_sha1_curation()
test_gh_curation()
test_json_curation()
test_false_positive_curation()
test_label_inheritance_int()
test_label_inheritance_str()