Skip to content

Commit

Permalink
Add tests and update sub op
Browse files Browse the repository at this point in the history
  • Loading branch information
panickal-xmos committed Jul 18, 2024
1 parent 9bb95f6 commit 95007ff
Show file tree
Hide file tree
Showing 30 changed files with 10 additions and 1 deletion.
6 changes: 6 additions & 0 deletions integration_tests/models/8x8/test_sub/1.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
cp $1 /tmp/
xcore-opt /tmp/$1 --lce-translate-tfl --mlir-print-ir-after-all -o /tmp/1.tflite >/tmp/1.mlir 2>&1
cat /tmp/1.mlir | grep -v Tensor > /tmp/2.mlir
sed -i -e 's/tfl.add/tfl.sub/g' /tmp/2.mlir
xcore-opt --mlir-io --lce-translate-tfl /tmp/2.mlir -o /tmp/t.tflite
cp /tmp/t.tflite $1
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
5 changes: 4 additions & 1 deletion xformer/Transforms/ReplaceAddSub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ LogicalResult replaceAddorSub(T addOp, PatternRewriter &rewriter,
auto lhsZeroPoint = lhsQType.getZeroPoint();

auto rhsQType = utils::getQType(addOp.getRhs());
auto rhsScale = negateForSub ? -rhsQType.getScale() : rhsQType.getScale();
auto rhsScale = rhsQType.getScale();
auto rhsZeroPoint = rhsQType.getZeroPoint();

auto outputQType = utils::getQType(addOp.getOutput());
Expand All @@ -55,6 +55,9 @@ LogicalResult replaceAddorSub(T addOp, PatternRewriter &rewriter,
// We want the max shift to be 14 bits
int shift = int(floor(log2(pow(2, 14) / maxR)));

// For doing subtraction with add op
rhsRatio = negateForSub? -rhsRatio: rhsRatio;

// Multipliers are converted to fixed-point
int m1 = round(lhsRatio * pow(2, shift));
int m2 = round(rhsRatio * pow(2, shift));
Expand Down

0 comments on commit 95007ff

Please sign in to comment.