Skip to content

Commit

Permalink
RF native cleanup and optim
Browse files Browse the repository at this point in the history
Cleanup _compareOrCombine_subsetDims to use _permuteAndExtend,
which is just as efficient,
as it also only permutes or extends if necessary.

Optimize accessing raw shape:
only access it when really needed.
  • Loading branch information
albertz committed Oct 6, 2023
1 parent c1b2619 commit e176bdc
Showing 1 changed file with 74 additions and 142 deletions.
216 changes: 74 additions & 142 deletions returnn/frontend/_native/tensor_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,69 +486,6 @@ PyObject* pyConvertToRawTorchTensorLike(PyObject *self, PyObject *const *args, P
return PyObject_Call(convertOp, args_, kwargs);
}

template<bool bIsSubset, bool permutedDims>
static PyObject* _compareOrCombine_subsetDims(
PyModuleState* modState,
const char* rawOpName, bool resultIsBool,
PyObject* permuteOp, PyObject* reshapeOp, PyObject* rawOp,
PyObject* a, PyObject* b,
PyObject* aRawTensor, PyObject* bRawTensor,
PyObject* aRawShape, PyObject* bRawShape,
PyObject* aDims, PyObject* bDims,
const std::vector<int>& outPermutation
) {
// The tensor with the subset dims will be adapted to the other tensor.
PyObject* rawTensor_ = bIsSubset ? bRawTensor : aRawTensor;
PyObjectScopedRef rawTensorExt; // just for holding the ref and decrefing it later

// Maybe permute the tensor with subset dims, to match the order of the other tensor.
if(permutedDims) {
PyObjectScopedRef permuteArg = PyTuple_New(PyTuple_GET_SIZE(bIsSubset ? bDims : aDims));
if(!permuteArg) return NULL;
int j = 0;
for(int i = 0; i < (int) outPermutation.size(); ++i) {
if(outPermutation[i] < 0) continue;
PyObject* intObj = PyLong_FromLong(outPermutation[i]);
if(!intObj) return NULL;
PyTuple_SET_ITEM(permuteArg.get(), j, intObj);
++j;
}
assert(j == PyTuple_GET_SIZE(permuteArg.get()));
rawTensor_ = PyObject_CallFunctionObjArgs(permuteOp, rawTensor_, permuteArg.get(), NULL);
if(!rawTensor_) return NULL;
rawTensorExt = rawTensor_;
}

// Reshape the tensor with subset dims, to add broadcast dims, to match the dims of the other tensor.
{
PyObjectScopedRef rawShapeExt = PyTuple_New(outPermutation.size());
if(!rawShapeExt) return NULL;
for(int i = 0; i < (int) outPermutation.size(); ++i) {
PyObject* d;
if(outPermutation[i] >= 0) {
d = PyTuple_GET_ITEM(bIsSubset ? aRawShape : bRawShape, i);
Py_XINCREF(d);
}
else
d = PyLong_FromLong(1);
if(!d) return NULL;
PyTuple_SET_ITEM(rawShapeExt.get(), i, d);
}
rawTensor_ = PyObject_CallFunctionObjArgs(reshapeOp, rawTensor_, rawShapeExt.get(), NULL);
if(!rawTensor_) return NULL;
rawTensorExt = rawTensor_;
}

// Now create the result.
PyObjectScopedRef res = tensorCopyTemplateSimple(modState, bIsSubset ? a : b, rawOpName, resultIsBool ? "bool" : NULL);
if(!res) return NULL;
PyObjectScopedRef resRawTensor = PyObject_CallFunctionObjArgs(
rawOp, bIsSubset ? aRawTensor : rawTensor_, bIsSubset ? rawTensor_ : bRawTensor, NULL);
if(!resRawTensor) return NULL;
if(PyObject_SetAttrString(res, "raw_tensor", resRawTensor) < 0) return NULL;
return res.release();
}

// when it returns with false, some exception should be raised
template<typename ASeqT, typename BSeqT>
static bool _getPermutationSupersetToSubset(const char* funcName, ASeqT subset, BSeqT superset, std::vector<int>& outPermutation) {
Expand Down Expand Up @@ -635,17 +572,19 @@ static bool _getPermutationSupersetToSubset(const char* funcName, ASeqT subset,
return true;
}

// return raw tensor
template<typename OutDimSeqT>
static PyObject* _permuteAndExtend(
const char* rawOpName,
PyObject* permuteOp, PyObject* reshapeOp,
PyObject* tensor, PyTupleOrListStaticRef<true> dims, PyObject* rawTensor, PyObject* rawShape,
PyObject* permuteOp, PyObject* reshapeOp, PyObject* getShapeOp,
PyObject* tensor, PyTupleOrListStaticRef<true> dims, PyObject* rawTensor,
OutDimSeqT outDims,
std::vector<int>& outPermutation
std::vector<int>& outPermutation /* if empty, will get it from outDims */
) {
// First find the mapping.
if(!_getPermutationSupersetToSubset(rawOpName, dims, outDims, outPermutation))
if(outPermutation.empty() && !_getPermutationSupersetToSubset(rawOpName, dims, outDims, outPermutation))
return NULL;
assert((int) outPermutation.size() == outDims.size());

PyObject* rawTensor_ = rawTensor;
PyObjectScopedRef rawTensorExt; // just for holding the ref and decrefing it later
Expand Down Expand Up @@ -679,19 +618,37 @@ static PyObject* _permuteAndExtend(

// Maybe reshape the tensor
if(outDims.size() > dims.size()) {
PyObjectScopedRef rawShape = PyObject_CallFunctionObjArgs(getShapeOp, rawTensor_, NULL);
if(!rawShape) return NULL;
if(!PyTuple_Check(rawShape)) {
PyErr_Format(PyExc_TypeError, "%s: expected raw_tensor.shape to be tuple, got %R", rawOpName, rawShape.get());
return NULL;
}
if(PyTuple_GET_SIZE(rawShape.get()) != dims.size()) {
PyErr_Format(
PyExc_ValueError,
"%s: raw_tensor ndim != tensor ndim, from tensor dims %R and raw_tensor shape %R",
rawOpName, dims.get(), rawShape.get());
return NULL;
}

PyObjectScopedRef rawShapeExt = PyTuple_New(outPermutation.size());
if(!rawShapeExt) return NULL;
int j = 0;
for(int i = 0; i < (int) outPermutation.size(); ++i) {
PyObject* d;
if(outPermutation[i] >= 0) {
d = PyTuple_GET_ITEM(rawShape, outPermutation[i]);
assert(j < dims.size());
d = PyTuple_GET_ITEM(rawShape.get(), j);
Py_XINCREF(d);
++j;
}
else
d = PyLong_FromLong(1);
if(!d) return NULL;
PyTuple_SET_ITEM(rawShapeExt.get(), i, d);
}
assert(j == dims.size());
rawTensor_ = PyObject_CallFunctionObjArgs(reshapeOp, rawTensor_, rawShapeExt.get(), NULL);
if(!rawTensor_) return NULL;
rawTensorExt = rawTensor_;
Expand All @@ -702,6 +659,33 @@ static PyObject* _permuteAndExtend(
return rawTensor_;
}

template<bool bIsSubset>
static PyObject* _compareOrCombine_subsetDims(
PyModuleState* modState,
const char* rawOpName, bool resultIsBool,
PyObject* permuteOp, PyObject* reshapeOp, PyObject* getShapeOp, PyObject* rawOp,
PyObject* a, PyObject* b,
PyObject* aRawTensor, PyObject* bRawTensor,
PyTupleOrListStaticRef<true> aDims, PyTupleOrListStaticRef<true> bDims,
std::vector<int>& outPermutation
) {
// The tensor with the subset dims will be adapted to the other tensor.
PyObjectScopedRef rawTensorExt;
if(bIsSubset)
rawTensorExt = _permuteAndExtend(rawOpName, permuteOp, reshapeOp, getShapeOp, b, bDims, bRawTensor, aDims, outPermutation);
else
rawTensorExt = _permuteAndExtend(rawOpName, permuteOp, reshapeOp, getShapeOp, a, aDims, aRawTensor, bDims, outPermutation);

// Now create the result.
PyObjectScopedRef res = tensorCopyTemplateSimple(modState, bIsSubset ? a : b, rawOpName, resultIsBool ? "bool" : NULL);
if(!res) return NULL;
PyObjectScopedRef resRawTensor = PyObject_CallFunctionObjArgs(
rawOp, bIsSubset ? aRawTensor : rawTensorExt.get(), bIsSubset ? rawTensorExt.get() : bRawTensor, NULL);
if(!resRawTensor) return NULL;
if(PyObject_SetAttrString(res, "raw_tensor", resRawTensor) < 0) return NULL;
return res.release();
}

static PyObject* _consistentFeatureDim(PyObject* a, PyObject* b) {
PyObjectScopedRef aFeatureDim = PyObject_GetAttrString(a, "feature_dim");
if(!aFeatureDim) return NULL;
Expand Down Expand Up @@ -803,32 +787,24 @@ static PyObject* tensorCopyCompatibleToDims(const char* funcName, PyModuleState*
return NULL;
}
else if(modState->isTorchTensorType((PyObject*) Py_TYPE(rawTensor))) {
PyObjectScopedRef rawShape = PyObject_GetAttrString(rawTensor, "shape");
if(!rawShape) return NULL;
if(!PyTuple_Check(rawShape)) {
PyErr_Format(PyExc_TypeError, "%s: expected raw_tensor.shape to be tuple, got %R", funcName, rawShape.get());
return NULL;
}
PyObject* permuteOp = modState->cachedOp(TOp_Permute, BWCO_Torch);
if(!permuteOp) return NULL;
PyObject* reshapeOp = modState->cachedOp(TOp_Reshape, BWCO_Torch);
if(!reshapeOp) return NULL;
outRawTensor = _permuteAndExtend(funcName, permuteOp, reshapeOp, tensor, dimsSeq, rawTensor, rawShape, outDimsSeq, outPermutation);
PyObject* getShapeOp = modState->cachedOp(TOp_GetShape, BWCO_Torch);
if(!getShapeOp) return NULL;
outRawTensor = _permuteAndExtend(funcName, permuteOp, reshapeOp, getShapeOp, tensor, dimsSeq, rawTensor, outDimsSeq, outPermutation);
if(!outRawTensor) return NULL;
}
else { // generic backend fallback
PyObject* backend = getBackendForRawTensor(modState, rawTensor);
PyObjectScopedRef rawShape = PyObject_CallMethod(backend, "get_shape_tuple_raw", "O", rawTensor.get());
if(!rawShape) return NULL;
if(!PyTuple_Check(rawShape)) {
PyErr_Format(PyExc_TypeError, "%s: expected raw_tensor.shape to be tuple, got %R", funcName, rawShape.get());
return NULL;
}
PyObjectScopedRef permuteOp = PyObject_GetAttrString(backend, "transpose_raw");
if(!permuteOp) return NULL;
PyObjectScopedRef reshapeOp = PyObject_GetAttrString(backend, "reshape_raw");
if(!reshapeOp) return NULL;
outRawTensor = _permuteAndExtend(funcName, permuteOp, reshapeOp, tensor, dimsSeq, rawTensor, rawShape, outDimsSeq, outPermutation);
PyObjectScopedRef getShapeOp = PyObject_GetAttrString(backend, "get_shape_tuple_raw");
if(!getShapeOp) return NULL;
outRawTensor = _permuteAndExtend(funcName, permuteOp, reshapeOp, getShapeOp, tensor, dimsSeq, rawTensor, outDimsSeq, outPermutation);
if(!outRawTensor) return NULL;
}

Expand Down Expand Up @@ -1198,94 +1174,62 @@ static PyObject* compareOrCombine(
return res.release();
}

PyObjectScopedRef aRawShape = PyObject_CallFunctionObjArgs(getShapeOp, aRawTensor.get(), NULL);
if(!aRawShape) return NULL;
if(!PyTuple_Check(aRawShape)) {
PyErr_Format(PyExc_TypeError, "compareOrCombine: expected a.raw_tensor.shape to be tuple, got %R", aRawShape.get());
return NULL;
}

// check if bDims is a subset of aDims, in the same order (fast dim identity check only)
{
std::vector<int> outPermutation;
if(_isSeqSubsetFast(bDimsSeq, aDimsSeq, outPermutation) && (dimOrder == Py_None || _isSameSeqFast(aDimsSeq, dimOrderSeq)))
return _compareOrCombine_subsetDims<true, false>(
return _compareOrCombine_subsetDims<true>(
modState, rawOpName, resultIsBool,
permuteOp, reshapeOp, rawOp,
permuteOp, reshapeOp, getShapeOp, rawOp,
a, b,
aRawTensor, bRawTensor,
aRawShape, NULL,
aDims, bDims,
aDimsSeq, bDimsSeq,
outPermutation);
}

PyObjectScopedRef bRawShape = PyObject_CallFunctionObjArgs(getShapeOp, bRawTensor.get(), NULL);
if(!bRawShape) return NULL;
if(!PyTuple_Check(bRawShape)) {
PyErr_Format(PyExc_TypeError, "compareOrCombine: expected b.raw_tensor.shape to be tuple, got %R", bRawShape.get());
return NULL;
}

// check if aDims is a subset of bDims, in the same order (fast dim identity check only)
{
std::vector<int> outPermutation;
if(_isSeqSubsetFast(aDimsSeq, bDimsSeq, outPermutation) && (dimOrder == Py_None || _isSameSeqFast(bDimsSeq, dimOrderSeq)))
return _compareOrCombine_subsetDims<false, false>(
return _compareOrCombine_subsetDims<false>(
modState, rawOpName, resultIsBool,
permuteOp, reshapeOp, rawOp,
permuteOp, reshapeOp, getShapeOp, rawOp,
a, b,
aRawTensor, bRawTensor,
aRawShape, bRawShape,
aDims, bDims,
aDimsSeq, bDimsSeq,
outPermutation);
}

// check if bDims is a subset of aDims, maybe reordered (fast dim identity check only)
{
std::vector<int> outPermutation;
if(_isSeqSubsetReorderFast(bDimsSeq, aDimsSeq, outPermutation) && (dimOrder == Py_None || _isSameSeqFast(aDimsSeq, dimOrderSeq)))
return _compareOrCombine_subsetDims<true, true>(
return _compareOrCombine_subsetDims<true>(
modState, rawOpName, resultIsBool,
permuteOp, reshapeOp, rawOp,
permuteOp, reshapeOp, getShapeOp, rawOp,
a, b,
aRawTensor, bRawTensor,
aRawShape, bRawShape,
aDims, bDims,
aDimsSeq, bDimsSeq,
outPermutation);
}

// check if aDims is a subset of bDims, maybe reordered (fast dim identity check only)
{
std::vector<int> outPermutation;
if(_isSeqSubsetReorderFast(aDimsSeq, bDimsSeq, outPermutation) && (dimOrder == Py_None || _isSameSeqFast(bDimsSeq, dimOrderSeq)))
return _compareOrCombine_subsetDims<false, true>(
return _compareOrCombine_subsetDims<false>(
modState, rawOpName, resultIsBool,
permuteOp, reshapeOp, rawOp,
permuteOp, reshapeOp, getShapeOp, rawOp,
a, b,
aRawTensor, bRawTensor,
aRawShape, bRawShape,
aDims, bDims,
aDimsSeq, bDimsSeq,
outPermutation);
}

{
// follow the bin_op_out_template code

// collect all dims
if(aDimsSeq.size() != PyTuple_GET_SIZE(aRawShape.get())) {
PyErr_Format(
PyExc_ValueError,
"compareOrCombine: a.dims and a.raw_tensor.shape have different size, from a.dims %R and a.raw_tensor.shape %R",
aDims.get(), aRawShape.get());
return NULL;
}
if(bDimsSeq.size() != PyTuple_GET_SIZE(bRawShape.get())) {
PyErr_Format(
PyExc_ValueError,
"compareOrCombine: b.dims and b.raw_tensor.shape have different size, from b.dims %R and b.raw_tensor.shape %R",
bDims.get(), bRawShape.get());
return NULL;
}
PyObjectScopedRef allDims = PyList_New(0);
if(!allDims) return NULL;
for(int i = 0; i < aDimsSeq.size() + bDimsSeq.size(); ++i) {
Expand All @@ -1298,18 +1242,6 @@ static PyObject* compareOrCombine(
if(contains < 0) return NULL;
if(contains) continue;
}
long dimValue =
i < aDimsSeq.size() ?
PyLong_AsLong(PyTuple_GET_ITEM(aRawShape.get(), i)) :
PyLong_AsLong(PyTuple_GET_ITEM(bRawShape.get(), i - aDimsSeq.size()));
if(dimValue < 0) {
if(!PyErr_Occurred())
PyErr_Format(
PyExc_ValueError,
"compareOrCombine: a.raw_tensor.shape or b.raw_tensor.shape has negative dim, from a.raw_tensor.shape %R and b.raw_tensor.shape %R",
aRawShape.get(), bRawShape.get());
return NULL;
}
// Not simply `all_dims.append(dim)`,
// because a dim might occur multiple times in a.dims or b.dims
// (with different match_priority),
Expand Down Expand Up @@ -1409,14 +1341,14 @@ static PyObject* compareOrCombine(
{
std::vector<int> outPermutation;
PyObjectScopedRef aRawTensorExt = _permuteAndExtend(
rawOpName, permuteOp, reshapeOp,
a, aDimsSeq, aRawTensor, aRawShape,
rawOpName, permuteOp, reshapeOp, getShapeOp,
a, aDimsSeq, aRawTensor,
allDimsSeq, outPermutation);
if(!aRawTensorExt) return NULL;
outPermutation.clear();
PyObjectScopedRef bRawTensorExt = _permuteAndExtend(
rawOpName, permuteOp, reshapeOp,
b, bDimsSeq, bRawTensor, bRawShape,
rawOpName, permuteOp, reshapeOp, getShapeOp,
b, bDimsSeq, bRawTensor,
allDimsSeq, outPermutation);
if(!bRawTensorExt) return NULL;
PyObjectScopedRef resRawTensor = PyObject_CallFunctionObjArgs(
Expand Down

0 comments on commit e176bdc

Please sign in to comment.