diff --git a/lupa/_lupa.pyx b/lupa/_lupa.pyx index 66b10e4..48e7e7f 100644 --- a/lupa/_lupa.pyx +++ b/lupa/_lupa.pyx @@ -50,10 +50,11 @@ cdef object exc_info from sys import exc_info cdef object Mapping +cdef object Sequence try: - from collections.abc import Mapping + from collections.abc import Mapping, Sequence except ImportError: - from collections import Mapping # Py2 + from collections import Mapping, Sequence # Py2 cdef object wraps from functools import wraps @@ -169,6 +170,10 @@ def lua_type(obj): lua.lua_settop(L, old_top) unlock_runtime(lua_object._runtime) +cdef inline int _len_as_int(Py_ssize_t obj) except -1: + if obj > INT_MAX: + raise OverflowError + return obj @cython.no_gc_clear cdef class LuaRuntime: @@ -520,7 +525,7 @@ cdef class LuaRuntime: """ return self.table_from(items, kwargs) - def table_from(self, *args): + def table_from(self, *args, bint recursive=False): """Create a new table from Python mapping or iterable. table_from() accepts either a dict/mapping or an iterable with items. @@ -528,13 +533,15 @@ cdef class LuaRuntime: are placed in the table in order. Nested mappings / iterables are passed to Lua as userdata - (wrapped Python objects); they are not converted to Lua tables. + (wrapped Python objects) by default. If `recursive` is True, + they are converted to Lua tables recursively, handling loops + and duplicates via identity de-duplication. """ assert self._state is not NULL cdef lua_State *L = self._state lock_runtime(self) try: - return py_to_lua_table(self, L, args) + return py_to_lua_table(self, L, args, recursive=recursive) finally: unlock_runtime(self) @@ -1236,7 +1243,7 @@ cdef object resume_lua_thread(_LuaThread thread, tuple args): # already terminated raise StopIteration if args: - nargs = len(args) + nargs = _len_as_int(len(args)) push_lua_arguments(thread._runtime, co, args) with nogil: status = lua.lua_resume(co, L, nargs, &nres) @@ -1482,7 +1489,7 @@ cdef py_object* unpack_userdata(lua_State *L, int n) noexcept nogil: cdef int py_function_result_to_lua(LuaRuntime runtime, lua_State *L, object o) except -1: if runtime._unpack_returned_tuples and isinstance(o, tuple): push_lua_arguments(runtime, L, o) - return len(o) + return _len_as_int(len(o)) check_lua_stack(L, 1) return py_to_lua(runtime, L, o) @@ -1511,7 +1518,7 @@ cdef int py_to_lua_handle_overflow(LuaRuntime runtime, lua_State *L, object o) e lua.lua_settop(L, old_top) raise -cdef int py_to_lua(LuaRuntime runtime, lua_State *L, object o, bint wrap_none=False) except -1: +cdef int py_to_lua(LuaRuntime runtime, lua_State *L, object o, bint wrap_none=False, bint recursive=False, dict mapped_tables=None) except -1: """Converts Python object to Lua Preconditions: 1 extra slot in the Lua stack @@ -1563,13 +1570,19 @@ cdef int py_to_lua(LuaRuntime runtime, lua_State *L, object o, bint wrap_none=Fa elif isinstance(o, float): lua.lua_pushnumber(L, o) pushed_values_count = 1 + elif isinstance(o, _PyProtocolWrapper): + type_flags = (<_PyProtocolWrapper> o)._type_flags + o = (<_PyProtocolWrapper> o)._obj + pushed_values_count = py_to_lua_custom(runtime, L, o, type_flags) + elif recursive and isinstance(o, (list, dict, Sequence, Mapping)): + if mapped_tables is None: + mapped_tables = {} + table = py_to_lua_table(runtime, L, (o,), recursive=recursive, mapped_tables=mapped_tables) + (<_LuaObject> table).push_lua_object(L) + pushed_values_count = 1 else: - if isinstance(o, _PyProtocolWrapper): - type_flags = (<_PyProtocolWrapper>o)._type_flags - o = (<_PyProtocolWrapper>o)._obj - else: - # prefer __getitem__ over __getattr__ by default - type_flags = OBJ_AS_INDEX if hasattr(o, '__getitem__') else 0 + # prefer __getitem__ over __getattr__ by default + type_flags = OBJ_AS_INDEX if hasattr(o, '__getitem__') else 0 pushed_values_count = py_to_lua_custom(runtime, L, o, type_flags) return pushed_values_count @@ -1655,7 +1668,7 @@ cdef bytes _asciiOrNone(s): return s -cdef _LuaTable py_to_lua_table(LuaRuntime runtime, lua_State* L, items): +cdef _LuaTable py_to_lua_table(LuaRuntime runtime, lua_State* L, tuple items, bint recursive=False, dict mapped_tables=None): """ Create a new Lua table and add different kinds of values from the sequence 'items' to it. @@ -1666,14 +1679,24 @@ cdef _LuaTable py_to_lua_table(LuaRuntime runtime, lua_State* L, items): check_lua_stack(L, 5) old_top = lua.lua_gettop(L) lua.lua_newtable(L) - # FIXME: how to check for failure? - + # FIXME: handle allocation errors + cdef int lua_table_ref = lua.lua_gettop(L) # the index of the lua table which we are filling + if recursive and mapped_tables is None: + mapped_tables = {} try: for obj in items: + if recursive: + if id(obj) not in mapped_tables: + # this object is never seen before, we should cache it + mapped_tables[id(obj)] = lua_table_ref + else: + # this object has been cached, just get the corresponding lua table's index + idx = mapped_tables[id(obj)] + return new_lua_table(runtime, L, idx) if isinstance(obj, dict): for key, value in (obj).items(): - py_to_lua(runtime, L, key, wrap_none=True) - py_to_lua(runtime, L, value) + py_to_lua(runtime, L, key, wrap_none=True, recursive=recursive, mapped_tables=mapped_tables) + py_to_lua(runtime, L, value, wrap_none=False, recursive=recursive, mapped_tables=mapped_tables) lua.lua_rawset(L, -3) elif isinstance(obj, _LuaTable): @@ -1689,13 +1712,13 @@ cdef _LuaTable py_to_lua_table(LuaRuntime runtime, lua_State* L, items): elif isinstance(obj, Mapping): for key in obj: value = obj[key] - py_to_lua(runtime, L, key, wrap_none=True) - py_to_lua(runtime, L, value) + py_to_lua(runtime, L, key, wrap_none=True, recursive=recursive, mapped_tables=mapped_tables) + py_to_lua(runtime, L, value, wrap_none=False, recursive=recursive, mapped_tables=mapped_tables) lua.lua_rawset(L, -3) else: for arg in obj: - py_to_lua(runtime, L, arg) + py_to_lua(runtime, L, arg, wrap_none=False, recursive=recursive, mapped_tables=mapped_tables) lua.lua_rawseti(L, -2, i) i += 1 @@ -1826,7 +1849,7 @@ cdef object execute_lua_call(LuaRuntime runtime, lua_State *L, Py_ssize_t nargs) lua.lua_replace(L, -2) lua.lua_insert(L, 1) has_lua_traceback_func = True - result_status = lua.lua_pcall(L, nargs, lua.LUA_MULTRET, has_lua_traceback_func) + result_status = lua.lua_pcall(L, nargs, lua.LUA_MULTRET, has_lua_traceback_func) if has_lua_traceback_func: lua.lua_remove(L, 1) results = unpack_lua_results(runtime, L) @@ -2004,7 +2027,7 @@ cdef bint call_python(LuaRuntime runtime, lua_State *L, py_object* py_obj) excep else: args = () kwargs = {} - + for i in range(nargs): arg = py_from_lua(runtime, L, i+2) if isinstance(arg, _PyArguments): diff --git a/lupa/tests/test.py b/lupa/tests/test.py index f2d036a..e662a9e 100644 --- a/lupa/tests/test.py +++ b/lupa/tests/test.py @@ -104,6 +104,9 @@ def get_attr(obj, name): class TestLuaRuntime(SetupLuaRuntimeMixin, LupaTestCase): + def assertLuaResult(self, lua_expression, result): + self.assertEqual(self.lua.eval(lua_expression), result) + def test_lua_version(self): version = self.lua.lua_version self.assertEqual(tuple, type(version)) @@ -598,27 +601,23 @@ def test_table_from_bad(self): self.assertRaises(TypeError, self.lua.table_from, None) self.assertRaises(TypeError, self.lua.table_from, {"a": 5}, 123) - def test_table_from_nested_datastructures(self): - from itertools import count - def make_ds(*children): - yield list(children) - yield dict(zip(count(), children)) - yield {chr(ord('A') + i): child for i, child in enumerate(children)} - - elements = [1, 2, 'x', 'y'] - for ds1 in make_ds(*elements): - for ds2 in make_ds(ds1): - for ds3 in make_ds(ds1, elements, ds2): - for ds in make_ds(ds1, ds2, ds3): - with self.subTest(ds=ds): - table = self.lua.table_from(ds) - # we don't translate transitively, so apply arbitrary test operation - self.assertTrue(list(table)) - - # def test_table_from_nested(self): - # table = self.lua.table_from({"obj": {"foo": "bar"}}) - # lua_type = self.lua.eval("type") - # self.assertEqual(lua_type(table["obj"]), "table") + def test_table_from_nested(self): + table = self.lua.table_from([[3, 3, 3]], recursive=True) + self.lua.globals()["data"] = table + self.assertLuaResult("data[1][1]", 3) + self.assertLuaResult("data[1][2]", 3) + self.assertLuaResult("data[1][3]", 3) + self.assertLuaResult("type(data)", "table") + self.assertLuaResult("type(data[1])", "table") + self.assertLuaResult("#data", 1) + self.assertLuaResult("#data[1]", 3) + + def test_table_from_nested2(self): + table2 = self.lua.table_from([{"a": "foo"}, {"b": 1}], recursive=True) + self.lua.globals()["data2"] = table2 + self.assertLuaResult("#data2", 2) + self.assertLuaResult("data2[1]['a']", "foo") + self.assertLuaResult("data2[2]['b']", 1) def test_table_from_table(self): table1 = self.lua.eval("{3, 4, foo='bar'}") @@ -649,6 +648,75 @@ def test_table_from_table_iter_indirect(self): self.assertEqual(list(table2.keys()), [1, 2, 3]) self.assertEqual(set(table2.values()), set([1, 2, "foo"])) + def test_table_from_nested_dict(self): + data = {"a": {"a": "foo"}, "b": {"b": "bar"}} + table = self.lua.table_from(data, recursive=True) + self.assertEqual(table["a"]["a"], "foo") + self.assertEqual(table["b"]["b"], "bar") + self.lua.globals()["data"] = table + self.assertLuaResult("data.a.a", "foo") + self.assertLuaResult("data.b.b", "bar") + self.assertLuaResult("type(data.a)", "table") + self.assertLuaResult("type(data.b)", "table") + + def test_table_from_nested_list(self): + data = {"a": {"a": "foo"}, "b": [1, 2, 3]} + table = self.lua.table_from(data, recursive=True) + self.assertEqual(table["a"]["a"], "foo") + self.assertEqual(table["b"][1], 1) + self.assertEqual(table["b"][2], 2) + self.assertEqual(table["b"][3], 3) + self.lua.globals()["data"] = table + self.assertLuaResult("data.a.a", "foo") + self.assertLuaResult("#data.b", 3) + self.lua.eval("assert(#data.b==3, 'failed')") + self.assertLuaResult("type(data.a)", "table") + self.assertLuaResult("type(data.b)", "table") + + def test_table_from_nested_list_bad(self): + data = {"a": {"a": "foo"}, "b": [1, 2, 3]} + table = self.lua.table_from(data) # in this case, lua will get userdata instead of table + self.assertEqual(table["a"]["a"], "foo") + self.assertEqual(list(table["b"]), [1, 2, 3]) + self.assertEqual(table["b"][0], 1) + self.assertEqual(table["b"][1], 2) + self.assertEqual(table["b"][2], 3) + self.lua.globals()["data"] = table + self.assertLuaResult("type(data.a)", "userdata") + self.assertLuaResult("type(data.b)", "userdata") + + def test_table_from_self_ref_obj(self): + data = {} + data["key"] = data + l = [] + l.append(l) + data["list"] = l + table = self.lua.table_from(data, recursive=True) + self.lua.globals()["data"] = table + self.assertLuaResult("type(data)", 'table') + self.assertLuaResult("type(data['key'])",'table') + self.assertLuaResult("type(data['list'])",'table') + self.assertLuaResult("data['list']==data['list'][1]", True) + self.assertLuaResult("type(data['key']['key']['key']['key'])", 'table') + self.assertLuaResult("type(data['key']['key']['key']['key']['list'])", 'table') + + def test_table_from_nested_datastructures(self): + from itertools import count + def make_ds(*children): + yield list(children) + yield dict(zip(count(), children)) + yield {chr(ord('A') + i): child for i, child in enumerate(children)} + + elements = [1, 2, 'x', 'y'] + for ds1 in make_ds(*elements): + for ds2 in make_ds(ds1): + for ds3 in make_ds(ds1, elements, ds2): + for ds in make_ds(ds1, ds2, ds3): + with self.subTest(ds=ds): + table = self.lua.table_from(ds) + # we don't translate transitively, so apply arbitrary test operation + self.assertTrue(list(table)) + # FIXME: it segfaults # def test_table_from_generator_calling_lua_functions(self): # func = self.lua.eval("function (obj) return obj end")