Skip to content

Commit

Permalink
Add sub op using add
Browse files Browse the repository at this point in the history
  • Loading branch information
panickal-xmos committed Jul 11, 2024
1 parent 55b9cf1 commit 6f89d0b
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 100 deletions.
8 changes: 4 additions & 4 deletions xformer/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ void buildXCorePreOpSplitPassPipeline(OpPassManager &pm) {
pm.addPass(mlir::TFL::CreateTranslateToLCEPass());
// Convert dynamic shapes in batch dimension to static
pm.addPass(createRemoveDynamicShapePass());
}

void buildXCoreRemainingPassPipeline(OpPassManager &pm) {
// TFL passes
pm.addPass(createOptimizeTransposePass());
pm.addPass(createReplaceAvgPoolWithConv2DPass());
pm.addPass(createReplaceFCWithConv2DPass());
}

void buildXCoreRemainingPassPipeline(OpPassManager &pm) {
if (opSplitTensorArenaOption) {
pm.addPass(createOpSplitPass());
}
Expand All @@ -36,7 +36,7 @@ void buildXCoreRemainingPassPipeline(OpPassManager &pm) {
pm.addPass(mlir::createCanonicalizerPass());

// XC passes
pm.addPass(createReplaceAddPass());
pm.addPass(createReplaceAddSubPass());
pm.addPass(createReplaceMaxPoolPass());
pm.addPass(createReplaceMulPass());
pm.addPass(createReplaceTransposeConvPass());
Expand Down
2 changes: 1 addition & 1 deletion xformer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ std::unique_ptr<OperationPass<func::FuncOp>> createOptimizeConv2DPass();
std::unique_ptr<OperationPass<func::FuncOp>> createOpSplitPass();
std::unique_ptr<OperationPass<func::FuncOp>> createApplyTFLPatternsPass();
std::unique_ptr<OperationPass<func::FuncOp>> createRemoveDynamicShapePass();
std::unique_ptr<OperationPass<func::FuncOp>> createReplaceAddPass();
std::unique_ptr<OperationPass<func::FuncOp>> createReplaceAddSubPass();
std::unique_ptr<OperationPass<func::FuncOp>> createReplaceMulPass();
std::unique_ptr<OperationPass<func::FuncOp>> createReplaceMaxPoolPass();
std::unique_ptr<OperationPass<func::FuncOp>> createReplaceStridedSlicePass();
Expand Down
95 changes: 0 additions & 95 deletions xformer/Transforms/ReplaceAdd.cpp

This file was deleted.

110 changes: 110 additions & 0 deletions xformer/Transforms/ReplaceAddSub.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright 2021 XMOS LIMITED. This Software is subject to the terms of the
// XMOS Public License: Version 1

#include "IR/XCoreOps.h"
#include "Utils/Util.h"

#include "lib_nn/api/MemCpyFn.hpp"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"

namespace mlir::xcore {

namespace {
// Replace TFL Add with Add for XCore.
struct ReplaceAddSub
: public PassWrapper<ReplaceAddSub, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReplaceAddSub)

void getDependentDialects(DialectRegistry &registry) const final {
registry.insert<TFL::TensorFlowLiteDialect>();
}
StringRef getArgument() const final { return "xcore-replace-addsub"; }
StringRef getDescription() const final {
return "Replace TFL Add/Sub with Add for XCore.";
}
void runOnOperation() override;
};

template <typename T>
LogicalResult replaceAddorSub(T addOp, PatternRewriter &rewriter,
bool negateForSub) {
if (!utils::checkBinaryCompatibility(addOp))
return failure();

auto lhsQType = utils::getQType(addOp.getLhs());
auto lhsScale = lhsQType.getScale();
auto lhsZeroPoint = lhsQType.getZeroPoint();

auto rhsQType = utils::getQType(addOp.getRhs());
auto rhsScale = negateForSub ? -rhsQType.getScale() : rhsQType.getScale();
auto rhsZeroPoint = rhsQType.getZeroPoint();

auto outputQType = utils::getQType(addOp.getOutput());
auto outputScale = outputQType.getScale();
auto outputZeroPoint = outputQType.getZeroPoint();

double lhsRatio = lhsScale / outputScale;
double rhsRatio = rhsScale / outputScale;

// We find the max in case there is a large difference
// between lhs and rhs scales.
double maxR = std::max(lhsRatio, rhsRatio);
// We want the max shift to be 14 bits
int shift = int(floor(log2(pow(2, 14) / maxR)));

// Multipliers are converted to fixed-point
int m1 = round(lhsRatio * pow(2, shift));
int m2 = round(rhsRatio * pow(2, shift));
int bias = round((outputZeroPoint - (lhsZeroPoint * lhsRatio) -
(rhsZeroPoint * rhsRatio)) *
pow(2, shift));

auto xcAddOp = rewriter.create<AddOp>(
addOp.getLoc(), addOp.getType(), addOp.getLhs(), addOp.getRhs(),
rewriter.getStringAttr(addOp.getFusedActivationFunction()),
rewriter.getI32IntegerAttr(m1), rewriter.getI32IntegerAttr(m2),
rewriter.getI32IntegerAttr(bias), rewriter.getI32IntegerAttr(shift));
rewriter.replaceOp(addOp, xcAddOp.getOutput());

return success();
}

struct ReplaceAddPattern : public OpRewritePattern<TFL::AddOp> {
using OpRewritePattern<TFL::AddOp>::OpRewritePattern;

LogicalResult matchAndRewrite(TFL::AddOp addOp,
PatternRewriter &rewriter) const override {
return replaceAddorSub(addOp, rewriter, /*negateForSub=*/false);
}
};

struct ReplaceSubPattern : public OpRewritePattern<TFL::SubOp> {
using OpRewritePattern<TFL::SubOp>::OpRewritePattern;

LogicalResult matchAndRewrite(TFL::SubOp subOp,
PatternRewriter &rewriter) const override {
return replaceAddorSub(subOp, rewriter, /*negateForSub=*/true);
}
};

void ReplaceAddSub::runOnOperation() {
auto *ctx = &getContext();
func::FuncOp func = getOperation();
RewritePatternSet patterns(ctx);
patterns.insert<ReplaceAddPattern>(ctx);
patterns.insert<ReplaceSubPattern>(ctx);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
} // namespace

// Creates an instance of the ReplaceAddSub pass.
std::unique_ptr<OperationPass<func::FuncOp>> createReplaceAddSubPass() {
return std::make_unique<ReplaceAddSub>();
}

static PassRegistration<ReplaceAddSub> pass;

} // namespace mlir::xcore

0 comments on commit 6f89d0b

Please sign in to comment.