-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
**Description**: This PR adds Ascend CANN execution provider support. **Motivation and Context** - Why is this change required? What problem does it solve? As the info shown in the issue. CANN is the API layer for Ascend processor. Add CANN EP can allow user run onnx model on Ascend hardware via onnxruntime The detail change: 1. Added CANN EP framework. 2. Added the basic operators to support ResNet and VGG model. 3. Added C/C++、Python API support - If it fixes an open issue, please link to the issue here. #11477 Author: lijiawei <[email protected]> wangxiyuan <[email protected]> Co-authored-by: FFrog <[email protected]>
- Loading branch information
1 parent
a6c216d
commit fcd3b12
Showing
71 changed files
with
3,876 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
18 changes: 18 additions & 0 deletions
18
include/onnxruntime/core/providers/cann/cann_provider_options.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Copyright (c) Huawei. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "onnxruntime_c_api.h" | ||
#include "core/framework/arena_extend_strategy.h" | ||
|
||
struct OrtCANNProviderOptions { | ||
int device_id; // CANN device id | ||
int max_opqueue_num; // CANN operator cache information aging configuration | ||
size_t npu_mem_limit; // BFC Arena memory limit for CANN | ||
onnxruntime::ArenaExtendStrategy arena_extend_strategy; // Strategy used to grow the memory arena | ||
int do_copy_in_default_stream; // Flag indicating if copying needs to take place on the | ||
// same stream as the compute stream in the CANN EP | ||
OrtArenaCfg* default_memory_arena_cfg; // CANN memory arena configuration parameters | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
101 changes: 101 additions & 0 deletions
101
onnxruntime/core/providers/cann/activation/activations.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Copyright (c) Huawei. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/providers/cann/activation/activations.h" | ||
|
||
using onnxruntime::common::Status; | ||
namespace onnxruntime { | ||
namespace cann { | ||
|
||
template <typename T> | ||
Status Activations::Prepare(OpKernelContext* ctx, CannPreparation& prepare) const { | ||
const aclDataType aclType = getACLType<T>(); | ||
aclFormat format = ACL_FORMAT_ND; | ||
|
||
const Tensor* X = ctx->Input<Tensor>(0); | ||
Tensor* Y = ctx->Output(0, X->Shape()); | ||
|
||
ORT_TRY { | ||
CANN_PREPARE_INPUTDESC(prepare, aclType, X->Shape().NumDimensions(), X->Shape().GetDims().data(), format); | ||
CANN_PREPARE_OUTPUTDESC(prepare, aclType, X->Shape().NumDimensions(), X->Shape().GetDims().data(), format); | ||
|
||
CANN_PREPARE_INPUTBUFFER(prepare, const_cast<T*>(X->template Data<T>()), X->SizeInBytes()); | ||
CANN_PREPARE_OUTPUTBUFFER(prepare, Y->template MutableData<T>(), Y->SizeInBytes()); | ||
} | ||
ORT_CATCH(const std::exception& e) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what()); | ||
} | ||
|
||
return Status::OK(); | ||
} | ||
|
||
#define REGISTER_ACTIVATION_TYPED_COMPUTE(x, T) \ | ||
template <> \ | ||
Status x<T>::ComputeInternal(OpKernelContext* context) const { \ | ||
CannPreparation prepare; \ | ||
ORT_RETURN_IF_ERROR(Prepare<T>(context, prepare)); \ | ||
CANN_RETURN_IF_ERROR(aclopCompileAndExecute(#x, \ | ||
prepare.inputDesc_.size(), \ | ||
prepare.inputDesc_.data(), \ | ||
prepare.inputBuffers_.data(), \ | ||
prepare.outputDesc_.size(), \ | ||
prepare.outputDesc_.data(), \ | ||
prepare.outputBuffers_.data(), \ | ||
prepare.opAttr_, \ | ||
ACL_ENGINE_SYS, \ | ||
ACL_COMPILE_SYS, \ | ||
NULL, \ | ||
Stream())); \ | ||
return Status::OK(); \ | ||
} | ||
|
||
#define REGISTER_ACTIVATION_TYPED_KERNEL(x, class_name, ver, T) \ | ||
ONNX_OPERATOR_TYPED_KERNEL_EX( \ | ||
x, \ | ||
kOnnxDomain, \ | ||
ver, \ | ||
T, \ | ||
kCannExecutionProvider, \ | ||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \ | ||
class_name<T>); | ||
|
||
#define REGISTER_ACTIVATION_VERSIONED_TYPED_KERNEL(x, startver, endver, T) \ | ||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ | ||
x, \ | ||
kOnnxDomain, \ | ||
startver, \ | ||
endver, \ | ||
T, \ | ||
kCannExecutionProvider, \ | ||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \ | ||
x<T>); | ||
|
||
#define REGISTER_ACTIVATION_VERSIONED_TYPED(name, startver, endver, T) \ | ||
REGISTER_ACTIVATION_VERSIONED_TYPED_KERNEL(name, startver, endver, T) | ||
|
||
#define REGISTER_ACTIVATION_TYPED(name, ver, T) \ | ||
REGISTER_ACTIVATION_TYPED_KERNEL(name, name, ver, T) \ | ||
REGISTER_ACTIVATION_TYPED_COMPUTE(name, T) | ||
|
||
#define REGISTER_ACTIVATION_VERSIONED_HFD(name, startver, endver) \ | ||
REGISTER_ACTIVATION_VERSIONED_TYPED(name, startver, endver, MLFloat16) \ | ||
REGISTER_ACTIVATION_VERSIONED_TYPED(name, startver, endver, float) \ | ||
REGISTER_ACTIVATION_VERSIONED_TYPED(name, startver, endver, double) | ||
|
||
#define REGISTER_ACTIVATION_CSIHFD(name, ver) \ | ||
REGISTER_ACTIVATION_TYPED(name, ver, int8_t) \ | ||
REGISTER_ACTIVATION_TYPED(name, ver, int16_t) \ | ||
REGISTER_ACTIVATION_TYPED(name, ver, int32_t) \ | ||
REGISTER_ACTIVATION_TYPED(name, ver, MLFloat16) \ | ||
REGISTER_ACTIVATION_TYPED(name, ver, float) \ | ||
REGISTER_ACTIVATION_TYPED(name, ver, double) | ||
|
||
REGISTER_ACTIVATION_VERSIONED_HFD(Relu, 6, 12) | ||
|
||
REGISTER_ACTIVATION_VERSIONED_HFD(Relu, 13, 13) | ||
|
||
REGISTER_ACTIVATION_CSIHFD(Relu, 14) | ||
|
||
} // namespace cann | ||
} // namespace onnxruntime |
Oops, something went wrong.