Skip to content

Commit

Permalink
Merge pull request #55 from cyLi-Tiger/main
Browse files Browse the repository at this point in the history
feat: add onnx to MegCC
  • Loading branch information
yeasoon authored Oct 30, 2023
2 parents 51e242a + 08f1655 commit 4e35d92
Show file tree
Hide file tree
Showing 16 changed files with 1,275 additions and 1 deletion.
6 changes: 6 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,9 @@
[submodule "third_party/llvm-project"]
path = third_party/llvm-project
url = https://github.com/llvm/llvm-project.git
[submodule "third_party/onnx"]
path = third_party/onnx
url = https://github.com/onnx/onnx.git
[submodule "third_party/protobuf"]
path = third_party/protobuf
url = https://github.com/protocolbuffers/protobuf.git
24 changes: 24 additions & 0 deletions compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,30 @@ add_library(mgb_imported INTERFACE)
target_link_libraries(mgb_imported INTERFACE ${MGB_INSTALL_LIB})
target_include_directories(mgb_imported INTERFACE ${MGB_INCLUDE_DIR})

find_library(
ONNX_INSTALL_LIB
NAMES libonnx.a
PATHS ${PROJECT_SOURCE_DIR}/../third_party/onnx/install/lib/ REQUIRED)
find_library(
ONNX_PROTO_INSTALL_LIB
NAMES libonnx_proto.a
PATHS ${PROJECT_SOURCE_DIR}/../third_party/onnx/install/lib/ REQUIRED)

list(APPEND ONNX_LIBS -Wl,--whole-archive ${ONNX_PROTO_INSTALL_LIB}
-Wl,--no-whole-archive)
list(APPEND ONNX_LIBS ${ONNX_INSTALL_LIB})
set(ONNX_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/../third_party/onnx/install/include)

list(APPEND PROTOBUF_LIBS
${PROJECT_SOURCE_DIR}/../third_party/protobuf/install/lib/libprotobuf.a
${PROJECT_SOURCE_DIR}/../third_party/protobuf/install/lib/libprotoc.a)
set(PROTOBUF_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/../third_party/protobuf/install/include)

add_library(onnx_imported INTERFACE)
target_link_libraries(onnx_imported INTERFACE ${ONNX_LIB} ${PROTOBUF_LIBS})
target_include_directories(onnx_imported INTERFACE ${ONNX_INCLUDE_DIR}
${PROTOBUF_INCLUDE_DIR})

if(APPLE)
if(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES arm64)
set(TCC_INSTALL_LIB
Expand Down
60 changes: 60 additions & 0 deletions compiler/include/compiler/Target/onnx/helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#pragma once

#include <map>
#include <vector>

#include "compiler/Common/Logger.h"
#include "compiler/Common/MemoryStatus.h"

#include "megdnn/basic_types.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"

#include "llvm/Support/raw_ostream.h"

#include "compiler/Dialect/MGB/IR/MGBDialect.h"

#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "onnx/common/ir.h"

namespace mlir {
namespace ONNX {
static inline mlir::Type elemTypeToType(
mlir::MLIRContext* context, const int32_t& elem_type) {
switch (elem_type) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
return mlir::FloatType::getF32(context);
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
return mlir::IntegerType::get(context, 8, mlir::IntegerType::Unsigned);
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
return mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed);
default:
CC_ABORT << "Unsupported dtype " << elem_type << "\n";
break;
}
return mlir::Type();
}

static inline mlir::ShapedType valueToShapedType(
mlir::MLIRContext* context, ONNX_NAMESPACE::Value* value) {
std::vector<int64_t> dims;
for (auto dim : value->sizes()) {
dims.emplace_back(dim.dim);
}
LOG_DEBUG << "Create RankedTensorType in Value with shape= " << dims << "\n";
mlir::ShapedType res;
if (dims.size() > 0) {
res = mlir::RankedTensorType::get(
dims, elemTypeToType(context, value->elemType()));
} else {
LOG_WARN << "Shape is unknown, compiler just make 1 dim dynamic tensor "
"type\n";
res = mlir::RankedTensorType::get(
{-1}, elemTypeToType(context, value->elemType()));
}
return res;
}

} // namespace ONNX
} // namespace mlir
21 changes: 21 additions & 0 deletions compiler/include/compiler/Target/onnx/import.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once

#include <map>
#include <vector>

#include "megdnn/basic_types.h"
#include "mlir/IR/BuiltinOps.h"

namespace mlir {
namespace ONNX {

struct ONNXImporterOptions {
std::string module_name;
std::string model_path;
std::string input_shape_str;
};

mlir::LogicalResult import_onnx(mlir::ModuleOp module, std::string model_path);

} // namespace ONNX
} // namespace mlir
2 changes: 1 addition & 1 deletion compiler/lib/KernelGen/Common/ConvKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ConvImpl : public KernelFunc {
static bool is_channel_broadcast_bias(TContext* ctx) {
if (is_bias(ctx)) {
CCOperand bias = ctx->getAttrOprand("operand:2");
return bias.shape[0] == 1 && bias.shape[2] == 1 && bias.shape[3] == 1;
return (bias.shape[0] == 1 && bias.shape[2] == 1 && bias.shape[3] == 1) || bias.shape.size() == 1;
}
return false;
}
Expand Down
1 change: 1 addition & 0 deletions compiler/lib/Target/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(MGB)
add_subdirectory(TinyNN)
add_subdirectory(Hako)
add_subdirectory(onnx)
22 changes: 22 additions & 0 deletions compiler/lib/Target/onnx/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
set(LLVM_OPTIONAL_SOURCES onnx_importer.cpp)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DONNX_ML=1 -DONNX_NAMESPACE=onnx")
add_compile_definitions("ONNX_NO_EXCEPTIONS")

add_mlir_translation_library(
MLIRONNXImporter
importer.cpp
DEPENDS
MLIRMGBIncGen
LINK_LIBS
PUBLIC
MLIRIR
MLIRMGB
MLIRStandard)
# detail obj library created in llvm_add_library
target_include_directories(
obj.MLIRONNXImporter PRIVATE ${MGB_INCLUDE_DIR} ${ONNX_INCLUDE_DIR}
${PROTOBUF_INCLUDE_DIR})
# add onnx-imported
target_link_libraries(MLIRONNXImporter PUBLIC $<BUILD_INTERFACE:onnx_imported>)
# target_compile_options(MLIRONNXImporter PUBLIC -fexceptions)
Loading

0 comments on commit 4e35d92

Please sign in to comment.