Skip to content

Commit

Permalink
add support for self-ref objects
Browse files Browse the repository at this point in the history
  • Loading branch information
synodriver committed Jan 4, 2024
1 parent 12b21d0 commit 642b375
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 27 deletions.
54 changes: 33 additions & 21 deletions lupa/_lupa.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -525,15 +525,15 @@ cdef class LuaRuntime:
"""
return self.table_from(items, kwargs)

def table_from(self, *args, int max_depth=1):
def table_from(self, *args, bint recursive=False):
"""Create a new table from Python mapping or iterable.
table_from() accepts either a dict/mapping or an iterable with items.
Items from dicts are set as key-value pairs; items from iterables
are placed in the table in order.
Nested mappings / iterables are passed to Lua as userdata
(wrapped Python objects) if recursive depth is greater than `max_depth`, they are not converted to Lua tables.
(wrapped Python objects) if `recursive` is False, they are not converted to Lua tables.
"""
assert self._state is not NULL
cdef lua_State *L = self._state
Expand All @@ -547,8 +547,8 @@ cdef class LuaRuntime:
for obj in args:
if isinstance(obj, dict):
for key, value in obj.iteritems():
py_to_lua(self, L, key, wrap_none=True, max_depth=max_depth)
py_to_lua(self, L, value, wrap_none=False, max_depth=max_depth)
py_to_lua(self, L, key, True, recursive)
py_to_lua(self, L, value, False, recursive)
lua.lua_rawset(L, -3)

elif isinstance(obj, _LuaTable):
Expand All @@ -564,12 +564,12 @@ cdef class LuaRuntime:
elif isinstance(obj, Mapping):
for key in obj:
value = obj[key]
py_to_lua(self, L, key, wrap_none=True, max_depth=max_depth)
py_to_lua(self, L, value, wrap_none=False, max_depth=max_depth)
py_to_lua(self, L, key, True, recursive)
py_to_lua(self, L, value, False, recursive)
lua.lua_rawset(L, -3)
else:
for arg in obj:
py_to_lua(self, L, arg, wrap_none=False, max_depth=max_depth)
py_to_lua(self, L, arg, False, recursive)
lua.lua_rawseti(L, -2, i)
i += 1
return py_from_lua(self, L, -1)
Expand Down Expand Up @@ -1550,7 +1550,7 @@ cdef int py_to_lua_handle_overflow(LuaRuntime runtime, lua_State *L, object o) e
lua.lua_settop(L, old_top)
raise

cdef int py_to_lua(LuaRuntime runtime, lua_State *L, object o, bint wrap_none=False, int max_depth=1, int current_depth=0) except -1:
cdef int py_to_lua(LuaRuntime runtime, lua_State *L, object o, bint wrap_none=False, bint recursive=False, dict mapped_objs = None) except -1:
"""Converts Python object to Lua
Preconditions:
1 extra slot in the Lua stack
Expand All @@ -1559,8 +1559,6 @@ cdef int py_to_lua(LuaRuntime runtime, lua_State *L, object o, bint wrap_none=Fa
Returns 0 if cannot convert Python object to Lua
Returns 1 if the Python object was converted successfully and pushed onto the stack
"""
if current_depth >= max_depth:
raise ValueError("max recursive depth reached")
cdef int pushed_values_count = 0
cdef int type_flags = 0

Expand Down Expand Up @@ -1608,18 +1606,32 @@ cdef int py_to_lua(LuaRuntime runtime, lua_State *L, object o, bint wrap_none=Fa
type_flags = (<_PyProtocolWrapper> o)._type_flags
o = (<_PyProtocolWrapper> o)._obj
pushed_values_count = py_to_lua_custom(runtime, L, o, type_flags)
elif max_depth>1 and isinstance(o, (list, Sequence)):
lua.lua_createtable(L, _len_as_int(len(o)), 0) # create a table at the top of stack, with narr already known
for i, v in enumerate(o, 1):
py_to_lua(runtime, L, v, wrap_none, max_depth, current_depth+1)
lua.lua_rawseti(L, -2, i)
elif recursive and isinstance(o, (list, Sequence)):
if mapped_objs is None:
mapped_objs = {}
if id(o) not in mapped_objs:
lua.lua_createtable(L, _len_as_int(len(o)), 0) # create a table at the top of stack, with narr already known
mapped_objs[id(o)] = lua.lua_gettop(L)
for i, v in enumerate(o, 1):
py_to_lua(runtime, L, v, wrap_none, recursive, mapped_objs)
lua.lua_rawseti(L, -2, i)
else: # self-reference detected
idx = mapped_objs[id(o)]
lua.lua_pushvalue(L, idx)
pushed_values_count = 1
elif max_depth>1 and isinstance(o, (dict, Mapping)):
lua.lua_createtable(L, 0, _len_as_int(len(o))) # create a table at the top of stack, with nrec already known
for key, value in o.items():
py_to_lua(runtime, L, key, wrap_none, max_depth, current_depth+1)
py_to_lua(runtime, L, value, wrap_none, max_depth, current_depth+1)
lua.lua_rawset(L, -3)
elif recursive and isinstance(o, (dict, Mapping)):
if mapped_objs is None:
mapped_objs = {}
if id(o) not in mapped_objs:
lua.lua_createtable(L, 0, _len_as_int(len(o))) # create a table at the top of stack, with nrec already known
mapped_objs[id(o)] = lua.lua_gettop(L)
for key, value in o.items():
py_to_lua(runtime, L, key, wrap_none, recursive, mapped_objs)
py_to_lua(runtime, L, value, wrap_none, recursive, mapped_objs)
lua.lua_rawset(L, -3)
else: # self-reference detected
idx = mapped_objs[id(o)]
lua.lua_pushvalue(L, idx)
pushed_values_count = 1
else:
# prefer __getitem__ over __getattr__ by default
Expand Down
24 changes: 18 additions & 6 deletions lupa/tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def test_table_from_table_iter_indirect(self):

def test_table_from_nested_dict(self):
data = {"a": {"a": "foo"}, "b": {"b": "bar"}}
table = self.lua.table_from(data, max_depth=10)
table = self.lua.table_from(data, recursive=True)
self.assertEqual(table["a"]["a"], "foo")
self.assertEqual(table["b"]["b"], "bar")
self.lua.globals()["data"] = table
Expand All @@ -659,7 +659,7 @@ def test_table_from_nested_dict(self):

def test_table_from_nested_list(self):
data = {"a": {"a": "foo"}, "b": [1, 2, 3]}
table = self.lua.table_from(data, max_depth=10)
table = self.lua.table_from(data, recursive=True)
self.assertEqual(table["a"]["a"], "foo")
self.assertEqual(table["b"][1], 1)
self.assertEqual(table["b"][2], 2)
Expand All @@ -686,7 +686,7 @@ def test_table_from_nested_list(self):

def test_table_from_nested_list_bad(self):
data = {"a": {"a": "foo"}, "b": [1, 2, 3]}
table = self.lua.table_from(data, max_depth=10) # in this case, lua will get userdata instead of table
table = self.lua.table_from(data, recursive=True) # in this case, lua will get userdata instead of table
self.assertEqual(table["a"]["a"], "foo")
print(list(table["b"]))
self.assertEqual(table["b"][1], 1)
Expand All @@ -699,11 +699,23 @@ def test_table_from_nested_list_bad(self):

del self.lua.globals()["data"]

def test_table_from_recursive_dict(self):
def test_table_from_self_ref_obj(self):
data = {}
data["key"] = data
with self.assertRaises(ValueError):
self.lua.table_from(data, max_depth=10)
l = []
l.append(l)
data["list"] = l
table = self.lua.table_from(data, recursive=True)
self.lua.globals()["data"] = table
self.lua.eval("assert(type(data)=='table', '')")
self.lua.eval("assert(type(data['key'])=='table', '')")
self.lua.eval("assert(type(data['list'])=='table', '')")
self.lua.eval("assert(data['list']==data['list'][1], 'wrong self-ref list')")
self.lua.eval("assert(type(data['key']['key']['key']['key'])=='table', 'wrong self-ref map')")
self.lua.eval("assert(type(data['key']['key']['key']['key']['list'])=='table', 'wrong self-ref map')")
# self.assertEqual(table["key"], table)
# self.assertEqual(table["list"], table["list"][0])
del self.lua.globals()["data"]

# FIXME: it segfaults
# def test_table_from_generator_calling_lua_functions(self):
Expand Down

0 comments on commit 642b375

Please sign in to comment.