diff --git a/xformer/Transforms/ConvPatterns.td b/xformer/Transforms/ConvPatterns.td index 305cb2557..82383c3ab 100644 --- a/xformer/Transforms/ConvPatterns.td +++ b/xformer/Transforms/ConvPatterns.td @@ -8,6 +8,13 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" include "IR/XCoreOps.td" include "Utils/Utils.td" +// Unfuse Relu from Conv if output zero point is not -128 for QI8 output +// XC Conv assumes there is no RELU, or that if there is one, it is being +// clamped to -128 Add benefit so that this pattern runs first +def: +Pat<(TFL_Conv2DOp: $output TensorOf<[QI8]>:$input, TensorOf<[QI8]>:$f, AnyTypeOf<[TensorOf<[I32,QI32]>, NoneType]>:$b, $dh, $dw, TFL_AF_Relu, $wf, $sh, $sw), + (TFL_ReluOp (TFL_Conv2DOp $input, $f, $b, $dh, $dw, TFL_AF_None, $wf, $sh, $sw, (returnType $output))), [(IsZeroPointNotEqualTo<-128> $output)], [], (addBenefit 20)>; + def CreateNoneValue : NativeCodeCall<"$_builder.create($0." "getLoc(), $_builder.getUnitAttr())">; diff --git a/xformer/Utils/Utils.td b/xformer/Utils/Utils.td index 081163d89..6bc1339c4 100644 --- a/xformer/Utils/Utils.td +++ b/xformer/Utils/Utils.td @@ -57,6 +57,11 @@ class IsSplatAndEqualTo "dyn_cast($0.getDefiningOp()).getValue()." "cast().getSplatValue() == " #n>>; +class IsZeroPointNotEqualTo + : Constraint< + CPred<"dyn_cast($0.getType().cast<" + "ShapedType>().getElementType()).getZeroPoint() != " #n>>; + // Get the dimension size as integer attr. class GetDimAsI32 : NativeCodeCall<