Skip to content

Commit

Permalink
add json validation to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bendichter committed Dec 8, 2024
1 parent 26a7d33 commit d529e0d
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 2 deletions.
8 changes: 6 additions & 2 deletions resources/probe.json.schema
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
"annotations": {
"type": "object",
"properties": {
"name": { "type": "string" },
"model_name": { "type": "string" },
"manufacturer": { "type": "string" }
},
"required": ["name", "manufacturer"],
"required": ["model_name", "manufacturer"],
"additionalProperties": true
},
"contact_annotations": {
Expand Down Expand Up @@ -101,6 +101,10 @@
"shank_ids": {
"type": "array",
"items": { "type": "string" }
},
"device_channel_indices": {
"type": "array",
"items": { "type": "integer" }
}
},
"required": [
Expand Down
12 changes: 12 additions & 0 deletions src/probeinterface/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import json
from pathlib import Path

from probeinterface import __version__ as version
import jsonschema

json_schema_file = Path(__file__).absolute().parent.parent.parent / "resources" / "probe.json.schema"
schema = json.load(open(json_schema_file, "r"))

def validate_probe_dict(probe_dict):
instance = dict(specification="probeinterface", version=version, probes=[probe_dict])
jsonschema.validate(instance=instance, schema=schema)
16 changes: 16 additions & 0 deletions tests/test_io/test_spikeglx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import pytest

from probeinterface import (
__version__ as version,
read_spikeglx,
parse_spikeglx_meta,
get_saved_channel_indices_from_spikeglx_meta,
)
from probeinterface.testing import validate_probe_dict

data_path = Path(__file__).absolute().parent.parent / "data" / "spikeglx"

Expand All @@ -34,6 +36,7 @@ def test_get_saved_channel_indices_from_spikeglx_meta():
def test_NP1():
probe = read_spikeglx(data_path / "Noise_g0_t0.imec0.ap.meta")
assert "1.0" in probe.model_name
validate_probe_dict(probe.to_dict(array_as_list=True))


def test_NP_phase3A():
Expand All @@ -54,12 +57,14 @@ def test_NP_phase3A():

assert np.all(probe.contact_shape_params == {"width": contact_width})
assert np.all(probe.contact_shapes == contact_shape)
validate_probe_dict(probe.to_dict(array_as_list=True))


def test_NP2_1_shanks():
probe = read_spikeglx(data_path / "p2_g0_t0.imec0.ap.meta")
assert "2.0" in probe.model_name
assert probe.get_shank_count() == 1
validate_probe_dict(probe.to_dict(array_as_list=True))


def test_NP2_4_shanks():
Expand All @@ -83,6 +88,7 @@ def test_NP2_4_shanks():
# This file does not save the channnels from 0 as the one above (NP2_4_shanks_g0_t0.imec0.ap.meta)
ypos = probe.contact_positions[:, 1]
assert np.min(ypos) == pytest.approx(0)
validate_probe_dict(probe.to_dict(array_as_list=True))


def test_NP2_2013_all():
Expand All @@ -108,6 +114,7 @@ def test_NP2_2013_all():
# This file does not save the channnels from 0 as the one above (NP2_4_shanks_g0_t0.imec0.ap.meta)
ypos = probe.contact_positions[:, 1]
assert np.min(ypos) == pytest.approx(0)
validate_probe_dict(probe.to_dict(array_as_list=True))


def test_NP2_2013_subset():
Expand All @@ -133,6 +140,7 @@ def test_NP2_2013_subset():
# This file does not save the channnels from 0 as the one above (NP2_4_shanks_g0_t0.imec0.ap.meta)
ypos = probe.contact_positions[:, 1]
assert np.min(ypos) == pytest.approx(0)
validate_probe_dict(probe.to_dict(array_as_list=True))


def test_NP2_4_shanks_with_different_electrodes_saved():
Expand All @@ -158,6 +166,7 @@ def test_NP2_4_shanks_with_different_electrodes_saved():
ypos = probe.contact_positions[:, 1]
assert np.min(ypos) == pytest.approx(4080.0)
assert np.max(ypos) == pytest.approx(4785.0)
validate_probe_dict(probe.to_dict(array_as_list=True))


def test_NP1_large_depth_span():
Expand All @@ -167,6 +176,7 @@ def test_NP1_large_depth_span():
assert probe.get_shank_count() == 1
ypos = probe.contact_positions[:, 1]
assert (np.max(ypos) - np.min(ypos)) > 7600
validate_probe_dict(probe.to_dict(array_as_list=True))


def test_NP1_other_example():
Expand All @@ -177,6 +187,7 @@ def test_NP1_other_example():
assert probe.get_shank_count() == 1
ypos = probe.contact_positions[:, 1]
assert (np.max(ypos) - np.min(ypos)) > 7600
validate_probe_dict(probe.to_dict(array_as_list=True))


def tes_NP1_384_channels():
Expand All @@ -185,6 +196,7 @@ def tes_NP1_384_channels():
assert probe.get_shank_count() == 1
assert probe.get_contact_count() == 151
assert 152 not in probe.contact_annotations["channel_ids"]
validate_probe_dict(probe.to_dict(array_as_list=True))


def test_NPH_long_staggered():
Expand Down Expand Up @@ -241,6 +253,7 @@ def test_NPH_long_staggered():
assert np.allclose(banks, 0)
assert np.allclose(references, 0)
assert np.allclose(filters, 1)
validate_probe_dict(probe.to_dict(array_as_list=True))


def test_NPH_short_linear_probe_type_0():
Expand Down Expand Up @@ -291,6 +304,7 @@ def test_NPH_short_linear_probe_type_0():
assert np.allclose(banks, 0)
assert np.allclose(references, 0)
assert np.allclose(filters, 1)
validate_probe_dict(probe.to_dict(array_as_list=True))


def test_ultra_probe():
Expand Down Expand Up @@ -319,11 +333,13 @@ def test_ultra_probe():
expected_electode_rows = 48
unique_y_values = np.unique(y)
assert unique_y_values.size == expected_electode_rows
validate_probe_dict(probe.to_dict(array_as_list=True))


def test_CatGT_NP1():
probe = read_spikeglx(data_path / "catgt.meta")
assert "1.0" in probe.model_name
validate_probe_dict(probe.to_dict(array_as_list=True))


if __name__ == "__main__":
Expand Down

0 comments on commit d529e0d

Please sign in to comment.