diff --git a/xformer/Transforms/Passes.cpp b/xformer/Transforms/Passes.cpp index 0f482b99b..29d720de2 100644 --- a/xformer/Transforms/Passes.cpp +++ b/xformer/Transforms/Passes.cpp @@ -40,6 +40,7 @@ void buildXCoreRemainingPassPipeline(OpPassManager &pm) { pm.addPass(createReplaceMaxPoolPass()); pm.addPass(createReplaceMulPass()); pm.addPass(createReplaceMeanPass()); + pm.addPass(createReplaceSumPass()); pm.addPass(createReplaceTransposeConvPass()); pm.addPass(createReplaceConv2DPass()); pm.addPass(createReplacePadPass()); diff --git a/xformer/Transforms/Passes.h b/xformer/Transforms/Passes.h index edae71024..bf7ac59ae 100644 --- a/xformer/Transforms/Passes.h +++ b/xformer/Transforms/Passes.h @@ -34,6 +34,7 @@ std::unique_ptr> createRemoveDynamicShapePass(); std::unique_ptr> createReplaceAddSubPass(); std::unique_ptr> createReplaceMulPass(); std::unique_ptr> createReplaceMeanPass(); +std::unique_ptr> createReplaceSumPass(); std::unique_ptr> createReplaceMaxPoolPass(); std::unique_ptr> createReplaceStridedSlicePass(); std::unique_ptr> createReplaceSlicePass(); diff --git a/xformer/Transforms/ReplaceSum.cpp b/xformer/Transforms/ReplaceSum.cpp new file mode 100644 index 000000000..6415b6a62 --- /dev/null +++ b/xformer/Transforms/ReplaceSum.cpp @@ -0,0 +1,120 @@ +// Copyright 2021 XMOS LIMITED. This Software is subject to the terms of the +// XMOS Public License: Version 1 + +#include "IR/XCoreOps.h" +#include "Utils/Util.h" + +extern "C" { +#include "lib_nn/api/nn_layers.h" +} +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/utils/validators.h" + +namespace mlir::xcore { + +namespace { +// Replace TFL Sum with Mean for XCore. +struct ReplaceSum + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReplaceSum) + + void getDependentDialects(DialectRegistry ®istry) const final { + registry.insert(); + } + StringRef getArgument() const final { return "xcore-replace-sum"; } + StringRef getDescription() const final { + return "Replace TFL Sum with mean for XCore."; + } + void runOnOperation() override; +}; + +struct ReplaceSumPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::SumOp sumOp, + PatternRewriter &rewriter) const override { + + auto input = sumOp.getInput(); + auto output = sumOp.getOutput(); + + DenseElementsAttr axisAttr; + matchPattern(sumOp.getAxes(), m_Constant(&axisAttr)); + auto axisValues = axisAttr.getValues(); + std::vector axis(axisValues.begin(), axisValues.end()); + int32_t minAxis = *std::min_element(axis.begin(), axis.end()); + int32_t maxAxis = *std::max_element(axis.begin(), axis.end()); + if (maxAxis - minAxis > axis.size() - 1) { + return failure(); + } + + auto inputType = input.getType().cast(); + auto outputType = output.getType().cast(); + if (!utils::isNBitSignedQType<8>(inputType.getElementType()) || + !utils::isNBitSignedQType<8>(outputType.getElementType())) { + return failure(); + } + + auto inputShape = inputType.getShape(); + auto outputShape = outputType.getShape(); + + int rank = inputShape.size(); + + int beginDims = 1; + for (int i = 0; i < minAxis; i++) { + beginDims *= inputShape[i]; + } + + int endDims = 1; + for (int i = maxAxis + 1; i < rank; i++) { + endDims *= inputShape[i]; + } + + int sumDims = 1; + for (int i = minAxis; i <= maxAxis; i++) { + sumDims *= inputShape[i]; + } + + auto inputQType = utils::getQType(input); + auto outputQType = utils::getQType(output); + + float inZeroPoint = static_cast(inputQType.getZeroPoint()); + float outZeroPoint = static_cast(outputQType.getZeroPoint()); + float scaleMul = inputQType.getScale() / outputQType.getScale(); + + auto beginDimsAttr = rewriter.getI32IntegerAttr(beginDims); + auto endDimsAttr = rewriter.getI32IntegerAttr(endDims); + auto meanDimsAttr = rewriter.getI32IntegerAttr(sumDims); + auto inZeroPointAttr = rewriter.getF32FloatAttr(inZeroPoint); + auto outZeroPointAttr = rewriter.getF32FloatAttr(outZeroPoint); + auto scaleMulAttr = rewriter.getF32FloatAttr(scaleMul); + + auto xcSumOp = rewriter.create( + sumOp.getLoc(), sumOp.getType(), sumOp.getInput(), beginDimsAttr, + meanDimsAttr, endDimsAttr, inZeroPointAttr, outZeroPointAttr, + scaleMulAttr); + rewriter.replaceOp(sumOp, xcSumOp.getOutput()); + + return success(); + } +}; + +void ReplaceSum::runOnOperation() { + auto *ctx = &getContext(); + func::FuncOp func = getOperation(); + RewritePatternSet patterns(ctx); + patterns.insert(ctx); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); +} +} // namespace + +// Creates an instance of the ReplaceSum pass. +std::unique_ptr> createReplaceSumPass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace mlir::xcore