From 3ad7a08d2ccdf03b08c8a01d531bc5c64be655d5 Mon Sep 17 00:00:00 2001 From: panickal-xmos Date: Thu, 11 Jul 2024 12:33:25 +0100 Subject: [PATCH 1/5] Add initial support for fake slice op --- xformer/IR/XCoreOps.td | 11 ++++ xformer/Transforms/OpSplit.cpp | 113 +++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+) diff --git a/xformer/IR/XCoreOps.td b/xformer/IR/XCoreOps.td index 4b302ef26..75f9b711c 100644 --- a/xformer/IR/XCoreOps.td +++ b/xformer/IR/XCoreOps.td @@ -96,6 +96,17 @@ def XC_SliceOp : XC_Op<"slice", [Pure]> { let results = (outs TensorOf<[QI8, QI16, F32, I8, I32]> : $output); } +def XC_FakeSliceOp + : XC_Op<"fake_slice", [Pure, SameOperandsAndResultType]> { + let summary = "Fake slice op"; + + let description = [{Fake slice op.}]; + + let arguments = (ins AnyTensor : $input); + + let results = (outs AnyTensor : $output); +} + def XC_BroadcastOp : XC_Op<"broadcast", [Pure]> { let summary = "Broadcast op"; diff --git a/xformer/Transforms/OpSplit.cpp b/xformer/Transforms/OpSplit.cpp index 612e3d47e..67cce0793 100644 --- a/xformer/Transforms/OpSplit.cpp +++ b/xformer/Transforms/OpSplit.cpp @@ -178,6 +178,118 @@ struct RaiseSliceHorizontalAddPattern : public OpRewritePattern { } }; +struct RaiseSliceHorizontalMulPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::SliceOp slice, + PatternRewriter &rewriter) const override { + // Only raise slices that have been inserted with op split pass + if (!((slice->hasAttr(opSplitLabel)))) + return failure(); + + // If slice does not have a defining op, return failure + if (!(slice.getInput().getDefiningOp())) { + return failure(); + } + + if (!isa(slice.getInput().getDefiningOp())) { + return failure(); + } + + auto addOriginal = llvm::cast(slice.getInput().getDefiningOp()); + + // Do not raise slice if op does not have op split label + if (!(addOriginal->hasAttr(opSplitLabel))) + return failure(); + + auto sliceOutShape = utils::getValShape(slice.getOutput()); + + +// // if broadcast, create fakeslice op with no slices +// // find if lhs or rhs has broadcast +// auto lhsType = addOriginal.getLhs().getType().cast(); +// auto rhsType = addOriginal.getRhs().getType().cast(); +// auto outputType = addOriginal.getOutput().getType().cast(); +// if (!hasSameShape(rhsType, outputType) && +// !hasSameShape(lhsType, outputType)) { +// return failure(); +// } +// // RHS needs broadcast +// if (!hasSameShape(rhsType, outputType) { + +// } else { + +// } + +// o l +// mul +// s s s s +// op + +// l +// fs fs fs fs + +// fs +// l +// fs fs fs + +// o l +// s fs +// m m +// s s s + + +// fc +// 10 +// conv +// 40x10x1 + +// // raisefsaboveop +// // if op can be split spatially, then raise fs and convert to slice +// // if another fs already above op, then merge with that one +// // otherwise, raise above op + + +// conv +// 1x1x1x16 +// fs(4) + + + + +// conv +// 1x40x10x10 +// s s s s +// 1x10x10 + + // Create new slice for above add + auto sliceLHS = llvm::cast(rewriter.clone(*slice)); + sliceLHS.setOperand(0, addOriginal.getLhs()); + RankedTensorType sliceLHSType = RankedTensorType::get( + sliceOutShape, utils::getValElementType(addOriginal.getLhs())); + sliceLHS->getResult(0).setType(sliceLHSType); + + // Create new slice for above add + auto sliceRHS = llvm::cast(rewriter.clone(*slice)); + sliceRHS.setOperand(0, addOriginal.getRhs()); + RankedTensorType sliceRHSType = RankedTensorType::get( + sliceOutShape, utils::getValElementType(addOriginal.getRhs())); + sliceRHS->getResult(0).setType(sliceRHSType); + + auto addReplacement = llvm::cast(rewriter.clone(*addOriginal)); + RankedTensorType addReplacementType = RankedTensorType::get( + sliceOutShape, utils::getValElementType(addOriginal.getOutput())); + addReplacement->getResult(0).setType(addReplacementType); + addReplacement.setOperand(0, sliceLHS); + addReplacement.setOperand(1, sliceRHS); + + // replace slice with new slice -> new add + rewriter.replaceOp(slice, addReplacement.getOutput()); + + return success(); + } +}; + template struct RaiseSliceHorizontalPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -744,6 +856,7 @@ void OpSplit::runOnOperation() { RewritePatternSet patterns2(ctx); patterns2.insert(ctx); + patterns2.insert(ctx); patterns2.insert(ctx); patterns2.insert>(ctx); patterns2.insert>(ctx); From f48d6b0b816488211949923a45b23d4c4f3191e8 Mon Sep 17 00:00:00 2001 From: panickal-xmos Date: Thu, 11 Jul 2024 12:33:38 +0100 Subject: [PATCH 2/5] Handle batched softmax for three ranks --- xformer/Transforms/XCPatterns.td | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/xformer/Transforms/XCPatterns.td b/xformer/Transforms/XCPatterns.td index e03074c40..d4cdf8b48 100644 --- a/xformer/Transforms/XCPatterns.td +++ b/xformer/Transforms/XCPatterns.td @@ -46,8 +46,13 @@ def getExpLookupF32 def isSingleSegment : Constraint().getRank() == 2">>; -def isSingleBatch : Constraint().getDimSize(0) == 1">>; -def isMultiBatch : Constraint().getDimSize(0) != 1">>; +def isSingleBatch + : Constraint().getDimSize(0) == 1">>; +def isMultiBatch + : Constraint().getRank() == 2 && " + "$0.getType().cast().getDimSize(0) != 1) || " + "($0.getType().cast().getRank() == 3 && " + "$0.getType().cast().getDimSize(1) != 1)">>; def betaIsOne : Constraint>; @@ -61,7 +66,7 @@ def: Pat<(TFL_SoftmaxOp : $output TensorOf<[QI8]>:$input, $beta), (XC_BatchedSoftmaxOp $input, (Arith_ConstantOp (getExpLookupF32 - $output))), [(betaIsOne $beta), (isSingleSegment $input), (isMultiBatch $input)]>; + $output))), [(betaIsOne $beta), (isMultiBatch $input)]>; // Beta float activation lookup def getActivationType From aa0acb0ca70360381d9ef347a3b41f3591b8b8c7 Mon Sep 17 00:00:00 2001 From: panickal-xmos Date: Wed, 17 Jul 2024 11:02:04 +0100 Subject: [PATCH 3/5] Opsplit changes for fakeslice --- xformer/IR/XCoreOps.td | 3 +- xformer/Transforms/OpSplit.cpp | 588 +++++++++++---------- xformer/Transforms/TranslateToCustomOp.cpp | 3 + 3 files changed, 310 insertions(+), 284 deletions(-) diff --git a/xformer/IR/XCoreOps.td b/xformer/IR/XCoreOps.td index 75f9b711c..54be58203 100644 --- a/xformer/IR/XCoreOps.td +++ b/xformer/IR/XCoreOps.td @@ -96,8 +96,7 @@ def XC_SliceOp : XC_Op<"slice", [Pure]> { let results = (outs TensorOf<[QI8, QI16, F32, I8, I32]> : $output); } -def XC_FakeSliceOp - : XC_Op<"fake_slice", [Pure, SameOperandsAndResultType]> { +def XC_FakeSliceOp : XC_Op<"fake_slice", [Pure]> { let summary = "Fake slice op"; let description = [{Fake slice op.}]; diff --git a/xformer/Transforms/OpSplit.cpp b/xformer/Transforms/OpSplit.cpp index 67cce0793..ebd5aae42 100644 --- a/xformer/Transforms/OpSplit.cpp +++ b/xformer/Transforms/OpSplit.cpp @@ -1,6 +1,7 @@ // Copyright 2021 XMOS LIMITED. This Software is subject to the terms of the // XMOS Public License: Version 1 +#include "IR/XCoreOps.h" #include "Transforms/Options.h" #include "Utils/Util.h" @@ -15,6 +16,8 @@ namespace mlir::xcore { namespace { static constexpr char opSplitLabel[] = "opSplitLabel"; static constexpr char opSplitLabelNumSplits[] = "opSplitLabelNumSplits"; +static constexpr char opSplitLabelSavedNumSplits[] = + "opSplitLabelSavedNumSplits"; // OpSplit struct OpSplit : public PassWrapper> { @@ -22,6 +25,7 @@ struct OpSplit : public PassWrapper> { void getDependentDialects(DialectRegistry ®istry) const final { registry.insert(); + registry.insert(); } StringRef getArgument() const final { return "xcore-op-split"; } StringRef getDescription() const final { return "Op Split."; } @@ -43,6 +47,88 @@ TFL::SliceOp createSliceOp(PatternRewriter &rewriter, Location loc, Value input, return sliceOp; } +LogicalResult isRaisableSlice(PatternRewriter &rewriter, TFL::SliceOp slice) { + // Only raise slices that have been inserted with op split pass + if (!slice->hasAttr(opSplitLabel)) + return failure(); + + auto definingOp = slice.getInput().getDefiningOp(); + // Do not raise slice if defining op does not have op split label + if (!definingOp->hasAttr(opSplitLabel)) + return failure(); + + // all other uses of defining op must be eligible slices + // we currently only opsplit ops with one result + int numEligibleSlices = 0; + for (const mlir::OpOperand &use : definingOp->getResult(0).getUses()) { + mlir::Operation *op = use.getOwner(); + if (auto sliceOp = dyn_cast_or_null(op)) { + if (!sliceOp->hasAttr(opSplitLabel)) { + return failure(); + } else { + numEligibleSlices++; + } + } else { + return failure(); + } + } + + // no of eligible slices must be greater than or equal to set num of splits + // If more slices, we should try to combine before raising + if (!definingOp->hasAttr(opSplitLabelSavedNumSplits)) + return failure(); + auto attr = definingOp->getAttr(opSplitLabelSavedNumSplits); + int numSplits = attr.cast().getInt(); + if (numSplits != -1 && numEligibleSlices < numSplits) { + return failure(); + } else { + definingOp->setAttr(opSplitLabelSavedNumSplits, + rewriter.getI32IntegerAttr(-1)); + } + + return success(); +} + +LogicalResult combineSliceWithExisting(PatternRewriter &rewriter, + TFL::SliceOp slice) { + auto definingOp = slice.getInput().getDefiningOp(); + + // all other uses of defining op must be eligible slices + // we currently only opsplit ops with one result + SmallVector sliceOps; + for (const mlir::OpOperand &use : definingOp->getResult(0).getUses()) { + mlir::Operation *op = use.getOwner(); + if (auto sliceOp = dyn_cast_or_null(op)) { + // dont push current slice op + if (sliceOp != slice) { + sliceOps.push_back(sliceOp); + } + } + } + + auto f = slice->getParentOfType(); + int i; + for (i = 0; i < sliceOps.size(); i++) { + // if slice op matches with another op in list + // remove current one and attach to that + if (slice.getBegin() == sliceOps[i].getBegin() && + slice.getSize() == sliceOps[i].getSize()) { + break; + } + } + + if (i < sliceOps.size()) { + // slice.getOutput().replaceAllUsesWith(sliceOps[i]); + // rewriter.eraseOp(slice); + rewriter.replaceOp(slice, sliceOps[i].getOutput()); + return success(); + } + // // replace slice with new slice -> new add + // rewriter.replaceOp(fakeSlice, opReplacement->getResult(0)); + + return failure(); +} + template struct OpSplitHorizontalPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -129,41 +215,43 @@ struct RaiseSliceHorizontalAddPattern : public OpRewritePattern { LogicalResult matchAndRewrite(TFL::SliceOp slice, PatternRewriter &rewriter) const override { - // Only raise slices that have been inserted with op split pass - if (!((slice->hasAttr(opSplitLabel)))) + if (!slice.getInput().getDefiningOp() || + !isa(slice.getInput().getDefiningOp())) { return failure(); + } - // If slice does not have a defining op, return failure - if (!(slice.getInput().getDefiningOp())) { + if (failed(isRaisableSlice(rewriter, slice))) { return failure(); } - if (!isa(slice.getInput().getDefiningOp())) { - return failure(); + // combineslice with existing + // go through all uses of defining op + // find other slices and see if there is a match + // if so, erase this slice and attach to that one, remove opsplitlabel from + // attached new slice + if (succeeded(combineSliceWithExisting(rewriter, slice))) { + return success(); } auto addOriginal = llvm::cast(slice.getInput().getDefiningOp()); - // Do not raise slice if op does not have op split label - if (!(addOriginal->hasAttr(opSplitLabel))) - return failure(); - auto sliceOutShape = utils::getValShape(slice.getOutput()); - // Create new slice for above add - auto sliceLHS = llvm::cast(rewriter.clone(*slice)); - sliceLHS.setOperand(0, addOriginal.getLhs()); - RankedTensorType sliceLHSType = RankedTensorType::get( - sliceOutShape, utils::getValElementType(addOriginal.getLhs())); - sliceLHS->getResult(0).setType(sliceLHSType); - - // Create new slice for above add - auto sliceRHS = llvm::cast(rewriter.clone(*slice)); - sliceRHS.setOperand(0, addOriginal.getRhs()); - RankedTensorType sliceRHSType = RankedTensorType::get( - sliceOutShape, utils::getValElementType(addOriginal.getRhs())); - sliceRHS->getResult(0).setType(sliceRHSType); - + auto outputType = + addOriginal.getOutput().getType().cast(); + auto getSliceOp = [&](Value arg) -> Value { + auto argType = arg.getType().cast(); + auto newSlice = llvm::cast(rewriter.clone(*slice)); + newSlice.setOperand(0, arg); + RankedTensorType newSliceType = + RankedTensorType::get(sliceOutShape, utils::getValElementType(arg)); + newSlice->getResult(0).setType(newSliceType); + return newSlice; + }; + + // Create new slice for above adds + auto sliceLHS = getSliceOp(addOriginal.getLhs()); + auto sliceRHS = getSliceOp(addOriginal.getRhs()); auto addReplacement = llvm::cast(rewriter.clone(*addOriginal)); RankedTensorType addReplacementType = RankedTensorType::get( sliceOutShape, utils::getValElementType(addOriginal.getOutput())); @@ -178,108 +266,217 @@ struct RaiseSliceHorizontalAddPattern : public OpRewritePattern { } }; -struct RaiseSliceHorizontalMulPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct RaiseFakeSliceHorizontalPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TFL::SliceOp slice, + LogicalResult matchAndRewrite(FakeSliceOp fakeSlice, PatternRewriter &rewriter) const override { - // Only raise slices that have been inserted with op split pass - if (!((slice->hasAttr(opSplitLabel)))) + auto definingOp = fakeSlice.getInput().getDefiningOp(); + // If fakeslice does not have a defining op, remove it + if (!definingOp) { + rewriter.replaceOp(fakeSlice, fakeSlice->getResult(0)); + return success(); + } + + // no of fake slices must be equal to set num of splits + if (!definingOp->hasAttr(opSplitLabelSavedNumSplits)) + return failure(); + auto attr = definingOp->getAttr(opSplitLabelSavedNumSplits); + int numSplits = attr.cast().getInt(); + auto beginAttr = fakeSlice->getAttr("begin").cast(); + if (beginAttr.size() != numSplits) { + return failure(); + } + + if (!dyn_cast_or_null(definingOp) && + !dyn_cast_or_null(definingOp) && + !dyn_cast_or_null(definingOp)) { return failure(); + } + + // Do not raise slice if op does not have op split label + if (!(definingOp->hasAttr(opSplitLabel))) + return failure(); + + auto sliceOutShape = utils::getValShape(fakeSlice.getOutput()); + // Create new slice for above adds + auto sliceReplacement = rewriter.clone(*fakeSlice); + sliceReplacement->setOperand(0, definingOp->getOperand(0)); + sliceReplacement->getResult(0).setType(definingOp->getOperand(0).getType()); + auto opReplacement = rewriter.clone(*definingOp); + opReplacement->setOperand(0, sliceReplacement->getResult(0)); + + // replace slice with new slice -> new add + rewriter.replaceOp(fakeSlice, opReplacement->getResult(0)); + + return success(); + } +}; + +struct RaiseFakeSliceHorizontalMeanPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(FakeSliceOp fakeSlice, + PatternRewriter &rewriter) const override { // If slice does not have a defining op, return failure - if (!(slice.getInput().getDefiningOp())) { + if (!(fakeSlice.getInput().getDefiningOp())) { return failure(); } - if (!isa(slice.getInput().getDefiningOp())) { + auto meanOriginal = + dyn_cast_or_null(fakeSlice.getInput().getDefiningOp()); + if (!meanOriginal) { return failure(); } - auto addOriginal = llvm::cast(slice.getInput().getDefiningOp()); + fakeSlice.dump(); // Do not raise slice if op does not have op split label - if (!(addOriginal->hasAttr(opSplitLabel))) + if (!(meanOriginal->hasAttr(opSplitLabel))) return failure(); - auto sliceOutShape = utils::getValShape(slice.getOutput()); - - -// // if broadcast, create fakeslice op with no slices -// // find if lhs or rhs has broadcast -// auto lhsType = addOriginal.getLhs().getType().cast(); -// auto rhsType = addOriginal.getRhs().getType().cast(); -// auto outputType = addOriginal.getOutput().getType().cast(); -// if (!hasSameShape(rhsType, outputType) && -// !hasSameShape(lhsType, outputType)) { -// return failure(); -// } -// // RHS needs broadcast -// if (!hasSameShape(rhsType, outputType) { + auto beginAttr = fakeSlice->getAttr("begin").cast(); + auto sizeAttr = fakeSlice->getAttr("size").cast(); + assert(beginAttr.size() == sizeAttr.size()); -// } else { + SmallVector meanOps; -// } + for (int i = 0; i < beginAttr.size(); i++) { + auto begin = beginAttr.getValue()[i].cast(); + auto beginVector = std::vector{ + begin.getValues().begin(), begin.getValues().end()}; -// o l -// mul -// s s s s -// op + auto size = sizeAttr.getValue()[i].cast(); + auto sizeVector = std::vector{size.getValues().begin(), + size.getValues().end()}; -// l -// fs fs fs fs + // create slice and mean op + auto sliceReplacement = createSliceOp( + rewriter, fakeSlice.getLoc(), meanOriginal.getInput(), beginVector, + sizeVector, utils::getValElementType(meanOriginal.getInput())); -// fs -// l -// fs fs fs + auto meanReplacement = + llvm::cast(rewriter.clone(*meanOriginal)); + meanReplacement.setOperand(0, sliceReplacement); -// o l -// s fs -// m m -// s s s + meanOps.push_back(meanReplacement.getResult()); + } + // create concat and final mean op + RankedTensorType newOutputType = RankedTensorType::get( + {1, 4, 1, 16}, utils::getValElementType(meanOriginal.getOutput())); + auto newConcatOp = rewriter.create( + meanOriginal.getLoc(), newOutputType, meanOps, /*axis=*/1, "NONE"); + auto meanReplacement = + llvm::cast(rewriter.clone(*meanOriginal)); + meanReplacement.setOperand(0, newConcatOp); -// fc -// 10 -// conv -// 40x10x1 + // auto sliceOutShape = utils::getValShape(fakeSlice.getOutput()); -// // raisefsaboveop -// // if op can be split spatially, then raise fs and convert to slice -// // if another fs already above op, then merge with that one -// // otherwise, raise above op + // // Create new slice for above adds + // auto sliceReplacement = rewriter.clone(*fakeSlice); + // sliceReplacement->setOperand(0, definingOp.getOperand()); + // sliceReplacement->getResult(0).setType(definingOp.getResult().getType()); + // auto opReplacement = rewriter.clone(*definingOp); + // opReplacement->setOperand(0, sliceReplacement->getResult(0)); -// conv -// 1x1x1x16 -// fs(4) + // // replace slice with new slice -> new add + rewriter.replaceOp(fakeSlice, meanReplacement->getResult(0)); + return success(); + } +}; +struct RaiseSliceHorizontalMulPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TFL::SliceOp slice, + PatternRewriter &rewriter) const override { + auto f = slice->getParentOfType(); + // If slice does not have a defining op, return failure + if (!slice.getInput().getDefiningOp() || + !isa(slice.getInput().getDefiningOp())) { + return failure(); + } -// conv -// 1x40x10x10 -// s s s s -// 1x10x10 + if (failed(isRaisableSlice(rewriter, slice))) { + return failure(); + } - // Create new slice for above add - auto sliceLHS = llvm::cast(rewriter.clone(*slice)); - sliceLHS.setOperand(0, addOriginal.getLhs()); - RankedTensorType sliceLHSType = RankedTensorType::get( - sliceOutShape, utils::getValElementType(addOriginal.getLhs())); - sliceLHS->getResult(0).setType(sliceLHSType); + auto addOriginal = llvm::cast(slice.getInput().getDefiningOp()); - // Create new slice for above add - auto sliceRHS = llvm::cast(rewriter.clone(*slice)); - sliceRHS.setOperand(0, addOriginal.getRhs()); - RankedTensorType sliceRHSType = RankedTensorType::get( - sliceOutShape, utils::getValElementType(addOriginal.getRhs())); - sliceRHS->getResult(0).setType(sliceRHSType); + DenseElementsAttr beginAttr, sizeAttr; + if (!matchPattern(slice.getBegin(), m_Constant(&beginAttr))) { + return failure(); + } + if (!matchPattern(slice.getSize(), m_Constant(&sizeAttr))) { + return failure(); + } + auto sliceOutShape = utils::getValShape(slice.getOutput()); auto addReplacement = llvm::cast(rewriter.clone(*addOriginal)); RankedTensorType addReplacementType = RankedTensorType::get( sliceOutShape, utils::getValElementType(addOriginal.getOutput())); addReplacement->getResult(0).setType(addReplacementType); + + auto outputType = + addOriginal.getOutput().getType().cast(); + auto getSliceOp = [&](int argNo, Value arg) -> Value { + auto argType = arg.getType().cast(); + if (utils::hasSameShape(argType, outputType)) { + rewriter.setInsertionPoint(addReplacement); + auto newSlice = llvm::cast(rewriter.clone(*slice)); + newSlice.setOperand(0, arg); + RankedTensorType newSliceType = + RankedTensorType::get(sliceOutShape, utils::getValElementType(arg)); + newSlice->getResult(0).setType(newSliceType); + return newSlice; + } else { + auto fakeSlice = dyn_cast_or_null(arg.getDefiningOp()); + if (!fakeSlice) { + rewriter.setInsertionPoint(addOriginal); + auto newFsOp = + rewriter.create(arg.getLoc(), arg.getType(), arg); + + llvm::SmallVector beginVals; + beginVals.push_back(beginAttr); + newFsOp->setAttr("begin", rewriter.getArrayAttr(beginVals)); + + llvm::SmallVector sizeVals; + sizeVals.push_back(sizeAttr); + newFsOp->setAttr("size", rewriter.getArrayAttr(sizeVals)); + + auto addReplacement = + llvm::cast(rewriter.clone(*addOriginal)); + addReplacement.setOperand(argNo, newFsOp); + addOriginal.getOutput().replaceAllUsesWith(addReplacement); + rewriter.eraseOp(addOriginal); + return newFsOp; + } else { + auto begin = fakeSlice->getAttr("begin").cast(); + llvm::SmallVector beginVals = { + begin.getValue().begin(), begin.getValue().end()}; + beginVals.push_back(beginAttr); + fakeSlice->setAttr("begin", rewriter.getArrayAttr(beginVals)); + + auto size = fakeSlice->getAttr("size").cast(); + llvm::SmallVector sizeVals = { + size.getValue().begin(), size.getValue().end()}; + sizeVals.push_back(sizeAttr); + fakeSlice->setAttr("size", rewriter.getArrayAttr(sizeVals)); + } + return fakeSlice; + } + }; + + // Create new slice for above adds + auto sliceLHS = getSliceOp(0, addOriginal.getLhs()); + auto sliceRHS = getSliceOp(1, addOriginal.getRhs()); + // auto addReplacement = + // llvm::cast(rewriter.clone(*addOriginal2)); addReplacement.setOperand(0, sliceLHS); addReplacement.setOperand(1, sliceRHS); @@ -296,26 +493,28 @@ struct RaiseSliceHorizontalPattern : public OpRewritePattern { LogicalResult matchAndRewrite(TFL::SliceOp slice, PatternRewriter &rewriter) const override { - // Only raise slices that have been inserted with op split pass - if (!(slice->hasAttr(opSplitLabel))) - return failure(); - // If slice does not have a defining op, return failure - if (!(slice.getInput().getDefiningOp())) { + if (!slice.getInput().getDefiningOp() || + !isa(slice.getInput().getDefiningOp())) { return failure(); } - if (!isa(slice.getInput().getDefiningOp())) { + if (failed(isRaisableSlice(rewriter, slice))) { return failure(); } + // combineslice with existing + // go through all uses of defining op + // find other slices and see if there is a match + // if so, erase this slice and attach to that one, remove opsplitlabel from + // attached new slice + if (succeeded(combineSliceWithExisting(rewriter, slice))) { + return success(); + } + // Get data from conv needed to raise slice auto convOriginal = llvm::cast(slice.getInput().getDefiningOp()); - // Do not raise slice if op does not have op split label - if (!(convOriginal->hasAttr(opSplitLabel))) - return failure(); - auto convOriginalInput = convOriginal.getInput().getType().template cast(); auto inputHeight = convOriginalInput.getDimSize(1); @@ -635,189 +834,6 @@ void OpSplit::runOnOperation() { return; } - if (numSplits.empty()) { - int memoryThreshold = opSplitTargetSizeOption.getValue(); - // Initialize operation counter, tensor vectors, and size variables - int opNum = 0; - - std::vector unconsumedTensors; - std::vector newUnconsumedTensors; - - std::map> opSize; - std::vector sizeInfo; - - size_t currentTensorArenaSize; - size_t inputSize; - size_t outputSize; - size_t residualSize; - - // Keep a pointer to the previous operation - Operation *prevOp = nullptr; - - // Walk through each operation in the function - func.walk([&](Operation *op) { - // Ignore constant and quantized constant operations - if (!op->hasTrait() && - !llvm::isa(op)) { - - // Helper function to compute the size of a tensor - auto computeTensorSize = [](mlir::Type type) -> size_t { - mlir::TensorType tensorType = type.cast(); - mlir::ArrayRef shape = tensorType.getShape(); - size_t tensorSize = 1; - - for (int64_t dim : shape) { - tensorSize *= dim; - } - - return tensorSize; - }; - - // Clear the contents of the vector - newUnconsumedTensors.clear(); - // Iterate over unconsumed tensors and remove those consumed by the - // current operation - for (const mlir::Value &tensor : unconsumedTensors) { - bool shouldRemove = false; - for (mlir::Value::use_iterator it = tensor.use_begin(), - e = tensor.use_end(); - it != e; ++it) { - if ((*it).getOwner() == op) { - shouldRemove = true; - break; - } - } - if (!shouldRemove) { - newUnconsumedTensors.push_back(tensor); - } - } - // Update unconsumed tensors with the new vector - unconsumedTensors = newUnconsumedTensors; - - currentTensorArenaSize = 0; - - residualSize = 0; - // Iterate over the unconsumed tensors and compute their sizes - for (mlir::Value tensor : unconsumedTensors) { - residualSize += computeTensorSize(tensor.getType()); - currentTensorArenaSize += computeTensorSize(tensor.getType()); - } - - inputSize = 0; - // Iterate over the input operands and compute their sizes - for (mlir::Value input : op->getOperands()) { - if (!input.getType().isa()) { - continue; - } - if (input.getDefiningOp() && - (input.getDefiningOp()->hasTrait() || - llvm::isa(input.getDefiningOp()))) { - continue; - } - - inputSize += computeTensorSize(input.getType()); - currentTensorArenaSize += computeTensorSize(input.getType()); - - // If input tensor has more than one use and was created by the - // previous operation, add it to unconsumed tensors - if ((std::distance(input.use_begin(), input.use_end()) > 1) && - (input.getDefiningOp() == prevOp)) { - unconsumedTensors.push_back(input); - } - } - - outputSize = 0; - // Iterate over the output results and compute their sizes - for (mlir::Value output : op->getResults()) { - if (!output.getType().isa()) { - continue; - } - if (output.getDefiningOp() && - (output.getDefiningOp()->hasTrait() || - llvm::isa(output.getDefiningOp()))) { - continue; - } - outputSize += computeTensorSize(output.getType()); - currentTensorArenaSize += computeTensorSize(output.getType()); - } - - sizeInfo = {currentTensorArenaSize, inputSize, outputSize, - residualSize}; - opSize[opNum] = sizeInfo; - - // Increment operation counter - opNum++; - - // Update the previous operation pointer - prevOp = op; - } - }); - - double size = 0; - std::vector aboveThreshold; - std::vector belowThreshold; - bool crossedThreshold = false; - - for (auto it = opSize.rbegin(); it != opSize.rend(); ++it) { - size = it->second[0]; - auto opId = it->first; - if (size > memoryThreshold) { - if (!crossedThreshold) { - outputSize = it->second[2]; - // if 2 * output size is greater than the threshold, - // concat will be greater than the threshold - // so add the next op - if (2 * outputSize > memoryThreshold) { - aboveThreshold.push_back(opId + 1); - } else { - aboveThreshold.push_back(opId); - } - crossedThreshold = true; - } - } else { - if (crossedThreshold) { - belowThreshold.push_back(opId); - crossedThreshold = false; - } - } - } - - // If the first operation was above the threshold, add it, 0, to - // belowThreshold - if (crossedThreshold) { - belowThreshold.push_back(0); - } - - // adjust threshold trackers if size goes below threshold for only one - // operation - for (size_t i = 0; i < aboveThreshold.size(); ++i) { - if (i > 0 && belowThreshold[i - 1] - aboveThreshold[i] <= 1) { - aboveThreshold.erase(aboveThreshold.begin() + i); - belowThreshold.erase(belowThreshold.begin() + i - 1); - // Decrement the indices to account for the removed elements - --i; - } - } - - // Clear the llvm::cl::list containers first - startOps.clear(); - endOps.clear(); - // Copy the elements from the std::vector containers - for (int value : aboveThreshold) { - startOps.push_back(value); - } - for (int value : belowThreshold) { - endOps.push_back(value); - } - for (size_t i = 0; i < startOps.size(); ++i) { - numSplits.push_back(8); - } - - } // if numSplits - OpBuilder builder(func); for (int i = 0; i < startOps.size(); ++i) { int k = 0; @@ -835,9 +851,13 @@ void OpSplit::runOnOperation() { } else { // add label to insert slice under op later op->setAttr(opSplitLabelNumSplits, builder.getI32IntegerAttr(numSplits[i])); + op->setAttr(opSplitLabelSavedNumSplits, + builder.getI32IntegerAttr(numSplits[i])); } } else if (k < startOps[i] && k >= endOps[i]) { op->setAttr(opSplitLabel, builder.getUnitAttr()); + op->setAttr(opSplitLabelSavedNumSplits, + builder.getI32IntegerAttr(numSplits[i])); } k++; } @@ -861,7 +881,11 @@ void OpSplit::runOnOperation() { patterns2.insert>(ctx); patterns2.insert>(ctx); + patterns2.insert(ctx); + patterns2.insert(ctx); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns2)); + } // void OpSplit::runOnOperation() { } // namespace diff --git a/xformer/Transforms/TranslateToCustomOp.cpp b/xformer/Transforms/TranslateToCustomOp.cpp index 5a4ba6eef..c4c302751 100644 --- a/xformer/Transforms/TranslateToCustomOp.cpp +++ b/xformer/Transforms/TranslateToCustomOp.cpp @@ -29,6 +29,8 @@ std::vector BinaryI16Op::buildCustomOptions() { return fbb.GetBuffer(); } +std::vector FakeSliceOp::buildCustomOptions() { return {}; } + std::vector Beta_ActivationF32Op::buildCustomOptions() { flexbuffers::Builder fbb; fbb.Map([&]() { fbb.Int("type", (int32_t)getType()); }); @@ -263,6 +265,7 @@ void TranslateToCustomOp::runOnOperation() { patterns.insert>(ctx); patterns.insert>(ctx); patterns.insert>(ctx); + patterns.insert>(ctx); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } From e08352f6d148c850233c8149ef56ecc5b8f43be0 Mon Sep 17 00:00:00 2001 From: panickal-xmos Date: Tue, 30 Jul 2024 14:12:51 +0100 Subject: [PATCH 4/5] Cleanup fakeslice passes --- xformer/Transforms/OpSplit.cpp | 357 +++++++++++++++++---------------- 1 file changed, 189 insertions(+), 168 deletions(-) diff --git a/xformer/Transforms/OpSplit.cpp b/xformer/Transforms/OpSplit.cpp index ebd5aae42..29fcefaae 100644 --- a/xformer/Transforms/OpSplit.cpp +++ b/xformer/Transforms/OpSplit.cpp @@ -15,9 +15,8 @@ namespace mlir::xcore { namespace { static constexpr char opSplitLabel[] = "opSplitLabel"; +static constexpr char opSplitLabelStartSplits[] = "opSplitLabelStartSplits"; static constexpr char opSplitLabelNumSplits[] = "opSplitLabelNumSplits"; -static constexpr char opSplitLabelSavedNumSplits[] = - "opSplitLabelSavedNumSplits"; // OpSplit struct OpSplit : public PassWrapper> { @@ -57,8 +56,8 @@ LogicalResult isRaisableSlice(PatternRewriter &rewriter, TFL::SliceOp slice) { if (!definingOp->hasAttr(opSplitLabel)) return failure(); - // all other uses of defining op must be eligible slices - // we currently only opsplit ops with one result + // All other uses of defining op must be eligible slices + // We currently only opsplit ops with one result int numEligibleSlices = 0; for (const mlir::OpOperand &use : definingOp->getResult(0).getUses()) { mlir::Operation *op = use.getOwner(); @@ -73,17 +72,17 @@ LogicalResult isRaisableSlice(PatternRewriter &rewriter, TFL::SliceOp slice) { } } - // no of eligible slices must be greater than or equal to set num of splits - // If more slices, we should try to combine before raising - if (!definingOp->hasAttr(opSplitLabelSavedNumSplits)) + // No of eligible slices must be greater than or equal to set num of splits + // If there are more slices, we should try to combine before raising + if (!definingOp->hasAttr(opSplitLabelNumSplits)) return failure(); - auto attr = definingOp->getAttr(opSplitLabelSavedNumSplits); + + auto attr = definingOp->getAttr(opSplitLabelNumSplits); int numSplits = attr.cast().getInt(); if (numSplits != -1 && numEligibleSlices < numSplits) { return failure(); } else { - definingOp->setAttr(opSplitLabelSavedNumSplits, - rewriter.getI32IntegerAttr(-1)); + definingOp->setAttr(opSplitLabelNumSplits, rewriter.getI32IntegerAttr(-1)); } return success(); @@ -93,13 +92,24 @@ LogicalResult combineSliceWithExisting(PatternRewriter &rewriter, TFL::SliceOp slice) { auto definingOp = slice.getInput().getDefiningOp(); - // all other uses of defining op must be eligible slices - // we currently only opsplit ops with one result + // All other uses of defining op must be slices + // We currently only opsplit ops with one result SmallVector sliceOps; for (const mlir::OpOperand &use : definingOp->getResult(0).getUses()) { mlir::Operation *op = use.getOwner(); if (auto sliceOp = dyn_cast_or_null(op)) { - // dont push current slice op + // We only support slices on height dimension + // Slice must have rank 4 and dim 0, 2, 3 must be the same for input and + // output + auto inType = sliceOp.getInput().getType().cast(); + auto outType = sliceOp.getOutput().getType().cast(); + if (!inType.getRank() == 4 || + inType.getDimSize(0) != outType.getDimSize(0) || + inType.getDimSize(2) != outType.getDimSize(2) || + inType.getDimSize(3) != outType.getDimSize(3)) { + return failure(); + } + // Dont push current slice op if (sliceOp != slice) { sliceOps.push_back(sliceOp); } @@ -107,36 +117,78 @@ LogicalResult combineSliceWithExisting(PatternRewriter &rewriter, } auto f = slice->getParentOfType(); + + // Get begin index for slice + DenseElementsAttr attr; + int sliceBegin, sliceSize, candidateBegin, candidateSize; + if (!matchPattern(slice.getBegin(), m_Constant(&attr))) { + return failure(); + } + sliceBegin = attr.getValues()[1]; + if (!matchPattern(slice.getSize(), m_Constant(&attr))) { + return failure(); + } + sliceSize = attr.getValues()[1]; + int i; for (i = 0; i < sliceOps.size(); i++) { - // if slice op matches with another op in list + // If slice op matches with another op in list, // remove current one and attach to that - if (slice.getBegin() == sliceOps[i].getBegin() && - slice.getSize() == sliceOps[i].getSize()) { + // Only need to consider height dimension as we only slice on that + if (!matchPattern(sliceOps[i].getBegin(), m_Constant(&attr))) { + return failure(); + } + candidateBegin = attr.getValues()[1]; + if (!matchPattern(sliceOps[i].getSize(), m_Constant(&attr))) { + return failure(); + } + candidateSize = attr.getValues()[1]; + + if (sliceBegin >= candidateBegin && + sliceBegin + sliceSize <= candidateBegin + candidateSize) { break; } } if (i < sliceOps.size()) { - // slice.getOutput().replaceAllUsesWith(sliceOps[i]); - // rewriter.eraseOp(slice); - rewriter.replaceOp(slice, sliceOps[i].getOutput()); + // Slice can be removed + if (sliceBegin == candidateBegin && sliceSize == candidateSize) { + rewriter.replaceOp(slice, sliceOps[i].getOutput()); + } else { + // Create new slice + if (!matchPattern(sliceOps[i].getBegin(), m_Constant(&attr))) { + return failure(); + } + int32_t newBeginAttr[4] = { + attr.getValues()[0], sliceBegin - candidateBegin, + attr.getValues()[2], attr.getValues()[3]}; + if (!matchPattern(slice.getSize(), m_Constant(&attr))) { + return failure(); + } + // Same as slice size attr + int32_t newSizeAttr[4] = { + attr.getValues()[0], attr.getValues()[1], + attr.getValues()[2], attr.getValues()[3]}; + auto newSlice = createSliceOp( + rewriter, sliceOps[i].getLoc(), sliceOps[i], newBeginAttr, + newSizeAttr, slice.getOutput().getType().getElementType()); + newSlice->removeAttr(opSplitLabel); + rewriter.replaceOp(slice, newSlice.getOutput()); + } return success(); } - // // replace slice with new slice -> new add - // rewriter.replaceOp(fakeSlice, opReplacement->getResult(0)); - return failure(); } template -struct OpSplitHorizontalPattern : public OpRewritePattern { +struct OpSplitPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TargetOp targetOp, PatternRewriter &rewriter) const override { // Do not split ops already split - if (!(targetOp->hasAttr(opSplitLabelNumSplits))) + // Only start splitting if the label is present + if (!(targetOp->hasAttr(opSplitLabelStartSplits))) return failure(); int numSplits = 0; @@ -210,112 +262,73 @@ struct OpSplitHorizontalPattern : public OpRewritePattern { } }; -struct RaiseSliceHorizontalAddPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TFL::SliceOp slice, - PatternRewriter &rewriter) const override { - if (!slice.getInput().getDefiningOp() || - !isa(slice.getInput().getDefiningOp())) { - return failure(); - } - - if (failed(isRaisableSlice(rewriter, slice))) { - return failure(); - } - - // combineslice with existing - // go through all uses of defining op - // find other slices and see if there is a match - // if so, erase this slice and attach to that one, remove opsplitlabel from - // attached new slice - if (succeeded(combineSliceWithExisting(rewriter, slice))) { - return success(); - } - - auto addOriginal = llvm::cast(slice.getInput().getDefiningOp()); - - auto sliceOutShape = utils::getValShape(slice.getOutput()); - - auto outputType = - addOriginal.getOutput().getType().cast(); - auto getSliceOp = [&](Value arg) -> Value { - auto argType = arg.getType().cast(); - auto newSlice = llvm::cast(rewriter.clone(*slice)); - newSlice.setOperand(0, arg); - RankedTensorType newSliceType = - RankedTensorType::get(sliceOutShape, utils::getValElementType(arg)); - newSlice->getResult(0).setType(newSliceType); - return newSlice; - }; - - // Create new slice for above adds - auto sliceLHS = getSliceOp(addOriginal.getLhs()); - auto sliceRHS = getSliceOp(addOriginal.getRhs()); - auto addReplacement = llvm::cast(rewriter.clone(*addOriginal)); - RankedTensorType addReplacementType = RankedTensorType::get( - sliceOutShape, utils::getValElementType(addOriginal.getOutput())); - addReplacement->getResult(0).setType(addReplacementType); - addReplacement.setOperand(0, sliceLHS); - addReplacement.setOperand(1, sliceRHS); - - // replace slice with new slice -> new add - rewriter.replaceOp(slice, addReplacement.getOutput()); - - return success(); - } -}; - -struct RaiseFakeSliceHorizontalPattern : public OpRewritePattern { +template +struct RaiseFakeSlicePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(FakeSliceOp fakeSlice, PatternRewriter &rewriter) const override { auto definingOp = fakeSlice.getInput().getDefiningOp(); - // If fakeslice does not have a defining op, remove it - if (!definingOp) { - rewriter.replaceOp(fakeSlice, fakeSlice->getResult(0)); + // If fakeslice is only defined by const ops, remove it as we have reached + // the top + if (dyn_cast_or_null(definingOp) || + dyn_cast_or_null(definingOp)) { + rewriter.replaceOp(fakeSlice, fakeSlice->getOperand(0)); return success(); } - // no of fake slices must be equal to set num of splits - if (!definingOp->hasAttr(opSplitLabelSavedNumSplits)) + // No of fake slices must be equal to set num of splits + if (!definingOp->hasAttr(opSplitLabelNumSplits)) return failure(); - auto attr = definingOp->getAttr(opSplitLabelSavedNumSplits); + auto attr = definingOp->getAttr(opSplitLabelNumSplits); int numSplits = attr.cast().getInt(); auto beginAttr = fakeSlice->getAttr("begin").cast(); if (beginAttr.size() != numSplits) { return failure(); } - if (!dyn_cast_or_null(definingOp) && - !dyn_cast_or_null(definingOp) && - !dyn_cast_or_null(definingOp)) { + if (!dyn_cast_or_null(definingOp)) { return failure(); } - // Do not raise slice if op does not have op split label + // Do not raise fake slice if op does not have op split label if (!(definingOp->hasAttr(opSplitLabel))) return failure(); auto sliceOutShape = utils::getValShape(fakeSlice.getOutput()); - // Create new slice for above adds + // Create new fake slice above op auto sliceReplacement = rewriter.clone(*fakeSlice); sliceReplacement->setOperand(0, definingOp->getOperand(0)); sliceReplacement->getResult(0).setType(definingOp->getOperand(0).getType()); auto opReplacement = rewriter.clone(*definingOp); opReplacement->setOperand(0, sliceReplacement->getResult(0)); - // replace slice with new slice -> new add + // replace fakeslice with new fakeslice -> op rewriter.replaceOp(fakeSlice, opReplacement->getResult(0)); return success(); } }; -struct RaiseFakeSliceHorizontalMeanPattern - : public OpRewritePattern { +template +struct RaiseFakeSliceConstPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(FakeSliceOp fakeSlice, + PatternRewriter &rewriter) const override { + auto definingOp = fakeSlice.getInput().getDefiningOp(); + // If fakeslice is only defined by const ops, remove it as we have reached + // the top + if (dyn_cast_or_null(definingOp)) { + rewriter.replaceOp(fakeSlice, fakeSlice->getOperand(0)); + return success(); + } + return failure(); + } +}; + +struct RaiseFakeSliceToSliceMeanPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(FakeSliceOp fakeSlice, @@ -331,8 +344,6 @@ struct RaiseFakeSliceHorizontalMeanPattern return failure(); } - fakeSlice.dump(); - // Do not raise slice if op does not have op split label if (!(meanOriginal->hasAttr(opSplitLabel))) return failure(); @@ -341,8 +352,18 @@ struct RaiseFakeSliceHorizontalMeanPattern auto sizeAttr = fakeSlice->getAttr("size").cast(); assert(beginAttr.size() == sizeAttr.size()); - SmallVector meanOps; + // We only support all equal slices at the moment + assert(sizeAttr.size() > 1); + auto size = sizeAttr.getValue()[0].cast(); + for (int i = 1; i < sizeAttr.size(); i++) { + auto size2 = sizeAttr.getValue()[i].cast(); + if (size != size2) { + return failure(); + } + } + // For each begin and size attr, create slice and corresponding mean op + SmallVector meanOps; for (int i = 0; i < beginAttr.size(); i++) { auto begin = beginAttr.getValue()[i].cast(); auto beginVector = std::vector{ @@ -352,7 +373,7 @@ struct RaiseFakeSliceHorizontalMeanPattern auto sizeVector = std::vector{size.getValues().begin(), size.getValues().end()}; - // create slice and mean op + // Create slice and mean op auto sliceReplacement = createSliceOp( rewriter, fakeSlice.getLoc(), meanOriginal.getInput(), beginVector, sizeVector, utils::getValElementType(meanOriginal.getInput())); @@ -363,9 +384,11 @@ struct RaiseFakeSliceHorizontalMeanPattern meanOps.push_back(meanReplacement.getResult()); } - // create concat and final mean op + // Create concat and final mean op + auto meanOutShape = utils::getValShape(meanOriginal.getOutput()); RankedTensorType newOutputType = RankedTensorType::get( - {1, 4, 1, 16}, utils::getValElementType(meanOriginal.getOutput())); + {1, static_cast(beginAttr.size()), 1, meanOutShape[3]}, + utils::getValElementType(meanOriginal.getOutput())); auto newConcatOp = rewriter.create( meanOriginal.getLoc(), newOutputType, meanOps, /*axis=*/1, "NONE"); @@ -373,24 +396,15 @@ struct RaiseFakeSliceHorizontalMeanPattern llvm::cast(rewriter.clone(*meanOriginal)); meanReplacement.setOperand(0, newConcatOp); - // auto sliceOutShape = utils::getValShape(fakeSlice.getOutput()); - - // // Create new slice for above adds - // auto sliceReplacement = rewriter.clone(*fakeSlice); - // sliceReplacement->setOperand(0, definingOp.getOperand()); - // sliceReplacement->getResult(0).setType(definingOp.getResult().getType()); - - // auto opReplacement = rewriter.clone(*definingOp); - // opReplacement->setOperand(0, sliceReplacement->getResult(0)); - - // // replace slice with new slice -> new add + // Replace fake slice with new slices -> mean ops -> concat -> mean rewriter.replaceOp(fakeSlice, meanReplacement->getResult(0)); return success(); } }; -struct RaiseSliceHorizontalMulPattern : public OpRewritePattern { +template +struct RaiseSliceBinaryPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TFL::SliceOp slice, @@ -398,7 +412,7 @@ struct RaiseSliceHorizontalMulPattern : public OpRewritePattern { auto f = slice->getParentOfType(); // If slice does not have a defining op, return failure if (!slice.getInput().getDefiningOp() || - !isa(slice.getInput().getDefiningOp())) { + !isa(slice.getInput().getDefiningOp())) { return failure(); } @@ -406,7 +420,11 @@ struct RaiseSliceHorizontalMulPattern : public OpRewritePattern { return failure(); } - auto addOriginal = llvm::cast(slice.getInput().getDefiningOp()); + if (succeeded(combineSliceWithExisting(rewriter, slice))) { + return success(); + } + + auto opOriginal = llvm::cast(slice.getInput().getDefiningOp()); DenseElementsAttr beginAttr, sizeAttr; if (!matchPattern(slice.getBegin(), m_Constant(&beginAttr))) { @@ -417,17 +435,17 @@ struct RaiseSliceHorizontalMulPattern : public OpRewritePattern { } auto sliceOutShape = utils::getValShape(slice.getOutput()); - auto addReplacement = llvm::cast(rewriter.clone(*addOriginal)); - RankedTensorType addReplacementType = RankedTensorType::get( - sliceOutShape, utils::getValElementType(addOriginal.getOutput())); - addReplacement->getResult(0).setType(addReplacementType); + auto opReplacement = llvm::cast(rewriter.clone(*opOriginal)); + RankedTensorType opReplacementType = RankedTensorType::get( + sliceOutShape, utils::getValElementType(opOriginal.getOutput())); + opReplacement->getResult(0).setType(opReplacementType); auto outputType = - addOriginal.getOutput().getType().cast(); + opOriginal.getOutput().getType().template cast(); auto getSliceOp = [&](int argNo, Value arg) -> Value { auto argType = arg.getType().cast(); if (utils::hasSameShape(argType, outputType)) { - rewriter.setInsertionPoint(addReplacement); + rewriter.setInsertionPoint(opReplacement); auto newSlice = llvm::cast(rewriter.clone(*slice)); newSlice.setOperand(0, arg); RankedTensorType newSliceType = @@ -437,7 +455,7 @@ struct RaiseSliceHorizontalMulPattern : public OpRewritePattern { } else { auto fakeSlice = dyn_cast_or_null(arg.getDefiningOp()); if (!fakeSlice) { - rewriter.setInsertionPoint(addOriginal); + rewriter.setInsertionPoint(opOriginal); auto newFsOp = rewriter.create(arg.getLoc(), arg.getType(), arg); @@ -449,11 +467,11 @@ struct RaiseSliceHorizontalMulPattern : public OpRewritePattern { sizeVals.push_back(sizeAttr); newFsOp->setAttr("size", rewriter.getArrayAttr(sizeVals)); - auto addReplacement = - llvm::cast(rewriter.clone(*addOriginal)); - addReplacement.setOperand(argNo, newFsOp); - addOriginal.getOutput().replaceAllUsesWith(addReplacement); - rewriter.eraseOp(addOriginal); + auto opReplacement = + llvm::cast(rewriter.clone(*opOriginal)); + opReplacement.setOperand(argNo, newFsOp); + opOriginal.getOutput().replaceAllUsesWith(opReplacement); + rewriter.eraseOp(opOriginal); return newFsOp; } else { auto begin = fakeSlice->getAttr("begin").cast(); @@ -472,23 +490,21 @@ struct RaiseSliceHorizontalMulPattern : public OpRewritePattern { } }; - // Create new slice for above adds - auto sliceLHS = getSliceOp(0, addOriginal.getLhs()); - auto sliceRHS = getSliceOp(1, addOriginal.getRhs()); - // auto addReplacement = - // llvm::cast(rewriter.clone(*addOriginal2)); - addReplacement.setOperand(0, sliceLHS); - addReplacement.setOperand(1, sliceRHS); + // Create new slices for above op + auto sliceLHS = getSliceOp(0, opOriginal.getLhs()); + auto sliceRHS = getSliceOp(1, opOriginal.getRhs()); + opReplacement.setOperand(0, sliceLHS); + opReplacement.setOperand(1, sliceRHS); - // replace slice with new slice -> new add - rewriter.replaceOp(slice, addReplacement.getOutput()); + // replace slice with new slice -> new op + rewriter.replaceOp(slice, opReplacement.getOutput()); return success(); } }; template -struct RaiseSliceHorizontalPattern : public OpRewritePattern { +struct RaiseSlicePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TFL::SliceOp slice, @@ -503,11 +519,6 @@ struct RaiseSliceHorizontalPattern : public OpRewritePattern { return failure(); } - // combineslice with existing - // go through all uses of defining op - // find other slices and see if there is a match - // if so, erase this slice and attach to that one, remove opsplitlabel from - // attached new slice if (succeeded(combineSliceWithExisting(rewriter, slice))) { return success(); } @@ -680,7 +691,7 @@ struct RaiseSliceHorizontalPattern : public OpRewritePattern { } }; -struct RaiseSliceHorizontalPadPattern : public OpRewritePattern { +struct RaiseSlicePadPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TFL::SliceOp slice, @@ -717,16 +728,16 @@ struct RaiseSliceHorizontalPadPattern : public OpRewritePattern { auto outputWidth = padOriginalOutput.getDimSize(2); auto outputChannels = padOriginalOutput.getDimSize(3); - int64_t padVertical, padHorizontal; + int64_t padVertical, pad; int64_t padTop, padBottom, padLeft, padRight; padVertical = outputHeight - inputHeight; - padHorizontal = outputWidth - inputWidth; + pad = outputWidth - inputWidth; padTop = padVertical / 2; padBottom = padVertical - padTop; - padLeft = padHorizontal / 2; - padRight = padHorizontal - padLeft; + padLeft = pad / 2; + padRight = pad - padLeft; // Get original slice's output height auto sliceOutput = slice.getOutput().getType().cast(); @@ -780,13 +791,13 @@ struct RaiseSliceHorizontalPadPattern : public OpRewritePattern { auto paddedHeight = newOutputHeight + padTop + padBottom; auto paddedWidth = inputWidth + padLeft + padRight; - DenseIntElementsAttr pad; - if (!matchPattern(padOriginal.getPadding(), m_Constant(&pad))) { + DenseIntElementsAttr padAttr; + if (!matchPattern(padOriginal.getPadding(), m_Constant(&padAttr))) { return failure(); } // Keep padding values the same in the last dimension - auto padVal = pad.getValues(); + auto padVal = padAttr.getValues(); std::vector paddingValues{0, 0, @@ -849,14 +860,13 @@ void OpSplit::runOnOperation() { sliceOp.getInput().getDefiningOp()->setAttr(opSplitLabel, builder.getUnitAttr()); } else { // add label to insert slice under op later + op->setAttr(opSplitLabelStartSplits, builder.getUnitAttr()); op->setAttr(opSplitLabelNumSplits, builder.getI32IntegerAttr(numSplits[i])); - op->setAttr(opSplitLabelSavedNumSplits, - builder.getI32IntegerAttr(numSplits[i])); } } else if (k < startOps[i] && k >= endOps[i]) { op->setAttr(opSplitLabel, builder.getUnitAttr()); - op->setAttr(opSplitLabelSavedNumSplits, + op->setAttr(opSplitLabelNumSplits, builder.getI32IntegerAttr(numSplits[i])); } k++; @@ -866,25 +876,36 @@ void OpSplit::runOnOperation() { RewritePatternSet patterns1(ctx); - patterns1.insert>(ctx); - patterns1.insert>(ctx); - patterns1.insert>(ctx); - patterns1.insert>(ctx); + patterns1.insert>(ctx); + patterns1.insert>(ctx); + patterns1.insert>(ctx); + patterns1.insert>(ctx); + patterns1.insert>(ctx); (void)applyPatternsAndFoldGreedily(func, std::move(patterns1)); RewritePatternSet patterns2(ctx); - - patterns2.insert(ctx); - patterns2.insert(ctx); - patterns2.insert(ctx); - patterns2.insert>(ctx); - patterns2.insert>(ctx); - - patterns2.insert(ctx); - patterns2.insert(ctx); - - (void)applyPatternsAndFoldGreedily(func, std::move(patterns2)); + // We are restricting pattern matching to only move slices above when they + // have reached the number of set splits. This means many pass iterations + // would fail as they don't meet the criteria. We increase the maxIterations + // count here so that more iterations are tried before the rewriter decides + // failure. + GreedyRewriteConfig config; + config.maxIterations = 50; + + patterns2.insert>(ctx); + patterns2.insert>(ctx); + patterns2.insert(ctx); + patterns2.insert>(ctx); + patterns2.insert>(ctx); + + patterns2.insert>(ctx); + patterns2.insert>(ctx); + patterns2.insert>(ctx); + patterns2.insert>(ctx); + patterns2.insert(ctx); + + (void)applyPatternsAndFoldGreedily(func, std::move(patterns2), config); } // void OpSplit::runOnOperation() { } // namespace From f1911a5c74ccea57a337f1b2142d9413d4954131 Mon Sep 17 00:00:00 2001 From: panickal-xmos Date: Tue, 30 Jul 2024 16:37:13 +0100 Subject: [PATCH 5/5] Update submodule --- third_party/lib_tflite_micro | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/lib_tflite_micro b/third_party/lib_tflite_micro index 105420122..668bcb10e 160000 --- a/third_party/lib_tflite_micro +++ b/third_party/lib_tflite_micro @@ -1 +1 @@ -Subproject commit 1054201227b35e75dd40a8c8db8f424007ba7599 +Subproject commit 668bcb10e5258edc8a37f744d6d060637eddcccc