Skip to content

Commit

Permalink
Add support for transitively mapping Python data structures to Lua ta…
Browse files Browse the repository at this point in the history
…bles (GH-208)

Closes #199
  • Loading branch information
synodriver authored Feb 21, 2024
1 parent 6e8cb6d commit b204445
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 45 deletions.
71 changes: 47 additions & 24 deletions lupa/_lupa.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@ cdef object exc_info
from sys import exc_info

cdef object Mapping
cdef object Sequence
try:
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
except ImportError:
from collections import Mapping # Py2
from collections import Mapping, Sequence # Py2

cdef object wraps
from functools import wraps
Expand Down Expand Up @@ -169,6 +170,10 @@ def lua_type(obj):
lua.lua_settop(L, old_top)
unlock_runtime(lua_object._runtime)

cdef inline int _len_as_int(Py_ssize_t obj) except -1:
if obj > <Py_ssize_t>INT_MAX:
raise OverflowError
return <int>obj

@cython.no_gc_clear
cdef class LuaRuntime:
Expand Down Expand Up @@ -520,21 +525,23 @@ cdef class LuaRuntime:
"""
return self.table_from(items, kwargs)

def table_from(self, *args):
def table_from(self, *args, bint recursive=False):
"""Create a new table from Python mapping or iterable.
table_from() accepts either a dict/mapping or an iterable with items.
Items from dicts are set as key-value pairs; items from iterables
are placed in the table in order.
Nested mappings / iterables are passed to Lua as userdata
(wrapped Python objects); they are not converted to Lua tables.
(wrapped Python objects) by default. If `recursive` is True,
they are converted to Lua tables recursively, handling loops
and duplicates via identity de-duplication.
"""
assert self._state is not NULL
cdef lua_State *L = self._state
lock_runtime(self)
try:
return py_to_lua_table(self, L, args)
return py_to_lua_table(self, L, args, recursive=recursive)
finally:
unlock_runtime(self)

Expand Down Expand Up @@ -1236,7 +1243,7 @@ cdef object resume_lua_thread(_LuaThread thread, tuple args):
# already terminated
raise StopIteration
if args:
nargs = len(args)
nargs = _len_as_int(len(args))
push_lua_arguments(thread._runtime, co, args)
with nogil:
status = lua.lua_resume(co, L, nargs, &nres)
Expand Down Expand Up @@ -1482,7 +1489,7 @@ cdef py_object* unpack_userdata(lua_State *L, int n) noexcept nogil:
cdef int py_function_result_to_lua(LuaRuntime runtime, lua_State *L, object o) except -1:
if runtime._unpack_returned_tuples and isinstance(o, tuple):
push_lua_arguments(runtime, L, <tuple>o)
return len(<tuple>o)
return _len_as_int(len(<tuple>o))
check_lua_stack(L, 1)
return py_to_lua(runtime, L, o)

Expand Down Expand Up @@ -1511,7 +1518,7 @@ cdef int py_to_lua_handle_overflow(LuaRuntime runtime, lua_State *L, object o) e
lua.lua_settop(L, old_top)
raise

cdef int py_to_lua(LuaRuntime runtime, lua_State *L, object o, bint wrap_none=False) except -1:
cdef int py_to_lua(LuaRuntime runtime, lua_State *L, object o, bint wrap_none=False, bint recursive=False, dict mapped_tables=None) except -1:
"""Converts Python object to Lua
Preconditions:
1 extra slot in the Lua stack
Expand Down Expand Up @@ -1563,13 +1570,19 @@ cdef int py_to_lua(LuaRuntime runtime, lua_State *L, object o, bint wrap_none=Fa
elif isinstance(o, float):
lua.lua_pushnumber(L, <lua.lua_Number><double>o)
pushed_values_count = 1
elif isinstance(o, _PyProtocolWrapper):
type_flags = (<_PyProtocolWrapper> o)._type_flags
o = (<_PyProtocolWrapper> o)._obj
pushed_values_count = py_to_lua_custom(runtime, L, o, type_flags)
elif recursive and isinstance(o, (list, dict, Sequence, Mapping)):
if mapped_tables is None:
mapped_tables = {}
table = py_to_lua_table(runtime, L, (o,), recursive=recursive, mapped_tables=mapped_tables)
(<_LuaObject> table).push_lua_object(L)
pushed_values_count = 1
else:
if isinstance(o, _PyProtocolWrapper):
type_flags = (<_PyProtocolWrapper>o)._type_flags
o = (<_PyProtocolWrapper>o)._obj
else:
# prefer __getitem__ over __getattr__ by default
type_flags = OBJ_AS_INDEX if hasattr(o, '__getitem__') else 0
# prefer __getitem__ over __getattr__ by default
type_flags = OBJ_AS_INDEX if hasattr(o, '__getitem__') else 0
pushed_values_count = py_to_lua_custom(runtime, L, o, type_flags)
return pushed_values_count

Expand Down Expand Up @@ -1655,7 +1668,7 @@ cdef bytes _asciiOrNone(s):
return <bytes>s


cdef _LuaTable py_to_lua_table(LuaRuntime runtime, lua_State* L, items):
cdef _LuaTable py_to_lua_table(LuaRuntime runtime, lua_State* L, tuple items, bint recursive=False, dict mapped_tables=None):
"""
Create a new Lua table and add different kinds of values from the sequence 'items' to it.
Expand All @@ -1666,14 +1679,24 @@ cdef _LuaTable py_to_lua_table(LuaRuntime runtime, lua_State* L, items):
check_lua_stack(L, 5)
old_top = lua.lua_gettop(L)
lua.lua_newtable(L)
# FIXME: how to check for failure?

# FIXME: handle allocation errors
cdef int lua_table_ref = lua.lua_gettop(L) # the index of the lua table which we are filling
if recursive and mapped_tables is None:
mapped_tables = {}
try:
for obj in items:
if recursive:
if id(obj) not in mapped_tables:
# this object is never seen before, we should cache it
mapped_tables[id(obj)] = lua_table_ref
else:
# this object has been cached, just get the corresponding lua table's index
idx = mapped_tables[id(obj)]
return new_lua_table(runtime, L, <int>idx)
if isinstance(obj, dict):
for key, value in (<dict>obj).items():
py_to_lua(runtime, L, key, wrap_none=True)
py_to_lua(runtime, L, value)
py_to_lua(runtime, L, key, wrap_none=True, recursive=recursive, mapped_tables=mapped_tables)
py_to_lua(runtime, L, value, wrap_none=False, recursive=recursive, mapped_tables=mapped_tables)
lua.lua_rawset(L, -3)

elif isinstance(obj, _LuaTable):
Expand All @@ -1689,13 +1712,13 @@ cdef _LuaTable py_to_lua_table(LuaRuntime runtime, lua_State* L, items):
elif isinstance(obj, Mapping):
for key in obj:
value = obj[key]
py_to_lua(runtime, L, key, wrap_none=True)
py_to_lua(runtime, L, value)
py_to_lua(runtime, L, key, wrap_none=True, recursive=recursive, mapped_tables=mapped_tables)
py_to_lua(runtime, L, value, wrap_none=False, recursive=recursive, mapped_tables=mapped_tables)
lua.lua_rawset(L, -3)

else:
for arg in obj:
py_to_lua(runtime, L, arg)
py_to_lua(runtime, L, arg, wrap_none=False, recursive=recursive, mapped_tables=mapped_tables)
lua.lua_rawseti(L, -2, i)
i += 1

Expand Down Expand Up @@ -1826,7 +1849,7 @@ cdef object execute_lua_call(LuaRuntime runtime, lua_State *L, Py_ssize_t nargs)
lua.lua_replace(L, -2)
lua.lua_insert(L, 1)
has_lua_traceback_func = True
result_status = lua.lua_pcall(L, nargs, lua.LUA_MULTRET, has_lua_traceback_func)
result_status = lua.lua_pcall(L, <int>nargs, lua.LUA_MULTRET, has_lua_traceback_func)
if has_lua_traceback_func:
lua.lua_remove(L, 1)
results = unpack_lua_results(runtime, L)
Expand Down Expand Up @@ -2004,7 +2027,7 @@ cdef bint call_python(LuaRuntime runtime, lua_State *L, py_object* py_obj) excep
else:
args = ()
kwargs = {}

for i in range(nargs):
arg = py_from_lua(runtime, L, i+2)
if isinstance(arg, _PyArguments):
Expand Down
110 changes: 89 additions & 21 deletions lupa/tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def get_attr(obj, name):


class TestLuaRuntime(SetupLuaRuntimeMixin, LupaTestCase):
def assertLuaResult(self, lua_expression, result):
self.assertEqual(self.lua.eval(lua_expression), result)

def test_lua_version(self):
version = self.lua.lua_version
self.assertEqual(tuple, type(version))
Expand Down Expand Up @@ -598,27 +601,23 @@ def test_table_from_bad(self):
self.assertRaises(TypeError, self.lua.table_from, None)
self.assertRaises(TypeError, self.lua.table_from, {"a": 5}, 123)

def test_table_from_nested_datastructures(self):
from itertools import count
def make_ds(*children):
yield list(children)
yield dict(zip(count(), children))
yield {chr(ord('A') + i): child for i, child in enumerate(children)}

elements = [1, 2, 'x', 'y']
for ds1 in make_ds(*elements):
for ds2 in make_ds(ds1):
for ds3 in make_ds(ds1, elements, ds2):
for ds in make_ds(ds1, ds2, ds3):
with self.subTest(ds=ds):
table = self.lua.table_from(ds)
# we don't translate transitively, so apply arbitrary test operation
self.assertTrue(list(table))

# def test_table_from_nested(self):
# table = self.lua.table_from({"obj": {"foo": "bar"}})
# lua_type = self.lua.eval("type")
# self.assertEqual(lua_type(table["obj"]), "table")
def test_table_from_nested(self):
table = self.lua.table_from([[3, 3, 3]], recursive=True)
self.lua.globals()["data"] = table
self.assertLuaResult("data[1][1]", 3)
self.assertLuaResult("data[1][2]", 3)
self.assertLuaResult("data[1][3]", 3)
self.assertLuaResult("type(data)", "table")
self.assertLuaResult("type(data[1])", "table")
self.assertLuaResult("#data", 1)
self.assertLuaResult("#data[1]", 3)

def test_table_from_nested2(self):
table2 = self.lua.table_from([{"a": "foo"}, {"b": 1}], recursive=True)
self.lua.globals()["data2"] = table2
self.assertLuaResult("#data2", 2)
self.assertLuaResult("data2[1]['a']", "foo")
self.assertLuaResult("data2[2]['b']", 1)

def test_table_from_table(self):
table1 = self.lua.eval("{3, 4, foo='bar'}")
Expand Down Expand Up @@ -649,6 +648,75 @@ def test_table_from_table_iter_indirect(self):
self.assertEqual(list(table2.keys()), [1, 2, 3])
self.assertEqual(set(table2.values()), set([1, 2, "foo"]))

def test_table_from_nested_dict(self):
data = {"a": {"a": "foo"}, "b": {"b": "bar"}}
table = self.lua.table_from(data, recursive=True)
self.assertEqual(table["a"]["a"], "foo")
self.assertEqual(table["b"]["b"], "bar")
self.lua.globals()["data"] = table
self.assertLuaResult("data.a.a", "foo")
self.assertLuaResult("data.b.b", "bar")
self.assertLuaResult("type(data.a)", "table")
self.assertLuaResult("type(data.b)", "table")

def test_table_from_nested_list(self):
data = {"a": {"a": "foo"}, "b": [1, 2, 3]}
table = self.lua.table_from(data, recursive=True)
self.assertEqual(table["a"]["a"], "foo")
self.assertEqual(table["b"][1], 1)
self.assertEqual(table["b"][2], 2)
self.assertEqual(table["b"][3], 3)
self.lua.globals()["data"] = table
self.assertLuaResult("data.a.a", "foo")
self.assertLuaResult("#data.b", 3)
self.lua.eval("assert(#data.b==3, 'failed')")
self.assertLuaResult("type(data.a)", "table")
self.assertLuaResult("type(data.b)", "table")

def test_table_from_nested_list_bad(self):
data = {"a": {"a": "foo"}, "b": [1, 2, 3]}
table = self.lua.table_from(data) # in this case, lua will get userdata instead of table
self.assertEqual(table["a"]["a"], "foo")
self.assertEqual(list(table["b"]), [1, 2, 3])
self.assertEqual(table["b"][0], 1)
self.assertEqual(table["b"][1], 2)
self.assertEqual(table["b"][2], 3)
self.lua.globals()["data"] = table
self.assertLuaResult("type(data.a)", "userdata")
self.assertLuaResult("type(data.b)", "userdata")

def test_table_from_self_ref_obj(self):
data = {}
data["key"] = data
l = []
l.append(l)
data["list"] = l
table = self.lua.table_from(data, recursive=True)
self.lua.globals()["data"] = table
self.assertLuaResult("type(data)", 'table')
self.assertLuaResult("type(data['key'])",'table')
self.assertLuaResult("type(data['list'])",'table')
self.assertLuaResult("data['list']==data['list'][1]", True)
self.assertLuaResult("type(data['key']['key']['key']['key'])", 'table')
self.assertLuaResult("type(data['key']['key']['key']['key']['list'])", 'table')

def test_table_from_nested_datastructures(self):
from itertools import count
def make_ds(*children):
yield list(children)
yield dict(zip(count(), children))
yield {chr(ord('A') + i): child for i, child in enumerate(children)}

elements = [1, 2, 'x', 'y']
for ds1 in make_ds(*elements):
for ds2 in make_ds(ds1):
for ds3 in make_ds(ds1, elements, ds2):
for ds in make_ds(ds1, ds2, ds3):
with self.subTest(ds=ds):
table = self.lua.table_from(ds)
# we don't translate transitively, so apply arbitrary test operation
self.assertTrue(list(table))

# FIXME: it segfaults
# def test_table_from_generator_calling_lua_functions(self):
# func = self.lua.eval("function (obj) return obj end")
Expand Down

0 comments on commit b204445

Please sign in to comment.