diff --git a/jsonpath/env.py b/jsonpath/env.py index 01afedc..f1dfe4c 100644 --- a/jsonpath/env.py +++ b/jsonpath/env.py @@ -20,7 +20,15 @@ from . import function_extensions from .exceptions import JSONPathNameError from .exceptions import JSONPathSyntaxError +from .exceptions import JSONPathTypeError from .filter import UNDEFINED +from .filter import VALUE_TYPE_EXPRESSIONS +from .filter import FilterExpression +from .filter import FunctionExtension +from .filter import InfixExpression +from .filter import Path +from .function_extensions import ExpressionType +from .function_extensions import FilterFunction from .function_extensions import validate from .lex import Lexer from .match import JSONPathMatch @@ -120,6 +128,7 @@ def __init__( *, filter_caching: bool = True, unicode_escape: bool = True, + well_typed: bool = True, ) -> None: self.filter_caching: bool = filter_caching """Enable or disable filter expression caching.""" @@ -128,6 +137,9 @@ def __init__( """Enable or disable decoding of UTF-16 escape sequences found in JSONPath string literals.""" + self.well_typed: bool = well_typed + """Control well-typedness checks on filter function expressions.""" + self.lexer: Lexer = self.lexer_class(env=self) """The lexer bound to this environment.""" @@ -336,12 +348,77 @@ def validate_function_extension_signature( f"function {token.value!r} is not defined", token=token ) from err + # Type-aware function extensions use the spec's type system. + if self.well_typed and isinstance(func, FilterFunction): + self.check_well_typedness(token, func, args) + return args + + # A callable with a `validate` method? if hasattr(func, "validate"): args = func.validate(self, args, token) assert isinstance(args, list) return args + + # Generic validation using introspection. return validate(self, func, args, token) + def check_well_typedness( + self, + token: Token, + func: FilterFunction, + args: List[FilterExpression], + ) -> None: + """Check the well-typedness of a function's arguments at compile-time.""" + # Correct number of arguments? + if len(args) != len(func.arg_types): + raise JSONPathTypeError( + f"{token.value!r}() requires {len(func.arg_types)} arguments", + token=token, + ) + + # Argument types + for idx, typ in enumerate(func.arg_types): + arg = args[idx] + if typ == ExpressionType.VALUE: + if not ( + isinstance(arg, VALUE_TYPE_EXPRESSIONS) + or ( + (isinstance(arg, Path) and arg.path.singular_query()) + or (self._function_return_type(arg) == ExpressionType.VALUE) + ) + ): + raise JSONPathTypeError( + f"{token.value}() argument {idx} must be of ValueType", + token=token, + ) + elif typ == ExpressionType.LOGICAL: + if not isinstance(arg, (Path, InfixExpression)): + raise JSONPathTypeError( + f"{token.value}() argument {idx} must be of LogicalType", + token=token, + ) + elif typ == ExpressionType.NODES and not ( + isinstance(arg, Path) + or self._function_return_type(arg) == ExpressionType.NODES + ): + raise JSONPathTypeError( + f"{token.value}() argument {idx} must be of NodesType", + token=token, + ) + + def _function_return_type(self, expr: FilterExpression) -> Optional[ExpressionType]: + """Return the type returned from a filter function. + + If _expr_ is not a `FunctionExtension` or the registered function definition is + not type-aware, return `None`. + """ + if not isinstance(expr, FunctionExtension): + return None + func = self.function_extensions.get(expr.name) + if isinstance(func, FilterFunction): + return func.return_type + return None + def getitem(self, obj: Any, key: Any) -> Any: """Sequence and mapping item getter used throughout JSONPath resolution. diff --git a/jsonpath/filter.py b/jsonpath/filter.py index e8fbc67..965c2b4 100644 --- a/jsonpath/filter.py +++ b/jsonpath/filter.py @@ -690,3 +690,13 @@ def walk(expr: FilterExpression) -> Iterable[FilterExpression]: yield expr for child in expr.children(): yield from walk(child) + + +VALUE_TYPE_EXPRESSIONS = ( + Nil, + Undefined, + Literal, + RegexArgument, + ListLiteral, + CurrentKey, +) diff --git a/jsonpath/function_extensions/__init__.py b/jsonpath/function_extensions/__init__.py index 983f3ba..302d339 100644 --- a/jsonpath/function_extensions/__init__.py +++ b/jsonpath/function_extensions/__init__.py @@ -1,6 +1,8 @@ # noqa: D104 from .arguments import validate from .count import Count +from .filter_function import ExpressionType +from .filter_function import FilterFunction from .is_instance import IsInstance from .keys import keys from .length import length @@ -11,6 +13,8 @@ __all__ = ( "Count", + "ExpressionType", + "FilterFunction", "IsInstance", "keys", "length", diff --git a/jsonpath/function_extensions/filter_function.py b/jsonpath/function_extensions/filter_function.py new file mode 100644 index 0000000..7391323 --- /dev/null +++ b/jsonpath/function_extensions/filter_function.py @@ -0,0 +1,32 @@ +"""Classes modeling the JSONPath spec type system for function extensions.""" +from abc import ABC +from abc import abstractmethod +from enum import Enum +from typing import Any +from typing import List + + +class ExpressionType(Enum): + """The type of a filter function argument or return value.""" + + VALUE = 1 + LOGICAL = 2 + NODES = 3 + + +class FilterFunction(ABC): + """Base class for typed function extensions.""" + + @property + @abstractmethod + def arg_types(self) -> List[ExpressionType]: + """Argument types expected by the filter function.""" + + @property + @abstractmethod + def return_type(self) -> ExpressionType: + """The type of the value returned by the filter function.""" + + @abstractmethod + def __call__(self, *args: Any, **kwds: Any) -> Any: + """Called the filter function.""" diff --git a/jsonpath/function_extensions/match.py b/jsonpath/function_extensions/match.py index 8fd2fbd..0334701 100644 --- a/jsonpath/function_extensions/match.py +++ b/jsonpath/function_extensions/match.py @@ -1,56 +1,21 @@ """The standard `match` function extension.""" import re -from typing import TYPE_CHECKING -from typing import List -from typing import Pattern -from typing import Union -from jsonpath.exceptions import JSONPathTypeError -from jsonpath.filter import RegexArgument -from jsonpath.filter import StringLiteral +from jsonpath.function_extensions import ExpressionType +from jsonpath.function_extensions import FilterFunction -if TYPE_CHECKING: - from jsonpath.env import JSONPathEnvironment - from jsonpath.token import Token +class Match(FilterFunction): + """A type-aware implementation of the standard `match` function.""" -class Match: - """The built-in `match` function. - - This implementation uses the standard _re_ module, without attempting to map - I-Regexps to Python regex. - """ - - def __call__(self, string: str, pattern: Union[str, Pattern[str], None]) -> bool: - """Return `True` if _pattern_ matches the given string, `False` otherwise.""" - # The IETF JSONPath draft requires us to return `False` if the pattern was - # invalid. We use `None` to indicate the pattern could not be compiled. - if string is None or pattern is None: - return False + arg_types = [ExpressionType.VALUE, ExpressionType.VALUE] + return_type = ExpressionType.LOGICAL + def __call__(self, string: str, pattern: str) -> bool: + """Return `True` if _s_ matches _pattern_, or `False` otherwise.""" try: + # re.fullmatch caches compiled patterns internally return bool(re.fullmatch(pattern, string)) except (TypeError, re.error): return False - - def validate( - self, - _: "JSONPathEnvironment", - args: List[object], - token: "Token", - ) -> List[object]: - """Function argument validation.""" - if len(args) != 2: # noqa: PLR2004 - raise JSONPathTypeError( - f"{token.value!r} requires 2 arguments, found {len(args)}", - token=token, - ) - - if isinstance(args[1], StringLiteral): - try: - return [args[0], RegexArgument(re.compile(args[1].value))] - except re.error: - return [None, None] - - return args diff --git a/jsonpath/path.py b/jsonpath/path.py index 0c8fc7c..10359ed 100644 --- a/jsonpath/path.py +++ b/jsonpath/path.py @@ -17,6 +17,9 @@ from jsonpath._data import load_data from jsonpath.match import FilterContextVars from jsonpath.match import JSONPathMatch +from jsonpath.selectors import IndexSelector +from jsonpath.selectors import ListSelector +from jsonpath.selectors import PropertySelector if TYPE_CHECKING: from io import IOBase @@ -206,6 +209,20 @@ def empty(self) -> bool: """Return `True` if this path has no selectors.""" return not bool(self.selectors) + def singular_query(self) -> bool: + """Return `True` if this JSONPath query is a singular query.""" + for selector in self.selectors: + if isinstance(selector, PropertySelector): + continue + if ( + isinstance(selector, ListSelector) + and len(selector.items) == 1 + and isinstance(selector.items[0], (PropertySelector, IndexSelector)) + ): + continue + return False + return True + class CompoundJSONPath: """Multiple `JSONPath`s combined."""