Skip to content

Commit

Permalink
Merge pull request potassco#131 from florianfischer91/handle_future_a…
Browse files Browse the repository at this point in the history
…nnotations

add support for 'from __future__ import annotations'
  • Loading branch information
daveraja authored Feb 12, 2024
2 parents 3b3d795 + ebcac86 commit c66f441
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 18 deletions.
36 changes: 35 additions & 1 deletion clorm/orm/_typing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import Any, Optional, Tuple, Type, TypeVar, Union, cast
from typing import Any, Dict, ForwardRef, Optional, Tuple, Type, TypeVar, Union, _eval_type, cast

from clingo import Symbol

Expand Down Expand Up @@ -58,3 +58,37 @@ def get_args(t: Type[Any]) -> Tuple[Any, ...]:
res = (list(res[:-1]), res[-1])
return res
return getattr(t, "__args__", ())


def resolve_annotations(
raw_annotations: Dict[str, Type[Any]], module_name: Optional[str] = None
) -> Dict[str, Type[Any]]:
"""
Taken from https://github.com/pydantic/pydantic/blob/1.10.X-fixes/pydantic/typing.py#L376
Resolve string or ForwardRef annotations into type objects if possible.
"""
base_globals: Optional[Dict[str, Any]] = None
if module_name:
try:
module = sys.modules[module_name]
except KeyError:
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
pass
else:
base_globals = module.__dict__

annotations = {}
for name, value in raw_annotations.items():
if isinstance(value, str):
if (3, 10) > sys.version_info >= (3, 9, 8) or sys.version_info >= (3, 10, 1):
value = ForwardRef(value, is_argument=False, is_class=True)
else:
value = ForwardRef(value, is_argument=False)
try:
value = _eval_type(value, base_globals, None)
except NameError:
# this is ok, it can be fixed with update_forward_refs
pass
annotations[name] = value
return annotations
28 changes: 17 additions & 11 deletions clorm/orm/atsyntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import collections.abc as cabc
import functools
import inspect
from typing import Any, Callable, List, Sequence, Tuple, Type
from typing import Any, Callable, List, Sequence, Tuple, Type, Union

from .core import BaseField, get_field_definition, infer_field_definition
from .core import BaseField, get_field_definition, infer_field_definition, resolve_annotations

__all__ = [
"TypeCastSignature",
Expand Down Expand Up @@ -36,6 +36,7 @@ class TypeCastSignature(object):
r"""Defines a signature for converting to/from Clingo data types.
Args:
module: Name of the module where the signature is defined
sigs(\*sigs): A list of signature elements.
- Inputs. Match the sub-elements [:-1] define the input signature while
Expand All @@ -54,9 +55,9 @@ class DateField(StringField):
pytocl = lambda dt: dt.strftime("%Y%m%d")
cltopy = lambda s: datetime.datetime.strptime(s,"%Y%m%d").date()
drsig = TypeCastSignature(DateField, DateField, [DateField])
drsig = TypeCastSignature(DateField, DateField, [DateField], module = "__main__")
@drsig.make_clingo_wrapper
@drsig.wrap_function
def date_range(start, end):
return [ start + timedelta(days=x) for x in range(0,end-start) ]
Expand Down Expand Up @@ -97,24 +98,29 @@ def _is_output_field(o):
return _is_output_field(se[0])
return _is_output_field(se)

def __init__(self, *sigs: Any) -> None:
def __init__(self, *sigs: Any, module: Union[str, None] = None) -> None:
module = self.__module__ if module is None else module

def _validate_basic_sig(sig):
if TypeCastSignature._is_input_element(sig):
return True
raise TypeError(
("TypeCastSignature element {} must be a BaseField " "subclass".format(sig))
"TypeCastSignature element {} must be a BaseField subclass".format(sig)
)

insigs: List[Type[BaseField]] = []
for s in sigs[:-1]:
field = None
try:
field = infer_field_definition(s, "")
resolved = resolve_annotations({"__tmp__": s}, module)["__tmp__"]
field = infer_field_definition(resolved, "")
except Exception:
pass
insigs.append(field if field else type(get_field_definition(s)))
try:
self._outsig = infer_field_definition(sigs[-1], "") or sigs[-1]
outsig = sigs[-1]
outsig = resolve_annotations({"__tmp__": outsig}, module)["__tmp__"]
self._outsig = infer_field_definition(outsig, "") or outsig
except Exception:
self._outsig = sigs[-1]

Expand Down Expand Up @@ -327,7 +333,7 @@ def make_function_asp_callable(*args: Any) -> _AnyCallable:

# A decorator function that adjusts for the given signature
def _sig_decorate(func):
s = TypeCastSignature(*sigs)
s = TypeCastSignature(*sigs, module=func.__module__)
return s.wrap_function(func)

# If no function and sig then called as a decorator with arguments
Expand Down Expand Up @@ -372,7 +378,7 @@ def make_method_asp_callable(*args: Any) -> _AnyCallable:

# A decorator function that adjusts for the given signature
def _sig_decorate(func):
s = TypeCastSignature(*sigs)
s = TypeCastSignature(*sigs, module=func.__module__)
return s.wrap_method(func)

# If no function and sig then called as a decorator with arguments
Expand Down Expand Up @@ -479,7 +485,7 @@ def _decorator(fn):
args = sigargs
else:
args = _get_annotations(fn)
s = TypeCastSignature(*args)
s = TypeCastSignature(*args, module=fn.__module__)
self._add_function(fname, s, fn)
return fn

Expand Down
13 changes: 7 additions & 6 deletions clorm/orm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
TailListReversed,
)

from ._typing import AnySymbol, get_args, get_origin
from ._typing import AnySymbol, get_args, get_origin, resolve_annotations
from .noclingo import (
Function,
Number,
Expand Down Expand Up @@ -2866,6 +2866,9 @@ def infer_field_definition(type_: Type[Any], module: str) -> Optional[Type[BaseF
tuple(infer_field_definition(arg, module) for arg in args), module
)
)
if not isinstance(type_, type):
return None
# from here on only check for subclass
if issubclass(type_, enum.Enum):
# if type_ just inherits from Enum is IntegerField, otherwise find appropriate Field
field = (
Expand Down Expand Up @@ -3000,13 +3003,11 @@ def _make_predicatedefn(

fields_from_annotations = {}
module = namespace.get("__module__", None)
for name, type_ in namespace.get("__annotations__", {}).items():
for name, type_ in resolve_annotations(namespace.get("__annotations__", {}), module).items():
if name in fields_from_dct: # first check if FieldDefinition was assigned
fields_from_annotations[name] = fields_from_dct[name]
else:
fdefn = infer_field_definition(
type_, module
) # if not try to infer the definition based on the type
else: # if not try to infer the definition based on the type
fdefn = infer_field_definition(type_, module)
if fdefn:
fields_from_annotations[name] = fdefn
elif inspect.isclass(type_):
Expand Down
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
os.environ["CLORM_NOCLINGO"] = "True"

from .test_clingo import *
from .test_forward_ref import *
from .test_json import *
from .test_libdate import LibDateTestCase
from .test_libtimeslot import *
Expand Down
167 changes: 167 additions & 0 deletions tests/test_forward_ref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import importlib
import inspect
import secrets
import sys
import tempfile
import textwrap
import unittest
from contextlib import contextmanager
from pathlib import Path
from types import FunctionType

from clingo import Number, String

__all__ = [
"ForwardRefTestCase",
]


def _extract_source_code_from_function(function):
if function.__code__.co_argcount:
raise RuntimeError(f"function {function.__qualname__} cannot have any arguments")

code_lines = ""
body_started = False
for line in textwrap.dedent(inspect.getsource(function)).split("\n"):
if line.startswith("def "):
body_started = True
continue
elif body_started:
code_lines += f"{line}\n"

return textwrap.dedent(code_lines)


def _create_module_file(code, tmp_path, name):
name = f"{name}_{secrets.token_hex(5)}"
path = Path(tmp_path, f"{name}.py")
path.write_text(code)
return name, str(path)


def create_module(tmp_path, method_name):
def run(source_code_or_function):
"""
Create module object, execute it and return
:param source_code_or_function string or function with body as a source code for created module
"""
if isinstance(source_code_or_function, FunctionType):
source_code = _extract_source_code_from_function(source_code_or_function)
else:
source_code = source_code_or_function

module_name, filename = _create_module_file(source_code, tmp_path, method_name)

spec = importlib.util.spec_from_file_location(module_name, filename, loader=None)
sys.modules[module_name] = module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module

return run


class ForwardRefTestCase(unittest.TestCase):
def setUp(self):
@contextmanager
def f(source_code_or_function):
with tempfile.TemporaryDirectory() as tmp_path:
yield create_module(tmp_path, self._testMethodName)(source_code_or_function)

self._create_module = f

def test_postponed_annotations(self):
code = """
from __future__ import annotations
from clorm import Predicate
class P1(Predicate):
a: int
b: str
"""
with self._create_module(code) as module:
p = module.P1(a=3, b="42")
self.assertEqual(str(p), 'p1(3,"42")')

def test_postponed_annotations_complex(self):
code = """
from clorm import Predicate
from typing import Union
class P1(Predicate):
a: int
b: str
class P2(Predicate):
a: int
class P3(Predicate):
a: 'Union[P1, P2]'
"""
with self._create_module(code) as module:
p = module.P3(a=module.P1(a=3, b="42"))
self.assertEqual(str(p), 'p3(p1(3,"42"))')
p = module.P3(a=module.P2(a=42))
self.assertEqual(str(p), "p3(p2(42))")

def test_forward_ref(self):
def module_():
from typing import ForwardRef

from clorm import Predicate

class A(Predicate):
a: int

ARef = ForwardRef("A")

class B(Predicate):
a: ARef

with self._create_module(module_) as module:
b = module.B(a=module.A(a=42))
self.assertEqual(str(b), "b(a(42))")

def test_forward_ref_list(self):
def module_():
from typing import ForwardRef

from clorm import HeadList, Predicate

class A(Predicate):
a: int

ARef = ForwardRef("A")

class B(Predicate):
a: HeadList[ARef]

with self._create_module(module_) as module:
b = module.B(a=[module.A(a=41), module.A(a=42)])
self.assertEqual(str(b), "b((a(41),(a(42),())))")

def test_forward_ref_asp_callable(self):
code = """
from __future__ import annotations
from clorm import Predicate, make_function_asp_callable, make_method_asp_callable
class P1(Predicate):
a: int
b: str
@make_function_asp_callable
def f(a: int, b: str) -> P1:
return P1(a,b)
class Context:
@make_method_asp_callable
def f(self, a: int, b: str) -> P1:
return P1(a,b)
"""
with self._create_module(code) as module:
p = module.f(Number(2), String("2"))
self.assertEqual(str(p), 'p1(2,"2")')
ctx = module.Context()
p = ctx.f(Number(2), String("2"))
self.assertEqual(str(p), 'p1(2,"2")')

0 comments on commit c66f441

Please sign in to comment.