Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small fixes #920

Merged
merged 5 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/xmos_ai_tools/xinterpreters/src/dll_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ void add_lib_vision_ops(
resolver->AddResizeBilinear();
resolver->AddResizeNearestNeighbor();
resolver->AddRound();
resolver->AddRsqrt();
resolver->AddStridedSlice();
resolver->AddSlice();
resolver->AddSub();
Expand Down
7 changes: 7 additions & 0 deletions xformer/Transforms/ConvPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
include "IR/XCoreOps.td"
include "Utils/Utils.td"

// Unfuse Relu from Conv if output zero point is not -128 for QI8 output
// XC Conv assumes there is no RELU, or that if there is one, it is being
// clamped to -128 Add benefit so that this pattern runs first
def:
Pat<(TFL_Conv2DOp: $output TensorOf<[QI8]>:$input, TensorOf<[QI8]>:$f, AnyTypeOf<[TensorOf<[I32,QI32]>, NoneType]>:$b, $dh, $dw, TFL_AF_Relu, $wf, $sh, $sw),
(TFL_ReluOp (TFL_Conv2DOp $input, $f, $b, $dh, $dw, TFL_AF_None, $wf, $sh, $sw, (returnType $output))), [(IsZeroPointNotEqualTo<-128> $output)], [], (addBenefit 20)>;

def CreateNoneValue : NativeCodeCall<"$_builder.create<TFL::NoValueOp>($0."
"getLoc(), $_builder.getUnitAttr())">;

Expand Down
4 changes: 2 additions & 2 deletions xformer/Transforms/ReplaceAddSub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ LogicalResult replaceAddorSub(T addOp, PatternRewriter &rewriter,
auto lhsZeroPoint = lhsQType.getZeroPoint();

auto rhsQType = utils::getQType(addOp.getRhs());
auto rhsScale = rhsQType.getScale();
auto rhsScale = rhsQType.getScale();
auto rhsZeroPoint = rhsQType.getZeroPoint();

auto outputQType = utils::getQType(addOp.getOutput());
Expand All @@ -56,7 +56,7 @@ LogicalResult replaceAddorSub(T addOp, PatternRewriter &rewriter,
int shift = int(floor(log2(pow(2, 14) / maxR)));

// For doing subtraction with add op
rhsRatio = negateForSub? -rhsRatio: rhsRatio;
rhsRatio = negateForSub ? -rhsRatio : rhsRatio;

// Multipliers are converted to fixed-point
int m1 = round(lhsRatio * pow(2, shift));
Expand Down
3 changes: 2 additions & 1 deletion xformer/Transforms/ReplaceConcat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ struct SplitConcatPattern : public OpRewritePattern<TFL::ConcatenationOp> {
auto outputType = concatOp.getOutput().getType().cast<RankedTensorType>();
Type elementType = outputType.getElementType();
ArrayRef<int64_t> outputShape = outputType.getShape();
const int axis = concatOp.getAxis();
const int axis = concatOp.getAxis() == -1 ? outputType.getRank() - 1
: concatOp.getAxis();

int axisShape = 0;
for (int i = 0; i < CONCAT_OP_MAX_INPUTS; i++) {
Expand Down
5 changes: 3 additions & 2 deletions xformer/Transforms/TFLPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ def : Pat<(TFL_ReluOp(TFL_MinimumOp $lhs, $rhs)), (TFL_ReluOp $lhs),
[(IsSplatAndEqualTo<127> $rhs)]>;

// Merge Relu with Conv
def : Pat<(TFL_ReluOp(TFL_Conv2DOp $input, $f, $b, $dh, $dw, TFL_AF_None, $p,
$sh, $sw)),
def : Pat<(TFL_ReluOp
: $output(TFL_Conv2DOp $input, $f, $b, $dh, $dw, TFL_AF_None, $p,
$sh, $sw)),
(TFL_Conv2DOp $input, $f, $b, $dh, $dw, TFL_AF_Relu, $p, $sh, $sw)>;

// Unfuse activation functions from binary ops
Expand Down
5 changes: 5 additions & 0 deletions xformer/Utils/Utils.td
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ class IsSplatAndEqualTo<int n>
"dyn_cast<TFL::QConstOp>($0.getDefiningOp()).getValue()."
"cast<DenseElementsAttr>().getSplatValue<int8_t>() == " #n>>;

class IsZeroPointNotEqualTo<int n>
: Constraint<
CPred<"dyn_cast<mlir::quant::UniformQuantizedType>($0.getType().cast<"
"ShapedType>().getElementType()).getZeroPoint() != " #n>>;

// Get the dimension size as integer attr.
class GetDimAsI32<int n>
: NativeCodeCall<
Expand Down
Loading