diff --git a/.github/workflows/release-beta.yml b/.github/workflows/release-beta.yml index 1ae3ca09f..bc01ad9fd 100644 --- a/.github/workflows/release-beta.yml +++ b/.github/workflows/release-beta.yml @@ -112,7 +112,7 @@ jobs: if: github.event.pull_request.merged == true name: Build release wheels for macOS arm64 needs: [build-release-archive] - runs-on: macos-11 + runs-on: macos-14 strategy: matrix: python-version: [3.9] diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c62e9ad0e..0adebdd57 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -122,7 +122,7 @@ jobs: macos-arm-release-wheel: name: Build release wheels for macOS arm64 needs: [build-release-archive] - runs-on: macos-11 + runs-on: macos-14 strategy: matrix: python-version: [3.9] diff --git a/integration_tests/models/8x8/test_sub/1.sh b/integration_tests/models/8x8/test_sub/1.sh new file mode 100755 index 000000000..4a264cea1 --- /dev/null +++ b/integration_tests/models/8x8/test_sub/1.sh @@ -0,0 +1,6 @@ +cp $1 /tmp/ +xcore-opt /tmp/$1 --lce-translate-tfl --mlir-print-ir-after-all -o /tmp/1.tflite >/tmp/1.mlir 2>&1 +cat /tmp/1.mlir | grep -v Tensor > /tmp/2.mlir +sed -i -e 's/tfl.add/tfl.sub/g' /tmp/2.mlir +xcore-opt --mlir-io --lce-translate-tfl /tmp/2.mlir -o /tmp/t.tflite +cp /tmp/t.tflite $1 diff --git a/integration_tests/models/8x8/test_sub/test_sub_0.tflite b/integration_tests/models/8x8/test_sub/test_sub_0.tflite new file mode 100644 index 000000000..3a3ed2e5c Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_0.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_1.tflite b/integration_tests/models/8x8/test_sub/test_sub_1.tflite new file mode 100644 index 000000000..933426f7a Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_1.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_10.tflite b/integration_tests/models/8x8/test_sub/test_sub_10.tflite new file mode 100644 index 000000000..c602e8ee8 Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_10.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_11.tflite b/integration_tests/models/8x8/test_sub/test_sub_11.tflite new file mode 100644 index 000000000..bd73d3a34 Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_11.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_12.tflite b/integration_tests/models/8x8/test_sub/test_sub_12.tflite new file mode 100644 index 000000000..1c983b9f0 Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_12.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_13.tflite b/integration_tests/models/8x8/test_sub/test_sub_13.tflite new file mode 100644 index 000000000..c388b3ce0 Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_13.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_14.tflite b/integration_tests/models/8x8/test_sub/test_sub_14.tflite new file mode 100644 index 000000000..70caf0c7f Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_14.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_15.tflite b/integration_tests/models/8x8/test_sub/test_sub_15.tflite new file mode 100644 index 000000000..9750f9b06 Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_15.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_16.tflite b/integration_tests/models/8x8/test_sub/test_sub_16.tflite new file mode 100644 index 000000000..81a1e3ceb Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_16.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_17.tflite b/integration_tests/models/8x8/test_sub/test_sub_17.tflite new file mode 100644 index 000000000..3a3493806 Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_17.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_18.tflite b/integration_tests/models/8x8/test_sub/test_sub_18.tflite new file mode 100644 index 000000000..8ec012fbf Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_18.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_19.tflite b/integration_tests/models/8x8/test_sub/test_sub_19.tflite new file mode 100644 index 000000000..5cde7b1e5 Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_19.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_2.tflite b/integration_tests/models/8x8/test_sub/test_sub_2.tflite new file mode 100644 index 000000000..578db1d86 Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_2.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_3.tflite b/integration_tests/models/8x8/test_sub/test_sub_3.tflite new file mode 100644 index 000000000..725acfdd7 Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_3.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_4.tflite b/integration_tests/models/8x8/test_sub/test_sub_4.tflite new file mode 100644 index 000000000..7bf123c24 Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_4.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_41.tflite b/integration_tests/models/8x8/test_sub/test_sub_41.tflite new file mode 100644 index 000000000..bf1466c48 Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_41.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_42.tflite b/integration_tests/models/8x8/test_sub/test_sub_42.tflite new file mode 100644 index 000000000..c8fdf4a0b Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_42.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_43.tflite b/integration_tests/models/8x8/test_sub/test_sub_43.tflite new file mode 100644 index 000000000..937fe3022 Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_43.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_44.tflite b/integration_tests/models/8x8/test_sub/test_sub_44.tflite new file mode 100644 index 000000000..70feb7854 Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_44.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_45.tflite b/integration_tests/models/8x8/test_sub/test_sub_45.tflite new file mode 100644 index 000000000..ad4b49749 Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_45.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_46.tflite b/integration_tests/models/8x8/test_sub/test_sub_46.tflite new file mode 100644 index 000000000..f41c031dd Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_46.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_47.tflite b/integration_tests/models/8x8/test_sub/test_sub_47.tflite new file mode 100644 index 000000000..8efac2d57 Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_47.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_5.tflite b/integration_tests/models/8x8/test_sub/test_sub_5.tflite new file mode 100644 index 000000000..ca3e769c4 Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_5.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_6.tflite b/integration_tests/models/8x8/test_sub/test_sub_6.tflite new file mode 100644 index 000000000..ddd3ced2c Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_6.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_7.tflite b/integration_tests/models/8x8/test_sub/test_sub_7.tflite new file mode 100644 index 000000000..3332f8c48 Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_7.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_8.tflite b/integration_tests/models/8x8/test_sub/test_sub_8.tflite new file mode 100644 index 000000000..397d94544 Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_8.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_9.tflite b/integration_tests/models/8x8/test_sub/test_sub_9.tflite new file mode 100644 index 000000000..89143cff6 Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_9.tflite differ diff --git a/integration_tests/models/8x8/test_sub/test_sub_dual_output.tflite b/integration_tests/models/8x8/test_sub/test_sub_dual_output.tflite new file mode 100644 index 000000000..cf3bea545 Binary files /dev/null and b/integration_tests/models/8x8/test_sub/test_sub_dual_output.tflite differ diff --git a/xformer/Analysis/MemoryPlan.cpp b/xformer/Analysis/MemoryPlan.cpp index 98d5cc0f6..d08ccd8f1 100644 --- a/xformer/Analysis/MemoryPlan.cpp +++ b/xformer/Analysis/MemoryPlan.cpp @@ -161,43 +161,62 @@ std::vector MemoryPlan::getAllocatedOffsets(const bool overlapOps, llvm::DenseSet alreadyVisited; if (overlapOps) { for (auto o : operations) { + // We iterate through overlappable ops which have not been visited yet if (o->hasTrait() && - !alreadyVisited.contains(o) && o->getOperand(0).hasOneUse()) { - alreadyVisited.insert(o); - - llvm::SmallVector inputVals; + !alreadyVisited.contains(o)) { auto inVal = o->getOperand(0); - inputVals.push_back(inVal); - - auto outVal = o->getResult(0); - auto nextOp = *outVal.getUsers().begin(); - // Identify chain of overlappable Ops - while (outVal.hasOneUse() && !alreadyVisited.contains(nextOp) && - nextOp->hasTrait()) { - inVal = outVal; + + // We have binary and unary ops as overlappable + // For binary ops, we might have to overlap with the second operand + // The complicated if condition below is to check for valid one operand + // or two operand cases + if ((o->getNumOperands() == 1 && inVal.hasOneUse() && + !vInfo[inVal].isConstant) || + (o->getNumOperands() == 2 && + (inVal.hasOneUse() && !vInfo[inVal].isConstant || + o->getOperand(1).hasOneUse() && + !vInfo[o->getOperand(1)].isConstant))) { + // In case of two operands and first operand is invalid, use the + // second one + if (o->getNumOperands() == 2 && + (!inVal.hasOneUse() || vInfo[inVal].isConstant)) { + inVal = o->getOperand(1); + } + + alreadyVisited.insert(o); + llvm::SmallVector inputVals; inputVals.push_back(inVal); - alreadyVisited.insert(nextOp); - outVal = nextOp->getResult(0); - nextOp = *outVal.getUsers().begin(); - } - // Set first Used of output Val to the first input Val - vInfo[outVal].firstUsed = vInfo[inputVals[0]].firstUsed; - auto unalignedSizeOutVal = - utils::getShapedTypeSize(outVal.getType().dyn_cast()); - size_t maxSizeNeeded = 0; - for (auto inV : inputVals) { - auto unalignedSizeInV = - utils::getShapedTypeSize(inV.getType().dyn_cast()); - auto unalignedOffset = unalignedSizeOutVal - unalignedSizeInV; - // Align offset up to double word = 8 bytes - auto offset = ((unalignedOffset + 7) / 8) * 8; - maxSizeNeeded = std::max(vInfo[inV].size + offset, maxSizeNeeded); - inOutMap[inV] = {outVal, offset}; + auto outVal = o->getResult(0); + auto nextOp = *outVal.getUsers().begin(); + // Identify chain of overlappable Ops + while (outVal.hasOneUse() && !alreadyVisited.contains(nextOp) && + nextOp->hasTrait()) { + inVal = outVal; + inputVals.push_back(inVal); + alreadyVisited.insert(nextOp); + outVal = nextOp->getResult(0); + nextOp = *outVal.getUsers().begin(); + } + + // Set first Used of output Val to the first input Val + vInfo[outVal].firstUsed = vInfo[inputVals[0]].firstUsed; + auto unalignedSizeOutVal = + utils::getShapedTypeSize(outVal.getType().dyn_cast()); + size_t maxSizeNeeded = 0; + for (auto inV : inputVals) { + auto unalignedSizeInV = + utils::getShapedTypeSize(inV.getType().dyn_cast()); + auto unalignedOffset = unalignedSizeOutVal - unalignedSizeInV; + // Align offset up to double word = 8 bytes + auto offset = ((unalignedOffset + 7) / 8) * 8; + maxSizeNeeded = std::max(vInfo[inV].size + offset, maxSizeNeeded); + inOutMap[inV] = {outVal, offset}; + } + // The aligned input val size plus aligned offset might be larger than + // aligned output val size + vInfo[outVal].size = std::max(vInfo[outVal].size, maxSizeNeeded); } - // The aligned input val size plus aligned offset might be larger than - // aligned output val size - vInfo[outVal].size = std::max(vInfo[outVal].size, maxSizeNeeded); } } } @@ -353,6 +372,7 @@ void MemoryPlan::printMemoryPlan() { line[c] = '.'; } int memory_use = 0; + int peakSize = 0; for (int i = 0; i < nonConstantAllocatedValues.size(); ++i) { if ((t < valueInfo[nonConstantAllocatedValues[i]].firstUsed) || (t > valueInfo[nonConstantAllocatedValues[i]].lastUsed)) { @@ -362,7 +382,12 @@ void MemoryPlan::printMemoryPlan() { if (offset == -1) { continue; } + const int size = valueInfo[nonConstantAllocatedValues[i]].size; + if (peakSize < offset + size) { + peakSize = offset + size; + } + memory_use += size; const int line_start = (offset * kLineWidth) / max_size; const int line_end = ((offset + size) * kLineWidth) / max_size; @@ -377,9 +402,10 @@ void MemoryPlan::printMemoryPlan() { line[kLineWidth] = 0; llvm::outs() << llvm::format( - "\n%-20s %s%d: %s (%dk)", + "\n%-20s %s%d: %s (%dk), (%dk)", operations[t]->getName().stripDialect().str().c_str(), - t < 10 ? " " : "", t, (const char *)line, (memory_use + 1023) / 1024); + t < 10 ? " " : "", t, (const char *)line, (memory_use + 1023) / 1024, + (peakSize + 1023) / 1024); } llvm::outs() << "\n"; } diff --git a/xformer/Test/add_broadcast.mlir b/xformer/Test/add_broadcast.mlir index 3ff19e237..faf8920be 100644 --- a/xformer/Test/add_broadcast.mlir +++ b/xformer/Test/add_broadcast.mlir @@ -1,4 +1,4 @@ -// RUN: xcore-opt --mlir-io %s --xcore-replace-add | FileCheck %s +// RUN: xcore-opt --mlir-io %s --xcore-replace-addsub | FileCheck %s // CHECK-LABEL: add_broadcast func.func @add_broadcast(%arg0: tensor<1x15x1x1x!quant.uniform> {tf_saved_model.index_path = ["input_1"]}) -> (tensor> {tf_saved_model.index_path = ["add"]}) attributes {tf.entry_function = {inputs = "serving_default_input_1:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { diff --git a/xformer/Transforms/Passes.cpp b/xformer/Transforms/Passes.cpp index 8a8866d4d..d0cad10e2 100644 --- a/xformer/Transforms/Passes.cpp +++ b/xformer/Transforms/Passes.cpp @@ -17,13 +17,13 @@ void buildXCorePreOpSplitPassPipeline(OpPassManager &pm) { pm.addPass(mlir::TFL::CreateTranslateToLCEPass()); // Convert dynamic shapes in batch dimension to static pm.addPass(createRemoveDynamicShapePass()); +} + +void buildXCoreRemainingPassPipeline(OpPassManager &pm) { // TFL passes pm.addPass(createOptimizeTransposePass()); pm.addPass(createReplaceAvgPoolWithConv2DPass()); pm.addPass(createReplaceFCWithConv2DPass()); -} - -void buildXCoreRemainingPassPipeline(OpPassManager &pm) { if (opSplitTensorArenaOption) { pm.addPass(createOpSplitPass()); } @@ -36,7 +36,7 @@ void buildXCoreRemainingPassPipeline(OpPassManager &pm) { pm.addPass(mlir::createCanonicalizerPass()); // XC passes - pm.addPass(createReplaceAddPass()); + pm.addPass(createReplaceAddSubPass()); pm.addPass(createReplaceMaxPoolPass()); pm.addPass(createReplaceMulPass()); pm.addPass(createReplaceTransposeConvPass()); diff --git a/xformer/Transforms/Passes.h b/xformer/Transforms/Passes.h index c41f4e7b9..32d6d581b 100644 --- a/xformer/Transforms/Passes.h +++ b/xformer/Transforms/Passes.h @@ -31,7 +31,7 @@ std::unique_ptr> createOptimizeConv2DPass(); std::unique_ptr> createOpSplitPass(); std::unique_ptr> createApplyTFLPatternsPass(); std::unique_ptr> createRemoveDynamicShapePass(); -std::unique_ptr> createReplaceAddPass(); +std::unique_ptr> createReplaceAddSubPass(); std::unique_ptr> createReplaceMulPass(); std::unique_ptr> createReplaceMaxPoolPass(); std::unique_ptr> createReplaceStridedSlicePass(); diff --git a/xformer/Transforms/ReplaceAdd.cpp b/xformer/Transforms/ReplaceAdd.cpp deleted file mode 100644 index 1e8947d6a..000000000 --- a/xformer/Transforms/ReplaceAdd.cpp +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2021 XMOS LIMITED. This Software is subject to the terms of the -// XMOS Public License: Version 1 - -#include "IR/XCoreOps.h" -#include "Utils/Util.h" - -#include "lib_nn/api/MemCpyFn.hpp" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" - -namespace mlir::xcore { - -namespace { -// Replace TFL Add with Add for XCore. -struct ReplaceAdd - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReplaceAdd) - - void getDependentDialects(DialectRegistry ®istry) const final { - registry.insert(); - } - StringRef getArgument() const final { return "xcore-replace-add"; } - StringRef getDescription() const final { - return "Replace TFL Add with Add for XCore."; - } - void runOnOperation() override; -}; - -struct ReplaceAddPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TFL::AddOp addOp, - PatternRewriter &rewriter) const override { - - if (!utils::checkBinaryCompatibility(addOp)) - return failure(); - - auto lhsQType = utils::getQType(addOp.getLhs()); - auto lhsScale = lhsQType.getScale(); - auto lhsZeroPoint = lhsQType.getZeroPoint(); - - auto rhsQType = utils::getQType(addOp.getRhs()); - auto rhsScale = rhsQType.getScale(); - auto rhsZeroPoint = rhsQType.getZeroPoint(); - - auto outputQType = utils::getQType(addOp.getOutput()); - auto outputScale = outputQType.getScale(); - auto outputZeroPoint = outputQType.getZeroPoint(); - - double lhsRatio = lhsScale / outputScale; - double rhsRatio = rhsScale / outputScale; - - // We find the max in case there is a large difference - // between lhs and rhs scales. - double maxR = std::max(lhsRatio, rhsRatio); - // We want the max shift to be 14 bits - int shift = int(floor(log2(pow(2, 14) / maxR))); - - // Multipliers are converted to fixed-point - int m1 = round(lhsRatio * pow(2, shift)); - int m2 = round(rhsRatio * pow(2, shift)); - int bias = round((outputZeroPoint - (lhsZeroPoint * lhsRatio) - - (rhsZeroPoint * rhsRatio)) * - pow(2, shift)); - - auto xcAddOp = rewriter.create( - addOp.getLoc(), addOp.getType(), addOp.getLhs(), addOp.getRhs(), - rewriter.getStringAttr(addOp.getFusedActivationFunction()), - rewriter.getI32IntegerAttr(m1), rewriter.getI32IntegerAttr(m2), - rewriter.getI32IntegerAttr(bias), rewriter.getI32IntegerAttr(shift)); - rewriter.replaceOp(addOp, xcAddOp.getOutput()); - - return success(); - } -}; - -void ReplaceAdd::runOnOperation() { - auto *ctx = &getContext(); - func::FuncOp func = getOperation(); - RewritePatternSet patterns(ctx); - patterns.insert(ctx); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); -} -} // namespace - -// Creates an instance of the ReplaceAdd pass. -std::unique_ptr> createReplaceAddPass() { - return std::make_unique(); -} - -static PassRegistration pass; - -} // namespace mlir::xcore diff --git a/xformer/Transforms/ReplaceAddSub.cpp b/xformer/Transforms/ReplaceAddSub.cpp new file mode 100644 index 000000000..b7bf738c6 --- /dev/null +++ b/xformer/Transforms/ReplaceAddSub.cpp @@ -0,0 +1,113 @@ +// Copyright 2021 XMOS LIMITED. This Software is subject to the terms of the +// XMOS Public License: Version 1 + +#include "IR/XCoreOps.h" +#include "Utils/Util.h" + +#include "lib_nn/api/MemCpyFn.hpp" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" + +namespace mlir::xcore { + +namespace { +// Replace TFL Add with Add for XCore. +struct ReplaceAddSub + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReplaceAddSub) + + void getDependentDialects(DialectRegistry ®istry) const final { + registry.insert(); + } + StringRef getArgument() const final { return "xcore-replace-addsub"; } + StringRef getDescription() const final { + return "Replace TFL Add/Sub with Add for XCore."; + } + void runOnOperation() override; +}; + +template +LogicalResult replaceAddorSub(T addOp, PatternRewriter &rewriter, + bool negateForSub) { + if (!utils::checkBinaryCompatibility(addOp)) + return failure(); + + auto lhsQType = utils::getQType(addOp.getLhs()); + auto lhsScale = lhsQType.getScale(); + auto lhsZeroPoint = lhsQType.getZeroPoint(); + + auto rhsQType = utils::getQType(addOp.getRhs()); + auto rhsScale = rhsQType.getScale(); + auto rhsZeroPoint = rhsQType.getZeroPoint(); + + auto outputQType = utils::getQType(addOp.getOutput()); + auto outputScale = outputQType.getScale(); + auto outputZeroPoint = outputQType.getZeroPoint(); + + double lhsRatio = lhsScale / outputScale; + double rhsRatio = rhsScale / outputScale; + + // We find the max in case there is a large difference + // between lhs and rhs scales. + double maxR = std::max(lhsRatio, rhsRatio); + // We want the max shift to be 14 bits + int shift = int(floor(log2(pow(2, 14) / maxR))); + + // For doing subtraction with add op + rhsRatio = negateForSub? -rhsRatio: rhsRatio; + + // Multipliers are converted to fixed-point + int m1 = round(lhsRatio * pow(2, shift)); + int m2 = round(rhsRatio * pow(2, shift)); + int bias = round((outputZeroPoint - (lhsZeroPoint * lhsRatio) - + (rhsZeroPoint * rhsRatio)) * + pow(2, shift)); + + auto xcAddOp = rewriter.create( + addOp.getLoc(), addOp.getType(), addOp.getLhs(), addOp.getRhs(), + rewriter.getStringAttr(addOp.getFusedActivationFunction()), + rewriter.getI32IntegerAttr(m1), rewriter.getI32IntegerAttr(m2), + rewriter.getI32IntegerAttr(bias), rewriter.getI32IntegerAttr(shift)); + rewriter.replaceOp(addOp, xcAddOp.getOutput()); + + return success(); +} + +struct ReplaceAddPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::AddOp addOp, + PatternRewriter &rewriter) const override { + return replaceAddorSub(addOp, rewriter, /*negateForSub=*/false); + } +}; + +struct ReplaceSubPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::SubOp subOp, + PatternRewriter &rewriter) const override { + return replaceAddorSub(subOp, rewriter, /*negateForSub=*/true); + } +}; + +void ReplaceAddSub::runOnOperation() { + auto *ctx = &getContext(); + func::FuncOp func = getOperation(); + RewritePatternSet patterns(ctx); + patterns.insert(ctx); + patterns.insert(ctx); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); +} +} // namespace + +// Creates an instance of the ReplaceAddSub pass. +std::unique_ptr> createReplaceAddSubPass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace mlir::xcore