From ebcac866f5d30c8d17a0a57e7dddc4c3d95a470d Mon Sep 17 00:00:00 2001 From: FlorianFischer Date: Sat, 10 Feb 2024 10:52:42 +0100 Subject: [PATCH] add support for future annotations --- clorm/orm/_typing.py | 36 +++++++- clorm/orm/atsyntax.py | 28 ++++--- clorm/orm/core.py | 13 +-- tests/__init__.py | 1 + tests/test_forward_ref.py | 167 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 227 insertions(+), 18 deletions(-) create mode 100644 tests/test_forward_ref.py diff --git a/clorm/orm/_typing.py b/clorm/orm/_typing.py index f896f8a..473272a 100644 --- a/clorm/orm/_typing.py +++ b/clorm/orm/_typing.py @@ -1,5 +1,5 @@ import sys -from typing import Any, Optional, Tuple, Type, TypeVar, Union, cast +from typing import Any, Dict, ForwardRef, Optional, Tuple, Type, TypeVar, Union, _eval_type, cast from clingo import Symbol @@ -58,3 +58,37 @@ def get_args(t: Type[Any]) -> Tuple[Any, ...]: res = (list(res[:-1]), res[-1]) return res return getattr(t, "__args__", ()) + + +def resolve_annotations( + raw_annotations: Dict[str, Type[Any]], module_name: Optional[str] = None +) -> Dict[str, Type[Any]]: + """ + Taken from https://github.com/pydantic/pydantic/blob/1.10.X-fixes/pydantic/typing.py#L376 + + Resolve string or ForwardRef annotations into type objects if possible. + """ + base_globals: Optional[Dict[str, Any]] = None + if module_name: + try: + module = sys.modules[module_name] + except KeyError: + # happens occasionally, see https://github.com/pydantic/pydantic/issues/2363 + pass + else: + base_globals = module.__dict__ + + annotations = {} + for name, value in raw_annotations.items(): + if isinstance(value, str): + if (3, 10) > sys.version_info >= (3, 9, 8) or sys.version_info >= (3, 10, 1): + value = ForwardRef(value, is_argument=False, is_class=True) + else: + value = ForwardRef(value, is_argument=False) + try: + value = _eval_type(value, base_globals, None) + except NameError: + # this is ok, it can be fixed with update_forward_refs + pass + annotations[name] = value + return annotations diff --git a/clorm/orm/atsyntax.py b/clorm/orm/atsyntax.py index 1e1b910..97d3eff 100644 --- a/clorm/orm/atsyntax.py +++ b/clorm/orm/atsyntax.py @@ -6,9 +6,9 @@ import collections.abc as cabc import functools import inspect -from typing import Any, Callable, List, Sequence, Tuple, Type +from typing import Any, Callable, List, Sequence, Tuple, Type, Union -from .core import BaseField, get_field_definition, infer_field_definition +from .core import BaseField, get_field_definition, infer_field_definition, resolve_annotations __all__ = [ "TypeCastSignature", @@ -36,6 +36,7 @@ class TypeCastSignature(object): r"""Defines a signature for converting to/from Clingo data types. Args: + module: Name of the module where the signature is defined sigs(\*sigs): A list of signature elements. - Inputs. Match the sub-elements [:-1] define the input signature while @@ -54,9 +55,9 @@ class DateField(StringField): pytocl = lambda dt: dt.strftime("%Y%m%d") cltopy = lambda s: datetime.datetime.strptime(s,"%Y%m%d").date() - drsig = TypeCastSignature(DateField, DateField, [DateField]) + drsig = TypeCastSignature(DateField, DateField, [DateField], module = "__main__") - @drsig.make_clingo_wrapper + @drsig.wrap_function def date_range(start, end): return [ start + timedelta(days=x) for x in range(0,end-start) ] @@ -97,24 +98,29 @@ def _is_output_field(o): return _is_output_field(se[0]) return _is_output_field(se) - def __init__(self, *sigs: Any) -> None: + def __init__(self, *sigs: Any, module: Union[str, None] = None) -> None: + module = self.__module__ if module is None else module + def _validate_basic_sig(sig): if TypeCastSignature._is_input_element(sig): return True raise TypeError( - ("TypeCastSignature element {} must be a BaseField " "subclass".format(sig)) + "TypeCastSignature element {} must be a BaseField subclass".format(sig) ) insigs: List[Type[BaseField]] = [] for s in sigs[:-1]: field = None try: - field = infer_field_definition(s, "") + resolved = resolve_annotations({"__tmp__": s}, module)["__tmp__"] + field = infer_field_definition(resolved, "") except Exception: pass insigs.append(field if field else type(get_field_definition(s))) try: - self._outsig = infer_field_definition(sigs[-1], "") or sigs[-1] + outsig = sigs[-1] + outsig = resolve_annotations({"__tmp__": outsig}, module)["__tmp__"] + self._outsig = infer_field_definition(outsig, "") or outsig except Exception: self._outsig = sigs[-1] @@ -327,7 +333,7 @@ def make_function_asp_callable(*args: Any) -> _AnyCallable: # A decorator function that adjusts for the given signature def _sig_decorate(func): - s = TypeCastSignature(*sigs) + s = TypeCastSignature(*sigs, module=func.__module__) return s.wrap_function(func) # If no function and sig then called as a decorator with arguments @@ -372,7 +378,7 @@ def make_method_asp_callable(*args: Any) -> _AnyCallable: # A decorator function that adjusts for the given signature def _sig_decorate(func): - s = TypeCastSignature(*sigs) + s = TypeCastSignature(*sigs, module=func.__module__) return s.wrap_method(func) # If no function and sig then called as a decorator with arguments @@ -479,7 +485,7 @@ def _decorator(fn): args = sigargs else: args = _get_annotations(fn) - s = TypeCastSignature(*args) + s = TypeCastSignature(*args, module=fn.__module__) self._add_function(fname, s, fn) return fn diff --git a/clorm/orm/core.py b/clorm/orm/core.py index b034dad..cec0826 100644 --- a/clorm/orm/core.py +++ b/clorm/orm/core.py @@ -47,7 +47,7 @@ TailListReversed, ) -from ._typing import AnySymbol, get_args, get_origin +from ._typing import AnySymbol, get_args, get_origin, resolve_annotations from .noclingo import ( Function, Number, @@ -2866,6 +2866,9 @@ def infer_field_definition(type_: Type[Any], module: str) -> Optional[Type[BaseF tuple(infer_field_definition(arg, module) for arg in args), module ) ) + if not isinstance(type_, type): + return None + # from here on only check for subclass if issubclass(type_, enum.Enum): # if type_ just inherits from Enum is IntegerField, otherwise find appropriate Field field = ( @@ -3000,13 +3003,11 @@ def _make_predicatedefn( fields_from_annotations = {} module = namespace.get("__module__", None) - for name, type_ in namespace.get("__annotations__", {}).items(): + for name, type_ in resolve_annotations(namespace.get("__annotations__", {}), module).items(): if name in fields_from_dct: # first check if FieldDefinition was assigned fields_from_annotations[name] = fields_from_dct[name] - else: - fdefn = infer_field_definition( - type_, module - ) # if not try to infer the definition based on the type + else: # if not try to infer the definition based on the type + fdefn = infer_field_definition(type_, module) if fdefn: fields_from_annotations[name] = fdefn elif inspect.isclass(type_): diff --git a/tests/__init__.py b/tests/__init__.py index 59808ff..ee924d0 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -8,6 +8,7 @@ os.environ["CLORM_NOCLINGO"] = "True" from .test_clingo import * +from .test_forward_ref import * from .test_json import * from .test_libdate import LibDateTestCase from .test_libtimeslot import * diff --git a/tests/test_forward_ref.py b/tests/test_forward_ref.py new file mode 100644 index 0000000..979898a --- /dev/null +++ b/tests/test_forward_ref.py @@ -0,0 +1,167 @@ +import importlib +import inspect +import secrets +import sys +import tempfile +import textwrap +import unittest +from contextlib import contextmanager +from pathlib import Path +from types import FunctionType + +from clingo import Number, String + +__all__ = [ + "ForwardRefTestCase", +] + + +def _extract_source_code_from_function(function): + if function.__code__.co_argcount: + raise RuntimeError(f"function {function.__qualname__} cannot have any arguments") + + code_lines = "" + body_started = False + for line in textwrap.dedent(inspect.getsource(function)).split("\n"): + if line.startswith("def "): + body_started = True + continue + elif body_started: + code_lines += f"{line}\n" + + return textwrap.dedent(code_lines) + + +def _create_module_file(code, tmp_path, name): + name = f"{name}_{secrets.token_hex(5)}" + path = Path(tmp_path, f"{name}.py") + path.write_text(code) + return name, str(path) + + +def create_module(tmp_path, method_name): + def run(source_code_or_function): + """ + Create module object, execute it and return + + :param source_code_or_function string or function with body as a source code for created module + + """ + if isinstance(source_code_or_function, FunctionType): + source_code = _extract_source_code_from_function(source_code_or_function) + else: + source_code = source_code_or_function + + module_name, filename = _create_module_file(source_code, tmp_path, method_name) + + spec = importlib.util.spec_from_file_location(module_name, filename, loader=None) + sys.modules[module_name] = module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + return run + + +class ForwardRefTestCase(unittest.TestCase): + def setUp(self): + @contextmanager + def f(source_code_or_function): + with tempfile.TemporaryDirectory() as tmp_path: + yield create_module(tmp_path, self._testMethodName)(source_code_or_function) + + self._create_module = f + + def test_postponed_annotations(self): + code = """ +from __future__ import annotations +from clorm import Predicate + +class P1(Predicate): + a: int + b: str +""" + with self._create_module(code) as module: + p = module.P1(a=3, b="42") + self.assertEqual(str(p), 'p1(3,"42")') + + def test_postponed_annotations_complex(self): + code = """ +from clorm import Predicate +from typing import Union + +class P1(Predicate): + a: int + b: str + +class P2(Predicate): + a: int + +class P3(Predicate): + a: 'Union[P1, P2]' +""" + with self._create_module(code) as module: + p = module.P3(a=module.P1(a=3, b="42")) + self.assertEqual(str(p), 'p3(p1(3,"42"))') + p = module.P3(a=module.P2(a=42)) + self.assertEqual(str(p), "p3(p2(42))") + + def test_forward_ref(self): + def module_(): + from typing import ForwardRef + + from clorm import Predicate + + class A(Predicate): + a: int + + ARef = ForwardRef("A") + + class B(Predicate): + a: ARef + + with self._create_module(module_) as module: + b = module.B(a=module.A(a=42)) + self.assertEqual(str(b), "b(a(42))") + + def test_forward_ref_list(self): + def module_(): + from typing import ForwardRef + + from clorm import HeadList, Predicate + + class A(Predicate): + a: int + + ARef = ForwardRef("A") + + class B(Predicate): + a: HeadList[ARef] + + with self._create_module(module_) as module: + b = module.B(a=[module.A(a=41), module.A(a=42)]) + self.assertEqual(str(b), "b((a(41),(a(42),())))") + + def test_forward_ref_asp_callable(self): + code = """ +from __future__ import annotations +from clorm import Predicate, make_function_asp_callable, make_method_asp_callable + +class P1(Predicate): + a: int + b: str + +@make_function_asp_callable +def f(a: int, b: str) -> P1: + return P1(a,b) + +class Context: + @make_method_asp_callable + def f(self, a: int, b: str) -> P1: + return P1(a,b) +""" + with self._create_module(code) as module: + p = module.f(Number(2), String("2")) + self.assertEqual(str(p), 'p1(2,"2")') + ctx = module.Context() + p = ctx.f(Number(2), String("2")) + self.assertEqual(str(p), 'p1(2,"2")')