diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 31c988f500779..c1cc69edc17d8 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -33,6 +33,8 @@ class Node; #include "core/framework/stream_handles.h" #include "core/framework/tuning_context.h" +struct OrtRunOptions; + namespace onnxruntime { /** @@ -51,6 +53,8 @@ struct NodeComputeInfo { DestroyFunctionStateFunc release_state_func; }; +using RunOptions = OrtRunOptions; + enum class DataLayout { NCHW, NHWC, @@ -184,7 +188,7 @@ class IExecutionProvider { Run may not be finished on device This function should be regarded as the point after which a new Run would start to submit commands from CPU */ - virtual common::Status OnRunStart() { return Status::OK(); } + virtual common::Status OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); } /** Called when InferenceSession::Run ended @@ -192,7 +196,9 @@ class IExecutionProvider { may not be finished on device This function should be regarded as the point that all commands of current Run has been submmited by CPU */ - virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); } + virtual common::Status OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) { + return Status::OK(); + } /** Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for diff --git a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h index 1f5fcd50e185c..b0a17e175fef3 100644 --- a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h @@ -30,3 +30,15 @@ static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memor // Per default it will be set to '0' // Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream. static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers"; + +// Set HTP performance mode for QNN HTP backend before session run. +// options for HTP performance mode: "burst", "balanced", "default", "high_performance", +// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver", +// "sustained_high_performance". Default to "default". +static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode"; + +// Set HTP performance mode for QNN HTP backend post session run. +static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run"; + +// Set RPC control latency for QNN HTP backend +static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency"; diff --git a/onnxruntime/core/framework/stream_execution_context.cc b/onnxruntime/core/framework/stream_execution_context.cc index 875e7f395bfa8..dd7f4d35b34bd 100644 --- a/onnxruntime/core/framework/stream_execution_context.cc +++ b/onnxruntime/core/framework/stream_execution_context.cc @@ -181,11 +181,13 @@ void RunSince(size_t stream_idx, StreamExecutionContext& ctx, SessionScope& sess } #ifdef USE_CANN + // Leave it to CANN EP to fill the gap if they want to use run_options + static onnxruntime::RunOptions run_options; // For CANN EP, it is necessary to explicitly create a corresponding Context for each thread in the thread pool, // which is different from CUDA Runtime API, but similar to CUDA Driver API. auto& execution_providers = ctx.GetSessionState().GetExecutionProviders(); for (auto& xp : execution_providers) { - auto status = xp->OnRunStart(); + auto status = xp->OnRunStart(run_options); if (!status.IsOK()) { ctx.SetStatus(status); return; diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index 752b742805a7c..9a242919665bb 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1045,7 +1045,7 @@ CANNExecutionProvider::~CANNExecutionProvider() { } // All threads share the same context and stream -Status CANNExecutionProvider::OnRunStart() { +Status CANNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { CANN_RETURN_IF_ERROR(aclrtSetDevice(info_.device_id)); return Status::OK(); diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.h b/onnxruntime/core/providers/cann/cann_execution_provider.h index 63ae980869c65..d83bd88d6958f 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.h +++ b/onnxruntime/core/providers/cann/cann_execution_provider.h @@ -33,7 +33,7 @@ class CANNExecutionProvider : public IExecutionProvider { explicit CANNExecutionProvider(const CANNExecutionProviderInfo& info); virtual ~CANNExecutionProvider(); - Status OnRunStart() override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; template Status Fill(Tensor* y, void* addr, aclrtStream stream) const { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 77e682e05a2a4..5646183e1acd3 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -386,7 +386,7 @@ Status CUDAExecutionProvider::Sync() const { return Status::OK(); } -Status CUDAExecutionProvider::OnRunStart() { +Status CUDAExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { // always set CUDA device when session::Run() in case it runs in a worker thread CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId())); if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { @@ -396,7 +396,7 @@ Status CUDAExecutionProvider::OnRunStart() { return Status::OK(); } -Status CUDAExecutionProvider::OnRunEnd(bool sync_stream) { +Status CUDAExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) { if (GetPerThreadContext().IsGraphCaptureAllowed()) { GetPerThreadContext().CaptureEnd(); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 55f0b5570e0ee..5f62f313b86a2 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -29,9 +29,9 @@ class CUDAExecutionProvider : public IExecutionProvider { Status Sync() const override; - Status OnRunStart() override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; DataLayout GetPreferredLayout() const override; @@ -115,6 +115,7 @@ class CUDAExecutionProvider : public IExecutionProvider { PerThreadContext(OrtDevice::DeviceId device_id, cudaStream_t stream, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy, CUDAExecutionProviderExternalAllocatorInfo external_alloc_info, OrtArenaCfg* arena_cfg); ~PerThreadContext(); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext); cublasHandle_t CublasHandle() const { return cublas_handle_; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 5617bc7bdcac6..841d6244a983e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -270,7 +270,7 @@ namespace Dml return m_impl->OnSessionInitializationEnd(); } - virtual onnxruntime::Status Sync() const final override + onnxruntime::Status Sync() const final override { // Completely wait until the device has completed all preceding tasks. // The application could have called SynchronizeBoundOutputs(). @@ -278,7 +278,7 @@ namespace Dml return Status::OK(); } - virtual onnxruntime::Status OnRunEnd(bool /*sync_stream*/) final override + onnxruntime::Status OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) final override { // Flush any pending work to the GPU, but don't block for completion, permitting it // to overlap other work. diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 799d4172f2b64..62c3981682cfc 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -756,7 +756,7 @@ std::unique_ptr JsExecutionProvider::GetDataTransfer JsExecutionProvider::~JsExecutionProvider() { } -Status JsExecutionProvider::OnRunStart() { +Status JsExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured()) { LOGS(*GetLogger(), INFO) << "Capturing the webgpu graph for this model"; EM_ASM({ Module.jsepCaptureBegin(); }); @@ -764,7 +764,7 @@ Status JsExecutionProvider::OnRunStart() { return Status::OK(); } -Status JsExecutionProvider::OnRunEnd(bool sync_stream) { +Status JsExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { if (IsGraphCaptureEnabled() && !IsGraphCaptured()) { if (IsGraphCaptureAllowed()) { EM_ASM({ Module.jsepCaptureEnd(); }); diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index 91a3256ec2bd5..b4518c67d1e60 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -59,8 +59,8 @@ class JsExecutionProvider : public IExecutionProvider { std::vector CreatePreferredAllocators() override; - Status OnRunStart() override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; bool IsGraphCaptureEnabled() const override; bool IsGraphCaptured() const override; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 40e76a0a67782..50782569ee80a 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1383,11 +1383,11 @@ Status MIGraphXExecutionProvider::Sync() const { return Status::OK(); } -Status MIGraphXExecutionProvider::OnRunStart() { +Status MIGraphXExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); } -Status MIGraphXExecutionProvider::OnRunEnd(bool) { +Status MIGraphXExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) { auto status = hipStreamQuery(stream_); if (status != hipSuccess) { diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index d582338c7e067..c3617f409e72c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -56,9 +56,9 @@ class MIGraphXExecutionProvider : public IExecutionProvider { #ifdef MIGRAPHX_STREAM_SYNC Status Sync() const override; - Status OnRunStart() override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; #endif std::vector> diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 5f0b87c7cb9d7..aa53c4d0b4df5 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -629,11 +629,6 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_ LOGS(logger, VERBOSE) << "CreateContext succeed."; } - if (htp_performance_mode_ != HtpPerformanceMode::kHtpDefault) { - ORT_RETURN_IF_ERROR(SetHtpPowerConfig()); - LOGS(logger, VERBOSE) << "SetHtpPowerConfig succeed."; - } - LOGS(logger, VERBOSE) << "QNN SetupBackend succeed"; backend_setup_completed_ = true; @@ -641,7 +636,7 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_ return Status::OK(); } -Status QnnBackendManager::SetHtpPowerConfig() { +Status QnnBackendManager::CreateHtpPowerCfgId(uint32_t device_id, uint32_t core_id, uint32_t& htp_power_config_id) { QnnDevice_Infrastructure_t qnn_device_infra = nullptr; auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); @@ -651,23 +646,37 @@ Status QnnBackendManager::SetHtpPowerConfig() { "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; // Get power client id - status = htp_perf_infra.createPowerConfigId(/*device_id=*/0, /*core_id=*/0, &htp_power_config_client_id_); + status = htp_perf_infra.createPowerConfigId(device_id, core_id, &htp_power_config_id); ORT_RETURN_IF(QNN_SUCCESS != status, "createPowerConfigId failed."); + return Status::OK(); +} + +Status QnnBackendManager::SetHtpPowerConfig(uint32_t htp_power_config_client_id, + HtpPerformanceMode htp_performance_mode) { + QnnDevice_Infrastructure_t qnn_device_infra = nullptr; + auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); + ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); + + auto* htp_infra = static_cast(qnn_device_infra); + ORT_RETURN_IF(QNN_HTP_DEVICE_INFRASTRUCTURE_TYPE_PERF != htp_infra->infraType, + "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); + QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; + constexpr const int kNumConfigs = 1; std::vector power_configs( kNumConfigs); QnnHtpPerfInfrastructure_PowerConfig_t& dcvs_config = power_configs[0]; dcvs_config.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_DCVS_V3; QnnHtpPerfInfrastructure_DcvsV3_t& dcvs_v3 = dcvs_config.dcvsV3Config; - dcvs_v3.contextId = htp_power_config_client_id_; + dcvs_v3.contextId = htp_power_config_client_id; dcvs_v3.setSleepDisable = 0; dcvs_v3.sleepDisable = 0; dcvs_v3.setDcvsEnable = 1; dcvs_v3.dcvsEnable = kDcvsDisable; dcvs_v3.powerMode = QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_PERFORMANCE_MODE; // choose performance mode - switch (htp_performance_mode_) { + switch (htp_performance_mode) { case HtpPerformanceMode::kHtpBurst: dcvs_v3.setSleepLatency = 1; // true dcvs_v3.sleepLatency = kSleepMinLatency; @@ -766,25 +775,40 @@ Status QnnBackendManager::SetHtpPowerConfig() { dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM_PLUS; break; default: - ORT_THROW("Invalid performance profile %d", static_cast(htp_performance_mode_)); + ORT_THROW("Invalid performance profile %d", static_cast(htp_performance_mode)); break; } std::vector perf_power_configs_ptr = ObtainNullTermPtrVector(power_configs); - status = htp_perf_infra.setPowerConfig(htp_power_config_client_id_, perf_power_configs_ptr.data()); + status = htp_perf_infra.setPowerConfig(htp_power_config_client_id, perf_power_configs_ptr.data()); ORT_RETURN_IF(QNN_SUCCESS != status, "setPowerConfig failed for HTP performance mode."); - // Set rpc control latency here, but note that v68 doesn't support rpc polling mode. - if (rpc_control_latency_ != 0) { + return Status::OK(); +} + +Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_id, + uint32_t rpc_control_latency) { + if (rpc_control_latency != 0) { + QnnDevice_Infrastructure_t qnn_device_infra = nullptr; + auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); + ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); + + auto* htp_infra = static_cast(qnn_device_infra); + ORT_RETURN_IF(QNN_HTP_DEVICE_INFRASTRUCTURE_TYPE_PERF != htp_infra->infraType, + "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); + QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; + + // Set rpc control latency here, but note that v68 doesn't support rpc polling mode. constexpr int kNumRpcPollingPowerConfigs = 2; std::vector rpc_power_configs(kNumRpcPollingPowerConfigs); - QnnHtpPerfInfrastructure_PowerConfig_t& rpc_control_latency = rpc_power_configs[0]; + QnnHtpPerfInfrastructure_PowerConfig_t& rpc_control_latency_cfg = rpc_power_configs[0]; // v68 doesn't support this. QnnHtpPerfInfrastructure_PowerConfig_t& rpc_polling_time = rpc_power_configs[1]; - rpc_control_latency.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY; + rpc_control_latency_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY; rpc_polling_time.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_POLLING_TIME; - rpc_control_latency.rpcControlLatencyConfig = rpc_control_latency_; - perf_power_configs_ptr = ObtainNullTermPtrVector(rpc_power_configs); - status = htp_perf_infra.setPowerConfig(htp_power_config_client_id_, perf_power_configs_ptr.data()); + rpc_control_latency_cfg.rpcControlLatencyConfig = rpc_control_latency; + std::vector perf_power_configs_ptr = + ObtainNullTermPtrVector(rpc_power_configs); + status = htp_perf_infra.setPowerConfig(htp_power_config_client_id, perf_power_configs_ptr.data()); ORT_RETURN_IF(QNN_SUCCESS != status, "setPowerConfig failed for RPC control latency."); } @@ -805,11 +829,7 @@ void QnnBackendManager::Split(std::vector& split_string, } } -Status QnnBackendManager::DestroyHTPPowerConfigID() { - if (htp_performance_mode_ == HtpPerformanceMode::kHtpDefault) { - return Status::OK(); - } - +Status QnnBackendManager::DestroyHTPPowerConfigID(uint32_t htp_power_config_id) { QnnDevice_Infrastructure_t qnn_device_infra = nullptr; auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); @@ -819,7 +839,7 @@ Status QnnBackendManager::DestroyHTPPowerConfigID() { "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; - Qnn_ErrorHandle_t destroy_ret = htp_perf_infra.destroyPowerConfigId(htp_power_config_client_id_); + Qnn_ErrorHandle_t destroy_ret = htp_perf_infra.destroyPowerConfigId(htp_power_config_id); ORT_RETURN_IF(QNN_SUCCESS != destroy_ret, "destroyPowerConfigId failed."); return Status::OK(); } @@ -829,12 +849,7 @@ void QnnBackendManager::ReleaseResources() { return; } - auto result = DestroyHTPPowerConfigID(); - if (Status::OK() != result) { - ORT_THROW("Failed to DestroyHTPPowerConfigID."); - } - - result = ReleaseContext(); + auto result = ReleaseContext(); if (Status::OK() != result) { ORT_THROW("Failed to ReleaseContext."); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 36375522b5a0a..ff97c4c3a991c 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -33,8 +33,6 @@ class QnnBackendManager { public: QnnBackendManager(std::string&& backend_path, ProfilingLevel profiling_level, - uint32_t rpc_control_latency, - HtpPerformanceMode htp_performance_mode, ContextPriority context_priority, std::string&& qnn_saver_path, uint32_t device_id, @@ -42,8 +40,6 @@ class QnnBackendManager { uint32_t soc_model) : backend_path_(backend_path), profiling_level_(profiling_level), - rpc_control_latency_(rpc_control_latency), - htp_performance_mode_(htp_performance_mode), context_priority_(context_priority), qnn_saver_path_(qnn_saver_path), device_id_(device_id), @@ -92,7 +88,13 @@ class QnnBackendManager { Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context); - Status SetHtpPowerConfig(); + Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id); + + Status SetHtpPowerConfig(uint32_t htp_power_config_client_id, + HtpPerformanceMode htp_performance_mode); + + Status SetRpcControlLatency(uint32_t htp_power_config_client_id, + uint32_t rpc_control_latency); const QNN_INTERFACE_VER_TYPE& GetQnnInterface() { return qnn_interface_; } @@ -141,6 +143,8 @@ class QnnBackendManager { const std::string& GetSdkVersion() { return sdk_build_version_; } + Status DestroyHTPPowerConfigID(uint32_t htp_power_config_id); + private: void* LoadLib(const char* file_name, int flags, std::string& error_msg); @@ -150,8 +154,6 @@ class QnnBackendManager { Status UnloadLib(void* handle); - Status DestroyHTPPowerConfigID(); - void* LibFunction(void* handle, const char* symbol, std::string& error_msg); template @@ -232,15 +234,12 @@ class QnnBackendManager { QnnBackendType qnn_backend_type_ = QnnBackendType::CPU; Qnn_ProfileHandle_t profile_backend_handle_ = nullptr; std::vector op_package_paths_; - uint32_t rpc_control_latency_ = 0; - HtpPerformanceMode htp_performance_mode_; ContextPriority context_priority_; std::string sdk_build_version_ = ""; #ifdef _WIN32 std::set mod_handles_; #endif const std::string qnn_saver_path_; - uint32_t htp_power_config_client_id_ = 0; uint32_t device_id_ = 0; QnnHtpDevice_Arch_t htp_arch_ = QNN_HTP_DEVICE_ARCH_NONE; uint32_t soc_model_ = QNN_SOC_MODEL_UNKNOWN; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index b58f6e10df94c..7f70dbc56f05c 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -7,6 +7,7 @@ #include "core/framework/compute_capability.h" #include "core/graph/graph_viewer.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/onnxruntime_run_options_config_keys.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/kernel_registry.h" #include "core/platform/env.h" @@ -18,11 +19,36 @@ #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_def.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" +#include "core/framework/run_options.h" namespace onnxruntime { constexpr const char* QNN = "QNN"; +static std::unique_ptr>> s_run_on_unload_; + +void RunOnUnload(std::function function) { + OrtMutex mutex; + std::lock_guard guard(mutex); + if (!s_run_on_unload_) { + s_run_on_unload_ = std::make_unique>>(); + } + s_run_on_unload_->push_back(std::move(function)); +} + +struct OnUnload { + ~OnUnload() { + if (!s_run_on_unload_) + return; + + for (auto& function : *s_run_on_unload_) + function(); + + s_run_on_unload_.reset(); + } + +} g_on_unload; + static void ParseProfilingLevel(std::string profiling_level_string, qnn::ProfilingLevel& profiling_level) { std::transform(profiling_level_string.begin(), @@ -193,18 +219,18 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } static const std::string RPC_CONTROL_LANTENCY = "rpc_control_latency"; - uint32_t rpc_control_latency = 0; auto latency_pos = provider_options_map.find(RPC_CONTROL_LANTENCY); if (latency_pos != provider_options_map.end()) { - rpc_control_latency = static_cast(std::stoul(latency_pos->second)); - LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency; + default_rpc_control_latency_ = static_cast(std::stoul(latency_pos->second)); + LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << default_rpc_control_latency_; } - qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault; + // default_htp_performance_mode from QNN EP option. + // set it once only for each thread as default so user don't need to set it for every session run static const std::string HTP_PERFORMANCE_MODE = "htp_performance_mode"; auto htp_performance_mode_pos = provider_options_map.find(HTP_PERFORMANCE_MODE); if (htp_performance_mode_pos != provider_options_map.end()) { - ParseHtpPerformanceMode(htp_performance_mode_pos->second, htp_performance_mode); + ParseHtpPerformanceMode(htp_performance_mode_pos->second, default_htp_performance_mode_); } htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault; @@ -241,15 +267,14 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } static const std::string QNN_DEVICE_ID = "device_id"; - uint32_t device_id = 0; auto dev_id_pos = provider_options_map.find(QNN_DEVICE_ID); if (dev_id_pos != provider_options_map.end()) { int value = std::stoi(dev_id_pos->second); if (value < 0) { LOGS_DEFAULT(WARNING) << "Invalid device ID '" << value - << "', only >= 0 allowed. Set to " << device_id << "."; + << "', only >= 0 allowed. Set to " << device_id_ << "."; } else { - device_id = static_cast(value); + device_id_ = static_cast(value); } } @@ -276,15 +301,23 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio qnn_backend_manager_ = std::make_unique( std::move(backend_path), profiling_level, - rpc_control_latency, - htp_performance_mode, context_priority, std::move(qnn_saver_path), - device_id, + device_id_, htp_arch, soc_model); } +QNNExecutionProvider::~QNNExecutionProvider() { + // clean up thread local context caches + std::lock_guard lock(context_state_.mutex); + for (const auto& cache_weak : context_state_.caches_to_update_on_destruction) { + const auto cache = cache_weak.lock(); + if (!cache) continue; + ORT_IGNORE_RETURN_VALUE(cache->erase(this)); + } +} + bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::unordered_map& node_unit_supported_result, const logging::Logger& logger) const { @@ -706,4 +739,147 @@ const InlinedVector QNNExecutionProvider::GetEpContextNodes() const return ep_context_nodes; } + +QNNExecutionProvider::PerThreadContext::PerThreadContext(qnn::QnnBackendManager* qnn_backend_manager, + uint32_t device_id, + uint32_t core_id, + qnn::HtpPerformanceMode default_htp_performance_mode, + uint32_t default_rpc_control_latency) + : qnn_backend_manager_(qnn_backend_manager) { + Status rt = qnn_backend_manager_->CreateHtpPowerCfgId(device_id, core_id, htp_power_config_id_); + is_htp_power_config_id_valid_ = rt.IsOK(); + // default_htp_performance_mode and default_rpc_control_latency are from QNN EP option. + // set it once only for each thread as default so user don't need to set it for every session run + if (is_htp_power_config_id_valid_) { + if (qnn::HtpPerformanceMode::kHtpDefault != default_htp_performance_mode) { + ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetHtpPowerConfig(htp_power_config_id_, + default_htp_performance_mode)); + } + if (default_rpc_control_latency > 0) { + ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetRpcControlLatency(htp_power_config_id_, + default_rpc_control_latency)); + } + } +} + +QNNExecutionProvider::PerThreadContext::~PerThreadContext() { + if (is_htp_power_config_id_valid_) { + ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->DestroyHTPPowerConfigID(htp_power_config_id_)); + } +} + +QNNExecutionProvider::PerThreadContext& QNNExecutionProvider::GetPerThreadContext() const { + const auto& per_thread_context_cache = PerThreadContextCache(); + + // try to use cached context + auto cached_context_it = per_thread_context_cache->find(this); + if (cached_context_it != per_thread_context_cache->end()) { + auto cached_context = cached_context_it->second.lock(); + ORT_ENFORCE(cached_context); + return *cached_context; + } + + // get context and update cache + std::shared_ptr context; + { + std::lock_guard lock(context_state_.mutex); + + // get or create a context + if (context_state_.retired_context_pool.empty()) { + uint32_t core_id = 0; + context = std::make_shared(qnn_backend_manager_.get(), device_id_, core_id, + default_htp_performance_mode_, default_rpc_control_latency_); + } else { + context = context_state_.retired_context_pool.back(); + context_state_.retired_context_pool.pop_back(); + } + + // insert into active_contexts, should not already be present + const auto active_contexts_insert_result = context_state_.active_contexts.insert(context); + ORT_ENFORCE(active_contexts_insert_result.second); + + // insert into caches_to_update_on_destruction, may already be present + ORT_IGNORE_RETURN_VALUE(context_state_.caches_to_update_on_destruction.insert(per_thread_context_cache)); + } + + per_thread_context_cache->insert(std::make_pair(this, context)); + + return *context; +} + +void QNNExecutionProvider::ReleasePerThreadContext() const { + const auto& per_thread_context_cache = PerThreadContextCache(); + + auto cached_context_it = per_thread_context_cache->find(this); + ORT_ENFORCE(cached_context_it != per_thread_context_cache->end()); + auto cached_context = cached_context_it->second.lock(); + ORT_ENFORCE(cached_context); + + { + std::lock_guard lock(context_state_.mutex); + context_state_.active_contexts.erase(cached_context); + context_state_.retired_context_pool.push_back(cached_context); + } + + per_thread_context_cache->erase(cached_context_it); +} + +Status QNNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { + auto backend_type = qnn_backend_manager_->GetQnnBackendType(); + if (qnn::QnnBackendType::HTP != backend_type && qnn::QnnBackendType::DSP != backend_type) { + return Status::OK(); + } + + std::string htp_perf_mode = ""; + qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault; + if (run_options.config_options.TryGetConfigEntry(kOrtRunOptionsConfigQnnPerfMode, htp_perf_mode)) { + // set power mode + ParseHtpPerformanceMode(htp_perf_mode, htp_performance_mode); + } + + std::string rpc_latency = ""; + uint32_t rpc_control_latency = 0; + if (run_options.config_options.TryGetConfigEntry(kOrtRunOptionsConfigQnnRpcControlLatency, rpc_latency)) { + rpc_control_latency = static_cast(std::stoul(rpc_latency)); + LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency; + } + + if (GetPerThreadContext().IsHtpPowerConfigIdValid()) { + if (qnn::HtpPerformanceMode::kHtpDefault != htp_performance_mode) { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(), + htp_performance_mode)); + } + + if (rpc_control_latency > 0) { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetRpcControlLatency(GetPerThreadContext().GetHtpPowerConfigId(), + rpc_control_latency)); + } + } + + return Status::OK(); +} + +Status QNNExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& run_options) { + auto backend_type = qnn_backend_manager_->GetQnnBackendType(); + if (qnn::QnnBackendType::HTP != backend_type && qnn::QnnBackendType::DSP != backend_type) { + return Status::OK(); + } + + std::string htp_perf_mode = ""; + qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault; + if (run_options.config_options.TryGetConfigEntry(kOrtRunOptionsConfigQnnPerfModePostRun, htp_perf_mode)) { + // set power mode + ParseHtpPerformanceMode(htp_perf_mode, htp_performance_mode); + } + + if (qnn::HtpPerformanceMode::kHtpDefault != htp_performance_mode) { + if (!GetPerThreadContext().IsHtpPowerConfigIdValid()) { + return Status::OK(); + } + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(), + htp_performance_mode)); + } + + return Status::OK(); +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 09bcb24db4dc2..802f7f2f073e8 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -12,14 +12,19 @@ #include "core/providers/qnn/builder/qnn_model.h" #include "core/providers/qnn/builder/qnn_configs_helper.h" #include "HTP/QnnHtpGraph.h" +#include +#include +#include namespace onnxruntime { +void RunOnUnload(std::function function); + // Logical device representation. class QNNExecutionProvider : public IExecutionProvider { public: explicit QNNExecutionProvider(const ProviderOptions& provider_options_map, const SessionOptions* session_options); - virtual ~QNNExecutionProvider() = default; + virtual ~QNNExecutionProvider(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QNNExecutionProvider); // we implement the Compile that takes FusedNodeAndGraph instances @@ -40,6 +45,10 @@ class QNNExecutionProvider : public IExecutionProvider { const InlinedVector GetEpContextNodes() const override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; + + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; + private: bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::unordered_map& node_unit_supported_result, @@ -73,6 +82,68 @@ class QNNExecutionProvider : public IExecutionProvider { int32_t vtcm_size_in_mb_ = 0; std::unique_ptr qnn_ep_context_model_; ModelMetadefIdGenerator metadef_id_generator_; + uint32_t device_id_ = 0; + qnn::HtpPerformanceMode default_htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; + uint32_t default_rpc_control_latency_ = 0; + + class PerThreadContext final { + public: + PerThreadContext(qnn::QnnBackendManager* qnn_backend_manager, + uint32_t device_id, uint32_t core_id, + qnn::HtpPerformanceMode default_htp_performance_mode, + uint32_t default_rpc_control_latency); + ~PerThreadContext(); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext); + + bool IsHtpPowerConfigIdValid() { return is_htp_power_config_id_valid_; } + + uint32_t GetHtpPowerConfigId() { return htp_power_config_id_; } + + private: + bool is_htp_power_config_id_valid_ = false; + uint32_t htp_power_config_id_ = 0; + qnn::QnnBackendManager* qnn_backend_manager_; + }; + + using PerThreadContextMap = std::unordered_map>; + + struct ContextCacheHolder { + ContextCacheHolder() { + RunOnUnload([&, weak_p_ = std::weak_ptr(p)] { + if (auto lock = weak_p_.lock()) + p.reset(); + }); + } + + std::shared_ptr p = std::make_shared(); + }; + + static const std::shared_ptr& PerThreadContextCache() { + thread_local const ContextCacheHolder per_thread_context_cache; + return per_thread_context_cache.p; + } + + struct PerThreadContextState { + // contexts that are currently active + std::set, std::owner_less>> active_contexts; + // contexts available for reuse + std::vector> retired_context_pool; + // weak references to thread local caches from which this QNNExecutionProvider instance's entry should be removed + // upon destruction + std::set, std::owner_less>> + caches_to_update_on_destruction; + // synchronizes access to PerThreadContextState members + OrtMutex mutex; + }; + + // The execution provider maintains the PerThreadContexts in this structure. + // Synchronization is required to update the contained structures. + // On the other hand, access to an individual PerThreadContext is assumed to be from a single thread at a time, + // so synchronization is not required for that. + mutable PerThreadContextState context_state_; + + PerThreadContext& GetPerThreadContext() const; + void ReleasePerThreadContext() const; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index ee3578326ac6d..3fd5423681b81 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -353,7 +353,7 @@ Status ROCMExecutionProvider::Sync() const { return Status::OK(); } -Status ROCMExecutionProvider::OnRunStart() { +Status ROCMExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { // always set ROCM device when session::Run() in case it runs in a worker thread HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId())); if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { @@ -363,7 +363,7 @@ Status ROCMExecutionProvider::OnRunStart() { return Status::OK(); } -Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) { +Status ROCMExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) { if (GetPerThreadContext().IsGraphCaptureAllowed()) { GetPerThreadContext().CaptureEnd(); diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index 37d5f7b42210f..da671d9e863bb 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -28,9 +28,9 @@ class ROCMExecutionProvider : public IExecutionProvider { Status Sync() const override; - Status OnRunStart() override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; const void* GetExecutionHandle() const noexcept override { // The ROCM interface does not return anything interesting. diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index c0bf29e486c88..81346671f2aad 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1818,11 +1818,11 @@ std::unique_ptr TensorrtExecutionProvider::GetDataTransfer() cons return onnxruntime::CreateGPUDataTransfer(); } -Status TensorrtExecutionProvider::OnRunStart() { +Status TensorrtExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); } -Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { +Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { if (sync_stream && external_stream_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index e86f997b6597a..26f6b2dcc3020 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -233,8 +233,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; - Status OnRunStart() override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; ProviderOptions GetProviderOptions() const override { return TensorrtExecutionProviderInfo::ToProviderOptions(info_); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index b045f30a59797..efd7db4ea7629 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2289,8 +2289,8 @@ Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options, // TODO: only call OnRunStart for all providers in-use for (auto& xp : execution_providers_) { // call OnRunStart and add to exec_providers_to_stop if successful - auto start_func = [&xp, &exec_providers_to_stop]() { - auto status = xp->OnRunStart(); + auto start_func = [&xp, &exec_providers_to_stop, run_options]() { + auto status = xp->OnRunStart(run_options); if (status.IsOK()) exec_providers_to_stop.push_back(xp.get()); @@ -2326,7 +2326,7 @@ Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options, // info all execution providers InferenceSession:Run ended for (auto* xp : exec_providers_to_stop) { - auto status = xp->OnRunEnd(/*sync_stream*/ false); + auto status = xp->OnRunEnd(/*sync_stream*/ false, run_options); ORT_CHECK_AND_SET_RETVAL(status); } @@ -2448,8 +2448,8 @@ Status InferenceSession::Run(const RunOptions& run_options, // TODO: only call OnRunStart for all providers in-use for (auto& xp : execution_providers_) { // call OnRunStart and add to exec_providers_to_stop if successful - auto start_func = [&xp, &exec_providers_to_stop]() { - auto status = xp->OnRunStart(); + auto start_func = [&xp, &exec_providers_to_stop, &run_options]() { + auto status = xp->OnRunStart(run_options); if (status.IsOK()) exec_providers_to_stop.push_back(xp.get()); @@ -2490,7 +2490,7 @@ Status InferenceSession::Run(const RunOptions& run_options, // info all execution providers InferenceSession:Run ended for (auto* xp : exec_providers_to_stop) { bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0"; - auto status = xp->OnRunEnd(synchronize_execution_providers); + auto status = xp->OnRunEnd(synchronize_execution_providers, run_options); ORT_CHECK_AND_SET_RETVAL(status); } diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc index a70e439cdf755..5505d689381c9 100644 --- a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc @@ -22,6 +22,8 @@ TEST(TestDeferredRelease, WithArena) { CUDAExecutionProvider ep(info); AllocatorPtr gpu_alloctor = ep.CreatePreferredAllocators()[0]; + RunOptions run_opts; + run_opts.run_tag = "log1"; // Allocator for call cudaMallocHost and cudaFreeHost // For details, see CUDAPinnedAllocator in cuda_allocator.cc. AllocatorPtr cpu_pinned_alloc = ep.CreatePreferredAllocators()[1]; @@ -31,7 +33,7 @@ TEST(TestDeferredRelease, WithArena) { // 10 MB const size_t n_bytes = 10 * 1000000; const int64_t n_allocs = 64; - ORT_THROW_IF_ERROR(ep.OnRunStart()); + ORT_THROW_IF_ERROR(ep.OnRunStart(run_opts)); for (size_t i = 0; i < n_allocs; ++i) { // Allocate 10MB CUDA pinned memory. auto pinned_buffer = IAllocator::MakeUniquePtr(cpu_pinned_alloc, n_bytes); @@ -44,7 +46,7 @@ TEST(TestDeferredRelease, WithArena) { cpu_pinned_alloc->GetStats(&stats); ASSERT_EQ(stats.num_allocs, n_allocs); ORT_THROW_IF_ERROR(stream.CleanUpOnRunEnd()); - ORT_THROW_IF_ERROR(ep.OnRunEnd(true)); + ORT_THROW_IF_ERROR(ep.OnRunEnd(true, run_opts)); } TEST(TestDeferredRelease, WithoutArena) { @@ -52,6 +54,9 @@ TEST(TestDeferredRelease, WithoutArena) { CUDAExecutionProviderInfo info; CUDAExecutionProvider ep(info); + RunOptions run_opts; + run_opts.run_tag = "log1"; + OrtDevice pinned_device{OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, DEFAULT_CPU_ALLOCATOR_DEVICE_ID}; // Create allocator without BFCArena AllocatorCreationInfo pinned_memory_info( @@ -70,7 +75,7 @@ TEST(TestDeferredRelease, WithoutArena) { // 10 MB const size_t n_bytes = 10 * 1000000; const int64_t n_allocs = 64; - ORT_THROW_IF_ERROR(ep.OnRunStart()); + ORT_THROW_IF_ERROR(ep.OnRunStart(run_opts)); for (size_t i = 0; i < n_allocs; ++i) { // Allocate 10MB CUDA pinned memory. auto pinned_buffer = IAllocator::MakeUniquePtr(cuda_pinned_alloc, n_bytes); @@ -79,7 +84,7 @@ TEST(TestDeferredRelease, WithoutArena) { } ORT_THROW_IF_ERROR(stream.CleanUpOnRunEnd()); - ORT_THROW_IF_ERROR(ep.OnRunEnd(true)); + ORT_THROW_IF_ERROR(ep.OnRunEnd(true, run_opts)); } } // namespace test diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 4e1aef2c40b2b..8f07c2ce77e77 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -7,6 +7,7 @@ #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/onnxruntime_run_options_config_keys.h" #include "core/providers/cpu/cpu_provider_factory.h" // For OrtSessionOptionsAppendExecutionProvider_CPU #include "core/session/inference_session.h" @@ -332,19 +333,23 @@ static void CreateModelInMemory(std::unique_ptr& result, static void RunSessionAndVerify(InferenceSession& session, const RunOptions& run_options, const NameMLValMap& feeds, const std::vector& output_names, const std::vector>& output_shapes, - const std::vector>& expected_values) { - std::vector fetches; - auto status = session.Run(run_options, feeds, output_names, &fetches); - ASSERT_TRUE(status.IsOK()); - - for (size_t i = 0; i < fetches.size(); i++) { - auto& tensor = fetches[i].Get(); - TensorShape expected_shape(output_shapes[i]); - ASSERT_EQ(expected_shape, tensor.Shape()); - - gsl::span actual = tensor.DataAsSpan(); - gsl::span expected(expected_values[i].data(), expected_values[i].size()); - ASSERT_EQ(expected, actual); + const std::vector>& expected_values, + int loop_count = 10) { + // Let it run for a while + for (int it = 0; it < loop_count; ++it) { + std::vector fetches; + auto status = session.Run(run_options, feeds, output_names, &fetches); + ASSERT_TRUE(status.IsOK()); + + for (size_t i = 0; i < fetches.size(); i++) { + auto& tensor = fetches[i].Get(); + TensorShape expected_shape(output_shapes[i]); + ASSERT_EQ(expected_shape, tensor.Shape()); + + gsl::span actual = tensor.DataAsSpan(); + gsl::span expected(expected_values[i].data(), expected_values[i].size()); + ASSERT_EQ(expected, actual); + } } } @@ -404,11 +409,11 @@ TEST_F(QnnCPUBackendTests, MultithreadSessionRun) { std::vector threads; constexpr int num_threads = 5; - + constexpr int loop_count = 10; for (int i = 0; i < num_threads; i++) { threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, model->builder.feeds_, model->builder.output_names_, - output_shapes, output_values)); + output_shapes, output_values, loop_count)); } for (auto& th : threads) { @@ -484,11 +489,191 @@ TEST_F(QnnHTPBackendTests, MultithreadSessionRun) { std::vector threads; constexpr int num_threads = 5; + constexpr int loop_count = 10; for (int i = 0; i < num_threads; i++) { threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, model->builder.feeds_, model->builder.output_names_, - output_shapes, output_values)); + output_shapes, output_values, loop_count)); + } + + for (auto& th : threads) { + th.join(); + } +} + +// Tests running a single session in multiple threads on the HTP backend with run option to set power config +TEST_F(QnnHTPBackendTests, MultithreadHtpPowerCfgSessionRunOption) { + std::unique_ptr model; + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector shape = {1, 3, 2}; + std::vector> output_shapes = {shape}; + std::vector> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}}; + + CreateModelInMemory(model, + QDQBuildAdd3Tensors(TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data)), + "add3.qdq"); + + SessionOptions session_opts; + session_opts.session_logid = "logger0"; + + InferenceSession session_obj{session_opts, GetEnvironment()}; + onnxruntime::ProviderOptions options; + +#if defined(_WIN32) + options["backend_path"] = "QnnHtp.dll"; +#else + options["backend_path"] = "libQnnHtp.so"; +#endif + + auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts); + EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK()); + + auto status = session_obj.Load(model->model_data.data(), static_cast(model->model_data.size())); + ASSERT_TRUE(status.IsOK()); + status = session_obj.Initialize(); + ASSERT_TRUE(status.IsOK()); + + std::vector threads; + constexpr int num_threads = 5; + constexpr int loop_count = 10; + + std::vector perf_modes{ + "burst", "balanced", "default", "high_performance", "high_power_saver", + "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver"}; + + size_t post_i = perf_modes.size() - 1; + ASSERT_TRUE(post_i > num_threads); + for (int i = 0; i < num_threads; ++i, --post_i) { + RunOptions run_opts; + run_opts.run_tag = session_opts.session_logid; + auto rt = run_opts.config_options.AddConfigEntry(kOrtRunOptionsConfigQnnPerfMode, perf_modes[i].c_str()); + ASSERT_TRUE(rt.IsOK()); + rt = run_opts.config_options.AddConfigEntry(kOrtRunOptionsConfigQnnPerfModePostRun, perf_modes[post_i].c_str()); + ASSERT_TRUE(rt.IsOK()); + + threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, + model->builder.feeds_, model->builder.output_names_, + output_shapes, output_values, loop_count)); + } + + for (auto& th : threads) { + th.join(); + } +} + +// Tests running a single session in multiple threads on the HTP backend with EP option to set default power config +TEST_F(QnnHTPBackendTests, MultithreadDefaultHtpPowerCfgFromEpOption) { + std::unique_ptr model; + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector shape = {1, 3, 2}; + std::vector> output_shapes = {shape}; + std::vector> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}}; + + CreateModelInMemory(model, + QDQBuildAdd3Tensors(TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data)), + "add3.qdq"); + + SessionOptions session_opts; + session_opts.session_logid = "logger0"; + + RunOptions run_opts; + run_opts.run_tag = session_opts.session_logid; + + InferenceSession session_obj{session_opts, GetEnvironment()}; + onnxruntime::ProviderOptions options; + +#if defined(_WIN32) + options["backend_path"] = "QnnHtp.dll"; +#else + options["backend_path"] = "libQnnHtp.so"; +#endif + options["htp_performance_mode"] = "burst"; + + auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts); + EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK()); + + auto status = session_obj.Load(model->model_data.data(), static_cast(model->model_data.size())); + ASSERT_TRUE(status.IsOK()); + status = session_obj.Initialize(); + ASSERT_TRUE(status.IsOK()); + + std::vector threads; + constexpr int num_threads = 5; + constexpr int loop_count = 10; + + for (int i = 0; i < num_threads; i++) { + threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, + model->builder.feeds_, model->builder.output_names_, + output_shapes, output_values, loop_count)); + } + + for (auto& th : threads) { + th.join(); + } +} + +// Tests running a single session in multiple threads on the HTP backend with +// EP option to set default power config + run option to set power config for each run +TEST_F(QnnHTPBackendTests, MultithreadHtpPowerCfgDefaultAndRunOption) { + std::unique_ptr model; + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector shape = {1, 3, 2}; + std::vector> output_shapes = {shape}; + std::vector> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}}; + + CreateModelInMemory(model, + QDQBuildAdd3Tensors(TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data)), + "add3.qdq"); + + SessionOptions session_opts; + session_opts.session_logid = "logger0"; + + InferenceSession session_obj{session_opts, GetEnvironment()}; + onnxruntime::ProviderOptions options; + +#if defined(_WIN32) + options["backend_path"] = "QnnHtp.dll"; +#else + options["backend_path"] = "libQnnHtp.so"; +#endif + options["htp_performance_mode"] = "burst"; + + auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts); + EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK()); + + auto status = session_obj.Load(model->model_data.data(), static_cast(model->model_data.size())); + ASSERT_TRUE(status.IsOK()); + status = session_obj.Initialize(); + ASSERT_TRUE(status.IsOK()); + + std::vector threads; + constexpr int num_threads = 5; + constexpr int loop_count = 10; + + std::vector perf_modes{ + "burst", "balanced", "default", "high_performance", "high_power_saver", + "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver"}; + + size_t post_i = perf_modes.size() - 1; + ASSERT_TRUE(post_i > num_threads); + for (int i = 0; i < num_threads; ++i, --post_i) { + RunOptions run_opts; + run_opts.run_tag = session_opts.session_logid; + auto rt = run_opts.config_options.AddConfigEntry(kOrtRunOptionsConfigQnnPerfMode, perf_modes[i].c_str()); + ASSERT_TRUE(rt.IsOK()); + rt = run_opts.config_options.AddConfigEntry(kOrtRunOptionsConfigQnnPerfModePostRun, perf_modes[post_i].c_str()); + ASSERT_TRUE(rt.IsOK()); + + threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, + model->builder.feeds_, model->builder.output_names_, + output_shapes, output_values, loop_count)); } for (auto& th : threads) {