diff --git a/python/xmos_ai_tools/xinterpreters/src/dll_interpreter.cc b/python/xmos_ai_tools/xinterpreters/src/dll_interpreter.cc index 34bf8b147..3e0bd3d92 100644 --- a/python/xmos_ai_tools/xinterpreters/src/dll_interpreter.cc +++ b/python/xmos_ai_tools/xinterpreters/src/dll_interpreter.cc @@ -34,6 +34,7 @@ void add_lib_vision_ops( resolver->AddResizeBilinear(); resolver->AddResizeNearestNeighbor(); resolver->AddRound(); + resolver->AddRsqrt(); resolver->AddStridedSlice(); resolver->AddSlice(); resolver->AddSub(); diff --git a/xformer/Transforms/ConvPatterns.td b/xformer/Transforms/ConvPatterns.td index 305cb2557..82383c3ab 100644 --- a/xformer/Transforms/ConvPatterns.td +++ b/xformer/Transforms/ConvPatterns.td @@ -8,6 +8,13 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" include "IR/XCoreOps.td" include "Utils/Utils.td" +// Unfuse Relu from Conv if output zero point is not -128 for QI8 output +// XC Conv assumes there is no RELU, or that if there is one, it is being +// clamped to -128 Add benefit so that this pattern runs first +def: +Pat<(TFL_Conv2DOp: $output TensorOf<[QI8]>:$input, TensorOf<[QI8]>:$f, AnyTypeOf<[TensorOf<[I32,QI32]>, NoneType]>:$b, $dh, $dw, TFL_AF_Relu, $wf, $sh, $sw), + (TFL_ReluOp (TFL_Conv2DOp $input, $f, $b, $dh, $dw, TFL_AF_None, $wf, $sh, $sw, (returnType $output))), [(IsZeroPointNotEqualTo<-128> $output)], [], (addBenefit 20)>; + def CreateNoneValue : NativeCodeCall<"$_builder.create($0." "getLoc(), $_builder.getUnitAttr())">; diff --git a/xformer/Transforms/ReplaceAddSub.cpp b/xformer/Transforms/ReplaceAddSub.cpp index b7bf738c6..afb73ecd3 100644 --- a/xformer/Transforms/ReplaceAddSub.cpp +++ b/xformer/Transforms/ReplaceAddSub.cpp @@ -39,7 +39,7 @@ LogicalResult replaceAddorSub(T addOp, PatternRewriter &rewriter, auto lhsZeroPoint = lhsQType.getZeroPoint(); auto rhsQType = utils::getQType(addOp.getRhs()); - auto rhsScale = rhsQType.getScale(); + auto rhsScale = rhsQType.getScale(); auto rhsZeroPoint = rhsQType.getZeroPoint(); auto outputQType = utils::getQType(addOp.getOutput()); @@ -56,7 +56,7 @@ LogicalResult replaceAddorSub(T addOp, PatternRewriter &rewriter, int shift = int(floor(log2(pow(2, 14) / maxR))); // For doing subtraction with add op - rhsRatio = negateForSub? -rhsRatio: rhsRatio; + rhsRatio = negateForSub ? -rhsRatio : rhsRatio; // Multipliers are converted to fixed-point int m1 = round(lhsRatio * pow(2, shift)); diff --git a/xformer/Transforms/ReplaceConcat.cpp b/xformer/Transforms/ReplaceConcat.cpp index 63c9d07cf..aa42958dc 100644 --- a/xformer/Transforms/ReplaceConcat.cpp +++ b/xformer/Transforms/ReplaceConcat.cpp @@ -62,7 +62,8 @@ struct SplitConcatPattern : public OpRewritePattern { auto outputType = concatOp.getOutput().getType().cast(); Type elementType = outputType.getElementType(); ArrayRef outputShape = outputType.getShape(); - const int axis = concatOp.getAxis(); + const int axis = concatOp.getAxis() == -1 ? outputType.getRank() - 1 + : concatOp.getAxis(); int axisShape = 0; for (int i = 0; i < CONCAT_OP_MAX_INPUTS; i++) { diff --git a/xformer/Transforms/TFLPatterns.td b/xformer/Transforms/TFLPatterns.td index 9854fdc29..077d2c8f2 100644 --- a/xformer/Transforms/TFLPatterns.td +++ b/xformer/Transforms/TFLPatterns.td @@ -56,8 +56,9 @@ def : Pat<(TFL_ReluOp(TFL_MinimumOp $lhs, $rhs)), (TFL_ReluOp $lhs), [(IsSplatAndEqualTo<127> $rhs)]>; // Merge Relu with Conv -def : Pat<(TFL_ReluOp(TFL_Conv2DOp $input, $f, $b, $dh, $dw, TFL_AF_None, $p, - $sh, $sw)), +def : Pat<(TFL_ReluOp + : $output(TFL_Conv2DOp $input, $f, $b, $dh, $dw, TFL_AF_None, $p, + $sh, $sw)), (TFL_Conv2DOp $input, $f, $b, $dh, $dw, TFL_AF_Relu, $p, $sh, $sw)>; // Unfuse activation functions from binary ops diff --git a/xformer/Utils/Utils.td b/xformer/Utils/Utils.td index 081163d89..6bc1339c4 100644 --- a/xformer/Utils/Utils.td +++ b/xformer/Utils/Utils.td @@ -57,6 +57,11 @@ class IsSplatAndEqualTo "dyn_cast($0.getDefiningOp()).getValue()." "cast().getSplatValue() == " #n>>; +class IsZeroPointNotEqualTo + : Constraint< + CPred<"dyn_cast($0.getType().cast<" + "ShapedType>().getElementType()).getZeroPoint() != " #n>>; + // Get the dimension size as integer attr. class GetDimAsI32 : NativeCodeCall<