Skip to content

Commit

Permalink
rf: Walk all rules, collect all names
Browse files Browse the repository at this point in the history
  • Loading branch information
effigies committed Aug 30, 2024
1 parent 218e8d9 commit 0f7bfb1
Showing 1 changed file with 74 additions and 57 deletions.
131 changes: 74 additions & 57 deletions tools/schemacode/bidsschematools/tests/test_expressions.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from collections.abc import Mapping
from functools import singledispatch
from typing import Union

import pytest
from pyparsing.exceptions import ParseException

from ..expressions import (
Array,
ASTNode,
BinOp,
Element,
Function,
Property,
RightOp,
expression,
)
from ..expressions import Array, ASTNode, BinOp, Function, Property, RightOp, expression
from ..types import Namespace


def test_schema_expressions(schema_obj):
Expand Down Expand Up @@ -88,65 +84,86 @@ def test_expected_failures(expr):
expression.parse_string(expr)


def walk_schema(schema_obj, predicate):
for key, value in schema_obj.items():
if predicate(key, value):
yield key, value
if isinstance(value, Mapping):
for subkey, value in walk_schema(value, predicate):
yield f"{key}.{subkey}", value


def test_valid_sidecar_field(schema_obj):
"""Check sidecar fields actually exist in the metadata listed in the schema.
Test failures are usually due to typos.
"""
for rules, level in ((schema_obj.rules.checks, 3),):
keys = (key for key in rules.keys(level=level) if key.endswith("selectors"))
check_fields(schema_obj, rules, keys)

keys = (key for key in rules.keys(level=level) if key.endswith("checks"))
check_fields(schema_obj, rules, keys)


def check_fields(schema_obj, rules, keys):
for key in keys:
for rule in rules[key]:
ast = expression.parse_string(rule)[0]
if isinstance(ast, BinOp):
check_binop(schema_obj, ast)
elif isinstance(ast, Function):
check_function(schema_obj, ast)
elif isinstance(ast, Property):
check_property(schema_obj, ast)
elif isinstance(ast, RightOp):
check_half(schema_obj, ast.rh)
field_names = {field.name for key, field in schema_obj.objects.metadata.items()}

for key, rule in walk_schema(
schema_obj.rules, lambda k, v: isinstance(v, Mapping) and v.get("selectors")
):
for selector in rule["selectors"]:
ast = expression.parse_string(selector)[0]
for name in find_names(ast):
if name.startswith(("json.", "sidecar.")):
assert (
name.split(".", 1)[1] in field_names
), f"Bad field in selector: {name} ({key})"
for selector in rule.get("checks", []):
ast = expression.parse_string(selector)[0]
for name in find_names(ast):
if name.startswith(("json.", "sidecar.")):
assert (
name.split(".", 1)[1] in field_names
), f"Bad field in selector: {name} ({key})"


def test_test_valid_sidecar_field():
schema_obj = Namespace.build(
{
"objects": {
"metadata": {
"a": {"name": "a"},
}
},
"rules": {"myruleA": {"selectors": ["sidecar.a"], "checks": ["json.a == sidecar.a"]}},
}
)
test_valid_sidecar_field(schema_obj)

def check_binop(schema_obj, binop):
for half in [binop.lh, binop.rh]:
check_half(schema_obj, half)
schema_obj.objects.metadata.a["name"] = "b"
with pytest.raises(AssertionError):
test_valid_sidecar_field(schema_obj)


def check_half(schema_obj, half):
if isinstance(half, BinOp):
check_binop(schema_obj, half)
elif isinstance(half, Function):
check_function(schema_obj, half)
elif isinstance(half, Property):
check_property(schema_obj, half)
elif isinstance(half, Element):
check_property(schema_obj, half.name)
@singledispatch
def find_names(node: Union[ASTNode, str]):
# Walk AST nodes
if isinstance(node, BinOp):
yield from find_names(node.lh)
yield from find_names(node.rh)
elif isinstance(node, RightOp):
yield from find_names(node.rh)
elif isinstance(node, Array):
for element in node.elements:
yield from find_names(element)


def check_function(schema_obj, function):
for x in function.args:
if isinstance(x, Property):
check_property(schema_obj, x)
elif isinstance(x, Function):
check_function(schema_obj, x)
elif isinstance(x, Array):
check_array(schema_obj, x)
@find_names.register
def find_function_names(node: Function):
yield node.name
for arg in node.args:
yield from find_names(arg)


def check_array(schema_obj, array):
for element in array.elements:
if isinstance(element, Property):
check_property(schema_obj, element)
@find_names.register
def find_property_name(node: Property):
# Properties are left-associative, so expand the left side
yield f"{next(find_names(node.name))}.{node.field}"


def check_property(schema_obj, property):
if property.name == "sidecar":
assert property.field in schema_obj.objects.metadata
@find_names.register
def find_identifiers(node: str):
if not node.startswith(('"', "'")):
yield node

0 comments on commit 0f7bfb1

Please sign in to comment.