diff --git a/returnn/frontend/_native/tensor_ops.cpp b/returnn/frontend/_native/tensor_ops.cpp index 262081c3bd..effd499914 100644 --- a/returnn/frontend/_native/tensor_ops.cpp +++ b/returnn/frontend/_native/tensor_ops.cpp @@ -486,69 +486,6 @@ PyObject* pyConvertToRawTorchTensorLike(PyObject *self, PyObject *const *args, P return PyObject_Call(convertOp, args_, kwargs); } -template -static PyObject* _compareOrCombine_subsetDims( - PyModuleState* modState, - const char* rawOpName, bool resultIsBool, - PyObject* permuteOp, PyObject* reshapeOp, PyObject* rawOp, - PyObject* a, PyObject* b, - PyObject* aRawTensor, PyObject* bRawTensor, - PyObject* aRawShape, PyObject* bRawShape, - PyObject* aDims, PyObject* bDims, - const std::vector& outPermutation -) { - // The tensor with the subset dims will be adapted to the other tensor. - PyObject* rawTensor_ = bIsSubset ? bRawTensor : aRawTensor; - PyObjectScopedRef rawTensorExt; // just for holding the ref and decrefing it later - - // Maybe permute the tensor with subset dims, to match the order of the other tensor. - if(permutedDims) { - PyObjectScopedRef permuteArg = PyTuple_New(PyTuple_GET_SIZE(bIsSubset ? bDims : aDims)); - if(!permuteArg) return NULL; - int j = 0; - for(int i = 0; i < (int) outPermutation.size(); ++i) { - if(outPermutation[i] < 0) continue; - PyObject* intObj = PyLong_FromLong(outPermutation[i]); - if(!intObj) return NULL; - PyTuple_SET_ITEM(permuteArg.get(), j, intObj); - ++j; - } - assert(j == PyTuple_GET_SIZE(permuteArg.get())); - rawTensor_ = PyObject_CallFunctionObjArgs(permuteOp, rawTensor_, permuteArg.get(), NULL); - if(!rawTensor_) return NULL; - rawTensorExt = rawTensor_; - } - - // Reshape the tensor with subset dims, to add broadcast dims, to match the dims of the other tensor. - { - PyObjectScopedRef rawShapeExt = PyTuple_New(outPermutation.size()); - if(!rawShapeExt) return NULL; - for(int i = 0; i < (int) outPermutation.size(); ++i) { - PyObject* d; - if(outPermutation[i] >= 0) { - d = PyTuple_GET_ITEM(bIsSubset ? aRawShape : bRawShape, i); - Py_XINCREF(d); - } - else - d = PyLong_FromLong(1); - if(!d) return NULL; - PyTuple_SET_ITEM(rawShapeExt.get(), i, d); - } - rawTensor_ = PyObject_CallFunctionObjArgs(reshapeOp, rawTensor_, rawShapeExt.get(), NULL); - if(!rawTensor_) return NULL; - rawTensorExt = rawTensor_; - } - - // Now create the result. - PyObjectScopedRef res = tensorCopyTemplateSimple(modState, bIsSubset ? a : b, rawOpName, resultIsBool ? "bool" : NULL); - if(!res) return NULL; - PyObjectScopedRef resRawTensor = PyObject_CallFunctionObjArgs( - rawOp, bIsSubset ? aRawTensor : rawTensor_, bIsSubset ? rawTensor_ : bRawTensor, NULL); - if(!resRawTensor) return NULL; - if(PyObject_SetAttrString(res, "raw_tensor", resRawTensor) < 0) return NULL; - return res.release(); -} - // when it returns with false, some exception should be raised template static bool _getPermutationSupersetToSubset(const char* funcName, ASeqT subset, BSeqT superset, std::vector& outPermutation) { @@ -635,17 +572,19 @@ static bool _getPermutationSupersetToSubset(const char* funcName, ASeqT subset, return true; } +// return raw tensor template static PyObject* _permuteAndExtend( const char* rawOpName, - PyObject* permuteOp, PyObject* reshapeOp, - PyObject* tensor, PyTupleOrListStaticRef dims, PyObject* rawTensor, PyObject* rawShape, + PyObject* permuteOp, PyObject* reshapeOp, PyObject* getShapeOp, + PyObject* tensor, PyTupleOrListStaticRef dims, PyObject* rawTensor, OutDimSeqT outDims, - std::vector& outPermutation + std::vector& outPermutation /* if empty, will get it from outDims */ ) { // First find the mapping. - if(!_getPermutationSupersetToSubset(rawOpName, dims, outDims, outPermutation)) + if(outPermutation.empty() && !_getPermutationSupersetToSubset(rawOpName, dims, outDims, outPermutation)) return NULL; + assert((int) outPermutation.size() == outDims.size()); PyObject* rawTensor_ = rawTensor; PyObjectScopedRef rawTensorExt; // just for holding the ref and decrefing it later @@ -679,19 +618,37 @@ static PyObject* _permuteAndExtend( // Maybe reshape the tensor if(outDims.size() > dims.size()) { + PyObjectScopedRef rawShape = PyObject_CallFunctionObjArgs(getShapeOp, rawTensor_, NULL); + if(!rawShape) return NULL; + if(!PyTuple_Check(rawShape)) { + PyErr_Format(PyExc_TypeError, "%s: expected raw_tensor.shape to be tuple, got %R", rawOpName, rawShape.get()); + return NULL; + } + if(PyTuple_GET_SIZE(rawShape.get()) != dims.size()) { + PyErr_Format( + PyExc_ValueError, + "%s: raw_tensor ndim != tensor ndim, from tensor dims %R and raw_tensor shape %R", + rawOpName, dims.get(), rawShape.get()); + return NULL; + } + PyObjectScopedRef rawShapeExt = PyTuple_New(outPermutation.size()); if(!rawShapeExt) return NULL; + int j = 0; for(int i = 0; i < (int) outPermutation.size(); ++i) { PyObject* d; if(outPermutation[i] >= 0) { - d = PyTuple_GET_ITEM(rawShape, outPermutation[i]); + assert(j < dims.size()); + d = PyTuple_GET_ITEM(rawShape.get(), j); Py_XINCREF(d); + ++j; } else d = PyLong_FromLong(1); if(!d) return NULL; PyTuple_SET_ITEM(rawShapeExt.get(), i, d); } + assert(j == dims.size()); rawTensor_ = PyObject_CallFunctionObjArgs(reshapeOp, rawTensor_, rawShapeExt.get(), NULL); if(!rawTensor_) return NULL; rawTensorExt = rawTensor_; @@ -702,6 +659,33 @@ static PyObject* _permuteAndExtend( return rawTensor_; } +template +static PyObject* _compareOrCombine_subsetDims( + PyModuleState* modState, + const char* rawOpName, bool resultIsBool, + PyObject* permuteOp, PyObject* reshapeOp, PyObject* getShapeOp, PyObject* rawOp, + PyObject* a, PyObject* b, + PyObject* aRawTensor, PyObject* bRawTensor, + PyTupleOrListStaticRef aDims, PyTupleOrListStaticRef bDims, + std::vector& outPermutation +) { + // The tensor with the subset dims will be adapted to the other tensor. + PyObjectScopedRef rawTensorExt; + if(bIsSubset) + rawTensorExt = _permuteAndExtend(rawOpName, permuteOp, reshapeOp, getShapeOp, b, bDims, bRawTensor, aDims, outPermutation); + else + rawTensorExt = _permuteAndExtend(rawOpName, permuteOp, reshapeOp, getShapeOp, a, aDims, aRawTensor, bDims, outPermutation); + + // Now create the result. + PyObjectScopedRef res = tensorCopyTemplateSimple(modState, bIsSubset ? a : b, rawOpName, resultIsBool ? "bool" : NULL); + if(!res) return NULL; + PyObjectScopedRef resRawTensor = PyObject_CallFunctionObjArgs( + rawOp, bIsSubset ? aRawTensor : rawTensorExt.get(), bIsSubset ? rawTensorExt.get() : bRawTensor, NULL); + if(!resRawTensor) return NULL; + if(PyObject_SetAttrString(res, "raw_tensor", resRawTensor) < 0) return NULL; + return res.release(); +} + static PyObject* _consistentFeatureDim(PyObject* a, PyObject* b) { PyObjectScopedRef aFeatureDim = PyObject_GetAttrString(a, "feature_dim"); if(!aFeatureDim) return NULL; @@ -803,32 +787,24 @@ static PyObject* tensorCopyCompatibleToDims(const char* funcName, PyModuleState* return NULL; } else if(modState->isTorchTensorType((PyObject*) Py_TYPE(rawTensor))) { - PyObjectScopedRef rawShape = PyObject_GetAttrString(rawTensor, "shape"); - if(!rawShape) return NULL; - if(!PyTuple_Check(rawShape)) { - PyErr_Format(PyExc_TypeError, "%s: expected raw_tensor.shape to be tuple, got %R", funcName, rawShape.get()); - return NULL; - } PyObject* permuteOp = modState->cachedOp(TOp_Permute, BWCO_Torch); if(!permuteOp) return NULL; PyObject* reshapeOp = modState->cachedOp(TOp_Reshape, BWCO_Torch); if(!reshapeOp) return NULL; - outRawTensor = _permuteAndExtend(funcName, permuteOp, reshapeOp, tensor, dimsSeq, rawTensor, rawShape, outDimsSeq, outPermutation); + PyObject* getShapeOp = modState->cachedOp(TOp_GetShape, BWCO_Torch); + if(!getShapeOp) return NULL; + outRawTensor = _permuteAndExtend(funcName, permuteOp, reshapeOp, getShapeOp, tensor, dimsSeq, rawTensor, outDimsSeq, outPermutation); if(!outRawTensor) return NULL; } else { // generic backend fallback PyObject* backend = getBackendForRawTensor(modState, rawTensor); - PyObjectScopedRef rawShape = PyObject_CallMethod(backend, "get_shape_tuple_raw", "O", rawTensor.get()); - if(!rawShape) return NULL; - if(!PyTuple_Check(rawShape)) { - PyErr_Format(PyExc_TypeError, "%s: expected raw_tensor.shape to be tuple, got %R", funcName, rawShape.get()); - return NULL; - } PyObjectScopedRef permuteOp = PyObject_GetAttrString(backend, "transpose_raw"); if(!permuteOp) return NULL; PyObjectScopedRef reshapeOp = PyObject_GetAttrString(backend, "reshape_raw"); if(!reshapeOp) return NULL; - outRawTensor = _permuteAndExtend(funcName, permuteOp, reshapeOp, tensor, dimsSeq, rawTensor, rawShape, outDimsSeq, outPermutation); + PyObjectScopedRef getShapeOp = PyObject_GetAttrString(backend, "get_shape_tuple_raw"); + if(!getShapeOp) return NULL; + outRawTensor = _permuteAndExtend(funcName, permuteOp, reshapeOp, getShapeOp, tensor, dimsSeq, rawTensor, outDimsSeq, outPermutation); if(!outRawTensor) return NULL; } @@ -1198,45 +1174,29 @@ static PyObject* compareOrCombine( return res.release(); } - PyObjectScopedRef aRawShape = PyObject_CallFunctionObjArgs(getShapeOp, aRawTensor.get(), NULL); - if(!aRawShape) return NULL; - if(!PyTuple_Check(aRawShape)) { - PyErr_Format(PyExc_TypeError, "compareOrCombine: expected a.raw_tensor.shape to be tuple, got %R", aRawShape.get()); - return NULL; - } - // check if bDims is a subset of aDims, in the same order (fast dim identity check only) { std::vector outPermutation; if(_isSeqSubsetFast(bDimsSeq, aDimsSeq, outPermutation) && (dimOrder == Py_None || _isSameSeqFast(aDimsSeq, dimOrderSeq))) - return _compareOrCombine_subsetDims( + return _compareOrCombine_subsetDims( modState, rawOpName, resultIsBool, - permuteOp, reshapeOp, rawOp, + permuteOp, reshapeOp, getShapeOp, rawOp, a, b, aRawTensor, bRawTensor, - aRawShape, NULL, - aDims, bDims, + aDimsSeq, bDimsSeq, outPermutation); } - PyObjectScopedRef bRawShape = PyObject_CallFunctionObjArgs(getShapeOp, bRawTensor.get(), NULL); - if(!bRawShape) return NULL; - if(!PyTuple_Check(bRawShape)) { - PyErr_Format(PyExc_TypeError, "compareOrCombine: expected b.raw_tensor.shape to be tuple, got %R", bRawShape.get()); - return NULL; - } - // check if aDims is a subset of bDims, in the same order (fast dim identity check only) { std::vector outPermutation; if(_isSeqSubsetFast(aDimsSeq, bDimsSeq, outPermutation) && (dimOrder == Py_None || _isSameSeqFast(bDimsSeq, dimOrderSeq))) - return _compareOrCombine_subsetDims( + return _compareOrCombine_subsetDims( modState, rawOpName, resultIsBool, - permuteOp, reshapeOp, rawOp, + permuteOp, reshapeOp, getShapeOp, rawOp, a, b, aRawTensor, bRawTensor, - aRawShape, bRawShape, - aDims, bDims, + aDimsSeq, bDimsSeq, outPermutation); } @@ -1244,13 +1204,12 @@ static PyObject* compareOrCombine( { std::vector outPermutation; if(_isSeqSubsetReorderFast(bDimsSeq, aDimsSeq, outPermutation) && (dimOrder == Py_None || _isSameSeqFast(aDimsSeq, dimOrderSeq))) - return _compareOrCombine_subsetDims( + return _compareOrCombine_subsetDims( modState, rawOpName, resultIsBool, - permuteOp, reshapeOp, rawOp, + permuteOp, reshapeOp, getShapeOp, rawOp, a, b, aRawTensor, bRawTensor, - aRawShape, bRawShape, - aDims, bDims, + aDimsSeq, bDimsSeq, outPermutation); } @@ -1258,13 +1217,12 @@ static PyObject* compareOrCombine( { std::vector outPermutation; if(_isSeqSubsetReorderFast(aDimsSeq, bDimsSeq, outPermutation) && (dimOrder == Py_None || _isSameSeqFast(bDimsSeq, dimOrderSeq))) - return _compareOrCombine_subsetDims( + return _compareOrCombine_subsetDims( modState, rawOpName, resultIsBool, - permuteOp, reshapeOp, rawOp, + permuteOp, reshapeOp, getShapeOp, rawOp, a, b, aRawTensor, bRawTensor, - aRawShape, bRawShape, - aDims, bDims, + aDimsSeq, bDimsSeq, outPermutation); } @@ -1272,20 +1230,6 @@ static PyObject* compareOrCombine( // follow the bin_op_out_template code // collect all dims - if(aDimsSeq.size() != PyTuple_GET_SIZE(aRawShape.get())) { - PyErr_Format( - PyExc_ValueError, - "compareOrCombine: a.dims and a.raw_tensor.shape have different size, from a.dims %R and a.raw_tensor.shape %R", - aDims.get(), aRawShape.get()); - return NULL; - } - if(bDimsSeq.size() != PyTuple_GET_SIZE(bRawShape.get())) { - PyErr_Format( - PyExc_ValueError, - "compareOrCombine: b.dims and b.raw_tensor.shape have different size, from b.dims %R and b.raw_tensor.shape %R", - bDims.get(), bRawShape.get()); - return NULL; - } PyObjectScopedRef allDims = PyList_New(0); if(!allDims) return NULL; for(int i = 0; i < aDimsSeq.size() + bDimsSeq.size(); ++i) { @@ -1298,18 +1242,6 @@ static PyObject* compareOrCombine( if(contains < 0) return NULL; if(contains) continue; } - long dimValue = - i < aDimsSeq.size() ? - PyLong_AsLong(PyTuple_GET_ITEM(aRawShape.get(), i)) : - PyLong_AsLong(PyTuple_GET_ITEM(bRawShape.get(), i - aDimsSeq.size())); - if(dimValue < 0) { - if(!PyErr_Occurred()) - PyErr_Format( - PyExc_ValueError, - "compareOrCombine: a.raw_tensor.shape or b.raw_tensor.shape has negative dim, from a.raw_tensor.shape %R and b.raw_tensor.shape %R", - aRawShape.get(), bRawShape.get()); - return NULL; - } // Not simply `all_dims.append(dim)`, // because a dim might occur multiple times in a.dims or b.dims // (with different match_priority), @@ -1409,14 +1341,14 @@ static PyObject* compareOrCombine( { std::vector outPermutation; PyObjectScopedRef aRawTensorExt = _permuteAndExtend( - rawOpName, permuteOp, reshapeOp, - a, aDimsSeq, aRawTensor, aRawShape, + rawOpName, permuteOp, reshapeOp, getShapeOp, + a, aDimsSeq, aRawTensor, allDimsSeq, outPermutation); if(!aRawTensorExt) return NULL; outPermutation.clear(); PyObjectScopedRef bRawTensorExt = _permuteAndExtend( - rawOpName, permuteOp, reshapeOp, - b, bDimsSeq, bRawTensor, bRawShape, + rawOpName, permuteOp, reshapeOp, getShapeOp, + b, bDimsSeq, bRawTensor, allDimsSeq, outPermutation); if(!bRawTensorExt) return NULL; PyObjectScopedRef resRawTensor = PyObject_CallFunctionObjArgs(