Skip to content

Commit

Permalink
Extend the forward reference handling to more complex boundary cases
Browse files Browse the repository at this point in the history
Should now work when the Predicate definition with type annotation is defined in a non-global
scope (eg. within a function) and is referring to another predicate definition of a different
scope.
  • Loading branch information
daveraja committed Feb 13, 2024
1 parent c66f441 commit 23f8290
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 27 deletions.
27 changes: 23 additions & 4 deletions clorm/orm/_typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import inspect
from typing import Any, Dict, ForwardRef, Optional, Tuple, Type, TypeVar, Union, _eval_type, cast

from clingo import Symbol
Expand Down Expand Up @@ -65,6 +66,7 @@ def resolve_annotations(
) -> Dict[str, Type[Any]]:
"""
Taken from https://github.com/pydantic/pydantic/blob/1.10.X-fixes/pydantic/typing.py#L376
with some modifications for handling when the first _eval_type() call fails.
Resolve string or ForwardRef annotations into type objects if possible.
"""
Expand All @@ -86,9 +88,26 @@ def resolve_annotations(
else:
value = ForwardRef(value, is_argument=False)
try:
value = _eval_type(value, base_globals, None)
type_ = _eval_type(value, base_globals, None)
except NameError:
# this is ok, it can be fixed with update_forward_refs
pass
annotations[name] = value
# The type annotation could refer to a definition at a non-global scope so build
# the locals from the calling context.
currframe = inspect.currentframe()
finfos = inspect.getouterframes(currframe)
if len(finfos) < 4:
raise RuntimeError('Cannot resolve field "{name}" with type annotation "{value}"')
locals_ = {}
type_ = None
for finfo in finfos[3:]:
locals_.update(finfo.frame.f_locals)
try:
type_ = _eval_type(value, base_globals, locals_)
break
except NameError:
pass
if type_ is None:
raise RuntimeError(
f'Cannot resolve field "{name}" with type annotation "{value}"'
)
annotations[name] = type_
return annotations
50 changes: 29 additions & 21 deletions clorm/orm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2918,7 +2918,6 @@ def infer_field_definition(type_: Type[Any], module: str) -> Optional[Type[BaseF
def _make_predicatedefn(
class_name: str, namespace: Dict[str, Any], meta_dct: Dict[str, Any]
) -> PredicateDefn:

# Set the default predicate name
pname = _predicatedefn_default_predicate_name(class_name)
anon = False
Expand Down Expand Up @@ -3001,15 +3000,37 @@ def _make_predicatedefn(
)
fields_from_dct[fname] = fdefn

fields_from_annotations = {}
module = namespace.get("__module__", None)
for name, type_ in resolve_annotations(namespace.get("__annotations__", {}), module).items():
if name in fields_from_dct: # first check if FieldDefinition was assigned
fields_from_annotations[name] = fields_from_dct[name]
else: # if not try to infer the definition based on the type

# Get the list of fields with annotations
annotations = namespace.get("__annotations__", {})

# If using type annotations then all fields must be annotated - however some fields can
# have field definition overrides.
if not annotations:
field_specification = fields_from_dct
else:
set_anno = set(annotations)
set_dct = set(fields_from_dct)
if not (set_dct <= set_anno):
raise TypeError(
(
f"Predicate '{pname}' contains a mixture of type annotated and un-annotated "
f"fields ({set_anno} and {set_dct}). If one field is annotated then all fields "
"must be annotated"
)
)

raw_annotations = {
name_: type_ for name_, type_ in annotations.items() if name_ not in fields_from_dct
}
field_specification = {
name: fields_from_dct.get(name, None) for name, _ in annotations.items()
}
for name, type_ in resolve_annotations(raw_annotations, module).items():
fdefn = infer_field_definition(type_, module)
if fdefn:
fields_from_annotations[name] = fdefn
field_specification[name] = fdefn
elif inspect.isclass(type_):
raise TypeError(
(
Expand All @@ -3018,22 +3039,9 @@ def _make_predicatedefn(
)
)

# TODO can this be done more elegantly
set_anno = set(fields_from_annotations)
set_dct = set(fields_from_dct)
set_union = set_dct.union(set_anno)
if set_dct < set_union > set_anno:
raise TypeError(
(
f"Predicate '{pname}': Mixed fields are not allowed. "
"(one field has just an annotation, the other one was only assigned a FieldDefinition)"
)
)

fas = []
idx = 0
fields_from_annotations.update(**fields_from_dct)
for fname, fdefn in fields_from_annotations.items():
for fname, fdefn in field_specification.items():
try:
fd = get_field_definition(fdefn, module)
fa = FieldAccessor(fname, idx, fd)
Expand Down
62 changes: 61 additions & 1 deletion tests/test_forward_ref.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import importlib
import importlib.util
import inspect
import secrets
import sys
Expand Down Expand Up @@ -86,6 +86,7 @@ class P1(Predicate):

def test_postponed_annotations_complex(self):
code = """
from __future__ import annotations
from clorm import Predicate
from typing import Union
Expand All @@ -105,6 +106,65 @@ class P3(Predicate):
p = module.P3(a=module.P2(a=42))
self.assertEqual(str(p), "p3(p2(42))")

def test_postponed_annotations_nonglobal1(self):
code = """
from __future__ import annotations
from clorm import Predicate, ConstantField, field
from typing import Union
def define_predicates():
class P1(Predicate):
a1: str = field(ConstantField)
a: int
b: str
class P2(Predicate):
a: Union[int, P1]
return P1, P2
XP1, XP2 = define_predicates()
"""
with self._create_module(code) as module:
p1 = module.XP1(a1="c", a=3, b="42")
self.assertEqual(str(p1), 'p1(c,3,"42")')
p2 = module.XP2(a=p1)
self.assertEqual(str(p2), 'p2(p1(c,3,"42"))')

def test_postponed_annotations_nonglobal2(self):
code = """
from __future__ import annotations
from clorm import Predicate, ConstantField, field
from typing import Union
def define_predicates():
class P1(Predicate):
a1: str = field(ConstantField)
a: int
b: str
def define_complex():
class P2(Predicate):
a: Union[int, P1]
return P2
return P1, define_complex()
XP1, XP2 = define_predicates()
"""
with self._create_module(code) as module:
p1 = module.XP1(a1="c", a=3, b="42")
self.assertEqual(str(p1), 'p1(c,3,"42")')
p2 = module.XP2(a=p1)
self.assertEqual(str(p2), 'p2(p1(c,3,"42"))')


def test_forward_ref(self):
def module_():
from typing import ForwardRef
Expand Down
2 changes: 1 addition & 1 deletion tests/test_orm_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,7 +1489,7 @@ class Blah(ComplexTerm):
def test_predicates_with_annotated_fields(self):
class P(Predicate):
a: int = IntegerField
b = StringField
b: str = StringField

class P1(Predicate):
a: int
Expand Down

0 comments on commit 23f8290

Please sign in to comment.