Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable user to set QNN HTP performance mode for every session run #19521

Merged
merged 14 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class Node;
#include "core/framework/stream_handles.h"
#include "core/framework/tuning_context.h"

struct OrtRunOptions;

namespace onnxruntime {

/**
Expand All @@ -51,6 +53,8 @@ struct NodeComputeInfo {
DestroyFunctionStateFunc release_state_func;
};

using RunOptions = OrtRunOptions;

enum class DataLayout {
NCHW,
NHWC,
Expand Down Expand Up @@ -184,15 +188,17 @@ class IExecutionProvider {
Run may not be finished on device This function should be regarded as the
point after which a new Run would start to submit commands from CPU
*/
virtual common::Status OnRunStart() { return Status::OK(); }
virtual common::Status OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); }

/**
Called when InferenceSession::Run ended
NOTE that due to async execution in provider, the actual work of this Run
may not be finished on device This function should be regarded as the point
that all commands of current Run has been submmited by CPU
*/
virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); }
virtual common::Status OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) {
return Status::OK();
}

/**
Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,15 @@ static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memor
// Per default it will be set to '0'
// Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream.
static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers";

// Set HTP performance mode for QNN HTP backend before session run.
// options for HTP performance mode: "burst", "balanced", "default", "high_performance",
// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver",
// "sustained_high_performance". Default to "default".
static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode";

// Set HTP performance mode for QNN HTP backend post session run.
static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run";

// Set RPC control latency for QNN HTP backend
static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency";
4 changes: 3 additions & 1 deletion onnxruntime/core/framework/stream_execution_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,13 @@ void RunSince(size_t stream_idx, StreamExecutionContext& ctx, SessionScope& sess
}

#ifdef USE_CANN
// Leave it to CANN EP to fill the gap if they want to use run_options
static onnxruntime::RunOptions run_options;
// For CANN EP, it is necessary to explicitly create a corresponding Context for each thread in the thread pool,
// which is different from CUDA Runtime API, but similar to CUDA Driver API.
auto& execution_providers = ctx.GetSessionState().GetExecutionProviders();
for (auto& xp : execution_providers) {
auto status = xp->OnRunStart();
auto status = xp->OnRunStart(run_options);
if (!status.IsOK()) {
ctx.SetStatus(status);
return;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cann/cann_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,7 @@ CANNExecutionProvider::~CANNExecutionProvider() {
}

// All threads share the same context and stream
Status CANNExecutionProvider::OnRunStart() {
Status CANNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) {
CANN_RETURN_IF_ERROR(aclrtSetDevice(info_.device_id));

return Status::OK();
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cann/cann_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class CANNExecutionProvider : public IExecutionProvider {
explicit CANNExecutionProvider(const CANNExecutionProviderInfo& info);
virtual ~CANNExecutionProvider();

Status OnRunStart() override;
Status OnRunStart(const onnxruntime::RunOptions& run_options) override;

template <typename T>
Status Fill(Tensor* y, void* addr, aclrtStream stream) const {
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ Status CUDAExecutionProvider::Sync() const {
return Status::OK();
}

Status CUDAExecutionProvider::OnRunStart() {
Status CUDAExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) {
// always set CUDA device when session::Run() in case it runs in a worker thread
CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId()));
if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) {
Expand All @@ -396,7 +396,7 @@ Status CUDAExecutionProvider::OnRunStart() {
return Status::OK();
}

Status CUDAExecutionProvider::OnRunEnd(bool sync_stream) {
Status CUDAExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) {
if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) {
if (GetPerThreadContext().IsGraphCaptureAllowed()) {
GetPerThreadContext().CaptureEnd();
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ class CUDAExecutionProvider : public IExecutionProvider {

Status Sync() const override;

Status OnRunStart() override;
Status OnRunStart(const onnxruntime::RunOptions& run_options) override;

Status OnRunEnd(bool sync_stream) override;
Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override;

DataLayout GetPreferredLayout() const override;

Expand Down Expand Up @@ -115,6 +115,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
PerThreadContext(OrtDevice::DeviceId device_id, cudaStream_t stream, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy,
CUDAExecutionProviderExternalAllocatorInfo external_alloc_info, OrtArenaCfg* arena_cfg);
~PerThreadContext();
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext);

cublasHandle_t CublasHandle() const {
return cublas_handle_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,15 @@
return m_impl->OnSessionInitializationEnd();
}

virtual onnxruntime::Status Sync() const final override
onnxruntime::Status Sync() const final override

Check warning on line 273 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 "override" is redundant since function is already declared as "final" [readability/inheritance] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h:273: "override" is redundant since function is already declared as "final" [readability/inheritance] [4]
{
// Completely wait until the device has completed all preceding tasks.
// The application could have called SynchronizeBoundOutputs().
m_impl->WaitForOutstandingWork();
return Status::OK();
}

virtual onnxruntime::Status OnRunEnd(bool /*sync_stream*/) final override
onnxruntime::Status OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) final override

Check warning on line 281 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 "override" is redundant since function is already declared as "final" [readability/inheritance] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h:281: "override" is redundant since function is already declared as "final" [readability/inheritance] [4]
{
// Flush any pending work to the GPU, but don't block for completion, permitting it
// to overlap other work.
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/js/js_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -756,15 +756,15 @@ std::unique_ptr<onnxruntime::IDataTransfer> JsExecutionProvider::GetDataTransfer
JsExecutionProvider::~JsExecutionProvider() {
}

Status JsExecutionProvider::OnRunStart() {
Status JsExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) {
if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured()) {
LOGS(*GetLogger(), INFO) << "Capturing the webgpu graph for this model";
EM_ASM({ Module.jsepCaptureBegin(); });
}
return Status::OK();
}

Status JsExecutionProvider::OnRunEnd(bool sync_stream) {
Status JsExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) {
if (IsGraphCaptureEnabled() && !IsGraphCaptured()) {
if (IsGraphCaptureAllowed()) {
EM_ASM({ Module.jsepCaptureEnd(); });
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/js/js_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ class JsExecutionProvider : public IExecutionProvider {

std::vector<AllocatorPtr> CreatePreferredAllocators() override;

Status OnRunStart() override;
Status OnRunEnd(bool sync_stream) override;
Status OnRunStart(const onnxruntime::RunOptions& run_options) override;
Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override;

bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured() const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1383,11 +1383,11 @@ Status MIGraphXExecutionProvider::Sync() const {
return Status::OK();
}

Status MIGraphXExecutionProvider::OnRunStart() {
Status MIGraphXExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) {
return Status::OK();
}

Status MIGraphXExecutionProvider::OnRunEnd(bool) {
Status MIGraphXExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) {
auto status = hipStreamQuery(stream_);

if (status != hipSuccess) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
#ifdef MIGRAPHX_STREAM_SYNC
Status Sync() const override;

Status OnRunStart() override;
Status OnRunStart(const onnxruntime::RunOptions& run_options) override;

Status OnRunEnd(bool sync_stream) override;
Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override;
#endif

std::vector<std::unique_ptr<ComputeCapability>>
Expand Down
75 changes: 45 additions & 30 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -629,19 +629,14 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_
LOGS(logger, VERBOSE) << "CreateContext succeed.";
}

if (htp_performance_mode_ != HtpPerformanceMode::kHtpDefault) {
ORT_RETURN_IF_ERROR(SetHtpPowerConfig());
LOGS(logger, VERBOSE) << "SetHtpPowerConfig succeed.";
}

LOGS(logger, VERBOSE) << "QNN SetupBackend succeed";

backend_setup_completed_ = true;

return Status::OK();
}

Status QnnBackendManager::SetHtpPowerConfig() {
Status QnnBackendManager::CreateHtpPowerCfgId(uint32_t device_id, uint32_t core_id, uint32_t& htp_power_config_id) {
QnnDevice_Infrastructure_t qnn_device_infra = nullptr;
auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra);
ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed.");
Expand All @@ -651,23 +646,37 @@ Status QnnBackendManager::SetHtpPowerConfig() {
"HTP infra type = ", htp_infra->infraType, ", which is not perf infra type.");
QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra;
// Get power client id
status = htp_perf_infra.createPowerConfigId(/*device_id=*/0, /*core_id=*/0, &htp_power_config_client_id_);
status = htp_perf_infra.createPowerConfigId(device_id, core_id, &htp_power_config_id);
ORT_RETURN_IF(QNN_SUCCESS != status, "createPowerConfigId failed.");

return Status::OK();
}

Status QnnBackendManager::SetHtpPowerConfig(uint32_t htp_power_config_client_id,
HtpPerformanceMode htp_performance_mode) {
QnnDevice_Infrastructure_t qnn_device_infra = nullptr;
auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra);
ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed.");

auto* htp_infra = static_cast<QnnHtpDevice_Infrastructure_t*>(qnn_device_infra);
ORT_RETURN_IF(QNN_HTP_DEVICE_INFRASTRUCTURE_TYPE_PERF != htp_infra->infraType,
"HTP infra type = ", htp_infra->infraType, ", which is not perf infra type.");
QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra;

constexpr const int kNumConfigs = 1;
std::vector<QnnHtpPerfInfrastructure_PowerConfig_t> power_configs(
kNumConfigs);
QnnHtpPerfInfrastructure_PowerConfig_t& dcvs_config = power_configs[0];
dcvs_config.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_DCVS_V3;
QnnHtpPerfInfrastructure_DcvsV3_t& dcvs_v3 = dcvs_config.dcvsV3Config;
dcvs_v3.contextId = htp_power_config_client_id_;
dcvs_v3.contextId = htp_power_config_client_id;
dcvs_v3.setSleepDisable = 0;
dcvs_v3.sleepDisable = 0;
dcvs_v3.setDcvsEnable = 1;
dcvs_v3.dcvsEnable = kDcvsDisable;
dcvs_v3.powerMode = QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_PERFORMANCE_MODE;
// choose performance mode
switch (htp_performance_mode_) {
switch (htp_performance_mode) {
case HtpPerformanceMode::kHtpBurst:
dcvs_v3.setSleepLatency = 1; // true
dcvs_v3.sleepLatency = kSleepMinLatency;
Expand Down Expand Up @@ -766,25 +775,40 @@ Status QnnBackendManager::SetHtpPowerConfig() {
dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM_PLUS;
break;
default:
ORT_THROW("Invalid performance profile %d", static_cast<int>(htp_performance_mode_));
ORT_THROW("Invalid performance profile %d", static_cast<int>(htp_performance_mode));
break;
}
std::vector<const QnnHtpPerfInfrastructure_PowerConfig_t*> perf_power_configs_ptr = ObtainNullTermPtrVector(power_configs);
status = htp_perf_infra.setPowerConfig(htp_power_config_client_id_, perf_power_configs_ptr.data());
status = htp_perf_infra.setPowerConfig(htp_power_config_client_id, perf_power_configs_ptr.data());
ORT_RETURN_IF(QNN_SUCCESS != status, "setPowerConfig failed for HTP performance mode.");

// Set rpc control latency here, but note that v68 doesn't support rpc polling mode.
if (rpc_control_latency_ != 0) {
return Status::OK();
}

Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_id,
uint32_t rpc_control_latency) {
if (rpc_control_latency != 0) {
QnnDevice_Infrastructure_t qnn_device_infra = nullptr;
auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra);
ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed.");

auto* htp_infra = static_cast<QnnHtpDevice_Infrastructure_t*>(qnn_device_infra);
ORT_RETURN_IF(QNN_HTP_DEVICE_INFRASTRUCTURE_TYPE_PERF != htp_infra->infraType,
"HTP infra type = ", htp_infra->infraType, ", which is not perf infra type.");
QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra;

// Set rpc control latency here, but note that v68 doesn't support rpc polling mode.
constexpr int kNumRpcPollingPowerConfigs = 2;
std::vector<QnnHtpPerfInfrastructure_PowerConfig_t> rpc_power_configs(kNumRpcPollingPowerConfigs);
QnnHtpPerfInfrastructure_PowerConfig_t& rpc_control_latency = rpc_power_configs[0];
QnnHtpPerfInfrastructure_PowerConfig_t& rpc_control_latency_cfg = rpc_power_configs[0];
// v68 doesn't support this.
QnnHtpPerfInfrastructure_PowerConfig_t& rpc_polling_time = rpc_power_configs[1];
rpc_control_latency.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY;
rpc_control_latency_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY;
rpc_polling_time.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_POLLING_TIME;
rpc_control_latency.rpcControlLatencyConfig = rpc_control_latency_;
perf_power_configs_ptr = ObtainNullTermPtrVector(rpc_power_configs);
status = htp_perf_infra.setPowerConfig(htp_power_config_client_id_, perf_power_configs_ptr.data());
rpc_control_latency_cfg.rpcControlLatencyConfig = rpc_control_latency;
std::vector<const QnnHtpPerfInfrastructure_PowerConfig_t*> perf_power_configs_ptr =
ObtainNullTermPtrVector(rpc_power_configs);
status = htp_perf_infra.setPowerConfig(htp_power_config_client_id, perf_power_configs_ptr.data());
ORT_RETURN_IF(QNN_SUCCESS != status, "setPowerConfig failed for RPC control latency.");
}

Expand All @@ -805,11 +829,7 @@ void QnnBackendManager::Split(std::vector<std::string>& split_string,
}
}

Status QnnBackendManager::DestroyHTPPowerConfigID() {
if (htp_performance_mode_ == HtpPerformanceMode::kHtpDefault) {
return Status::OK();
}

Status QnnBackendManager::DestroyHTPPowerConfigID(uint32_t htp_power_config_id) {
QnnDevice_Infrastructure_t qnn_device_infra = nullptr;
auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra);
ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed.");
Expand All @@ -819,7 +839,7 @@ Status QnnBackendManager::DestroyHTPPowerConfigID() {
"HTP infra type = ", htp_infra->infraType, ", which is not perf infra type.");
QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra;

Qnn_ErrorHandle_t destroy_ret = htp_perf_infra.destroyPowerConfigId(htp_power_config_client_id_);
Qnn_ErrorHandle_t destroy_ret = htp_perf_infra.destroyPowerConfigId(htp_power_config_id);
ORT_RETURN_IF(QNN_SUCCESS != destroy_ret, "destroyPowerConfigId failed.");
return Status::OK();
}
Expand All @@ -829,12 +849,7 @@ void QnnBackendManager::ReleaseResources() {
return;
}

auto result = DestroyHTPPowerConfigID();
if (Status::OK() != result) {
ORT_THROW("Failed to DestroyHTPPowerConfigID.");
}

result = ReleaseContext();
auto result = ReleaseContext();
if (Status::OK() != result) {
ORT_THROW("Failed to ReleaseContext.");
}
Expand Down
Loading
Loading