From 57d6819212464f49b30db047528be0f409dadc67 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Thu, 22 Feb 2024 00:08:47 +0800 Subject: [PATCH 01/16] [js/web] Fix fused-conv is not included in npm test (#19581) BUG: https://github.com/microsoft/onnxruntime/issues/18855 ### Description ### Motivation and Context --- js/web/test/suite-test-list.jsonc | 1 + 1 file changed, 1 insertion(+) diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 1c61518ddcdd2..b43b1ac37e37d 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1354,6 +1354,7 @@ "expand.jsonc", "fast-gelu.jsonc", "floor.jsonc", + "fused-conv.jsonc", "gather-elements.jsonc", "gemm.jsonc", "global-average-pool.jsonc", From e5ce81ae847d0b347a3dfe95abfc9e407e2f0469 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Wed, 21 Feb 2024 15:24:41 -0500 Subject: [PATCH 02/16] [java] Adding ML program flag for CoreML (#19551) ### Description Adds the new CoreML enum flags to enable ML Program support in Java. ### Motivation and Context Adds support for #19347 to the Java API. --- .../ai/onnxruntime/providers/CoreMLFlags.java | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java index eb124decf75f3..cec3fadf446ca 100644 --- a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java +++ b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; @@ -14,7 +14,18 @@ public enum CoreMLFlags implements OrtFlags { /** Enables CoreML on subgraphs. */ ENABLE_ON_SUBGRAPH(2), // COREML_FLAG_ENABLE_ON_SUBGRAPH(0x002) /** Only enable usage of CoreML if the device has an Apple Neural Engine. */ - ONLY_ENABLE_DEVICE_WITH_ANE(4); // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004), + ONLY_ENABLE_DEVICE_WITH_ANE(4), // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004) + /** + * Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also + * allow inputs with dynamic shapes. However, the performance may be negatively impacted if inputs + * have dynamic shapes. + */ + ONLY_ALLOW_STATIC_INPUT_SHAPES(8), // COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES(0x008) + /** + * Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or + * later. + */ + CREATE_MLPROGRAM(16); // COREML_FLAG_CREATE_MLPROGRAM(0x010) /** The native value of the enum. */ public final int value; From 3afb38cfb7d4263f262dea33bcfa16d35c67fede Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 21 Feb 2024 12:46:16 -0800 Subject: [PATCH 03/16] [CUDA] Add use_tf32 cuda provider option (for FP32 Conv) (#19426) Follow up of https://github.com/microsoft/onnxruntime/pull/19357 to apply the use_tf32 option on fp32 cuDNN convolution. When use_tf32 = 0, we will disable TF32 in cuDNN convolution for FP32 inputs. https://docs.nvidia.com/deeplearning/cudnn/api/cudnn-graph-library.html#cudnnmathtype-t **CUDNN_FMA_MATH** - Restricted to only kernels that use FMA instructions. - On pre-NVIDIA A100 GPU devices, CUDNN_DEFAULT_MATH and CUDNN_FMA_MATH have the same behavior: Tensor Core kernels will not be selected. - With NVIDIA Ampere architecture and CUDA toolkit 11, CUDNN_DEFAULT_MATH permits TF32 Tensor Core operation and CUDNN_FMA_MATH does not. - The TF32 behavior for CUDNN_DEFAULT_MATH and the other Tensor Core math types can be explicitly disabled by the environment variable NVIDIA_TF32_OVERRIDE=0. --- onnxruntime/core/providers/cuda/nn/conv.cc | 17 ++++++++++++++--- onnxruntime/core/providers/cuda/nn/conv.h | 3 ++- .../core/providers/cuda/nn/conv_transpose.cc | 10 ++++++++-- .../training_ops/cuda/nn/conv_grad.cc | 3 ++- .../training_ops/cuda/nn/conv_shared.cc | 6 ++++-- .../training_ops/cuda/nn/conv_shared.h | 2 +- .../training_ops/cuda/nn/conv_transpose_grad.cc | 6 ++++-- 7 files changed, 35 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 82f3503919237..a417be5a86c32 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -326,7 +326,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), - CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType())); + CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType(), + UseTF32())); if (context->InputCount() >= 3) { const Tensor* B = context->Input(2); @@ -351,8 +352,13 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) { // set math type to tensor core before algorithm search - if constexpr (std::is_same::value) + if constexpr (std::is_same::value) { CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + } else if constexpr (std::is_same::value) { + if (!UseTF32()) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); + } + } cudnnConvolutionFwdAlgoPerf_t perf; int algo_count = 1; @@ -399,6 +405,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory)); if (std::is_same::value) { perf.mathType = CUDNN_TENSOR_OP_MATH; + } else if (std::is_same::value && !UseTF32()) { + perf.mathType = CUDNN_FMA_MATH; } else { perf.mathType = CUDNN_DEFAULT_MATH; } @@ -480,7 +488,8 @@ Status CudnnConvolutionDescriptor::Set( const gsl::span& dilations, int groups, cudnnConvolutionMode_t mode, - cudnnDataType_t data_type) { + cudnnDataType_t data_type, + bool use_tf32) { if (!desc_) CUDNN_RETURN_IF_ERROR(cudnnCreateConvolutionDescriptor(&desc_)); @@ -513,6 +522,8 @@ Status CudnnConvolutionDescriptor::Set( CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_DEFAULT_MATH)); if (data_type == CUDNN_DATA_HALF) { CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_TENSOR_OP_MATH)); + } else if (data_type == CUDNN_DATA_FLOAT && !use_tf32) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_FMA_MATH)); } return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index bcaa4d855b81e..181fbc99fd8e9 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -29,7 +29,8 @@ class CudnnConvolutionDescriptor final { const gsl::span& dilations, int groups, cudnnConvolutionMode_t mode, - cudnnDataType_t data_type); + cudnnDataType_t data_type, + bool use_tf32); operator cudnnConvolutionDescriptor_t() const { return desc_; } diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 55dceaa2698e8..939b9959af818 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -167,7 +167,8 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, gsl::narrow_cast(conv_transpose_attrs_.group), mode, - CudnnTensor::GetDataType())); + CudnnTensor::GetDataType(), + UseTF32())); if (has_bias) { const auto& b_shape = p.B->Shape(); @@ -187,8 +188,13 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); // set math type to tensor core before algorithm search - if constexpr (std::is_same::value) + if constexpr (std::is_same::value) { CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + } else if constexpr (std::is_same::value) { + if (!UseTF32()) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); + } + } cudnnConvolutionBwdDataAlgoPerf_t perf; int algo_count = 1; diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc index f6c58445c0a5d..fc5d9b65d0f89 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc @@ -114,7 +114,8 @@ Status ConvGrad::PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor& ORT_RETURN_IF_ERROR(args_.y_tensor.Set(dy_dims, args_.params.data_type)); ORT_RETURN_IF_ERROR(args_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, - args_.params.data_type)); + args_.params.data_type, + UseTF32())); if (dB) { const TensorShape& db_shape = dB->Shape(); diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc index 5dc16c68f6210..d23905496c9bb 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc @@ -233,11 +233,13 @@ bool ConvParamsEqual::operator()(const ConvParams& a, const ConvParams& b) const } template -Status AlgoIterator::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results) { +Status AlgoIterator::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results, bool use_tf32) { perf_results.resize(1); perf_results[0].algo = AlgoSearch::DEFAULT_ALGO; if (args.params.data_type == CUDNN_DATA_HALF) { perf_results[0].mathType = CUDNN_TENSOR_OP_MATH; + } else if (args.params.data_type == CUDNN_DATA_FLOAT && !use_tf32) { + perf_results[0].mathType = CUDNN_FMA_MATH; } else { perf_results[0].mathType = CUDNN_DEFAULT_MATH; } @@ -256,7 +258,7 @@ Status AlgoIterator::TryAll(const CUDAExecutionProvider* provider, const std::vector perf_results; ORT_RETURN_IF_ERROR(args_.params.algo_mode == OrtCudnnConvAlgoSearchDefault - ? OnlyDefaultAlgorithm(args_, perf_results) + ? OnlyDefaultAlgorithm(args_, perf_results, provider->UseTF32()) : AlgoSearch::FindAlgorithms(args_, provider, allocator, perf_results)); for (auto& algo_perf : perf_results) { if (f(algo_perf) == Status::OK()) { diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h index a2d4bf3bdc006..3fdb4306bfbbb 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h @@ -75,7 +75,7 @@ class AlgoIterator { Status TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, std::function f); - static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results); + static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results, bool use_tf32); private: const ConvArgs& args_; diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc index 5f7206fc121ec..d3f5a89434a48 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc @@ -182,7 +182,8 @@ Status ConvTransposeGrad::PrepareConvForwardArgs(const Tensor& X, const Tenso ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type)); ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, - args.params.data_type)); + args.params.data_type, + UseTF32())); } return Status::OK(); @@ -287,7 +288,8 @@ Status ConvTransposeGrad::PrepareConvBackwardFilterArgs(const Tensor& X, cons ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type)); ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, - args.params.data_type)); + args.params.data_type, + UseTF32())); if (dB) { const auto& b_shape = dB->Shape(); From ebd220b0730f9898aaa0275ef0d8195ce70057d0 Mon Sep 17 00:00:00 2001 From: Matttttt <18152455+martholomew@users.noreply.github.com> Date: Wed, 21 Feb 2024 21:38:18 +0000 Subject: [PATCH 04/16] Misspelling in README.md (#19433) Fixed a misspelling. --- js/web/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/README.md b/js/web/README.md index c75a40ad6da28..906c78a1b7ec4 100644 --- a/js/web/README.md +++ b/js/web/README.md @@ -12,7 +12,7 @@ The [Open Neural Network Exchange](http://onnx.ai/) (ONNX) is an open standard f With ONNX Runtime Web, web developers can score models directly on browsers with various benefits including reducing server-client communication and protecting user privacy, as well as offering install-free and cross-platform in-browser ML experience. -ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web complies the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend. +ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web compiles the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend. See [Compatibility](#Compatibility) and [Operators Supported](#Operators) for a list of platforms and operators ONNX Runtime Web currently supports. @@ -22,7 +22,7 @@ Refer to [ONNX Runtime JavaScript examples](https://github.com/microsoft/onnxrun ## Documents -### Developement +### Development Refer to the following links for development information: From 38c34323939bac03b9648b2e59dbbe8de0bd7092 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 21 Feb 2024 13:58:53 -0800 Subject: [PATCH 05/16] Bump ip from 1.1.8 to 1.1.9 in /js/react_native (#19582) Bumps [ip](https://github.com/indutny/node-ip) from 1.1.8 to 1.1.9.
Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=ip&package-manager=npm_and_yarn&previous-version=1.1.8&new-version=1.1.9)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) Dependabot will merge this PR once CI passes on it, as requested by @fs-eire. [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- js/react_native/yarn.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock index 4dca90d7415cf..bbb0c4f3d1e22 100644 --- a/js/react_native/yarn.lock +++ b/js/react_native/yarn.lock @@ -3701,9 +3701,9 @@ invariant@^2.2.4: loose-envify "^1.0.0" ip@^1.1.5: - version "1.1.8" - resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48" - integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg== + version "1.1.9" + resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396" + integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ== is-absolute@^1.0.0: version "1.0.0" From 5197db19802a39e47d19ac829cd08a94bacbdfbb Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Wed, 21 Feb 2024 15:45:44 -0800 Subject: [PATCH 06/16] Diable __cpuid call for ARM64EC (#19592) Diable __cpuid call for ARM64EC Co-authored-by: Sheil Kumar --- winml/lib/Api/HardwareCoreEnumerator.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/winml/lib/Api/HardwareCoreEnumerator.cpp b/winml/lib/Api/HardwareCoreEnumerator.cpp index b6b44690f4f6c..d04e276347170 100644 --- a/winml/lib/Api/HardwareCoreEnumerator.cpp +++ b/winml/lib/Api/HardwareCoreEnumerator.cpp @@ -84,7 +84,7 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores. auto cores = GetNumberOPhysicalAndEngineeringCores(); -#if !defined(_M_ARM64) && !defined(__aarch64__) +#if !defined(_M_ARM64EC) && !defined(_M_ARM64) && !defined(__aarch64__) const int kVendorID_Intel[3] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" int regs_leaf0[4]; int regs_leaf7[4]; From 3d88487c96bf467c4b83dff179c9e282602e2d64 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 22 Feb 2024 10:35:26 +0800 Subject: [PATCH 07/16] Minor Triton Fix (#19589) Including removing a unnecessary assert, and add support of passing string attribute from ONNX node attribute to python functoin kwargs (mainly for passing debug info from graph to python for now). --- .../orttraining/core/framework/triton/triton_op_executor.cc | 2 ++ orttraining/orttraining/python/training/ort_triton/_utils.py | 3 ++- orttraining/orttraining/training_ops/cpu/triton/triton_op.h | 5 ++++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/core/framework/triton/triton_op_executor.cc b/orttraining/orttraining/core/framework/triton/triton_op_executor.cc index 092ab89d5d760..f30d6ddee253a 100644 --- a/orttraining/orttraining/core/framework/triton/triton_op_executor.cc +++ b/orttraining/orttraining/core/framework/triton/triton_op_executor.cc @@ -106,6 +106,8 @@ void TritonOpExecutor::ExecuteByFuncName(const std::string& func_name, const Inl PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyLong_FromLongLong(std::stoll(kv.second.first))); } else if (kv.second.second == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyFloat_FromDouble(std::stod(kv.second.first))); + } else if (kv.second.second == ONNX_NAMESPACE::TensorProto_DataType_STRING) { + PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyUnicode_FromString(kv.second.first.c_str())); } else { ORT_THROW("Unsupported kwargs data type: ", kv.second.second); } diff --git a/orttraining/orttraining/python/training/ort_triton/_utils.py b/orttraining/orttraining/python/training/ort_triton/_utils.py index 95e6703be8783..877eacc0b775f 100644 --- a/orttraining/orttraining/python/training/ort_triton/_utils.py +++ b/orttraining/orttraining/python/training/ort_triton/_utils.py @@ -141,13 +141,14 @@ def get_reduce_info(node: NodeProto, graph: GraphProto, input_rank: int) -> Tupl def next_power_of_2(n: int) -> int: - assert n <= 2**32, "32-bit only" + """Return the smallest power of 2 greater than or equal to n""" n -= 1 n |= n >> 1 n |= n >> 2 n |= n >> 4 n |= n >> 8 n |= n >> 16 + n |= n >> 32 n += 1 return n diff --git a/orttraining/orttraining/training_ops/cpu/triton/triton_op.h b/orttraining/orttraining/training_ops/cpu/triton/triton_op.h index f226db76f7ed7..db8e8558ab884 100644 --- a/orttraining/orttraining/training_ops/cpu/triton/triton_op.h +++ b/orttraining/orttraining/training_ops/cpu/triton/triton_op.h @@ -25,12 +25,15 @@ class TritonOp final : public OpKernel { attr.first == "onnx_string") { continue; } - // Support int64 and float only for now, skip other types. + // Support int64, float and string only for now, skip other types. if (attr.second.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_INT) { kwargs_.insert({attr.first, {std::to_string(attr.second.i()), ONNX_NAMESPACE::TensorProto_DataType_INT64}}); } else if (attr.second.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_FLOAT) { kwargs_.insert({attr.first, {std::to_string(attr.second.f()), ONNX_NAMESPACE::TensorProto_DataType_FLOAT}}); + } else if (attr.second.type() == + ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_STRING) { + kwargs_.insert({attr.first, {attr.second.s(), ONNX_NAMESPACE::TensorProto_DataType_STRING}}); } } } From 8354329086ebb190db9ea0cb6a3fa72f53f8f881 Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Thu, 22 Feb 2024 13:34:45 +0800 Subject: [PATCH 08/16] [ROCm] SkipGroupNorm triton (#19408) Change GroupNorm triton to support SkipGroupNorm --- .../rocm/diffusion/group_norm_triton.cuh | 23 ++++++++--- .../rocm/diffusion/group_norm_triton.py | 39 +++++++++++++++++-- .../kernel_explorer/kernels/groupnorm_test.py | 12 ++++++ 3 files changed, 64 insertions(+), 10 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh index b3d3e92209b39..c6ca16bfdfc80 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -46,8 +46,6 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { auto block_size = metadata->constants.at("BLOCK_SIZE"); auto hw_size = metadata->constants.at("HW_SIZE"); auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), - "Input skip or bias is not supported by triton kernel."); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size, "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", @@ -61,23 +59,36 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { } // Construct args for launch kernel struct { - void* X; - void* Y; + const void* src; + const void* skip; + const void* bias; + void* out; + void* add_out; const void* gamma; const void* beta; int hw; int c; int c_per_group; float eps; + bool has_skip; + bool has_bias; + bool broadcast_skip; } args = { - (void*)params->src, + (const void*)params->src, + (const void*)params->skip, + (const void*)params->bias, (void*)params->dst, + (void*)params->skip_workspace, (const void*)params->gamma, (const void*)params->beta, params->hw, params->c, params->channels_per_group, - params->epsilon}; + params->epsilon, + params->skip != nullptr, + params->bias != nullptr, + params->broadcast_skip, + }; // Grid dim is (batch_count, groups, 1) return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args)); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py index 5368cb1cf635b..5ba96ebc117f0 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py @@ -12,13 +12,19 @@ @triton.jit def group_norm_kernel( input_ptr, + skip_ptr, + bias_ptr, output_ptr, + add_out_ptr, gamma_ptr, beta_ptr, img_size, c, c_per_group, eps, + has_skip, + has_bias, + broadcast_skip, BLOCK_SIZE: tl.constexpr, HW_SIZE: tl.constexpr, ACTIVATION_SILU: tl.constexpr, @@ -36,14 +42,35 @@ def group_norm_kernel( offsets = hw[:, None] * c + cols[None, :] mask = (cols < c_per_group)[None, :] + bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + if has_skip: + add_out_ptr += row_x * stride + row_y * c_per_group + if broadcast_skip: + broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group + bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) + else: + skip_ptr += row_x * stride + row_y * c_per_group + if has_bias: + bias_ptr += row_y * c_per_group + bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) + # Calculate mean and variance _sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) _square_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) for i in range(tl.cdiv(img_size, HW_SIZE)): x_ptr = input_ptr + i * HW_SIZE * c a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + if has_skip and not broadcast_skip: + s_ptr = skip_ptr + i * HW_SIZE * c + s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + a += s + if has_bias or broadcast_skip: + a += bias _sum += a _square_sum += a * a + if has_skip: + add_y_ptr = add_out_ptr + i * HW_SIZE * c + tl.store(add_y_ptr + offsets, a, mask=mask) # Set axis=None (or leave it unspecified) to reduce all axes. # TODO: In older Triton we have to reduce an axis at a time, but in our case @@ -57,9 +84,13 @@ def group_norm_kernel( gamma = tl.load(gamma_ptr + cols, mask=cols < c_per_group).to(tl.float32) beta = tl.load(beta_ptr + cols, mask=cols < c_per_group).to(tl.float32) for i in range(tl.cdiv(img_size, HW_SIZE)): - x_ptr = input_ptr + i * HW_SIZE * c y_ptr = output_ptr + i * HW_SIZE * c - x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + if has_skip: + add_y_ptr = add_out_ptr + i * HW_SIZE * c + x = tl.load(add_y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + else: + x_ptr = input_ptr + i * HW_SIZE * c + x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) x_hat = (x - group_mean) * rstd y = x_hat * gamma + beta if ACTIVATION_SILU: @@ -77,7 +108,7 @@ def group_norm_kernel( hw_sizes = [8, 16, 32, 64, 128, 256] warps = [1, 2, 4, 8, 16] name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}" -sig_pattern = "*{},*{},*fp32,*fp32,i32,i32,i32,fp32" +sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1" group_pattern = "GroupNormTriton_{}_{}" @@ -88,7 +119,7 @@ def get_function_table(): silu_suffix = "Silu" if silu else "Pass" name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp) group = group_pattern.format(silu_suffix, dtype) - sig = sig_pattern.format(dtype, dtype) + sig = sig_pattern.format(dtype, dtype, dtype, dtype, dtype) kwargs = { "num_warps": warp, "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)}, diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py index 8334d20e47c86..400a9d8a7a187 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py @@ -80,6 +80,18 @@ def run_group_norm( ) use_silu = silu broadcast_skip = False + if has_skip: + skip_x_shape = skip_x.shape + b2 = len(skip_x_shape) == 2 and skip_x_shape[0] == batch_size and skip_x_shape[1] == num_channels + b4 = ( + len(skip_x_shape) == 4 + and skip_x_shape[0] == batch_size + and skip_x_shape[1] == 1 + and skip_x_shape[2] == 1 + and skip_x_shape[3] == num_channels + ) + if b2 or b4: + broadcast_skip = True channels_per_block = 0 # Compute in params initialization input_d = ke.DeviceArray(input_x.astype(dtype)) From 05ed89f46980b7e5a5328bc20af8b32ca9f1f715 Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Thu, 22 Feb 2024 13:34:55 +0800 Subject: [PATCH 09/16] [ROCm] Add excluded libs for ROCm python package (#19586) The rocm lib version has changed in rocm 6.0 Using libs packaged in whl might cause errors. For example, `libamdhip64.so.6` packaged in whl will cause compute error when training gpt2 model. The root cause still in investigating. --- setup.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/setup.py b/setup.py index 03e1cb75ba581..9a5fc29dd5e02 100644 --- a/setup.py +++ b/setup.py @@ -205,18 +205,23 @@ def run(self): rocm_dependencies = [ "libamd_comgr.so.2", "libamdhip64.so.5", + "libamdhip64.so.6", "libdrm.so.2", "libdrm_amdgpu.so.1", "libelf.so.1", "libhipfft.so.0", "libhiprtc.so.5", + "libhiprtc.so.6", "libhsa-runtime64.so.1", "libMIOpen.so.1", "libnuma.so.1", "librccl.so.1", "librocblas.so.3", + "librocblas.so.4", "librocfft.so.0", + "libroctx64.so.4", "librocm_smi64.so.5", + "librocm_smi64.so.6", "libroctracer64.so.4", "libtinfo.so.6", "libmigraphx_c.so.3", From 6b73ab3e3e72a9f2008e8d0e221b0be77d2993b1 Mon Sep 17 00:00:00 2001 From: cao lei Date: Thu, 22 Feb 2024 10:19:08 -0800 Subject: [PATCH 10/16] Introduce reused_buffer_index_per_stream in allocation planner which will be reset after computing the reuse buffer for each stream (#19515) ### Description Introduce reused_buffer_index_per_stream in allocation planner which will be reset after computing the reuse buffer for each stream. So if a NodeArg is an input of several Ops across different streams and reuses other NodeArg, the reused NodeArg won't be involved when computing the second stream's reuse plan. ### Motivation and Context This is to fix https://github.com/microsoft/onnxruntime/issues/19480, which is a crash for the scenario mentioned above. --------- Co-authored-by: Lei Cao --- .../core/framework/allocation_planner.cc | 44 ++++++------ .../test/framework/allocation_planner_test.cc | 68 ++++++++++++++++++ .../multi_stream_models/issue_19480.onnx | Bin 0 -> 760 bytes 3 files changed, 91 insertions(+), 21 deletions(-) create mode 100644 onnxruntime/test/testdata/multi_stream_models/issue_19480.onnx diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index ea7a6432a7507..158ab8ed610f4 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -182,7 +182,6 @@ class PlannerImpl { // upstream_node_0 and upstream_node_1 are the immmediate upstream nodes of downstream_node // upstream_node_2 is the immediate nodes ahead of downstream_node in the same logic stream InlinedHashMap> dependence_graph_; - InlinedHashMap> value_consumer_map_; InlinedHashMap value_node_map_; // OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation: @@ -295,7 +294,7 @@ class PlannerImpl { } #endif - // Find if there exists some input tensor that we can use in-place for output_arg_num-th input in the node. + // Find if there exists some input tensor that we can use in-place for output_arg_num-th output in the node. bool FindReusableInput(const onnxruntime::Node& node, int output_arg_num, OrtValueIndex* reusable_input, bool* is_strided_tensor) { *is_strided_tensor = false; @@ -530,6 +529,7 @@ class PlannerImpl { // Initialize allocation plan: plan_.allocation_plan.resize(num_ml_values); + for (int i = 0; static_cast(i) < num_ml_values; i++) AllocPlan(i).reused_buffer = i; } bool HasExternalOutputs(const Node& node) const { @@ -1065,7 +1065,8 @@ class PlannerImpl { // build the consumer list for each value int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1; - value_consumer_map_.reserve(num_ml_values); + InlinedHashMap> value_consumer_map; + value_consumer_map.reserve(num_ml_values); // iterate each stream from back, so the first element is the last consumer in single stream case for (auto& stream : stream_nodes_) { @@ -1078,10 +1079,10 @@ class PlannerImpl { const auto& name = input.Name(); int value_idx; ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx)); - auto origin = Buffer(value_idx); - if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) { + auto origin = AllocPlan(value_idx).reused_buffer; + if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) { // add current node as consumer for origin buffer - value_consumer_map_[origin].insert(node_index); + value_consumer_map[origin].insert(node_index); } } return Status::OK(); @@ -1138,8 +1139,8 @@ class PlannerImpl { std::cout << p_input_arg->Name() << " reused by " << p_output_arg->Name() << " as input" << std::endl; allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = reusable_input; - value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(reusable_input); found_reusable = true; break; @@ -1168,8 +1169,8 @@ class PlannerImpl { allocation_plan[reusable_input].alloc_kind == AllocKind::kAllocate) { allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = reusable_input; - value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(reusable_input); continue; } // if @@ -1187,11 +1188,11 @@ class PlannerImpl { OrtValueIndex input_arg_index{}; if (value_map.GetIdx(p_input_arg->Name(), input_arg_index).IsOK() && allocation_plan[input_arg_index].alloc_kind == AllocKind::kAllocate) { - if (value_consumer_map_[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) { + if (value_consumer_map[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) { allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = input_arg_index; - value_consumer_map_[input_arg_index].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[input_arg_index].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(input_arg_index); } } @@ -1266,7 +1267,7 @@ class PlannerImpl { } bool all_covered = true; - for (auto consumer : value_consumer_map_[output_idx_global]) { + for (auto consumer : value_consumer_map[output_idx_global]) { if (deps->find(consumer) == deps->end()) { all_covered = false; break; @@ -1277,9 +1278,9 @@ class PlannerImpl { allocation_plan[downstream_value].reused_buffer = output_idx_global; get_reused = true; // add new consumer for the value to be reused - value_consumer_map_[output_idx_global].insert(value_node_map_[downstream_value]); - value_consumer_map_[output_idx_global].insert(value_consumer_map_[downstream_value].begin(), - value_consumer_map_[downstream_value].end()); + value_consumer_map[output_idx_global].insert(value_node_map_[downstream_value]); + value_consumer_map[output_idx_global].insert(value_consumer_map[downstream_value].begin(), + value_consumer_map[downstream_value].end()); node_iter = size_iter->second.erase(node_iter); if (size_iter->second.empty()) { local_iter->second.erase(size_iter); @@ -1342,8 +1343,9 @@ class PlannerImpl { ort_value_usecount.reserve(ort_value_info_.size()); #endif for (size_t i = 0; i < stream_nodes_.size(); ++i) { - // compute use count first + // compute use count first. TODO(leca): call ComputeReuseCount() only once is enough! ORT_RETURN_IF_ERROR(ComputeReuseCount()); + for (int j = 0; static_cast(j) < ort_value_info_.size(); j++) Buffer(j) = j; #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) if (i == 0) { for (auto ort_value_info : ort_value_info_) { @@ -1693,8 +1695,8 @@ class PlannerImpl { const auto& name = input.Name(); int value_idx; ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx)); - auto origin = Buffer(value_idx); - if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) { + auto origin = AllocPlan(value_idx).reused_buffer; + if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) { // add current node as consumer for origin buffer value_consumers[origin].push_back(node_index); } @@ -1889,7 +1891,7 @@ class PlannerImpl { // 2. the consumer is in the same stream(non-cpu device), but it consumes a CPU tensor from an non-shape op. // for example, a resize cuda kernel consumer a tensor from MemCpyToHost cuda kernel on the same stream. // in this case, the FIFO can't guarantee the cpu tensor is ready when resize kernel is launching - OrtDevice::DeviceType output_arg_device = plan_.allocation_plan[output_arg_idx].location.Type(); + OrtDevice::DeviceType output_arg_device = AllocPlan(output_arg_idx).location.Type(); WaitNotificationFn wait_handle = stream_handle_registry.GetWaitHandle(stream_device, output_arg_device); if ((node_stream_map_[it->Index()] != i || output_arg_device == OrtDevice::CPU) && wait_handle != nullptr) { if (node_to_notification.find(node_index) == node_to_notification.end()) { diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index d7b1de5c930c5..3e0d94e94e48c 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -1974,6 +1974,74 @@ TEST_F(PlannerTest, TestCpuIf) { ASSERT_TRUE(exe_plan[1]->steps_[6]->ToString().substr(0, WaitOnEPStep.size()) == WaitOnEPStep); } } + +// model looks like: +// |-----------> Gather +// |-----------> Gather +// |-----------> Gather +// |-----------> Gather +// Shape ----------------> Reshape --> Shape ------------------> Reshape +// ^ ^ +// InstanceNormalization ----| InstanceNormalization ------| +// +// Python script to create this model: +// def CreateModelFor19480(): +// #shape->reshape->shape->reshape, 4 gather +// graphNodes = [] +// graphNodes.append(h.make_node('Shape', inputs=['shape_input'], outputs=['9'])) +// graphNodes.append(h.make_node('InstanceNormalization', inputs=['in0_input', 'scale0', 'B0'], outputs=['8'])) +// graphNodes.append(h.make_node('Reshape', inputs=['8', '9'], outputs=['Reshape15_output'])) +// graphNodes.append(h.make_node('Shape', inputs=['Reshape15_output'], outputs=['281'])) +// graphNodes.append(h.make_node('InstanceNormalization', inputs=['in1_input', 'scale1', 'B1'], outputs=['293'])) +// graphNodes.append(h.make_node('Reshape', inputs=['293', '281'], outputs=['output0'])) +// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices1'], outputs=['output1'])) +// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices2'], outputs=['output2'])) +// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices3'], outputs=['output3'])) +// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices4'], outputs=['output4'])) +// g = h.make_graph(graphNodes, 'issue_19480', +// [h.make_tensor_value_info('shape_input', tp.FLOAT, ['batch', 128, None, None]), +// h.make_tensor_value_info('in0_input', tp.FLOAT, ['batch', 32, None]), +// h.make_tensor_value_info('scale0', tp.FLOAT, [32]), +// h.make_tensor_value_info('B0', tp.FLOAT, [32]), +// h.make_tensor_value_info('in1_input', tp.FLOAT, ['batch', 32, None]), +// h.make_tensor_value_info('scale1', tp.FLOAT, [32]), +// h.make_tensor_value_info('B1', tp.FLOAT, [32]), +// h.make_tensor_value_info('indices1', tp.INT32, []), +// h.make_tensor_value_info('indices2', tp.INT32, []), +// h.make_tensor_value_info('indices3', tp.INT32, []), +// h.make_tensor_value_info('indices4', tp.INT32, [])], +// [h.make_tensor_value_info('output0', tp.FLOAT, None), +// h.make_tensor_value_info('output1', tp.INT64, None), +// h.make_tensor_value_info('output2', tp.INT64, None), +// h.make_tensor_value_info('output3', tp.INT64, None), +// h.make_tensor_value_info('output4', tp.INT64, None)]) +// model = h.make_model(g, opset_imports=[h.make_operatorsetid("", 17)], producer_name='producer_name') +// onnx.save(model, 'issue_19480.onnx') +// +TEST(AllocationPlannerTest, ReusedInputCrossDifferentStreams) { + SessionOptions sess_opt; + sess_opt.graph_optimization_level = TransformerLevel::Default; + + InferenceSession sess(sess_opt, GetEnvironment(), ORT_TSTR("./testdata/multi_stream_models/issue_19480.onnx")); + auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()); + status = sess.Load(); + status = sess.Initialize(); + ASSERT_TRUE(status.IsOK()) << "No crash"; + const SequentialExecutionPlan* plan = sess.GetSessionState().GetExecutionPlan(); + ASSERT_EQ(plan->allocation_plan[14].alloc_kind, AllocKind::kReuse) << "The input of reshape and gather will reuse the output of shape"; + + int gather_count = 0; + for (size_t i = 0; i < plan->execution_plan[1]->steps_.size(); i++) { + if (strstr(typeid(*(plan->execution_plan[1]->steps_[i])).name(), "LaunchKernelStep")) { + const Node* node = sess.GetSessionState().GetGraphViewer().GetNode(plan->execution_plan[1]->steps_[i]->GetNodeIndex()); + if (node->OpType() == "Gather") + gather_count++; + else + FAIL() << "CPU stream should contain only gather ops"; + } + } + ASSERT_EQ(gather_count, 4) << "4 gather ops are all placed in CPU stream"; +} #endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/multi_stream_models/issue_19480.onnx b/onnxruntime/test/testdata/multi_stream_models/issue_19480.onnx new file mode 100644 index 0000000000000000000000000000000000000000..dc7d39206dd49f4ef6daf65b7d58c5b456ecf331 GIT binary patch literal 760 zcmaixKTm@|7>Bw3f%9#m_0_6_F_p#1gab^#v5RqW(2b?J(o1>?g{IKO$uH`6@t{2^ z#m0q%=Y9D7j(aJ^(>&%f;j=_M&c!l&{_evy4DtnEiK$Fin*vE__dm*aU~nQ+XN$p9 zA11aA3GP#64zs z+VGAUzBYVq;6Ud2Mod}g2Tt_Ry!IQoq685v?9X@+FQ7}m2pC{Q_TCzB1Q$v>tF;at zE9X-02LY%OdZ2hTthTjJs;u4R{+GpCSxtg_7ivO}nrK8dbFt05KbWuCO#ReubJg)l W4Oj)N8n}nRI|Tj~OnP7p&wl_GGP8{U literal 0 HcmV?d00001 From 3bdb10d5ca4f258ec444863bcd5e839eeac5c238 Mon Sep 17 00:00:00 2001 From: jingyanwangms <47403504+jingyanwangms@users.noreply.github.com> Date: Thu, 22 Feb 2024 10:56:25 -0800 Subject: [PATCH 11/16] Move import to when needed to avoid circular dependency error (#19579) ### Description Move import to when needed to avoid circular dependency error ### Motivation and Context Fixes dependency error described here: https://github.com/microsoft/DeepSpeed/issues/5140 --------- Co-authored-by: Thiago Crepaldi --- .../python/training/ortmodule/_graph_execution_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 779b6bfe50422..fda6e345da235 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -20,7 +20,6 @@ from onnxruntime.capi import _pybind_state as C from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference from onnxruntime.training.utils import ORTModelInputOutputSchemaType, PTable, onnx_dtype_to_pytorch_dtype -from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils from ._fallback import ( @@ -143,6 +142,9 @@ def __init__( self._zero_stage3_param_map = {} if self._runtime_options.enable_zero_stage3_support: + # Move import to here to avoid circular dependency error + from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 # type: ignore[import] + # Cannot toggle feature enabling/disabling after the first time enabled. configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True) From fe82fccf1a4d7ea6c24c8448d7264df36605c370 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Fri, 23 Feb 2024 05:09:28 +0800 Subject: [PATCH 12/16] [js/webgpu] Fix Conv2DTransposeMatMul f16 compilation failure (#19596) This is used in sam-h-decoder-f16. ### Description ### Motivation and Context --- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index b5b6a2a15cd8c..11c8778b72335 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -23,17 +23,17 @@ import {DataType} from '../../../../wasm-common'; import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; -import {biasSnippet, typeSnippet} from './activation_util'; +import {biasSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; const conv2dTransposeCommonSnippet = - (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, innerElementSize = 4): string => { - const type = typeSnippet(innerElementSize, 'f32'); + (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, type: string, + innerElementSize = 4): string => { const getWSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: @@ -47,7 +47,7 @@ const conv2dTransposeCommonSnippet = let v1 = w[getIndexFromCoords4D(coord1, vec4(uniforms.w_shape))]; let v2 = w[getIndexFromCoords4D(coord2, vec4(uniforms.w_shape))]; let v3 = w[getIndexFromCoords4D(coord3, vec4(uniforms.w_shape))]; - return vec4(v0, v1, v2, v3); + return ${type}(v0, v1, v2, v3); `; default: throw new Error(`innerElementSize ${innerElementSize} is not supported.`); @@ -224,7 +224,7 @@ export const createConv2DTransposeMatMulProgramInfo = const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); inputVariables.push(bias); declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + fn getBiasByOutputCoords(coords : vec4) -> ${bias.type.value} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } @@ -236,16 +236,20 @@ export const createConv2DTransposeMatMulProgramInfo = {name: 'pads', type: 'i32', length: pads.length} ]; appendActivationUniforms(attributes, uniforms); + const elemType = tensorTypeToWsglStorageType(inputs[0].dataType, 1); + if (elemType !== 'f16' && elemType !== 'f32') { + throw new Error(`elemType ${elemType} is not supported.`); + } return ` ${utilFunctions('uniforms.result_strides')} ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}; ${declareFunctions} - ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)} + ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, x.type.value, innerElementSize)} ${ isVec4 ? makeMatMulPackedVec4Source( - elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : + elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( - elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false, + elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner, false, undefined, sequentialAccessByThreads)}`; }; From 09622418c45b265977a8f1f17581e15719357423 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 22 Feb 2024 13:15:13 -0800 Subject: [PATCH 13/16] Add special handling if there is only 1 graph inside the cached QNN context binary (#19594) Add special handling if there is only 1 graph inside the cached QNN context binary. No need to make the EPContext node name match the QNN graph name. This is for better backward compatibility in case the QNN context model is generated before the PR for QNN context binary model support multi-partition. --- .../qnn/builder/onnx_ctx_model_helper.cc | 6 +- .../qnn/builder/onnx_ctx_model_helper.h | 3 +- .../qnn/builder/qnn_backend_manager.cc | 15 ++-- .../providers/qnn/qnn_execution_provider.cc | 3 +- .../test/providers/qnn/qnn_ep_context_test.cc | 83 ++++++++++++++++++- 5 files changed, 99 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index c2e71081b898e..2d8ec295d613b 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -151,12 +151,14 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - std::unordered_map>& qnn_models) { + std::unordered_map>& qnn_models, + const logging::Logger& logger) { Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager, qnn_models); // This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model if (!status.IsOK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage()); + LOGS(logger, ERROR) << "Failed to load from EpContext model. " << status.ErrorMessage(); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContext model. ", status.ErrorMessage()); } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index b1360b4e576fa..7d56b45a1dbcd 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -56,7 +56,8 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - std::unordered_map>& qnn_models); + std::unordered_map>& qnn_models, + const logging::Logger& logger); Status CreateEPContextNodes(Model* model, unsigned char* buffer, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 5f0b87c7cb9d7..ca34a1efa6ca7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -573,11 +573,16 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t // More work to support multiple partition, how to map the graph name in compile to qnn graph name // Need the lower level framework to understand EPContext op and pass in the partition_name in fused_node during Compile - for (uint32_t i = 0; i < graph_count; ++i) { - std::string graph_name(graphs_info[i].graphInfoV1.graphName); - auto qnn_model_pos = qnn_models.find(graph_name); - ORT_RETURN_IF(qnn_model_pos == qnn_models.end(), graph_name + " does not match any EPContext node names."); - ORT_RETURN_IF_ERROR(qnn_model_pos->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[i])); + if (1 == graph_count) { + auto qnn_model_pose = qnn_models.begin(); + ORT_RETURN_IF_ERROR(qnn_model_pose->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[0])); + } else { + for (uint32_t i = 0; i < graph_count; ++i) { + std::string graph_name(graphs_info[i].graphInfoV1.graphName); + auto qnn_model_pos = qnn_models.find(graph_name); + ORT_RETURN_IF(qnn_model_pos == qnn_models.end(), graph_name + " does not match any EPContext node names."); + ORT_RETURN_IF_ERROR(qnn_model_pos->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[i])); + } } qnn_sys_interface_.systemContextFree(sys_ctx_handle); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index f5a166d36b15a..9a6540a3efea5 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -670,7 +670,8 @@ Status QNNExecutionProvider::Compile(const std::vector& fused ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxGraph(main_ctx_graph_viewer, context_cache_path, qnn_backend_manager_.get(), - qnn_models)); + qnn_models, + logger)); for (auto fused_node_and_graph : fused_nodes_and_graphs) { const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index b1f3b52e77553..eaef6f6315157 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -463,7 +463,6 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_InvalidGraph) { InferenceSessionWrapper session_object{so, GetEnvironment()}; - std::string provider_type = kCpuExecutionProvider; ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); ASSERT_STATUS_OK(session_object.Load(qnn_ctx_model_data.data(), static_cast(qnn_ctx_model_data.size()))); // Verify the return status with code INVALID_GRAPH @@ -486,7 +485,6 @@ std::string CreateQnnCtxModelWithNonEmbedMode(std::string external_bin_path) { auto* graph_output = helper.MakeOutput(shape); Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain); ep_context_node.AddAttribute("embed_mode", static_cast(0)); - // The .. in the path will cause INVALID_GRAPH ep_context_node.AddAttribute("ep_cache_context", external_bin_path); ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0"); ep_context_node.AddAttribute("source", "QNN"); @@ -651,6 +649,87 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } +// Context binary only contains a single QNN graph, generated context cache model (detached mode) only has 1 EPContext node +// Create another Onnx model which also reference to the bin file, +// but the node name is not same with the QNN graph name inside the bin file. +// This is to support backward compitable for the models generated before the PR that +// make context generation support multi-partition +TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphNameInCtx) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; + std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + std::remove(context_binary_file.c_str()); + std::remove(context_bin.string().c_str()); + + std::unordered_map session_option_pairs; + session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); + + const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); + const std::string op_type = "Atan"; + + // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. + // 1st run will generate the Onnx skeleton file + Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); + + // Check the Onnx skeleton file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + // Check the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_bin)); + + const std::unordered_map domain_to_version = {{"", 11}, {kMSDomain, 1}}; + auto& logging_manager = DefaultLoggingManager(); + onnxruntime::Model model("QNN_ctx_model", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + std::vector shape = {1, 2, 3}; + NodeArg* graph_input = MakeTestInput(helper, TestInputDef(shape, false, {0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f})); + auto* graph_output = helper.MakeOutput(shape); + Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain); + ep_context_node.AddAttribute("embed_mode", static_cast(0)); + ep_context_node.AddAttribute("ep_cache_context", context_bin.string()); + ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0"); + ep_context_node.AddAttribute("source", "QNNExecutionProvider"); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(graph.Resolve()); + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + // loads and run from Onnx skeleton file + Qnn context cache binary file + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::OK); + + // Clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + ASSERT_EQ(std::remove(context_bin.string().c_str()), 0); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test From 76a2a487a12c7ec579f453a36932429164494ef6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 22 Feb 2024 13:58:17 -0800 Subject: [PATCH 14/16] Bump ip from 1.1.8 to 1.1.9 in /js/react_native/e2e (#19583) Bumps [ip](https://github.com/indutny/node-ip) from 1.1.8 to 1.1.9.
Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=ip&package-manager=npm_and_yarn&previous-version=1.1.8&new-version=1.1.9)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) Dependabot will merge this PR once CI passes on it, as requested by @fs-eire. [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- js/react_native/e2e/yarn.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/react_native/e2e/yarn.lock b/js/react_native/e2e/yarn.lock index 9e20a286c4e27..6f05faf046098 100644 --- a/js/react_native/e2e/yarn.lock +++ b/js/react_native/e2e/yarn.lock @@ -3351,9 +3351,9 @@ invariant@^2.2.4: loose-envify "^1.0.0" ip@^1.1.5: - version "1.1.8" - resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48" - integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg== + version "1.1.9" + resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396" + integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ== is-accessor-descriptor@^0.1.6: version "0.1.6" From 5e5c36f6df95dfbb25787ea385f733f8c9ef691e Mon Sep 17 00:00:00 2001 From: AtomicVar Date: Fri, 23 Feb 2024 09:03:56 +0800 Subject: [PATCH 15/16] Fix citation author name issue (#19597) Use `name` rather than `given-names` to set author name. ### Motivation and Context The old CITATION.cff uses `given-names` to set author names, which won't be rendered properly with some bibtex style of LaTeX: image The problem is that **the `"ONNX Runtime developers"` is regarded as a human name**. How to fix: by using `name` to set author name, the generated Bibtex entry will use `{}` to enclose the `"ONNX Runtime developers"`. Then it is displayed literally: image --- CITATION.cff | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CITATION.cff b/CITATION.cff index 82bcac5a7b750..10b7290022aef 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -3,8 +3,7 @@ title: ONNX Runtime message: "Please use this information to cite ONNX Runtime in research or other publications." authors: - - affiliation: Microsoft Corporation - given-names: ONNX Runtime developers + - name: ONNX Runtime developers date-released: 2018-11-29 url: "https://onnxruntime.ai" repository-code: "https://github.com/microsoft/onnxruntime" From 4ab497603e915ca992b96ef1ec25bfcf8b9a2ad5 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 22 Feb 2024 17:04:59 -0800 Subject: [PATCH 16/16] Enable user to set QNN HTP performance mode for every session run (#19521) ### Description Currently, the QNN HTP performance mode is set during session creation, there's no way to change it afterwards. There's requirement to set it high performance mode for high priority request and set it back to low performance mode later to save the power when the incoming request is idle for example. Now, still keeps the performance mode at the session level in QNN EP options which is used at the default one. Ort QNN EP will set it once if user set it. And there are setting (qnn.htp_perf_mode and qnn.htp_perf_mode_post_run) in run option to change the performance mode before and after session run. There's recommended scenario that user set the mode to high performance mode before the the inference sun so that user can get the result back ASAP. And set the mode to low performance mode after the inference to save the power. --- .../core/framework/execution_provider.h | 10 +- .../onnxruntime_run_options_config_keys.h | 12 + .../framework/stream_execution_context.cc | 4 +- .../providers/cann/cann_execution_provider.cc | 2 +- .../providers/cann/cann_execution_provider.h | 2 +- .../providers/cuda/cuda_execution_provider.cc | 4 +- .../providers/cuda/cuda_execution_provider.h | 5 +- .../src/ExecutionProvider.h | 4 +- .../providers/js/js_execution_provider.cc | 4 +- .../core/providers/js/js_execution_provider.h | 4 +- .../migraphx/migraphx_execution_provider.cc | 4 +- .../migraphx/migraphx_execution_provider.h | 4 +- .../qnn/builder/qnn_backend_manager.cc | 75 +++--- .../qnn/builder/qnn_backend_manager.h | 19 +- .../providers/qnn/qnn_execution_provider.cc | 198 +++++++++++++++- .../providers/qnn/qnn_execution_provider.h | 73 +++++- .../providers/rocm/rocm_execution_provider.cc | 4 +- .../providers/rocm/rocm_execution_provider.h | 4 +- .../tensorrt/tensorrt_execution_provider.cc | 4 +- .../tensorrt/tensorrt_execution_provider.h | 4 +- onnxruntime/core/session/inference_session.cc | 12 +- .../cuda_execution_provider_test.cc | 13 +- .../test/providers/qnn/qnn_basic_test.cc | 217 ++++++++++++++++-- 23 files changed, 577 insertions(+), 105 deletions(-) diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 31c988f500779..c1cc69edc17d8 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -33,6 +33,8 @@ class Node; #include "core/framework/stream_handles.h" #include "core/framework/tuning_context.h" +struct OrtRunOptions; + namespace onnxruntime { /** @@ -51,6 +53,8 @@ struct NodeComputeInfo { DestroyFunctionStateFunc release_state_func; }; +using RunOptions = OrtRunOptions; + enum class DataLayout { NCHW, NHWC, @@ -184,7 +188,7 @@ class IExecutionProvider { Run may not be finished on device This function should be regarded as the point after which a new Run would start to submit commands from CPU */ - virtual common::Status OnRunStart() { return Status::OK(); } + virtual common::Status OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); } /** Called when InferenceSession::Run ended @@ -192,7 +196,9 @@ class IExecutionProvider { may not be finished on device This function should be regarded as the point that all commands of current Run has been submmited by CPU */ - virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); } + virtual common::Status OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) { + return Status::OK(); + } /** Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for diff --git a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h index 1f5fcd50e185c..b0a17e175fef3 100644 --- a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h @@ -30,3 +30,15 @@ static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memor // Per default it will be set to '0' // Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream. static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers"; + +// Set HTP performance mode for QNN HTP backend before session run. +// options for HTP performance mode: "burst", "balanced", "default", "high_performance", +// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver", +// "sustained_high_performance". Default to "default". +static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode"; + +// Set HTP performance mode for QNN HTP backend post session run. +static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run"; + +// Set RPC control latency for QNN HTP backend +static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency"; diff --git a/onnxruntime/core/framework/stream_execution_context.cc b/onnxruntime/core/framework/stream_execution_context.cc index 875e7f395bfa8..dd7f4d35b34bd 100644 --- a/onnxruntime/core/framework/stream_execution_context.cc +++ b/onnxruntime/core/framework/stream_execution_context.cc @@ -181,11 +181,13 @@ void RunSince(size_t stream_idx, StreamExecutionContext& ctx, SessionScope& sess } #ifdef USE_CANN + // Leave it to CANN EP to fill the gap if they want to use run_options + static onnxruntime::RunOptions run_options; // For CANN EP, it is necessary to explicitly create a corresponding Context for each thread in the thread pool, // which is different from CUDA Runtime API, but similar to CUDA Driver API. auto& execution_providers = ctx.GetSessionState().GetExecutionProviders(); for (auto& xp : execution_providers) { - auto status = xp->OnRunStart(); + auto status = xp->OnRunStart(run_options); if (!status.IsOK()) { ctx.SetStatus(status); return; diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index 752b742805a7c..9a242919665bb 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1045,7 +1045,7 @@ CANNExecutionProvider::~CANNExecutionProvider() { } // All threads share the same context and stream -Status CANNExecutionProvider::OnRunStart() { +Status CANNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { CANN_RETURN_IF_ERROR(aclrtSetDevice(info_.device_id)); return Status::OK(); diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.h b/onnxruntime/core/providers/cann/cann_execution_provider.h index 63ae980869c65..d83bd88d6958f 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.h +++ b/onnxruntime/core/providers/cann/cann_execution_provider.h @@ -33,7 +33,7 @@ class CANNExecutionProvider : public IExecutionProvider { explicit CANNExecutionProvider(const CANNExecutionProviderInfo& info); virtual ~CANNExecutionProvider(); - Status OnRunStart() override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; template Status Fill(Tensor* y, void* addr, aclrtStream stream) const { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 48a952e6dd98f..0dd568c5ecc05 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -386,7 +386,7 @@ Status CUDAExecutionProvider::Sync() const { return Status::OK(); } -Status CUDAExecutionProvider::OnRunStart() { +Status CUDAExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { // always set CUDA device when session::Run() in case it runs in a worker thread CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId())); if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { @@ -396,7 +396,7 @@ Status CUDAExecutionProvider::OnRunStart() { return Status::OK(); } -Status CUDAExecutionProvider::OnRunEnd(bool sync_stream) { +Status CUDAExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) { if (GetPerThreadContext().IsGraphCaptureAllowed()) { GetPerThreadContext().CaptureEnd(); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 55f0b5570e0ee..5f62f313b86a2 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -29,9 +29,9 @@ class CUDAExecutionProvider : public IExecutionProvider { Status Sync() const override; - Status OnRunStart() override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; DataLayout GetPreferredLayout() const override; @@ -115,6 +115,7 @@ class CUDAExecutionProvider : public IExecutionProvider { PerThreadContext(OrtDevice::DeviceId device_id, cudaStream_t stream, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy, CUDAExecutionProviderExternalAllocatorInfo external_alloc_info, OrtArenaCfg* arena_cfg); ~PerThreadContext(); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext); cublasHandle_t CublasHandle() const { return cublas_handle_; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 5617bc7bdcac6..841d6244a983e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -270,7 +270,7 @@ namespace Dml return m_impl->OnSessionInitializationEnd(); } - virtual onnxruntime::Status Sync() const final override + onnxruntime::Status Sync() const final override { // Completely wait until the device has completed all preceding tasks. // The application could have called SynchronizeBoundOutputs(). @@ -278,7 +278,7 @@ namespace Dml return Status::OK(); } - virtual onnxruntime::Status OnRunEnd(bool /*sync_stream*/) final override + onnxruntime::Status OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) final override { // Flush any pending work to the GPU, but don't block for completion, permitting it // to overlap other work. diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 799d4172f2b64..62c3981682cfc 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -756,7 +756,7 @@ std::unique_ptr JsExecutionProvider::GetDataTransfer JsExecutionProvider::~JsExecutionProvider() { } -Status JsExecutionProvider::OnRunStart() { +Status JsExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured()) { LOGS(*GetLogger(), INFO) << "Capturing the webgpu graph for this model"; EM_ASM({ Module.jsepCaptureBegin(); }); @@ -764,7 +764,7 @@ Status JsExecutionProvider::OnRunStart() { return Status::OK(); } -Status JsExecutionProvider::OnRunEnd(bool sync_stream) { +Status JsExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { if (IsGraphCaptureEnabled() && !IsGraphCaptured()) { if (IsGraphCaptureAllowed()) { EM_ASM({ Module.jsepCaptureEnd(); }); diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index 91a3256ec2bd5..b4518c67d1e60 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -59,8 +59,8 @@ class JsExecutionProvider : public IExecutionProvider { std::vector CreatePreferredAllocators() override; - Status OnRunStart() override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; bool IsGraphCaptureEnabled() const override; bool IsGraphCaptured() const override; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 40e76a0a67782..50782569ee80a 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1383,11 +1383,11 @@ Status MIGraphXExecutionProvider::Sync() const { return Status::OK(); } -Status MIGraphXExecutionProvider::OnRunStart() { +Status MIGraphXExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); } -Status MIGraphXExecutionProvider::OnRunEnd(bool) { +Status MIGraphXExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) { auto status = hipStreamQuery(stream_); if (status != hipSuccess) { diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index d582338c7e067..c3617f409e72c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -56,9 +56,9 @@ class MIGraphXExecutionProvider : public IExecutionProvider { #ifdef MIGRAPHX_STREAM_SYNC Status Sync() const override; - Status OnRunStart() override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; #endif std::vector> diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index ca34a1efa6ca7..e354bf6562722 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -634,11 +634,6 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_ LOGS(logger, VERBOSE) << "CreateContext succeed."; } - if (htp_performance_mode_ != HtpPerformanceMode::kHtpDefault) { - ORT_RETURN_IF_ERROR(SetHtpPowerConfig()); - LOGS(logger, VERBOSE) << "SetHtpPowerConfig succeed."; - } - LOGS(logger, VERBOSE) << "QNN SetupBackend succeed"; backend_setup_completed_ = true; @@ -646,7 +641,7 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_ return Status::OK(); } -Status QnnBackendManager::SetHtpPowerConfig() { +Status QnnBackendManager::CreateHtpPowerCfgId(uint32_t device_id, uint32_t core_id, uint32_t& htp_power_config_id) { QnnDevice_Infrastructure_t qnn_device_infra = nullptr; auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); @@ -656,23 +651,37 @@ Status QnnBackendManager::SetHtpPowerConfig() { "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; // Get power client id - status = htp_perf_infra.createPowerConfigId(/*device_id=*/0, /*core_id=*/0, &htp_power_config_client_id_); + status = htp_perf_infra.createPowerConfigId(device_id, core_id, &htp_power_config_id); ORT_RETURN_IF(QNN_SUCCESS != status, "createPowerConfigId failed."); + return Status::OK(); +} + +Status QnnBackendManager::SetHtpPowerConfig(uint32_t htp_power_config_client_id, + HtpPerformanceMode htp_performance_mode) { + QnnDevice_Infrastructure_t qnn_device_infra = nullptr; + auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); + ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); + + auto* htp_infra = static_cast(qnn_device_infra); + ORT_RETURN_IF(QNN_HTP_DEVICE_INFRASTRUCTURE_TYPE_PERF != htp_infra->infraType, + "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); + QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; + constexpr const int kNumConfigs = 1; std::vector power_configs( kNumConfigs); QnnHtpPerfInfrastructure_PowerConfig_t& dcvs_config = power_configs[0]; dcvs_config.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_DCVS_V3; QnnHtpPerfInfrastructure_DcvsV3_t& dcvs_v3 = dcvs_config.dcvsV3Config; - dcvs_v3.contextId = htp_power_config_client_id_; + dcvs_v3.contextId = htp_power_config_client_id; dcvs_v3.setSleepDisable = 0; dcvs_v3.sleepDisable = 0; dcvs_v3.setDcvsEnable = 1; dcvs_v3.dcvsEnable = kDcvsDisable; dcvs_v3.powerMode = QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_PERFORMANCE_MODE; // choose performance mode - switch (htp_performance_mode_) { + switch (htp_performance_mode) { case HtpPerformanceMode::kHtpBurst: dcvs_v3.setSleepLatency = 1; // true dcvs_v3.sleepLatency = kSleepMinLatency; @@ -771,25 +780,40 @@ Status QnnBackendManager::SetHtpPowerConfig() { dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM_PLUS; break; default: - ORT_THROW("Invalid performance profile %d", static_cast(htp_performance_mode_)); + ORT_THROW("Invalid performance profile %d", static_cast(htp_performance_mode)); break; } std::vector perf_power_configs_ptr = ObtainNullTermPtrVector(power_configs); - status = htp_perf_infra.setPowerConfig(htp_power_config_client_id_, perf_power_configs_ptr.data()); + status = htp_perf_infra.setPowerConfig(htp_power_config_client_id, perf_power_configs_ptr.data()); ORT_RETURN_IF(QNN_SUCCESS != status, "setPowerConfig failed for HTP performance mode."); - // Set rpc control latency here, but note that v68 doesn't support rpc polling mode. - if (rpc_control_latency_ != 0) { + return Status::OK(); +} + +Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_id, + uint32_t rpc_control_latency) { + if (rpc_control_latency != 0) { + QnnDevice_Infrastructure_t qnn_device_infra = nullptr; + auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); + ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); + + auto* htp_infra = static_cast(qnn_device_infra); + ORT_RETURN_IF(QNN_HTP_DEVICE_INFRASTRUCTURE_TYPE_PERF != htp_infra->infraType, + "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); + QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; + + // Set rpc control latency here, but note that v68 doesn't support rpc polling mode. constexpr int kNumRpcPollingPowerConfigs = 2; std::vector rpc_power_configs(kNumRpcPollingPowerConfigs); - QnnHtpPerfInfrastructure_PowerConfig_t& rpc_control_latency = rpc_power_configs[0]; + QnnHtpPerfInfrastructure_PowerConfig_t& rpc_control_latency_cfg = rpc_power_configs[0]; // v68 doesn't support this. QnnHtpPerfInfrastructure_PowerConfig_t& rpc_polling_time = rpc_power_configs[1]; - rpc_control_latency.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY; + rpc_control_latency_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY; rpc_polling_time.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_POLLING_TIME; - rpc_control_latency.rpcControlLatencyConfig = rpc_control_latency_; - perf_power_configs_ptr = ObtainNullTermPtrVector(rpc_power_configs); - status = htp_perf_infra.setPowerConfig(htp_power_config_client_id_, perf_power_configs_ptr.data()); + rpc_control_latency_cfg.rpcControlLatencyConfig = rpc_control_latency; + std::vector perf_power_configs_ptr = + ObtainNullTermPtrVector(rpc_power_configs); + status = htp_perf_infra.setPowerConfig(htp_power_config_client_id, perf_power_configs_ptr.data()); ORT_RETURN_IF(QNN_SUCCESS != status, "setPowerConfig failed for RPC control latency."); } @@ -810,11 +834,7 @@ void QnnBackendManager::Split(std::vector& split_string, } } -Status QnnBackendManager::DestroyHTPPowerConfigID() { - if (htp_performance_mode_ == HtpPerformanceMode::kHtpDefault) { - return Status::OK(); - } - +Status QnnBackendManager::DestroyHTPPowerConfigID(uint32_t htp_power_config_id) { QnnDevice_Infrastructure_t qnn_device_infra = nullptr; auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); @@ -824,7 +844,7 @@ Status QnnBackendManager::DestroyHTPPowerConfigID() { "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; - Qnn_ErrorHandle_t destroy_ret = htp_perf_infra.destroyPowerConfigId(htp_power_config_client_id_); + Qnn_ErrorHandle_t destroy_ret = htp_perf_infra.destroyPowerConfigId(htp_power_config_id); ORT_RETURN_IF(QNN_SUCCESS != destroy_ret, "destroyPowerConfigId failed."); return Status::OK(); } @@ -834,12 +854,7 @@ void QnnBackendManager::ReleaseResources() { return; } - auto result = DestroyHTPPowerConfigID(); - if (Status::OK() != result) { - ORT_THROW("Failed to DestroyHTPPowerConfigID."); - } - - result = ReleaseContext(); + auto result = ReleaseContext(); if (Status::OK() != result) { ORT_THROW("Failed to ReleaseContext."); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 36375522b5a0a..ff97c4c3a991c 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -33,8 +33,6 @@ class QnnBackendManager { public: QnnBackendManager(std::string&& backend_path, ProfilingLevel profiling_level, - uint32_t rpc_control_latency, - HtpPerformanceMode htp_performance_mode, ContextPriority context_priority, std::string&& qnn_saver_path, uint32_t device_id, @@ -42,8 +40,6 @@ class QnnBackendManager { uint32_t soc_model) : backend_path_(backend_path), profiling_level_(profiling_level), - rpc_control_latency_(rpc_control_latency), - htp_performance_mode_(htp_performance_mode), context_priority_(context_priority), qnn_saver_path_(qnn_saver_path), device_id_(device_id), @@ -92,7 +88,13 @@ class QnnBackendManager { Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context); - Status SetHtpPowerConfig(); + Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id); + + Status SetHtpPowerConfig(uint32_t htp_power_config_client_id, + HtpPerformanceMode htp_performance_mode); + + Status SetRpcControlLatency(uint32_t htp_power_config_client_id, + uint32_t rpc_control_latency); const QNN_INTERFACE_VER_TYPE& GetQnnInterface() { return qnn_interface_; } @@ -141,6 +143,8 @@ class QnnBackendManager { const std::string& GetSdkVersion() { return sdk_build_version_; } + Status DestroyHTPPowerConfigID(uint32_t htp_power_config_id); + private: void* LoadLib(const char* file_name, int flags, std::string& error_msg); @@ -150,8 +154,6 @@ class QnnBackendManager { Status UnloadLib(void* handle); - Status DestroyHTPPowerConfigID(); - void* LibFunction(void* handle, const char* symbol, std::string& error_msg); template @@ -232,15 +234,12 @@ class QnnBackendManager { QnnBackendType qnn_backend_type_ = QnnBackendType::CPU; Qnn_ProfileHandle_t profile_backend_handle_ = nullptr; std::vector op_package_paths_; - uint32_t rpc_control_latency_ = 0; - HtpPerformanceMode htp_performance_mode_; ContextPriority context_priority_; std::string sdk_build_version_ = ""; #ifdef _WIN32 std::set mod_handles_; #endif const std::string qnn_saver_path_; - uint32_t htp_power_config_client_id_ = 0; uint32_t device_id_ = 0; QnnHtpDevice_Arch_t htp_arch_ = QNN_HTP_DEVICE_ARCH_NONE; uint32_t soc_model_ = QNN_SOC_MODEL_UNKNOWN; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 9a6540a3efea5..3d9cfd92b7922 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -7,6 +7,7 @@ #include "core/framework/compute_capability.h" #include "core/graph/graph_viewer.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/onnxruntime_run_options_config_keys.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/kernel_registry.h" #include "core/platform/env.h" @@ -18,11 +19,36 @@ #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_def.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" +#include "core/framework/run_options.h" namespace onnxruntime { constexpr const char* QNN = "QNN"; +static std::unique_ptr>> s_run_on_unload_; + +void RunOnUnload(std::function function) { + OrtMutex mutex; + std::lock_guard guard(mutex); + if (!s_run_on_unload_) { + s_run_on_unload_ = std::make_unique>>(); + } + s_run_on_unload_->push_back(std::move(function)); +} + +struct OnUnload { + ~OnUnload() { + if (!s_run_on_unload_) + return; + + for (auto& function : *s_run_on_unload_) + function(); + + s_run_on_unload_.reset(); + } + +} g_on_unload; + static void ParseProfilingLevel(std::string profiling_level_string, qnn::ProfilingLevel& profiling_level) { std::transform(profiling_level_string.begin(), @@ -193,18 +219,18 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } static const std::string RPC_CONTROL_LANTENCY = "rpc_control_latency"; - uint32_t rpc_control_latency = 0; auto latency_pos = provider_options_map.find(RPC_CONTROL_LANTENCY); if (latency_pos != provider_options_map.end()) { - rpc_control_latency = static_cast(std::stoul(latency_pos->second)); - LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency; + default_rpc_control_latency_ = static_cast(std::stoul(latency_pos->second)); + LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << default_rpc_control_latency_; } - qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault; + // default_htp_performance_mode from QNN EP option. + // set it once only for each thread as default so user don't need to set it for every session run static const std::string HTP_PERFORMANCE_MODE = "htp_performance_mode"; auto htp_performance_mode_pos = provider_options_map.find(HTP_PERFORMANCE_MODE); if (htp_performance_mode_pos != provider_options_map.end()) { - ParseHtpPerformanceMode(htp_performance_mode_pos->second, htp_performance_mode); + ParseHtpPerformanceMode(htp_performance_mode_pos->second, default_htp_performance_mode_); } htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault; @@ -241,15 +267,14 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } static const std::string QNN_DEVICE_ID = "device_id"; - uint32_t device_id = 0; auto dev_id_pos = provider_options_map.find(QNN_DEVICE_ID); if (dev_id_pos != provider_options_map.end()) { int value = std::stoi(dev_id_pos->second); if (value < 0) { LOGS_DEFAULT(WARNING) << "Invalid device ID '" << value - << "', only >= 0 allowed. Set to " << device_id << "."; + << "', only >= 0 allowed. Set to " << device_id_ << "."; } else { - device_id = static_cast(value); + device_id_ = static_cast(value); } } @@ -276,15 +301,23 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio qnn_backend_manager_ = std::make_unique( std::move(backend_path), profiling_level, - rpc_control_latency, - htp_performance_mode, context_priority, std::move(qnn_saver_path), - device_id, + device_id_, htp_arch, soc_model); } +QNNExecutionProvider::~QNNExecutionProvider() { + // clean up thread local context caches + std::lock_guard lock(context_state_.mutex); + for (const auto& cache_weak : context_state_.caches_to_update_on_destruction) { + const auto cache = cache_weak.lock(); + if (!cache) continue; + ORT_IGNORE_RETURN_VALUE(cache->erase(this)); + } +} + bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger) const { const std::string& op_type = node_unit.OpType(); @@ -725,4 +758,147 @@ const InlinedVector QNNExecutionProvider::GetEpContextNodes() const return ep_context_nodes; } + +QNNExecutionProvider::PerThreadContext::PerThreadContext(qnn::QnnBackendManager* qnn_backend_manager, + uint32_t device_id, + uint32_t core_id, + qnn::HtpPerformanceMode default_htp_performance_mode, + uint32_t default_rpc_control_latency) + : qnn_backend_manager_(qnn_backend_manager) { + Status rt = qnn_backend_manager_->CreateHtpPowerCfgId(device_id, core_id, htp_power_config_id_); + is_htp_power_config_id_valid_ = rt.IsOK(); + // default_htp_performance_mode and default_rpc_control_latency are from QNN EP option. + // set it once only for each thread as default so user don't need to set it for every session run + if (is_htp_power_config_id_valid_) { + if (qnn::HtpPerformanceMode::kHtpDefault != default_htp_performance_mode) { + ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetHtpPowerConfig(htp_power_config_id_, + default_htp_performance_mode)); + } + if (default_rpc_control_latency > 0) { + ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetRpcControlLatency(htp_power_config_id_, + default_rpc_control_latency)); + } + } +} + +QNNExecutionProvider::PerThreadContext::~PerThreadContext() { + if (is_htp_power_config_id_valid_) { + ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->DestroyHTPPowerConfigID(htp_power_config_id_)); + } +} + +QNNExecutionProvider::PerThreadContext& QNNExecutionProvider::GetPerThreadContext() const { + const auto& per_thread_context_cache = PerThreadContextCache(); + + // try to use cached context + auto cached_context_it = per_thread_context_cache->find(this); + if (cached_context_it != per_thread_context_cache->end()) { + auto cached_context = cached_context_it->second.lock(); + ORT_ENFORCE(cached_context); + return *cached_context; + } + + // get context and update cache + std::shared_ptr context; + { + std::lock_guard lock(context_state_.mutex); + + // get or create a context + if (context_state_.retired_context_pool.empty()) { + uint32_t core_id = 0; + context = std::make_shared(qnn_backend_manager_.get(), device_id_, core_id, + default_htp_performance_mode_, default_rpc_control_latency_); + } else { + context = context_state_.retired_context_pool.back(); + context_state_.retired_context_pool.pop_back(); + } + + // insert into active_contexts, should not already be present + const auto active_contexts_insert_result = context_state_.active_contexts.insert(context); + ORT_ENFORCE(active_contexts_insert_result.second); + + // insert into caches_to_update_on_destruction, may already be present + ORT_IGNORE_RETURN_VALUE(context_state_.caches_to_update_on_destruction.insert(per_thread_context_cache)); + } + + per_thread_context_cache->insert(std::make_pair(this, context)); + + return *context; +} + +void QNNExecutionProvider::ReleasePerThreadContext() const { + const auto& per_thread_context_cache = PerThreadContextCache(); + + auto cached_context_it = per_thread_context_cache->find(this); + ORT_ENFORCE(cached_context_it != per_thread_context_cache->end()); + auto cached_context = cached_context_it->second.lock(); + ORT_ENFORCE(cached_context); + + { + std::lock_guard lock(context_state_.mutex); + context_state_.active_contexts.erase(cached_context); + context_state_.retired_context_pool.push_back(cached_context); + } + + per_thread_context_cache->erase(cached_context_it); +} + +Status QNNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { + auto backend_type = qnn_backend_manager_->GetQnnBackendType(); + if (qnn::QnnBackendType::HTP != backend_type && qnn::QnnBackendType::DSP != backend_type) { + return Status::OK(); + } + + std::string htp_perf_mode = ""; + qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault; + if (run_options.config_options.TryGetConfigEntry(kOrtRunOptionsConfigQnnPerfMode, htp_perf_mode)) { + // set power mode + ParseHtpPerformanceMode(htp_perf_mode, htp_performance_mode); + } + + std::string rpc_latency = ""; + uint32_t rpc_control_latency = 0; + if (run_options.config_options.TryGetConfigEntry(kOrtRunOptionsConfigQnnRpcControlLatency, rpc_latency)) { + rpc_control_latency = static_cast(std::stoul(rpc_latency)); + LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency; + } + + if (GetPerThreadContext().IsHtpPowerConfigIdValid()) { + if (qnn::HtpPerformanceMode::kHtpDefault != htp_performance_mode) { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(), + htp_performance_mode)); + } + + if (rpc_control_latency > 0) { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetRpcControlLatency(GetPerThreadContext().GetHtpPowerConfigId(), + rpc_control_latency)); + } + } + + return Status::OK(); +} + +Status QNNExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& run_options) { + auto backend_type = qnn_backend_manager_->GetQnnBackendType(); + if (qnn::QnnBackendType::HTP != backend_type && qnn::QnnBackendType::DSP != backend_type) { + return Status::OK(); + } + + std::string htp_perf_mode = ""; + qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault; + if (run_options.config_options.TryGetConfigEntry(kOrtRunOptionsConfigQnnPerfModePostRun, htp_perf_mode)) { + // set power mode + ParseHtpPerformanceMode(htp_perf_mode, htp_performance_mode); + } + + if (qnn::HtpPerformanceMode::kHtpDefault != htp_performance_mode) { + if (!GetPerThreadContext().IsHtpPowerConfigIdValid()) { + return Status::OK(); + } + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(), + htp_performance_mode)); + } + + return Status::OK(); +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 0bcaa39b22f6d..43b5e7bff827e 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -12,14 +12,19 @@ #include "core/providers/qnn/builder/qnn_model.h" #include "core/providers/qnn/builder/qnn_configs_helper.h" #include "HTP/QnnHtpGraph.h" +#include +#include +#include namespace onnxruntime { +void RunOnUnload(std::function function); + // Logical device representation. class QNNExecutionProvider : public IExecutionProvider { public: explicit QNNExecutionProvider(const ProviderOptions& provider_options_map, const SessionOptions* session_options); - virtual ~QNNExecutionProvider() = default; + virtual ~QNNExecutionProvider(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QNNExecutionProvider); // we implement the Compile that takes FusedNodeAndGraph instances @@ -40,6 +45,10 @@ class QNNExecutionProvider : public IExecutionProvider { const InlinedVector GetEpContextNodes() const override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; + + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; + private: bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger) const; @@ -72,6 +81,68 @@ class QNNExecutionProvider : public IExecutionProvider { int32_t vtcm_size_in_mb_ = 0; std::unique_ptr qnn_ep_context_model_; ModelMetadefIdGenerator metadef_id_generator_; + uint32_t device_id_ = 0; + qnn::HtpPerformanceMode default_htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; + uint32_t default_rpc_control_latency_ = 0; + + class PerThreadContext final { + public: + PerThreadContext(qnn::QnnBackendManager* qnn_backend_manager, + uint32_t device_id, uint32_t core_id, + qnn::HtpPerformanceMode default_htp_performance_mode, + uint32_t default_rpc_control_latency); + ~PerThreadContext(); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext); + + bool IsHtpPowerConfigIdValid() { return is_htp_power_config_id_valid_; } + + uint32_t GetHtpPowerConfigId() { return htp_power_config_id_; } + + private: + bool is_htp_power_config_id_valid_ = false; + uint32_t htp_power_config_id_ = 0; + qnn::QnnBackendManager* qnn_backend_manager_; + }; + + using PerThreadContextMap = std::unordered_map>; + + struct ContextCacheHolder { + ContextCacheHolder() { + RunOnUnload([&, weak_p_ = std::weak_ptr(p)] { + if (auto lock = weak_p_.lock()) + p.reset(); + }); + } + + std::shared_ptr p = std::make_shared(); + }; + + static const std::shared_ptr& PerThreadContextCache() { + thread_local const ContextCacheHolder per_thread_context_cache; + return per_thread_context_cache.p; + } + + struct PerThreadContextState { + // contexts that are currently active + std::set, std::owner_less>> active_contexts; + // contexts available for reuse + std::vector> retired_context_pool; + // weak references to thread local caches from which this QNNExecutionProvider instance's entry should be removed + // upon destruction + std::set, std::owner_less>> + caches_to_update_on_destruction; + // synchronizes access to PerThreadContextState members + OrtMutex mutex; + }; + + // The execution provider maintains the PerThreadContexts in this structure. + // Synchronization is required to update the contained structures. + // On the other hand, access to an individual PerThreadContext is assumed to be from a single thread at a time, + // so synchronization is not required for that. + mutable PerThreadContextState context_state_; + + PerThreadContext& GetPerThreadContext() const; + void ReleasePerThreadContext() const; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index ee3578326ac6d..3fd5423681b81 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -353,7 +353,7 @@ Status ROCMExecutionProvider::Sync() const { return Status::OK(); } -Status ROCMExecutionProvider::OnRunStart() { +Status ROCMExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { // always set ROCM device when session::Run() in case it runs in a worker thread HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId())); if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { @@ -363,7 +363,7 @@ Status ROCMExecutionProvider::OnRunStart() { return Status::OK(); } -Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) { +Status ROCMExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) { if (GetPerThreadContext().IsGraphCaptureAllowed()) { GetPerThreadContext().CaptureEnd(); diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index 37d5f7b42210f..da671d9e863bb 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -28,9 +28,9 @@ class ROCMExecutionProvider : public IExecutionProvider { Status Sync() const override; - Status OnRunStart() override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; const void* GetExecutionHandle() const noexcept override { // The ROCM interface does not return anything interesting. diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index c0bf29e486c88..81346671f2aad 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1818,11 +1818,11 @@ std::unique_ptr TensorrtExecutionProvider::GetDataTransfer() cons return onnxruntime::CreateGPUDataTransfer(); } -Status TensorrtExecutionProvider::OnRunStart() { +Status TensorrtExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); } -Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { +Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { if (sync_stream && external_stream_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index e86f997b6597a..26f6b2dcc3020 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -233,8 +233,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; - Status OnRunStart() override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; ProviderOptions GetProviderOptions() const override { return TensorrtExecutionProviderInfo::ToProviderOptions(info_); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index b045f30a59797..efd7db4ea7629 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2289,8 +2289,8 @@ Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options, // TODO: only call OnRunStart for all providers in-use for (auto& xp : execution_providers_) { // call OnRunStart and add to exec_providers_to_stop if successful - auto start_func = [&xp, &exec_providers_to_stop]() { - auto status = xp->OnRunStart(); + auto start_func = [&xp, &exec_providers_to_stop, run_options]() { + auto status = xp->OnRunStart(run_options); if (status.IsOK()) exec_providers_to_stop.push_back(xp.get()); @@ -2326,7 +2326,7 @@ Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options, // info all execution providers InferenceSession:Run ended for (auto* xp : exec_providers_to_stop) { - auto status = xp->OnRunEnd(/*sync_stream*/ false); + auto status = xp->OnRunEnd(/*sync_stream*/ false, run_options); ORT_CHECK_AND_SET_RETVAL(status); } @@ -2448,8 +2448,8 @@ Status InferenceSession::Run(const RunOptions& run_options, // TODO: only call OnRunStart for all providers in-use for (auto& xp : execution_providers_) { // call OnRunStart and add to exec_providers_to_stop if successful - auto start_func = [&xp, &exec_providers_to_stop]() { - auto status = xp->OnRunStart(); + auto start_func = [&xp, &exec_providers_to_stop, &run_options]() { + auto status = xp->OnRunStart(run_options); if (status.IsOK()) exec_providers_to_stop.push_back(xp.get()); @@ -2490,7 +2490,7 @@ Status InferenceSession::Run(const RunOptions& run_options, // info all execution providers InferenceSession:Run ended for (auto* xp : exec_providers_to_stop) { bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0"; - auto status = xp->OnRunEnd(synchronize_execution_providers); + auto status = xp->OnRunEnd(synchronize_execution_providers, run_options); ORT_CHECK_AND_SET_RETVAL(status); } diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc index a70e439cdf755..5505d689381c9 100644 --- a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc @@ -22,6 +22,8 @@ TEST(TestDeferredRelease, WithArena) { CUDAExecutionProvider ep(info); AllocatorPtr gpu_alloctor = ep.CreatePreferredAllocators()[0]; + RunOptions run_opts; + run_opts.run_tag = "log1"; // Allocator for call cudaMallocHost and cudaFreeHost // For details, see CUDAPinnedAllocator in cuda_allocator.cc. AllocatorPtr cpu_pinned_alloc = ep.CreatePreferredAllocators()[1]; @@ -31,7 +33,7 @@ TEST(TestDeferredRelease, WithArena) { // 10 MB const size_t n_bytes = 10 * 1000000; const int64_t n_allocs = 64; - ORT_THROW_IF_ERROR(ep.OnRunStart()); + ORT_THROW_IF_ERROR(ep.OnRunStart(run_opts)); for (size_t i = 0; i < n_allocs; ++i) { // Allocate 10MB CUDA pinned memory. auto pinned_buffer = IAllocator::MakeUniquePtr(cpu_pinned_alloc, n_bytes); @@ -44,7 +46,7 @@ TEST(TestDeferredRelease, WithArena) { cpu_pinned_alloc->GetStats(&stats); ASSERT_EQ(stats.num_allocs, n_allocs); ORT_THROW_IF_ERROR(stream.CleanUpOnRunEnd()); - ORT_THROW_IF_ERROR(ep.OnRunEnd(true)); + ORT_THROW_IF_ERROR(ep.OnRunEnd(true, run_opts)); } TEST(TestDeferredRelease, WithoutArena) { @@ -52,6 +54,9 @@ TEST(TestDeferredRelease, WithoutArena) { CUDAExecutionProviderInfo info; CUDAExecutionProvider ep(info); + RunOptions run_opts; + run_opts.run_tag = "log1"; + OrtDevice pinned_device{OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, DEFAULT_CPU_ALLOCATOR_DEVICE_ID}; // Create allocator without BFCArena AllocatorCreationInfo pinned_memory_info( @@ -70,7 +75,7 @@ TEST(TestDeferredRelease, WithoutArena) { // 10 MB const size_t n_bytes = 10 * 1000000; const int64_t n_allocs = 64; - ORT_THROW_IF_ERROR(ep.OnRunStart()); + ORT_THROW_IF_ERROR(ep.OnRunStart(run_opts)); for (size_t i = 0; i < n_allocs; ++i) { // Allocate 10MB CUDA pinned memory. auto pinned_buffer = IAllocator::MakeUniquePtr(cuda_pinned_alloc, n_bytes); @@ -79,7 +84,7 @@ TEST(TestDeferredRelease, WithoutArena) { } ORT_THROW_IF_ERROR(stream.CleanUpOnRunEnd()); - ORT_THROW_IF_ERROR(ep.OnRunEnd(true)); + ORT_THROW_IF_ERROR(ep.OnRunEnd(true, run_opts)); } } // namespace test diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 4e1aef2c40b2b..8f07c2ce77e77 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -7,6 +7,7 @@ #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/onnxruntime_run_options_config_keys.h" #include "core/providers/cpu/cpu_provider_factory.h" // For OrtSessionOptionsAppendExecutionProvider_CPU #include "core/session/inference_session.h" @@ -332,19 +333,23 @@ static void CreateModelInMemory(std::unique_ptr& result, static void RunSessionAndVerify(InferenceSession& session, const RunOptions& run_options, const NameMLValMap& feeds, const std::vector& output_names, const std::vector>& output_shapes, - const std::vector>& expected_values) { - std::vector fetches; - auto status = session.Run(run_options, feeds, output_names, &fetches); - ASSERT_TRUE(status.IsOK()); - - for (size_t i = 0; i < fetches.size(); i++) { - auto& tensor = fetches[i].Get(); - TensorShape expected_shape(output_shapes[i]); - ASSERT_EQ(expected_shape, tensor.Shape()); - - gsl::span actual = tensor.DataAsSpan(); - gsl::span expected(expected_values[i].data(), expected_values[i].size()); - ASSERT_EQ(expected, actual); + const std::vector>& expected_values, + int loop_count = 10) { + // Let it run for a while + for (int it = 0; it < loop_count; ++it) { + std::vector fetches; + auto status = session.Run(run_options, feeds, output_names, &fetches); + ASSERT_TRUE(status.IsOK()); + + for (size_t i = 0; i < fetches.size(); i++) { + auto& tensor = fetches[i].Get(); + TensorShape expected_shape(output_shapes[i]); + ASSERT_EQ(expected_shape, tensor.Shape()); + + gsl::span actual = tensor.DataAsSpan(); + gsl::span expected(expected_values[i].data(), expected_values[i].size()); + ASSERT_EQ(expected, actual); + } } } @@ -404,11 +409,11 @@ TEST_F(QnnCPUBackendTests, MultithreadSessionRun) { std::vector threads; constexpr int num_threads = 5; - + constexpr int loop_count = 10; for (int i = 0; i < num_threads; i++) { threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, model->builder.feeds_, model->builder.output_names_, - output_shapes, output_values)); + output_shapes, output_values, loop_count)); } for (auto& th : threads) { @@ -484,11 +489,191 @@ TEST_F(QnnHTPBackendTests, MultithreadSessionRun) { std::vector threads; constexpr int num_threads = 5; + constexpr int loop_count = 10; for (int i = 0; i < num_threads; i++) { threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, model->builder.feeds_, model->builder.output_names_, - output_shapes, output_values)); + output_shapes, output_values, loop_count)); + } + + for (auto& th : threads) { + th.join(); + } +} + +// Tests running a single session in multiple threads on the HTP backend with run option to set power config +TEST_F(QnnHTPBackendTests, MultithreadHtpPowerCfgSessionRunOption) { + std::unique_ptr model; + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector shape = {1, 3, 2}; + std::vector> output_shapes = {shape}; + std::vector> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}}; + + CreateModelInMemory(model, + QDQBuildAdd3Tensors(TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data)), + "add3.qdq"); + + SessionOptions session_opts; + session_opts.session_logid = "logger0"; + + InferenceSession session_obj{session_opts, GetEnvironment()}; + onnxruntime::ProviderOptions options; + +#if defined(_WIN32) + options["backend_path"] = "QnnHtp.dll"; +#else + options["backend_path"] = "libQnnHtp.so"; +#endif + + auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts); + EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK()); + + auto status = session_obj.Load(model->model_data.data(), static_cast(model->model_data.size())); + ASSERT_TRUE(status.IsOK()); + status = session_obj.Initialize(); + ASSERT_TRUE(status.IsOK()); + + std::vector threads; + constexpr int num_threads = 5; + constexpr int loop_count = 10; + + std::vector perf_modes{ + "burst", "balanced", "default", "high_performance", "high_power_saver", + "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver"}; + + size_t post_i = perf_modes.size() - 1; + ASSERT_TRUE(post_i > num_threads); + for (int i = 0; i < num_threads; ++i, --post_i) { + RunOptions run_opts; + run_opts.run_tag = session_opts.session_logid; + auto rt = run_opts.config_options.AddConfigEntry(kOrtRunOptionsConfigQnnPerfMode, perf_modes[i].c_str()); + ASSERT_TRUE(rt.IsOK()); + rt = run_opts.config_options.AddConfigEntry(kOrtRunOptionsConfigQnnPerfModePostRun, perf_modes[post_i].c_str()); + ASSERT_TRUE(rt.IsOK()); + + threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, + model->builder.feeds_, model->builder.output_names_, + output_shapes, output_values, loop_count)); + } + + for (auto& th : threads) { + th.join(); + } +} + +// Tests running a single session in multiple threads on the HTP backend with EP option to set default power config +TEST_F(QnnHTPBackendTests, MultithreadDefaultHtpPowerCfgFromEpOption) { + std::unique_ptr model; + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector shape = {1, 3, 2}; + std::vector> output_shapes = {shape}; + std::vector> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}}; + + CreateModelInMemory(model, + QDQBuildAdd3Tensors(TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data)), + "add3.qdq"); + + SessionOptions session_opts; + session_opts.session_logid = "logger0"; + + RunOptions run_opts; + run_opts.run_tag = session_opts.session_logid; + + InferenceSession session_obj{session_opts, GetEnvironment()}; + onnxruntime::ProviderOptions options; + +#if defined(_WIN32) + options["backend_path"] = "QnnHtp.dll"; +#else + options["backend_path"] = "libQnnHtp.so"; +#endif + options["htp_performance_mode"] = "burst"; + + auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts); + EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK()); + + auto status = session_obj.Load(model->model_data.data(), static_cast(model->model_data.size())); + ASSERT_TRUE(status.IsOK()); + status = session_obj.Initialize(); + ASSERT_TRUE(status.IsOK()); + + std::vector threads; + constexpr int num_threads = 5; + constexpr int loop_count = 10; + + for (int i = 0; i < num_threads; i++) { + threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, + model->builder.feeds_, model->builder.output_names_, + output_shapes, output_values, loop_count)); + } + + for (auto& th : threads) { + th.join(); + } +} + +// Tests running a single session in multiple threads on the HTP backend with +// EP option to set default power config + run option to set power config for each run +TEST_F(QnnHTPBackendTests, MultithreadHtpPowerCfgDefaultAndRunOption) { + std::unique_ptr model; + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector shape = {1, 3, 2}; + std::vector> output_shapes = {shape}; + std::vector> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}}; + + CreateModelInMemory(model, + QDQBuildAdd3Tensors(TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data)), + "add3.qdq"); + + SessionOptions session_opts; + session_opts.session_logid = "logger0"; + + InferenceSession session_obj{session_opts, GetEnvironment()}; + onnxruntime::ProviderOptions options; + +#if defined(_WIN32) + options["backend_path"] = "QnnHtp.dll"; +#else + options["backend_path"] = "libQnnHtp.so"; +#endif + options["htp_performance_mode"] = "burst"; + + auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts); + EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK()); + + auto status = session_obj.Load(model->model_data.data(), static_cast(model->model_data.size())); + ASSERT_TRUE(status.IsOK()); + status = session_obj.Initialize(); + ASSERT_TRUE(status.IsOK()); + + std::vector threads; + constexpr int num_threads = 5; + constexpr int loop_count = 10; + + std::vector perf_modes{ + "burst", "balanced", "default", "high_performance", "high_power_saver", + "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver"}; + + size_t post_i = perf_modes.size() - 1; + ASSERT_TRUE(post_i > num_threads); + for (int i = 0; i < num_threads; ++i, --post_i) { + RunOptions run_opts; + run_opts.run_tag = session_opts.session_logid; + auto rt = run_opts.config_options.AddConfigEntry(kOrtRunOptionsConfigQnnPerfMode, perf_modes[i].c_str()); + ASSERT_TRUE(rt.IsOK()); + rt = run_opts.config_options.AddConfigEntry(kOrtRunOptionsConfigQnnPerfModePostRun, perf_modes[post_i].c_str()); + ASSERT_TRUE(rt.IsOK()); + + threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, + model->builder.feeds_, model->builder.output_names_, + output_shapes, output_values, loop_count)); } for (auto& th : threads) {