Skip to content

Commit

Permalink
Add CANN EP (#12416)
Browse files Browse the repository at this point in the history
**Description**: This PR adds Ascend CANN execution provider support.

**Motivation and Context**
- Why is this change required? What problem does it solve?
As the info shown in the issue. CANN is the API layer for Ascend
processor. Add CANN EP can allow user run onnx model on Ascend hardware
via onnxruntime
  The detail change:
  1. Added CANN EP framework.
  2. Added the basic operators to support ResNet and VGG model.
  3. Added C/C++、Python API support
- If it fixes an open issue, please link to the issue here.
   #11477

Author: 
lijiawei <[email protected]>
wangxiyuan <[email protected]>

Co-authored-by: FFrog <[email protected]>
  • Loading branch information
2 people authored and linnealovespie committed Sep 30, 2022
1 parent a6c216d commit fcd3b12
Show file tree
Hide file tree
Showing 71 changed files with 3,876 additions and 2 deletions.
6 changes: 6 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ option(onnxruntime_ARMNN_BN_USE_CPU "Use the CPU implementation for the Batch No
option(onnxruntime_ENABLE_INSTRUMENT "Enable Instrument with Event Tracing for Windows (ETW)" OFF)
option(onnxruntime_USE_TELEMETRY "Build with Telemetry" OFF)
option(onnxruntime_USE_MIMALLOC "Override new/delete and arena allocator with mimalloc" OFF)
option(onnxruntime_USE_CANN "Build with CANN support" OFF)
#The onnxruntime_PREFER_SYSTEM_LIB is mainly designed for package managers like apt/yum/vcpkg.
#Please note, by default Protobuf_USE_STATIC_LIBS is OFF but it's recommended to turn it ON on Windows. You should set it properly when onnxruntime_PREFER_SYSTEM_LIB is ON otherwise you'll hit linkage errors.
#If you have already installed protobuf(or the others) in your system at the default system paths(like /usr/include), then it's better to set onnxruntime_PREFER_SYSTEM_LIB ON. Otherwise onnxruntime may see two different protobuf versions and we won't know which one will be used, the worst case could be onnxruntime picked up header files from one of them but the binaries from the other one.
Expand Down Expand Up @@ -1259,6 +1260,11 @@ if (onnxruntime_USE_XNNPACK)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_XNNPACK=1)
list(APPEND ONNXRUNTIME_PROVIDER_NAMES xnnpack)
endif()
if (onnxruntime_USE_CANN)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_CANN=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_CANN=1)
list(APPEND ONNXRUNTIME_PROVIDER_NAMES cann)
endif()

function(onnxruntime_set_compile_flags target_name)
target_compile_definitions(${target_name} PUBLIC EIGEN_USE_THREADS)
Expand Down
35 changes: 35 additions & 0 deletions cmake/onnxruntime_providers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ endif()
if(onnxruntime_USE_SNPE)
include(onnxruntime_snpe_provider.cmake)
endif()
if (onnxruntime_USE_CANN)
set(PROVIDERS_CANN onnxruntime_providers_cann)
endif()

source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_common_srcs} ${onnxruntime_providers_srcs})

Expand Down Expand Up @@ -1535,6 +1538,38 @@ if (onnxruntime_USE_XNNPACK)
endif()
endif()

if (onnxruntime_USE_CANN)
add_definitions(-DUSE_CANN=1)
file(GLOB_RECURSE onnxruntime_providers_cann_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/cann/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/cann/*.cc"
)

# The shared_library files are in a separate list since they use precompiled headers, and the above files have them disabled.
file(GLOB_RECURSE onnxruntime_providers_cann_shared_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc"
)

source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cann_cc_srcs} ${onnxruntime_providers_cann_shared_srcs})
set(onnxruntime_providers_cann_src ${onnxruntime_providers_cann_cc_srcs} ${onnxruntime_providers_cann_shared_srcs})

onnxruntime_add_shared_library_module(onnxruntime_providers_cann ${onnxruntime_providers_cann_src})
onnxruntime_add_include_to_target(onnxruntime_providers_cann onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers)

add_dependencies(onnxruntime_providers_cann onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
target_link_libraries(onnxruntime_providers_cann PRIVATE ascendcl acl_op_compiler nsync_cpp ${ABSEIL_LIBS} onnxruntime_providers_shared)
target_link_directories(onnxruntime_providers_cann PRIVATE ${onnxruntime_CANN_HOME}/lib64)
target_include_directories(onnxruntime_providers_cann PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${onnxruntime_CANN_HOME} ${onnxruntime_CANN_HOME}/include)
set_target_properties(onnxruntime_providers_cann PROPERTIES LINKER_LANGUAGE CXX)
set_target_properties(onnxruntime_providers_cann PROPERTIES FOLDER "ONNXRuntime")

install(TARGETS onnxruntime_providers_cann
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})
endif()

if (NOT onnxruntime_BUILD_SHARED_LIB)
install(TARGETS onnxruntime_providers
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
Expand Down
13 changes: 13 additions & 0 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ target_include_directories(onnxruntime_pybind11_state PRIVATE ${ONNXRUNTIME_ROOT
if(onnxruntime_USE_CUDA AND onnxruntime_CUDNN_HOME)
target_include_directories(onnxruntime_pybind11_state PRIVATE ${onnxruntime_CUDNN_HOME}/include)
endif()
if(onnxruntime_USE_CANN)
target_include_directories(onnxruntime_pybind11_state PRIVATE ${onnxruntime_CANN_HOME}/include)
endif()
if(onnxruntime_USE_ROCM)
target_compile_options(onnxruntime_pybind11_state PUBLIC -D__HIP_PLATFORM_HCC__=1)
target_include_directories(onnxruntime_pybind11_state PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/hiprand/include ${onnxruntime_ROCM_HOME}/rocrand/include ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining)
Expand Down Expand Up @@ -812,6 +815,16 @@ if (onnxruntime_USE_CUDA)
)
endif()

if (onnxruntime_USE_CANN)
add_custom_command(
TARGET onnxruntime_pybind11_state POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
$<TARGET_FILE:onnxruntime_providers_cann>
$<TARGET_FILE:onnxruntime_providers_shared>
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/capi/
)
endif()

if (onnxruntime_USE_ROCM)
add_custom_command(
TARGET onnxruntime_pybind11_state POST_BUILD
Expand Down
11 changes: 11 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,13 @@ if (onnxruntime_USE_CUDA AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_R
list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_cuda_src})
endif()

if (onnxruntime_USE_CANN)
file(GLOB_RECURSE onnxruntime_test_providers_cann_src CONFIGURE_DEPENDS
"${TEST_SRC_DIR}/providers/cann/*"
)
list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_cann_src})
endif()

if (onnxruntime_ENABLE_TRAINING)
file(GLOB_RECURSE orttraining_test_trainingops_cpu_src CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/test/training_ops/compare_provider_test_utils.cc"
Expand Down Expand Up @@ -443,6 +450,10 @@ if(onnxruntime_USE_CUDA)
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_cuda)
endif()

if(onnxruntime_USE_CANN)
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_cann)
endif()

if(onnxruntime_USE_NNAPI_BUILTIN)
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_nnapi)
endif()
Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/core/framework/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ namespace onnxruntime {
constexpr const char* CPU = "Cpu";
constexpr const char* CUDA = "Cuda";
constexpr const char* CUDA_PINNED = "CudaPinned";
constexpr const char* CANN = "Cann";
constexpr const char* CANN_PINNED = "CannPinned";
constexpr const char* DML = "DML";
constexpr const char* OpenVINO_CPU = "OpenVINO_CPU";
constexpr const char* OpenVINO_GPU = "OpenVINO_GPU";
Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/core/framework/ortdevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ struct OrtDevice {
static const DeviceType CPU = 0;
static const DeviceType GPU = 1; // Nvidia or AMD
static const DeviceType FPGA = 2;
static const DeviceType NPU = 3; // Ascend

struct MemType {
// Pre-defined memory types.
static const MemoryType DEFAULT = 0;
static const MemoryType CUDA_PINNED = 1;
static const MemoryType HIP_PINNED = 2;
static const MemoryType CANN_PINNED = 3;
};

constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_)
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/graph/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ constexpr const char* kCoreMLExecutionProvider = "CoreMLExecutionProvider";
constexpr const char* kSnpeExecutionProvider = "SNPEExecutionProvider";
constexpr const char* kTvmExecutionProvider = "TvmExecutionProvider";
constexpr const char* kXnnpackExecutionProvider = "XnnpackExecutionProvider";
constexpr const char* kCannExecutionProvider = "CANNExecutionProvider";

constexpr const char* kExecutionProviderSharedLibraryPath = "shared_lib_path";
constexpr const char* kExecutionProviderSharedLibraryEntry = "provider_factory_entry_point";
Expand Down
18 changes: 18 additions & 0 deletions include/onnxruntime/core/providers/cann/cann_provider_options.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Huawei. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "onnxruntime_c_api.h"
#include "core/framework/arena_extend_strategy.h"

struct OrtCANNProviderOptions {
int device_id; // CANN device id
int max_opqueue_num; // CANN operator cache information aging configuration
size_t npu_mem_limit; // BFC Arena memory limit for CANN
onnxruntime::ArenaExtendStrategy arena_extend_strategy; // Strategy used to grow the memory arena
int do_copy_in_default_stream; // Flag indicating if copying needs to take place on the
// same stream as the compute stream in the CANN EP
OrtArenaCfg* default_memory_arena_cfg; // CANN memory arena configuration parameters
};
65 changes: 65 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ ORT_RUNTIME_CLASS(ArenaCfg);
ORT_RUNTIME_CLASS(PrepackedWeightsContainer);
ORT_RUNTIME_CLASS(TensorRTProviderOptionsV2);
ORT_RUNTIME_CLASS(CUDAProviderOptionsV2);
ORT_RUNTIME_CLASS(CANNProviderOptions);
ORT_RUNTIME_CLASS(Op);
ORT_RUNTIME_CLASS(OpAttr);

Expand Down Expand Up @@ -3496,6 +3497,70 @@ struct OrtApi {
*/
const OrtTrainingApi*(ORT_API_CALL* GetTrainingApi)(uint32_t version) NO_EXCEPTION;

/** \brief Append CANN provider to session options
*
* If CANN is not available (due to a non CANN enabled build, or if CANN is not installed on the system), this function will return failure.
*
* \param[in] options
* \param[in] cann_options
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.13.
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_CANN,
_In_ OrtSessionOptions* options, _In_ const OrtCANNProviderOptions* cann_options);

/** \brief Create an OrtCANNProviderOptions
*
* \param[out] out created ::OrtCANNProviderOptions. Must be released with OrtApi::ReleaseCANNProviderOptions
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.13.
*/
ORT_API2_STATUS(CreateCANNProviderOptions, _Outptr_ OrtCANNProviderOptions** out);

/** \brief Set options in a CANN Execution Provider.
*
* \param[in] cann_options
* \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys
* \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values
* \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.13.
*/
ORT_API2_STATUS(UpdateCANNProviderOptions, _Inout_ OrtCANNProviderOptions* cann_options,
_In_reads_(num_keys) const char* const* provider_options_keys,
_In_reads_(num_keys) const char* const* provider_options_values,
_In_ size_t num_keys);

/** \brief Get serialized CANN provider options string.
*
* \param[in] cann_options OrtCANNProviderOptions instance
* \param[in] allocator a ptr to an instance of OrtAllocator obtained with CreateAllocator()
* or GetAllocatorWithDefaultOptions(), the specified allocator will be used to allocate
* continuous buffers for output strings and lengths.
* \param[out] ptr is a UTF-8 null terminated string allocated using 'allocator'.
* The caller is responsible for using the same allocator to free it.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.13.
*/
ORT_API2_STATUS(GetCANNProviderOptionsAsString, _In_ const OrtCANNProviderOptions* cann_options,
_Inout_ OrtAllocator* allocator, _Outptr_ char** ptr);

/** \brief Release an OrtCANNProviderOptions
*
* \param[in] the pointer of OrtCANNProviderOptions which will been deleted
*
* \since Version 1.13.
*/
void(ORT_API_CALL* ReleaseCANNProviderOptions)(_Frees_ptr_opt_ OrtCANNProviderOptions* input);

#ifdef __cplusplus
OrtApi(const OrtApi&)=delete; // Prevent users from accidentally copying the API structure, it should always be passed as a pointer
#endif
Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,8 @@ struct SessionOptions : Base<OrtSessionOptions> {
SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
SessionOptions& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
SessionOptions& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN
SessionOptions& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options);
/// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports SNPE and XNNPACK.
SessionOptions& AppendExecutionProvider(const std::string& provider_name,
const std::unordered_map<std::string, std::string>& provider_options = {});
Expand Down
5 changes: 5 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,11 @@ inline SessionOptions& SessionOptions::AppendExecutionProvider_MIGraphX(const Or
return *this;
}

inline SessionOptions& SessionOptions::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) {
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(p_, &provider_options));
return *this;
}

inline SessionOptions& SessionOptions::AppendExecutionProvider(
const std::string& provider_name,
const std::unordered_map<std::string, std::string>& provider_options) {
Expand Down
101 changes: 101 additions & 0 deletions onnxruntime/core/providers/cann/activation/activations.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Huawei. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cann/activation/activations.h"

using onnxruntime::common::Status;
namespace onnxruntime {
namespace cann {

template <typename T>
Status Activations::Prepare(OpKernelContext* ctx, CannPreparation& prepare) const {
const aclDataType aclType = getACLType<T>();
aclFormat format = ACL_FORMAT_ND;

const Tensor* X = ctx->Input<Tensor>(0);
Tensor* Y = ctx->Output(0, X->Shape());

ORT_TRY {
CANN_PREPARE_INPUTDESC(prepare, aclType, X->Shape().NumDimensions(), X->Shape().GetDims().data(), format);
CANN_PREPARE_OUTPUTDESC(prepare, aclType, X->Shape().NumDimensions(), X->Shape().GetDims().data(), format);

CANN_PREPARE_INPUTBUFFER(prepare, const_cast<T*>(X->template Data<T>()), X->SizeInBytes());
CANN_PREPARE_OUTPUTBUFFER(prepare, Y->template MutableData<T>(), Y->SizeInBytes());
}
ORT_CATCH(const std::exception& e) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what());
}

return Status::OK();
}

#define REGISTER_ACTIVATION_TYPED_COMPUTE(x, T) \
template <> \
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
CannPreparation prepare; \
ORT_RETURN_IF_ERROR(Prepare<T>(context, prepare)); \
CANN_RETURN_IF_ERROR(aclopCompileAndExecute(#x, \
prepare.inputDesc_.size(), \
prepare.inputDesc_.data(), \
prepare.inputBuffers_.data(), \
prepare.outputDesc_.size(), \
prepare.outputDesc_.data(), \
prepare.outputBuffers_.data(), \
prepare.opAttr_, \
ACL_ENGINE_SYS, \
ACL_COMPILE_SYS, \
NULL, \
Stream())); \
return Status::OK(); \
}

#define REGISTER_ACTIVATION_TYPED_KERNEL(x, class_name, ver, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
ver, \
T, \
kCannExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
class_name<T>);

#define REGISTER_ACTIVATION_VERSIONED_TYPED_KERNEL(x, startver, endver, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
startver, \
endver, \
T, \
kCannExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
x<T>);

#define REGISTER_ACTIVATION_VERSIONED_TYPED(name, startver, endver, T) \
REGISTER_ACTIVATION_VERSIONED_TYPED_KERNEL(name, startver, endver, T)

#define REGISTER_ACTIVATION_TYPED(name, ver, T) \
REGISTER_ACTIVATION_TYPED_KERNEL(name, name, ver, T) \
REGISTER_ACTIVATION_TYPED_COMPUTE(name, T)

#define REGISTER_ACTIVATION_VERSIONED_HFD(name, startver, endver) \
REGISTER_ACTIVATION_VERSIONED_TYPED(name, startver, endver, MLFloat16) \
REGISTER_ACTIVATION_VERSIONED_TYPED(name, startver, endver, float) \
REGISTER_ACTIVATION_VERSIONED_TYPED(name, startver, endver, double)

#define REGISTER_ACTIVATION_CSIHFD(name, ver) \
REGISTER_ACTIVATION_TYPED(name, ver, int8_t) \
REGISTER_ACTIVATION_TYPED(name, ver, int16_t) \
REGISTER_ACTIVATION_TYPED(name, ver, int32_t) \
REGISTER_ACTIVATION_TYPED(name, ver, MLFloat16) \
REGISTER_ACTIVATION_TYPED(name, ver, float) \
REGISTER_ACTIVATION_TYPED(name, ver, double)

REGISTER_ACTIVATION_VERSIONED_HFD(Relu, 6, 12)

REGISTER_ACTIVATION_VERSIONED_HFD(Relu, 13, 13)

REGISTER_ACTIVATION_CSIHFD(Relu, 14)

} // namespace cann
} // namespace onnxruntime
Loading

0 comments on commit fcd3b12

Please sign in to comment.