From 02b1ff5fa2c41dc026022ca29c9249628f71f026 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Thu, 4 Jan 2024 13:32:48 -0800 Subject: [PATCH] [QNN EP] Support multithreaded inference of a single session (#18981) ### Description - Add mutex to protect QNN API calls for executing a graph and extracting the corresponding profile data. - Ensures QNN EP's execute function does not store unnecessary state (i.e., input and output buffer pointers do not need to be stored as class members.) ### Motivation and Context Allow calling `session.Run()` from multiple threads when using QNN EP. --- .../core/providers/qnn/builder/qnn_def.cc | 9 + .../core/providers/qnn/builder/qnn_def.h | 1 + .../core/providers/qnn/builder/qnn_model.cc | 107 ++++++---- .../core/providers/qnn/builder/qnn_model.h | 19 +- .../test/providers/qnn/qnn_basic_test.cc | 194 +++++++++++++++++- .../azure-pipelines/linux-qnn-ci-pipeline.yml | 8 +- .../win-qnn-arm64-ci-pipeline.yml | 6 +- .../azure-pipelines/win-qnn-ci-pipeline.yml | 4 +- 8 files changed, 292 insertions(+), 56 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.cc b/onnxruntime/core/providers/qnn/builder/qnn_def.cc index a77ac16cf624b..55e72670a6971 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.cc @@ -89,6 +89,15 @@ void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, const std::vector } } +void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, void* buf_data, uint32_t buf_size) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + qnn_tensor.v1.clientBuf.data = buf_data; + qnn_tensor.v1.clientBuf.dataSize = buf_size; + } else { + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); + } +} + void SetQnnTensorClientBufSize(Qnn_Tensor_t& qnn_tensor, uint32_t client_buf_size) { if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { qnn_tensor.v1.clientBuf.dataSize = client_buf_size; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index f6a3b1bd360ec..c202f2bf79c57 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -100,6 +100,7 @@ void SetQnnTensorDim(Qnn_Tensor_t& qnn_tensor, const std::vector& dime void SetQnnTensorMemType(Qnn_Tensor_t& qnn_tensor, Qnn_TensorMemType_t mem_type); void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, const std::vector& client_buf); void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, const std::vector& client_buf); +void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, void* buf_data, uint32_t buf_size); void SetQnnTensorClientBufSize(Qnn_Tensor_t& qnn_tensor, uint32_t client_buf_size); void SetQnnTensorClientBufData(Qnn_Tensor_t& qnn_tensor, void* client_buf_data); void SetQnnTensorQParams(Qnn_Tensor_t& qnn_tensor, const Qnn_QuantizeParams_t& quantize_params); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index fd3a95b5f1f78..869d9326d9232 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -166,14 +166,14 @@ Status QnnModel::FinalizeGraphs() { Status QnnModel::SetupQnnInputOutput() { LOGS(logger_, VERBOSE) << "Setting up QNN input/output for graph: " << graph_info_->Name(); - auto result = SetupTensors(qnn_inputs_, graph_info_->InputTensors()); + auto result = SetupTensors(qnn_input_infos_, graph_info_->InputTensors()); if (Status::OK() != result) { LOGS(logger_, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name(); return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN input tensors!"); } - result = SetupTensors(qnn_outputs_, graph_info_->OutputTensors(), false); + result = SetupTensors(qnn_output_infos_, graph_info_->OutputTensors(), false); if (Status::OK() != result) { LOGS(logger_, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name(); return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN output tensors!"); @@ -186,8 +186,8 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) { LOGS(logger_, VERBOSE) << "QnnModel::ExecuteGraphs"; const size_t num_inputs = context.GetInputCount(); const size_t num_outputs = context.GetOutputCount(); - ORT_RETURN_IF_NOT(qnn_inputs_.size() <= num_inputs, "Inconsistent input sizes"); - ORT_RETURN_IF_NOT(qnn_outputs_.size() == num_outputs, "Inconsistent output sizes"); + ORT_RETURN_IF_NOT(qnn_input_infos_.size() <= num_inputs, "Inconsistent input sizes"); + ORT_RETURN_IF_NOT(qnn_output_infos_.size() == num_outputs, "Inconsistent output sizes"); using namespace qnn::utils; auto TensorDataSize = [&](auto ort_tensor) -> size_t { @@ -198,49 +198,67 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) { return element_size * length; }; - for (auto& qnn_input_tensor : qnn_inputs_) { - const std::string& model_input_name(GetQnnTensorName(qnn_input_tensor)); - auto index = GetOrtInputIndex(model_input_name); - LOGS(logger_, VERBOSE) << "model_input = " << model_input_name << " index = " << index; - auto ort_input_tensor = context.GetInput(index); - auto qnn_tensor_size = GetQnnTensorClientBuf(qnn_input_tensor).dataSize; + std::vector qnn_inputs; + qnn_inputs.reserve(qnn_input_infos_.size()); + + for (const auto& qnn_input_info : qnn_input_infos_) { + LOGS(logger_, VERBOSE) << "model_input = " << qnn_input_info.tensor_wrapper->GetName() + << " index = " << qnn_input_info.ort_index; + auto ort_input_tensor = context.GetInput(qnn_input_info.ort_index); auto ort_tensor_size = TensorDataSize(ort_input_tensor); - LOGS(logger_, VERBOSE) << "Qnn tensor size: " << qnn_tensor_size << "Ort tensor size: " << ort_tensor_size; - ORT_ENFORCE(qnn_tensor_size == ort_tensor_size, + LOGS(logger_, VERBOSE) << "Qnn tensor size: " << qnn_input_info.tensor_byte_size + << "Ort tensor size: " << ort_tensor_size; + ORT_ENFORCE(qnn_input_info.tensor_byte_size == ort_tensor_size, "ORT Tensor data size does not match QNN tensor data size."); - SetQnnTensorClientBufData(qnn_input_tensor, - const_cast(ort_input_tensor.GetTensorData())); + + qnn_inputs.push_back(qnn_input_info.tensor_wrapper->GetQnnTensor()); + SetQnnTensorClientBuf(qnn_inputs.back(), + const_cast(ort_input_tensor.GetTensorData()), qnn_input_info.tensor_byte_size); } - for (auto& qnn_output_tensor : qnn_outputs_) { - const std::string& model_output_name(GetQnnTensorName(qnn_output_tensor)); - auto index = GetOutputIndex(model_output_name); - LOGS(logger_, VERBOSE) << "model_output = " << model_output_name << " index = " << index; - const auto& output_info = GetOutputInfo(model_output_name); - const std::vector& output_shape = output_info->shape_; - auto output_tensor = context.GetOutput(index, output_shape.data(), output_shape.size()); - auto qnn_tensor_size = GetQnnTensorClientBuf(qnn_output_tensor).dataSize; - auto ort_tensor_size = TensorDataSize(output_tensor); - LOGS(logger_, VERBOSE) << "Qnn tensor size: " << qnn_tensor_size << "Ort tensor size: " << ort_tensor_size; - ORT_ENFORCE(qnn_tensor_size == ort_tensor_size, + std::vector qnn_outputs; + qnn_outputs.reserve(qnn_output_infos_.size()); + + for (auto& qnn_output_info : qnn_output_infos_) { + const std::string& model_output_name = qnn_output_info.tensor_wrapper->GetName(); + LOGS(logger_, VERBOSE) << "model_output = " << model_output_name << " index = " << qnn_output_info.ort_index; + const auto& ort_output_info = GetOutputInfo(model_output_name); + const std::vector& output_shape = ort_output_info->shape_; + auto ort_output_tensor = context.GetOutput(qnn_output_info.ort_index, output_shape.data(), output_shape.size()); + auto ort_tensor_size = TensorDataSize(ort_output_tensor); + LOGS(logger_, VERBOSE) << "Qnn tensor size: " << qnn_output_info.tensor_byte_size + << "Ort tensor size: " << ort_tensor_size; + ORT_ENFORCE(qnn_output_info.tensor_byte_size == ort_tensor_size, "ORT Tensor data size does not match QNN tensor data size"); - SetQnnTensorClientBufData(qnn_output_tensor, - const_cast(output_tensor.GetTensorData())); + + qnn_outputs.push_back(qnn_output_info.tensor_wrapper->GetQnnTensor()); + SetQnnTensorClientBuf(qnn_outputs.back(), + const_cast(ort_output_tensor.GetTensorData()), qnn_output_info.tensor_byte_size); } LOGS(logger_, VERBOSE) << "Start execute QNN graph:" << graph_info_->Name(); auto qnn_interface = qnn_backend_manager_->GetQnnInterface(); auto profile_backend_handle = qnn_backend_manager_->GetQnnProfileHandle(); Qnn_ErrorHandle_t execute_status = QNN_GRAPH_NO_ERROR; - execute_status = qnn_interface.graphExecute(graph_info_->Graph(), - qnn_inputs_.data(), - static_cast(qnn_inputs_.size()), - qnn_outputs_.data(), - static_cast(qnn_outputs_.size()), - profile_backend_handle, - nullptr); - ORT_RETURN_IF_ERROR(qnn_backend_manager_->ExtractBackendProfilingInfo()); + { + // Acquire mutex before calling graphExecute and profiling APIs to support calling session.Run() + // from multiple threads. + std::lock_guard lock(graph_exec_mutex_); + execute_status = qnn_interface.graphExecute(graph_info_->Graph(), + qnn_inputs.data(), + static_cast(qnn_inputs.size()), + qnn_outputs.data(), + static_cast(qnn_outputs.size()), + profile_backend_handle, + nullptr); + + // NOTE: This function returns immediately when profiling is disabled. + // Extracting profiling data can be expensive, but it is typically only enabled for debugging purposes + // and not in production. We can improve synchronization for event profiling if it becomes an issue. + ORT_RETURN_IF_ERROR(qnn_backend_manager_->ExtractBackendProfilingInfo()); + } + if (QNN_GRAPH_NO_ERROR != execute_status) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN graph execute error. Error code: ", execute_status); } @@ -262,14 +280,13 @@ Status QnnModel::GetQnnTensorDataLength(const std::vector& dims, return Status::OK(); } -// Setup details for Qnn_Tensor_t for execution -// based on information in QnnTensorWrapper -Status QnnModel::SetupTensors(std::vector& qnn_tensors, +// Setup information for Qnn inputs/outputs used during execution. +Status QnnModel::SetupTensors(std::vector& qnn_tensor_infos, const std::vector& tensor_wrappers, bool is_input) { size_t tensor_count = tensor_wrappers.size(); ORT_RETURN_IF(0 == tensor_count, "Zero tensor size!"); - qnn_tensors.resize(tensor_count); + qnn_tensor_infos.resize(tensor_count); for (auto& tensor_wrapper : tensor_wrappers) { size_t length = 0; @@ -277,10 +294,14 @@ Status QnnModel::SetupTensors(std::vector& qnn_tensors, ORT_RETURN_IF_ERROR(GetQnnTensorDataLength(tensor_wrapper.GetTensorDims(), tensor_wrapper.GetTensorDataType(), length)); - auto tensor_name = tensor_wrapper.GetName(); - auto index = is_input ? GetGraphInputIndex(tensor_name) : GetOutputIndex(tensor_name); - qnn_tensors[index] = tensor_wrapper.GetQnnTensor(); - SetQnnTensorClientBufSize(qnn_tensors[index], static_cast(length)); + const auto& tensor_name = tensor_wrapper.GetName(); + auto qnn_index = is_input ? GetGraphInputIndex(tensor_name) : GetOutputIndex(tensor_name); + auto ort_index = is_input ? GetOrtInputIndex(tensor_name) : qnn_index; + + QnnTensorInfo& qnn_tensor_info = qnn_tensor_infos[qnn_index]; + qnn_tensor_info.tensor_wrapper = &tensor_wrapper; + qnn_tensor_info.tensor_byte_size = static_cast(length); + qnn_tensor_info.ort_index = ort_index; } return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h index de4f872f73ccf..d0dd091cb1688 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h @@ -3,8 +3,11 @@ #pragma once +#include + #include "core/common/status.h" #include "core/graph/graph_viewer.h" +#include "core/platform/ort_mutex.h" #include "core/providers/qnn/builder/qnn_def.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/qnn_backend_manager.h" @@ -14,6 +17,12 @@ namespace onnxruntime { namespace qnn { +struct QnnTensorInfo { + const QnnTensorWrapper* tensor_wrapper = nullptr; + uint32_t tensor_byte_size = 0; + size_t ort_index = 0; +}; + class QnnModel { public: QnnModel(const logging::Logger& logger, @@ -103,7 +112,8 @@ class QnnModel { Qnn_DataType_t data_type, size_t& data_length) const; - Status SetupTensors(std::vector& tensors, const std::vector& tensor_wrappers, bool is_input = true); + Status SetupTensors(std::vector& tensors, const std::vector& tensor_wrappers, + bool is_input = true); QnnBackendType GetQnnBackendType() { return qnn_backend_type_; } @@ -126,9 +136,12 @@ class QnnModel { std::vector output_names_; std::unordered_map inputs_info_; std::unordered_map outputs_info_; - std::vector qnn_inputs_; - std::vector qnn_outputs_; + std::vector qnn_input_infos_; + std::vector qnn_output_infos_; QnnBackendType qnn_backend_type_ = QnnBackendType::CPU; + + // Mutex acquired during graph execution to support multi-threaded inference of a single session. + OrtMutex graph_exec_mutex_; }; } // namespace qnn diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 391d7bebc9589..f9064cad3fe12 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -1,8 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include #include +#include +#include #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -287,8 +288,199 @@ TEST_F(QnnCPUBackendTests, QnnSaver_OutputFiles) { EXPECT_TRUE(std::filesystem::exists(qnn_saver_output_dir / "params.bin")); } +struct ModelAndBuilder { + ModelAndBuilder(Graph& graph) : builder(graph) {} + std::string model_data; + ModelTestBuilder builder; +}; + +// Creates a model in memory. Input feeds and output names can be accessed from result.builder. +static void CreateModelInMemory(std::unique_ptr& result, + const GetTestModelFn& model_build_fn, + const std::string& model_name, + int opset_version = 18) { + const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; + auto& logging_manager = DefaultLoggingManager(); + + // Create float model and serialize it to a string. + onnxruntime::Model model(model_name, false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + result = std::make_unique(model.MainGraph()); + model_build_fn(result->builder); + result->builder.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + model.ToProto().SerializeToString(&result->model_data); +} + +// Runs a session and verifies the outputs. Can be run by individual threads. +static void RunSessionAndVerify(InferenceSession& session, const RunOptions& run_options, const NameMLValMap& feeds, + const std::vector& output_names, + const std::vector>& output_shapes, + const std::vector>& expected_values) { + std::vector fetches; + auto status = session.Run(run_options, feeds, output_names, &fetches); + ASSERT_TRUE(status.IsOK()); + + for (size_t i = 0; i < fetches.size(); i++) { + auto& tensor = fetches[i].Get(); + TensorShape expected_shape(output_shapes[i]); + ASSERT_EQ(expected_shape, tensor.Shape()); + + gsl::span actual = tensor.DataAsSpan(); + gsl::span expected(expected_values[i].data(), expected_values[i].size()); + ASSERT_EQ(expected, actual); + } +} + +// Returns a function that builds a float32 model that adds 3 tensors. +static GetTestModelFn F32BuildAdd3Tensors(const TestInputDef& input0_def, + const TestInputDef& input1_def, + const TestInputDef& input2_def) { + return [input0_def, input1_def, input2_def](ModelTestBuilder& builder) { + NodeArg* input0 = MakeTestInput(builder, input0_def); + NodeArg* input1 = MakeTestInput(builder, input1_def); + NodeArg* input2 = MakeTestInput(builder, input1_def); + + auto* add0_out = builder.MakeIntermediate(); + builder.AddNode("Add", {input0, input1}, {add0_out}); + + auto* output = builder.MakeOutput(); + builder.AddNode("Add", {add0_out, input2}, {output}); + }; +} + +// Tests running a single session in multiple threads on the CPU backend. +TEST_F(QnnCPUBackendTests, MultithreadSessionRun) { + std::unique_ptr model; + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector shape = {1, 3, 2}; + std::vector> output_shapes = {shape}; + std::vector> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}}; + + CreateModelInMemory(model, + F32BuildAdd3Tensors(TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data)), + "add3.f32"); + + SessionOptions session_opts; + session_opts.session_logid = "logger0"; + + RunOptions run_opts; + run_opts.run_tag = session_opts.session_logid; + + InferenceSession session_obj{session_opts, GetEnvironment()}; + onnxruntime::ProviderOptions options; + +#if defined(_WIN32) + options["backend_path"] = "QnnCpu.dll"; +#else + options["backend_path"] = "libQnnCpu.so"; +#endif + + auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts); + EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK()); + + auto status = session_obj.Load(model->model_data.data(), static_cast(model->model_data.size())); + ASSERT_TRUE(status.IsOK()); + status = session_obj.Initialize(); + ASSERT_TRUE(status.IsOK()); + + std::vector threads; + constexpr int num_threads = 5; + + for (int i = 0; i < num_threads; i++) { + threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, + model->builder.feeds_, model->builder.output_names_, + output_shapes, output_values)); + } + + for (auto& th : threads) { + th.join(); + } +} + #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// Returns a function that builds a QDQ model that adds 3 tensors. Forces all scales and zero-points to be (1.0f, 0), +// so it is only accurate when using non-fractional positive inputs. +template +static GetTestModelFn QDQBuildAdd3Tensors(const TestInputDef& input0_def, + const TestInputDef& input1_def, + const TestInputDef& input2_def) { + return [input0_def, input1_def, input2_def](ModelTestBuilder& builder) { + NodeArg* input0 = MakeTestInput(builder, input0_def); + NodeArg* input0_after_qdq = AddQDQNodePair(builder, input0, 1.0f, 0); + NodeArg* input1 = MakeTestInput(builder, input1_def); + NodeArg* input1_after_qdq = AddQDQNodePair(builder, input1, 1.0f, 0); + NodeArg* input2 = MakeTestInput(builder, input1_def); + NodeArg* input2_after_qdq = AddQDQNodePair(builder, input2, 1.0f, 0); + + auto* add0_out = builder.MakeIntermediate(); + builder.AddNode("Add", {input0_after_qdq, input1_after_qdq}, {add0_out}); + + auto* add0_out_dq = AddQDQNodePair(builder, add0_out, 1.0f, 0); + + auto* add1_out = builder.MakeIntermediate(); + builder.AddNode("Add", {add0_out_dq, input2_after_qdq}, {add1_out}); + + // op_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, add1_out, 1.0f, 0); + }; +} + +// Tests running a single session in multiple threads on the HTP backend. +TEST_F(QnnHTPBackendTests, MultithreadSessionRun) { + std::unique_ptr model; + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector shape = {1, 3, 2}; + std::vector> output_shapes = {shape}; + std::vector> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}}; + + CreateModelInMemory(model, + QDQBuildAdd3Tensors(TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data)), + "add3.qdq"); + + SessionOptions session_opts; + session_opts.session_logid = "logger0"; + + RunOptions run_opts; + run_opts.run_tag = session_opts.session_logid; + + InferenceSession session_obj{session_opts, GetEnvironment()}; + onnxruntime::ProviderOptions options; + +#if defined(_WIN32) + options["backend_path"] = "QnnHtp.dll"; +#else + options["backend_path"] = "libQnnHtp.so"; +#endif + + auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts); + EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK()); + + auto status = session_obj.Load(model->model_data.data(), static_cast(model->model_data.size())); + ASSERT_TRUE(status.IsOK()); + status = session_obj.Initialize(); + ASSERT_TRUE(status.IsOK()); + + std::vector threads; + constexpr int num_threads = 5; + + for (int i = 0; i < num_threads; i++) { + threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, + model->builder.feeds_, model->builder.output_names_, + output_shapes, output_values)); + } + + for (auto& th : threads) { + th.join(); + } +} + // Test shape inference of QDQ NHWC Resize operator (opset 18) that uses // the sizes input. Use the QNN HTP backend. TEST_F(QnnHTPBackendTests, TestNHWCResizeShapeInference_qdq_sizes_opset18) { diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 07e69ff496720..d286c4f3a46fe 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -86,7 +86,7 @@ jobs: inputs: script: | ./build/Release/onnx_test_runner -e qnn \ - -v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnCpu.so" \ + -v -j 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnCpu.so" \ cmake/external/onnx/onnx/backend/test/data/node - task: CmdLine@2 @@ -94,7 +94,7 @@ jobs: inputs: script: | ./build/Release/onnx_test_runner -e qnn \ - -v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnCpu.so" \ + -v -j 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnCpu.so" \ /data/float32_models - task: CmdLine@2 @@ -102,7 +102,7 @@ jobs: inputs: script: | ./build/Release/onnx_test_runner -e qnn \ - -v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so" \ + -v -j 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so" \ /data/qdq_models - task: CmdLine@2 @@ -110,5 +110,5 @@ jobs: inputs: script: | ./build/Release/onnx_test_runner -e qnn \ - -v -f -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so" \ + -v -f -j 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so" \ /data/qdq_models/mobilenetv2-1.0_add_transpose_quant diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 5e35cbfed6692..6dc428d6606af 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -84,17 +84,17 @@ jobs: displayName: 'Run unit tests' - script: | - .\$(BuildConfig)\onnx_test_runner -j 1 -c 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\aarch64-windows-msvc\QnnCpu.dll" $(Build.SourcesDirectory)\cmake\external\onnx\onnx\backend\test\data\node + .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\aarch64-windows-msvc\QnnCpu.dll" $(Build.SourcesDirectory)\cmake\external\onnx\onnx\backend\test\data\node workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)' displayName: 'Run ONNX Tests' - script: | - .\$(BuildConfig)\onnx_test_runner -j 1 -c 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\aarch64-windows-msvc\QnnCpu.dll" C:\data\float32_models + .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\aarch64-windows-msvc\QnnCpu.dll" C:\data\float32_models workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)' displayName: 'Run float32 model tests' - script: | - .\$(BuildConfig)\onnx_test_runner -j 1 -c 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\aarch64-windows-msvc\QnnHtp.dll" C:\data\qdq_models + .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\aarch64-windows-msvc\QnnHtp.dll" C:\data\qdq_models workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)' displayName: 'Run QDQ model tests' enabled: false diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 65b2924c8be60..fbec572fd346c 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -88,11 +88,11 @@ jobs: displayName: 'Run unit tests' - script: | - .\$(BuildConfig)\onnx_test_runner -j 1 -c 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\x86_64-windows-msvc\QnnCpu.dll" $(Build.SourcesDirectory)\cmake\external\onnx\onnx\backend\test\data\node + .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\x86_64-windows-msvc\QnnCpu.dll" $(Build.SourcesDirectory)\cmake\external\onnx\onnx\backend\test\data\node workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)' displayName: 'Run ONNX Tests' - script: | - .\$(BuildConfig)\onnx_test_runner -j 1 -c 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\x86_64-windows-msvc\QnnCpu.dll" C:\data\float32_models + .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\x86_64-windows-msvc\QnnCpu.dll" C:\data\float32_models workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)' displayName: 'Run float32 model tests'