Skip to content

Commit

Permalink
Tensor copy_compatible_to_dims native (#1409)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz authored Oct 6, 2023
1 parent 4e9adb3 commit 13cb8ff
Show file tree
Hide file tree
Showing 9 changed files with 443 additions and 70 deletions.
18 changes: 18 additions & 0 deletions returnn/frontend/_native/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import hashlib
from glob import glob
import textwrap
from returnn.util.py_ext_mod_compiler import PyExtModCompiler

_module = None
Expand All @@ -29,6 +30,21 @@ def get_module(*, verbose: bool = False):
src_code += f"// {os.path.basename(fn)} code hash md5: {_code_hash_md5(fn)}\n"
src_code += f'#include "{os.path.basename(fn)}"\n'

if os.environ.get("RETURNN_TEST") == "1":
src_code = (
textwrap.dedent(
"""\
#define DEBUG 1
#ifdef NDEBUG
#undef NDEBUG
#endif
"""
)
+ src_code
)
verbose = True

compiler = PyExtModCompiler(
base_name="_returnn_frontend_native",
code_version=1,
Expand Down Expand Up @@ -106,6 +122,8 @@ def setup():
"copy": "tensor_copy",
"copy_template": "tensor_copy_template",
"get_out_permutation_to_dims": "tensor_get_out_permutation_to_dims",
"copy_compatible_to_dims": "tensor_copy_compatible_to_dims",
"copy_compatible_to_dims_raw": "tensor_copy_compatible_to_dims_raw",
}.items():
assert hasattr(_TensorMixin, rf_name)
native_func = getattr(mod, "_" + native_name + "_instancemethod")
Expand Down
6 changes: 6 additions & 0 deletions returnn/frontend/_native/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ static PyMethodDef _pyModuleMethods[] = {
{"tensor_copy_template", (PyCFunction) pyTensorCopyTemplate, METH_VARARGS | METH_KEYWORDS, "Tensor.copy_template"},
{"tensor_get_out_permutation_to_dims", (PyCFunction) pyTensorGetOutPermutationsToDims, METH_FASTCALL,
"Tensor.get_out_permutation_to_dims"},
{"tensor_copy_compatible_to_dims", (PyCFunction) pyTensorCopyCompatibleToDims, METH_FASTCALL, "Tensor.copy_compatible_to_dims"},
{"tensor_copy_compatible_to_dims_raw", (PyCFunction) pyTensorCopyCompatibleToDimsRaw, METH_FASTCALL, "Tensor.copy_compatible_to_dims_raw"},

{"tensor_compare", (PyCFunction) pyTensorCompare, METH_VARARGS | METH_KEYWORDS, "rf.compare"},
{"tensor_combine", (PyCFunction) pyTensorCombine, METH_VARARGS | METH_KEYWORDS, "rf.combine"},
Expand Down Expand Up @@ -82,6 +84,8 @@ int PyModuleState::pyInitModuleExec(PyObject* module) {
if(!mod) return -1;
_tensorType = PyObject_GetAttrString(mod, "Tensor");
if(!_tensorType) return -1;
_dimType = PyObject_GetAttrString(mod, "Dim");
if(!_dimType) return -1;
}

{
Expand Down Expand Up @@ -134,6 +138,8 @@ int PyModuleState::pyInitModuleExec(PyObject* module) {
AddInstanceMethod(copy);
AddInstanceMethod(copy_template);
AddInstanceMethod(get_out_permutation_to_dims);
AddInstanceMethod(copy_compatible_to_dims);
AddInstanceMethod(copy_compatible_to_dims_raw);

AddInstanceMethod(eq);
AddInstanceMethod(ne);
Expand Down
4 changes: 4 additions & 0 deletions returnn/frontend/_native/module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class PyModuleState {
for(int i = 0; i < _rawTensorTypesLen; ++i)
Py_VISIT(_rawTensorTypes[i]);
Py_VISIT(_tensorType);
Py_VISIT(_dimType);
Py_VISIT(_globalBackend);
Py_VISIT(_backendTensorTypeDispatchTable);
for(int i = 0; i < NumBackendsWithCachedOps * NumTOps; ++i)
Expand All @@ -130,6 +131,7 @@ class PyModuleState {
for(unsigned int i = 0; i < sizeof(_rawTensorTypes)/sizeof(_rawTensorTypes[0]); ++i)
Py_CLEAR(_rawTensorTypes[i]);
Py_CLEAR(_tensorType);
Py_CLEAR(_dimType);
Py_CLEAR(_globalBackend);
Py_CLEAR(_backendTensorTypeDispatchTable);
for(int i = 0; i < NumBackendsWithCachedOps * NumTOps; ++i)
Expand All @@ -144,6 +146,7 @@ class PyModuleState {

inline PyObject* notSpecified() const { return _notSpecified; }
inline PyObject* tensorType() const { return _tensorType; }
inline PyObject* dimType() const { return _dimType; }
inline PyObject* globalBackend() const { return _globalBackend; }
inline PyObject* cachedOp(RawOp op, BackendWithCachedOps backend) {
if(!_cachedOps[backend * NumTOps + op])
Expand All @@ -163,6 +166,7 @@ class PyModuleState {
int _rawTensorTypesLen;
PyObject* _rawTensorTypes[10];
PyObject* _tensorType;
PyObject* _dimType;
PyObject* _globalBackend;
PyObject* _backendTensorTypeDispatchTable;
PyObject* _cachedOps[NumBackendsWithCachedOps * NumTOps];
Expand Down
18 changes: 18 additions & 0 deletions returnn/frontend/_native/py_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,23 @@ class PyTupleOrListStaticRef {

public:
PyTupleOrListStaticRef(PyObject* obj) : _obj(obj) {
#ifdef DEBUG
assert(obj);
if(isTuple) assert(PyTuple_Check(obj));
else assert(PyList_Check(obj));
#endif
if(isTuple) _size = PyTuple_GET_SIZE(obj);
else _size = PyList_GET_SIZE(obj);
#ifdef DEBUG
assert(_size >= 0);
#endif
}

int size() const { return _size; }
PyObject* getItem(int i) const {
#ifdef DEBUG
assert(i >= 0 && i < _size);
#endif
if(isTuple) return PyTuple_GET_ITEM(_obj, i);
else return PyList_GET_ITEM(_obj, i);
}
Expand All @@ -73,11 +84,18 @@ class PyTupleOrListRef {
if(_type == TupleType) _size = PyTuple_GET_SIZE(obj);
else if(_type == ListType) _size = PyList_GET_SIZE(obj);
else _size = -1;
#ifdef DEBUG
if(_type != UnknownType)
assert(_size >= 0);
#endif
}

bool isValid() const { return _type != UnknownType; }
int size() const { return _size; }
PyObject* getItem(int i) const {
#ifdef DEBUG
assert(i >= 0 && i < _size);
#endif
if(_type == TupleType) return PyTuple_GET_ITEM(_obj, i);
else if(_type == ListType) return PyList_GET_ITEM(_obj, i);
else return NULL;
Expand Down
Loading

0 comments on commit 13cb8ff

Please sign in to comment.