From 7c4eef2c2162c96f094b27c953cf30a14f777215 Mon Sep 17 00:00:00 2001 From: panickal-xmos Date: Fri, 2 Aug 2024 13:47:00 +0100 Subject: [PATCH 1/5] Format --- xformer/Transforms/ReplaceAddSub.cpp | 4 ++-- xformer/Transforms/TFLPatterns.td | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/xformer/Transforms/ReplaceAddSub.cpp b/xformer/Transforms/ReplaceAddSub.cpp index b7bf738c6..afb73ecd3 100644 --- a/xformer/Transforms/ReplaceAddSub.cpp +++ b/xformer/Transforms/ReplaceAddSub.cpp @@ -39,7 +39,7 @@ LogicalResult replaceAddorSub(T addOp, PatternRewriter &rewriter, auto lhsZeroPoint = lhsQType.getZeroPoint(); auto rhsQType = utils::getQType(addOp.getRhs()); - auto rhsScale = rhsQType.getScale(); + auto rhsScale = rhsQType.getScale(); auto rhsZeroPoint = rhsQType.getZeroPoint(); auto outputQType = utils::getQType(addOp.getOutput()); @@ -56,7 +56,7 @@ LogicalResult replaceAddorSub(T addOp, PatternRewriter &rewriter, int shift = int(floor(log2(pow(2, 14) / maxR))); // For doing subtraction with add op - rhsRatio = negateForSub? -rhsRatio: rhsRatio; + rhsRatio = negateForSub ? -rhsRatio : rhsRatio; // Multipliers are converted to fixed-point int m1 = round(lhsRatio * pow(2, shift)); diff --git a/xformer/Transforms/TFLPatterns.td b/xformer/Transforms/TFLPatterns.td index 9854fdc29..077d2c8f2 100644 --- a/xformer/Transforms/TFLPatterns.td +++ b/xformer/Transforms/TFLPatterns.td @@ -56,8 +56,9 @@ def : Pat<(TFL_ReluOp(TFL_MinimumOp $lhs, $rhs)), (TFL_ReluOp $lhs), [(IsSplatAndEqualTo<127> $rhs)]>; // Merge Relu with Conv -def : Pat<(TFL_ReluOp(TFL_Conv2DOp $input, $f, $b, $dh, $dw, TFL_AF_None, $p, - $sh, $sw)), +def : Pat<(TFL_ReluOp + : $output(TFL_Conv2DOp $input, $f, $b, $dh, $dw, TFL_AF_None, $p, + $sh, $sw)), (TFL_Conv2DOp $input, $f, $b, $dh, $dw, TFL_AF_Relu, $p, $sh, $sw)>; // Unfuse activation functions from binary ops From eaf9987cea14630460b649571206fdb82b61b80c Mon Sep 17 00:00:00 2001 From: panickal-xmos Date: Fri, 2 Aug 2024 13:47:24 +0100 Subject: [PATCH 2/5] Add rsqrt ref op --- python/xmos_ai_tools/xinterpreters/src/dll_interpreter.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/python/xmos_ai_tools/xinterpreters/src/dll_interpreter.cc b/python/xmos_ai_tools/xinterpreters/src/dll_interpreter.cc index 34bf8b147..3e0bd3d92 100644 --- a/python/xmos_ai_tools/xinterpreters/src/dll_interpreter.cc +++ b/python/xmos_ai_tools/xinterpreters/src/dll_interpreter.cc @@ -34,6 +34,7 @@ void add_lib_vision_ops( resolver->AddResizeBilinear(); resolver->AddResizeNearestNeighbor(); resolver->AddRound(); + resolver->AddRsqrt(); resolver->AddStridedSlice(); resolver->AddSlice(); resolver->AddSub(); From 4bd7b522e0b3db684d8038c0014c3a92afddd41c Mon Sep 17 00:00:00 2001 From: panickal-xmos Date: Fri, 2 Aug 2024 13:47:46 +0100 Subject: [PATCH 3/5] Handle when concat axis is -1 --- xformer/Transforms/ReplaceConcat.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xformer/Transforms/ReplaceConcat.cpp b/xformer/Transforms/ReplaceConcat.cpp index 63c9d07cf..4235f0c02 100644 --- a/xformer/Transforms/ReplaceConcat.cpp +++ b/xformer/Transforms/ReplaceConcat.cpp @@ -62,7 +62,8 @@ struct SplitConcatPattern : public OpRewritePattern { auto outputType = concatOp.getOutput().getType().cast(); Type elementType = outputType.getElementType(); ArrayRef outputShape = outputType.getShape(); - const int axis = concatOp.getAxis(); + int axis = concatOp.getAxis(); + axis = -1 ? outputType.getRank() - 1 : axis; int axisShape = 0; for (int i = 0; i < CONCAT_OP_MAX_INPUTS; i++) { From c9e77482fce8f90de2c5c4e5ac3c3781a883da38 Mon Sep 17 00:00:00 2001 From: panickal-xmos Date: Fri, 2 Aug 2024 13:48:10 +0100 Subject: [PATCH 4/5] Unfuse Relu from Conv if zeropoint is not -128 --- xformer/Transforms/ConvPatterns.td | 7 +++++++ xformer/Utils/Utils.td | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/xformer/Transforms/ConvPatterns.td b/xformer/Transforms/ConvPatterns.td index 305cb2557..82383c3ab 100644 --- a/xformer/Transforms/ConvPatterns.td +++ b/xformer/Transforms/ConvPatterns.td @@ -8,6 +8,13 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" include "IR/XCoreOps.td" include "Utils/Utils.td" +// Unfuse Relu from Conv if output zero point is not -128 for QI8 output +// XC Conv assumes there is no RELU, or that if there is one, it is being +// clamped to -128 Add benefit so that this pattern runs first +def: +Pat<(TFL_Conv2DOp: $output TensorOf<[QI8]>:$input, TensorOf<[QI8]>:$f, AnyTypeOf<[TensorOf<[I32,QI32]>, NoneType]>:$b, $dh, $dw, TFL_AF_Relu, $wf, $sh, $sw), + (TFL_ReluOp (TFL_Conv2DOp $input, $f, $b, $dh, $dw, TFL_AF_None, $wf, $sh, $sw, (returnType $output))), [(IsZeroPointNotEqualTo<-128> $output)], [], (addBenefit 20)>; + def CreateNoneValue : NativeCodeCall<"$_builder.create($0." "getLoc(), $_builder.getUnitAttr())">; diff --git a/xformer/Utils/Utils.td b/xformer/Utils/Utils.td index 081163d89..6bc1339c4 100644 --- a/xformer/Utils/Utils.td +++ b/xformer/Utils/Utils.td @@ -57,6 +57,11 @@ class IsSplatAndEqualTo "dyn_cast($0.getDefiningOp()).getValue()." "cast().getSplatValue() == " #n>>; +class IsZeroPointNotEqualTo + : Constraint< + CPred<"dyn_cast($0.getType().cast<" + "ShapedType>().getElementType()).getZeroPoint() != " #n>>; + // Get the dimension size as integer attr. class GetDimAsI32 : NativeCodeCall< From 4a8bf20f2fd2217b5696169d23bf5dbb320ed659 Mon Sep 17 00:00:00 2001 From: panickal-xmos Date: Fri, 2 Aug 2024 14:59:02 +0100 Subject: [PATCH 5/5] Fix typo --- xformer/Transforms/ReplaceConcat.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xformer/Transforms/ReplaceConcat.cpp b/xformer/Transforms/ReplaceConcat.cpp index 4235f0c02..aa42958dc 100644 --- a/xformer/Transforms/ReplaceConcat.cpp +++ b/xformer/Transforms/ReplaceConcat.cpp @@ -62,8 +62,8 @@ struct SplitConcatPattern : public OpRewritePattern { auto outputType = concatOp.getOutput().getType().cast(); Type elementType = outputType.getElementType(); ArrayRef outputShape = outputType.getShape(); - int axis = concatOp.getAxis(); - axis = -1 ? outputType.getRank() - 1 : axis; + const int axis = concatOp.getAxis() == -1 ? outputType.getRank() - 1 + : concatOp.getAxis(); int axisShape = 0; for (int i = 0; i < CONCAT_OP_MAX_INPUTS; i++) {