diff --git a/xformer/Transforms/ReplaceConcat.cpp b/xformer/Transforms/ReplaceConcat.cpp index 63c9d07cf..4235f0c02 100644 --- a/xformer/Transforms/ReplaceConcat.cpp +++ b/xformer/Transforms/ReplaceConcat.cpp @@ -62,7 +62,8 @@ struct SplitConcatPattern : public OpRewritePattern { auto outputType = concatOp.getOutput().getType().cast(); Type elementType = outputType.getElementType(); ArrayRef outputShape = outputType.getShape(); - const int axis = concatOp.getAxis(); + int axis = concatOp.getAxis(); + axis = -1 ? outputType.getRank() - 1 : axis; int axisShape = 0; for (int i = 0; i < CONCAT_OP_MAX_INPUTS; i++) {