diff --git a/tools/schemacode/bidsschematools/tests/test_expressions.py b/tools/schemacode/bidsschematools/tests/test_expressions.py index 3c32cda023..e2e0704b37 100644 --- a/tools/schemacode/bidsschematools/tests/test_expressions.py +++ b/tools/schemacode/bidsschematools/tests/test_expressions.py @@ -1,3 +1,7 @@ +from collections.abc import Mapping +from functools import singledispatch +from typing import Union + import pytest from pyparsing.exceptions import ParseException @@ -11,6 +15,7 @@ RightOp, expression, ) +from ..types import Namespace def test_schema_expressions(schema_obj): @@ -88,65 +93,93 @@ def test_expected_failures(expr): expression.parse_string(expr) +def walk_schema(schema_obj, predicate): + for key, value in schema_obj.items(): + if predicate(key, value): + yield key, value + if isinstance(value, Mapping): + for subkey, value in walk_schema(value, predicate): + yield f"{key}.{subkey}", value + + def test_valid_sidecar_field(schema_obj): """Check sidecar fields actually exist in the metadata listed in the schema. Test failures are usually due to typos. """ - for rules, level in ((schema_obj.rules.checks, 3),): - keys = (key for key in rules.keys(level=level) if key.endswith("selectors")) - check_fields(schema_obj, rules, keys) - - keys = (key for key in rules.keys(level=level) if key.endswith("checks")) - check_fields(schema_obj, rules, keys) - - -def check_fields(schema_obj, rules, keys): - for key in keys: - for rule in rules[key]: - ast = expression.parse_string(rule)[0] - if isinstance(ast, BinOp): - check_binop(schema_obj, ast) - elif isinstance(ast, Function): - check_function(schema_obj, ast) - elif isinstance(ast, Property): - check_property(schema_obj, ast) - elif isinstance(ast, RightOp): - check_half(schema_obj, ast.rh) - - -def check_binop(schema_obj, binop): - for half in [binop.lh, binop.rh]: - check_half(schema_obj, half) + field_names = {field.name for key, field in schema_obj.objects.metadata.items()} - -def check_half(schema_obj, half): - if isinstance(half, BinOp): - check_binop(schema_obj, half) - elif isinstance(half, Function): - check_function(schema_obj, half) - elif isinstance(half, Property): - check_property(schema_obj, half) - elif isinstance(half, Element): - check_property(schema_obj, half.name) - - -def check_function(schema_obj, function): - for x in function.args: - if isinstance(x, Property): - check_property(schema_obj, x) - elif isinstance(x, Function): - check_function(schema_obj, x) - elif isinstance(x, Array): - check_array(schema_obj, x) - - -def check_array(schema_obj, array): - for element in array.elements: - if isinstance(element, Property): - check_property(schema_obj, element) - - -def check_property(schema_obj, property): - if property.name == "sidecar": - assert property.field in schema_obj.objects.metadata + for key, rule in walk_schema( + schema_obj.rules, lambda k, v: isinstance(v, Mapping) and v.get("selectors") + ): + for selector in rule["selectors"]: + ast = expression.parse_string(selector)[0] + for name in find_names(ast): + if name.startswith(("json.", "sidecar.")): + assert ( + name.split(".", 1)[1] in field_names + ), f"Bad field in selector: {name} ({key})" + for selector in rule.get("checks", []): + ast = expression.parse_string(selector)[0] + for name in find_names(ast): + if name.startswith(("json.", "sidecar.")): + assert ( + name.split(".", 1)[1] in field_names + ), f"Bad field in selector: {name} ({key})" + + +def test_test_valid_sidecar_field(): + schema_obj = Namespace.build( + { + "objects": { + "metadata": { + "a": {"name": "a"}, + } + }, + "rules": {"myruleA": {"selectors": ["sidecar.a"], "checks": ["json.a == sidecar.a"]}}, + } + ) + test_valid_sidecar_field(schema_obj) + + schema_obj.objects.metadata.a["name"] = "b" + with pytest.raises(AssertionError): + test_valid_sidecar_field(schema_obj) + + +@singledispatch +def find_names(node: Union[ASTNode, str]): + # Walk AST nodes + if isinstance(node, BinOp): + yield from find_names(node.lh) + yield from find_names(node.rh) + elif isinstance(node, RightOp): + yield from find_names(node.rh) + elif isinstance(node, Array): + for element in node.elements: + yield from find_names(element) + elif isinstance(node, Element): + yield from find_names(node.name) + yield from find_names(node.index) + elif isinstance(node, (int, float)): + return + else: + raise TypeError(f"Unexpected node type: {node!r}") + + +@find_names.register +def find_function_names(node: Function): + yield node.name + for arg in node.args: + yield from find_names(arg) + + +@find_names.register +def find_property_name(node: Property): + # Properties are left-associative, so expand the left side + yield f"{next(find_names(node.name))}.{node.field}" + + +@find_names.register +def find_identifiers(node: str): + if not node.startswith(('"', "'")): + yield node