Skip to content

Commit

Permalink
Merge pull request #920 from xmos/small_fixes
Browse files Browse the repository at this point in the history
Small fixes
  • Loading branch information
panickal-xmos authored Aug 2, 2024
2 parents 08258e4 + 4a8bf20 commit 59d3064
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 5 deletions.
1 change: 1 addition & 0 deletions python/xmos_ai_tools/xinterpreters/src/dll_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ void add_lib_vision_ops(
resolver->AddResizeBilinear();
resolver->AddResizeNearestNeighbor();
resolver->AddRound();
resolver->AddRsqrt();
resolver->AddStridedSlice();
resolver->AddSlice();
resolver->AddSub();
Expand Down
7 changes: 7 additions & 0 deletions xformer/Transforms/ConvPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
include "IR/XCoreOps.td"
include "Utils/Utils.td"

// Unfuse Relu from Conv if output zero point is not -128 for QI8 output
// XC Conv assumes there is no RELU, or that if there is one, it is being
// clamped to -128 Add benefit so that this pattern runs first
def:
Pat<(TFL_Conv2DOp: $output TensorOf<[QI8]>:$input, TensorOf<[QI8]>:$f, AnyTypeOf<[TensorOf<[I32,QI32]>, NoneType]>:$b, $dh, $dw, TFL_AF_Relu, $wf, $sh, $sw),
(TFL_ReluOp (TFL_Conv2DOp $input, $f, $b, $dh, $dw, TFL_AF_None, $wf, $sh, $sw, (returnType $output))), [(IsZeroPointNotEqualTo<-128> $output)], [], (addBenefit 20)>;

def CreateNoneValue : NativeCodeCall<"$_builder.create<TFL::NoValueOp>($0."
"getLoc(), $_builder.getUnitAttr())">;

Expand Down
4 changes: 2 additions & 2 deletions xformer/Transforms/ReplaceAddSub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ LogicalResult replaceAddorSub(T addOp, PatternRewriter &rewriter,
auto lhsZeroPoint = lhsQType.getZeroPoint();

auto rhsQType = utils::getQType(addOp.getRhs());
auto rhsScale = rhsQType.getScale();
auto rhsScale = rhsQType.getScale();
auto rhsZeroPoint = rhsQType.getZeroPoint();

auto outputQType = utils::getQType(addOp.getOutput());
Expand All @@ -56,7 +56,7 @@ LogicalResult replaceAddorSub(T addOp, PatternRewriter &rewriter,
int shift = int(floor(log2(pow(2, 14) / maxR)));

// For doing subtraction with add op
rhsRatio = negateForSub? -rhsRatio: rhsRatio;
rhsRatio = negateForSub ? -rhsRatio : rhsRatio;

// Multipliers are converted to fixed-point
int m1 = round(lhsRatio * pow(2, shift));
Expand Down
3 changes: 2 additions & 1 deletion xformer/Transforms/ReplaceConcat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ struct SplitConcatPattern : public OpRewritePattern<TFL::ConcatenationOp> {
auto outputType = concatOp.getOutput().getType().cast<RankedTensorType>();
Type elementType = outputType.getElementType();
ArrayRef<int64_t> outputShape = outputType.getShape();
const int axis = concatOp.getAxis();
const int axis = concatOp.getAxis() == -1 ? outputType.getRank() - 1
: concatOp.getAxis();

int axisShape = 0;
for (int i = 0; i < CONCAT_OP_MAX_INPUTS; i++) {
Expand Down
5 changes: 3 additions & 2 deletions xformer/Transforms/TFLPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ def : Pat<(TFL_ReluOp(TFL_MinimumOp $lhs, $rhs)), (TFL_ReluOp $lhs),
[(IsSplatAndEqualTo<127> $rhs)]>;

// Merge Relu with Conv
def : Pat<(TFL_ReluOp(TFL_Conv2DOp $input, $f, $b, $dh, $dw, TFL_AF_None, $p,
$sh, $sw)),
def : Pat<(TFL_ReluOp
: $output(TFL_Conv2DOp $input, $f, $b, $dh, $dw, TFL_AF_None, $p,
$sh, $sw)),
(TFL_Conv2DOp $input, $f, $b, $dh, $dw, TFL_AF_Relu, $p, $sh, $sw)>;

// Unfuse activation functions from binary ops
Expand Down
5 changes: 5 additions & 0 deletions xformer/Utils/Utils.td
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ class IsSplatAndEqualTo<int n>
"dyn_cast<TFL::QConstOp>($0.getDefiningOp()).getValue()."
"cast<DenseElementsAttr>().getSplatValue<int8_t>() == " #n>>;

class IsZeroPointNotEqualTo<int n>
: Constraint<
CPred<"dyn_cast<mlir::quant::UniformQuantizedType>($0.getType().cast<"
"ShapedType>().getElementType()).getZeroPoint() != " #n>>;

// Get the dimension size as integer attr.
class GetDimAsI32<int n>
: NativeCodeCall<
Expand Down

0 comments on commit 59d3064

Please sign in to comment.