-
Notifications
You must be signed in to change notification settings - Fork 328
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TOSA] Update type converter and unary ops (#1553)
Signed-off-by: Philipp Braun <[email protected]> Co-authored-by: Alexandre Eichenberger <[email protected]>
- Loading branch information
1 parent
172c226
commit 64c3d18
Showing
6 changed files
with
223 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//===---------------- Elementwise.cpp - Elementwise Op --------------------===// | ||
// | ||
// Copyright (c) 2022 Advanced Micro Devices, Inc. | ||
// | ||
// ============================================================================= | ||
// | ||
// This file lowers ONNX element-wise operators to TOSA dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/IR/TypeUtilities.h" | ||
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" | ||
|
||
using namespace mlir; | ||
|
||
namespace onnx_mlir { | ||
|
||
template <> | ||
struct TOSADialectOp<ONNXNegOp> { | ||
using Op = tosa::NegateOp; | ||
}; | ||
|
||
namespace { | ||
|
||
// Element-wise unary ops lowering to TOSA dialect. | ||
//===----------------------------------------------------------------------===// | ||
template <typename ElementwiseUnaryOp> | ||
class ONNXElementwiseUnaryOpLoweringToTOSA | ||
: public OpConversionPattern<ElementwiseUnaryOp> { | ||
public: | ||
using OpConversionPattern<ElementwiseUnaryOp>::OpConversionPattern; | ||
using OpAdaptor = typename ElementwiseUnaryOp::Adaptor; | ||
LogicalResult matchAndRewrite(ElementwiseUnaryOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
rewriter.replaceOpWithNewOp<TOSAOp<ElementwiseUnaryOp>>( | ||
op, op.getType(), adaptor.X()); | ||
return success(); | ||
} | ||
}; | ||
|
||
class ONNXFloorOpLoweringToTOSA : public OpConversionPattern<ONNXFloorOp> { | ||
public: | ||
using OpConversionPattern<ONNXFloorOp>::OpConversionPattern; | ||
using OpAdaptor = typename ONNXFloorOp::Adaptor; | ||
LogicalResult matchAndRewrite(ONNXFloorOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
|
||
auto scalarType = getElementTypeOrSelf(adaptor.X()); | ||
if (!isTOSAFloat(scalarType)) | ||
return rewriter.notifyMatchFailure( | ||
op, "`tosa.floor` only supports float types"); | ||
|
||
rewriter.replaceOpWithNewOp<tosa::FloorOp>(op, op.getType(), adaptor.X()); | ||
return success(); | ||
} | ||
}; | ||
|
||
class ONNXReluOpLoweringToTOSA : public OpConversionPattern<ONNXReluOp> { | ||
public: | ||
using OpConversionPattern::OpConversionPattern; | ||
LogicalResult matchAndRewrite(ONNXReluOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
|
||
Value input = adaptor.X(); | ||
|
||
// Quantized types are not supported right now (in type conversion). | ||
// Once they are, the input should be rescaled for quantized types. (TBD) | ||
// Maps to `tosa.clamp` which has both int and fp limits. | ||
rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, op.getType(), input, | ||
rewriter.getI64IntegerAttr(0), | ||
rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()), | ||
rewriter.getF32FloatAttr(0.0f), | ||
rewriter.getF32FloatAttr(std::numeric_limits<float>::max())); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void populateLoweringONNXElementwiseOpToTOSAPattern(ConversionTarget &target, | ||
RewritePatternSet &patterns, TypeConverter &typeConverter, | ||
MLIRContext *ctx) { | ||
patterns.insert<ONNXElementwiseUnaryOpLoweringToTOSA<ONNXNegOp>, | ||
ONNXFloorOpLoweringToTOSA, ONNXReluOpLoweringToTOSA>(typeConverter, ctx); | ||
} | ||
|
||
} // namespace onnx_mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//====------ ONNXToTOSACommon.hpp - ONNX dialects to TOSA lowering --------===// | ||
// | ||
// Copyright (c) 2022 Advanced Micro Devices, Inc. | ||
// | ||
// ============================================================================= | ||
// | ||
// This file contains common code shared by the functions performing the | ||
// lowering to the TOSA dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Quant/QuantTypes.h" | ||
#include "mlir/Dialect/Tosa/IR/TosaOps.h" | ||
|
||
#include "mlir/IR/MLIRContext.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
#include "src/Dialect/ONNX/DialectBuilder.hpp" | ||
#include "src/Dialect/ONNX/ONNXOps.hpp" | ||
#include "src/Dialect/ONNX/ONNXOpsHelper.hpp" | ||
#include "src/Pass/Passes.hpp" | ||
#include "src/Transform/ONNX/ConstPropHelper.hpp" | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Functions to add lowering patterns for frontend operations. | ||
//===----------------------------------------------------------------------===// | ||
|
||
namespace onnx_mlir { | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Check for valid TOSA types. | ||
//===----------------------------------------------------------------------===// | ||
|
||
inline bool isTOSASignedInt(Type type) { | ||
IntegerType intType = type.dyn_cast<IntegerType>(); | ||
std::set<unsigned> intWidth{8, 16, 32, 48, 64}; | ||
return intType && intType.isSigned() && | ||
(intWidth.find(intType.getWidth()) != intWidth.end()); | ||
} | ||
|
||
inline bool isTOSAFloat(Type type) { | ||
return type.isa<BFloat16Type, Float16Type, Float32Type>(); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// This is to get a TOSA operation of a given type for a specific operation. | ||
//===----------------------------------------------------------------------===// | ||
template <typename ONNXOp> | ||
struct TOSADialectOp { | ||
using Op = void; | ||
}; | ||
|
||
template <typename Op> | ||
using TOSAOp = typename TOSADialectOp<Op>::Op; | ||
|
||
// `Math` directory methods: | ||
void populateLoweringONNXElementwiseOpToTOSAPattern( | ||
ConversionTarget &, RewritePatternSet &, TypeConverter &, MLIRContext *); | ||
} // namespace onnx_mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa %s -split-input-file | FileCheck %s | ||
|
||
func.func @test_relu(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { | ||
%0 = "onnx.Relu"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> | ||
"func.return"(%0) : (tensor<10x10xf32>) -> () | ||
// CHECK-LABEL: func @test_relu | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { | ||
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.clamp"([[PARAM_0_]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<10x10xf32>) -> tensor<10x10xf32> | ||
// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32> | ||
// CHECK-NEXT: } | ||
} | ||
|
||
func.func @test_relu_dynamic(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { | ||
%0 = "onnx.Relu"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32> | ||
"func.return"(%0) : (tensor<*xf32>) -> () | ||
// CHECK-LABEL: func @test_relu_dynamic | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x10xf32>) -> tensor<?x10xf32> { | ||
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.clamp"([[PARAM_0_]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<?x10xf32>) -> tensor<?x10xf32> | ||
// CHECK-NEXT: return [[VAR_0_]] : tensor<?x10xf32> | ||
// CHECK-NEXT: } | ||
} | ||
|
||
func.func @test_neg(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { | ||
%0 = "onnx.Neg"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> | ||
"func.return"(%0) : (tensor<10x10xf32>) -> () | ||
// CHECK-LABEL: func @test_neg | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { | ||
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.negate"([[PARAM_0_]]) : (tensor<10x10xf32>) -> tensor<10x10xf32> | ||
} | ||
|
||
func.func @test_floor(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { | ||
%0 = "onnx.Floor"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> | ||
"func.return"(%0) : (tensor<10x10xf32>) -> () | ||
// CHECK-LABEL: func @test_floor | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { | ||
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.floor"([[PARAM_0_]]) : (tensor<10x10xf32>) -> tensor<10x10xf32> | ||
} |
This file was deleted.
Oops, something went wrong.