Skip to content

Commit

Permalink
rf: Walk all rules, collect all names
Browse files Browse the repository at this point in the history
  • Loading branch information
effigies committed Aug 30, 2024
1 parent 218e8d9 commit cde2ba9
Showing 1 changed file with 89 additions and 56 deletions.
145 changes: 89 additions & 56 deletions tools/schemacode/bidsschematools/tests/test_expressions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from collections.abc import Mapping
from functools import singledispatch
from typing import Union

import pytest
from pyparsing.exceptions import ParseException

Expand All @@ -11,6 +15,7 @@
RightOp,
expression,
)
from ..types import Namespace


def test_schema_expressions(schema_obj):
Expand Down Expand Up @@ -88,65 +93,93 @@ def test_expected_failures(expr):
expression.parse_string(expr)


def walk_schema(schema_obj, predicate):
for key, value in schema_obj.items():
if predicate(key, value):
yield key, value
if isinstance(value, Mapping):
for subkey, value in walk_schema(value, predicate):
yield f"{key}.{subkey}", value


def test_valid_sidecar_field(schema_obj):
"""Check sidecar fields actually exist in the metadata listed in the schema.
Test failures are usually due to typos.
"""
for rules, level in ((schema_obj.rules.checks, 3),):
keys = (key for key in rules.keys(level=level) if key.endswith("selectors"))
check_fields(schema_obj, rules, keys)

keys = (key for key in rules.keys(level=level) if key.endswith("checks"))
check_fields(schema_obj, rules, keys)


def check_fields(schema_obj, rules, keys):
for key in keys:
for rule in rules[key]:
ast = expression.parse_string(rule)[0]
if isinstance(ast, BinOp):
check_binop(schema_obj, ast)
elif isinstance(ast, Function):
check_function(schema_obj, ast)
elif isinstance(ast, Property):
check_property(schema_obj, ast)
elif isinstance(ast, RightOp):
check_half(schema_obj, ast.rh)


def check_binop(schema_obj, binop):
for half in [binop.lh, binop.rh]:
check_half(schema_obj, half)
field_names = {field.name for key, field in schema_obj.objects.metadata.items()}


def check_half(schema_obj, half):
if isinstance(half, BinOp):
check_binop(schema_obj, half)
elif isinstance(half, Function):
check_function(schema_obj, half)
elif isinstance(half, Property):
check_property(schema_obj, half)
elif isinstance(half, Element):
check_property(schema_obj, half.name)


def check_function(schema_obj, function):
for x in function.args:
if isinstance(x, Property):
check_property(schema_obj, x)
elif isinstance(x, Function):
check_function(schema_obj, x)
elif isinstance(x, Array):
check_array(schema_obj, x)


def check_array(schema_obj, array):
for element in array.elements:
if isinstance(element, Property):
check_property(schema_obj, element)


def check_property(schema_obj, property):
if property.name == "sidecar":
assert property.field in schema_obj.objects.metadata
for key, rule in walk_schema(
schema_obj.rules, lambda k, v: isinstance(v, Mapping) and v.get("selectors")
):
for selector in rule["selectors"]:
ast = expression.parse_string(selector)[0]
for name in find_names(ast):
if name.startswith(("json.", "sidecar.")):
assert (
name.split(".", 1)[1] in field_names
), f"Bad field in selector: {name} ({key})"
for selector in rule.get("checks", []):
ast = expression.parse_string(selector)[0]
for name in find_names(ast):
if name.startswith(("json.", "sidecar.")):
assert (
name.split(".", 1)[1] in field_names
), f"Bad field in selector: {name} ({key})"


def test_test_valid_sidecar_field():
schema_obj = Namespace.build(
{
"objects": {
"metadata": {
"a": {"name": "a"},
}
},
"rules": {"myruleA": {"selectors": ["sidecar.a"], "checks": ["json.a == sidecar.a"]}},
}
)
test_valid_sidecar_field(schema_obj)

schema_obj.objects.metadata.a["name"] = "b"
with pytest.raises(AssertionError):
test_valid_sidecar_field(schema_obj)


@singledispatch
def find_names(node: Union[ASTNode, str]):
# Walk AST nodes
if isinstance(node, BinOp):
yield from find_names(node.lh)
yield from find_names(node.rh)
elif isinstance(node, RightOp):
yield from find_names(node.rh)
elif isinstance(node, Array):
for element in node.elements:
yield from find_names(element)
elif isinstance(node, Element):
yield from find_names(node.name)
yield from find_names(node.index)
elif isinstance(node, (int, float)):
return
else:
raise TypeError(f"Unexpected node type: {node!r}")


@find_names.register
def find_function_names(node: Function):
yield node.name
for arg in node.args:
yield from find_names(arg)


@find_names.register
def find_property_name(node: Property):
# Properties are left-associative, so expand the left side
yield f"{next(find_names(node.name))}.{node.field}"


@find_names.register
def find_identifiers(node: str):
if not node.startswith(('"', "'")):
yield node

0 comments on commit cde2ba9

Please sign in to comment.