Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Correct unit ID matching in sortingview curation #2037

Merged
merged 17 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,4 @@ test_folder/

# Mac OS
.DS_Store
test_data.json
51 changes: 34 additions & 17 deletions src/spikeinterface/curation/sortingview_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,37 +57,55 @@ def apply_sortingview_curation(
unit_ids_dtype = sorting.unit_ids.dtype

# STEP 1: merge groups
labels_dict = sortingview_curation_dict["labelsByUnit"]
if "mergeGroups" in sortingview_curation_dict and not skip_merge:
merge_groups = sortingview_curation_dict["mergeGroups"]
for mg in merge_groups:
for merge_group in merge_groups:
# Store labels of units that are about to be merged
labels_to_inherit = []
for unit in merge_group:
labels_to_inherit.extend(labels_dict.get(str(unit), []))
labels_to_inherit = list(set(labels_to_inherit)) # Remove duplicates

if verbose:
print(f"Merging {mg}")
print(f"Merging {merge_group}")
if unit_ids_dtype.kind in ("U", "S"):
# if unit dtype is str, set new id as "{unit1}-{unit2}"
new_unit_id = "-".join(mg)
new_unit_id = "-".join(merge_group)
curation_sorting.merge(merge_group, new_unit_id=new_unit_id)
else:
# in this case, the CurationSorting takes care of finding a new unused int
new_unit_id = None
curation_sorting.merge(mg, new_unit_id=new_unit_id)
curation_sorting.merge(merge_group, new_unit_id=None)
new_unit_id = curation_sorting.max_used_id # merged unit id
labels_dict[str(new_unit_id)] = labels_to_inherit
rkim48 marked this conversation as resolved.
Show resolved Hide resolved

# STEP 2: gather and apply sortingview curation labels

# In sortingview, a unit is not required to have all labels.
# For example, the first 3 units could be labeled as "accept".
# In this case, the first 3 values of the property "accept" will be True, the rest False
labels_dict = sortingview_curation_dict["labelsByUnit"]
properties = {}
for _, labels in labels_dict.items():
for label in labels:
if label not in properties:
properties[label] = np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool)

# Initialize the properties dictionary
properties = {
label: np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool)
for labels in labels_dict.values()
for label in labels
}

# Populate the properties dictionary
for u_i, unit_id in enumerate(curation_sorting.current_sorting.unit_ids):
labels_unit = []
for unit_label, labels in labels_dict.items():
if unit_label in str(unit_id):
labels_unit.extend(labels)
labels_unit = set()

# Check for exact match first
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
if str(unit_id) in labels_dict:
labels_unit.update(labels_dict[str(unit_id)])
# If no exact match, check if unit_label is a substring of unit_id (for string unit ID merged unit)
else:
for unit_label, labels in labels_dict.items():
if isinstance(unit_id, str) and unit_label in unit_id:
labels_unit.update(labels)
rkim48 marked this conversation as resolved.
Show resolved Hide resolved
for label in labels_unit:
properties[label][u_i] = True

for prop_name, prop_values in properties.items():
curation_sorting.current_sorting.set_property(prop_name, prop_values)

Expand All @@ -103,5 +121,4 @@ def apply_sortingview_curation(
units_to_remove.extend(unit_ids[curation_sorting.current_sorting.get_property(exclude_label) == True])
units_to_remove = np.unique(units_to_remove)
curation_sorting.remove_units(units_to_remove)

return curation_sorting.current_sorting
193 changes: 187 additions & 6 deletions src/spikeinterface/curation/tests/test_sortingview_curation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pytest
from pathlib import Path
import os
import json
import numpy as np

import spikeinterface as si
import spikeinterface.extractors as se
from spikeinterface.extractors import read_mearec
from spikeinterface import set_global_tmp_folder
from spikeinterface.postprocessing import (
Expand All @@ -19,7 +22,6 @@
cache_folder = Path("cache_folder") / "curation"

parent_folder = Path(__file__).parent

ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS"))
KACHERY_CLOUD_SET = bool(os.getenv("KACHERY_CLOUD_CLIENT_ID")) and bool(os.getenv("KACHERY_CLOUD_PRIVATE_KEY"))

Expand Down Expand Up @@ -50,15 +52,15 @@ def generate_sortingview_curation_dataset():

@pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available")
def test_gh_curation():
"""
Test curation using GitHub URI.
"""
local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5")
_, sorting = read_mearec(local_path)

# from GH
# curated link:
# https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary&s={%22sortingCuration%22:%22gh://alejoe91/spikeinterface/fix-codecov/spikeinterface/curation/tests/sv-sorting-curation.json%22}
gh_uri = "gh://SpikeInterface/spikeinterface/main/src/spikeinterface/curation/tests/sv-sorting-curation.json"
sorting_curated_gh = apply_sortingview_curation(sorting, uri_or_json=gh_uri, verbose=True)
print(f"From GH: {sorting_curated_gh}")

assert len(sorting_curated_gh.unit_ids) == 9
assert "#8-#9" in sorting_curated_gh.unit_ids
Expand All @@ -78,6 +80,9 @@ def test_gh_curation():

@pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available")
def test_sha1_curation():
"""
Test curation using SHA1 URI.
"""
local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5")
_, sorting = read_mearec(local_path)

Expand All @@ -93,7 +98,7 @@ def test_sha1_curation():
assert "accept" in sorting_curated_sha1.get_property_keys()
assert "mua" in sorting_curated_sha1.get_property_keys()
assert "artifact" in sorting_curated_sha1.get_property_keys()

unit_ids = sorting_curated_sha1.unit_ids
sorting_curated_sha1_accepted = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, include_labels=["accept"])
sorting_curated_sha1_mua = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, exclude_labels=["mua"])
sorting_curated_sha1_art_mua = apply_sortingview_curation(
Expand All @@ -105,13 +110,16 @@ def test_sha1_curation():


def test_json_curation():
"""
Test curation using a JSON file.
"""
local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5")
_, sorting = read_mearec(local_path)

# from curation.json
json_file = parent_folder / "sv-sorting-curation.json"
print(f"Sorting: {sorting.get_unit_ids()}")
sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True)
print(f"From JSON: {sorting_curated_json}")

assert len(sorting_curated_json.unit_ids) == 9
assert "#8-#9" in sorting_curated_json.unit_ids
Expand All @@ -130,9 +138,182 @@ def test_json_curation():
assert len(sorting_curated_json_mua.unit_ids) == 6
assert len(sorting_curated_json_mua1.unit_ids) == 5

print("Test for json curation passed!\n")


def test_false_positive_curation():
"""
Test curation for false positives.
"""
# https://spikeinterface.readthedocs.io/en/latest/modules_gallery/core/plot_2_sorting_extractor.html
sampling_frequency = 30000.0
duration = 20.0
num_timepoints = int(sampling_frequency * duration)
num_units = 20
num_spikes = 1000
times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes)))
labels = np.random.randint(1, num_units + 1, size=num_spikes)

sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency)
print("Sorting: {}".format(sorting.get_unit_ids()))

# Test curation JSON:
test_json = {"labelsByUnit": {"1": ["accept"], "2": ["artifact"], "12": ["artifact"]}, "mergeGroups": [[2, 12]]}

json_path = "test_data.json"
rkim48 marked this conversation as resolved.
Show resolved Hide resolved
with open(json_path, "w") as f:
json.dump(test_json, f, indent=4)

# Apply curation
sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_path, verbose=True)
print("Curated:", sorting_curated_json.get_unit_ids())

# Assertions
assert sorting_curated_json.get_unit_property(unit_id=1, key="accept")
assert not sorting_curated_json.get_unit_property(unit_id=10, key="accept")
assert 21 in sorting_curated_json.unit_ids
rkim48 marked this conversation as resolved.
Show resolved Hide resolved

print("False positive test for integer unit IDs passed!\n")


def test_label_inheritance_int():
"""
Test curation for label inheritance for integer unit IDs.
"""
# Setup
sampling_frequency = 30000.0
duration = 20.0
num_timepoints = int(sampling_frequency * duration)
num_spikes = 1000
times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes)))
labels = np.random.randint(1, 8, size=num_spikes) # 7 units: 1 to 7

sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency)

# Create a curation JSON with labels and merge groups
curation_dict = {
"labelsByUnit": {
"1": ["mua"],
"2": ["mua"],
"3": ["reject"],
"4": ["noise"],
"5": ["accept"],
"6": ["accept"],
"7": ["accept"],
},
"mergeGroups": [[1, 2], [3, 4], [5, 6]],
}

json_path = "test_curation_int.json"
rkim48 marked this conversation as resolved.
Show resolved Hide resolved
with open(json_path, "w") as f:
json.dump(curation_dict, f, indent=4)

# Apply curation
sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_path)

# Assertions for merged units
print(f"Merge only: {sorting_merge.get_unit_ids()}")
assert sorting_merge.get_unit_property(unit_id=8, key="mua") # 8 = merged unit of 1 and 2
assert not sorting_merge.get_unit_property(unit_id=8, key="reject")
assert not sorting_merge.get_unit_property(unit_id=8, key="noise")
assert not sorting_merge.get_unit_property(unit_id=8, key="accept")

assert not sorting_merge.get_unit_property(unit_id=9, key="mua") # 9 = merged unit of 3 and 4
assert sorting_merge.get_unit_property(unit_id=9, key="reject")
assert sorting_merge.get_unit_property(unit_id=9, key="noise")
assert not sorting_merge.get_unit_property(unit_id=9, key="accept")

assert not sorting_merge.get_unit_property(unit_id=10, key="mua") # 10 = merged unit of 5 and 6
assert not sorting_merge.get_unit_property(unit_id=10, key="reject")
assert not sorting_merge.get_unit_property(unit_id=10, key="noise")
assert sorting_merge.get_unit_property(unit_id=10, key="accept")

# Assertions for exclude_labels
sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_path, exclude_labels=["noise"])
print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}")
assert 9 not in sorting_exclude_noise.get_unit_ids()

# Assertions for include_labels
sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_path, include_labels=["accept"])
print(f"Include accept: {sorting_include_accept.get_unit_ids()}")
assert 8 not in sorting_include_accept.get_unit_ids()
assert 9 not in sorting_include_accept.get_unit_ids()
assert 10 in sorting_include_accept.get_unit_ids()


def test_label_inheritance_str():
"""
Test curation for label inheritance for string unit IDs.
"""
sampling_frequency = 30000.0
duration = 20.0
num_timepoints = int(sampling_frequency * duration)
num_spikes = 1000
times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes)))
labels = np.random.choice(["a", "b", "c", "d", "e", "f", "g"], size=num_spikes)

sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency)
print(f"Sorting: {sorting.get_unit_ids()}")
# Create a curation JSON with labels and merge groups
curation_dict = {
"labelsByUnit": {
"a": ["mua"],
"b": ["mua"],
"c": ["reject"],
"d": ["noise"],
"e": ["accept"],
"f": ["accept"],
"g": ["accept"],
},
"mergeGroups": [["a", "b"], ["c", "d"], ["e", "f"]],
}

json_path = "test_curation_str.json"
with open(json_path, "w") as f:
json.dump(curation_dict, f, indent=4)

# Check label inheritance for merged units
merged_id_1 = "a-b"
merged_id_2 = "c-d"
merged_id_3 = "e-f"
# Apply curation
sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_path, verbose=True)

# Assertions for merged units
print(f"Merge only: {sorting_merge.get_unit_ids()}")
assert sorting_merge.get_unit_property(unit_id="a-b", key="mua")
assert not sorting_merge.get_unit_property(unit_id="a-b", key="reject")
assert not sorting_merge.get_unit_property(unit_id="a-b", key="noise")
assert not sorting_merge.get_unit_property(unit_id="a-b", key="accept")

assert not sorting_merge.get_unit_property(unit_id="c-d", key="mua")
assert sorting_merge.get_unit_property(unit_id="c-d", key="reject")
assert sorting_merge.get_unit_property(unit_id="c-d", key="noise")
assert not sorting_merge.get_unit_property(unit_id="c-d", key="accept")

assert not sorting_merge.get_unit_property(unit_id="e-f", key="mua")
assert not sorting_merge.get_unit_property(unit_id="e-f", key="reject")
assert not sorting_merge.get_unit_property(unit_id="e-f", key="noise")
assert sorting_merge.get_unit_property(unit_id="e-f", key="accept")

# Assertions for exclude_labels
sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_path, exclude_labels=["noise"])
print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}")
assert "c-d" not in sorting_exclude_noise.get_unit_ids()

# Assertions for include_labels
sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_path, include_labels=["accept"])
print(f"Include accept: {sorting_include_accept.get_unit_ids()}")
assert "a-b" not in sorting_include_accept.get_unit_ids()
assert "c-d" not in sorting_include_accept.get_unit_ids()
assert "e-f" in sorting_include_accept.get_unit_ids()


if __name__ == "__main__":
# generate_sortingview_curation_dataset()
test_sha1_curation()
test_gh_curation()
test_json_curation()
test_false_positive_curation()
test_label_inheritance_int()
test_label_inheritance_str()