Skip to content

Commit

Permalink
Unfuse Relu from Conv if zeropoint is not -128
Browse files Browse the repository at this point in the history
  • Loading branch information
panickal-xmos committed Aug 2, 2024
1 parent 4bd7b52 commit c9e7748
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
7 changes: 7 additions & 0 deletions xformer/Transforms/ConvPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
include "IR/XCoreOps.td"
include "Utils/Utils.td"

// Unfuse Relu from Conv if output zero point is not -128 for QI8 output
// XC Conv assumes there is no RELU, or that if there is one, it is being
// clamped to -128 Add benefit so that this pattern runs first
def:
Pat<(TFL_Conv2DOp: $output TensorOf<[QI8]>:$input, TensorOf<[QI8]>:$f, AnyTypeOf<[TensorOf<[I32,QI32]>, NoneType]>:$b, $dh, $dw, TFL_AF_Relu, $wf, $sh, $sw),
(TFL_ReluOp (TFL_Conv2DOp $input, $f, $b, $dh, $dw, TFL_AF_None, $wf, $sh, $sw, (returnType $output))), [(IsZeroPointNotEqualTo<-128> $output)], [], (addBenefit 20)>;

def CreateNoneValue : NativeCodeCall<"$_builder.create<TFL::NoValueOp>($0."
"getLoc(), $_builder.getUnitAttr())">;

Expand Down
5 changes: 5 additions & 0 deletions xformer/Utils/Utils.td
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ class IsSplatAndEqualTo<int n>
"dyn_cast<TFL::QConstOp>($0.getDefiningOp()).getValue()."
"cast<DenseElementsAttr>().getSplatValue<int8_t>() == " #n>>;

class IsZeroPointNotEqualTo<int n>
: Constraint<
CPred<"dyn_cast<mlir::quant::UniformQuantizedType>($0.getType().cast<"
"ShapedType>().getElementType()).getZeroPoint() != " #n>>;

// Get the dimension size as integer attr.
class GetDimAsI32<int n>
: NativeCodeCall<
Expand Down

0 comments on commit c9e7748

Please sign in to comment.