Skip to content

Commit

Permalink
Allow define_nested_list and define_flat_list to have tuple fields
Browse files Browse the repository at this point in the history
  • Loading branch information
daveraja committed Mar 2, 2024
1 parent 88ce485 commit 9fbe2e4
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 11 deletions.
37 changes: 27 additions & 10 deletions clorm/orm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,31 @@ def pytocl(value):
return String(value)


# ------------------------------------------------------------------------------------------
# Internal function to process field definitions. If the input is a BaseField subclass then
# return it directly. Otherwise tries to generate a new BaseField subclass from a tuple field
# specification. This can be used by the field builder functions when passed a union based
# field.
# ------------------------------------------------------------------------------------------


def _process_field_definition(field_defn: _FieldDefinition) -> Type[BaseField]:
if isinstance(field_defn, tuple):
try:
module = sys._getframe(1).f_globals.get("__name__", "__main__")
except (AttributeError, ValueError):
module = None
return _create_complex_term(field_defn, module=module)
if not inspect.isclass(field_defn) or not issubclass(field_defn, BaseField):
raise TypeError(
(
f"'{field_defn}' must be a '{BaseField}' sub-class or a nested "
f"tuple of {BaseField} sub-class leaves"
)
)
return field_defn


# ------------------------------------------------------------------------------
# refine_field is a function that creates a sub-class of a BaseField (or BaseField
# sub-class). It restricts the set of allowable values based on a functor or an
Expand Down Expand Up @@ -1980,11 +2005,7 @@ def define_flat_list_field(element_field: Type[BaseField], *, name: str = "") ->
"""
subclass_name = name if name else "AnonymousFlatSeqField"
efield = element_field

# The element_field must be a BaseField sub-class
if not inspect.isclass(efield) or not issubclass(efield, BaseField):
raise TypeError("'{}' is not a BaseField or a sub-class".format(efield))
efield = _process_field_definition(element_field)

def _checkpy(v):
if isinstance(v, str) or not isinstance(v, cabc.Iterable):
Expand Down Expand Up @@ -2093,11 +2114,7 @@ def define_nested_list_field(
"""
subclass_name = name if name else "AnonymousNestedSeqField"
efield = element_field

# The element_field must be a BaseField sub-class
if not inspect.isclass(efield) or not issubclass(efield, BaseField):
raise TypeError("'{}' is not a BaseField or a sub-class".format(efield))
efield = _process_field_definition(element_field)

# Support function - to check input values
def _checkpy(v):
Expand Down
25 changes: 25 additions & 0 deletions tests/test_forward_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,31 @@ class P2(Predicate):
p2 = module.XP2(a=p1)
self.assertEqual(str(p2), 'p2(p1(c,3,"42"))')

def test_postponed_annotations_headlist(self):
code = """
from __future__ import annotations
from clorm import Predicate, HeadList
class P(Predicate):
x: HeadList[tuple[int,str]]
"""
with self._create_module(code) as module:
p = module.P(x=((1, "a"), (2, "b")))
self.assertEqual(str(p), 'p(((1,"a"),((2,"b"),())))')

def test_postponed_annotations_flatlist(self):
code = """
from __future__ import annotations
from clorm import Predicate
class P(Predicate):
x: tuple[tuple[int,str], ...]
"""
with self._create_module(code) as module:
p = module.P(x=((1, "a"), (2, "b")))
self.assertEqual(str(p), 'p(((1,"a"),(2,"b")))')


def test_forward_ref(self):
def module_():
from typing import ForwardRef
Expand Down
30 changes: 29 additions & 1 deletion tests/test_orm_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,21 @@ def test_api_nested_list_field(self):
# Some badly defined fields
with self.assertRaises(TypeError) as ctx:
tmp = define_nested_list_field("FG", name="FG")
check_errmsg("'FG' is not a ", ctx)
check_errmsg("'FG' must be ", ctx)

def test_api_nested_list_field_complex_element_field(self):
XField = define_nested_list_field((IntegerField, (ConstantField, StringField)))

symvalue1 = Function("", [Number(1), Function("", [Function("a",[]), String("A")])])
symvalue2 = Function("", [Number(2), Function("", [Function("b",[]), String("B")])])
symnlist = Function("", [symvalue1, Function("", [symvalue2, Function("",[])])])

value1 = (1,("a", "A"))
value2 = (2,("b", "B"))
nlist = (value1, value2)

self.assertEqual(XField.cltopy(symnlist), nlist)
self.assertEqual(XField.pytocl(nlist), symnlist)

# --------------------------------------------------------------------------
# Test defining a field that handles python lists/sequences as a tuple of
Expand Down Expand Up @@ -810,6 +824,20 @@ def test_api_flat_list_field(self):
tmp = CLField.pytocl(1)
check_errmsg("'1' is not a sequence", ctx)

def test_api_flat_list_field_complex_element_field(self):
XField = define_flat_list_field((IntegerField, (ConstantField, StringField)))

symvalue1 = Function("", [Number(1), Function("", [Function("a",[]), String("A")])])
symvalue2 = Function("", [Number(2), Function("", [Function("b",[]), String("B")])])
symnlist = Function("", [symvalue1, symvalue2])

value1 = (1,("a", "A"))
value2 = (2,("b", "B"))
nlist = (value1, value2)

self.assertEqual(XField.cltopy(symnlist), nlist)
self.assertEqual(XField.pytocl(nlist), symnlist)

# --------------------------------------------------------------------------
# Test the different variants for defining a nested list encoding of the
# Python sequence (1,2,3)
Expand Down

0 comments on commit 9fbe2e4

Please sign in to comment.