Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental: convert onnx to linalg #1891

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ llvm::cl::opt<bool> allowSorting("allowSorting",
llvm::cl::desc("Perform topological sort on onnx graph"),
llvm::cl::init(true), llvm::cl::cat(OnnxMlirOptions));

llvm::cl::opt<bool> enableLinalg("enableLinalg",
llvm::cl::desc("Enable ONNX to Linalg conversion and related passes"),
llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions));

// Configuration states associated with certain options.
// For example, when maccel is specified, NNPA can register
// dependent libdnn.
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ extern llvm::cl::opt<int> onnxOpTransformThreshold;
extern llvm::cl::opt<bool> onnxOpTransformReport;
extern llvm::cl::opt<bool> enableParallel;
extern llvm::cl::opt<bool> enableSimdDataLayout;
extern llvm::cl::opt<bool> enableLinalg;

// The customEnvFlags must be scanned before the normal options.
bool parseCustomEnvFlagsCommandLineOption(int argc, const char *const *argv,
Expand Down
19 changes: 19 additions & 0 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Transforms/Passes.h"
Expand Down Expand Up @@ -121,17 +122,35 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
if (enableInstrumentONNXSignature)
pm.addNestedPass<func::FuncOp>(
onnx_mlir::createInstrumentONNXSignaturePass());
if (enableLinalg) {
pm.addPass(onnx_mlir::createLowerONNXToLinalgPass());

// Convert tensor.EmptyOp to bufferization.alloc_tensor
// This pass has to come before Linalg Bufferize pass.
// Otherwise, the bufferization.alloc_tensor will not be lowered
pm.addNestedPass<func::FuncOp>(
bufferization::createEmptyTensorToAllocTensorPass());

// Linalg bufferization can be before or after LowerToKrnlPass
pm.addNestedPass<func::FuncOp>(createLinalgBufferizePass());
}
pm.addPass(onnx_mlir::createLowerToKrnlPass(optLevel, enableParallel));
// An additional pass of canonicalization is helpful because lowering
// from ONNX dialect to Standard dialect exposes additional canonicalization
// opportunities.

// For Linalg and Krnl mixed IR:
// Canonicalization pass will clean up bufferization::to_tensor and to_memref
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We needed to add a new pass to accomplish eliminate type conversions, and looking at your IR output it seems you still have bufferization::to_memref showing up.

Strictly speaking, I don't know if such a pass should be necessary, and in your case it may be because you aren't doing one shot bufferization, but if that doesn't clean things up, you might want to look at adding a pass for cleaning up type conversions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this particular test case, the to_tensor/to_memref are cleaned up by canonicalization pass. Since ONNX has data types other than tensor, such as Sequence and Map type, I think we may have to convert to_tensor/to_memref to UnrealizedConversionCastOp so that they can be cancelled out by canonicalization. The cast Op can handle any type.

pm.addPass(mlir::createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(
onnx_mlir::createDisconnectKrnlDimFromAllocPass());
pm.addPass(mlir::createCanonicalizerPass());
} // namespace onnx_mlir

void addKrnlToAffinePasses(mlir::PassManager &pm) {
if (enableLinalg) {
pm.addNestedPass<func::FuncOp>(createConvertLinalgToAffineLoopsPass());
}
pm.addNestedPass<func::FuncOp>(
onnx_mlir::krnl::createConvertKrnlToAffinePass());
}
Expand Down
1 change: 1 addition & 0 deletions src/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_subdirectory(KrnlToLLVM)
add_subdirectory(KrnlToAffine)
add_subdirectory(KrnlSeqToMemref)
add_subdirectory(ONNXToTOSA)
add_subdirectory(ONNXToLinalg)

if (ONNX_MLIR_ENABLE_MHLO)
add_subdirectory(ONNXToMhlo)
Expand Down
1 change: 1 addition & 0 deletions src/Conversion/ONNXToKrnl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ add_onnx_mlir_library(OMONNXToKrnl

LINK_LIBS PUBLIC
OMAccelerator
OMCompilerOptions
OMConstPropHelper
OMONNXOps
OMSupport
Expand Down
2 changes: 2 additions & 0 deletions src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "src/Compiler/CompilerOptions.hpp"
Expand Down Expand Up @@ -350,6 +351,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
target.addIllegalOp<mlir::AffineLoadOp>();
target.addIllegalOp<mlir::memref::StoreOp>();
target.addIllegalOp<mlir::AffineStoreOp>();
target.addIllegalOp<bufferization::AllocTensorOp>();

// If `emitDealloc` is turned off, make sure we don't have buffer deallocation
// at this level. Will use MLIR buffer-deallocation for this purpose instead.
Expand Down
26 changes: 21 additions & 5 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
//
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"

#include "src/Accelerators/Accelerator.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/Krnl/DialectBuilder.hpp"
#include "src/Dialect/Mlir/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
Expand Down Expand Up @@ -724,8 +727,16 @@ KrnlTypeConverter::KrnlTypeConverter() {
if (inputs.size() != 1)
return llvm::None;

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
// Use ToTensorOp instead of UnrealizedConversionCastOp
// because Linalg use ToTensor, though they are the same in semantic
// Since UnrealizedConversionCastOp is used in other places and will not
// be replaced in this PR
if (enableLinalg)
return builder.create<bufferization::ToTensorOp>(loc, resultType, inputs)
.getResult();
else
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});

addTargetMaterialization([&](OpBuilder &builder, Type resultType,
Expand All @@ -734,8 +745,13 @@ KrnlTypeConverter::KrnlTypeConverter() {
if (inputs.size() != 1)
return llvm::None;

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
// Replace UnrealizedConversionCastOp
if (enableLinalg)
return builder.create<bufferization::ToMemrefOp>(loc, resultType, inputs)
.getResult();
else
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
}

Expand Down
16 changes: 16 additions & 0 deletions src/Conversion/ONNXToLinalg/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0

# Please keep in alphabetical order.
add_onnx_mlir_library(OMONNXToLinalg
ConvertONNXToLinalg.cpp
ONNXToLinalgCommon.cpp
Math/MatMul.cpp

LINK_LIBS PUBLIC
OMAccelerator
OMConstPropHelper
OMONNXOps
OMSupport
MLIRFuncDialect
MLIRFuncTransforms
)
132 changes: 132 additions & 0 deletions src/Conversion/ONNXToLinalg/ConvertONNXToLinalg.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//====------ ConvertONNXToLinalg.cpp - ONNX dialects to Krnl lowering -----===//
//
// Copyright 2019-2022 The IBM Research Authors.
//
// =============================================================================
//
// This file implements the lowering of frontend operations to a combination of
// Krnl IR and standard operations.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "src/Compiler/CompilerOptions.hpp"

#include "src/Accelerators/Accelerator.hpp"
#include "src/Conversion/ONNXToLinalg/ONNXToLinalgCommon.hpp"

using namespace mlir;

namespace onnx_mlir {

void populateONNXToLinalgConversionPattern(RewritePatternSet &patterns,
TypeConverter &typeConverter, MLIRContext *ctx) {

// Math
populateLoweringONNXMatMulOpLinalgPattern(patterns, typeConverter, ctx);
}

//===----------------------------------------------------------------------===//
// ONNX to Krnl Dialect lowering pass
//===----------------------------------------------------------------------===//

/// This is a partial lowering to Krnl loops of the ONNX operations.
struct ONNXToLinalgLoweringPass
: public PassWrapper<ONNXToLinalgLoweringPass, OperationPass<ModuleOp>> {

MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToLinalgLoweringPass)

StringRef getArgument() const override { return "convert-onnx-to-linalg"; }

StringRef getDescription() const override {
return "Lower ONNX ops to Linalg dialect.";
}

// Make sure that we have a valid default constructor and copy
// constructor to make sure that the options are initialized properly.
ONNXToLinalgLoweringPass() = default;
ONNXToLinalgLoweringPass(const ONNXToLinalgLoweringPass &pass)
: PassWrapper<ONNXToLinalgLoweringPass, OperationPass<ModuleOp>>() {}

void runOnOperation() final;
};

void ONNXToLinalgLoweringPass::runOnOperation() {
ModuleOp module = getOperation();

// The first thing to define is the conversion target. This will define the
// final target for this lowering.
ConversionTarget target(getContext());

// We define the specific operations, or dialects, that are legal targets for
// this lowering.
target.addLegalDialect<KrnlDialect, AffineDialect, arith::ArithDialect,
func::FuncDialect, linalg::LinalgDialect, math::MathDialect,
memref::MemRefDialect, shape::ShapeDialect, scf::SCFDialect,
tensor::TensorDialect>();
// Needed to support unsigned int computations. To be removed if we use a
// scheme that does not rely on the UnrealizedConversionCastOp.
target.addLegalOp<::mlir::UnrealizedConversionCastOp>();
// Make ONNXNoneOp legal so that other ONNX ops can use it during the
// lowering. ONNXNoneOp will be dangling and removed by calling
// canonicalization after the lowering.
target.addLegalOp<::mlir::ONNXNoneOp>();
target.addLegalOp<linalg::MatmulOp>();
target.addLegalOp<tensor::EmptyOp>();

// The following requirements are from Krnl and they are kept if ONNXToKrnl
// is after this pass.
// If the Linalg is on tensor instead of memref, this lowering will not
// generate memref or Affine load/store. However, these requiremnts will may
// be an issue if Ops are lowered other than Krnl Use krnl.load/store instead
// of std.load/store and affine.load/store. krnl.load/store will be lowered to
// std.load/store and affine.load/store by `convert-krnl-to-affine` pass.
target.addIllegalOp<mlir::memref::LoadOp>();
target.addIllegalOp<mlir::AffineLoadOp>();
target.addIllegalOp<mlir::memref::StoreOp>();
target.addIllegalOp<mlir::AffineStoreOp>();

target.addIllegalOp<ONNXMatMulOp>();

// TODO: add any other ops which are considered legal.
// Some operations can be marked as being still legal.
// Example: target.addLegalOp<mlir::OpName>();

// For future: Handle the accelerator target.
// for (auto *accel : onnx_mlir::accel::Accelerator::getAccelerators())
// accel->conversionTargetONNXToLinalg(target);

// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the frontend operations.
RewritePatternSet patterns(&getContext());

// Convert types to legal types for the Krnl dialect.
LinalgTypeConverter linalgTypeConverter;

// Define patterns.
populateONNXToLinalgConversionPattern(
patterns, linalgTypeConverter, &getContext());

// For future: Rewrite patterns for accelerators.
// for (auto *accel : onnx_mlir::accel::Accelerator::getAccelerators())
// accel->rewritePatternONNXToLinalg(patterns, krnlTypeConverter,
// &getContext());

// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
// operations were not converted successfully.
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
}

std::unique_ptr<Pass> createLowerONNXToLinalgPass() {
return std::make_unique<ONNXToLinalgLoweringPass>();
}

} // namespace onnx_mlir
63 changes: 63 additions & 0 deletions src/Conversion/ONNXToLinalg/Math/MatMul.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===----------------- Matmul.cpp - Lowering Matmul Op --------------------===//
//
// Copyright 2019-2022 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the ONNX Matmul Operator to Linalg dialect.
//
//===----------------------------------------------------------------------===//

#include "llvm/Support/Debug.h"

#include "src/Conversion/ONNXToLinalg/ONNXToLinalgCommon.hpp"
#include "src/Dialect/Mlir/DialectBuilder.hpp"

#define DEBUG_TYPE "matmul"

using namespace mlir;

namespace onnx_mlir {

struct ONNXMatMulOpLinalgLowering : public ConversionPattern {
ONNXMatMulOpLinalgLowering(TypeConverter &typeConverter, MLIRContext *ctx)
: ConversionPattern(
typeConverter, mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();

auto outputType = op->getResult(0).getType().cast<ShapedType>();

SmallVector<Value> dynamicDims;
if (outputType.isDynamicDim(0)) {
dynamicDims.emplace_back(
rewriter.create<tensor::DimOp>(loc, operands[0], 0));
}
if (outputType.isDynamicDim(1)) {
dynamicDims.emplace_back(
rewriter.create<tensor::DimOp>(loc, operands[1], 1));
}

auto outV = rewriter.create<tensor::EmptyOp>(
loc, outputType.getShape(), outputType.getElementType(), dynamicDims);

SmallVector<Value, 1> outputs;
outputs.emplace_back(outV);
auto newOp =
rewriter.create<linalg::MatmulOp>(loc, outputType, operands, outputs);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may be aware, but linalg::MatMulOp only works for 2D MatMul, and you will need to call different MatMul ops depending on tensor shape.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. I need to check the shape of tensor for linalg.matmul. Other cases can be either lowered to Krnl, or linalg.generic.

rewriter.replaceOp(op, newOp.getResults());
return success();
}
}; // namespace onnx_mlir

void populateLoweringONNXMatMulOpLinalgPattern(RewritePatternSet &patterns,
TypeConverter &typeConverter, MLIRContext *ctx) {
patterns.insert<ONNXMatMulOpLinalgLowering>(typeConverter, ctx);
}

} // namespace onnx_mlir
Loading