Skip to content

Commit

Permalink
Enable Arm Compute Library 23.08 (#17672)
Browse files Browse the repository at this point in the history
### Description

This PR enables onnxruntime to build with the most recent release of Arm
Compute Library

### Motivation and Context

The latest version of Arm Compute Library that onnxruntime can build is
20.02 which is more than 3 years old.
  • Loading branch information
milpuz01 authored Jan 9, 2024
1 parent a2afd92 commit 37ac9d3
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 20 deletions.
9 changes: 7 additions & 2 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ option(onnxruntime_USE_ACL_1902 "Build with ACL version 1902 support" OFF)
option(onnxruntime_USE_ACL_1905 "Build with ACL version 1905 support" OFF)
option(onnxruntime_USE_ACL_1908 "Build with ACL version 1908 support" OFF)
option(onnxruntime_USE_ACL_2002 "Build with ACL version 2002 support" OFF)
option(onnxruntime_USE_ACL_2308 "Build with ACL version 2308 support" OFF)
option(onnxruntime_USE_ARMNN "Build with ArmNN support" OFF)
option(onnxruntime_ARMNN_RELU_USE_CPU "Use the CPU implementation for the Relu operator for the ArmNN EP" ON)
option(onnxruntime_ARMNN_BN_USE_CPU "Use the CPU implementation for the Batch Normalization operator for the ArmNN EP" ON)
Expand Down Expand Up @@ -1110,7 +1111,7 @@ function(onnxruntime_add_include_to_target dst_target)
endfunction()

# ACL
if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 OR onnxruntime_USE_ACL_1908 OR onnxruntime_USE_ACL_2002)
if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 OR onnxruntime_USE_ACL_1908 OR onnxruntime_USE_ACL_2002 OR onnxruntime_USE_ACL_2308)
set(onnxruntime_USE_ACL ON)
if (onnxruntime_USE_ACL_1902)
add_definitions(-DACL_1902=1)
Expand All @@ -1121,7 +1122,11 @@ if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905
if (onnxruntime_USE_ACL_2002)
add_definitions(-DACL_2002=1)
else()
add_definitions(-DACL_1905=1)
if (onnxruntime_USE_ACL_2308)
add_definitions(-DACL_2308=1)
else()
add_definitions(-DACL_1905=1)
endif()
endif()
endif()
endif()
Expand Down
16 changes: 14 additions & 2 deletions onnxruntime/core/providers/acl/math/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,18 @@ class Gemm : public onnxruntime::Gemm<T> {
}

Status Compute(OpKernelContext* context) const override {
#ifdef ACL_2308
if (this->packed_b_) {
// Prepacked RHS not supported, defaulting to cpu execution provider
return onnxruntime::Gemm<T>::Compute(context);
}
#endif
const auto A = context->Input<Tensor>(0);
const auto B = context->Input<Tensor>(1);
const auto C = context->Input<Tensor>(2);

GemmHelper helper(A->Shape(), trans_A_ != CblasNoTrans, B->Shape(), trans_B_ != CblasNoTrans, C->Shape());
GemmHelper helper(A->Shape(), trans_A_ != CblasNoTrans, B->Shape(), trans_B_ != CblasNoTrans,
C != nullptr ? C->Shape() : TensorShape({}));

if (!helper.State().IsOK())
return helper.State();
Expand All @@ -70,7 +77,7 @@ class Gemm : public onnxruntime::Gemm<T> {
return onnxruntime::Gemm<T>::Compute(context);
}

arm_compute::TensorShape cShape = ACLTensorShape(C->Shape());
arm_compute::TensorShape cShape = ACLTensorShape(C != nullptr ? C->Shape() : TensorShape({}));
if (useC &&
(cShape.num_dimensions() > 2 ||
(cShape.num_dimensions() == 2 && cShape[0] > 1 && cShape[1] > 1))) { // Multi-dimensional Bias
Expand All @@ -89,8 +96,13 @@ class Gemm : public onnxruntime::Gemm<T> {
(cShape[1] == 1 && cShape[0] != (long unsigned int)N)) {
return onnxruntime::Gemm<T>::Compute(context);
}
#ifdef ACL_2308
cShape = arm_compute::TensorShape(N);
LOGS_DEFAULT(VERBOSE) << "Bias reshaped to: {" << N << "}";
#else
cShape = arm_compute::TensorShape(1, N);
LOGS_DEFAULT(VERBOSE) << "Bias reshaped to: {1," << N << "}";
#endif
}

int64_t K = helper.K();
Expand Down
32 changes: 27 additions & 5 deletions onnxruntime/core/providers/acl/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ Status BatchNorm<T>::Compute(OpKernelContext* context) const {
const Tensor* M = context->Input<Tensor>(3); // mean
const Tensor* V = context->Input<Tensor>(4); // var

if (S->Shape().NumDimensions() > 1) {
LOGS_DEFAULT(WARNING) << "ACL does not support scale with dimension greater then 1; defaulting to cpu implementation";
return onnxruntime::BatchNorm<T>::Compute(context);
}

if (this->is_train_) {
LOGS_DEFAULT(WARNING) << "ACL does not have batchnorm training support; defaulting to cpu implementation";
return onnxruntime::BatchNorm<T>::Compute(context);
}

ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, S, B, M, V));

LOGS_DEFAULT(VERBOSE) << "BatchNorm ACL:";
Expand All @@ -70,7 +80,23 @@ Status BatchNorm<T>::Compute(OpKernelContext* context) const {

auto layer = std::make_shared<arm_compute::NEBatchNormalizationLayer>();

#ifdef ACL_2308
arm_compute::TensorShape in_x_shape;
const TensorShape& x_shape = X->Shape();
const auto& dims_vec = x_shape.GetDims();
in_x_shape.set(3, onnxruntime::narrow<size_t>(dims_vec[0])); // N
in_x_shape.set(1, 1); // H
size_t W = 1;
for (size_t i = 2; i < dims_vec.size(); ++i) {
W *= narrow<size_t>(dims_vec[i]);
}
in_x_shape.set(0, W); // W
in_x_shape.set(2, onnxruntime::narrow<size_t>(dims_vec[1])); // C

tbatch_norm.in->allocator()->init(arm_compute::TensorInfo(in_x_shape, arm_compute::Format::F32));
#else
tbatch_norm.in->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(X->Shape()), arm_compute::Format::F32));
#endif
tbatch_norm.out->allocator()->init(arm_compute::TensorInfo(tbatch_norm.in->info()->tensor_shape(), arm_compute::Format::F32));

tbatch_norm.scale->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(S->Shape()), arm_compute::Format::F32));
Expand Down Expand Up @@ -132,11 +158,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
7, 9,
kAclExecutionProvider,
KernelDefBuilder()
.TypeConstraint("X", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("scale", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("B", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("mean", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("var", DataTypeImpl::GetTensorType<float>()),
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
BatchNorm<float>);

} // namespace acl
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/acl/nn/batch_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ typedef struct {
typedef std::map<OpKernel*, ACLNEBatchNorm>::iterator BatchNormLayersIterator;

template <typename T>
class BatchNorm final : public OpKernel {
class BatchNorm : public onnxruntime::BatchNorm<T> {
public:
explicit BatchNorm(const OpKernelInfo& info) : OpKernel(info) {
explicit BatchNorm(const OpKernelInfo& info) : onnxruntime::BatchNorm<T>(info) {
auto st = info.GetAttr<float>("epsilon", &epsilon_);
ORT_ENFORCE(st.IsOK(), st.ErrorMessage());

Expand Down
17 changes: 15 additions & 2 deletions onnxruntime/core/providers/acl/nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
TensorShapeVector Y_dims;
Y_dims.insert(Y_dims.begin(), {N, M});
TensorShape input_shape = X->Shape().Slice(2);
#ifdef ACL_2308
ORT_RETURN_IF_ERROR(conv_attrs_.InferPadsAndOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims));
#else
ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims));
#endif
Tensor* Y = context->Output(0, TensorShape(Y_dims));
LOGS_DEFAULT(VERBOSE) << "Y " << Y->Shape().ToString().c_str();

Expand Down Expand Up @@ -222,6 +226,15 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
1 /* depth multiplier */,
acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(),
arm_compute::Size2D(aclDilation0, dilations[0])));
#elif defined(ACL_2308)
bool optimizable = bool(arm_compute::NEDepthwiseConvolutionLayer::validate(tconv.in->info(),
tconv.k->info(),
(B != nullptr) ? tconv.b->info() : nullptr,
tconv.out->info(),
aclPadStride,
1 /* depth multiplier */,
acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(),
arm_compute::Size2D(aclDilation0, dilations[0])));
#endif

if (optimizable) {
Expand All @@ -230,15 +243,15 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
auto layer = std::make_shared<arm_compute::NEDepthwiseConvolutionLayer3x3>();
#elif defined(ACL_1908)
auto layer = std::make_shared<arm_compute::NEDepthwiseConvolutionLayerOptimized>();
#elif defined(ACL_2002)
#elif defined(ACL_2002) || defined(ACL_2308)
auto layer = std::make_shared<arm_compute::NEDepthwiseConvolutionLayer>();
#endif

#ifdef ACL_1902
layer->configure(tconv.in.get(), tconv.k.get(), (B != nullptr) ? tconv.b.get() : nullptr, tconv.out.get(),
aclPadStride, 1 /* depth multiplier */,
acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo());
#elif defined(ACL_1905) || defined(ACL_1908) || defined(ACL_2002)
#elif defined(ACL_1905) || defined(ACL_1908) || defined(ACL_2002) || defined(ACL_2308)
layer->configure(tconv.in.get(), tconv.k.get(), (B != nullptr) ? tconv.b.get() : nullptr, tconv.out.get(),
aclPadStride, 1 /* depth multiplier */,
acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(),
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/acl/nn/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#include "core/providers/acl/acl_execution_provider.h"

// ACL
#ifdef ACL_2308
#include "arm_compute/runtime/Tensor.h"
#endif
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/runtime/TensorAllocator.h"
#include "arm_compute/runtime/Allocator.h"
Expand Down
17 changes: 15 additions & 2 deletions onnxruntime/core/providers/acl/nn/pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,14 @@ ACLNEPool PoolOperation(onnxruntime::OpKernelContext* context,
tpool.out->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(Y->Shape(), PREF_DIM), arm_compute::Format::F32));

if (pool_attrs.global_pooling) {
layer->configure(tpool.in.get(), tpool.out.get(), arm_compute::PoolingLayerInfo(pool_type));
layer->configure(tpool.in.get(),
tpool.out.get(),
arm_compute::PoolingLayerInfo(pool_type
#ifdef ACL_2308
,
arm_compute::DataLayout::NCHW
#endif
));
} else {
TensorShapeVector aclStrides(2);
aclStrides[0] = (strides.size() == 2) ? strides[1] : 1;
Expand Down Expand Up @@ -104,7 +111,13 @@ ACLNEPool PoolOperation(onnxruntime::OpKernelContext* context,
LOGS_DEFAULT(VERBOSE) << "strides: {" << aclStrides[0] << "," << aclStrides[1] << "}";
LOGS_DEFAULT(VERBOSE) << "excludePadding: " << excludePadding;

arm_compute::PoolingLayerInfo pool_info(pool_type, aclSize, aclPadStride, excludePadding);
arm_compute::PoolingLayerInfo pool_info(pool_type,
aclSize,
#ifdef ACL_2308
arm_compute::DataLayout::NCHW,
#endif
aclPadStride,
excludePadding);
layer->configure(tpool.in.get(), tpool.out.get(), pool_info);
}

Expand Down
32 changes: 28 additions & 4 deletions onnxruntime/core/providers/acl/tensor/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include "core/providers/acl/acl_common.h"
#include "core/providers/acl/acl_fwd.h"

#include <iostream>

#define PREF_DIM 4

namespace onnxruntime {
Expand All @@ -22,17 +24,27 @@ Status Concat<T>::Compute(OpKernelContext* ctx) const {
return onnxruntime::Concat::Compute(ctx);
}

if (axis_ < 0) {
LOGS_DEFAULT(WARNING) << "ACL does not have support for negative axis; defaulting to cpu implementation";
return onnxruntime::Concat::Compute(ctx);
}

// Number of input tensors to concatenate
auto input_count = Node().InputArgCount().front();

// Hold pointers to the input tensors to be used in the PrepareForCompute() step
std::vector<const Tensor*> input_tensors;
input_tensors.reserve(input_count);
int empty_tensors = 0;
for (int i = 0; i < input_count; ++i) {
if (ctx->Input<Tensor>(i)->Shape().Size() == 0) {
empty_tensors++;
continue;
}
input_tensors.push_back(ctx->Input<Tensor>(i));
}
input_count -= empty_tensors;

auto output_dims = input_tensors[0]->Shape().AsShapeVector();
auto output_dims = ctx->Input<Tensor>(0)->Shape().AsShapeVector();

// 'Concat' mode
if (!is_stack_) {
Expand Down Expand Up @@ -64,7 +76,11 @@ Status Concat<T>::Compute(OpKernelContext* ctx) const {
LOGS_DEFAULT(VERBOSE) << "Concat ACL:";

arm_compute::Tensor output;
#ifdef ACL_2308
std::vector<const arm_compute::ITensor*> inputs_vector;
#else
std::vector<arm_compute::ITensor*> inputs_vector;
#endif
for (int i = 0; i < input_count; i++) {
arm_compute::Tensor* input = new arm_compute::Tensor();
auto X = input_tensors[i];
Expand All @@ -75,15 +91,21 @@ Status Concat<T>::Compute(OpKernelContext* ctx) const {
}

arm_compute::NEConcatenateLayer layer;
layer.configure(inputs_vector, &output, 3 - axis_);
if (input_count > 0) {
layer.configure(inputs_vector, &output, 3 - axis_);
}

LOGS_DEFAULT(VERBOSE) << "axis: " << axis_;
LOGS_DEFAULT(VERBOSE) << std::endl;

for (int i = 0; i < input_count; i++) {
auto X = input_tensors[i];
const T* x_data = X->Data<T>();
#ifdef ACL_2308
arm_compute::Tensor* in = const_cast<arm_compute::Tensor*>(static_cast<const arm_compute::Tensor*>(inputs_vector[i]));
#else
arm_compute::Tensor* in = static_cast<arm_compute::Tensor*>(inputs_vector[i]);
#endif

if (X->Shape().Size() != 0 && in->info()->has_padding()) {
in->allocator()->allocate();
Expand All @@ -101,7 +123,9 @@ Status Concat<T>::Compute(OpKernelContext* ctx) const {
ACLImportMemory(output.allocator(), (void*)y_data, Y->Shape().Size() * 4);
}

layer.run();
if (input_count > 0) {
layer.run();
}

if (Y->Shape().Size() != 0 && output.info()->has_padding()) {
importDataFromTensor<T>(&output, y_data);
Expand Down
3 changes: 2 additions & 1 deletion tools/ci_build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def convert_arg_line_to_args(self, arg_line):
"--use_acl",
nargs="?",
const="ACL_1905",
choices=["ACL_1902", "ACL_1905", "ACL_1908", "ACL_2002"],
choices=["ACL_1902", "ACL_1905", "ACL_1908", "ACL_2002", "ACL_2308"],
help="Build with ACL for ARM architectures.",
)
parser.add_argument("--acl_home", help="Path to ACL home dir")
Expand Down Expand Up @@ -1031,6 +1031,7 @@ def generate_build_tree(
"-Donnxruntime_USE_ACL_1905=" + ("ON" if args.use_acl == "ACL_1905" else "OFF"),
"-Donnxruntime_USE_ACL_1908=" + ("ON" if args.use_acl == "ACL_1908" else "OFF"),
"-Donnxruntime_USE_ACL_2002=" + ("ON" if args.use_acl == "ACL_2002" else "OFF"),
"-Donnxruntime_USE_ACL_2308=" + ("ON" if args.use_acl == "ACL_2308" else "OFF"),
"-Donnxruntime_USE_ARMNN=" + ("ON" if args.use_armnn else "OFF"),
"-Donnxruntime_ARMNN_RELU_USE_CPU=" + ("OFF" if args.armnn_relu else "ON"),
"-Donnxruntime_ARMNN_BN_USE_CPU=" + ("OFF" if args.armnn_bn else "ON"),
Expand Down

0 comments on commit 37ac9d3

Please sign in to comment.