From 13cb8ff1c75eb8aa5f0773f4297b8afa30a42f53 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 6 Oct 2023 20:08:55 +0200 Subject: [PATCH] Tensor copy_compatible_to_dims native (#1409) --- returnn/frontend/_native/__init__.py | 18 ++ returnn/frontend/_native/module.cpp | 6 + returnn/frontend/_native/module.hpp | 4 + returnn/frontend/_native/py_utils.hpp | 18 ++ returnn/frontend/_native/tensor_ops.cpp | 412 ++++++++++++++++++++---- returnn/frontend/_native/tensor_ops.hpp | 2 + returnn/tensor/_tensor_extra.py | 25 +- returnn/tensor/tensor.py | 2 +- tests/test_torch_frontend.py | 26 ++ 9 files changed, 443 insertions(+), 70 deletions(-) diff --git a/returnn/frontend/_native/__init__.py b/returnn/frontend/_native/__init__.py index fa2b1cdd4e..a8d2938d9c 100644 --- a/returnn/frontend/_native/__init__.py +++ b/returnn/frontend/_native/__init__.py @@ -7,6 +7,7 @@ import os import hashlib from glob import glob +import textwrap from returnn.util.py_ext_mod_compiler import PyExtModCompiler _module = None @@ -29,6 +30,21 @@ def get_module(*, verbose: bool = False): src_code += f"// {os.path.basename(fn)} code hash md5: {_code_hash_md5(fn)}\n" src_code += f'#include "{os.path.basename(fn)}"\n' + if os.environ.get("RETURNN_TEST") == "1": + src_code = ( + textwrap.dedent( + """\ + #define DEBUG 1 + #ifdef NDEBUG + #undef NDEBUG + #endif + + """ + ) + + src_code + ) + verbose = True + compiler = PyExtModCompiler( base_name="_returnn_frontend_native", code_version=1, @@ -106,6 +122,8 @@ def setup(): "copy": "tensor_copy", "copy_template": "tensor_copy_template", "get_out_permutation_to_dims": "tensor_get_out_permutation_to_dims", + "copy_compatible_to_dims": "tensor_copy_compatible_to_dims", + "copy_compatible_to_dims_raw": "tensor_copy_compatible_to_dims_raw", }.items(): assert hasattr(_TensorMixin, rf_name) native_func = getattr(mod, "_" + native_name + "_instancemethod") diff --git a/returnn/frontend/_native/module.cpp b/returnn/frontend/_native/module.cpp index 8ea0899cfd..d305fec306 100644 --- a/returnn/frontend/_native/module.cpp +++ b/returnn/frontend/_native/module.cpp @@ -21,6 +21,8 @@ static PyMethodDef _pyModuleMethods[] = { {"tensor_copy_template", (PyCFunction) pyTensorCopyTemplate, METH_VARARGS | METH_KEYWORDS, "Tensor.copy_template"}, {"tensor_get_out_permutation_to_dims", (PyCFunction) pyTensorGetOutPermutationsToDims, METH_FASTCALL, "Tensor.get_out_permutation_to_dims"}, + {"tensor_copy_compatible_to_dims", (PyCFunction) pyTensorCopyCompatibleToDims, METH_FASTCALL, "Tensor.copy_compatible_to_dims"}, + {"tensor_copy_compatible_to_dims_raw", (PyCFunction) pyTensorCopyCompatibleToDimsRaw, METH_FASTCALL, "Tensor.copy_compatible_to_dims_raw"}, {"tensor_compare", (PyCFunction) pyTensorCompare, METH_VARARGS | METH_KEYWORDS, "rf.compare"}, {"tensor_combine", (PyCFunction) pyTensorCombine, METH_VARARGS | METH_KEYWORDS, "rf.combine"}, @@ -82,6 +84,8 @@ int PyModuleState::pyInitModuleExec(PyObject* module) { if(!mod) return -1; _tensorType = PyObject_GetAttrString(mod, "Tensor"); if(!_tensorType) return -1; + _dimType = PyObject_GetAttrString(mod, "Dim"); + if(!_dimType) return -1; } { @@ -134,6 +138,8 @@ int PyModuleState::pyInitModuleExec(PyObject* module) { AddInstanceMethod(copy); AddInstanceMethod(copy_template); AddInstanceMethod(get_out_permutation_to_dims); + AddInstanceMethod(copy_compatible_to_dims); + AddInstanceMethod(copy_compatible_to_dims_raw); AddInstanceMethod(eq); AddInstanceMethod(ne); diff --git a/returnn/frontend/_native/module.hpp b/returnn/frontend/_native/module.hpp index c0e87dcde3..768b212d61 100644 --- a/returnn/frontend/_native/module.hpp +++ b/returnn/frontend/_native/module.hpp @@ -115,6 +115,7 @@ class PyModuleState { for(int i = 0; i < _rawTensorTypesLen; ++i) Py_VISIT(_rawTensorTypes[i]); Py_VISIT(_tensorType); + Py_VISIT(_dimType); Py_VISIT(_globalBackend); Py_VISIT(_backendTensorTypeDispatchTable); for(int i = 0; i < NumBackendsWithCachedOps * NumTOps; ++i) @@ -130,6 +131,7 @@ class PyModuleState { for(unsigned int i = 0; i < sizeof(_rawTensorTypes)/sizeof(_rawTensorTypes[0]); ++i) Py_CLEAR(_rawTensorTypes[i]); Py_CLEAR(_tensorType); + Py_CLEAR(_dimType); Py_CLEAR(_globalBackend); Py_CLEAR(_backendTensorTypeDispatchTable); for(int i = 0; i < NumBackendsWithCachedOps * NumTOps; ++i) @@ -144,6 +146,7 @@ class PyModuleState { inline PyObject* notSpecified() const { return _notSpecified; } inline PyObject* tensorType() const { return _tensorType; } + inline PyObject* dimType() const { return _dimType; } inline PyObject* globalBackend() const { return _globalBackend; } inline PyObject* cachedOp(RawOp op, BackendWithCachedOps backend) { if(!_cachedOps[backend * NumTOps + op]) @@ -163,6 +166,7 @@ class PyModuleState { int _rawTensorTypesLen; PyObject* _rawTensorTypes[10]; PyObject* _tensorType; + PyObject* _dimType; PyObject* _globalBackend; PyObject* _backendTensorTypeDispatchTable; PyObject* _cachedOps[NumBackendsWithCachedOps * NumTOps]; diff --git a/returnn/frontend/_native/py_utils.hpp b/returnn/frontend/_native/py_utils.hpp index 67d55fd133..119f311067 100644 --- a/returnn/frontend/_native/py_utils.hpp +++ b/returnn/frontend/_native/py_utils.hpp @@ -47,12 +47,23 @@ class PyTupleOrListStaticRef { public: PyTupleOrListStaticRef(PyObject* obj) : _obj(obj) { +#ifdef DEBUG + assert(obj); + if(isTuple) assert(PyTuple_Check(obj)); + else assert(PyList_Check(obj)); +#endif if(isTuple) _size = PyTuple_GET_SIZE(obj); else _size = PyList_GET_SIZE(obj); +#ifdef DEBUG + assert(_size >= 0); +#endif } int size() const { return _size; } PyObject* getItem(int i) const { +#ifdef DEBUG + assert(i >= 0 && i < _size); +#endif if(isTuple) return PyTuple_GET_ITEM(_obj, i); else return PyList_GET_ITEM(_obj, i); } @@ -73,11 +84,18 @@ class PyTupleOrListRef { if(_type == TupleType) _size = PyTuple_GET_SIZE(obj); else if(_type == ListType) _size = PyList_GET_SIZE(obj); else _size = -1; +#ifdef DEBUG + if(_type != UnknownType) + assert(_size >= 0); +#endif } bool isValid() const { return _type != UnknownType; } int size() const { return _size; } PyObject* getItem(int i) const { +#ifdef DEBUG + assert(i >= 0 && i < _size); +#endif if(_type == TupleType) return PyTuple_GET_ITEM(_obj, i); else if(_type == ListType) return PyList_GET_ITEM(_obj, i); else return NULL; diff --git a/returnn/frontend/_native/tensor_ops.cpp b/returnn/frontend/_native/tensor_ops.cpp index 1404b57a1a..e484a49663 100644 --- a/returnn/frontend/_native/tensor_ops.cpp +++ b/returnn/frontend/_native/tensor_ops.cpp @@ -25,6 +25,40 @@ PyObject* tensorCopy( return res.release(); } +// all but time_dim_axis (or other special axes, or any axes) +static bool _copyTensorExtraToKwargs(PyObject* extra, PyObject* kwargs) { + PyObjectScopedRef batch = PyObject_GetAttrString(extra, "batch"); + if(!batch) return false; + if(batch != Py_None) { + if(PyDict_SetItemString(kwargs, "batch", batch) < 0) return false; + } + { + PyObjectScopedRef beam = PyObject_GetAttrString(extra, "beam"); + if(!beam) return false; + if(beam == Py_None && batch != Py_None) { + beam = PyObject_GetAttrString(batch, "beam"); + } + if(beam != Py_None) { + if(PyDict_SetItemString(kwargs, "beam", beam) < 0) return false; + } + } + { + PyObjectScopedRef control_flow_ctx = PyObject_GetAttrString(extra, "control_flow_ctx"); + if(!control_flow_ctx) return false; + if(control_flow_ctx != Py_None) { + if(PyDict_SetItemString(kwargs, "control_flow_ctx", control_flow_ctx) < 0) return false; + } + } + { + PyObjectScopedRef available_for_inference = PyObject_GetAttrString(extra, "available_for_inference"); + if(!available_for_inference) return false; + if(available_for_inference != Py_None) { + if(PyDict_SetItemString(kwargs, "available_for_inference", available_for_inference) < 0) return false; + } + } + return true; +} + // copy of Tensor.copy_template() PyObject* tensorCopyTemplate( PyModuleState* modState, @@ -100,35 +134,8 @@ PyObject* tensorCopyTemplate( if(PyDict_SetItemString(kwargs, "version", version) < 0) return NULL; } if(extra != Py_None) { - PyObjectScopedRef batch = PyObject_GetAttrString(extra, "batch"); - if(!batch) return NULL; - if(batch != Py_None) { - if(PyDict_SetItemString(kwargs, "batch", batch) < 0) return NULL; - } - { - PyObjectScopedRef beam = PyObject_GetAttrString(extra, "beam"); - if(!beam) return NULL; - if(beam == Py_None && batch != Py_None) { - beam = PyObject_GetAttrString(batch, "beam"); - } - if(beam != Py_None) { - if(PyDict_SetItemString(kwargs, "beam", beam) < 0) return NULL; - } - } - { - PyObjectScopedRef control_flow_ctx = PyObject_GetAttrString(extra, "control_flow_ctx"); - if(!control_flow_ctx) return NULL; - if(control_flow_ctx != Py_None) { - if(PyDict_SetItemString(kwargs, "control_flow_ctx", control_flow_ctx) < 0) return NULL; - } - } - { - PyObjectScopedRef available_for_inference = PyObject_GetAttrString(extra, "available_for_inference"); - if(!available_for_inference) return NULL; - if(available_for_inference != Py_None) { - if(PyDict_SetItemString(kwargs, "available_for_inference", available_for_inference) < 0) return NULL; - } - } + if(!_copyTensorExtraToKwargs(extra, kwargs)) + return NULL; } return PyObject_Call(modState->tensorType(), emptyArgs, kwargs); @@ -501,7 +508,9 @@ static PyObject* _compareOrCombine_subsetDims( int j = 0; for(int i = 0; i < (int) outPermutation.size(); ++i) { if(outPermutation[i] < 0) continue; - PyTuple_SET_ITEM(permuteArg.get(), j, PyLong_FromLong(outPermutation[i])); + PyObject* intObj = PyLong_FromLong(outPermutation[i]); + if(!intObj) return NULL; + PyTuple_SET_ITEM(permuteArg.get(), j, intObj); ++j; } assert(j == PyTuple_GET_SIZE(permuteArg.get())); @@ -622,18 +631,19 @@ static bool _getPermutationSupersetToSubset(const char* funcName, ASeqT subset, funcName, subset.get(), superset.get()); return false; } - assert(outPermutation.size() == superset.size()); + assert((int) outPermutation.size() == superset.size()); return true; } +template static PyObject* _permuteAndExtend( const char* rawOpName, PyObject* permuteOp, PyObject* reshapeOp, PyObject* tensor, PyTupleOrListStaticRef dims, PyObject* rawTensor, PyObject* rawShape, - PyTupleOrListStaticRef outDims, std::vector outShape + OutDimSeqT outDims, + std::vector& outPermutation ) { // First find the mapping. - std::vector outPermutation; if(!_getPermutationSupersetToSubset(rawOpName, dims, outDims, outPermutation)) return NULL; @@ -641,27 +651,44 @@ static PyObject* _permuteAndExtend( PyObjectScopedRef rawTensorExt; // just for holding the ref and decrefing it later // Maybe permute the tensor - { - PyObjectScopedRef permuteArg = PyTuple_New(PyTuple_GET_SIZE(rawShape)); + bool needPermute = false; + for(int i = 0; i < (int) outPermutation.size(); ++i) { + if(i > 0 && outPermutation[i] != outPermutation[i - 1] + 1) { + needPermute = true; + break; + } + } + if(needPermute) { + PyObjectScopedRef permuteArg = PyTuple_New(dims.size()); if(!permuteArg) return NULL; int j = 0; for(int i = 0; i < (int) outPermutation.size(); ++i) { if(outPermutation[i] < 0) continue; - PyTuple_SET_ITEM(permuteArg.get(), j, PyLong_FromLong(outPermutation[i])); + PyObject* intObj = PyLong_FromLong(outPermutation[i]); + if(!intObj) return NULL; + assert(j < dims.size()); + PyTuple_SET_ITEM(permuteArg.get(), j, intObj); ++j; } assert(j == PyTuple_GET_SIZE(permuteArg.get())); + assert(j == dims.size()); rawTensor_ = PyObject_CallFunctionObjArgs(permuteOp, rawTensor_, permuteArg.get(), NULL); if(!rawTensor_) return NULL; rawTensorExt = rawTensor_; } // Maybe reshape the tensor - { + if(outDims.size() > dims.size()) { PyObjectScopedRef rawShapeExt = PyTuple_New(outPermutation.size()); if(!rawShapeExt) return NULL; for(int i = 0; i < (int) outPermutation.size(); ++i) { - PyObject* d = PyLong_FromLong((outPermutation[i] >= 0) ? outShape[i] : 1); + PyObject* d; + if(outPermutation[i] >= 0) { + d = PyTuple_GET_ITEM(rawShape, outPermutation[i]); + Py_XINCREF(d); + } + else + d = PyLong_FromLong(1); if(!d) return NULL; PyTuple_SET_ITEM(rawShapeExt.get(), i, d); } @@ -670,7 +697,8 @@ static PyObject* _permuteAndExtend( rawTensorExt = rawTensor_; } - rawTensorExt.release(); + if(rawTensorExt) rawTensorExt.release(); + else Py_INCREF(rawTensor_); // we still have it borrowed return rawTensor_; } @@ -743,6 +771,285 @@ PyObject* pyTensorGetOutPermutationsToDims(PyObject *self, PyObject *const *args return NULL; } +template +static PyObject* tensorCopyCompatibleToDims(const char* funcName, PyModuleState* modState, PyObject* tensor, PyObject* outDims) { + PyTupleOrListRef outDimsSeq(outDims); + if(!outDimsSeq.isValid()) { + PyErr_Format(PyExc_TypeError, "%s: expected dims to be tuple or list, got %R", funcName, outDims); + return NULL; + } + + PyObjectScopedRef dims = PyObject_GetAttrString(tensor, "_dims"); + if(!dims) return NULL; + if(!PyTuple_Check(dims)) { + PyErr_Format(PyExc_TypeError, "%s: expected tensor.dims to be tuple, got %R", funcName, dims.get()); + return NULL; + } + PyTupleOrListStaticRef dimsSeq(dims); + + PyObjectScopedRef rawTensor = PyObject_GetAttrString(tensor, "_raw_tensor"); + if(!rawTensor) return NULL; + + // follow Tensor.copy_compatible_to_dims logic + + std::vector outPermutation; + PyObjectScopedRef outRawTensor; + if(rawTensor == Py_None) { + if(rawMode) { + PyErr_Format(PyExc_ValueError, "%s: tensor does not have a raw_tensor", funcName); + return NULL; + } + if(!_getPermutationSupersetToSubset(funcName, dimsSeq, outDimsSeq, outPermutation)) + return NULL; + } + else if(modState->isTorchTensorType((PyObject*) Py_TYPE(rawTensor))) { + PyObjectScopedRef rawShape = PyObject_GetAttrString(rawTensor, "shape"); + if(!rawShape) return NULL; + if(!PyTuple_Check(rawShape)) { + PyErr_Format(PyExc_TypeError, "%s: expected raw_tensor.shape to be tuple, got %R", funcName, rawShape.get()); + return NULL; + } + PyObject* permuteOp = modState->cachedOp(TOp_Permute, BWCO_Torch); + if(!permuteOp) return NULL; + PyObject* reshapeOp = modState->cachedOp(TOp_Reshape, BWCO_Torch); + if(!reshapeOp) return NULL; + outRawTensor = _permuteAndExtend(funcName, permuteOp, reshapeOp, tensor, dimsSeq, rawTensor, rawShape, outDimsSeq, outPermutation); + if(!outRawTensor) return NULL; + } + else { // generic backend fallback + PyObject* backend = getBackendForRawTensor(modState, rawTensor); + PyObjectScopedRef rawShape = PyObject_CallMethod(backend, "get_shape_tuple_raw", "O", rawTensor.get()); + if(!rawShape) return NULL; + if(!PyTuple_Check(rawShape)) { + PyErr_Format(PyExc_TypeError, "%s: expected raw_tensor.shape to be tuple, got %R", funcName, rawShape.get()); + return NULL; + } + PyObjectScopedRef permuteOp = PyObject_GetAttrString(backend, "transpose_raw"); + if(!permuteOp) return NULL; + PyObjectScopedRef reshapeOp = PyObject_GetAttrString(backend, "reshape_raw"); + if(!reshapeOp) return NULL; + outRawTensor = _permuteAndExtend(funcName, permuteOp, reshapeOp, tensor, dimsSeq, rawTensor, rawShape, outDimsSeq, outPermutation); + if(!outRawTensor) return NULL; + } + + if(rawMode) { + assert(outRawTensor); + return outRawTensor.release(); + } + + assert((int) outPermutation.size() == outDimsSeq.size()); + PyObjectScopedRef outDims_ = PyTuple_New(outPermutation.size()); + if(!outDims_) return NULL; + for(int i = 0; (size_t) i < outPermutation.size(); ++i) { + PyObject* d; + if(outPermutation[i] >= 0) { + d = outDimsSeq.getItem(i); + if(!d) return NULL; + Py_INCREF(d); + } + else { + // create dummy broadcast dim + PyObject* dim = outDimsSeq.getItem(i); + PyObjectScopedRef kind = PyObject_GetAttrString(dim, "kind"); + if(!kind) return NULL; + PyObjectScopedRef description = PyObject_GetAttrString(dim, "description"); + if(!description) return NULL; + if(description == Py_None) description = PyUnicode_InternFromString("unnamed_bc_dim1"); + else description = PyUnicode_FromFormat("%S_bc_dim1", description.get()); + if(!description) return NULL; + PyObjectScopedRef dimValue = PyLong_FromLong(1); + if(!dimValue) return NULL; + PyObjectScopedRef args = PyTuple_New(0); + if(!args) return NULL; + PyObjectScopedRef kwargs = PyDict_New(); + if(!kwargs) return NULL; + if(PyDict_SetItemString(kwargs, "kind", kind) < 0) return NULL; + if(PyDict_SetItemString(kwargs, "description", description) < 0) return NULL; + if(PyDict_SetItemString(kwargs, "dimension", dimValue) < 0) return NULL; + if(PyDict_SetItemString(kwargs, "auto_generated", Py_True) < 0) return NULL; + d = PyObject_Call(modState->dimType(), args, kwargs); + if(!d) return NULL; + } + PyTuple_SET_ITEM(outDims_.get(), i, d); + } + + PyObjectScopedRef name = PyObject_GetAttrString(tensor, "name"); + if(!name) return NULL; + PyObjectScopedRef dtype = PyObject_GetAttrString(tensor, "dtype"); + if(!dtype) return NULL; + PyObjectScopedRef feature_dim_axis = PyObject_GetAttrString(tensor, "_feature_dim_axis"); + if(!feature_dim_axis) return NULL; + if(feature_dim_axis == Py_None) {} + else if(feature_dim_axis == modState->notSpecified()) {} + else { + if(!PyLong_Check(feature_dim_axis)) { + PyErr_Format( + PyExc_TypeError, + "%s: tensor._feature_dim_axis did not return an int, from tensor dims %R", + funcName, dimsSeq.get()); + return NULL; + } + long feature_dim_axisInt = PyLong_AsLong(feature_dim_axis); + if(feature_dim_axisInt < 0) { + if(!PyErr_Occurred()) + PyErr_Format( + PyExc_ValueError, + "%s: tensor._feature_dim_axis is negative, from tensor dims %R", + funcName, dimsSeq.get()); + return NULL; + } + if(feature_dim_axisInt >= (long) dimsSeq.size()) { + PyErr_Format( + PyExc_ValueError, + "%s: tensor._feature_dim_axis is out of range, from tensor dims %R", + funcName, dimsSeq.get()); + return NULL; + } + feature_dim_axis = NULL; + for(int i = 0; i < (int) outPermutation.size(); ++i) { + if(outPermutation[i] == feature_dim_axisInt) { + feature_dim_axis = PyLong_FromLong(i); + if(!feature_dim_axis) return NULL; + break; + } + } + if(!feature_dim_axis) { + PyErr_Format( + PyExc_SystemError, + "%s: tensor._feature_dim_axis is not in out_dims, from tensor dims %R", + funcName, dimsSeq.get()); + return NULL; + } + } + PyObjectScopedRef sparse_dim = PyObject_GetAttrString(tensor, "sparse_dim"); + if(!sparse_dim) return NULL; + + PyObjectScopedRef version = PyObject_GetAttrString(tensor, "version"); + if(!version) return NULL; + if(!PyLong_Check(version)) { + PyErr_Format( + PyExc_TypeError, + "tensorCopyTemplate: tensor.version did not return an int, from version %R", version.get()); + return NULL; + } + long versionInt = PyLong_AsLong(version); + if(versionInt != 1 && versionInt != 2) { + if(!PyErr_Occurred()) + PyErr_Format( + PyExc_ValueError, + "tensorCopyTemplate: tensor.version is invalid, from version %R", version.get()); + return NULL; + } + PyObjectScopedRef extra = PyObject_GetAttrString(tensor, "_extra"); + if(!extra) return NULL; + + PyObjectScopedRef outTensor; + if(versionInt == 2 && extra == Py_None) { + outTensor = PyObject_CallFunctionObjArgs( + modState->tensorType(), name.get(), outDims_.get(), dtype.get(), NULL); + if(!outTensor) return NULL; + } + else { + PyObjectScopedRef args = PyTuple_New(0); + if(!args) return NULL; + PyObjectScopedRef kwargs = PyDict_New(); + if(!kwargs) return NULL; + if(PyDict_SetItemString(kwargs, "name", name) < 0) return NULL; + if(PyDict_SetItemString(kwargs, "dims", outDims_) < 0) return NULL; + if(PyDict_SetItemString(kwargs, "dtype", dtype) < 0) return NULL; + if(PyDict_SetItemString(kwargs, "version", version) < 0) return NULL; + + if(extra != Py_None) { + if(versionInt == 1) { + PyObjectScopedRef time_dim_axis = PyObject_GetAttrString(tensor, "_time_dim_axis"); + if(!time_dim_axis) return NULL; + if(time_dim_axis != Py_None && time_dim_axis != modState->notSpecified()) { + if(!PyLong_Check(time_dim_axis)) { + PyErr_Format( + PyExc_TypeError, + "%s: tensor._time_dim_axis did not return an int, from tensor %R", + funcName, tensor); + return NULL; + } + long time_dim_axisInt = PyLong_AsLong(time_dim_axis); + if(time_dim_axisInt < 0) { + if(!PyErr_Occurred()) + PyErr_Format( + PyExc_ValueError, + "%s: tensor._time_dim_axis is negative, from tensor %R", + funcName, tensor); + return NULL; + } + if(time_dim_axisInt >= (long) dimsSeq.size()) { + PyErr_Format( + PyExc_ValueError, + "%s: tensor._time_dim_axis is out of range, from tensor %R", + funcName, tensor); + return NULL; + } + time_dim_axis.release(); + for(int i = 0; i < (int) outPermutation.size(); ++i) { + if(outPermutation[i] == time_dim_axisInt) { + time_dim_axis = PyLong_FromLong(i); + if(!time_dim_axis) return NULL; + break; + } + } + if(!time_dim_axis) { + PyErr_Format( + PyExc_SystemError, + "%s: tensor._time_dim_axis is not in out_dims, from tensor %R", + funcName, tensor); + return NULL; + } + } + if(PyDict_SetItemString(kwargs, "time_dim_axis", time_dim_axis) < 0) return NULL; + } + + if(!_copyTensorExtraToKwargs(extra, kwargs)) return NULL; + } + + outTensor = PyObject_Call(modState->tensorType(), args, kwargs); + if(!outTensor) return NULL; + } + + if(outRawTensor) + if(PyObject_SetAttrString(outTensor, "raw_tensor", outRawTensor) < 0) + return NULL; + + if(feature_dim_axis != ((versionInt == 2) ? Py_None : modState->notSpecified())) + if(PyObject_SetAttrString(outTensor, "_feature_dim_axis", feature_dim_axis) < 0) + return NULL; + if(sparse_dim != Py_None) + if(PyObject_SetAttrString(outTensor, "sparse_dim", sparse_dim) < 0) + return NULL; + return outTensor.release(); +} + +PyObject* pyTensorCopyCompatibleToDims(PyObject *self, PyObject *const *args, Py_ssize_t nargs) { + if(nargs != 2) { + PyErr_SetString(PyExc_TypeError, "tensor_copy_compatible_to_dims() takes exactly 2 args: tensor, dims"); + return NULL; + } + + PyModuleState* modState = (PyModuleState*) PyModule_GetState(self); + if(!modState) return NULL; + + return tensorCopyCompatibleToDims("tensor_copy_compatible_to_dims", modState, args[0], args[1]); +} + +PyObject* pyTensorCopyCompatibleToDimsRaw(PyObject *self, PyObject *const *args, Py_ssize_t nargs) { + if(nargs != 2) { + PyErr_SetString(PyExc_TypeError, "tensor_copy_compatible_to_dims_raw() takes exactly 2 args: tensor, dims"); + return NULL; + } + + PyModuleState* modState = (PyModuleState*) PyModule_GetState(self); + if(!modState) return NULL; + + return tensorCopyCompatibleToDims("tensor_copy_compatible_to_dims_raw", modState, args[0], args[1]); +} + static PyObject* compareOrCombine( PyObject* a, PyObject* b, bool resultIsBool, @@ -980,7 +1287,6 @@ static PyObject* compareOrCombine( return NULL; } PyObjectScopedRef allDims = PyList_New(0); - std::vector outShape; if(!allDims) return NULL; for(int i = 0; i < aDimsSeq.size() + bDimsSeq.size(); ++i) { PyObject* dim = @@ -1016,7 +1322,6 @@ static PyObject* compareOrCombine( if(bDimsCount < 0) return NULL; if(aDimsCount <= 1 && bDimsCount <= 1) { if(PyList_Append(allDims, dim) < 0) return NULL; - outShape.push_back(dimValue); continue; } int c = 0; @@ -1029,7 +1334,6 @@ static PyObject* compareOrCombine( if(!eq) continue; } if(PyList_Append(allDims, dim_) < 0) return NULL; - outShape.push_back(dimValue); ++c; } if(c != std::max(aDimsCount, bDimsCount)) { @@ -1041,7 +1345,6 @@ static PyObject* compareOrCombine( } } PyTupleOrListStaticRef allDimsSeq(allDims); - assert(outShape.size() == (size_t) allDimsSeq.size()); // check if all dims are in a and b, or whether we need allowBroadcastAllSources bool error = false; @@ -1061,9 +1364,9 @@ static PyObject* compareOrCombine( // maybe reorder according to dimOrder if(dimOrder != Py_None) { - std::vector> outDimWithValue; - for(size_t i = 0; i < outShape.size(); ++i) - outDimWithValue.push_back(std::make_pair(allDimsSeq.getItem(i), outShape[i])); + std::vector outDims; + for(int i = 0; i < allDimsSeq.size(); ++i) + outDims.push_back(allDimsSeq.getItem(i)); struct Cmp { PyTupleOrListRef dimOrderSeq; bool hadError; @@ -1081,20 +1384,15 @@ static PyObject* compareOrCombine( } return dimOrderSeq.size(); } - bool operator()(std::pair a, std::pair b) { - return (*this)(a.first, b.first); - } bool operator()(PyObject* a, PyObject* b) { if(a == b) return false; return getIndex(a) < getIndex(b); } } cmp(dimOrderSeq); - std::stable_sort(outDimWithValue.begin(), outDimWithValue.end(), cmp); + std::stable_sort(outDims.begin(), outDims.end(), cmp); if(cmp.hadError) return NULL; - for(size_t i = 0; i < outShape.size(); ++i) { - PyList_SET_ITEM(allDims.get(), i, outDimWithValue[i].first); - outShape[i] = outDimWithValue[i].second; - } + for(size_t i = 0; i < outDims.size(); ++i) + PyList_SET_ITEM(allDims.get(), i, outDims[i]); } PyObjectScopedRef res; @@ -1109,15 +1407,17 @@ static PyObject* compareOrCombine( } { + std::vector outPermutation; PyObjectScopedRef aRawTensorExt = _permuteAndExtend( rawOpName, permuteOp, reshapeOp, a, aDimsSeq, aRawTensor, aRawShape, - allDimsSeq, outShape); + allDimsSeq, outPermutation); if(!aRawTensorExt) return NULL; + outPermutation.clear(); PyObjectScopedRef bRawTensorExt = _permuteAndExtend( rawOpName, permuteOp, reshapeOp, b, bDimsSeq, bRawTensor, bRawShape, - allDimsSeq, outShape); + allDimsSeq, outPermutation); if(!bRawTensorExt) return NULL; PyObjectScopedRef resRawTensor = PyObject_CallFunctionObjArgs( rawOp, aRawTensorExt.get(), bRawTensorExt.get(), NULL); diff --git a/returnn/frontend/_native/tensor_ops.hpp b/returnn/frontend/_native/tensor_ops.hpp index d707622f10..cca2440018 100644 --- a/returnn/frontend/_native/tensor_ops.hpp +++ b/returnn/frontend/_native/tensor_ops.hpp @@ -17,6 +17,8 @@ PyObject* pyTensorCopy(PyObject *self, PyObject *args, PyObject *kwargs); PyObject* pyTensorCopyTemplate(PyObject *self, PyObject *args, PyObject *kwargs); PyObject* pyTensorRawTensorSetter(PyObject *self, PyObject *const *args, Py_ssize_t nargs); PyObject* pyTensorGetOutPermutationsToDims(PyObject *self, PyObject *const *args, Py_ssize_t nargs); +PyObject* pyTensorCopyCompatibleToDims(PyObject *self, PyObject *const *args, Py_ssize_t nargs); +PyObject* pyTensorCopyCompatibleToDimsRaw(PyObject *self, PyObject *const *args, Py_ssize_t nargs); PyObject* pyConvertToRawTorchTensorLike(PyObject *self, PyObject *const *args, Py_ssize_t nargs); diff --git a/returnn/tensor/_tensor_extra.py b/returnn/tensor/_tensor_extra.py index f5fa11d543..d2078ea780 100644 --- a/returnn/tensor/_tensor_extra.py +++ b/returnn/tensor/_tensor_extra.py @@ -1235,6 +1235,7 @@ def copy_compatible_to( return v # This func matches _native _permuteAndExtend logic. + # This func has a native implementation (_native tensor_get_out_permutation_to_dims). def get_out_permutation_to_dims(self, dims: Sequence[Dim]) -> List[int]: """ :param dims: superset of our dims @@ -1287,6 +1288,7 @@ def get_out_permutation_to_dims(self, dims: Sequence[Dim]) -> List[int]: assert len(out_permutation) == len(dims) return out_permutation + # This function has a native implementation (_native tensor_copy_compatible_to_dims). def copy_compatible_to_dims(self: _t.Tensor, dims: Sequence[Dim]) -> _t.Tensor: """ Simpler variant of :func:`copy_compatible_to` which just takes a list of dims, @@ -1315,18 +1317,15 @@ def copy_compatible_to_dims(self: _t.Tensor, dims: Sequence[Dim]) -> _t.Tensor: ) for i, p in enumerate(out_permutation) ] - return _t.Tensor( - self.name, - out_dims, - self.dtype, - raw_tensor=raw_tensor, - version=self.version, - batch=self.batch, - beam=self.beam, - feature_dim=self.feature_dim, - sparse_dim=self.sparse_dim, - ) + kwargs = self.get_kwargs() + for special_axis_name in self.SpecialAxesNames: + if special_axis_name in kwargs and kwargs[special_axis_name] is not None: + kwargs[special_axis_name] = out_permutation.index(kwargs[special_axis_name]) + kwargs["dims"] = out_dims + kwargs["raw_tensor"] = raw_tensor + return _t.Tensor(**kwargs) + # This function has a native implementation (_native tensor_copy_compatible_to_dims_raw). def copy_compatible_to_dims_raw(self: _t.Tensor, dims: Sequence[Dim]) -> _t.RawTensorType: """ Simpler variant of :func:`copy_compatible_to` which just takes a list of dims, @@ -1335,11 +1334,11 @@ def copy_compatible_to_dims_raw(self: _t.Tensor, dims: Sequence[Dim]) -> _t.RawT :param dims: :return: raw tensor from self with dims permuted and broadcast dims added """ - out_permutation = self.get_out_permutation_to_dims(dims) raw_tensor = self._raw_tensor + assert raw_tensor is not None, f"{self} copy_compatible_to_dims_raw: no raw tensor" + out_permutation = self.get_out_permutation_to_dims(dims) if out_permutation == list(range(len(self._dims))): return raw_tensor - assert raw_tensor is not None, f"{self} copy_compatible_to_dims_raw: no raw tensor" backend = self._raw_backend raw_shape = backend.get_shape_raw(raw_tensor) raw_tensor = backend.transpose_raw(raw_tensor, [p for p in out_permutation if p >= 0]) diff --git a/returnn/tensor/tensor.py b/returnn/tensor/tensor.py index bcce3e0141..30170e11e8 100644 --- a/returnn/tensor/tensor.py +++ b/returnn/tensor/tensor.py @@ -162,7 +162,7 @@ def raw_tensor(self) -> Optional[RawTensorType]: """ return self._raw_tensor - # This is potentially replaced by native implementation (_native pyTensorRawTensorSetter). + # This is potentially replaced by native implementation (_native tensor_raw_tensor_setter). @raw_tensor.setter def raw_tensor(self, value: Optional[RawTensorType]): """ diff --git a/tests/test_torch_frontend.py b/tests/test_torch_frontend.py index 41b8372c6b..a3477a7181 100644 --- a/tests/test_torch_frontend.py +++ b/tests/test_torch_frontend.py @@ -8,7 +8,10 @@ import torch import pytest import math +import sys +import unittest +from returnn.util import better_exchook from returnn.tensor import Tensor, Dim import returnn.frontend as rf @@ -528,3 +531,26 @@ def test_Data_copy_tranpose_match_priority(): assert len(x_.dims) == 2 and x_.dims[0] is in_dim and x_.dims[1] is feat_dim x_np = x_.raw_tensor.detach().numpy() numpy.testing.assert_equal(x_np, raw_np) + + +if __name__ == "__main__": + better_exchook.install() + if len(sys.argv) <= 1: + for k, v in sorted(globals().items()): + if k.startswith("test_"): + print("-" * 40) + print("Executing: %s" % k) + try: + v() + except unittest.SkipTest as exc: + print("SkipTest:", exc) + print("-" * 40) + print("Finished all tests.") + else: + assert len(sys.argv) >= 2 + for arg in sys.argv[1:]: + print("Executing: %s" % arg) + if arg in globals(): + globals()[arg]() # assume function and execute + else: + eval(arg) # assume Python code and execute