diff --git a/Jenkinsfile b/Jenkinsfile index 296fa07c5..8017e00f2 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -18,6 +18,7 @@ def dailyDeviceTest = { sh "pytest examples/app_mobilenetv2" } runPytestDevice("8x8/test_broadcast", "-n 1 --tc 1", "broadcast_1") + runPytestDevice("8x8/test_mean", "-n 1 --tc 1", "mean_1") runPytestDevice("8x8/test_lstm", "-n 1 --tc 1", "lstm_1") runPytestDevice("8x8/test_lstm", "-n 1", "lstm_5") runPytestDevice("complex_models/8x8/test_cnn_classifier", "-n 1 --tc 1", "cnn_classifier_1") diff --git a/integration_tests/models/8x8/test_mean/generate.py b/integration_tests/models/8x8/test_mean/generate.py new file mode 100644 index 000000000..1743e1a27 --- /dev/null +++ b/integration_tests/models/8x8/test_mean/generate.py @@ -0,0 +1,47 @@ +import numpy as np +import tensorflow as tf +from tensorflow import lite as tfl + +i = 0 + + +def generate_mean_model(input_shape, axes): + input_data = tf.keras.Input(shape=input_shape, dtype=tf.int8, batch_size=1) + mean_output = tf.keras.backend.mean(input_data, axis=axes) + model = tf.keras.Model(inputs=input_data, outputs=mean_output) + converter = tfl.TFLiteConverter.from_keras_model(model) + + def representative_dataset_gen(): + for _ in range(100): + yield [ + np.random.uniform(low=-127, high=127, size=input_shape).astype(np.int8) + ] + + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset_gen + converter.target_spec.supported_ops = [tfl.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + + tflite_model = converter.convert() + global i + model_name = f"test_mean_{i}.tflite" + i += 1 + with open(model_name, "wb") as f: + f.write(tflite_model) + print(f"Model saved: {model_name}") + + +input_shapes_and_axes = [ + ((10,), [0]), + ((8, 16), [0]), + ((8, 16), [1]), + ((8, 16), [0, 1]), + ((8, 15, 32), [0]), + ((8, 15, 32), [1]), + ((8, 15, 32), [2]), + ((8, 15, 32), [0, 2]), +] + +for shape, axes in input_shapes_and_axes: + generate_mean_model(shape, axes) diff --git a/integration_tests/models/8x8/test_mean/params.yaml b/integration_tests/models/8x8/test_mean/params.yaml new file mode 100644 index 000000000..aecf2f6bd --- /dev/null +++ b/integration_tests/models/8x8/test_mean/params.yaml @@ -0,0 +1 @@ +MAX_ABS_ERROR: 1.0 diff --git a/integration_tests/models/8x8/test_mean/test_mean_0.tflite b/integration_tests/models/8x8/test_mean/test_mean_0.tflite new file mode 100644 index 000000000..d6ba76ad0 Binary files /dev/null and b/integration_tests/models/8x8/test_mean/test_mean_0.tflite differ diff --git a/integration_tests/models/8x8/test_mean/test_mean_1.mlir b/integration_tests/models/8x8/test_mean/test_mean_1.mlir new file mode 100644 index 000000000..86e91f1ab --- /dev/null +++ b/integration_tests/models/8x8/test_mean/test_mean_1.mlir @@ -0,0 +1,5 @@ +func.func @main(%arg0: tensor<1x5x8x16x!quant.uniform> {tf_saved_model.index_path = ["input_2"]}) -> (tensor<1x5x1x16x!quant.uniform> {tf_saved_model.index_path = ["tf.mean_1"]}) attributes {tf.entry_function = {inputs = "serving_default_input_2:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = "tfl.pseudo_qconst"() {qtype = tensor<1xi32>, value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tfl.mean"(%arg0, %0) {keep_dims = true} : (tensor<1x5x8x16x!quant.uniform>, tensor<1xi32>) -> tensor<1x5x1x16x!quant.uniform> + return %1 : tensor<1x5x1x16x!quant.uniform> +} diff --git a/integration_tests/models/8x8/test_mean/test_mean_1.tflite b/integration_tests/models/8x8/test_mean/test_mean_1.tflite new file mode 100644 index 000000000..6fd030d11 Binary files /dev/null and b/integration_tests/models/8x8/test_mean/test_mean_1.tflite differ diff --git a/integration_tests/models/8x8/test_mean/test_mean_10.mlir b/integration_tests/models/8x8/test_mean/test_mean_10.mlir new file mode 100644 index 000000000..f49a62393 --- /dev/null +++ b/integration_tests/models/8x8/test_mean/test_mean_10.mlir @@ -0,0 +1,6 @@ +// This test reduces the 2nd and 3rd axes of a 4D tensor with consecutive axes and keep_dims = true. +func.func @main(%arg0: tensor<8x5x10x12x!quant.uniform>) -> (tensor<8x1x1x12x!quant.uniform>) { + %0 = "tfl.pseudo_qconst"() {qtype = tensor<2xi32>, value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tfl.mean"(%arg0, %0) {keep_dims = true} : (tensor<8x5x10x12x!quant.uniform>, tensor<2xi32>) -> tensor<8x1x1x12x!quant.uniform> + return %1 : tensor<8x1x1x12x!quant.uniform> +} diff --git a/integration_tests/models/8x8/test_mean/test_mean_10.tflite b/integration_tests/models/8x8/test_mean/test_mean_10.tflite new file mode 100644 index 000000000..0e35f8a89 Binary files /dev/null and b/integration_tests/models/8x8/test_mean/test_mean_10.tflite differ diff --git a/integration_tests/models/8x8/test_mean/test_mean_11.mlir b/integration_tests/models/8x8/test_mean/test_mean_11.mlir new file mode 100644 index 000000000..f49a62393 --- /dev/null +++ b/integration_tests/models/8x8/test_mean/test_mean_11.mlir @@ -0,0 +1,6 @@ +// This test reduces the 2nd and 3rd axes of a 4D tensor with consecutive axes and keep_dims = true. +func.func @main(%arg0: tensor<8x5x10x12x!quant.uniform>) -> (tensor<8x1x1x12x!quant.uniform>) { + %0 = "tfl.pseudo_qconst"() {qtype = tensor<2xi32>, value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tfl.mean"(%arg0, %0) {keep_dims = true} : (tensor<8x5x10x12x!quant.uniform>, tensor<2xi32>) -> tensor<8x1x1x12x!quant.uniform> + return %1 : tensor<8x1x1x12x!quant.uniform> +} diff --git a/integration_tests/models/8x8/test_mean/test_mean_11.tflite b/integration_tests/models/8x8/test_mean/test_mean_11.tflite new file mode 100644 index 000000000..d4311b854 Binary files /dev/null and b/integration_tests/models/8x8/test_mean/test_mean_11.tflite differ diff --git a/integration_tests/models/8x8/test_mean/test_mean_2.mlir b/integration_tests/models/8x8/test_mean/test_mean_2.mlir new file mode 100644 index 000000000..7b006436c --- /dev/null +++ b/integration_tests/models/8x8/test_mean/test_mean_2.mlir @@ -0,0 +1,6 @@ +// This test reduces the 3rd axis of a 4D tensor without keeping dimensions. +func.func @main(%arg0: tensor<2x3x4x5x!quant.uniform>) -> (tensor<2x3x5x!quant.uniform>) { + %0 = "tfl.pseudo_qconst"() {qtype = tensor<1xi32>, value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tfl.mean"(%arg0, %0) {keep_dims = false} : (tensor<2x3x4x5x!quant.uniform>, tensor<1xi32>) -> tensor<2x3x5x!quant.uniform> + return %1 : tensor<2x3x5x!quant.uniform> +} diff --git a/integration_tests/models/8x8/test_mean/test_mean_2.tflite b/integration_tests/models/8x8/test_mean/test_mean_2.tflite new file mode 100644 index 000000000..d9cf8d847 Binary files /dev/null and b/integration_tests/models/8x8/test_mean/test_mean_2.tflite differ diff --git a/integration_tests/models/8x8/test_mean/test_mean_3.mlir b/integration_tests/models/8x8/test_mean/test_mean_3.mlir new file mode 100644 index 000000000..1c964c9f2 --- /dev/null +++ b/integration_tests/models/8x8/test_mean/test_mean_3.mlir @@ -0,0 +1,6 @@ +// This test reduces the 2nd and 4th axes of a 5D tensor while keeping dimensions. +func.func @main(%arg0: tensor<4x3x5x7x6x!quant.uniform>) -> (tensor<4x1x5x1x6x!quant.uniform>) { + %0 = "tfl.pseudo_qconst"() {qtype = tensor<2xi32>, value = dense<[1, 3]> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tfl.mean"(%arg0, %0) {keep_dims = true} : (tensor<4x3x5x7x6x!quant.uniform>, tensor<2xi32>) -> tensor<4x1x5x1x6x!quant.uniform> + return %1 : tensor<4x1x5x1x6x!quant.uniform> +} diff --git a/integration_tests/models/8x8/test_mean/test_mean_3.tflite b/integration_tests/models/8x8/test_mean/test_mean_3.tflite new file mode 100644 index 000000000..d98f2ba43 Binary files /dev/null and b/integration_tests/models/8x8/test_mean/test_mean_3.tflite differ diff --git a/integration_tests/models/8x8/test_mean/test_mean_4.mlir b/integration_tests/models/8x8/test_mean/test_mean_4.mlir new file mode 100644 index 000000000..595ef06c6 --- /dev/null +++ b/integration_tests/models/8x8/test_mean/test_mean_4.mlir @@ -0,0 +1,6 @@ +// This test reduces the 1st axis of a 3D tensor without keeping dimensions. +func.func @main(%arg0: tensor<10x20x30x!quant.uniform>) -> (tensor<20x30x!quant.uniform>) { + %0 = "tfl.pseudo_qconst"() {qtype = tensor<1xi32>, value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tfl.mean"(%arg0, %0) {keep_dims = false} : (tensor<10x20x30x!quant.uniform>, tensor<1xi32>) -> tensor<20x30x!quant.uniform> + return %1 : tensor<20x30x!quant.uniform> +} diff --git a/integration_tests/models/8x8/test_mean/test_mean_4.tflite b/integration_tests/models/8x8/test_mean/test_mean_4.tflite new file mode 100644 index 000000000..01fb8e2cb Binary files /dev/null and b/integration_tests/models/8x8/test_mean/test_mean_4.tflite differ diff --git a/integration_tests/models/8x8/test_mean/test_mean_5.mlir b/integration_tests/models/8x8/test_mean/test_mean_5.mlir new file mode 100644 index 000000000..bfbc54527 --- /dev/null +++ b/integration_tests/models/8x8/test_mean/test_mean_5.mlir @@ -0,0 +1,6 @@ +// This test reduces all axes of a 2D tensor while keeping dimensions. +func.func @main(%arg0: tensor<5x7x!quant.uniform>) -> (tensor<1x1x!quant.uniform>) { + %0 = "tfl.pseudo_qconst"() {qtype = tensor<2xi32>, value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tfl.mean"(%arg0, %0) {keep_dims = true} : (tensor<5x7x!quant.uniform>, tensor<2xi32>) -> tensor<1x1x!quant.uniform> + return %1 : tensor<1x1x!quant.uniform> +} diff --git a/integration_tests/models/8x8/test_mean/test_mean_5.tflite b/integration_tests/models/8x8/test_mean/test_mean_5.tflite new file mode 100644 index 000000000..ea0352d0d Binary files /dev/null and b/integration_tests/models/8x8/test_mean/test_mean_5.tflite differ diff --git a/integration_tests/models/8x8/test_mean/test_mean_6.mlir b/integration_tests/models/8x8/test_mean/test_mean_6.mlir new file mode 100644 index 000000000..9d8f2be6d --- /dev/null +++ b/integration_tests/models/8x8/test_mean/test_mean_6.mlir @@ -0,0 +1,6 @@ +// This test reduces a 1D tensor to a scalar. +func.func @main(%arg0: tensor<15x!quant.uniform>) -> (tensor>) { + %0 = "tfl.pseudo_qconst"() {qtype = tensor<1xi32>, value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tfl.mean"(%arg0, %0) {keep_dims = false} : (tensor<15x!quant.uniform>, tensor<1xi32>) -> tensor> + return %1 : tensor> +} diff --git a/integration_tests/models/8x8/test_mean/test_mean_6.tflite b/integration_tests/models/8x8/test_mean/test_mean_6.tflite new file mode 100644 index 000000000..21f781ab9 Binary files /dev/null and b/integration_tests/models/8x8/test_mean/test_mean_6.tflite differ diff --git a/integration_tests/models/8x8/test_mean/test_mean_7.mlir b/integration_tests/models/8x8/test_mean/test_mean_7.mlir new file mode 100644 index 000000000..36c6a0739 --- /dev/null +++ b/integration_tests/models/8x8/test_mean/test_mean_7.mlir @@ -0,0 +1,6 @@ +// This test reduces the 2nd and 3rd axes of a 3D tensor with different input/output quantization parameters. +func.func @main(%arg0: tensor<5x6x7x!quant.uniform>) -> (tensor<5x!quant.uniform>) { + %0 = "tfl.pseudo_qconst"() {qtype = tensor<2xi32>, value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tfl.mean"(%arg0, %0) {keep_dims = false} : (tensor<5x6x7x!quant.uniform>, tensor<2xi32>) -> tensor<5x!quant.uniform> + return %1 : tensor<5x!quant.uniform> +} diff --git a/integration_tests/models/8x8/test_mean/test_mean_7.tflite b/integration_tests/models/8x8/test_mean/test_mean_7.tflite new file mode 100644 index 000000000..5949ca73e Binary files /dev/null and b/integration_tests/models/8x8/test_mean/test_mean_7.tflite differ diff --git a/integration_tests/models/8x8/test_softmax/params.yaml b/integration_tests/models/8x8/test_softmax/params.yaml index b684ee66c..aecf2f6bd 100644 --- a/integration_tests/models/8x8/test_softmax/params.yaml +++ b/integration_tests/models/8x8/test_softmax/params.yaml @@ -1 +1 @@ -MAX_ABS_ERROR: 0.0 +MAX_ABS_ERROR: 1.0 diff --git a/integration_tests/runner.py b/integration_tests/runner.py index e717c021d..046139a18 100644 --- a/integration_tests/runner.py +++ b/integration_tests/runner.py @@ -311,6 +311,7 @@ def test_model(request: FixtureRequest, filename: str) -> None: errors = np.concatenate( [(a - b).reshape(-1) for a, b in zip(ref_outputs, xf_outputs)] ) + if not len(errors): continue diff --git a/third_party/lib_nn b/third_party/lib_nn index c0449fd1c..435469856 160000 --- a/third_party/lib_nn +++ b/third_party/lib_nn @@ -1 +1 @@ -Subproject commit c0449fd1c3d12ac907f4cd737bb3f829f0f4ed64 +Subproject commit 4354698561b399ec9f121c1774812c6d746568e5 diff --git a/third_party/lib_tflite_micro b/third_party/lib_tflite_micro index 668bcb10e..4a55f5910 160000 --- a/third_party/lib_tflite_micro +++ b/third_party/lib_tflite_micro @@ -1 +1 @@ -Subproject commit 668bcb10e5258edc8a37f744d6d060637eddcccc +Subproject commit 4a55f5910b0e09da297a9193feed5311df4e851c diff --git a/xformer/IR/XCoreOps.td b/xformer/IR/XCoreOps.td index 54be58203..35d556952 100644 --- a/xformer/IR/XCoreOps.td +++ b/xformer/IR/XCoreOps.td @@ -181,6 +181,25 @@ def XC_AddOp : XC_Op<"add", [Pure, XC_MemoryOverlappable]> { let results = (outs TensorOf<[QI8]> : $output); } +def XC_MeanOp : XC_Op<"mean", [Pure]> { + let summary = "Mean op"; + + let description = [{Mean op.}]; + + let arguments = (ins + TensorOf<[QI8]>:$input, + + I32Attr:$start, + I32Attr:$mean, + I32Attr:$end, + F32Attr:$in_zero_point, + F32Attr:$out_zero_point, + F32Attr:$scale_mul + ); + + let results = (outs TensorOf<[QI8]> : $output); +} + def XC_MulOp : XC_Op<"mul", [Pure, XC_MemoryOverlappable]> { let summary = "Mul op"; diff --git a/xformer/Transforms/Passes.cpp b/xformer/Transforms/Passes.cpp index f87f1facc..0f482b99b 100644 --- a/xformer/Transforms/Passes.cpp +++ b/xformer/Transforms/Passes.cpp @@ -39,6 +39,7 @@ void buildXCoreRemainingPassPipeline(OpPassManager &pm) { pm.addPass(createReplaceAddSubPass()); pm.addPass(createReplaceMaxPoolPass()); pm.addPass(createReplaceMulPass()); + pm.addPass(createReplaceMeanPass()); pm.addPass(createReplaceTransposeConvPass()); pm.addPass(createReplaceConv2DPass()); pm.addPass(createReplacePadPass()); diff --git a/xformer/Transforms/Passes.h b/xformer/Transforms/Passes.h index 32d6d581b..edae71024 100644 --- a/xformer/Transforms/Passes.h +++ b/xformer/Transforms/Passes.h @@ -33,6 +33,7 @@ std::unique_ptr> createApplyTFLPatternsPass(); std::unique_ptr> createRemoveDynamicShapePass(); std::unique_ptr> createReplaceAddSubPass(); std::unique_ptr> createReplaceMulPass(); +std::unique_ptr> createReplaceMeanPass(); std::unique_ptr> createReplaceMaxPoolPass(); std::unique_ptr> createReplaceStridedSlicePass(); std::unique_ptr> createReplaceSlicePass(); diff --git a/xformer/Transforms/ReplaceMean.cpp b/xformer/Transforms/ReplaceMean.cpp new file mode 100644 index 000000000..59c20b387 --- /dev/null +++ b/xformer/Transforms/ReplaceMean.cpp @@ -0,0 +1,121 @@ +// Copyright 2021 XMOS LIMITED. This Software is subject to the terms of the +// XMOS Public License: Version 1 + +#include "IR/XCoreOps.h" +#include "Utils/Util.h" + +extern "C" { +#include "lib_nn/api/nn_layers.h" +} +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/utils/validators.h" + +namespace mlir::xcore { + +namespace { +// Replace TFL Mean with Mean for XCore. +struct ReplaceMean + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReplaceMean) + + void getDependentDialects(DialectRegistry ®istry) const final { + registry.insert(); + } + StringRef getArgument() const final { return "xcore-replace-mean"; } + StringRef getDescription() const final { + return "Replace TFL Mean with Mean for XCore."; + } + void runOnOperation() override; +}; + +struct ReplaceMeanPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::MeanOp meanOp, + PatternRewriter &rewriter) const override { + + auto input = meanOp.getInput(); + auto output = meanOp.getOutput(); + + DenseElementsAttr axisAttr; + matchPattern(meanOp.getAxis(), m_Constant(&axisAttr)); + auto axisValues = axisAttr.getValues(); + std::vector axis(axisValues.begin(), axisValues.end()); + int32_t minAxis = *std::min_element(axis.begin(), axis.end()); + int32_t maxAxis = *std::max_element(axis.begin(), axis.end()); + if (maxAxis - minAxis > axis.size() - 1) { + return failure(); + } + + auto inputType = input.getType().cast(); + auto outputType = output.getType().cast(); + if (!utils::isNBitSignedQType<8>(inputType.getElementType()) || + !utils::isNBitSignedQType<8>(outputType.getElementType())) { + return failure(); + } + + auto inputShape = inputType.getShape(); + auto outputShape = outputType.getShape(); + + int rank = inputShape.size(); + + int beginDims = 1; + for (int i = 0; i < minAxis; i++) { + beginDims *= inputShape[i]; + } + + int endDims = 1; + for (int i = maxAxis + 1; i < rank; i++) { + endDims *= inputShape[i]; + } + + int meanDims = 1; + for (int i = minAxis; i <= maxAxis; i++) { + meanDims *= inputShape[i]; + } + + auto inputQType = utils::getQType(input); + auto outputQType = utils::getQType(output); + + float inZeroPoint = static_cast(inputQType.getZeroPoint()); + float outZeroPoint = static_cast(outputQType.getZeroPoint()); + float scaleMul = inputQType.getScale() / outputQType.getScale() / + static_cast(meanDims); + + auto beginDimsAttr = rewriter.getI32IntegerAttr(beginDims); + auto endDimsAttr = rewriter.getI32IntegerAttr(endDims); + auto meanDimsAttr = rewriter.getI32IntegerAttr(meanDims); + auto inZeroPointAttr = rewriter.getF32FloatAttr(inZeroPoint); + auto outZeroPointAttr = rewriter.getF32FloatAttr(outZeroPoint); + auto scaleMulAttr = rewriter.getF32FloatAttr(scaleMul); + + auto xcMeanOp = rewriter.create( + meanOp.getLoc(), meanOp.getType(), meanOp.getInput(), beginDimsAttr, + meanDimsAttr, endDimsAttr, inZeroPointAttr, outZeroPointAttr, + scaleMulAttr); + rewriter.replaceOp(meanOp, xcMeanOp.getOutput()); + + return success(); + } +}; + +void ReplaceMean::runOnOperation() { + auto *ctx = &getContext(); + func::FuncOp func = getOperation(); + RewritePatternSet patterns(ctx); + patterns.insert(ctx); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); +} +} // namespace + +// Creates an instance of the ReplaceMean pass. +std::unique_ptr> createReplaceMeanPass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace mlir::xcore diff --git a/xformer/Transforms/TFLPatterns.td b/xformer/Transforms/TFLPatterns.td index 077d2c8f2..f0ebc6fe3 100644 --- a/xformer/Transforms/TFLPatterns.td +++ b/xformer/Transforms/TFLPatterns.td @@ -105,30 +105,6 @@ def: [(HasUnequalShape $input2, $output)]>; } -// If MeanOp with spatial axis and rank 2 output, expand output to rank 4, which -// we later lower to AveragePool2D -def : Pat<(TFL_MeanOp - : $output TensorOf<[QI8]>:$input, - (TFL_ConstOp - : $axis_op $axis), - $kd), - (TFL_ReshapeOp(TFL_MeanOp $input, $axis_op, $kd, - (returnType(getExpandedShape $output))), - (TFL_ConstOp(getExpandedShapeAttr $output))), - [(HasSpatialAxisForMean $axis), (HasRank<2> $output)]>; - -// Lower MeanOp with spatial axis to AveragePool2D -def : Pat<(TFL_MeanOp - : $output TensorOf<[QI8]>:$input, (TFL_ConstOp $axis), $kd), - (TFL_QuantizeOp( - TFL_AveragePool2DOp $input, (GetDimAsI32<1> $input), - (GetDimAsI32<2> $input), TFL_PAD_Valid, - ConstantAttr, ConstantAttr, - TFL_AF_None, - (returnType(getTypeOf1WithQParamsOf0 $input, $output))), - (getTypeAttrOf1WithQParamsOf0 $output, $output)), - [(HasSpatialAxisForMean $axis), (HasRank<4> $output)]>; - // PadChannel(PadSpatial) to PadSpatial(PadChannel) // Match cases where arith constant op and tfl constant op are both used foreach constOp = [Arith_ConstantOp, TFL_ConstOp] in { diff --git a/xformer/Transforms/TranslateToCustomOp.cpp b/xformer/Transforms/TranslateToCustomOp.cpp index c4c302751..40805bf6c 100644 --- a/xformer/Transforms/TranslateToCustomOp.cpp +++ b/xformer/Transforms/TranslateToCustomOp.cpp @@ -66,6 +66,20 @@ std::vector MulOp::buildCustomOptions() { return fbb.GetBuffer(); } +std::vector MeanOp::buildCustomOptions() { + flexbuffers::Builder fbb; + auto rootMap = fbb.StartMap(); + fbb.Int("s", (int32_t)getStart()); + fbb.Int("m", (int32_t)getMean()); + fbb.Int("e", (int32_t)getEnd()); + fbb.IndirectFloat("i", getInZeroPoint().convertToFloat()); + fbb.IndirectFloat("o", getOutZeroPoint().convertToFloat()); + fbb.IndirectFloat("sm", getScaleMul().convertToFloat()); + fbb.EndMap(rootMap); + fbb.Finish(); + return fbb.GetBuffer(); +} + std::vector SliceOp::buildCustomOptions() { flexbuffers::Builder fbb; auto rootMap = fbb.StartMap(); @@ -252,6 +266,7 @@ void TranslateToCustomOp::runOnOperation() { patterns.insert>(ctx); patterns.insert>(ctx); patterns.insert>(ctx); + patterns.insert>(ctx); patterns.insert>(ctx); patterns.insert>(ctx); patterns.insert>(ctx); diff --git a/xformer/lib_tflite_micro.BUILD b/xformer/lib_tflite_micro.BUILD index 23612a7b1..3e9a724ea 100644 --- a/xformer/lib_tflite_micro.BUILD +++ b/xformer/lib_tflite_micro.BUILD @@ -41,6 +41,7 @@ filegroup( "lib_tflite_micro/src/tflite-xcore-kernels/xcore_slice.cc", "lib_tflite_micro/src/tflite-xcore-kernels/xcore_broadcast.cc", "lib_tflite_micro/src/tflite-xcore-kernels/xcore_mul.cc", + "lib_tflite_micro/src/tflite-xcore-kernels/xcore_mean.cc", "lib_tflite_micro/src/tflite-xcore-kernels/xcore_binaryi16.cc", "lib_tflite_micro/src/tflite-xcore-kernels/xcore_unaryi16.cc", "lib_tflite_micro/src/tflite-xcore-kernels/xcore_beta_activationf32.cc",