-
Notifications
You must be signed in to change notification settings - Fork 323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Experimental: convert onnx to linalg #1891
base: main
Are you sure you want to change the base?
Changes from all commits
eaad3f0
c5755c6
55d3e56
b8ba626
01a31f8
178160c
732fc25
07bedd5
b9221b7
cc67107
6081399
52baddb
70c4be6
d1b6b37
81de4ba
4b1ad6e
8ed111f
304e330
64e8bca
bc1376d
a82c102
9ecf0a9
c94c71f
b534d31
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# Please keep in alphabetical order. | ||
add_onnx_mlir_library(OMONNXToLinalg | ||
ConvertONNXToLinalg.cpp | ||
ONNXToLinalgCommon.cpp | ||
Math/MatMul.cpp | ||
|
||
LINK_LIBS PUBLIC | ||
OMAccelerator | ||
OMConstPropHelper | ||
OMONNXOps | ||
OMSupport | ||
MLIRFuncDialect | ||
MLIRFuncTransforms | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//====------ ConvertONNXToLinalg.cpp - ONNX dialects to Krnl lowering -----===// | ||
// | ||
// Copyright 2019-2022 The IBM Research Authors. | ||
// | ||
// ============================================================================= | ||
// | ||
// This file implements the lowering of frontend operations to a combination of | ||
// Krnl IR and standard operations. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/SCF/IR/SCF.h" | ||
#include "mlir/Dialect/Shape/IR/Shape.h" | ||
#include "src/Compiler/CompilerOptions.hpp" | ||
|
||
#include "src/Accelerators/Accelerator.hpp" | ||
#include "src/Conversion/ONNXToLinalg/ONNXToLinalgCommon.hpp" | ||
|
||
using namespace mlir; | ||
|
||
namespace onnx_mlir { | ||
|
||
void populateONNXToLinalgConversionPattern(RewritePatternSet &patterns, | ||
TypeConverter &typeConverter, MLIRContext *ctx) { | ||
|
||
// Math | ||
populateLoweringONNXMatMulOpLinalgPattern(patterns, typeConverter, ctx); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ONNX to Krnl Dialect lowering pass | ||
//===----------------------------------------------------------------------===// | ||
|
||
/// This is a partial lowering to Krnl loops of the ONNX operations. | ||
struct ONNXToLinalgLoweringPass | ||
: public PassWrapper<ONNXToLinalgLoweringPass, OperationPass<ModuleOp>> { | ||
|
||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToLinalgLoweringPass) | ||
|
||
StringRef getArgument() const override { return "convert-onnx-to-linalg"; } | ||
|
||
StringRef getDescription() const override { | ||
return "Lower ONNX ops to Linalg dialect."; | ||
} | ||
|
||
// Make sure that we have a valid default constructor and copy | ||
// constructor to make sure that the options are initialized properly. | ||
ONNXToLinalgLoweringPass() = default; | ||
ONNXToLinalgLoweringPass(const ONNXToLinalgLoweringPass &pass) | ||
: PassWrapper<ONNXToLinalgLoweringPass, OperationPass<ModuleOp>>() {} | ||
|
||
void runOnOperation() final; | ||
}; | ||
|
||
void ONNXToLinalgLoweringPass::runOnOperation() { | ||
ModuleOp module = getOperation(); | ||
|
||
// The first thing to define is the conversion target. This will define the | ||
// final target for this lowering. | ||
ConversionTarget target(getContext()); | ||
|
||
// We define the specific operations, or dialects, that are legal targets for | ||
// this lowering. | ||
target.addLegalDialect<KrnlDialect, AffineDialect, arith::ArithDialect, | ||
func::FuncDialect, linalg::LinalgDialect, math::MathDialect, | ||
memref::MemRefDialect, shape::ShapeDialect, scf::SCFDialect, | ||
tensor::TensorDialect>(); | ||
// Needed to support unsigned int computations. To be removed if we use a | ||
// scheme that does not rely on the UnrealizedConversionCastOp. | ||
target.addLegalOp<::mlir::UnrealizedConversionCastOp>(); | ||
// Make ONNXNoneOp legal so that other ONNX ops can use it during the | ||
// lowering. ONNXNoneOp will be dangling and removed by calling | ||
// canonicalization after the lowering. | ||
target.addLegalOp<::mlir::ONNXNoneOp>(); | ||
target.addLegalOp<linalg::MatmulOp>(); | ||
target.addLegalOp<tensor::EmptyOp>(); | ||
|
||
// The following requirements are from Krnl and they are kept if ONNXToKrnl | ||
// is after this pass. | ||
// If the Linalg is on tensor instead of memref, this lowering will not | ||
// generate memref or Affine load/store. However, these requiremnts will may | ||
// be an issue if Ops are lowered other than Krnl Use krnl.load/store instead | ||
// of std.load/store and affine.load/store. krnl.load/store will be lowered to | ||
// std.load/store and affine.load/store by `convert-krnl-to-affine` pass. | ||
target.addIllegalOp<mlir::memref::LoadOp>(); | ||
target.addIllegalOp<mlir::AffineLoadOp>(); | ||
target.addIllegalOp<mlir::memref::StoreOp>(); | ||
target.addIllegalOp<mlir::AffineStoreOp>(); | ||
|
||
target.addIllegalOp<ONNXMatMulOp>(); | ||
|
||
// TODO: add any other ops which are considered legal. | ||
// Some operations can be marked as being still legal. | ||
// Example: target.addLegalOp<mlir::OpName>(); | ||
|
||
// For future: Handle the accelerator target. | ||
// for (auto *accel : onnx_mlir::accel::Accelerator::getAccelerators()) | ||
// accel->conversionTargetONNXToLinalg(target); | ||
|
||
// Now that the conversion target has been defined, we just need to provide | ||
// the set of patterns that will lower the frontend operations. | ||
RewritePatternSet patterns(&getContext()); | ||
|
||
// Convert types to legal types for the Krnl dialect. | ||
LinalgTypeConverter linalgTypeConverter; | ||
|
||
// Define patterns. | ||
populateONNXToLinalgConversionPattern( | ||
patterns, linalgTypeConverter, &getContext()); | ||
|
||
// For future: Rewrite patterns for accelerators. | ||
// for (auto *accel : onnx_mlir::accel::Accelerator::getAccelerators()) | ||
// accel->rewritePatternONNXToLinalg(patterns, krnlTypeConverter, | ||
// &getContext()); | ||
|
||
// With the target and rewrite patterns defined, we can now attempt the | ||
// conversion. The conversion will signal failure if any of our `illegal` | ||
// operations were not converted successfully. | ||
if (failed(applyPartialConversion(module, target, std::move(patterns)))) { | ||
signalPassFailure(); | ||
} | ||
} | ||
|
||
std::unique_ptr<Pass> createLowerONNXToLinalgPass() { | ||
return std::make_unique<ONNXToLinalgLoweringPass>(); | ||
} | ||
|
||
} // namespace onnx_mlir |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//===----------------- Matmul.cpp - Lowering Matmul Op --------------------===// | ||
// | ||
// Copyright 2019-2022 The IBM Research Authors. | ||
// | ||
// ============================================================================= | ||
// | ||
// This file lowers the ONNX Matmul Operator to Linalg dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "llvm/Support/Debug.h" | ||
|
||
#include "src/Conversion/ONNXToLinalg/ONNXToLinalgCommon.hpp" | ||
#include "src/Dialect/Mlir/DialectBuilder.hpp" | ||
|
||
#define DEBUG_TYPE "matmul" | ||
|
||
using namespace mlir; | ||
|
||
namespace onnx_mlir { | ||
|
||
struct ONNXMatMulOpLinalgLowering : public ConversionPattern { | ||
ONNXMatMulOpLinalgLowering(TypeConverter &typeConverter, MLIRContext *ctx) | ||
: ConversionPattern( | ||
typeConverter, mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {} | ||
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||
ConversionPatternRewriter &rewriter) const final { | ||
Location loc = op->getLoc(); | ||
|
||
auto outputType = op->getResult(0).getType().cast<ShapedType>(); | ||
|
||
SmallVector<Value> dynamicDims; | ||
if (outputType.isDynamicDim(0)) { | ||
dynamicDims.emplace_back( | ||
rewriter.create<tensor::DimOp>(loc, operands[0], 0)); | ||
} | ||
if (outputType.isDynamicDim(1)) { | ||
dynamicDims.emplace_back( | ||
rewriter.create<tensor::DimOp>(loc, operands[1], 1)); | ||
} | ||
|
||
auto outV = rewriter.create<tensor::EmptyOp>( | ||
loc, outputType.getShape(), outputType.getElementType(), dynamicDims); | ||
|
||
SmallVector<Value, 1> outputs; | ||
outputs.emplace_back(outV); | ||
auto newOp = | ||
rewriter.create<linalg::MatmulOp>(loc, outputType, operands, outputs); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You may be aware, but linalg::MatMulOp only works for 2D MatMul, and you will need to call different MatMul ops depending on tensor shape. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right. I need to check the shape of tensor for linalg.matmul. Other cases can be either lowered to Krnl, or linalg.generic. |
||
rewriter.replaceOp(op, newOp.getResults()); | ||
return success(); | ||
} | ||
}; // namespace onnx_mlir | ||
|
||
void populateLoweringONNXMatMulOpLinalgPattern(RewritePatternSet &patterns, | ||
TypeConverter &typeConverter, MLIRContext *ctx) { | ||
patterns.insert<ONNXMatMulOpLinalgLowering>(typeConverter, ctx); | ||
} | ||
|
||
} // namespace onnx_mlir |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We needed to add a new pass to accomplish eliminate type conversions, and looking at your IR output it seems you still have bufferization::to_memref showing up.
Strictly speaking, I don't know if such a pass should be necessary, and in your case it may be because you aren't doing one shot bufferization, but if that doesn't clean things up, you might want to look at adding a pass for cleaning up type conversions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this particular test case, the to_tensor/to_memref are cleaned up by canonicalization pass. Since ONNX has data types other than tensor, such as Sequence and Map type, I think we may have to convert to_tensor/to_memref to UnrealizedConversionCastOp so that they can be cancelled out by canonicalization. The cast Op can handle any type.