diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 0f57258dca706..1567da90cacfc 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -131,6 +131,7 @@ option(onnxruntime_USE_ACL_1902 "Build with ACL version 1902 support" OFF) option(onnxruntime_USE_ACL_1905 "Build with ACL version 1905 support" OFF) option(onnxruntime_USE_ACL_1908 "Build with ACL version 1908 support" OFF) option(onnxruntime_USE_ACL_2002 "Build with ACL version 2002 support" OFF) +option(onnxruntime_USE_ACL_2308 "Build with ACL version 2308 support" OFF) option(onnxruntime_USE_ARMNN "Build with ArmNN support" OFF) option(onnxruntime_ARMNN_RELU_USE_CPU "Use the CPU implementation for the Relu operator for the ArmNN EP" ON) option(onnxruntime_ARMNN_BN_USE_CPU "Use the CPU implementation for the Batch Normalization operator for the ArmNN EP" ON) @@ -1110,7 +1111,7 @@ function(onnxruntime_add_include_to_target dst_target) endfunction() # ACL -if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 OR onnxruntime_USE_ACL_1908 OR onnxruntime_USE_ACL_2002) +if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 OR onnxruntime_USE_ACL_1908 OR onnxruntime_USE_ACL_2002 OR onnxruntime_USE_ACL_2308) set(onnxruntime_USE_ACL ON) if (onnxruntime_USE_ACL_1902) add_definitions(-DACL_1902=1) @@ -1121,7 +1122,11 @@ if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 if (onnxruntime_USE_ACL_2002) add_definitions(-DACL_2002=1) else() - add_definitions(-DACL_1905=1) + if (onnxruntime_USE_ACL_2308) + add_definitions(-DACL_2308=1) + else() + add_definitions(-DACL_1905=1) + endif() endif() endif() endif() diff --git a/onnxruntime/core/providers/acl/math/gemm.h b/onnxruntime/core/providers/acl/math/gemm.h index d2f297e83aedb..f5288d7f231b0 100644 --- a/onnxruntime/core/providers/acl/math/gemm.h +++ b/onnxruntime/core/providers/acl/math/gemm.h @@ -49,11 +49,18 @@ class Gemm : public onnxruntime::Gemm { } Status Compute(OpKernelContext* context) const override { +#ifdef ACL_2308 + if (this->packed_b_) { + // Prepacked RHS not supported, defaulting to cpu execution provider + return onnxruntime::Gemm::Compute(context); + } +#endif const auto A = context->Input(0); const auto B = context->Input(1); const auto C = context->Input(2); - GemmHelper helper(A->Shape(), trans_A_ != CblasNoTrans, B->Shape(), trans_B_ != CblasNoTrans, C->Shape()); + GemmHelper helper(A->Shape(), trans_A_ != CblasNoTrans, B->Shape(), trans_B_ != CblasNoTrans, + C != nullptr ? C->Shape() : TensorShape({})); if (!helper.State().IsOK()) return helper.State(); @@ -70,7 +77,7 @@ class Gemm : public onnxruntime::Gemm { return onnxruntime::Gemm::Compute(context); } - arm_compute::TensorShape cShape = ACLTensorShape(C->Shape()); + arm_compute::TensorShape cShape = ACLTensorShape(C != nullptr ? C->Shape() : TensorShape({})); if (useC && (cShape.num_dimensions() > 2 || (cShape.num_dimensions() == 2 && cShape[0] > 1 && cShape[1] > 1))) { // Multi-dimensional Bias @@ -89,8 +96,13 @@ class Gemm : public onnxruntime::Gemm { (cShape[1] == 1 && cShape[0] != (long unsigned int)N)) { return onnxruntime::Gemm::Compute(context); } +#ifdef ACL_2308 + cShape = arm_compute::TensorShape(N); + LOGS_DEFAULT(VERBOSE) << "Bias reshaped to: {" << N << "}"; +#else cShape = arm_compute::TensorShape(1, N); LOGS_DEFAULT(VERBOSE) << "Bias reshaped to: {1," << N << "}"; +#endif } int64_t K = helper.K(); diff --git a/onnxruntime/core/providers/acl/nn/batch_norm.cc b/onnxruntime/core/providers/acl/nn/batch_norm.cc index da7fff730c96f..eb6a10074f1db 100755 --- a/onnxruntime/core/providers/acl/nn/batch_norm.cc +++ b/onnxruntime/core/providers/acl/nn/batch_norm.cc @@ -44,6 +44,16 @@ Status BatchNorm::Compute(OpKernelContext* context) const { const Tensor* M = context->Input(3); // mean const Tensor* V = context->Input(4); // var + if (S->Shape().NumDimensions() > 1) { + LOGS_DEFAULT(WARNING) << "ACL does not support scale with dimension greater then 1; defaulting to cpu implementation"; + return onnxruntime::BatchNorm::Compute(context); + } + + if (this->is_train_) { + LOGS_DEFAULT(WARNING) << "ACL does not have batchnorm training support; defaulting to cpu implementation"; + return onnxruntime::BatchNorm::Compute(context); + } + ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, S, B, M, V)); LOGS_DEFAULT(VERBOSE) << "BatchNorm ACL:"; @@ -70,7 +80,23 @@ Status BatchNorm::Compute(OpKernelContext* context) const { auto layer = std::make_shared(); +#ifdef ACL_2308 + arm_compute::TensorShape in_x_shape; + const TensorShape& x_shape = X->Shape(); + const auto& dims_vec = x_shape.GetDims(); + in_x_shape.set(3, onnxruntime::narrow(dims_vec[0])); // N + in_x_shape.set(1, 1); // H + size_t W = 1; + for (size_t i = 2; i < dims_vec.size(); ++i) { + W *= narrow(dims_vec[i]); + } + in_x_shape.set(0, W); // W + in_x_shape.set(2, onnxruntime::narrow(dims_vec[1])); // C + + tbatch_norm.in->allocator()->init(arm_compute::TensorInfo(in_x_shape, arm_compute::Format::F32)); +#else tbatch_norm.in->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(X->Shape()), arm_compute::Format::F32)); +#endif tbatch_norm.out->allocator()->init(arm_compute::TensorInfo(tbatch_norm.in->info()->tensor_shape(), arm_compute::Format::F32)); tbatch_norm.scale->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(S->Shape()), arm_compute::Format::F32)); @@ -132,11 +158,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 7, 9, kAclExecutionProvider, KernelDefBuilder() - .TypeConstraint("X", DataTypeImpl::GetTensorType()) - .TypeConstraint("scale", DataTypeImpl::GetTensorType()) - .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("mean", DataTypeImpl::GetTensorType()) - .TypeConstraint("var", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", DataTypeImpl::GetTensorType()), BatchNorm); } // namespace acl diff --git a/onnxruntime/core/providers/acl/nn/batch_norm.h b/onnxruntime/core/providers/acl/nn/batch_norm.h index c9ec08b67a779..264301976e6dc 100755 --- a/onnxruntime/core/providers/acl/nn/batch_norm.h +++ b/onnxruntime/core/providers/acl/nn/batch_norm.h @@ -31,9 +31,9 @@ typedef struct { typedef std::map::iterator BatchNormLayersIterator; template -class BatchNorm final : public OpKernel { +class BatchNorm : public onnxruntime::BatchNorm { public: - explicit BatchNorm(const OpKernelInfo& info) : OpKernel(info) { + explicit BatchNorm(const OpKernelInfo& info) : onnxruntime::BatchNorm(info) { auto st = info.GetAttr("epsilon", &epsilon_); ORT_ENFORCE(st.IsOK(), st.ErrorMessage()); diff --git a/onnxruntime/core/providers/acl/nn/conv.cc b/onnxruntime/core/providers/acl/nn/conv.cc index 1613d927d0f74..85bd0cfe96279 100644 --- a/onnxruntime/core/providers/acl/nn/conv.cc +++ b/onnxruntime/core/providers/acl/nn/conv.cc @@ -105,7 +105,11 @@ Status Conv::Compute(OpKernelContext* context) const { TensorShapeVector Y_dims; Y_dims.insert(Y_dims.begin(), {N, M}); TensorShape input_shape = X->Shape().Slice(2); +#ifdef ACL_2308 + ORT_RETURN_IF_ERROR(conv_attrs_.InferPadsAndOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims)); +#else ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims)); +#endif Tensor* Y = context->Output(0, TensorShape(Y_dims)); LOGS_DEFAULT(VERBOSE) << "Y " << Y->Shape().ToString().c_str(); @@ -222,6 +226,15 @@ Status Conv::Compute(OpKernelContext* context) const { 1 /* depth multiplier */, acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), arm_compute::Size2D(aclDilation0, dilations[0]))); +#elif defined(ACL_2308) + bool optimizable = bool(arm_compute::NEDepthwiseConvolutionLayer::validate(tconv.in->info(), + tconv.k->info(), + (B != nullptr) ? tconv.b->info() : nullptr, + tconv.out->info(), + aclPadStride, + 1 /* depth multiplier */, + acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), + arm_compute::Size2D(aclDilation0, dilations[0]))); #endif if (optimizable) { @@ -230,7 +243,7 @@ Status Conv::Compute(OpKernelContext* context) const { auto layer = std::make_shared(); #elif defined(ACL_1908) auto layer = std::make_shared(); -#elif defined(ACL_2002) +#elif defined(ACL_2002) || defined(ACL_2308) auto layer = std::make_shared(); #endif @@ -238,7 +251,7 @@ Status Conv::Compute(OpKernelContext* context) const { layer->configure(tconv.in.get(), tconv.k.get(), (B != nullptr) ? tconv.b.get() : nullptr, tconv.out.get(), aclPadStride, 1 /* depth multiplier */, acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo()); -#elif defined(ACL_1905) || defined(ACL_1908) || defined(ACL_2002) +#elif defined(ACL_1905) || defined(ACL_1908) || defined(ACL_2002) || defined(ACL_2308) layer->configure(tconv.in.get(), tconv.k.get(), (B != nullptr) ? tconv.b.get() : nullptr, tconv.out.get(), aclPadStride, 1 /* depth multiplier */, acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), diff --git a/onnxruntime/core/providers/acl/nn/conv.h b/onnxruntime/core/providers/acl/nn/conv.h index ecb11fb3c8f4e..660d47b4172df 100644 --- a/onnxruntime/core/providers/acl/nn/conv.h +++ b/onnxruntime/core/providers/acl/nn/conv.h @@ -8,6 +8,9 @@ #include "core/providers/acl/acl_execution_provider.h" // ACL +#ifdef ACL_2308 +#include "arm_compute/runtime/Tensor.h" +#endif #include "arm_compute/core/TensorInfo.h" #include "arm_compute/runtime/TensorAllocator.h" #include "arm_compute/runtime/Allocator.h" diff --git a/onnxruntime/core/providers/acl/nn/pool.cc b/onnxruntime/core/providers/acl/nn/pool.cc index dc79ae65bf21e..8fbcba3ed87a7 100644 --- a/onnxruntime/core/providers/acl/nn/pool.cc +++ b/onnxruntime/core/providers/acl/nn/pool.cc @@ -61,7 +61,14 @@ ACLNEPool PoolOperation(onnxruntime::OpKernelContext* context, tpool.out->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(Y->Shape(), PREF_DIM), arm_compute::Format::F32)); if (pool_attrs.global_pooling) { - layer->configure(tpool.in.get(), tpool.out.get(), arm_compute::PoolingLayerInfo(pool_type)); + layer->configure(tpool.in.get(), + tpool.out.get(), + arm_compute::PoolingLayerInfo(pool_type +#ifdef ACL_2308 + , + arm_compute::DataLayout::NCHW +#endif + )); } else { TensorShapeVector aclStrides(2); aclStrides[0] = (strides.size() == 2) ? strides[1] : 1; @@ -104,7 +111,13 @@ ACLNEPool PoolOperation(onnxruntime::OpKernelContext* context, LOGS_DEFAULT(VERBOSE) << "strides: {" << aclStrides[0] << "," << aclStrides[1] << "}"; LOGS_DEFAULT(VERBOSE) << "excludePadding: " << excludePadding; - arm_compute::PoolingLayerInfo pool_info(pool_type, aclSize, aclPadStride, excludePadding); + arm_compute::PoolingLayerInfo pool_info(pool_type, + aclSize, +#ifdef ACL_2308 + arm_compute::DataLayout::NCHW, +#endif + aclPadStride, + excludePadding); layer->configure(tpool.in.get(), tpool.out.get(), pool_info); } diff --git a/onnxruntime/core/providers/acl/tensor/concat.cc b/onnxruntime/core/providers/acl/tensor/concat.cc index 081472729cfcf..75eedaac80aea 100644 --- a/onnxruntime/core/providers/acl/tensor/concat.cc +++ b/onnxruntime/core/providers/acl/tensor/concat.cc @@ -10,6 +10,8 @@ #include "core/providers/acl/acl_common.h" #include "core/providers/acl/acl_fwd.h" +#include + #define PREF_DIM 4 namespace onnxruntime { @@ -22,17 +24,27 @@ Status Concat::Compute(OpKernelContext* ctx) const { return onnxruntime::Concat::Compute(ctx); } + if (axis_ < 0) { + LOGS_DEFAULT(WARNING) << "ACL does not have support for negative axis; defaulting to cpu implementation"; + return onnxruntime::Concat::Compute(ctx); + } + // Number of input tensors to concatenate auto input_count = Node().InputArgCount().front(); // Hold pointers to the input tensors to be used in the PrepareForCompute() step std::vector input_tensors; - input_tensors.reserve(input_count); + int empty_tensors = 0; for (int i = 0; i < input_count; ++i) { + if (ctx->Input(i)->Shape().Size() == 0) { + empty_tensors++; + continue; + } input_tensors.push_back(ctx->Input(i)); } + input_count -= empty_tensors; - auto output_dims = input_tensors[0]->Shape().AsShapeVector(); + auto output_dims = ctx->Input(0)->Shape().AsShapeVector(); // 'Concat' mode if (!is_stack_) { @@ -64,7 +76,11 @@ Status Concat::Compute(OpKernelContext* ctx) const { LOGS_DEFAULT(VERBOSE) << "Concat ACL:"; arm_compute::Tensor output; +#ifdef ACL_2308 + std::vector inputs_vector; +#else std::vector inputs_vector; +#endif for (int i = 0; i < input_count; i++) { arm_compute::Tensor* input = new arm_compute::Tensor(); auto X = input_tensors[i]; @@ -75,7 +91,9 @@ Status Concat::Compute(OpKernelContext* ctx) const { } arm_compute::NEConcatenateLayer layer; - layer.configure(inputs_vector, &output, 3 - axis_); + if (input_count > 0) { + layer.configure(inputs_vector, &output, 3 - axis_); + } LOGS_DEFAULT(VERBOSE) << "axis: " << axis_; LOGS_DEFAULT(VERBOSE) << std::endl; @@ -83,7 +101,11 @@ Status Concat::Compute(OpKernelContext* ctx) const { for (int i = 0; i < input_count; i++) { auto X = input_tensors[i]; const T* x_data = X->Data(); +#ifdef ACL_2308 + arm_compute::Tensor* in = const_cast(static_cast(inputs_vector[i])); +#else arm_compute::Tensor* in = static_cast(inputs_vector[i]); +#endif if (X->Shape().Size() != 0 && in->info()->has_padding()) { in->allocator()->allocate(); @@ -101,7 +123,9 @@ Status Concat::Compute(OpKernelContext* ctx) const { ACLImportMemory(output.allocator(), (void*)y_data, Y->Shape().Size() * 4); } - layer.run(); + if (input_count > 0) { + layer.run(); + } if (Y->Shape().Size() != 0 && output.info()->has_padding()) { importDataFromTensor(&output, y_data); diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index c655100fbf475..3d0ec92a7bd23 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -606,7 +606,7 @@ def convert_arg_line_to_args(self, arg_line): "--use_acl", nargs="?", const="ACL_1905", - choices=["ACL_1902", "ACL_1905", "ACL_1908", "ACL_2002"], + choices=["ACL_1902", "ACL_1905", "ACL_1908", "ACL_2002", "ACL_2308"], help="Build with ACL for ARM architectures.", ) parser.add_argument("--acl_home", help="Path to ACL home dir") @@ -1031,6 +1031,7 @@ def generate_build_tree( "-Donnxruntime_USE_ACL_1905=" + ("ON" if args.use_acl == "ACL_1905" else "OFF"), "-Donnxruntime_USE_ACL_1908=" + ("ON" if args.use_acl == "ACL_1908" else "OFF"), "-Donnxruntime_USE_ACL_2002=" + ("ON" if args.use_acl == "ACL_2002" else "OFF"), + "-Donnxruntime_USE_ACL_2308=" + ("ON" if args.use_acl == "ACL_2308" else "OFF"), "-Donnxruntime_USE_ARMNN=" + ("ON" if args.use_armnn else "OFF"), "-Donnxruntime_ARMNN_RELU_USE_CPU=" + ("OFF" if args.armnn_relu else "ON"), "-Donnxruntime_ARMNN_BN_USE_CPU=" + ("OFF" if args.armnn_bn else "ON"),