diff --git a/clorm/orm/_typing.py b/clorm/orm/_typing.py index 473272a..0bacdf1 100644 --- a/clorm/orm/_typing.py +++ b/clorm/orm/_typing.py @@ -1,4 +1,6 @@ +import inspect import sys +from inspect import FrameInfo from typing import Any, Dict, ForwardRef, Optional, Tuple, Type, TypeVar, Union, _eval_type, cast from clingo import Symbol @@ -60,11 +62,29 @@ def get_args(t: Type[Any]) -> Tuple[Any, ...]: return getattr(t, "__args__", ()) +if sys.version_info < (3, 9): + from ast import literal_eval + + def _strip_quoted_annotations(annotation: str) -> str: + """Strip quotes around any annotations. + + This is needed because the _eval_type() function for Python 3.8 and 3.7 doesn't + handle a ForwardRef that contains a quoted string. + + """ + try: + output = literal_eval(annotation) + return output if isinstance(output, str) else annotation + except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): + return annotation + + def resolve_annotations( raw_annotations: Dict[str, Type[Any]], module_name: Optional[str] = None ) -> Dict[str, Type[Any]]: """ Taken from https://github.com/pydantic/pydantic/blob/1.10.X-fixes/pydantic/typing.py#L376 + with some modifications for handling when the first _eval_type() call fails. Resolve string or ForwardRef annotations into type objects if possible. """ @@ -79,16 +99,45 @@ def resolve_annotations( base_globals = module.__dict__ annotations = {} + frameinfos: Union[list[FrameInfo], None] = None + locals_ = {} for name, value in raw_annotations.items(): if isinstance(value, str): + + # Strip quoted string annotions for Python 3.7 and 3.8 + if sys.version_info < (3, 9): + value = _strip_quoted_annotations(value) + + # Turn the string type annotation into a ForwardRef for processing if (3, 10) > sys.version_info >= (3, 9, 8) or sys.version_info >= (3, 10, 1): value = ForwardRef(value, is_argument=False, is_class=True) else: value = ForwardRef(value, is_argument=False) try: - value = _eval_type(value, base_globals, None) + type_ = _eval_type(value, base_globals, None) + except NameError: - # this is ok, it can be fixed with update_forward_refs - pass - annotations[name] = value + # The type annotation could refer to a definition at a non-global scope so build + # the locals from the calling context. We reuse the same set of locals for + # multiple annotations. + if frameinfos is None: + frameinfos = inspect.stack() + if len(frameinfos) < 4: + raise RuntimeError( + 'Cannot resolve field "{name}" with type annotation "{value}"' + ) + frameinfos = frameinfos[3:] + type_ = None + while frameinfos: + try: + type_ = _eval_type(value, base_globals, locals_) + break + except NameError: + finfo = frameinfos.pop(0) + locals_.update(finfo.frame.f_locals) + if type_ is None: + raise RuntimeError( + f'Cannot resolve field "{name}" with type annotation "{value}"' + ) + annotations[name] = type_ return annotations diff --git a/clorm/orm/core.py b/clorm/orm/core.py index cec0826..6a9da2f 100644 --- a/clorm/orm/core.py +++ b/clorm/orm/core.py @@ -2918,7 +2918,6 @@ def infer_field_definition(type_: Type[Any], module: str) -> Optional[Type[BaseF def _make_predicatedefn( class_name: str, namespace: Dict[str, Any], meta_dct: Dict[str, Any] ) -> PredicateDefn: - # Set the default predicate name pname = _predicatedefn_default_predicate_name(class_name) anon = False @@ -3001,15 +3000,37 @@ def _make_predicatedefn( ) fields_from_dct[fname] = fdefn - fields_from_annotations = {} module = namespace.get("__module__", None) - for name, type_ in resolve_annotations(namespace.get("__annotations__", {}), module).items(): - if name in fields_from_dct: # first check if FieldDefinition was assigned - fields_from_annotations[name] = fields_from_dct[name] - else: # if not try to infer the definition based on the type + + # Get the list of fields with annotations + annotations = namespace.get("__annotations__", {}) + + # If using type annotations then all fields must be annotated - however some fields can + # have field definition overrides. + if not annotations: + field_specification = fields_from_dct + else: + set_anno = set(annotations) + set_dct = set(fields_from_dct) + if not (set_dct <= set_anno): + raise TypeError( + ( + f"Predicate '{pname}' contains a mixture of type annotated and un-annotated " + f"fields ({set_anno} and {set_dct}). If one field is annotated then all fields " + "must be annotated" + ) + ) + + raw_annotations = { + name_: type_ for name_, type_ in annotations.items() if name_ not in fields_from_dct + } + field_specification = { + name: fields_from_dct.get(name, None) for name, _ in annotations.items() + } + for name, type_ in resolve_annotations(raw_annotations, module).items(): fdefn = infer_field_definition(type_, module) if fdefn: - fields_from_annotations[name] = fdefn + field_specification[name] = fdefn elif inspect.isclass(type_): raise TypeError( ( @@ -3018,22 +3039,9 @@ def _make_predicatedefn( ) ) - # TODO can this be done more elegantly - set_anno = set(fields_from_annotations) - set_dct = set(fields_from_dct) - set_union = set_dct.union(set_anno) - if set_dct < set_union > set_anno: - raise TypeError( - ( - f"Predicate '{pname}': Mixed fields are not allowed. " - "(one field has just an annotation, the other one was only assigned a FieldDefinition)" - ) - ) - fas = [] idx = 0 - fields_from_annotations.update(**fields_from_dct) - for fname, fdefn in fields_from_annotations.items(): + for fname, fdefn in field_specification.items(): try: fd = get_field_definition(fdefn, module) fa = FieldAccessor(fname, idx, fd) diff --git a/requirements.txt b/requirements.txt index 836997d..5f1b583 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ clingo>=5.5.1 -typing_extensions; python_version < '3.8' +typing_extensions; python_version < '3.11' dataclasses; python_version == '3.6' diff --git a/tests/test_forward_ref.py b/tests/test_forward_ref.py index 979898a..005607d 100644 --- a/tests/test_forward_ref.py +++ b/tests/test_forward_ref.py @@ -1,4 +1,4 @@ -import importlib +import importlib.util import inspect import secrets import sys @@ -86,6 +86,7 @@ class P1(Predicate): def test_postponed_annotations_complex(self): code = """ +from __future__ import annotations from clorm import Predicate from typing import Union @@ -105,6 +106,64 @@ class P3(Predicate): p = module.P3(a=module.P2(a=42)) self.assertEqual(str(p), "p3(p2(42))") + def test_postponed_annotations_nonglobal1(self): + code = """ +from __future__ import annotations +from clorm import Predicate, ConstantField, field +from typing import Union + +def define_predicates(): + + + class P1(Predicate): + a1: str = field(ConstantField) + a: int + b: str + + class P2(Predicate): + a: Union[int, P1] + + return P1, P2 + +XP1, XP2 = define_predicates() + +""" + with self._create_module(code) as module: + p1 = module.XP1(a1="c", a=3, b="42") + self.assertEqual(str(p1), 'p1(c,3,"42")') + p2 = module.XP2(a=p1) + self.assertEqual(str(p2), 'p2(p1(c,3,"42"))') + + def test_postponed_annotations_nonglobal2(self): + code = """ +from __future__ import annotations +from clorm import Predicate, ConstantField, field +from typing import Union + +def define_predicates(): + + + class P1(Predicate): + a1: str = field(ConstantField) + a: int + b: str + + def define_complex(): + class P2(Predicate): + a: Union[int, P1] + return P2 + + return P1, define_complex() + +XP1, XP2 = define_predicates() + +""" + with self._create_module(code) as module: + p1 = module.XP1(a1="c", a=3, b="42") + self.assertEqual(str(p1), 'p1(c,3,"42")') + p2 = module.XP2(a=p1) + self.assertEqual(str(p2), 'p2(p1(c,3,"42"))') + def test_forward_ref(self): def module_(): from typing import ForwardRef diff --git a/tests/test_mypy_query.py b/tests/test_mypy_query.py index ed6333c..05e3d92 100644 --- a/tests/test_mypy_query.py +++ b/tests/test_mypy_query.py @@ -1,6 +1,10 @@ +import sys from typing import Tuple -from typing_extensions import reveal_type +if sys.version_info < (3, 11): + from typing_extensions import reveal_type +else: + from typing import reveal_type from clorm import FactBase, Predicate from clorm.orm._queryimpl import GroupedQuery, UnGroupedQuery diff --git a/tests/test_orm_core.py b/tests/test_orm_core.py index ae2134d..a157786 100644 --- a/tests/test_orm_core.py +++ b/tests/test_orm_core.py @@ -1489,7 +1489,7 @@ class Blah(ComplexTerm): def test_predicates_with_annotated_fields(self): class P(Predicate): a: int = IntegerField - b = StringField + b: str = StringField class P1(Predicate): a: int