From 23f829045eca4f012d359be7d8975620c33af32c Mon Sep 17 00:00:00 2001 From: David Rajaratnam Date: Tue, 13 Feb 2024 18:58:14 +1100 Subject: [PATCH 1/4] Extend the forward reference handling to more complex boundary cases Should now work when the Predicate definition with type annotation is defined in a non-global scope (eg. within a function) and is referring to another predicate definition of a different scope. --- clorm/orm/_typing.py | 27 ++++++++++++++--- clorm/orm/core.py | 50 ++++++++++++++++++------------- tests/test_forward_ref.py | 62 ++++++++++++++++++++++++++++++++++++++- tests/test_orm_core.py | 2 +- 4 files changed, 114 insertions(+), 27 deletions(-) diff --git a/clorm/orm/_typing.py b/clorm/orm/_typing.py index 473272a..c386f43 100644 --- a/clorm/orm/_typing.py +++ b/clorm/orm/_typing.py @@ -1,4 +1,5 @@ import sys +import inspect from typing import Any, Dict, ForwardRef, Optional, Tuple, Type, TypeVar, Union, _eval_type, cast from clingo import Symbol @@ -65,6 +66,7 @@ def resolve_annotations( ) -> Dict[str, Type[Any]]: """ Taken from https://github.com/pydantic/pydantic/blob/1.10.X-fixes/pydantic/typing.py#L376 + with some modifications for handling when the first _eval_type() call fails. Resolve string or ForwardRef annotations into type objects if possible. """ @@ -86,9 +88,26 @@ def resolve_annotations( else: value = ForwardRef(value, is_argument=False) try: - value = _eval_type(value, base_globals, None) + type_ = _eval_type(value, base_globals, None) except NameError: - # this is ok, it can be fixed with update_forward_refs - pass - annotations[name] = value + # The type annotation could refer to a definition at a non-global scope so build + # the locals from the calling context. + currframe = inspect.currentframe() + finfos = inspect.getouterframes(currframe) + if len(finfos) < 4: + raise RuntimeError('Cannot resolve field "{name}" with type annotation "{value}"') + locals_ = {} + type_ = None + for finfo in finfos[3:]: + locals_.update(finfo.frame.f_locals) + try: + type_ = _eval_type(value, base_globals, locals_) + break + except NameError: + pass + if type_ is None: + raise RuntimeError( + f'Cannot resolve field "{name}" with type annotation "{value}"' + ) + annotations[name] = type_ return annotations diff --git a/clorm/orm/core.py b/clorm/orm/core.py index cec0826..6a9da2f 100644 --- a/clorm/orm/core.py +++ b/clorm/orm/core.py @@ -2918,7 +2918,6 @@ def infer_field_definition(type_: Type[Any], module: str) -> Optional[Type[BaseF def _make_predicatedefn( class_name: str, namespace: Dict[str, Any], meta_dct: Dict[str, Any] ) -> PredicateDefn: - # Set the default predicate name pname = _predicatedefn_default_predicate_name(class_name) anon = False @@ -3001,15 +3000,37 @@ def _make_predicatedefn( ) fields_from_dct[fname] = fdefn - fields_from_annotations = {} module = namespace.get("__module__", None) - for name, type_ in resolve_annotations(namespace.get("__annotations__", {}), module).items(): - if name in fields_from_dct: # first check if FieldDefinition was assigned - fields_from_annotations[name] = fields_from_dct[name] - else: # if not try to infer the definition based on the type + + # Get the list of fields with annotations + annotations = namespace.get("__annotations__", {}) + + # If using type annotations then all fields must be annotated - however some fields can + # have field definition overrides. + if not annotations: + field_specification = fields_from_dct + else: + set_anno = set(annotations) + set_dct = set(fields_from_dct) + if not (set_dct <= set_anno): + raise TypeError( + ( + f"Predicate '{pname}' contains a mixture of type annotated and un-annotated " + f"fields ({set_anno} and {set_dct}). If one field is annotated then all fields " + "must be annotated" + ) + ) + + raw_annotations = { + name_: type_ for name_, type_ in annotations.items() if name_ not in fields_from_dct + } + field_specification = { + name: fields_from_dct.get(name, None) for name, _ in annotations.items() + } + for name, type_ in resolve_annotations(raw_annotations, module).items(): fdefn = infer_field_definition(type_, module) if fdefn: - fields_from_annotations[name] = fdefn + field_specification[name] = fdefn elif inspect.isclass(type_): raise TypeError( ( @@ -3018,22 +3039,9 @@ def _make_predicatedefn( ) ) - # TODO can this be done more elegantly - set_anno = set(fields_from_annotations) - set_dct = set(fields_from_dct) - set_union = set_dct.union(set_anno) - if set_dct < set_union > set_anno: - raise TypeError( - ( - f"Predicate '{pname}': Mixed fields are not allowed. " - "(one field has just an annotation, the other one was only assigned a FieldDefinition)" - ) - ) - fas = [] idx = 0 - fields_from_annotations.update(**fields_from_dct) - for fname, fdefn in fields_from_annotations.items(): + for fname, fdefn in field_specification.items(): try: fd = get_field_definition(fdefn, module) fa = FieldAccessor(fname, idx, fd) diff --git a/tests/test_forward_ref.py b/tests/test_forward_ref.py index 979898a..11a78b8 100644 --- a/tests/test_forward_ref.py +++ b/tests/test_forward_ref.py @@ -1,4 +1,4 @@ -import importlib +import importlib.util import inspect import secrets import sys @@ -86,6 +86,7 @@ class P1(Predicate): def test_postponed_annotations_complex(self): code = """ +from __future__ import annotations from clorm import Predicate from typing import Union @@ -105,6 +106,65 @@ class P3(Predicate): p = module.P3(a=module.P2(a=42)) self.assertEqual(str(p), "p3(p2(42))") + def test_postponed_annotations_nonglobal1(self): + code = """ +from __future__ import annotations +from clorm import Predicate, ConstantField, field +from typing import Union + +def define_predicates(): + + + class P1(Predicate): + a1: str = field(ConstantField) + a: int + b: str + + class P2(Predicate): + a: Union[int, P1] + + return P1, P2 + +XP1, XP2 = define_predicates() + +""" + with self._create_module(code) as module: + p1 = module.XP1(a1="c", a=3, b="42") + self.assertEqual(str(p1), 'p1(c,3,"42")') + p2 = module.XP2(a=p1) + self.assertEqual(str(p2), 'p2(p1(c,3,"42"))') + + def test_postponed_annotations_nonglobal2(self): + code = """ +from __future__ import annotations +from clorm import Predicate, ConstantField, field +from typing import Union + +def define_predicates(): + + + class P1(Predicate): + a1: str = field(ConstantField) + a: int + b: str + + def define_complex(): + class P2(Predicate): + a: Union[int, P1] + return P2 + + return P1, define_complex() + +XP1, XP2 = define_predicates() + +""" + with self._create_module(code) as module: + p1 = module.XP1(a1="c", a=3, b="42") + self.assertEqual(str(p1), 'p1(c,3,"42")') + p2 = module.XP2(a=p1) + self.assertEqual(str(p2), 'p2(p1(c,3,"42"))') + + def test_forward_ref(self): def module_(): from typing import ForwardRef diff --git a/tests/test_orm_core.py b/tests/test_orm_core.py index ae2134d..a157786 100644 --- a/tests/test_orm_core.py +++ b/tests/test_orm_core.py @@ -1489,7 +1489,7 @@ class Blah(ComplexTerm): def test_predicates_with_annotated_fields(self): class P(Predicate): a: int = IntegerField - b = StringField + b: str = StringField class P1(Predicate): a: int From fa2a9a3fcd4c645fc20ddd38a301d958b56ea0cd Mon Sep 17 00:00:00 2001 From: David Rajaratnam Date: Tue, 13 Feb 2024 21:19:20 +1100 Subject: [PATCH 2/4] Make resolve_annotations() more efficient for complex cases. --- clorm/orm/_typing.py | 24 +++++++++++++++--------- tests/test_forward_ref.py | 1 - 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/clorm/orm/_typing.py b/clorm/orm/_typing.py index c386f43..da8c815 100644 --- a/clorm/orm/_typing.py +++ b/clorm/orm/_typing.py @@ -1,5 +1,6 @@ import sys import inspect +from inspect import FrameInfo from typing import Any, Dict, ForwardRef, Optional, Tuple, Type, TypeVar, Union, _eval_type, cast from clingo import Symbol @@ -81,6 +82,8 @@ def resolve_annotations( base_globals = module.__dict__ annotations = {} + frameinfos: Union[list[FrameInfo], None] = None + locals_ = {} for name, value in raw_annotations.items(): if isinstance(value, str): if (3, 10) > sys.version_info >= (3, 9, 8) or sys.version_info >= (3, 10, 1): @@ -91,20 +94,23 @@ def resolve_annotations( type_ = _eval_type(value, base_globals, None) except NameError: # The type annotation could refer to a definition at a non-global scope so build - # the locals from the calling context. - currframe = inspect.currentframe() - finfos = inspect.getouterframes(currframe) - if len(finfos) < 4: - raise RuntimeError('Cannot resolve field "{name}" with type annotation "{value}"') - locals_ = {} + # the locals from the calling context. We reuse the same set of locals for + # multiple annotations. + if frameinfos is None: + frameinfos = inspect.stack() + if len(frameinfos) < 4: + raise RuntimeError( + 'Cannot resolve field "{name}" with type annotation "{value}"' + ) + frameinfos = frameinfos[3:] type_ = None - for finfo in finfos[3:]: - locals_.update(finfo.frame.f_locals) + while frameinfos: try: type_ = _eval_type(value, base_globals, locals_) break except NameError: - pass + finfo = frameinfos.pop(0) + locals_.update(finfo.frame.f_locals) if type_ is None: raise RuntimeError( f'Cannot resolve field "{name}" with type annotation "{value}"' diff --git a/tests/test_forward_ref.py b/tests/test_forward_ref.py index 11a78b8..005607d 100644 --- a/tests/test_forward_ref.py +++ b/tests/test_forward_ref.py @@ -164,7 +164,6 @@ class P2(Predicate): p2 = module.XP2(a=p1) self.assertEqual(str(p2), 'p2(p1(c,3,"42"))') - def test_forward_ref(self): def module_(): from typing import ForwardRef From 3a2d2ca241b34e7a194427a981c4be713823f4f9 Mon Sep 17 00:00:00 2001 From: David Rajaratnam Date: Tue, 13 Feb 2024 21:55:28 +1100 Subject: [PATCH 3/4] Formmatting and conditional imports based on python version --- clorm/orm/_typing.py | 2 +- requirements.txt | 2 +- tests/test_mypy_query.py | 6 +++++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/clorm/orm/_typing.py b/clorm/orm/_typing.py index da8c815..d317904 100644 --- a/clorm/orm/_typing.py +++ b/clorm/orm/_typing.py @@ -1,5 +1,5 @@ -import sys import inspect +import sys from inspect import FrameInfo from typing import Any, Dict, ForwardRef, Optional, Tuple, Type, TypeVar, Union, _eval_type, cast diff --git a/requirements.txt b/requirements.txt index 836997d..5f1b583 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ clingo>=5.5.1 -typing_extensions; python_version < '3.8' +typing_extensions; python_version < '3.11' dataclasses; python_version == '3.6' diff --git a/tests/test_mypy_query.py b/tests/test_mypy_query.py index ed6333c..05e3d92 100644 --- a/tests/test_mypy_query.py +++ b/tests/test_mypy_query.py @@ -1,6 +1,10 @@ +import sys from typing import Tuple -from typing_extensions import reveal_type +if sys.version_info < (3, 11): + from typing_extensions import reveal_type +else: + from typing import reveal_type from clorm import FactBase, Predicate from clorm.orm._queryimpl import GroupedQuery, UnGroupedQuery From 8cd478ca059eb19871d1a5d5abb6cfd80902112b Mon Sep 17 00:00:00 2001 From: David Rajaratnam Date: Wed, 14 Feb 2024 11:07:51 +1100 Subject: [PATCH 4/4] Fix resolve_annotations() bug for Python 3.7 and 3.8 For these Python versions the _eval_type() function fails on ForwardRefs that are quoted strings. So first strip any quotes from type annotation strings. --- clorm/orm/_typing.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/clorm/orm/_typing.py b/clorm/orm/_typing.py index d317904..0bacdf1 100644 --- a/clorm/orm/_typing.py +++ b/clorm/orm/_typing.py @@ -62,6 +62,23 @@ def get_args(t: Type[Any]) -> Tuple[Any, ...]: return getattr(t, "__args__", ()) +if sys.version_info < (3, 9): + from ast import literal_eval + + def _strip_quoted_annotations(annotation: str) -> str: + """Strip quotes around any annotations. + + This is needed because the _eval_type() function for Python 3.8 and 3.7 doesn't + handle a ForwardRef that contains a quoted string. + + """ + try: + output = literal_eval(annotation) + return output if isinstance(output, str) else annotation + except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): + return annotation + + def resolve_annotations( raw_annotations: Dict[str, Type[Any]], module_name: Optional[str] = None ) -> Dict[str, Type[Any]]: @@ -86,12 +103,19 @@ def resolve_annotations( locals_ = {} for name, value in raw_annotations.items(): if isinstance(value, str): + + # Strip quoted string annotions for Python 3.7 and 3.8 + if sys.version_info < (3, 9): + value = _strip_quoted_annotations(value) + + # Turn the string type annotation into a ForwardRef for processing if (3, 10) > sys.version_info >= (3, 9, 8) or sys.version_info >= (3, 10, 1): value = ForwardRef(value, is_argument=False, is_class=True) else: value = ForwardRef(value, is_argument=False) try: type_ = _eval_type(value, base_globals, None) + except NameError: # The type annotation could refer to a definition at a non-global scope so build # the locals from the calling context. We reuse the same set of locals for