Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend the forward reference handling to more complex boundary cases #132

Merged
merged 4 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 53 additions & 4 deletions clorm/orm/_typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import inspect
import sys
from inspect import FrameInfo
from typing import Any, Dict, ForwardRef, Optional, Tuple, Type, TypeVar, Union, _eval_type, cast

from clingo import Symbol
Expand Down Expand Up @@ -60,11 +62,29 @@ def get_args(t: Type[Any]) -> Tuple[Any, ...]:
return getattr(t, "__args__", ())


if sys.version_info < (3, 9):
from ast import literal_eval

def _strip_quoted_annotations(annotation: str) -> str:
"""Strip quotes around any annotations.

This is needed because the _eval_type() function for Python 3.8 and 3.7 doesn't
handle a ForwardRef that contains a quoted string.

"""
try:
output = literal_eval(annotation)
return output if isinstance(output, str) else annotation
except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
return annotation


def resolve_annotations(
raw_annotations: Dict[str, Type[Any]], module_name: Optional[str] = None
) -> Dict[str, Type[Any]]:
"""
Taken from https://github.com/pydantic/pydantic/blob/1.10.X-fixes/pydantic/typing.py#L376
with some modifications for handling when the first _eval_type() call fails.

Resolve string or ForwardRef annotations into type objects if possible.
"""
Expand All @@ -79,16 +99,45 @@ def resolve_annotations(
base_globals = module.__dict__

annotations = {}
frameinfos: Union[list[FrameInfo], None] = None
locals_ = {}
for name, value in raw_annotations.items():
if isinstance(value, str):

# Strip quoted string annotions for Python 3.7 and 3.8
if sys.version_info < (3, 9):
value = _strip_quoted_annotations(value)

# Turn the string type annotation into a ForwardRef for processing
if (3, 10) > sys.version_info >= (3, 9, 8) or sys.version_info >= (3, 10, 1):
value = ForwardRef(value, is_argument=False, is_class=True)
else:
value = ForwardRef(value, is_argument=False)
try:
value = _eval_type(value, base_globals, None)
type_ = _eval_type(value, base_globals, None)

except NameError:
# this is ok, it can be fixed with update_forward_refs
pass
annotations[name] = value
# The type annotation could refer to a definition at a non-global scope so build
# the locals from the calling context. We reuse the same set of locals for
# multiple annotations.
if frameinfos is None:
frameinfos = inspect.stack()
if len(frameinfos) < 4:
raise RuntimeError(
'Cannot resolve field "{name}" with type annotation "{value}"'
)
frameinfos = frameinfos[3:]
type_ = None
while frameinfos:
try:
type_ = _eval_type(value, base_globals, locals_)
break
except NameError:
finfo = frameinfos.pop(0)
locals_.update(finfo.frame.f_locals)
if type_ is None:
raise RuntimeError(
f'Cannot resolve field "{name}" with type annotation "{value}"'
)
annotations[name] = type_
return annotations
50 changes: 29 additions & 21 deletions clorm/orm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2918,7 +2918,6 @@ def infer_field_definition(type_: Type[Any], module: str) -> Optional[Type[BaseF
def _make_predicatedefn(
class_name: str, namespace: Dict[str, Any], meta_dct: Dict[str, Any]
) -> PredicateDefn:

# Set the default predicate name
pname = _predicatedefn_default_predicate_name(class_name)
anon = False
Expand Down Expand Up @@ -3001,15 +3000,37 @@ def _make_predicatedefn(
)
fields_from_dct[fname] = fdefn

fields_from_annotations = {}
module = namespace.get("__module__", None)
for name, type_ in resolve_annotations(namespace.get("__annotations__", {}), module).items():
if name in fields_from_dct: # first check if FieldDefinition was assigned
fields_from_annotations[name] = fields_from_dct[name]
else: # if not try to infer the definition based on the type

# Get the list of fields with annotations
annotations = namespace.get("__annotations__", {})

# If using type annotations then all fields must be annotated - however some fields can
# have field definition overrides.
if not annotations:
field_specification = fields_from_dct
else:
set_anno = set(annotations)
set_dct = set(fields_from_dct)
if not (set_dct <= set_anno):
raise TypeError(
(
f"Predicate '{pname}' contains a mixture of type annotated and un-annotated "
f"fields ({set_anno} and {set_dct}). If one field is annotated then all fields "
"must be annotated"
)
)

raw_annotations = {
name_: type_ for name_, type_ in annotations.items() if name_ not in fields_from_dct
}
field_specification = {
name: fields_from_dct.get(name, None) for name, _ in annotations.items()
}
for name, type_ in resolve_annotations(raw_annotations, module).items():
fdefn = infer_field_definition(type_, module)
if fdefn:
fields_from_annotations[name] = fdefn
field_specification[name] = fdefn
elif inspect.isclass(type_):
raise TypeError(
(
Expand All @@ -3018,22 +3039,9 @@ def _make_predicatedefn(
)
)

# TODO can this be done more elegantly
set_anno = set(fields_from_annotations)
set_dct = set(fields_from_dct)
set_union = set_dct.union(set_anno)
if set_dct < set_union > set_anno:
raise TypeError(
(
f"Predicate '{pname}': Mixed fields are not allowed. "
"(one field has just an annotation, the other one was only assigned a FieldDefinition)"
)
)

fas = []
idx = 0
fields_from_annotations.update(**fields_from_dct)
for fname, fdefn in fields_from_annotations.items():
for fname, fdefn in field_specification.items():
try:
fd = get_field_definition(fdefn, module)
fa = FieldAccessor(fname, idx, fd)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
clingo>=5.5.1
typing_extensions; python_version < '3.8'
typing_extensions; python_version < '3.11'
dataclasses; python_version == '3.6'
61 changes: 60 additions & 1 deletion tests/test_forward_ref.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import importlib
import importlib.util
import inspect
import secrets
import sys
Expand Down Expand Up @@ -86,6 +86,7 @@ class P1(Predicate):

def test_postponed_annotations_complex(self):
code = """
from __future__ import annotations
from clorm import Predicate
from typing import Union

Expand All @@ -105,6 +106,64 @@ class P3(Predicate):
p = module.P3(a=module.P2(a=42))
self.assertEqual(str(p), "p3(p2(42))")

def test_postponed_annotations_nonglobal1(self):
code = """
from __future__ import annotations
from clorm import Predicate, ConstantField, field
from typing import Union

def define_predicates():


class P1(Predicate):
a1: str = field(ConstantField)
a: int
b: str

class P2(Predicate):
a: Union[int, P1]

return P1, P2

XP1, XP2 = define_predicates()

"""
with self._create_module(code) as module:
p1 = module.XP1(a1="c", a=3, b="42")
self.assertEqual(str(p1), 'p1(c,3,"42")')
p2 = module.XP2(a=p1)
self.assertEqual(str(p2), 'p2(p1(c,3,"42"))')

def test_postponed_annotations_nonglobal2(self):
code = """
from __future__ import annotations
from clorm import Predicate, ConstantField, field
from typing import Union

def define_predicates():


class P1(Predicate):
a1: str = field(ConstantField)
a: int
b: str

def define_complex():
class P2(Predicate):
a: Union[int, P1]
return P2

return P1, define_complex()

XP1, XP2 = define_predicates()

"""
with self._create_module(code) as module:
p1 = module.XP1(a1="c", a=3, b="42")
self.assertEqual(str(p1), 'p1(c,3,"42")')
p2 = module.XP2(a=p1)
self.assertEqual(str(p2), 'p2(p1(c,3,"42"))')

def test_forward_ref(self):
def module_():
from typing import ForwardRef
Expand Down
6 changes: 5 additions & 1 deletion tests/test_mypy_query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import sys
from typing import Tuple

from typing_extensions import reveal_type
if sys.version_info < (3, 11):
from typing_extensions import reveal_type
else:
from typing import reveal_type

from clorm import FactBase, Predicate
from clorm.orm._queryimpl import GroupedQuery, UnGroupedQuery
Expand Down
2 changes: 1 addition & 1 deletion tests/test_orm_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,7 +1489,7 @@ class Blah(ComplexTerm):
def test_predicates_with_annotated_fields(self):
class P(Predicate):
a: int = IntegerField
b = StringField
b: str = StringField

class P1(Predicate):
a: int
Expand Down
Loading