diff --git a/clorm/orm/_typing.py b/clorm/orm/_typing.py index 473272a..c386f43 100644 --- a/clorm/orm/_typing.py +++ b/clorm/orm/_typing.py @@ -1,4 +1,5 @@ import sys +import inspect from typing import Any, Dict, ForwardRef, Optional, Tuple, Type, TypeVar, Union, _eval_type, cast from clingo import Symbol @@ -65,6 +66,7 @@ def resolve_annotations( ) -> Dict[str, Type[Any]]: """ Taken from https://github.com/pydantic/pydantic/blob/1.10.X-fixes/pydantic/typing.py#L376 + with some modifications for handling when the first _eval_type() call fails. Resolve string or ForwardRef annotations into type objects if possible. """ @@ -86,9 +88,26 @@ def resolve_annotations( else: value = ForwardRef(value, is_argument=False) try: - value = _eval_type(value, base_globals, None) + type_ = _eval_type(value, base_globals, None) except NameError: - # this is ok, it can be fixed with update_forward_refs - pass - annotations[name] = value + # The type annotation could refer to a definition at a non-global scope so build + # the locals from the calling context. + currframe = inspect.currentframe() + finfos = inspect.getouterframes(currframe) + if len(finfos) < 4: + raise RuntimeError('Cannot resolve field "{name}" with type annotation "{value}"') + locals_ = {} + type_ = None + for finfo in finfos[3:]: + locals_.update(finfo.frame.f_locals) + try: + type_ = _eval_type(value, base_globals, locals_) + break + except NameError: + pass + if type_ is None: + raise RuntimeError( + f'Cannot resolve field "{name}" with type annotation "{value}"' + ) + annotations[name] = type_ return annotations diff --git a/clorm/orm/core.py b/clorm/orm/core.py index cec0826..6a9da2f 100644 --- a/clorm/orm/core.py +++ b/clorm/orm/core.py @@ -2918,7 +2918,6 @@ def infer_field_definition(type_: Type[Any], module: str) -> Optional[Type[BaseF def _make_predicatedefn( class_name: str, namespace: Dict[str, Any], meta_dct: Dict[str, Any] ) -> PredicateDefn: - # Set the default predicate name pname = _predicatedefn_default_predicate_name(class_name) anon = False @@ -3001,15 +3000,37 @@ def _make_predicatedefn( ) fields_from_dct[fname] = fdefn - fields_from_annotations = {} module = namespace.get("__module__", None) - for name, type_ in resolve_annotations(namespace.get("__annotations__", {}), module).items(): - if name in fields_from_dct: # first check if FieldDefinition was assigned - fields_from_annotations[name] = fields_from_dct[name] - else: # if not try to infer the definition based on the type + + # Get the list of fields with annotations + annotations = namespace.get("__annotations__", {}) + + # If using type annotations then all fields must be annotated - however some fields can + # have field definition overrides. + if not annotations: + field_specification = fields_from_dct + else: + set_anno = set(annotations) + set_dct = set(fields_from_dct) + if not (set_dct <= set_anno): + raise TypeError( + ( + f"Predicate '{pname}' contains a mixture of type annotated and un-annotated " + f"fields ({set_anno} and {set_dct}). If one field is annotated then all fields " + "must be annotated" + ) + ) + + raw_annotations = { + name_: type_ for name_, type_ in annotations.items() if name_ not in fields_from_dct + } + field_specification = { + name: fields_from_dct.get(name, None) for name, _ in annotations.items() + } + for name, type_ in resolve_annotations(raw_annotations, module).items(): fdefn = infer_field_definition(type_, module) if fdefn: - fields_from_annotations[name] = fdefn + field_specification[name] = fdefn elif inspect.isclass(type_): raise TypeError( ( @@ -3018,22 +3039,9 @@ def _make_predicatedefn( ) ) - # TODO can this be done more elegantly - set_anno = set(fields_from_annotations) - set_dct = set(fields_from_dct) - set_union = set_dct.union(set_anno) - if set_dct < set_union > set_anno: - raise TypeError( - ( - f"Predicate '{pname}': Mixed fields are not allowed. " - "(one field has just an annotation, the other one was only assigned a FieldDefinition)" - ) - ) - fas = [] idx = 0 - fields_from_annotations.update(**fields_from_dct) - for fname, fdefn in fields_from_annotations.items(): + for fname, fdefn in field_specification.items(): try: fd = get_field_definition(fdefn, module) fa = FieldAccessor(fname, idx, fd) diff --git a/tests/test_forward_ref.py b/tests/test_forward_ref.py index 979898a..11a78b8 100644 --- a/tests/test_forward_ref.py +++ b/tests/test_forward_ref.py @@ -1,4 +1,4 @@ -import importlib +import importlib.util import inspect import secrets import sys @@ -86,6 +86,7 @@ class P1(Predicate): def test_postponed_annotations_complex(self): code = """ +from __future__ import annotations from clorm import Predicate from typing import Union @@ -105,6 +106,65 @@ class P3(Predicate): p = module.P3(a=module.P2(a=42)) self.assertEqual(str(p), "p3(p2(42))") + def test_postponed_annotations_nonglobal1(self): + code = """ +from __future__ import annotations +from clorm import Predicate, ConstantField, field +from typing import Union + +def define_predicates(): + + + class P1(Predicate): + a1: str = field(ConstantField) + a: int + b: str + + class P2(Predicate): + a: Union[int, P1] + + return P1, P2 + +XP1, XP2 = define_predicates() + +""" + with self._create_module(code) as module: + p1 = module.XP1(a1="c", a=3, b="42") + self.assertEqual(str(p1), 'p1(c,3,"42")') + p2 = module.XP2(a=p1) + self.assertEqual(str(p2), 'p2(p1(c,3,"42"))') + + def test_postponed_annotations_nonglobal2(self): + code = """ +from __future__ import annotations +from clorm import Predicate, ConstantField, field +from typing import Union + +def define_predicates(): + + + class P1(Predicate): + a1: str = field(ConstantField) + a: int + b: str + + def define_complex(): + class P2(Predicate): + a: Union[int, P1] + return P2 + + return P1, define_complex() + +XP1, XP2 = define_predicates() + +""" + with self._create_module(code) as module: + p1 = module.XP1(a1="c", a=3, b="42") + self.assertEqual(str(p1), 'p1(c,3,"42")') + p2 = module.XP2(a=p1) + self.assertEqual(str(p2), 'p2(p1(c,3,"42"))') + + def test_forward_ref(self): def module_(): from typing import ForwardRef diff --git a/tests/test_orm_core.py b/tests/test_orm_core.py index ae2134d..a157786 100644 --- a/tests/test_orm_core.py +++ b/tests/test_orm_core.py @@ -1489,7 +1489,7 @@ class Blah(ComplexTerm): def test_predicates_with_annotated_fields(self): class P(Predicate): a: int = IntegerField - b = StringField + b: str = StringField class P1(Predicate): a: int