Skip to content

Commit

Permalink
Merge pull request potassco#132 from potassco/postponed_annotations
Browse files Browse the repository at this point in the history
Extend the forward reference handling to more complex boundary cases
  • Loading branch information
daveraja authored Feb 14, 2024
2 parents c66f441 + 8cd478c commit 24d053f
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 29 deletions.
57 changes: 53 additions & 4 deletions clorm/orm/_typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import inspect
import sys
from inspect import FrameInfo
from typing import Any, Dict, ForwardRef, Optional, Tuple, Type, TypeVar, Union, _eval_type, cast

from clingo import Symbol
Expand Down Expand Up @@ -60,11 +62,29 @@ def get_args(t: Type[Any]) -> Tuple[Any, ...]:
return getattr(t, "__args__", ())


if sys.version_info < (3, 9):
from ast import literal_eval

def _strip_quoted_annotations(annotation: str) -> str:
"""Strip quotes around any annotations.
This is needed because the _eval_type() function for Python 3.8 and 3.7 doesn't
handle a ForwardRef that contains a quoted string.
"""
try:
output = literal_eval(annotation)
return output if isinstance(output, str) else annotation
except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
return annotation


def resolve_annotations(
raw_annotations: Dict[str, Type[Any]], module_name: Optional[str] = None
) -> Dict[str, Type[Any]]:
"""
Taken from https://github.com/pydantic/pydantic/blob/1.10.X-fixes/pydantic/typing.py#L376
with some modifications for handling when the first _eval_type() call fails.
Resolve string or ForwardRef annotations into type objects if possible.
"""
Expand All @@ -79,16 +99,45 @@ def resolve_annotations(
base_globals = module.__dict__

annotations = {}
frameinfos: Union[list[FrameInfo], None] = None
locals_ = {}
for name, value in raw_annotations.items():
if isinstance(value, str):

# Strip quoted string annotions for Python 3.7 and 3.8
if sys.version_info < (3, 9):
value = _strip_quoted_annotations(value)

# Turn the string type annotation into a ForwardRef for processing
if (3, 10) > sys.version_info >= (3, 9, 8) or sys.version_info >= (3, 10, 1):
value = ForwardRef(value, is_argument=False, is_class=True)
else:
value = ForwardRef(value, is_argument=False)
try:
value = _eval_type(value, base_globals, None)
type_ = _eval_type(value, base_globals, None)

except NameError:
# this is ok, it can be fixed with update_forward_refs
pass
annotations[name] = value
# The type annotation could refer to a definition at a non-global scope so build
# the locals from the calling context. We reuse the same set of locals for
# multiple annotations.
if frameinfos is None:
frameinfos = inspect.stack()
if len(frameinfos) < 4:
raise RuntimeError(
'Cannot resolve field "{name}" with type annotation "{value}"'
)
frameinfos = frameinfos[3:]
type_ = None
while frameinfos:
try:
type_ = _eval_type(value, base_globals, locals_)
break
except NameError:
finfo = frameinfos.pop(0)
locals_.update(finfo.frame.f_locals)
if type_ is None:
raise RuntimeError(
f'Cannot resolve field "{name}" with type annotation "{value}"'
)
annotations[name] = type_
return annotations
50 changes: 29 additions & 21 deletions clorm/orm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2918,7 +2918,6 @@ def infer_field_definition(type_: Type[Any], module: str) -> Optional[Type[BaseF
def _make_predicatedefn(
class_name: str, namespace: Dict[str, Any], meta_dct: Dict[str, Any]
) -> PredicateDefn:

# Set the default predicate name
pname = _predicatedefn_default_predicate_name(class_name)
anon = False
Expand Down Expand Up @@ -3001,15 +3000,37 @@ def _make_predicatedefn(
)
fields_from_dct[fname] = fdefn

fields_from_annotations = {}
module = namespace.get("__module__", None)
for name, type_ in resolve_annotations(namespace.get("__annotations__", {}), module).items():
if name in fields_from_dct: # first check if FieldDefinition was assigned
fields_from_annotations[name] = fields_from_dct[name]
else: # if not try to infer the definition based on the type

# Get the list of fields with annotations
annotations = namespace.get("__annotations__", {})

# If using type annotations then all fields must be annotated - however some fields can
# have field definition overrides.
if not annotations:
field_specification = fields_from_dct
else:
set_anno = set(annotations)
set_dct = set(fields_from_dct)
if not (set_dct <= set_anno):
raise TypeError(
(
f"Predicate '{pname}' contains a mixture of type annotated and un-annotated "
f"fields ({set_anno} and {set_dct}). If one field is annotated then all fields "
"must be annotated"
)
)

raw_annotations = {
name_: type_ for name_, type_ in annotations.items() if name_ not in fields_from_dct
}
field_specification = {
name: fields_from_dct.get(name, None) for name, _ in annotations.items()
}
for name, type_ in resolve_annotations(raw_annotations, module).items():
fdefn = infer_field_definition(type_, module)
if fdefn:
fields_from_annotations[name] = fdefn
field_specification[name] = fdefn
elif inspect.isclass(type_):
raise TypeError(
(
Expand All @@ -3018,22 +3039,9 @@ def _make_predicatedefn(
)
)

# TODO can this be done more elegantly
set_anno = set(fields_from_annotations)
set_dct = set(fields_from_dct)
set_union = set_dct.union(set_anno)
if set_dct < set_union > set_anno:
raise TypeError(
(
f"Predicate '{pname}': Mixed fields are not allowed. "
"(one field has just an annotation, the other one was only assigned a FieldDefinition)"
)
)

fas = []
idx = 0
fields_from_annotations.update(**fields_from_dct)
for fname, fdefn in fields_from_annotations.items():
for fname, fdefn in field_specification.items():
try:
fd = get_field_definition(fdefn, module)
fa = FieldAccessor(fname, idx, fd)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
clingo>=5.5.1
typing_extensions; python_version < '3.8'
typing_extensions; python_version < '3.11'
dataclasses; python_version == '3.6'
61 changes: 60 additions & 1 deletion tests/test_forward_ref.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import importlib
import importlib.util
import inspect
import secrets
import sys
Expand Down Expand Up @@ -86,6 +86,7 @@ class P1(Predicate):

def test_postponed_annotations_complex(self):
code = """
from __future__ import annotations
from clorm import Predicate
from typing import Union
Expand All @@ -105,6 +106,64 @@ class P3(Predicate):
p = module.P3(a=module.P2(a=42))
self.assertEqual(str(p), "p3(p2(42))")

def test_postponed_annotations_nonglobal1(self):
code = """
from __future__ import annotations
from clorm import Predicate, ConstantField, field
from typing import Union
def define_predicates():
class P1(Predicate):
a1: str = field(ConstantField)
a: int
b: str
class P2(Predicate):
a: Union[int, P1]
return P1, P2
XP1, XP2 = define_predicates()
"""
with self._create_module(code) as module:
p1 = module.XP1(a1="c", a=3, b="42")
self.assertEqual(str(p1), 'p1(c,3,"42")')
p2 = module.XP2(a=p1)
self.assertEqual(str(p2), 'p2(p1(c,3,"42"))')

def test_postponed_annotations_nonglobal2(self):
code = """
from __future__ import annotations
from clorm import Predicate, ConstantField, field
from typing import Union
def define_predicates():
class P1(Predicate):
a1: str = field(ConstantField)
a: int
b: str
def define_complex():
class P2(Predicate):
a: Union[int, P1]
return P2
return P1, define_complex()
XP1, XP2 = define_predicates()
"""
with self._create_module(code) as module:
p1 = module.XP1(a1="c", a=3, b="42")
self.assertEqual(str(p1), 'p1(c,3,"42")')
p2 = module.XP2(a=p1)
self.assertEqual(str(p2), 'p2(p1(c,3,"42"))')

def test_forward_ref(self):
def module_():
from typing import ForwardRef
Expand Down
6 changes: 5 additions & 1 deletion tests/test_mypy_query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import sys
from typing import Tuple

from typing_extensions import reveal_type
if sys.version_info < (3, 11):
from typing_extensions import reveal_type
else:
from typing import reveal_type

from clorm import FactBase, Predicate
from clorm.orm._queryimpl import GroupedQuery, UnGroupedQuery
Expand Down
2 changes: 1 addition & 1 deletion tests/test_orm_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,7 +1489,7 @@ class Blah(ComplexTerm):
def test_predicates_with_annotated_fields(self):
class P(Predicate):
a: int = IntegerField
b = StringField
b: str = StringField

class P1(Predicate):
a: int
Expand Down

0 comments on commit 24d053f

Please sign in to comment.