From 7db7c4e5c80eeecd75dd66a9fa691ac32c3a8a98 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Fri, 26 Jul 2024 14:54:45 -0700 Subject: [PATCH 01/37] Separating all GPU stages into different Pipelines (#21521) ### Description Separating all GPU stages into different Pipelines --- .../win-gpu-cuda-ci-pipeline.yml | 64 +++++++++++++++++++ .../win-gpu-dml-ci-pipeline.yml | 52 +++++++++++++++ .../win-gpu-doc-gen-ci-pipeline.yml | 61 ++++++++++++++++++ .../win-gpu-training-ci-pipeline.yml | 63 ++++++++++++++++++ tools/ci_build/set-trigger-rules.py | 5 +- 5 files changed, 244 insertions(+), 1 deletion(-) create mode 100644 tools/ci_build/github/azure-pipelines/win-gpu-cuda-ci-pipeline.yml create mode 100644 tools/ci_build/github/azure-pipelines/win-gpu-dml-ci-pipeline.yml create mode 100644 tools/ci_build/github/azure-pipelines/win-gpu-doc-gen-ci-pipeline.yml create mode 100644 tools/ci_build/github/azure-pipelines/win-gpu-training-ci-pipeline.yml diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-cuda-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-cuda-ci-pipeline.yml new file mode 100644 index 0000000000000..78e1624b5d123 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/win-gpu-cuda-ci-pipeline.yml @@ -0,0 +1,64 @@ +##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### +trigger: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +pr: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +#### end trigger #### + +parameters: +- name: CudaVersion + displayName: CUDA version + type: string + default: '12.2' + values: + - 11.8 + - 12.2 +- name: RunOnnxRuntimeTests + displayName: Run Tests? + type: boolean + default: true + +stages: +- stage: cuda + dependsOn: [] + jobs: + - template: templates/jobs/win-ci-vs-2022-job.yml + parameters: + BuildConfig: 'RelWithDebInfo' + EnvSetupScript: setup_env_cuda.bat + buildArch: x64 + additionalBuildFlags: >- + --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" + --enable_cuda_profiling --enable_transformers_tool_test + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON + --cmake_extra_defines onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON + msbuildPlatform: x64 + isX86: false + job_name_suffix: x64_RelWithDebInfo + RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} + ORT_EP_NAME: CUDA + WITH_CACHE: true + MachinePool: onnxruntime-Win2022-GPU-A10 \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-dml-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-dml-ci-pipeline.yml new file mode 100644 index 0000000000000..904979f39ca31 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/win-gpu-dml-ci-pipeline.yml @@ -0,0 +1,52 @@ +##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### +trigger: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +pr: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +#### end trigger #### + +parameters: +- name: RunOnnxRuntimeTests + displayName: Run Tests? + type: boolean + default: true + +stages: +- stage: dml + dependsOn: [] + jobs: + - template: templates/jobs/win-ci-vs-2022-job.yml + parameters: + BuildConfig: 'RelWithDebInfo' + EnvSetupScript: setup_env.bat + buildArch: x64 + additionalBuildFlags: --enable_pybind --use_dml --enable_wcos --use_winml + msbuildPlatform: x64 + isX86: false + job_name_suffix: x64_RelWithDebInfo + RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} + ORT_EP_NAME: DML + WITH_CACHE: false + MachinePool: onnxruntime-Win2022-GPU-dml-A10 \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-doc-gen-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-doc-gen-ci-pipeline.yml new file mode 100644 index 0000000000000..4106889331350 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/win-gpu-doc-gen-ci-pipeline.yml @@ -0,0 +1,61 @@ +##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### +trigger: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +pr: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +#### end trigger #### + +parameters: +- name: CudaVersion + displayName: CUDA version + type: string + default: '12.2' + values: + - 11.8 + - 12.2 + +stages: +- stage: kernelDocumentation + dependsOn: [] + jobs: + - template: templates/jobs/win-ci-vs-2022-job.yml + parameters: + BuildConfig: 'RelWithDebInfo' + EnvSetupScript: setup_env_cuda.bat + buildArch: x64 + # note: need to specify `--gen_doc` when creating the build config so it has to be in additionalBuildFlags + additionalBuildFlags: >- + --gen_doc validate --skip_tests --enable_pybind --use_dml --use_cuda + --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF + msbuildPlatform: x64 + isX86: false + job_name_suffix: x64_RelWithDebInfo + RunOnnxRuntimeTests: false + GenerateDocumentation: true + ORT_EP_NAME: CUDA # It doesn't really matter which EP is selected here since this stage is for documentation. + WITH_CACHE: true + MachinePool: onnxruntime-Win2022-GPU-A10 diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-training-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-training-ci-pipeline.yml new file mode 100644 index 0000000000000..3bb6c267f0018 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/win-gpu-training-ci-pipeline.yml @@ -0,0 +1,63 @@ +##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### +trigger: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +pr: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +#### end trigger #### + +parameters: +- name: CudaVersion + displayName: CUDA version + type: string + default: '12.2' + values: + - 11.8 + - 12.2 +- name: RunOnnxRuntimeTests + displayName: Run Tests? + type: boolean + default: true + +stages: +- stage: training + dependsOn: [] + jobs: + - template: templates/jobs/win-ci-vs-2022-job.yml + parameters: + BuildConfig: 'RelWithDebInfo' + EnvSetupScript: setup_env_cuda.bat + buildArch: x64 + additionalBuildFlags: >- + --enable_pybind --enable_training --use_cuda --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" + --skip_onnx_tests + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + msbuildPlatform: x64 + isX86: false + job_name_suffix: x64_RelWithDebInfo + RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} + ORT_EP_NAME: CUDA + WITH_CACHE: true + MachinePool: onnxruntime-Win2022-GPU-A10 + isTraining: true diff --git a/tools/ci_build/set-trigger-rules.py b/tools/ci_build/set-trigger-rules.py index d26fec41033ca..0d90061e9c687 100644 --- a/tools/ci_build/set-trigger-rules.py +++ b/tools/ci_build/set-trigger-rules.py @@ -34,7 +34,10 @@ "orttraining-linux-gpu-ortmodule-distributed-test-ci-pipeline.yml", "orttraining-mac-ci-pipeline.yml", "win-ci-pipeline.yml", - "win-gpu-ci-pipeline.yml", + "win-gpu-ci-dml-pipeline.yml", + "win-gpu-ci-cuda-pipeline.yml", + "win-gpu-ci-training-pipeline.yml", + "win-gpu-ci-doc-gen-pipeline.yml", "win-gpu-tensorrt-ci-pipeline.yml", "win-qnn-arm64-ci-pipeline.yml", "win-qnn-ci-pipeline.yml", From fb61e14153b6a1263c15ea3b62d6bbbc5bde9848 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Fri, 26 Jul 2024 16:56:44 -0700 Subject: [PATCH 02/37] Add QNN EP option context_node_name_prefix to set EPContext node name prefix (#21236) ### Description Add QNN EP option context_node_name_prefix to set EPContext node name prefix ### Motivation and Context For the case to workaround QNN context PD memory limit, user need split the model into pieces and generate the QNN context model separately. It could happen that the generated EPContext node in separate graph has same node name. This will cause issue if glue those EPContext nodes together into a single model. To avoid this user can set this context_node_name_prefix for each split pieces to make the node name unique. --- .../onnxruntime_session_options_config_keys.h | 4 ++ .../providers/qnn/qnn_execution_provider.cc | 9 ++++- .../providers/qnn/qnn_execution_provider.h | 1 + .../test/providers/qnn/qnn_ep_context_test.cc | 39 +++++++++++++++++++ 4 files changed, 52 insertions(+), 1 deletion(-) diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 17ae649e6f174..209fd4279cc99 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -265,6 +265,10 @@ static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_p // "1": dump the EP context into the Onnx model. (default). static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; +// Specify the EPContext node name prefix to make it unique +// in case user need to merge/connect multiple EPContext nodes in one model +static const char* const kOrtSessionOptionEpContextNodeNamePrefix = "ep.context_node_name_prefix"; + // Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul. // Option values: // - "0": Gemm FastMath mode is not enabled. [DEFAULT] diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 539b456cb657f..c56a47e67497e 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -199,6 +199,13 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio context_cache_path_cfg_ = session_options->config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path_cfg_; + + // For the case that workaround QNN context PD memory limit, user need split the model into pieces and + // generate the QNN context model separately. + // It could happen that the generated EPContext node in separate graph has same node name. + // User can set this context_node_name_prefix for each split pieces to avoid that happens. + context_node_name_prefix_ = session_options->config_options.GetConfigOrDefault(kOrtSessionOptionEpContextNodeNamePrefix, ""); + LOGS_DEFAULT(VERBOSE) << "User specified QNN context node name prefix: " << context_node_name_prefix_; } static const std::string BACKEND_PATH = "backend_path"; @@ -613,7 +620,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer const auto gen_metadef_name = [&]() { uint64_t model_hash; int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); - return MakeString(QNN, "_", model_hash, "_", metadef_id); + return MakeString(QNN, context_node_name_prefix_, "_", model_hash, "_", metadef_id); }; // For model with EPContext, make sure each partition only has one single EPContext node diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index e7419dabb14d1..f00ffb6cfdb96 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -80,6 +80,7 @@ class QNNExecutionProvider : public IExecutionProvider { std::unordered_map> qnn_models_; bool context_cache_enabled_ = false; std::string context_cache_path_cfg_ = ""; + std::string context_node_name_prefix_ = ""; bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session. bool qnn_context_embed_mode_ = true; int32_t vtcm_size_in_mb_ = 0; diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index a3768cb98f584..be3bd2cc5dcd7 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -279,6 +279,45 @@ TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } +TEST_F(QnnHTPBackendTests, QnnContextGenerationNodeNamePrefix) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + std::string node_name_prefix = "node_name_prefix_test"; + + // Add kMSDomain to cover contrib op like Gelu + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextNodeNamePrefix, node_name_prefix.c_str()); + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); + for (auto& node : model->MainGraph().Nodes()) { + if (node.OpType() == "EPContext") { + EXPECT_TRUE(node.Name().find(node_name_prefix) != std::string::npos); + } + } + + // clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + // Run QDQ model on HTP 3 times // 1st run will generate the Qnn context cache onnx file // 2nd run directly loads and run from Qnn context cache model From 64819f6f8cad8387b23d7cc8af1a4b4207e2dfbb Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 26 Jul 2024 18:45:14 -0700 Subject: [PATCH 03/37] Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency. --- .../python/transformers/benchmark_mha.cmd | 47 ++ .../test/python/transformers/benchmark_mha.py | 690 +++++++++++++----- .../test/python/transformers/benchmark_mha.sh | 48 +- .../test/python/transformers/test_mha.py | 46 +- 4 files changed, 609 insertions(+), 222 deletions(-) create mode 100644 onnxruntime/test/python/transformers/benchmark_mha.cmd diff --git a/onnxruntime/test/python/transformers/benchmark_mha.cmd b/onnxruntime/test/python/transformers/benchmark_mha.cmd new file mode 100644 index 0000000000000..0a6d0c37b4a35 --- /dev/null +++ b/onnxruntime/test/python/transformers/benchmark_mha.cmd @@ -0,0 +1,47 @@ +echo "Benchmark Scaled Dot Product Attention (SDPA) performance on GPU:" + +set CUDA_VISIBLE_DEVICES=0 +python benchmark_mha.py --use_gpu +python benchmark_mha.py --use_gpu --use_cuda_graph +python benchmark_mha.py --use_gpu --torch + +type benchmark_mha_gpu_*.csv > mha_gpu_benchmark_results.csv + +echo "Benchmark performance on CPU with number of threads:" +set MKL_DYNAMIC=FALSE +set OMP_NUM_THREADS=1 +python benchmark_mha.py --torch + +set OMP_NUM_THREADS=2 +python benchmark_mha.py --torch + +set OMP_NUM_THREADS=4 +python benchmark_mha.py --torch + +set OMP_NUM_THREADS=8 +python benchmark_mha.py --torch + +set MKL_DYNAMIC= +set OMP_NUM_THREADS= + +set ORT_DISABLE_FLASH_ATTENTION=0 +python benchmark_mha.py --intra_op_num_threads 1 +python benchmark_mha.py --intra_op_num_threads 2 +python benchmark_mha.py --intra_op_num_threads 4 +python benchmark_mha.py --intra_op_num_threads 8 + +echo "Benchmark performance on CPU with default threads settings:" +python benchmark_mha.py + +python benchmark_mha.py --torch + +python benchmark_mha.py --causal +python benchmark_mha.py --torch --causal + +python benchmark_mha.py --causal --has_past + +set ORT_DISABLE_FLASH_ATTENTION=1 +python benchmark_mha.py +set ORT_DISABLE_FLASH_ATTENTION= + +type benchmark_mha_cpu_*.csv > mha_cpu_benchmark_results.csv diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 111c417479d20..715a92431e6bf 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -4,21 +4,35 @@ # -------------------------------------------------------------------------- """ -Benchmark performance of MultiHeadAttention with Nvidia GPU of Compute Capability 8.0, 8.6 or 8.9 in Linux: -sh benchmark_mha.sh +Benchmark performance of MultiHeadAttention with ORT or PyTorch. + +In Linux, run the the following: + sh benchmark_mha.sh + +In Windows, run the the following: + benchmark_mha.cmd """ +import argparse +import csv import math import os import platform import statistics import time -from typing import List, Optional +from contextlib import nullcontext +from datetime import datetime +from enum import IntEnum +from typing import Callable, Dict, List, Optional, Tuple import torch +import torch.utils.benchmark as benchmark from onnx import TensorProto, helper +from packaging.version import Version +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.nn.functional import scaled_dot_product_attention -from onnxruntime import InferenceSession, get_available_providers +from onnxruntime import InferenceSession, SessionOptions, get_available_providers from onnxruntime.transformers.io_binding_helper import CudaSession @@ -43,6 +57,20 @@ def get_name_list() -> List[str]: return ["Q,K,V", "QKV", "Q,KV", "Q,K',V'"] +class SdpaKernel(IntEnum): + """Bit flags for sdpa_kernel CUDA provider option""" + + DEFAULT = 0 + FLASH_ATTENTION = 1 + EFFICIENT_ATTENTION = 2 + TRT_FUSED_ATTENTION = 4 + CUDNN_FLASH_ATTENTION = 8 + MATH = 16 + TRT_FLASH_ATTENTION = 32 + TRT_CROSS_ATTENTION = 64 + TRT_CAUSAL_ATTENTION = 128 + + class MultiHeadAttentionConfig: def __init__( self, @@ -62,6 +90,7 @@ def __init__( use_kv_cache: bool = False, share_past_present_buffer: bool = False, input_format: int = InputFormats.Q_K_V_BSNH_BSNH_BSNH, + verbose: bool = False, ): self.operator = "MultiHeadAttention" self.batch_size = batch_size @@ -100,6 +129,7 @@ def __init__( self.input_format = input_format self.is_packed_qkv = input_format == InputFormats.QKV_BSN3H self.is_packed_kv = input_format == InputFormats.Q_KV_BSNH_BSN2H + self.verbose = verbose def __repr__(self): return ( @@ -114,89 +144,93 @@ def __repr__(self): ) def shape_dict(self, input_format=None): + shapes: Dict[str, Tuple] = { + "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + } + input_format = input_format or self.input_format - if input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: - # cross attention does not have past state - return { + if input_format == InputFormats.QKV_BSN3H: + shapes = { + **shapes, + "query": (self.batch_size, self.sequence_length, self.num_heads, 3, self.head_size), + } + elif input_format == InputFormats.Q_KV_BSNH_BSN2H: + shapes = { + **shapes, + "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "key": (self.batch_size, self.sequence_length, self.num_heads, 2, self.head_size), + } + elif input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: + shapes = { + **shapes, + "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "key": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "value": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + } + else: + assert input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH + shapes = { + **shapes, "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), "key": (self.batch_size, self.num_heads, self.sequence_length, self.head_size), "value": (self.batch_size, self.num_heads, self.sequence_length, self.head_size), - "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), } if self.use_kv_cache: + assert input_format != InputFormats.Q_K_V_BSNH_BNSH_BNSH, "cross attention shall not have past state" shapes = { + **shapes, "past_key": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size), "past_value": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size), - "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), "present_key": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size), "present_value": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size), } - else: - shapes = { - "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - } - if input_format == InputFormats.QKV_BSN3H: - shapes.update({"query": (self.batch_size, self.sequence_length, self.num_heads, 3, self.head_size)}) - elif input_format == InputFormats.Q_KV_BSNH_BSN2H: - shapes.update( - { - "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - "key": (self.batch_size, self.sequence_length, self.num_heads, 2, self.head_size), - } - ) - else: # input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH - shapes.update( - { - "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - "key": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - "value": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - } - ) return shapes def symbolic_shape_dict(self, input_format=None): + shapes: Dict[str, Tuple] = { + "output": ("batch_size", "sequence_length", self.num_heads * self.head_size), + } + input_format = input_format or self.input_format - if input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: - # cross attention does not have past state - return { + if input_format == InputFormats.QKV_BSN3H: + shapes = { + **shapes, + "query": ("batch_size", "sequence_length", self.num_heads, 3, self.head_size), + } + elif input_format == InputFormats.Q_KV_BSNH_BSN2H: + shapes = { + **shapes, + "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), + "key": ("batch_size", "sequence_length", self.num_heads, 2, self.head_size), + } + elif input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: + shapes = { + **shapes, + "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), + "key": ("batch_size", "sequence_length", self.num_heads * self.head_size), + "value": ("batch_size", "sequence_length", self.num_heads * self.head_size), + } + else: + assert input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH + shapes = { + **shapes, "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), "key": ("batch_size", self.num_heads, "sequence_length", self.head_size), "value": ("batch_size", self.num_heads, "sequence_length", self.head_size), - "output": ("batch_size", "sequence_length", self.num_heads * self.head_size), } if self.use_kv_cache: + assert input_format != InputFormats.Q_K_V_BSNH_BNSH_BNSH, "cross attention shall not have past state" shapes = { + **shapes, "past_key": ("batch_size", self.num_heads, "past_buffer_length", self.head_size), "past_value": ("batch_size", self.num_heads, "past_buffer_length", self.head_size), - "output": ("batch_size", "sequence_length", self.num_heads * self.head_size), "present_key": ("batch_size", self.num_heads, "present_buffer_length", self.head_size), "present_value": ("batch_size", self.num_heads, "present_buffer_length", self.head_size), } - else: - shapes = { - "output": ("batch_size", "sequence_length", self.num_heads * self.head_size), - } - if input_format == InputFormats.QKV_BSN3H: - shapes.update({"query": ("batch_size", "sequence_length", self.num_heads, 3, self.head_size)}) - elif input_format == InputFormats.Q_KV_BSNH_BSN2H: - shapes.update( - { - "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), - "key": ("batch_size", "sequence_length", self.num_heads, 2, self.head_size), - } - ) - else: # input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH - shapes.update( - { - "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), - "key": ("batch_size", "sequence_length", self.num_heads * self.head_size), - "value": ("batch_size", "sequence_length", self.num_heads * self.head_size), - } - ) return shapes def random_inputs(self, seed: int = 123): @@ -215,44 +249,42 @@ def random_inputs(self, seed: int = 123): k_bnsh = k.transpose(1, 2) v_bnsh = v.transpose(1, 2) - if self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: - return { + if self.input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: + feeds = { "query": q.reshape(shape_dict["query"]), - "key": k_bnsh.contiguous(), - "value": v_bnsh.contiguous(), + "key": k.reshape(shape_dict["key"]), + "value": v.reshape(shape_dict["value"]), } - - feeds = {} - if self.use_kv_cache: - feeds.update( - { - "past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_( - mean=0, std=0.1 - ), - "past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_( - mean=0, std=0.1 - ), - } - ) - - if self.input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: - feeds.update( - { - "query": q.reshape(shape_dict["query"]), - "key": k.reshape(shape_dict["key"]), - "value": v.reshape(shape_dict["value"]), - } - ) elif self.input_format == InputFormats.QKV_BSN3H: query = q.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) key = k.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) value = v.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) - feeds["query"] = torch.dstack((query, key, value)).reshape(shape_dict["query"]).contiguous() + feeds = { + "query": torch.dstack((query, key, value)).reshape(shape_dict["query"]).contiguous(), + } elif self.input_format == InputFormats.Q_KV_BSNH_BSN2H: key = k.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) value = v.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) - feeds["query"] = q.reshape(shape_dict["query"]) - feeds["key"] = torch.dstack((key, value)).reshape(shape_dict["key"]).contiguous() + feeds = { + "query": q.reshape(shape_dict["query"]), + "key": torch.dstack((key, value)).reshape(shape_dict["key"]).contiguous(), + } + else: + assert self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH + feeds = { + "query": q.reshape(shape_dict["query"]), + "key": k_bnsh.contiguous(), + "value": v_bnsh.contiguous(), + } + + if self.use_kv_cache: + feeds = { + **feeds, + "past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_(mean=0, std=0.1), + "past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_( + mean=0, std=0.1 + ), + } return feeds @@ -318,19 +350,32 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use return model.SerializeToString() -def create_session( +def create_ort_session( config: MultiHeadAttentionConfig, + session_options=None, + attention_kernel=SdpaKernel.DEFAULT, + use_symbolic_shape: bool = True, ) -> CudaSession: - onnx_model_str = create_multi_head_attention_onnx_model(config) + if config.verbose: + print(f"create session for {vars(config)}") + onnx_model_str = create_multi_head_attention_onnx_model(config, use_symbolic_shape=use_symbolic_shape) if config.provider == "CUDAExecutionProvider": device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index provider_options = CudaSession.get_cuda_provider_options(device_id, config.enable_cuda_graph) + provider_options["sdpa_kernel"] = int(attention_kernel) providers = [(config.provider, provider_options), "CPUExecutionProvider"] else: providers = ["CPUExecutionProvider"] - ort_session = InferenceSession(onnx_model_str, providers=providers) + ort_session = InferenceSession(onnx_model_str, session_options, providers=providers) + return ort_session + + +def create_session( + config: MultiHeadAttentionConfig, session_options=None, attention_kernel=SdpaKernel.DEFAULT +) -> CudaSession: + ort_session = create_ort_session(config, session_options, attention_kernel, use_symbolic_shape=False) cuda_session = CudaSession(ort_session, config.device, config.enable_cuda_graph) shape_dict = config.shape_dict() cuda_session.allocate_buffers(shape_dict) @@ -340,11 +385,8 @@ def create_session( class OrtMultiHeadAttention: """A wrapper of ORT MultiHeadAttention to test relevance and performance.""" - def __init__( - self, - config: MultiHeadAttentionConfig, - ): - self.ort_session = create_session(config) + def __init__(self, config: MultiHeadAttentionConfig, session_options=None): + self.ort_session = create_session(config, session_options) self.feed_dict = config.random_inputs() def infer(self): @@ -363,53 +405,90 @@ def flops(batch, sequence_length, head_size, num_heads, causal): def tflops_per_second(flop, time): - return (flop / time / 10**12) if not math.isnan(time) else 0.0 - - -def get_gpu_kernel_name(config: MultiHeadAttentionConfig) -> str: - # This classification is for Nvidia GPU of Compute Capability 8.* like A100. - # Note that some kernel might not exist in older or newer GPUs. - if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": - if config.input_format == InputFormats.QKV_BSN3H: - min_seq_len = os.getenv("ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV") - min_length = int(min_seq_len) if min_seq_len is not None else 513 - if config.sequence_length >= min_length: - return "Flash" - else: - return "Flash" + try: + return (flop / time / 10**12) if not math.isnan(time) else 0.0 + except ZeroDivisionError: + return None + + +def get_gpu_kernel_name(attention_kernel: SdpaKernel) -> str: + kernel_names = { + SdpaKernel.DEFAULT: "ort:default", + SdpaKernel.FLASH_ATTENTION: "ort:flash", + SdpaKernel.EFFICIENT_ATTENTION: "ort:efficient", + SdpaKernel.CUDNN_FLASH_ATTENTION: "ort:cudnn", + SdpaKernel.MATH: "ort:math", + } + assert attention_kernel in kernel_names + return kernel_names[attention_kernel] - if (os.getenv("ORT_DISABLE_FUSED_CROSS_ATTENTION") != "1" and config.kv_sequence_length <= 128) or ( - os.getenv("ORT_DISABLE_FUSED_ATTENTION") != "1" - and (config.sequence_length <= 384 or os.getenv("ORT_DISABLE_TRT_FLASH_ATTENTION") != "1") - ): - return "TRT" - if os.getenv("ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION") != "1": - return "MemEff" +def get_cpu_kernel_name(config: MultiHeadAttentionConfig) -> str: + # CPU Flash Attention does not support causal and kv cache etc. + if not (config.causal or config.use_kv_cache or config.past_sequence_length > 0): + if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": + return "ort:flash" - return "Unfused" + return "ort:math" -def get_cpu_kernel_name() -> str: - if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": - return "CPU:Flash" - return "CPU:Unfused" +# ------------------------------------------------------------------ +# Functions for benchmarking PyTorch SDPA +# ------------------------------------------------------------------ +def benchmark_torch_function(func: Callable, *args, **kwargs) -> float: + warmup = 5 + repeats = 100 + for _ in range(warmup): + func(*args, **kwargs) + timer = benchmark.Timer( + stmt="func(*args, **kwargs)", + globals={"args": args, "kwargs": kwargs, "func": func}, + ) + + return timer.timeit(number=repeats).median -def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repeats: int = 100): - if use_gpu: - device_id = torch.cuda.current_device() - device = torch.device("cuda", device_id) - formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.Q_KV_BSNH_BSN2H, InputFormats.QKV_BSN3H] - provider = "CUDAExecutionProvider" - print(f"enable_cuda_graph={enable_cuda_graph}") - else: - device_id = 0 - device = torch.device("cpu") - formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] - enable_cuda_graph = False - provider = "CPUExecutionProvider" +def run_torch_sdpa( + batch_size: int, + q_seq_len: int, + kv_seq_len: int, + num_heads: int, + head_size: int, + causal: bool, + device, + dtype, + has_mask: bool = False, + mask_dim: int = 2, + mask_dtype=torch.bool, + backend: Optional[int] = None, +): + q_shape = (batch_size, num_heads, q_seq_len, head_size) + kv_shape = (batch_size, num_heads, kv_seq_len, head_size) + q = torch.randn(q_shape, device=device, dtype=dtype) + k = torch.randn(kv_shape, device=device, dtype=dtype) + v = torch.randn(kv_shape, device=device, dtype=dtype) + + attn_mask = None + if has_mask: + mask_shape = (batch_size, num_heads, q_seq_len, kv_seq_len) if mask_dim == 4 else (q_seq_len, kv_seq_len) + attn_mask = torch.ones(mask_shape, dtype=mask_dtype, device=device) + + context = sdpa_kernel(backend) if backend is not None else nullcontext() + + with context: + average_latency = benchmark_torch_function( + scaled_dot_product_attention, + q, + k, + v, + is_causal=causal, + attn_mask=attn_mask, + ) + return average_latency + + +def get_test_configs(use_gpu: bool = True): if use_gpu: # (batch_size, sequence_length, past_sequence_length, num_heads, head_size, run_unfused) configs = [ @@ -450,31 +529,70 @@ def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repea ] else: configs = [ + # TNLGv4 (1, 128, 0, 32, 128, True), (1, 256, 0, 32, 128, True), (1, 512, 0, 32, 128, True), (1, 1024, 0, 32, 128, True), - (1, 2048, 0, 32, 128, True), + # (1, 2048, 0, 32, 128, True), + # bert-base + (1, 128, 0, 12, 64, True), + (1, 384, 0, 12, 64, True), + (1, 512, 0, 12, 64, True), + (4, 128, 0, 12, 64, True), + (4, 384, 0, 12, 64, True), + (4, 512, 0, 12, 64, True), + # bert-large + (1, 128, 0, 16, 64, True), + (1, 384, 0, 16, 64, True), + (1, 512, 0, 16, 64, True), + (4, 128, 0, 16, 64, True), + (4, 384, 0, 16, 64, True), + (4, 512, 0, 16, 64, True), ] + return configs + + +def get_compute_capability(): + assert torch.cuda.is_available() + major, minor = torch.cuda.get_device_capability() + sm = major * 10 + minor + return sm - # List of environment variables to enable/disable attention kernels - print("Environment Variables:") - env_names = [ - "ORT_DISABLE_FLASH_ATTENTION", - "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV", - "ORT_DISABLE_FUSED_ATTENTION", - "ORT_DISABLE_TRT_FLASH_ATTENTION", - "ORT_ENABLE_FUSED_CAUSAL_ATTENTION", - "ORT_DISABLE_FUSED_CROSS_ATTENTION", - "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION", - ] - for name in env_names: - value = os.getenv(name) - if value is not None: - print(f"{name}={value}") - print("\nformat\tcausal\tbatch\tseqlen\theads\th_dim\tms\tTFLOPS\tkernel") - causal = False +def run_tflops_test( + csv_writer: csv.DictWriter, + use_gpu: bool = True, + enable_cuda_graph: bool = False, + causal: bool = False, + has_past: bool = False, + intra_op_num_threads: int = 0, + repeats: int = 100, +): + print(f"run_tflops_test: causal={causal}") + + if use_gpu: + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.Q_KV_BSNH_BSN2H, InputFormats.QKV_BSN3H] + provider = "CUDAExecutionProvider" + # flash attention is available for sm >= 80 + sm = get_compute_capability() + if sm >= 80: + backends = [SdpaKernel.DEFAULT, SdpaKernel.FLASH_ATTENTION, SdpaKernel.EFFICIENT_ATTENTION] + else: + backends = [SdpaKernel.DEFAULT, SdpaKernel.EFFICIENT_ATTENTION] + else: + device_id = 0 + device = torch.device("cpu") + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] + enable_cuda_graph = False + provider = "CPUExecutionProvider" + backends = [SdpaKernel.DEFAULT] + + configs = get_test_configs(use_gpu) + + print("\nformat\tcausal\tprompt\tbatch\tseqlen\theads\th_dim\tthreads\tms\tTFLOPS\tkernel") for input_format in formats: for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs: @@ -496,21 +614,27 @@ def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repea share_past_present_buffer=False, input_format=input_format, ) - - session = create_session(config) + for attention_kernel in backends: + sess_options = SessionOptions() + sess_options.intra_op_num_threads = intra_op_num_threads + session = create_session(config, sess_options, attention_kernel=attention_kernel) if use_gpu: - kernel = get_gpu_kernel_name(config) + kernel = get_gpu_kernel_name(attention_kernel) else: - kernel = get_cpu_kernel_name() + kernel = get_cpu_kernel_name(config) - if kernel == "Unfused": + if "math" in kernel: # Skip large sequence length for Unfused kernel to avoid OOM. if not enable_unfused: + if config.verbose: + print(f"skip unfused kernel for {vars(config)}") continue # Unfused kernel does not support packed QKV or packed KV formats. if input_format not in [InputFormats.Q_K_V_BSNH_BSNH_BSNH]: + if config.verbose: + print(f"skip input_format for {vars(config)}") continue input_dict = config.random_inputs() @@ -526,19 +650,168 @@ def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repea del session + format_str = InputFormats.input_format_str(input_format) + # compute TFLOPS per second - speed = tflops_per_second( - flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency - ) + speed = None + if past_sequence_length == 0: + speed = tflops_per_second( + flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency + ) + + row = { + "use_gpu": use_gpu, + "enable_cuda_graph": enable_cuda_graph, + "format": format_str, + "causal": causal, + "batch_size": batch_size, + "sequence_length": sequence_length, + "past_sequence_length": past_sequence_length, + "num_heads": num_heads, + "head_size": head_size, + "intra_op_num_threads": intra_op_num_threads, + "average_latency": average_latency, + "tflops": speed, + "kernel": kernel, + } + csv_writer.writerow(row) - format = InputFormats.input_format_str(input_format) + speed = f"{speed:.2f}" if speed is not None else "NA" print( - f"{format}\t{causal}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t{average_latency * 1000:.2f}\t{speed:.2f}\t{kernel}" + f"{format_str}\t{causal}\t{not has_past}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t" + f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{kernel}" ) +def run_torch_test( + csv_writer: csv.DictWriter, + use_gpu: bool = True, + causal: bool = False, +): + configs = get_test_configs(use_gpu) + + if use_gpu: + if not torch.cuda.is_available(): + return + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + dtype = torch.float16 + backends = [ + None, + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.CUDNN_ATTENTION, + SDPBackend.MATH, + ] + else: + device = torch.device("cpu") + dtype = torch.float32 + backends = [None] + + backend_names = { + SDPBackend.FLASH_ATTENTION: "torch:flash", + SDPBackend.EFFICIENT_ATTENTION: "torch:efficient", + SDPBackend.CUDNN_ATTENTION: "torch:cudnn", + SDPBackend.MATH: "torch:math", + None: "torch:default", + } + + # Test PyTorch latency + for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs: + for backend in backends: + if backend == SDPBackend.MATH and not enable_unfused: + continue + if backend == SDPBackend.FLASH_ATTENTION and platform.system() != "Linux": + continue + + backend_name = backend_names[backend] + try: + with torch.no_grad(): + torch_latency = run_torch_sdpa( + batch_size, + sequence_length, + sequence_length, + num_heads, + head_size, + causal, + has_mask=False, + mask_dim=2, + mask_dtype=torch.bool, + device=device, + dtype=dtype, + backend=backend, + ) + except RuntimeError: + continue + + speed = tflops_per_second(flops(batch_size, sequence_length, head_size, num_heads, causal), torch_latency) + input_format = "Q,K,V" + print( + f"{input_format}\t{causal}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t" + f"{0}\t{torch_latency * 1000:.2f}\t{speed:.2f}\t{backend_name}" + ) + row = { + "use_gpu": use_gpu, + "enable_cuda_graph": False, + "format": input_format, + "causal": causal, + "batch_size": batch_size, + "sequence_length": sequence_length, + "past_sequence_length": past_sequence_length, + "num_heads": num_heads, + "head_size": head_size, + "intra_op_num_threads": torch.get_num_threads(), + "average_latency": torch_latency, + "tflops": speed, + "kernel": backend_name, + } + csv_writer.writerow(row) + + +def run_tflops_tests(args): + features = "gpu" if args.use_gpu else "cpu" + if args.causal: + features += "_causal" + if args.has_past: + features += "_past" + csv_filename = "benchmark_mha_{}_{}_{}.csv".format( + features, + "torch" if args.torch else "ort", + datetime.now().strftime("%Y%m%d-%H%M%S"), + ) + with open(csv_filename, mode="a", newline="") as csv_file: + column_names = [ + "use_gpu", + "enable_cuda_graph", + "format", + "causal", + "batch_size", + "sequence_length", + "past_sequence_length", + "num_heads", + "head_size", + "intra_op_num_threads", + "average_latency", + "tflops", + "kernel", + ] + csv_writer = csv.DictWriter(csv_file, fieldnames=column_names) + csv_writer.writeheader() + + if args.torch: + run_torch_test(csv_writer, args.use_gpu, args.causal) + else: + run_tflops_test( + csv_writer, + use_gpu=args.use_gpu, + enable_cuda_graph=args.use_cuda_graph, + causal=args.causal, + has_past=args.has_past, + intra_op_num_threads=args.intra_op_num_threads, + ) + + def plot_prompt_performance( - sm: int, model_name: str, batch_size: int, num_heads: int, @@ -558,6 +831,7 @@ def plot_prompt_performance( "styles": [("red", "solid"), ("yellow", "dashdot"), ("blue", "dashed"), ("green", "dotted")][0 : len(formats)], } + sm = get_compute_capability() configs = [ triton.testing.Benchmark( x_names=["sequence_length"], @@ -591,13 +865,14 @@ def benchmark( sequence_length=sequence_length, num_heads=num_heads, head_size=head_size, - causal=True, + causal=False, past_sequence_length=0, kv_sequence_length=sequence_length if input_format == InputFormats.get_name_list()[-1] else None, max_cache_sequence_length=max_seq_len, provider="CUDAExecutionProvider", enable_cuda_graph=False, device=device, + dtype=torch.float16, use_kv_cache=False, input_format=InputFormats.convert(input_format), ) @@ -609,14 +884,14 @@ def benchmark( benchmark.run(save_path=".", print_data=True) -def run_performance_test(sm: int): +def run_bert_performance_test(): """ Run performance tests for prompt and token generation. """ configures = [ - (1, 32, 128, 8192, "TNLGv4"), - (4, 32, 128, 8192, "TNLGv4"), + # (1, 32, 128, 8192, "TNLGv4"), + # (4, 32, 128, 8192, "TNLGv4"), (1, 12, 64, 1024, "BertBase"), (16, 12, 64, 1024, "BertBase"), (1, 16, 64, 1024, "BertLarge"), @@ -625,7 +900,6 @@ def run_performance_test(sm: int): for batch_size, num_heads, head_size, max_seq_len, model_name in configures: plot_prompt_performance( - sm=sm, batch_size=batch_size, num_heads=num_heads, head_size=head_size, @@ -634,18 +908,84 @@ def run_performance_test(sm: int): ) +def _parse_arguments(): + parser = argparse.ArgumentParser(description="Benchmark MultiHeadAttention for ONNX Runtime and PyTorch.") + + parser.add_argument( + "--use_gpu", + required=False, + action="store_true", + help="Use GPU for inference.", + ) + parser.set_defaults(use_gpu=False) + + parser.add_argument( + "--use_cuda_graph", + required=False, + action="store_true", + help="Use cuda graph in onnxruntime.", + ) + parser.set_defaults(use_cuda_graph=False) + + parser.add_argument( + "--intra_op_num_threads", + required=False, + type=int, + choices=[0, 1, 2, 4, 8, 16], + default=0, + help="intra_op_num_threads for onnxruntime. ", + ) + + parser.add_argument( + "--has_past", + required=False, + action="store_true", + help="whether past_sequence_length > 0", + ) + parser.set_defaults(has_past=False) + + parser.add_argument( + "--causal", + required=False, + action="store_true", + help="test unidirectional", + ) + parser.set_defaults(causal=False) + + parser.add_argument( + "--torch", + required=False, + action="store_true", + help="test pytorch instead of onnxruntime", + ) + parser.set_defaults(torch=False) + + args = parser.parse_args() + + return args + + if __name__ == "__main__": - if torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers(): - # Test CUDA provider - major, minor = torch.cuda.get_device_capability() - sm = major * 10 + minor + args = _parse_arguments() + print(f"arguments:{args}") + + if args.has_past: + assert args.causal, "--has_past need --causal specified" + + if args.use_gpu: + assert args.torch or not args.causal, "no causal cuda kernel in MHA op" + assert torch.cuda.is_available() + if not args.torch: + assert "CUDAExecutionProvider" in get_available_providers() + if args.torch: + assert Version(torch.__version__) >= Version("2.3.0") + assert args.has_past is False + + if args.use_gpu and not args.torch: if platform.system() == "Linux": s = torch.cuda.Stream() with torch.cuda.stream(s), torch.no_grad(): - run_performance_test(sm) - - run_tflops_test(use_gpu=True, enable_cuda_graph=True) + run_bert_performance_test() - # Test CPU provider - run_tflops_test(use_gpu=False, enable_cuda_graph=False) + run_tflops_tests(args) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.sh b/onnxruntime/test/python/transformers/benchmark_mha.sh index 7b21cf1cc1e08..613543d0172dd 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.sh +++ b/onnxruntime/test/python/transformers/benchmark_mha.sh @@ -1,14 +1,40 @@ -echo "flash attention v2" -ORT_DISABLE_FLASH_ATTENTION=0 ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV=0 python benchmark_mha.py | tee result.txt +#!/bin/sh -echo "===" -echo "TensorRT attention kernels - cross attention (when kv_seq_len <= 128) or fused attention (when seq_len <= 384) or flash attention (seq_len > 384)" -ORT_DISABLE_FLASH_ATTENTION=1 python benchmark_mha.py | tee -a result.txt +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- -echo "===" -echo "Memory Efficient attention" -ORT_DISABLE_FLASH_ATTENTION=1 ORT_DISABLE_TRT_FLASH_ATTENTION=1 ORT_DISABLE_FUSED_ATTENTION=1 ORT_DISABLE_FUSED_CROSS_ATTENTION=1 python benchmark_mha.py | tee -a result.txt +echo "Benchmark Scaled Dot Product Attention (SDPA) performance on GPU:" -echo "===" -echo "Unfused Attention (some configurations might fail)" -ORT_DISABLE_FLASH_ATTENTION=1 ORT_DISABLE_TRT_FLASH_ATTENTION=1 ORT_DISABLE_FUSED_ATTENTION=1 ORT_DISABLE_FUSED_CROSS_ATTENTION=1 ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION=1 python benchmark_mha.py | tee -a result.txt +export CUDA_VISIBLE_DEVICES=0 +python benchmark_mha.py --use_gpu +python benchmark_mha.py --use_gpu --use_cuda_graph +python benchmark_mha.py --use_gpu --torch + +cat benchmark_mha_gpu_*.csv > mha_gpu_benchmark_results.csv + +echo "Benchmark performance on CPU with number of threads:" +MKL_DYNAMIC=FALSE OMP_NUM_THREADS=1 python benchmark_mha.py --torch +MKL_DYNAMIC=FALSE OMP_NUM_THREADS=2 python benchmark_mha.py --torch +MKL_DYNAMIC=FALSE OMP_NUM_THREADS=4 python benchmark_mha.py --torch +MKL_DYNAMIC=FALSE OMP_NUM_THREADS=8 python benchmark_mha.py --torch + +python benchmark_mha.py --intra_op_num_threads 1 +python benchmark_mha.py --intra_op_num_threads 2 +python benchmark_mha.py --intra_op_num_threads 4 +python benchmark_mha.py --intra_op_num_threads 8 + + +echo "Benchmark performance on CPU with default threads settings:" +python benchmark_mha.py +ORT_DISABLE_FLASH_ATTENTION=1 python benchmark_mha.py +python benchmark_mha.py --torch + +python benchmark_mha.py --causal +python benchmark_mha.py --torch --causal + +# Pytorch SDPA does not support causal attention with past state, we only test ORT here. +python benchmark_mha.py --causal --has_past + +cat benchmark_mha_cpu_*.csv > mha_cpu_benchmark_results.csv diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index ff473cc2ced92..0fcbd889847e9 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -10,36 +10,15 @@ import concurrent.futures import itertools import unittest -from enum import IntEnum from typing import Dict, List, Optional import numpy import torch -from benchmark_mha import ( - InputFormats, - MultiHeadAttentionConfig, - OrtMultiHeadAttention, - create_multi_head_attention_onnx_model, -) +from benchmark_mha import InputFormats, MultiHeadAttentionConfig, OrtMultiHeadAttention, SdpaKernel, create_ort_session from einops import rearrange from parameterized import parameterized import onnxruntime -from onnxruntime import InferenceSession - - -class SdpaKernel(IntEnum): - """Bit flags for sdpa_kernel CUDA provider option""" - - DEFAULT = 0 - FLASH_ATTENTION = 1 - EFFICIENT_ATTENTION = 2 - TRT_FUSED_ATTENTION = 4 - CUDNN_FLASH_ATTENTION = 8 - MATH = 16 - TRT_FLASH_ATTENTION = 32 - TRT_CROSS_ATTENTION = 64 - TRT_CAUSAL_ATTENTION = 128 def attention_reference( @@ -466,7 +445,7 @@ def parity_check_mha_multi_threading( test_inputs: List[Dict], rtol: float = 1e-3, atol: float = 1e-3, - sdpa_kernel: int = SdpaKernel.DEFAULT, + attention_kernel: int = SdpaKernel.DEFAULT, max_threads: int = 5, verbose: bool = False, ): @@ -476,21 +455,14 @@ def parity_check_mha_multi_threading( if config.causal and config.provider == "CUDAExecutionProvider": return None # Some kernel does not support certain input format. - if sdpa_kernel not in [ + if attention_kernel not in [ SdpaKernel.DEFAULT, SdpaKernel.FLASH_ATTENTION, SdpaKernel.EFFICIENT_ATTENTION, ] and config.input_format in [InputFormats.Q_KV_BSNH_BSN2H]: return None - if verbose: - print(f"create a shared session with {vars(config)}") - onnx_model_str = create_multi_head_attention_onnx_model(config, use_symbolic_shape=True) - if config.provider == "CUDAExecutionProvider": - provider_options = {"arena_extend_strategy": "kSameAsRequested", "sdpa_kernel": int(sdpa_kernel)} - providers = [(config.provider, provider_options), "CPUExecutionProvider"] - else: - providers = ["CPUExecutionProvider"] - ort_session = InferenceSession(onnx_model_str, providers=providers) + + ort_session = create_ort_session(config, attention_kernel=attention_kernel, use_symbolic_shape=True) def convert_to_ort_inputs(feed_dict): ort_inputs = {} @@ -613,7 +585,7 @@ def test_mha_cuda(self, config): def test_mha_cpu(self, config): parity_check_mha(config) - def run_mha_cuda_multi_threading(self, spda_kernel): + def run_mha_cuda_multi_threading(self, attention_kernel): for configs in multi_thread_test_cases("CUDAExecutionProvider", comprehensive_mode): test_inputs = [] for config in configs: @@ -626,8 +598,10 @@ def run_mha_cuda_multi_threading(self, spda_kernel): config.input_format = old_format test_inputs.append({"config": config, "ort_inputs": ort_inputs, "ref_inputs": ref_inputs}) - exception = parity_check_mha_multi_threading(test_inputs, sdpa_kernel=spda_kernel, max_threads=len(configs)) - assert exception is None, f"{spda_kernel=}, {vars(configs[0])}, {exception}" + exception = parity_check_mha_multi_threading( + test_inputs, attention_kernel=attention_kernel, max_threads=len(configs) + ) + assert exception is None, f"{attention_kernel=}, {vars(configs[0])}, {exception}" def test_mha_cuda_multi_threading(self): self.run_mha_cuda_multi_threading(SdpaKernel.DEFAULT) From 5af423c7c0561d3861a6b8ed5598abef02715e28 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Sat, 27 Jul 2024 13:22:57 +1000 Subject: [PATCH 04/37] Set version and other info in the C# dll (#21517) ### Description Set version and other info in the Microsoft.ML.OnnxRuntime C# dll by setting GenerateAssemblyInfo to true and passing in ORT version in the CI. Minor re-org of the order of properties so related things are grouped a little better. ### Motivation and Context #21475 --- .../Microsoft.ML.OnnxRuntime.csproj | 67 +++++++++++-------- .../azure-pipelines/templates/c-api-cpu.yml | 4 +- 2 files changed, 42 insertions(+), 29 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj index 3c8a49bf93578..deb6b4f884bcf 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj @@ -23,7 +23,7 @@ + '$(IncludeMobileTargets)' == 'true'"> net8.0-android @@ -31,6 +31,43 @@ $(BaseTargets);$(MobileTargets) + + Microsoft.ML.OnnxRuntime + Microsoft.ML.OnnxRuntime + + + + 1.0.0 + 0.0.0 + + + + true + Microsoft.ML.OnnxRuntime C# Bindings + Microsoft + © Microsoft Corporation. All rights reserved. + This package contains ONNX Runtime for .Net platforms + + + $(PackageVersion) + + + + + Microsoft + Microsoft.ML.OnnxRuntime.Managed + ONNX;ONNX Runtime;Machine Learning + https://github.com/Microsoft/onnxruntime + LICENSE.txt + ORT_icon_for_light_bg.png + + Release Def: + Branch: $(BUILD_SOURCEBRANCH) + Commit: $(BUILD_SOURCEVERSION) + Build: https://aiinfra.visualstudio.com/Lotus/_build/results?buildId=$(BUILD_BUILDID) + + + AnyCPU;x86 default @@ -43,8 +80,6 @@ $(OnnxRuntimeRoot)\csharp x64 - Microsoft.ML.OnnxRuntime - Microsoft.ML.OnnxRuntime false false portable @@ -54,27 +89,8 @@ on their device is not built for training, an exception will be thrown with the following message - "Training is disabled in the current build. Please build onnxruntime from source with the build flags enable_training_apis. "--> - true + true - - - Microsoft.ML.OnnxRuntime.Managed - Microsoft - 1.0.0 - 0.0.0 - $(PackageVersion) - This package contains ONNX Runtime for .Net platforms - ONNX;ONNX Runtime;Machine Learning - https://github.com/Microsoft/onnxruntime - © Microsoft Corporation. All rights reserved. - LICENSE.txt - ORT_icon_for_light_bg.png - - Release Def: - Branch: $(BUILD_SOURCEBRANCH) - Commit: $(BUILD_SOURCEVERSION) - Build: https://aiinfra.visualstudio.com/Lotus/_build/results?buildId=$(BUILD_BUILDID) - true @@ -82,7 +98,6 @@ false - false $(AllowedOutputExtensionsInPackageBuildOutputFolder);.pdb Debug;Release;RelWithDebInfo @@ -158,10 +173,6 @@ $(OrtConstants);__ENABLE_COREML__ - - $(OrtConstants);__XAMARIN__ - - $(DefineConstants);$(OrtConstants) diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 7ba1179e7ad4d..ec97da3786fd9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -364,6 +364,8 @@ stages: workingDirectory: '$(Build.BinariesDirectory)/nuget-artifact' displayName: 'List artifacts' + - template: set-version-number-variables-step.yml + # Reconstruct the build dir - task: PowerShell@2 displayName: 'Extract native libraries for addition to nuget native package' @@ -403,7 +405,7 @@ stages: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' platform: 'Any CPU' configuration: RelWithDebInfo - msbuildArguments: '-p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix)' + msbuildArguments: '-p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) -p:PackageVersion=$(OnnxRuntimeVersion)' workingDirectory: '$(Build.SourcesDirectory)\csharp' - ${{ if eq(parameters.DoEsrp, true) }}: From 690d745cbff6f540f95e668be21da76873689a32 Mon Sep 17 00:00:00 2001 From: zz002 Date: Sat, 27 Jul 2024 11:28:55 +0800 Subject: [PATCH 05/37] [VitisAI] 1. KernelDef supports StartVersion and EndVersion (#21519) ### Description [VitisAI] 1. KernelDef supports StartVersion and EndVersion 2. CapabilityOps checks domain ### Motivation and Context Co-authored-by: Zhenze Wang --- onnxruntime/core/providers/vitisai/imp/capability.cc | 6 +++++- onnxruntime/core/providers/vitisai/imp/global_api.cc | 4 ++-- .../core/providers/vitisai/vitisai_execution_provider.cc | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/vitisai/imp/capability.cc b/onnxruntime/core/providers/vitisai/imp/capability.cc index 58522a45a151e..6d188076fe613 100644 --- a/onnxruntime/core/providers/vitisai/imp/capability.cc +++ b/onnxruntime/core/providers/vitisai/imp/capability.cc @@ -51,7 +51,11 @@ GetComputeCapabilityOps(const onnxruntime::GraphViewer& graph, std::vector node_indexs = graph.GetNodesInTopologicalOrder(); node_indexs.erase(std::remove_if(node_indexs.begin(), node_indexs.end(), [&](NodeIndex index) { return all_nodes_included_eps.count(index) > 0; }), node_indexs.end()); - node_indexs.erase(std::remove_if(node_indexs.begin(), node_indexs.end(), [&](NodeIndex index) { return all_support_optypes_by_eps.count(graph.GetNode(index)->OpType()) == 0; }), node_indexs.end()); + node_indexs.erase(std::remove_if(node_indexs.begin(), node_indexs.end(), + [&](NodeIndex index) { + auto node = graph.GetNode(index); + return all_support_optypes_by_eps.count(node->Domain() + ":" + node->OpType()) == 0; }), + node_indexs.end()); std::vector> result; for (auto& n : node_indexs) { diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 8c1dce0d3dc1a..a86a4fb61d54d 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -173,7 +173,7 @@ void create_kernel_registry(std::vector domains) { auto def_builder = KernelDefBuilder::Create(); def_builder->SetName(op->GetName(op)); def_builder->SetDomain(domain->domain_.c_str()); - def_builder->SinceVersion(1); + def_builder->SinceVersion(op->GetStartVersion(op), op->GetEndVersion(op)); if (op->version > 12) { auto input_count = op->GetInputTypeCount(op); for (auto i = 0u; i < input_count; i++) { @@ -183,7 +183,7 @@ void create_kernel_registry(std::vector domains) { def_builder->Provider(onnxruntime::kVitisAIExecutionProvider); KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { - // out = std::make_unique(info, *op); + out = std::make_unique(info, *op); return Status::OK(); }; std::ignore = s_kernel_registry_vitisaiep->Register(KernelCreateInfo(def_builder->Build(), kernel_create_fn)); diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 0f0972d96bcee..58fef537535d2 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -44,7 +44,7 @@ VitisAIExecutionProvider::VitisAIExecutionProvider( void VitisAIExecutionProvider::CreateKernelRegistry() { for (const auto& domain : get_domains_vitisaiep()) { for (const auto* op : domain->custom_ops_) { - vitisai_optypes_.insert(op->GetName(op)); + vitisai_optypes_.insert(domain->domain_ + ":" + op->GetName(op)); } } } From d01fc75ef161a624c4275f89cb068cc1c79d9392 Mon Sep 17 00:00:00 2001 From: Yueqing Zhang Date: Fri, 26 Jul 2024 22:15:57 -0700 Subject: [PATCH 06/37] [VitisAI] support vaip create ep context nodes & bug fix (#21506) ### Description 1. We decided to move the context node creation back to our own repo because it is more flexible to modify. 2. We found a bug related the context node. It would change the inference order. So, we fixed in this PR as well. ### Motivation and Context This is crucial for Microsoft Release next month. --------- Co-authored-by: Yueqing Zhang --- .../shared_library/provider_interfaces.h | 1 + .../shared_library/provider_wrappedtypes.h | 1 + .../core/providers/vitisai/imp/global_api.cc | 50 +++++++++++++++++++ .../vitisai/include/vaip/custom_op.h | 11 ++++ .../vitisai/include/vaip/global_api.h | 6 ++- .../vitisai/include/vaip/vaip_ort_api.h | 11 ++-- .../vitisai/vitisai_execution_provider.cc | 14 ++++-- .../core/session/provider_bridge_ort.cc | 1 + 8 files changed, 88 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 382b3ac932520..a9394838aa784 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -388,6 +388,7 @@ struct ProviderHost { virtual ONNX_NAMESPACE::TensorProto* AttributeProto__add_tensors(ONNX_NAMESPACE::AttributeProto* p) = 0; // GraphProto + virtual std::unique_ptr GraphProto__construct() = 0; virtual void GraphProto__operator_delete(ONNX_NAMESPACE::GraphProto* p) = 0; virtual void GraphProto__operator_assign(ONNX_NAMESPACE::GraphProto* p, const ONNX_NAMESPACE::GraphProto& v) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index de6c1da1d6430..242c7126f3274 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -146,6 +146,7 @@ struct AttributeProto final { }; struct GraphProto final { + static std::unique_ptr Create() { return g_host->GraphProto__construct(); } static void operator delete(void* p) { g_host->GraphProto__operator_delete(reinterpret_cast(p)); } void operator=(const GraphProto& v) { return g_host->GraphProto__operator_assign(this, v); } diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index a86a4fb61d54d..df47fa5cee4ab 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -55,10 +55,15 @@ struct OrtVitisAIEpAPI { uint32_t (*vaip_get_version)(); void (*get_backend_compilation_cache)(const std::string& model_path, const onnxruntime::Graph& graph, const char* json_config, uint8_t compiler_codes, std::string& cache_dir, std::string& cache_key, std::string& cache_data); void (*restore_backend_compilation_cache)(const std::string& cache_dir, const std::string& cache_key, const std::string& cache_data, const std::string& model_path); + void (*create_ep_context_nodes)( + onnxruntime::Graph& ep_context_graph, + const std::vector>& eps, + vaip_core::DllSafe>* ret_value) = nullptr; void Ensure() { if (handle_) return; auto& env = Provider_GetHost()->Env__Default(); + auto& logger = *Provider_GetHost()->LoggingManager_GetDefaultLogger(); #ifdef _WIN32 // this dll is already linked to the executable, normally a test program handle_ = reinterpret_cast(GetModuleHandle(TEXT("onnxruntime_vitisai_ep.dll"))); @@ -81,6 +86,10 @@ struct OrtVitisAIEpAPI { (void**)&vaip_get_version); ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "get_compilation_cache", (void**)&get_backend_compilation_cache)); ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "restore_compilation_cache", (void**)&restore_backend_compilation_cache)); + status1 = (env.GetSymbolFromLibrary(handle_, "create_ep_context_nodes", (void**)&create_ep_context_nodes)); + if (!status1.IsOK()) { + LOGS(logger, WARNING) << "create_ep_context_nodes is not defined, please upgrade onnxruntime_vitisai_ep.dll. However, it still works."; + } } private: @@ -146,6 +155,24 @@ void restore_backend_compilation_cache(const std::string& cache_dir, const std:: s_library_vitisaiep.restore_backend_compilation_cache(cache_dir, cache_key, cache_data, model_path); } +bool has_create_ep_context_nodes() { + return s_library_vitisaiep.create_ep_context_nodes != nullptr; +} + +std::optional> create_ep_context_nodes( + onnxruntime::Graph& ep_context_graph, + const std::vector>& eps) { + if (s_library_vitisaiep.create_ep_context_nodes) { + vaip_core::DllSafe> nodes; + s_library_vitisaiep.create_ep_context_nodes(ep_context_graph, eps, &nodes); + if (nodes.get()) { + auto ret = std::vector(*nodes); + return ret; + } + } + return std::nullopt; +} + struct MyCustomOpKernel : OpKernel { MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { op_kernel_ = @@ -405,6 +432,29 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { graph.AddInitializedTensor(tensor); }; + the_global_api.get_model_path = [](const Graph& graph) -> const std::filesystem::path& { + return graph.ModelPath(); + }; + + the_global_api.create_empty_model = [](const std::filesystem::path& path, const std::vector>& opset) -> Model* { + auto model_proto = ONNX_NAMESPACE::ModelProto::Create(); + auto graph_proto = ONNX_NAMESPACE::GraphProto::Create(); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + for (const auto& op : opset) { + auto* opset_import = model_proto->add_opset_import(); + *(opset_import->mutable_domain()) = op.first; + opset_import->set_version(op.second); + } + std::ignore = model_proto->mutable_graph(); // create a graph + auto& logger = logging::LoggingManager::DefaultLogger(); + auto model = Model::Create(std::move(*model_proto), path, nullptr, logger); + return model.release(); + }; + + the_global_api.graph_set_inputs = [](Graph& graph, gsl::span inputs) { + graph.SetInputs(inputs); + }; + if (!s_library_vitisaiep.vaip_get_version) { return reinterpret_cast(&(the_global_api.host_)); } else { diff --git a/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h b/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h index d34f7095b704d..5d020e00ff5b7 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h @@ -26,6 +26,17 @@ class ExecutionProvider { virtual DllSafe> get_meta_def_constant_initializer() const = 0; virtual std::unique_ptr compile() const = 0; + + public: + inline void set_fused_node(const onnxruntime::Node* fused_node) { + fused_node_ = fused_node; + } + inline const onnxruntime::Node* get_fused_node() const { + return fused_node_; + } + + private: + const onnxruntime::Node* fused_node_ = nullptr; }; class CustomOp { diff --git a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h index 3fdbc60bb0ee6..ae2a513a98e32 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h @@ -9,10 +9,14 @@ #include "vaip/my_ort.h" #include "vaip/dll_safe.h" #include "vaip/custom_op.h" - +#include void initialize_vitisai_ep(); vaip_core::DllSafe>> compile_onnx_model(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options); std::shared_ptr get_kernel_registry_vitisaiep(); const std::vector& get_domains_vitisaiep(); void get_backend_compilation_cache(const onnxruntime::PathString& model_path_str, const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::ProviderOptions& options, uint8_t compiler_codes, std::string& cache_dir, std::string& cache_key, std::string& cache_data); void restore_backend_compilation_cache(const std::string& cache_dir, const std::string& cache_key, const std::string& cache_data, const std::string& model_path); +std::optional> create_ep_context_nodes( + onnxruntime::Graph& ep_context_graph, + const std::vector>& eps); +bool has_create_ep_context_nodes(); diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h index 3346739890484..e6aacfe1f0272 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h @@ -8,12 +8,13 @@ #include #include #include +#include struct OrtApi; namespace vaip_core { -#define VAIP_ORT_API_MAJOR (3u) -#define VAIP_ORT_API_MINOR (1u) +#define VAIP_ORT_API_MAJOR (4u) +#define VAIP_ORT_API_MINOR (0u) #define VAIP_ORT_API_PATCH (0u) struct OrtApiForVaip { uint32_t magic; // 'VAIP' or something else to make sure the following field @@ -222,7 +223,11 @@ struct OrtApiForVaip { const std::vector& data); // [88] TensorProto* (*tensor_proto_new_bf16)( const std::string& name, const std::vector& shape, - const std::vector& data); // [89] + const std::vector& data); // [89] + const std::filesystem::path& (*get_model_path)(const Graph& graph); // [90] + Model* (*create_empty_model)(const std::filesystem::path& path, const std::vector>& opset); //[91] + void (*graph_set_inputs)(Graph& graph, + gsl::span inputs); // [92] }; #ifndef USE_VITISAI diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 58fef537535d2..756bda2199e89 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -58,8 +58,15 @@ const InlinedVector VitisAIExecutionProvider::GetEpContextNodes() c // All preconditions are supposed to have happened. if (p_ep_ctx_model_) { auto& graph = p_ep_ctx_model_->MainGraph(); - for (const auto* p_node : graph.Nodes()) { - ep_context_node_ptrs.push_back(p_node); + if (has_create_ep_context_nodes()) { + auto nodes = create_ep_context_nodes(graph, **execution_providers_); + if (nodes.has_value()) { + ep_context_node_ptrs.assign(nodes->begin(), nodes->end()); + } + } else { + for (const auto* p_node : graph.Nodes()) { + ep_context_node_ptrs.push_back(p_node); + } } } return ep_context_node_ptrs; @@ -187,6 +194,7 @@ common::Status VitisAIExecutionProvider::Compile(const std::vectorexecution_providers_)[index]->set_fused_node(&fused_node_graph.fused_node.get()); compute_info.create_state_func = [this, index](ComputeContext* context, FunctionState* state) { auto* p = (**this->execution_providers_)[index]->compile().release(); *state = p; @@ -204,7 +212,7 @@ common::Status VitisAIExecutionProvider::Compile(const std::vectoradd_tensors(); } // GraphProto (wrapped) + std::unique_ptr GraphProto__construct() override { return std::make_unique(); } void GraphProto__operator_delete(ONNX_NAMESPACE::GraphProto* p) override { delete p; } const ONNX_NAMESPACE::ValueInfoProto& GraphProto__input(const ONNX_NAMESPACE::GraphProto* p, int index) override { return p->input(index); } From 10b4a3b90bd61fcda8aefecf2a1dce1a45c086e1 Mon Sep 17 00:00:00 2001 From: maggie1059 <34173352+maggie1059@users.noreply.github.com> Date: Fri, 26 Jul 2024 22:26:38 -0700 Subject: [PATCH 07/37] Fix conda failure for onnxruntime-directml (#21526) The change in #21005 works for directly building wheels with `build.py`, but ort-nightly-directml wheels, as well as the 1.18.1 release of the onnxruntime-directml python wheel, still do not work with conda since they're built from the `py-win-gpu.yml` pipeline, which uses `install_third_party_deps.ps1` to set compile flags. --- tools/ci_build/github/windows/install_third_party_deps.ps1 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/windows/install_third_party_deps.ps1 b/tools/ci_build/github/windows/install_third_party_deps.ps1 index 07679006fb343..168df90188791 100644 --- a/tools/ci_build/github/windows/install_third_party_deps.ps1 +++ b/tools/ci_build/github/windows/install_third_party_deps.ps1 @@ -27,7 +27,7 @@ $Env:CMAKE_PREFIX_PATH = "$install_prefix" New-Item -Path "$install_prefix" -ItemType Directory -Force # Setup compile flags -$compile_flags = @('/MP', '/guard:cf', '/DWIN32', '/D_WINDOWS', '/DWINVER=0x0A00', '/D_WIN32_WINNT=0x0A00', '/DNTDDI_VERSION=0x0A000000', '/W3') +$compile_flags = @('/MP', '/guard:cf', '/DWIN32', '/D_WINDOWS', '/D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR', '/DWINVER=0x0A00', '/D_WIN32_WINNT=0x0A00', '/DNTDDI_VERSION=0x0A000000', '/W3') $linker_flags=@('/guard:cf') if ($use_cache) { From 1ce160883f964509a547458c484d2449bda047ae Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 26 Jul 2024 22:31:16 -0700 Subject: [PATCH 08/37] Bump Sixlabors.ImageSharp from 2.1.8 to 2.1.9 in /csharp/sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample (#21444) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [Sixlabors.ImageSharp](https://github.com/SixLabors/ImageSharp) from 2.1.8 to 2.1.9.
Release notes

Sourced from Sixlabors.ImageSharp's releases.

v2.1.9

What's Changed

Full Changelog: https://github.com/SixLabors/ImageSharp/compare/v2.1.8...v2.1.9

Commits
  • 9816ca4 Merge pull request #2770 from SixLabors/af/backport-2759-2.1.x
  • b33d666 handle DecodingMode
  • 6b2030b Merge branch 'release/2.1.x' into af/backport-2759-2.1.x
  • 8ffad3f Issue2012BadMinCode should decode now
  • 1f5bf23 skip Issue2758_DecodeWorks
  • 3bf8c57 manual port of 3.1 gif decoder
  • 28c20de Clamp JPEG quality estimation results.
  • 4b910e7 Decode LZW row by row
  • a1f2879 Merge pull request #2756 from SixLabors/af/git-av-2.1
  • 898df7f backport #2749 to 2.1
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=Sixlabors.ImageSharp&package-manager=nuget&previous-version=2.1.8&new-version=2.1.9)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .../Microsoft.ML.OnnxRuntime.ResNet50v2Sample.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csharp/sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample.csproj b/csharp/sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample.csproj index 647c0bbe6a242..29fc9f3bc382f 100644 --- a/csharp/sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample.csproj +++ b/csharp/sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample.csproj @@ -8,7 +8,7 @@ - + From 48fb8a7e56a7263a8405dc644756eb5c55560352 Mon Sep 17 00:00:00 2001 From: jingyanwangms <47403504+jingyanwangms@users.noreply.github.com> Date: Sat, 27 Jul 2024 11:10:52 -0700 Subject: [PATCH 09/37] Security fuzz address sanitizer fix Bug #2 and #3 (#21528) ### Description Security fuzz test with address sanitizer found several bugs --- onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc | 2 ++ onnxruntime/core/optimizer/attention_fusion.cc | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc index 34a1da99316a2..030cdb1e1b17f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc @@ -143,6 +143,8 @@ Status GptSubgraph::Validate(const std::vector& subgraph_inputs, // Past state shape is like (2, batch_size, num_heads, past_seq_len, hidden_size/num_heads). const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_inputs[3]->Shape(); + ORT_RETURN_IF(past_shape == nullptr, + "subgraph past state cannot be nullptr"); ORT_RETURN_IF(past_shape->dim_size() != 5, "subgraph past state is expected to have 5 dimension, got ", past_shape->dim_size()); diff --git a/onnxruntime/core/optimizer/attention_fusion.cc b/onnxruntime/core/optimizer/attention_fusion.cc index 08066f030a381..64a38214caff0 100644 --- a/onnxruntime/core/optimizer/attention_fusion.cc +++ b/onnxruntime/core/optimizer/attention_fusion.cc @@ -210,7 +210,7 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, if ((node.GetOutputEdgesCount() >= 2 && node.GetOutputEdgesCount() <= 6) && // Add node.GetOutputEdgesCount() == 5/6 for distilbert graph_utils::IsSupportedOptypeVersionAndDomain(node, "LayerNormalization", {1, 17}, kOnnxDomain) && - graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) && node.InputDefs().size() > 2) { // Get hidden size from layer norm bias tensor shape. const NodeArg& layer_norm_bias = *(node.InputDefs()[2]); if (!optimizer_utils::IsShapeKnownOnAllDims(layer_norm_bias, 1)) { From 82b2955268e14f26eb71ad2d660452ab8db454d7 Mon Sep 17 00:00:00 2001 From: Ranjit Ranjan <165394499+ranjitshs@users.noreply.github.com> Date: Sat, 27 Jul 2024 23:47:22 +0530 Subject: [PATCH 10/37] [AIX]test failure fix using gtest-1.15.0 for AIX (#21497) ### Description Local CI setup for AIX reported tests failure after the gtest 1.15.0 upgrade. ### Motivation and Context Below tests failure is observed after gtest upgrade. The following tests FAILED: 1 - onnxruntime_test_all (ILLEGAL) 7 - onnxruntime_logging_apis_test (Subprocess aborted) To fix this, I am enabling pthread support under gtest. This was disabled with previous version of gtest for some reason. Now by enabling this, above tests are getting passed with gtest 1.15.0. --- cmake/external/onnxruntime_external_deps.cmake | 3 --- 1 file changed, 3 deletions(-) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 14e6ed515fd6e..775576a771529 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -46,9 +46,6 @@ if (onnxruntime_BUILD_UNIT_TESTS) if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set(gtest_disable_pthreads ON) endif() - if (${CMAKE_SYSTEM_NAME} MATCHES "AIX") - set(gtest_disable_pthreads ON CACHE BOOL "gtest_disable_pthreads" FORCE) - endif() set(INSTALL_GTEST OFF CACHE BOOL "" FORCE) if (IOS OR ANDROID) # on mobile platforms the absl flags class dumps the flag names (assumably for binary size), which breaks passing From 7e23212de9746ed2452061958f8aae3ffc171cee Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Sat, 27 Jul 2024 15:58:12 -0700 Subject: [PATCH 11/37] Delete tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml (#21529) ### Description Delete tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml ### Motivation and Context This CI pipeline has been divided into 4 different pipeline. --- .../azure-pipelines/win-gpu-ci-pipeline.yml | 125 ------------------ 1 file changed, 125 deletions(-) delete mode 100644 tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml deleted file mode 100644 index c5262880c4c55..0000000000000 --- a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml +++ /dev/null @@ -1,125 +0,0 @@ -##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### -trigger: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -pr: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -#### end trigger #### - -parameters: -- name: CudaVersion - displayName: CUDA version - type: string - default: '12.2' - values: - - 11.8 - - 12.2 -- name: RunOnnxRuntimeTests - displayName: Run Tests? - type: boolean - default: true - -stages: -- stage: cuda - dependsOn: [] - jobs: - - template: templates/jobs/win-ci-vs-2022-job.yml - parameters: - BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda.bat - buildArch: x64 - additionalBuildFlags: >- - --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" - --enable_cuda_profiling --enable_transformers_tool_test - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 - --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON - --cmake_extra_defines onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON - msbuildPlatform: x64 - isX86: false - job_name_suffix: x64_RelWithDebInfo - RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - ORT_EP_NAME: CUDA - WITH_CACHE: true - MachinePool: onnxruntime-Win2022-GPU-A10 - -- stage: training - dependsOn: [] - jobs: - - template: templates/jobs/win-ci-vs-2022-job.yml - parameters: - BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda.bat - buildArch: x64 - additionalBuildFlags: >- - --enable_pybind --enable_training --use_cuda --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" - --skip_onnx_tests - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 - msbuildPlatform: x64 - isX86: false - job_name_suffix: x64_RelWithDebInfo - RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - ORT_EP_NAME: CUDA - WITH_CACHE: true - MachinePool: onnxruntime-Win2022-GPU-A10 - isTraining: true - -- stage: dml - dependsOn: [] - jobs: - - template: templates/jobs/win-ci-vs-2022-job.yml - parameters: - BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat - buildArch: x64 - additionalBuildFlags: --enable_pybind --use_dml --enable_wcos --use_winml - msbuildPlatform: x64 - isX86: false - job_name_suffix: x64_RelWithDebInfo - RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - ORT_EP_NAME: DML - WITH_CACHE: false - MachinePool: onnxruntime-Win2022-GPU-dml-A10 - -- stage: kernelDocumentation - dependsOn: [] - jobs: - - template: templates/jobs/win-ci-vs-2022-job.yml - parameters: - BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda.bat - buildArch: x64 - # note: need to specify `--gen_doc` when creating the build config so it has to be in additionalBuildFlags - additionalBuildFlags: >- - --gen_doc validate --skip_tests --enable_pybind --use_dml --use_cuda - --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 - --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF - msbuildPlatform: x64 - isX86: false - job_name_suffix: x64_RelWithDebInfo - RunOnnxRuntimeTests: false - GenerateDocumentation: true - ORT_EP_NAME: CUDA # It doesn't really matter which EP is selected here since this stage is for documentation. - WITH_CACHE: true - MachinePool: onnxruntime-Win2022-GPU-A10 From a4d3a1ce0c18e1d1b31a9cc0b45beba290ee114c Mon Sep 17 00:00:00 2001 From: liqun Fu Date: Sat, 27 Jul 2024 15:58:36 -0700 Subject: [PATCH 12/37] pick changes from https://github.com/onnx/onnx/pull/6195 to fix heap-buffer-overflow in onnx::convPoolShapeInference (#21507) ### Description onnx 1.16.2 is not available before ort 1.19.0 code freeze. Thus pick the needed change as patch --- cmake/patches/onnx/onnx.patch | 383 ++++++++++++++++++ .../providers/cpu/generator/random_test.cc | 8 +- .../core/graph/training_op_defs.cc | 104 +++-- 3 files changed, 447 insertions(+), 48 deletions(-) diff --git a/cmake/patches/onnx/onnx.patch b/cmake/patches/onnx/onnx.patch index 162d33581a5ca..6ac3555eeecf1 100644 --- a/cmake/patches/onnx/onnx.patch +++ b/cmake/patches/onnx/onnx.patch @@ -86,3 +86,386 @@ index 0aab3e26..398ac2d6 100644 +#endif + #endif // ! ONNX_ONNX_PB_H +diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc +index c315a2a7..58963154 100644 +--- a/onnx/defs/math/defs.cc ++++ b/onnx/defs/math/defs.cc +@@ -3472,6 +3472,9 @@ ONNX_OPERATOR_SET_SCHEMA( + } + + auto& input_shape = getInputShape(ctx, 0); ++ if (input_shape.dim_size() < 2) { ++ fail_shape_inference("First input should have at least 2 dimensions in ", ctx.getDisplayName(), "."); ++ } + auto signal_dim = input_shape.dim(1); + if (!signal_dim.has_dim_value()) { + return; +diff --git a/onnx/defs/nn/defs.cc b/onnx/defs/nn/defs.cc +index be6a851d..fad595d0 100644 +--- a/onnx/defs/nn/defs.cc ++++ b/onnx/defs/nn/defs.cc +@@ -126,6 +126,9 @@ void convPoolShapeInference( + residual -= stride; + } + } ++ if (i >= static_cast(effective_kernel_shape.size())) { ++ fail_shape_inference("kernel shape should have ", input_dims_size, " values in ", ctx.getDisplayName(), "."); ++ } + int64_t total_pad = residual == 0 ? effective_kernel_shape[i] - stride : effective_kernel_shape[i] - residual; + if (total_pad < 0) + total_pad = 0; +@@ -959,19 +962,21 @@ ONNX_OPERATOR_SET_SCHEMA( + auto w_type = ctx.getInputType(3); + if (nullptr == x_type || nullptr == w_type || x_type->value_case() != TypeProto::kTensorType || + w_type->value_case() != TypeProto::kTensorType) { +- fail_type_inference("inputs are expected to have tensor type."); ++ fail_type_inference("inputs are expected to have tensor type in ", ctx.getDisplayName(), "."); + } + + auto x_zero_point_type = ctx.getInputType(2); + if (nullptr == x_zero_point_type || + x_zero_point_type->tensor_type().elem_type() != x_type->tensor_type().elem_type()) { +- fail_type_inference("input and zero_point pair is expected to have be same type."); ++ fail_type_inference( ++ "input and zero_point pair is expected to have be same type in ", ctx.getDisplayName(), "."); + } + + auto w_zero_point_type = ctx.getInputType(5); + if (nullptr == w_zero_point_type || + w_zero_point_type->tensor_type().elem_type() != w_type->tensor_type().elem_type()) { +- fail_type_inference("weight and zero_point pair is expected to have same type."); ++ fail_type_inference( ++ "weight and zero_point pair is expected to have same type in ", ctx.getDisplayName(), "."); + } + + propagateElemTypeFromInputToOutput(ctx, 7, 0); +@@ -2647,7 +2652,8 @@ ONNX_OPERATOR_SET_SCHEMA( + if (!hasNInputShapes(ctx, 1)) { + return; + } +- auto& input_shape = ctx.getInputType(0)->tensor_type().shape(); ++ ++ auto& input_shape = getInputShape(ctx, 0); + int64_t input_ndim = input_shape.dim_size(); + int64_t axis = -1; + auto axis_proto = ctx.getAttribute("axis"); +@@ -2659,7 +2665,16 @@ ONNX_OPERATOR_SET_SCHEMA( + // positive value. + axis += input_ndim; + } +- ++ if (axis < 0) { ++ fail_shape_inference( ++ "Unexpected axis value (", ++ axis, ++ ") rank of first input is ", ++ input_ndim, ++ " in ", ++ ctx.getDisplayName(), ++ "."); ++ } + if (ctx.getNumOutputs() > 1) { + auto mean_shape = ctx.getOutputType(1)->mutable_tensor_type()->mutable_shape(); + mean_shape->CopyFrom(input_shape); +diff --git a/onnx/defs/nn/old.cc b/onnx/defs/nn/old.cc +index 57f8e2a4..8b2dc07f 100644 +--- a/onnx/defs/nn/old.cc ++++ b/onnx/defs/nn/old.cc +@@ -201,6 +201,9 @@ void convPoolShapeInference_opset19( + residual -= stride; + } + } ++ if (i >= static_cast(effective_kernel_shape.size())) { ++ fail_shape_inference("kernel shape should have ", input_dims_size, " values in ", ctx.getDisplayName(), "."); ++ } + int64_t total_pad = residual == 0 ? effective_kernel_shape[i] - stride : effective_kernel_shape[i] - residual; + if (total_pad < 0) + total_pad = 0; +diff --git a/onnx/defs/shape_inference.h b/onnx/defs/shape_inference.h +index a80473b3..d1bcd401 100644 +--- a/onnx/defs/shape_inference.h ++++ b/onnx/defs/shape_inference.h +@@ -105,6 +105,10 @@ struct InferenceContext { + virtual const SparseTensorProto* getInputSparseData(size_t index) const = 0; + // Gets the shape inputs computed by partial data propagation. + virtual const TensorShapeProto* getSymbolicInput(size_t index) const = 0; ++ // To display a name the user can use to narrow its search. ++ virtual std::string getDisplayName() const { ++ return ""; ++ } + }; + + // We use data propagation to perform partial evaluation of the model, to compute statically +@@ -263,7 +267,15 @@ inline void propagateElemTypeFromDtypeToOutput( + } else { + // This is not expected to happen + fail_type_inference( +- "Output ", outputIndex, " expected to have: ", expected_value_case, " or UNDEFINED. Got: ", output_value_case); ++ "Output ", ++ outputIndex, ++ " expected to have: ", ++ expected_value_case, ++ " or UNDEFINED. Got: ", ++ output_value_case, ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + } + +@@ -277,18 +289,18 @@ inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const Attr + const auto attr_type = attr->type(); + if (attr_type == AttributeProto::TENSOR) { + if (attr->t().dims().size() != 1) { +- fail_type_inference("Attribute expected to have a one-dim tensor"); ++ fail_type_inference("Attribute expected to have a one-dim tensor in ", ctx.getDisplayName(), "."); + } + data_type = attr->t().data_type(); + expected_value_case = TypeProto::kTensorType; + } else if (attr_type == AttributeProto::SPARSE_TENSOR) { + if (attr->sparse_tensor().dims().size() != 1) { +- fail_type_inference("Attribute expected to have a one-dim sparse tensor"); ++ fail_type_inference("Attribute expected to have a one-dim sparse tensor in ", ctx.getDisplayName(), "."); + } + data_type = attr->sparse_tensor().values().data_type(); + expected_value_case = TypeProto::kSparseTensorType; + } else { +- fail_type_inference("Attribute expected to have tensor or sparse tensor type"); ++ fail_type_inference("Attribute expected to have tensor or sparse tensor type in ", ctx.getDisplayName(), "."); + } + + propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, expected_value_case); +@@ -326,7 +338,10 @@ inline const TensorShapeProto& getInputShape(const InferenceContext& ctx, size_t + const auto* input_type = ctx.getInputType(n); + const auto value_case = input_type->value_case(); + if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) { +- fail_type_inference("Attribute expected to have tensor or sparse tensor type"); ++ fail_type_inference("Input ", n, "expected to be a tensor or a sparse tensor type in ", ctx.getDisplayName(), "."); ++ } ++ if (!hasShape(*input_type)) { ++ fail_shape_inference("Input ", n, " must have a non null shape in ", ctx.getDisplayName(), "."); + } + if (value_case == TypeProto::kTensorType) { + return input_type->tensor_type().shape(); +@@ -344,7 +359,7 @@ inline const TensorShapeProto* getOptionalInputShape(InferenceContext& ctx, size + + const auto value_case = input_type->value_case(); + if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) { +- fail_type_inference("Attribute expected to have tensor or sparse tensor type"); ++ fail_type_inference("Input ", n, "expected to be a tensor or a sparse tensor type in ", ctx.getDisplayName(), "."); + } + if (value_case == TypeProto::kTensorType) { + return &input_type->tensor_type().shape(); +@@ -372,7 +387,10 @@ inline void appendSingleDimCopiedFromInputTypeToOutputType( + " does not match type of output: ", + outputIndex, + "type: ", +- output_value_case); ++ output_value_case, ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + if (TypeProto::kTensorType == input_value_case) { + auto* dim = output_type->mutable_tensor_type()->mutable_shape()->add_dim(); +@@ -382,7 +400,13 @@ inline void appendSingleDimCopiedFromInputTypeToOutputType( + *dim = input_type->sparse_tensor_type().shape().dim(static_cast(fromDimIndex)); + } else { + fail_type_inference( +- "Input ", inputIndex, " and Output ", outputIndex, " expected to have tensor or sparse tensor type"); ++ "Input ", ++ inputIndex, ++ " and Output ", ++ outputIndex, ++ " expected to have tensor or sparse tensor type in ", ++ ctx.getDisplayName(), ++ "."); + } + } + +@@ -440,7 +464,14 @@ updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType + setTensorElementType(elemType, expected_type, *output_type); + } else { + // This is not expected to happen +- fail_type_inference("Output ", outputIndex, " expected to have tensor or sparse tensor type: ", expected_type); ++ fail_type_inference( ++ "Output ", ++ outputIndex, ++ " expected to have tensor or sparse tensor type: ", ++ expected_type, ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + } + +@@ -462,16 +493,17 @@ inline void propagateElemTypeFromAttributeToOutput( + updateOutputElemType(ctx, outputIndex, default_value, expected_type); + return; + } else { +- fail_type_inference("Value of attribute ", attributeName, " not specified"); ++ fail_type_inference("Value of attribute ", attributeName, " not specified in ", ctx.getDisplayName(), "."); + } + } + if (!attr_proto->has_i()) { +- fail_type_inference("Attribute ", attributeName, " should be of integer type and specify a type."); ++ fail_type_inference( ++ "Attribute ", attributeName, " should be of integer type and specify a type in ", ctx.getDisplayName(), "."); + } + auto attr_value = attr_proto->i(); + auto elem_type = static_cast(attr_value); + if (!TensorProto_DataType_IsValid(elem_type)) { +- fail_type_inference("Attribute ", attributeName, " does not specify a valid type."); ++ fail_type_inference("Attribute ", attributeName, " does not specify a valid type in ", ctx.getDisplayName(), "."); + } + updateOutputElemType(ctx, outputIndex, elem_type, expected_type); + } +@@ -497,7 +529,7 @@ inline TensorShapeProto* + getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_type = TypeProto::kTensorType) { + auto output_type = ctx.getOutputType(n); + if (output_type == nullptr) { +- fail_type_inference("Output ", n, " expected to have tensor or sparse type"); ++ fail_type_inference("Output ", n, " expected to have tensor or sparse type in ", ctx.getDisplayName(), "."); + } + const auto output_value_case = output_type->value_case(); + if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) { +@@ -505,7 +537,7 @@ getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_typ + } else if (output_value_case == TypeProto::VALUE_NOT_SET) { + return getTensorMutableShape(default_type, *output_type); + } else { +- fail_type_inference("Output ", n, " expected to have tensor type"); ++ fail_type_inference("Output ", n, " expected to have tensor type in ", ctx.getDisplayName(), "."); + } + } + +@@ -562,13 +594,13 @@ inline void propagateShapeFromAttributeToOutput( + auto attr_proto = ctx.getAttribute(attributeName); + if ((nullptr == attr_proto) || (!attr_proto->has_type()) || + (attr_proto->type() != AttributeProto_AttributeType_INTS)) { +- fail_shape_inference("Attribute ", attributeName, " should specify a shape"); ++ fail_shape_inference("Attribute ", attributeName, " should specify a shape in ", ctx.getDisplayName(), "."); + } + auto& int_list = attr_proto->ints(); + TensorShapeProto shape; + for (auto dim_size : int_list) { + if (dim_size < 0) { +- fail_shape_inference("Negative values are not allowed in a shape specification"); ++ fail_shape_inference("Negative values are not allowed in a shape specification in ", ctx.getDisplayName(), "."); + } + shape.add_dim()->set_dim_value(dim_size); + } +@@ -745,7 +777,16 @@ inline void checkInputRank(InferenceContext& ctx, size_t input_index, int expect + if (hasInputShape(ctx, input_index)) { + auto rank = getInputShape(ctx, input_index).dim_size(); + if (rank != expected_rank) { +- fail_shape_inference("Input ", input_index, " expected to have rank ", expected_rank, " but has rank ", rank); ++ fail_shape_inference( ++ "Input ", ++ input_index, ++ " expected to have rank ", ++ expected_rank, ++ " but has rank ", ++ rank, ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + } + } +@@ -798,7 +839,15 @@ inline void unifyInputDim(InferenceContext& ctx, size_t input_index, int dim_ind + // This shape is expected to have rank > dim_index: + if (input_shape.dim_size() <= dim_index) { + fail_shape_inference( +- "Input ", input_index, " expected to have rank >", dim_index, " but has rank ", input_shape.dim_size()); ++ "Input ", ++ input_index, ++ " expected to have rank >", ++ dim_index, ++ " but has rank ", ++ input_shape.dim_size(), ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + const Dim& input_dim = input_shape.dim(dim_index); + // Now, unify dim and input_dim: +diff --git a/onnx/shape_inference/implementation.cc b/onnx/shape_inference/implementation.cc +index 8723dcd4..8249fc59 100644 +--- a/onnx/shape_inference/implementation.cc ++++ b/onnx/shape_inference/implementation.cc +@@ -906,7 +906,7 @@ struct FunctionInferenceContext : public InferenceContext { + const std::vector& input_types, + const std::vector& attributes, + const ShapeInferenceOptions& options) +- : input_types_(input_types), options_(options) { ++ : input_types_(input_types), options_(options), func_proto_(&func_proto) { + for (const auto& attr : attributes) { + attributesByName_[attr.name()] = &attr; + } +@@ -971,11 +971,25 @@ struct FunctionInferenceContext : public InferenceContext { + return std::move(output_types_); + } + ++ std::string getDisplayName() const override { ++ if (func_proto_ == nullptr) ++ return ""; ++ if (func_proto_->domain().empty()) { ++ if (func_proto_->name().empty()) ++ return ""; ++ return MakeString("function ", func_proto_->name()); ++ } ++ if (func_proto_->name().empty()) ++ return MakeString("function [", func_proto_->domain(), "]"); ++ return MakeString("function ", func_proto_->name(), "[", func_proto_->domain(), "]"); ++ } ++ + private: + const std::vector& input_types_; + std::vector output_types_; + std::unordered_map attributesByName_; + ShapeInferenceOptions options_; ++ const FunctionProto* func_proto_; + }; + + std::vector InferFunctionOutputTypes( +diff --git a/onnx/shape_inference/implementation.h b/onnx/shape_inference/implementation.h +index 2c63c910..b0e4c32d 100644 +--- a/onnx/shape_inference/implementation.h ++++ b/onnx/shape_inference/implementation.h +@@ -146,7 +146,7 @@ struct InferenceContextImpl : public InferenceContext { + const ShapeInferenceOptions& options, + DataValueMap* generatedShapeData = nullptr, + GraphInferenceContext* graphInferenceContext = nullptr) +- : graphInferenceContext_{graphInferenceContext}, options_(options) { ++ : graphInferenceContext_{graphInferenceContext}, options_(options), node_(&n) { + for (auto& attr : *n.mutable_attribute()) { + attributesByName_[attr.name()] = &attr; + if (attr.has_g()) { +@@ -277,6 +277,19 @@ struct InferenceContextImpl : public InferenceContext { + return inferencer; + } + ++ std::string getDisplayName() const override { ++ if (node_ == nullptr) ++ return ""; ++ if (node_->domain().empty()) { ++ if (node_->name().empty()) ++ return MakeString("node ", node_->op_type()); ++ return MakeString("node ", node_->op_type(), " (", node_->name(), ")"); ++ } ++ if (node_->name().empty()) ++ return MakeString("node ", node_->op_type(), "[", node_->domain(), "]"); ++ return MakeString("node ", node_->op_type(), "[", node_->domain(), "]", " (", node_->name(), ")"); ++ } ++ + std::vector allInputData_; + std::vector allInputSparseData_; + std::vector allShapeInputData_; +@@ -289,6 +302,7 @@ struct InferenceContextImpl : public InferenceContext { + // mutable as internal cache of GraphInferencer instances + mutable std::unordered_map> graphAttributeInferencers_; + ShapeInferenceOptions options_; ++ NodeProto* node_; + }; + + struct DataPropagationContextImpl : public DataPropagationContext { diff --git a/onnxruntime/test/providers/cpu/generator/random_test.cc b/onnxruntime/test/providers/cpu/generator/random_test.cc index ec9b1614488a7..f42f32d63d1fa 100644 --- a/onnxruntime/test/providers/cpu/generator/random_test.cc +++ b/onnxruntime/test/providers/cpu/generator/random_test.cc @@ -178,7 +178,7 @@ TEST(Random, InvalidDType) { test.AddAttribute("shape", dims); test.AddOutput("Y", dims, expected_output); - test.Run(OpTester::ExpectResult::kExpectFailure, "Attribute dtype does not specify a valid type."); + test.Run(OpTester::ExpectResult::kExpectFailure, "Node (node1) Op (RandomNormal) [TypeInferenceError] Attribute dtype does not specify a valid type in ."); } { @@ -194,7 +194,7 @@ TEST(Random, InvalidDType) { test.AddAttribute("shape", dims); test.AddOutput("Y", dims, expected_output); - test.Run(OpTester::ExpectResult::kExpectFailure, "Attribute dtype does not specify a valid type."); + test.Run(OpTester::ExpectResult::kExpectFailure, "Node (node1) Op (RandomUniform) [TypeInferenceError] Attribute dtype does not specify a valid type in ."); } { @@ -210,7 +210,7 @@ TEST(Random, InvalidDType) { test.AddInput("X", dims, input); test.AddOutput("Y", dims, expected_output); - test.Run(OpTester::ExpectResult::kExpectFailure, "Attribute dtype does not specify a valid type."); + test.Run(OpTester::ExpectResult::kExpectFailure, "Node (node1) Op (RandomNormalLike) [TypeInferenceError] Attribute dtype does not specify a valid type in ."); } { @@ -226,7 +226,7 @@ TEST(Random, InvalidDType) { test.AddInput("X", dims, input); test.AddOutput("Y", dims, expected_output); - test.Run(OpTester::ExpectResult::kExpectFailure, "Attribute dtype does not specify a valid type."); + test.Run(OpTester::ExpectResult::kExpectFailure, "Node (node1) Op (RandomUniformLike) [TypeInferenceError] Attribute dtype does not specify a valid type in ."); } } diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 2a8d2de982e79..92f803030ada4 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -181,6 +181,64 @@ static void propagateRecvOutputTensorElemTypes( } } +void SendShapeInfer(ONNX_NAMESPACE::InferenceContext& ctx) { + if (ctx.getNumInputs() < 3) { + fail_shape_inference("Send must have at least three inputs."); + } else { + if (hasInputShape(ctx, 0)) { + auto& signal_input_shape = getInputShape(ctx, 0); + if (static_cast(signal_input_shape.dim_size()) != 0) { + fail_shape_inference("InputSignal of Send must be a scalar."); + } + } + if (hasInputShape(ctx, 1)) { + auto& remote_input_shape = getInputShape(ctx, 1); + if (static_cast(remote_input_shape.dim_size()) != 0) { + fail_shape_inference("Remote of Send must be a scalar."); + } + } + + checkSendInputTensorElemTypes(ctx, "element_types", ctx.getNumInputs() - 2); + } + + if (ctx.getNumOutputs() != 1) { + fail_shape_inference("Send must have one output."); + } + + auto output_element_type = ctx.getOutputType(0)->mutable_tensor_type(); + output_element_type->set_elem_type(TensorProto::BOOL); + ONNX_NAMESPACE::TensorShapeProto output_shape; + updateOutputShape(ctx, 0, {}); + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); +} + +void RecvShapeInfer(ONNX_NAMESPACE::InferenceContext& ctx) { + if (ctx.getNumInputs() != 2) { + fail_shape_inference("Recv must have two inputs."); + } else { + if (hasInputShape(ctx, 0)) { + auto& signal_input_shape = getInputShape(ctx, 0); + if (static_cast(signal_input_shape.dim_size()) != 0) { + fail_shape_inference("InputSignal of Recv must be a scalar."); + } + } + if (hasInputShape(ctx, 1)) { + auto& remote_input_shape = getInputShape(ctx, 1); + if (static_cast(remote_input_shape.dim_size()) != 0) { + fail_shape_inference("Remote of Recv must be a scalar."); + } + } + } + + if (ctx.getNumOutputs() < 2) { + fail_shape_inference("Recv must have at least two outputs."); + } + + updateOutputShape(ctx, 0, {}); + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); + propagateRecvOutputTensorElemTypes(ctx, "element_types", ctx.getNumOutputs() - 1); +} + TensorProto ToDimensionOneFloatTensor(float value) { auto t = ToTensor(std::vector({value})); t.add_dims(1); @@ -3388,30 +3446,7 @@ Return true if all elements are true and false otherwise. "Constrain types to boolean tensors.") .TypeConstraint("V", OpSchema::all_tensor_types(), "All Tensor types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - if (ctx.getNumInputs() < 3) { - fail_shape_inference("Send must have at least three inputs."); - } else { - auto& signal_input_shape = getInputShape(ctx, 0); - if (static_cast(signal_input_shape.dim_size()) != 0) { - fail_shape_inference("InputSignal of Send must be a scalar."); - } - auto& remote_input_shape = getInputShape(ctx, 1); - if (static_cast(remote_input_shape.dim_size()) != 0) { - fail_shape_inference("Remote of Send must be a scalar."); - } - - checkSendInputTensorElemTypes(ctx, "element_types", ctx.getNumInputs() - 2); - } - - if (ctx.getNumOutputs() != 1) { - fail_shape_inference("Send must have one output."); - } - - auto output_element_type = ctx.getOutputType(0)->mutable_tensor_type(); - output_element_type->set_elem_type(TensorProto::BOOL); - ONNX_NAMESPACE::TensorShapeProto output_shape; - updateOutputShape(ctx, 0, {}); - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); + SendShapeInfer(ctx); }); ONNX_CONTRIB_OPERATOR_SCHEMA(Recv) @@ -3437,26 +3472,7 @@ Return true if all elements are true and false otherwise. "Constrain types to boolean tensors.") .TypeConstraint("V", OpSchema::all_tensor_types(), "All Tensor types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - if (ctx.getNumInputs() != 2) { - fail_shape_inference("Recv must have two inputs."); - } else { - auto& signal_input_shape = getInputShape(ctx, 0); - if (static_cast(signal_input_shape.dim_size()) != 0) { - fail_shape_inference("InputSignal of Recv must be a scalar."); - } - auto& remote_input_shape = getInputShape(ctx, 1); - if (static_cast(remote_input_shape.dim_size()) != 0) { - fail_shape_inference("Remote of Recv must be a scalar."); - } - } - - if (ctx.getNumOutputs() < 2) { - fail_shape_inference("Recv must have at least two outputs."); - } - - updateOutputShape(ctx, 0, {}); - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); - propagateRecvOutputTensorElemTypes(ctx, "element_types", ctx.getNumOutputs() - 1); + RecvShapeInfer(ctx); }); ONNX_CONTRIB_OPERATOR_SCHEMA(MegatronF) From dbff0cd09860b60bd0a251c1dbe76785b0b2818c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sun, 28 Jul 2024 13:03:17 -0700 Subject: [PATCH 13/37] [js/node] enable float16 support for Node.js binding (#20581) ### Description enable float16 support for Node.js binding. data of float16 tensor uses `Uint16Array`. --- js/node/src/tensor_helper.cc | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/js/node/src/tensor_helper.cc b/js/node/src/tensor_helper.cc index 1c0b141e6a44f..1062d89f76c5f 100644 --- a/js/node/src/tensor_helper.cc +++ b/js/node/src/tensor_helper.cc @@ -38,13 +38,13 @@ constexpr size_t DATA_TYPE_ELEMENT_SIZE_MAP[] = { 2, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 2, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 4, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 - 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 INT64 not working in Javascript + 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 0, // ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING N/A 1, // ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL - 0, // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 FLOAT16 not working in Javascript + 2, // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE 4, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 - 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 UINT64 not working in Javascript + 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 0, // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 not supported 0, // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 not supported 0 // ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 not supported @@ -60,13 +60,13 @@ constexpr napi_typedarray_type DATA_TYPE_TYPEDARRAY_MAP[] = { napi_uint16_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 napi_int16_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 napi_int32_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 - napi_bigint64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 INT64 not working i + napi_bigint64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 (napi_typedarray_type)(-1), // ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING not supported napi_uint8_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL - (napi_typedarray_type)(-1), // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 FLOAT16 not working + napi_uint16_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 FLOAT16 uses Uint16Array napi_float64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE napi_uint32_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 - napi_biguint64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 UINT64 not working + napi_biguint64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 (napi_typedarray_type)(-1), // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 not supported (napi_typedarray_type)(-1), // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 not supported (napi_typedarray_type)(-1) // ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 not supported @@ -182,9 +182,7 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo * char *buffer = reinterpret_cast(tensorDataTypedArray.ArrayBuffer().Data()); size_t bufferByteOffset = tensorDataTypedArray.ByteOffset(); - // there is a bug in TypedArray::ElementSize(): https://github.com/nodejs/node-addon-api/pull/705 - // TODO: change to TypedArray::ByteLength() in next node-addon-api release. - size_t bufferByteLength = tensorDataTypedArray.ElementLength() * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]; + size_t bufferByteLength = tensorDataTypedArray.ByteLength(); return Ort::Value::CreateTensor(memory_info, buffer + bufferByteOffset, bufferByteLength, dims.empty() ? nullptr : &dims[0], dims.size(), elemType); } From 5bc12bf209304e7f5800845bd612bb3e7b7ab918 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Mon, 29 Jul 2024 23:47:41 +0800 Subject: [PATCH 14/37] [js/webgpu] Add activation for conv3d naive (#21466) ### Description ### Motivation and Context --- .../ops/3rd-party/conv3d_naive_webgpu.ts | 64 +++++----- js/web/test/data/ops/fused-conv3dncdhw.jsonc | 112 ++++++++++++++++++ js/web/test/suite-test-list.jsonc | 1 + 3 files changed, 149 insertions(+), 28 deletions(-) create mode 100644 js/web/test/data/ops/fused-conv3dncdhw.jsonc diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts index f428293add599..a2e5428385101 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts @@ -26,6 +26,9 @@ import {ShapeUtil} from '../../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; import {createTensorShapeVariables, getElementAt, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvAttributes} from '../conv'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; + +import {typeSnippet} from './activation_util'; const arrayProduct = (arr: number[]) => { let product = 1; @@ -218,8 +221,8 @@ export const computeConv3DInfo = export const createConv3DNaiveProgramInfo = (inputs: readonly TensorView[], attributes: ConvAttributes, outputShape: readonly number[], filterDims: readonly number[], pads: readonly number[], dataFormat: string): ProgramInfo => { - const isChannelsLast = dataFormat === 'channelsLast'; - const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1]; + const isChannelLast = dataFormat === 'channelsLast'; + const inChannels = isChannelLast ? inputs[0].dims[3] : inputs[0].dims[1]; // TODO: enable vec4. const isVec4 = false; const workGroupSize: [number, number, number] = [64, 1, 1]; @@ -228,13 +231,14 @@ export const createConv3DNaiveProgramInfo = LOG_DEBUG('verbose', () => `[conv3d_naive_webgpu] dispatch = ${dispatch}`); - const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1; + const innerElementSize = isVec4 ? (isChannelLast && inChannels % 4 !== 0 ? 3 : 4) : 1; const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = [ {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: filterDims}, {type: DataType.uint32, data: pads}, {type: DataType.uint32, data: attributes.strides}, {type: DataType.uint32, data: attributes.dilations} ]; + appendActivationUniformsData(attributes, programUniforms); programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; const hasBias = inputs.length === 3; @@ -251,6 +255,7 @@ export const createConv3DNaiveProgramInfo = {name: 'strides', type: 'u32', length: attributes.strides.length}, {name: 'dilations', type: 'u32', length: attributes.dilations.length} ]; + appendActivationUniforms(attributes, uniforms); // TODO: support component 2, 3. const components = isVec4 ? 4 : 1; const t = tensorTypeToWsglStorageType(inputs[0].dataType); @@ -266,10 +271,12 @@ export const createConv3DNaiveProgramInfo = inputVariables.push(bias); declareFunctions += ` fn getBiasByOutputCoords(coords : array) -> ${isVec4 ? `vec4<${t}>` : t} { - return bias[${isChannelsLast ? getElementAt('coords', 4, 5) : getElementAt('coords', 1, 5)}${ + return bias[${isChannelLast ? getElementAt('coords', 4, 5) : getElementAt('coords', 1, 5)}${ isVec4 ? '/ 4' : ''}]; }`; } + const resType = typeSnippet(innerElementSize, t); + const applyActivation = getActivationSnippet(attributes, resType, t); return ` ${declareFunctions} @@ -287,28 +294,28 @@ export const createConv3DNaiveProgramInfo = let coords = ${output.offsetToIndices('global_idx')}; let batch = ${getElementAt('coords', 0, x.rank)}; let d2 = ${ - isChannelsLast ? getElementAt('coords', x.rank - 1, x.rank) : getElementAt('coords', 1, x.rank)}; + isChannelLast ? getElementAt('coords', x.rank - 1, x.rank) : getElementAt('coords', 1, x.rank)}; let xFRCCorner = vec3(${ - isChannelsLast ? getElementAt('coords', 1, x.rank) : getElementAt('coords', 2, x.rank)}, - ${isChannelsLast ? getElementAt('coords', 2, x.rank) : getElementAt('coords', 3, x.rank)}, + isChannelLast ? getElementAt('coords', 1, x.rank) : getElementAt('coords', 2, x.rank)}, + ${isChannelLast ? getElementAt('coords', 2, x.rank) : getElementAt('coords', 3, x.rank)}, ${ - isChannelsLast ? getElementAt('coords', 3, x.rank) : - getElementAt('coords', 4, x.rank)}) * uniforms.strides - uniforms.pads; + isChannelLast ? getElementAt('coords', 3, x.rank) : + getElementAt('coords', 4, x.rank)}) * uniforms.strides - uniforms.pads; let xFCorner = xFRCCorner.x; let xRCorner = xFRCCorner.y; let xCCorner = xFRCCorner.z; let xShapeY = ${ - isChannelsLast ? getElementAt('uniforms.x_shape', 1, x.rank) : getElementAt('uniforms.x_shape', 2, x.rank)}; + isChannelLast ? getElementAt('uniforms.x_shape', 1, x.rank) : getElementAt('uniforms.x_shape', 2, x.rank)}; let xShapeZ = ${ - isChannelsLast ? getElementAt('uniforms.x_shape', 2, x.rank) : getElementAt('uniforms.x_shape', 3, x.rank)}; + isChannelLast ? getElementAt('uniforms.x_shape', 2, x.rank) : getElementAt('uniforms.x_shape', 3, x.rank)}; let xShapeW = ${ - isChannelsLast ? getElementAt('uniforms.x_shape', 3, x.rank) : getElementAt('uniforms.x_shape', 4, x.rank)}; + isChannelLast ? getElementAt('uniforms.x_shape', 3, x.rank) : getElementAt('uniforms.x_shape', 4, x.rank)}; let xShapeU = ${ - isChannelsLast ? getElementAt('uniforms.x_shape', 4, x.rank) : getElementAt('uniforms.x_shape', 1, x.rank)}; + isChannelLast ? getElementAt('uniforms.x_shape', 4, x.rank) : getElementAt('uniforms.x_shape', 1, x.rank)}; let inputDepthNearestVec4 = (xShapeU / 4) * 4; let inputDepthVec4Remainder = xShapeU % 4; - var dotProd = 0.0; + var value = 0.0; for (var wF = 0u; wF < uniforms.filter_dims[0]; wF++) { let xF = xFCorner + wF * uniforms.dilations[0]; if (xF < 0 || xF >= xShapeY) { @@ -329,13 +336,13 @@ export const createConv3DNaiveProgramInfo = for (var d1 = 0u; d1 < inputDepthNearestVec4; d1 += 4) { ${ - isChannelsLast ? `let xValues = vec4( + isChannelLast ? `let xValues = vec4( getX(batch, xF, xR, xC, d1), getX(batch, xF, xR, xC, d1 + 1), getX(batch, xF, xR, xC, d1 + 2), getX(batch, xF, xR, xC, d1 + 3)); ` : - `let xValues = vec4( + `let xValues = vec4( getX(batch, d1, xF, xR, xC), getX(batch, d1 + 1, xF, xR, xC), getX(batch, d1 + 2, xF, xR, xC), @@ -346,36 +353,36 @@ export const createConv3DNaiveProgramInfo = getW(d2, d1 + 1, wF, wR, wC), getW(d2, d1 + 2, wF, wR, wC), getW(d2, d1 + 3, wF, wR, wC)); - dotProd += dot(xValues, wValues); + value += dot(xValues, wValues); } if (inputDepthVec4Remainder == 1) { ${ - isChannelsLast ? `dotProd += getX(batch, xF, xR, xC, inputDepthNearestVec4) + isChannelLast ? `value += getX(batch, xF, xR, xC, inputDepthNearestVec4) * getW(d2, inputDepthNearestVec4, wF, wR, wC);` : - `dotProd += getX(batch, inputDepthNearestVec4, xF, xR, xC) + `value += getX(batch, inputDepthNearestVec4, xF, xR, xC) * getW(d2, inputDepthNearestVec4, wF, wR, wC);`} } else if (inputDepthVec4Remainder == 2) { ${ - isChannelsLast ? `let xValues = vec2( + isChannelLast ? `let xValues = vec2( getX(batch, xF, xR, xC, inputDepthNearestVec4), getX(batch, xF, xR, xC, inputDepthNearestVec4 + 1)); ` : - `let xValues = vec2( + `let xValues = vec2( getX(batch, inputDepthNearestVec4, xF, xR, xC), getX(batch, inputDepthNearestVec4 + 1, xF, xR, xC)); `} let wValues = vec2( getW(d2, inputDepthNearestVec4, wF, wR, wC), getW(d2, inputDepthNearestVec4 + 1, wF, wR, wC)); - dotProd += dot(xValues, wValues); + value += dot(xValues, wValues); } else if (inputDepthVec4Remainder == 3) { ${ - isChannelsLast ? `let xValues = vec3( + isChannelLast ? `let xValues = vec3( getX(batch, xF, xR, xC, inputDepthNearestVec4), getX(batch, xF, xR, xC, inputDepthNearestVec4 + 1), getX(batch, xF, xR, xC, inputDepthNearestVec4 + 2)); ` : - `let xValues = vec3( + `let xValues = vec3( getX(batch, inputDepthNearestVec4, xF, xR, xC), getX(batch, inputDepthNearestVec4 + 1, xF, xR, xC), getX(batch, inputDepthNearestVec4 + 2, xF, xR, xC)); @@ -384,19 +391,20 @@ export const createConv3DNaiveProgramInfo = getW(d2, inputDepthNearestVec4, wF, wR, wC), getW(d2, inputDepthNearestVec4 + 1, wF, wR, wC), getW(d2, inputDepthNearestVec4 + 2, wF, wR, wC)); - dotProd += dot(xValues, wValues); + value += dot(xValues, wValues); } } } } - ${hasBias ? 'dotProd = dotProd + getBiasByOutputCoords(coords)' : ''}; - result[global_idx] = f32(dotProd); + ${hasBias ? 'value = value + getBiasByOutputCoords(coords)' : ''}; + ${applyActivation} + result[global_idx] = f32(value); }`; }; return { name: 'Conv3DNaive', shaderCache: - {hint: `${attributes.cacheKey};${isChannelsLast};${innerElementSize};${hasBias}`, inputDependencies}, + {hint: `${attributes.cacheKey};${isChannelLast};${innerElementSize};${hasBias}`, inputDependencies}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, diff --git a/js/web/test/data/ops/fused-conv3dncdhw.jsonc b/js/web/test/data/ops/fused-conv3dncdhw.jsonc new file mode 100644 index 0000000000000..1801ca380aa09 --- /dev/null +++ b/js/web/test/data/ops/fused-conv3dncdhw.jsonc @@ -0,0 +1,112 @@ +[ + { + "name": "fused conv3d with relu, x=[1, 1, 2, 1, 2], f=[2, 1, 2, 1, 2], s=1, d=1, p=valid, relu", + "operator": "FusedConv", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "activation", "data": "Relu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 1, 2], "type": "ints" }, + { "name": "auto_pad", "data": "VALID", "type": "string" }, + { "name": "strides", "data": [1, 1, 1], "type": "ints" }, + { "name": "dilations", "data": [1, 1, 1], "type": "ints" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.25, 0.5, 0.75, 1], + "dims": [1, 1, 2, 1, 2], + "type": "float32" + }, + { + "data": [-0.125, -0.25, -0.375, 0.5, 0.625, -0.75, -0.875, -1], + "dims": [2, 1, 2, 1, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.0625, 0], + "dims": [1, 2, 1, 1, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "fused conv3d with clip", + "operator": "FusedConv", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "activation", "data": "Clip", "type": "string" }, + { "name": "activation_params", "data": [1.0, 3.0], "type": "floats" }, + { "name": "kernel_shape", "data": [2, 1, 2], "type": "ints" }, + { "name": "auto_pad", "data": "VALID", "type": "string" }, + { "name": "strides", "data": [1, 1, 1], "type": "ints" }, + { "name": "dilations", "data": [1, 1, 1], "type": "ints" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.25, 0.5, 0.75, 1], + "dims": [1, 1, 2, 1, 2], + "type": "float32" + }, + { + "data": [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1], + "dims": [2, 1, 2, 1, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2.1875], + "dims": [1, 2, 1, 1, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "fused conv3d with HardSigmoid, x=[1, 1, 2, 1, 2], f=[2, 1, 2, 1, 2], s=1, d=1, p=valid, relu", + "operator": "FusedConv", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "activation_params", "data": [0.1, 0.3], "type": "floats" }, + { "name": "kernel_shape", "data": [2, 1, 2], "type": "ints" }, + { "name": "auto_pad", "data": "VALID", "type": "string" }, + { "name": "strides", "data": [1, 1, 1], "type": "ints" }, + { "name": "dilations", "data": [1, 1, 1], "type": "ints" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.25, 0.5, 0.75, 1], + "dims": [1, 1, 2, 1, 2], + "type": "float32" + }, + { + "data": [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1], + "dims": [2, 1, 2, 1, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.39375001192092896, 0.518750011920929], + "dims": [1, 2, 1, 1, 1], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 4a3a23bfe91b4..4aaf9d16b2b0e 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1358,6 +1358,7 @@ "fast-gelu.jsonc", "floor.jsonc", "fused-conv.jsonc", + "fused-conv3dncdhw.jsonc", "gather-elements.jsonc", "gemm.jsonc", "global-average-pool.jsonc", From 94eb70d98348d83343207e113f9abaa0e7c6ea37 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Mon, 29 Jul 2024 23:50:14 +0800 Subject: [PATCH 15/37] [WebNN EP] Add labels for all WebNN operators (#21516) In order to provide more diagnosable error messages for developers. Spec change: https://github.com/webmachinelearning/webnn/pull/742 --- .../builders/impl/activation_op_builder.cc | 13 ++++--- .../builders/impl/argmax_min_op_builder.cc | 1 + .../webnn/builders/impl/binary_op_builder.cc | 15 ++++--- .../webnn/builders/impl/cast_op_builder.cc | 5 ++- .../webnn/builders/impl/clip_op_builder.cc | 1 + .../webnn/builders/impl/concat_op_builder.cc | 5 ++- .../webnn/builders/impl/conv_op_builder.cc | 16 +++++++- .../impl/dequantizeLinear_op_builder.cc | 17 ++++++-- .../impl/dynamicQuantizeLinear_op_builder.cc | 3 +- .../webnn/builders/impl/expand_op_builder.cc | 6 ++- .../webnn/builders/impl/flatten_op_builder.cc | 4 +- .../webnn/builders/impl/gather_op_builder.cc | 1 + .../webnn/builders/impl/gemm_op_builder.cc | 39 +++++++++++++++---- .../webnn/builders/impl/logical_op_builder.cc | 12 +++--- .../webnn/builders/impl/max_min_op_builder.cc | 10 +++-- .../builders/impl/normalization_op_builder.cc | 16 ++++++-- .../webnn/builders/impl/pad_op_builder.cc | 6 ++- .../webnn/builders/impl/pool_op_builder.cc | 1 + .../builders/impl/reduction_op_builder.cc | 1 + .../webnn/builders/impl/reshape_op_builder.cc | 7 +++- .../webnn/builders/impl/resize_op_builder.cc | 1 + .../webnn/builders/impl/shape_op_builder.cc | 9 ++++- .../webnn/builders/impl/slice_op_builder.cc | 5 ++- .../webnn/builders/impl/softmax_op_builder.cc | 4 +- .../webnn/builders/impl/split_op_builder.cc | 1 + .../impl/squeeze_unsqueeze_op_builder.cc | 8 +++- .../webnn/builders/impl/ternary_op_builder.cc | 4 +- .../builders/impl/transpose_op_builder.cc | 1 + .../builders/impl/triangular_op_builder.cc | 1 + .../webnn/builders/impl/unary_op_builder.cc | 30 +++++++------- 30 files changed, 180 insertions(+), 63 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc index af0f0133b497a..626aaf5c71b74 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc @@ -36,6 +36,7 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, NodeAttrHelper helper(node); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); if (op_type == "Elu") { options.set("alpha", helper.Get("alpha", 1.0f)); output = model_builder.GetBuilder().call("elu", input, options); @@ -46,20 +47,20 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, options.set("beta", helper.Get("beta", 0.5f)); output = model_builder.GetBuilder().call("hardSigmoid", input, options); } else if (op_type == "HardSwish") { - output = model_builder.GetBuilder().call("hardSwish", input); + output = model_builder.GetBuilder().call("hardSwish", input, options); } else if (op_type == "LeakyRelu") { options.set("alpha", helper.Get("alpha", 0.0f)); output = model_builder.GetBuilder().call("leakyRelu", input, options); } else if (op_type == "Relu") { - output = model_builder.GetBuilder().call("relu", input); + output = model_builder.GetBuilder().call("relu", input, options); } else if (op_type == "Sigmoid") { - output = model_builder.GetBuilder().call("sigmoid", input); + output = model_builder.GetBuilder().call("sigmoid", input, options); } else if (op_type == "Softplus") { - output = model_builder.GetBuilder().call("softplus", input); + output = model_builder.GetBuilder().call("softplus", input, options); } else if (op_type == "Softsign") { - output = model_builder.GetBuilder().call("softsign", input); + output = model_builder.GetBuilder().call("softsign", input, options); } else if (op_type == "Tanh") { - output = model_builder.GetBuilder().call("tanh", input); + output = model_builder.GetBuilder().call("tanh", input, options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); diff --git a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc index 1ae63a644a287..05f3a742a3775 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc @@ -47,6 +47,7 @@ Status ArgMaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, options.set("keepDimensions", keep_dims == 1); // TODO(Honry): check whether int64 output data type is supported by WebNN opSupportLimits() API. options.set("outputDataType", "int64"); + options.set("label", node.Name()); emscripten::val output = emscripten::val::object(); const auto& op_type = node.OpType(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc index 23e19d5943144..555de68cd60fe 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -35,18 +35,21 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name()); emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name()); emscripten::val output = emscripten::val::object(); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + if (op_type == "Add") { - output = model_builder.GetBuilder().call("add", input0, input1); + output = model_builder.GetBuilder().call("add", input0, input1, options); } else if (op_type == "Sub") { - output = model_builder.GetBuilder().call("sub", input0, input1); + output = model_builder.GetBuilder().call("sub", input0, input1, options); } else if (op_type == "Mul") { - output = model_builder.GetBuilder().call("mul", input0, input1); + output = model_builder.GetBuilder().call("mul", input0, input1, options); } else if (op_type == "Div") { - output = model_builder.GetBuilder().call("div", input0, input1); + output = model_builder.GetBuilder().call("div", input0, input1, options); } else if (op_type == "Pow") { - output = model_builder.GetBuilder().call("pow", input0, input1); + output = model_builder.GetBuilder().call("pow", input0, input1, options); } else if (op_type == "PRelu") { - output = model_builder.GetBuilder().call("prelu", input0, input1); + output = model_builder.GetBuilder().call("prelu", input0, input1, options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BinaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); diff --git a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc index a97d71b90de55..a08e1681a8464 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc @@ -69,8 +69,11 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, node.Name(), " type: ", to_type); } + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + emscripten::val output = - model_builder.GetBuilder().call("cast", input, emscripten::val(operand_type)); + model_builder.GetBuilder().call("cast", input, emscripten::val(operand_type), options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc index e6403a4cd12dc..b5c3206072d50 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc @@ -53,6 +53,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, "GetClipMinMax failed"); options.set("minValue", minValue); options.set("maxValue", maxValue); + options.set("label", node.Name()); emscripten::val input = model_builder.GetOperand(input_name); emscripten::val output = model_builder.GetBuilder().call("clamp", input, options); diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index e4f98b09e03c5..dedc76b80e978 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -42,8 +42,11 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, inputs.push_back(model_builder.GetOperand(input->Name())); } + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + emscripten::val output = - model_builder.GetBuilder().call("concat", emscripten::val::array(inputs), axis); + model_builder.GetBuilder().call("concat", emscripten::val::array(inputs), axis, options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 320aaa03930fd..4f3f7459a7b5b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -242,6 +242,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N } emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); ORT_RETURN_IF_ERROR(SetConvBaseOptions( model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_nhwc, is_conv1d, logger)); bool depthwise = false; @@ -276,7 +277,12 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (!is_nhwc || !is_constant_weight) { // The weight_shape has been appended 1's, reshape weight operand. std::vector new_shape = GetVecUint32FromVecInt64(weight_shape); - filter = model_builder.GetBuilder().call("reshape", filter, emscripten::val::array(new_shape)); + emscripten::val reshape_options = emscripten::val::object(); + reshape_options.set("label", node.Name() + "_reshape_filter"); + filter = model_builder.GetBuilder().call("reshape", + filter, + emscripten::val::array(new_shape), + reshape_options); } } @@ -293,6 +299,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N perm = {0, 2, 3, 1}; // L_0231 } transpose_options.set("permutation", emscripten::val::array(perm)); + transpose_options.set("label", node.Name() + "_transpose_filter"); filter = model_builder.GetBuilder().call("transpose", filter, transpose_options); } @@ -323,7 +330,12 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N std::vector output_shape; ORT_RETURN_IF_NOT(GetShape(*output_defs[0], output_shape, logger), "Cannot get output shape"); std::vector new_shape = GetVecUint32FromVecInt64(output_shape); - output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array(new_shape)); + emscripten::val reshape_options = emscripten::val::object(); + reshape_options.set("label", node.Name() + "_reshape_output"); + output = model_builder.GetBuilder().call("reshape", + output, + emscripten::val::array(new_shape), + reshape_options); } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); diff --git a/onnxruntime/core/providers/webnn/builders/impl/dequantizeLinear_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dequantizeLinear_op_builder.cc index 66d502a4e6727..93a12a696cce1 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/dequantizeLinear_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/dequantizeLinear_op_builder.cc @@ -50,11 +50,22 @@ Status DequantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil std::vector target_shape{static_cast(input_shape[axis])}; target_shape.insert(target_shape.begin(), axis, 1); target_shape.insert(target_shape.end(), input_shape.size() - axis - 1, 1); - scale = model_builder.GetBuilder().call("reshape", scale, emscripten::val::array(target_shape)); + emscripten::val reshape_scale_options = emscripten::val::object(); + reshape_scale_options.set("label", node.Name() + "_reshape_scale"); + scale = model_builder.GetBuilder().call("reshape", + scale, + emscripten::val::array(target_shape), + reshape_scale_options); + emscripten::val reshape_zero_point_options = emscripten::val::object(); + reshape_zero_point_options.set("label", node.Name() + "_reshape_zero_point"); zero_point = model_builder.GetBuilder().call("reshape", - zero_point, emscripten::val::array(target_shape)); + zero_point, + emscripten::val::array(target_shape), + reshape_zero_point_options); } - output = model_builder.GetBuilder().call("dequantizeLinear", input, scale, zero_point); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + output = model_builder.GetBuilder().call("dequantizeLinear", input, scale, zero_point, options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); diff --git a/onnxruntime/core/providers/webnn/builders/impl/dynamicQuantizeLinear_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dynamicQuantizeLinear_op_builder.cc index 3b5f64584b828..55746bb1f61f0 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/dynamicQuantizeLinear_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/dynamicQuantizeLinear_op_builder.cc @@ -31,8 +31,9 @@ Status DynamicQuantizaLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); - output_array = model_builder.GetBuilder().call("dynamicQuantizeLinear", input); + output_array = model_builder.GetBuilder().call("dynamicQuantizeLinear", input, options); for (size_t i = 0, count = output_array["length"].as(); i < count; i++) { model_builder.AddOperand(node.OutputDefs()[i]->Name(), std::move(output_array[i])); diff --git a/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc index 9c75c00fa9273..c8cea833983b1 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc @@ -53,10 +53,14 @@ Status ExpandOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::vector output_shape; ORT_RETURN_IF_NOT(GetBidirectionalBroadcastShape(input_shape, new_shape, output_shape), "Cannot get output shape."); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + emscripten::val output = model_builder.GetBuilder().call("expand", input, - emscripten::val::array(GetVecUint32FromVecInt64(output_shape))); + emscripten::val::array(GetVecUint32FromVecInt64(output_shape)), + options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/flatten_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/flatten_op_builder.cc index 31b1bd92a9503..d0ece026a7048 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/flatten_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/flatten_op_builder.cc @@ -52,8 +52,10 @@ Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, SafeInt(num_post_axis_elements)}; emscripten::val inputs = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); emscripten::val output = model_builder.GetBuilder().call( - "reshape", inputs, emscripten::val::array(new_shape)); + "reshape", inputs, emscripten::val::array(new_shape), options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index 014a08616c44f..23233539d34c7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -42,6 +42,7 @@ Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name()); emscripten::val options = emscripten::val::object(); options.set("axis", axis); + options.set("label", node.Name()); emscripten::val output = model_builder.GetBuilder().call("gather", input, indices, options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 53f885019ab2f..bd452b118fe3e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -39,6 +39,8 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N emscripten::val a = model_builder.GetOperand(node.InputDefs()[a_idx]->Name()); emscripten::val b = model_builder.GetOperand(node.InputDefs()[b_idx]->Name()); emscripten::val output = emscripten::val::object(); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); if (op_type == "MatMul") { std::vector a_shape; if (!GetShape(*input_defs[a_idx], a_shape, logger)) { @@ -53,23 +55,34 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (a_shape.size() == 1) { extended_a_shape = true; a_shape.insert(a_shape.begin(), 1); + emscripten::val reshape_a_options = emscripten::val::object(); + reshape_a_options.set("label", node.Name() + "_reshape_a"); a = model_builder.GetBuilder().call("reshape", a, - emscripten::val::array(GetVecUint32FromVecInt64(a_shape))); + emscripten::val::array(GetVecUint32FromVecInt64(a_shape)), + reshape_a_options); } // If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. bool extended_b_shape = false; if (b_shape.size() == 1) { extended_b_shape = true; b_shape.push_back(1); + emscripten::val reshape_b_options = emscripten::val::object(); + reshape_b_options.set("label", node.Name() + "_reshape_b"); b = model_builder.GetBuilder().call("reshape", b, - emscripten::val::array(GetVecUint32FromVecInt64(b_shape))); + emscripten::val::array(GetVecUint32FromVecInt64(b_shape)), + reshape_b_options); } - output = model_builder.GetBuilder().call("matmul", a, b); + output = model_builder.GetBuilder().call("matmul", a, b, options); + emscripten::val reshape_output_options = emscripten::val::object(); + reshape_output_options.set("label", node.Name() + "_reshape_output"); // If the inputs are both 1D, reduce the output to a scalar. if (extended_a_shape && extended_b_shape) { - output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array()); + output = model_builder.GetBuilder().call("reshape", + output, + emscripten::val::array(), + reshape_output_options); } // After matrix multiplication the prepended 1 is removed. else if (extended_a_shape) { @@ -78,7 +91,10 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N new_shape.push_back(narrow(b_shape[i])); } new_shape.push_back(narrow(b_shape.back())); - output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array(new_shape)); + output = model_builder.GetBuilder().call("reshape", + output, + emscripten::val::array(new_shape), + reshape_output_options); } // After matrix multiplication the appended 1 is removed. else if (extended_b_shape) { @@ -86,7 +102,10 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N for (size_t i = 0; i < a_shape.size() - 1; i++) { new_shape.push_back(narrow(a_shape[i])); } - output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array(new_shape)); + output = model_builder.GetBuilder().call("reshape", + output, + emscripten::val::array(new_shape), + reshape_output_options); } } else if (op_type == "MatMulInteger") { emscripten::val a_zero_point = emscripten::val::null(); @@ -101,9 +120,13 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N } else { b_zero_point = model_builder.GetZeroConstant("uint8"); } - output = model_builder.GetBuilder().call("matmulInteger", a, a_zero_point, b, b_zero_point); + output = model_builder.GetBuilder().call("matmulInteger", + a, + a_zero_point, + b, + b_zero_point, + options); } else { // Gemm - emscripten::val options = emscripten::val::object(); NodeAttrHelper helper(node); const auto transA = helper.Get("transA", 0); options.set("aTranspose", emscripten::val(transA == 1)); diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index e56e8f6a3eb6d..23f3a938fee5e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -33,16 +33,18 @@ Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name()); emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name()); emscripten::val output = emscripten::val::object(); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); if (op_type == "Equal") { - output = model_builder.GetBuilder().call("equal", input0, input1); + output = model_builder.GetBuilder().call("equal", input0, input1, options); } else if (op_type == "Greater") { - output = model_builder.GetBuilder().call("greater", input0, input1); + output = model_builder.GetBuilder().call("greater", input0, input1, options); } else if (op_type == "GreaterOrEqual") { - output = model_builder.GetBuilder().call("greaterOrEqual", input0, input1); + output = model_builder.GetBuilder().call("greaterOrEqual", input0, input1, options); } else if (op_type == "Less") { - output = model_builder.GetBuilder().call("lesser", input0, input1); + output = model_builder.GetBuilder().call("lesser", input0, input1, options); } else if (op_type == "LessOrEqual") { - output = model_builder.GetBuilder().call("lesserOrEqual", input0, input1); + output = model_builder.GetBuilder().call("lesserOrEqual", input0, input1, options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "LogicalOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index 0168f59273545..1080fd0a3f943 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -43,22 +43,26 @@ Status MaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, ORT_RETURN_IF_NOT(op_type == "Max" || op_type == "Min", "MaxMinOpBuilder, unknown op: ", op_type); emscripten::val output = emscripten::val::object(); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); if (input_count == 1) { // For 1 input, just concat the single input as workaround. // TODO: use identity instead once it's available in WebNN. emscripten::val inputs = emscripten::val::array(); inputs.call("push", input0); - output = model_builder.GetBuilder().call("concat", inputs, 0); + output = model_builder.GetBuilder().call("concat", inputs, 0, options); } else { std::string webnn_op_name = op_type == "Max" ? "max" : "min"; emscripten::val input1 = model_builder.GetOperand(input_defs[1]->Name()); - output = model_builder.GetBuilder().call(webnn_op_name.c_str(), input0, input1); + output = model_builder.GetBuilder().call(webnn_op_name.c_str(), input0, input1, options); for (size_t input_index = 2; input_index < input_count; ++input_index) { emscripten::val next_input = model_builder.GetOperand(input_defs[input_index]->Name()); - output = model_builder.GetBuilder().call(webnn_op_name.c_str(), output, next_input); + emscripten::val next_options = emscripten::val::object(); + next_options.set("label", node.Name() + "_" + input_defs[input_index]->Name()); + output = model_builder.GetBuilder().call(webnn_op_name.c_str(), output, next_input, next_options); } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index a2aa0df5586e3..4d068baf35e72 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -42,6 +42,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder const auto rank = input_shape.size(); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); std::vector scale_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[1], scale_shape, logger), "Cannot get scale shape"); @@ -116,7 +117,12 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder new_shape.erase(insertion_point, insertion_point + excess_rank); *insertion_point = sum; } - input = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape)); + emscripten::val reshape_input_options = emscripten::val::object(); + reshape_input_options.set("label", node.Name() + "_reshape_input"); + input = model_builder.GetBuilder().call("reshape", + input, + emscripten::val::array(new_shape), + reshape_input_options); } if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { @@ -126,8 +132,12 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder // Reshape back to the original output shape for 3D input. if (input_shape.size() != 4) { std::vector output_shape = GetVecUint32FromVecInt64(input_shape); - output = model_builder.GetBuilder().call( - "reshape", output, emscripten::val::array(output_shape)); + emscripten::val reshape_output_options = emscripten::val::object(); + reshape_output_options.set("label", node.Name() + "reshape_output"); + output = model_builder.GetBuilder().call("reshape", + output, + emscripten::val::array(output_shape), + reshape_output_options); } } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported normalization op: ", op_type); diff --git a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc index bc90821ba4ed8..071155a2fb372 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc @@ -73,6 +73,7 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); NodeAttrHelper helper(node); const auto pad_mode = helper.Get("mode", std::string("constant")); @@ -145,9 +146,12 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, starts.push_back(start_padding[i] >= 0 ? SafeInt(0) : SafeInt(-start_padding[i])); sizes.push_back(SafeInt(input_shape[i] + start_padding[i] + end_padding[i])); } + emscripten::val slice_options = emscripten::val::object(); + slice_options.set("label", node.Name() + "_slice_output"); output = model_builder.GetBuilder().call("slice", output, emscripten::val::array(starts), - emscripten::val::array(sizes)); + emscripten::val::array(sizes), + slice_options); } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc index 8b3eecf35fcc8..0af62dacedbd5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc @@ -59,6 +59,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); NodeAttrHelper helper(node); const auto kernel_shape = helper.Get("kernel_shape", std::vector{0, 0}); diff --git a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc index 461050849385a..3e6d4d9820e9a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc @@ -57,6 +57,7 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, NodeAttrHelper helper(node); const auto keep_dims = helper.Get("keepdims", 1); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); options.set("keepDimensions", keep_dims == 1); std::vector axes_data; diff --git a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc index b5005269b96a7..a7911683f0355 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc @@ -58,8 +58,13 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::transform(target_shape.cbegin(), target_shape.cend(), std::back_inserter(new_shape), [](int64_t dim) -> uint32_t { return SafeInt(dim); }); + + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); emscripten::val output = model_builder.GetBuilder().call("reshape", - input, emscripten::val::array(new_shape)); + input, + emscripten::val::array(new_shape), + options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index c4ca980fec715..2218c858951d3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -106,6 +106,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); NodeAttrHelper helper(node); const auto mode = helper.Get("mode", "nearest"); if (mode == "linear") { diff --git a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc index 1552023d3f876..0eb7dafdffe4d 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc @@ -55,8 +55,15 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val sizes = emscripten::val::array(); sizes.call("push", slice_length); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + // Since WebNN doesn't support Shape op, we use constant + slice ops as workaround. - emscripten::val output = model_builder.GetBuilder().call("slice", shape_constant, starts, sizes); + emscripten::val output = model_builder.GetBuilder().call("slice", + shape_constant, + starts, + sizes, + options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index fb452aec1c929..bef13841c646c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -97,9 +97,12 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, sizes.begin(), [](int64_t i, int64_t j) { return SafeInt(i - j); }); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); emscripten::val output = model_builder.GetBuilder().call("slice", inputs, emscripten::val::array(starts), - emscripten::val::array(sizes)); + emscripten::val::array(sizes), + options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index 95c1dbd518061..798cfabae65db 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -42,7 +42,9 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, int32_t axis = helper.Get("axis", default_axis); axis = static_cast(HandleNegativeAxis(axis, input_size)); - emscripten::val output = model_builder.GetBuilder().call("softmax", input, axis); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + emscripten::val output = model_builder.GetBuilder().call("softmax", input, axis, options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc index ea3b8ef384ddc..4c59b694d690a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc @@ -49,6 +49,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); const size_t rank = input_shape.size(); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); NodeAttrHelper helper(node); int32_t axis = helper.Get("axis", 0); diff --git a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc index 8e6feb62fa8c4..5eff96873b8c4 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc @@ -54,7 +54,6 @@ Status SqueezeUnsqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); const auto input_rank = input_shape.size(); - emscripten::val options = emscripten::val::object(); std::vector axes_data; auto rank = input_rank; @@ -111,7 +110,12 @@ Status SqueezeUnsqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil "SqueezeUnsqueezeOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); } - output = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape)); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + output = model_builder.GetBuilder().call("reshape", + input, + emscripten::val::array(new_shape), + options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 841e2d18244d5..2ed8330bf25be 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -32,9 +32,11 @@ Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name()); emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name()); emscripten::val input2 = model_builder.GetOperand(node.InputDefs()[2]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); emscripten::val output = emscripten::val::object(); if (op_type == "Where") { - output = model_builder.GetBuilder().call("where", input0, input1, input2); + output = model_builder.GetBuilder().call("where", input0, input1, input2, options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "TernaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); diff --git a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc index 3921b1da188c3..03c88ad9db88a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc @@ -42,6 +42,7 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); std::vector permutation = GetVecUint32FromVecInt64(perm); options.set("permutation", emscripten::val::array(permutation)); emscripten::val output = model_builder.GetBuilder().call("transpose", input, options); diff --git a/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc index e4b7021d49b30..0c818533918a4 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc @@ -46,6 +46,7 @@ Status TriangularOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val output = emscripten::val::object(); NodeAttrHelper helper(node); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); const bool upper = helper.Get("upper", 1); options.set("upper", upper); diff --git a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc index e0016de8e69b7..061404c8a9ce0 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc @@ -30,35 +30,37 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name()); emscripten::val output = emscripten::val::object(); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); if (op_type == "Abs") { - output = model_builder.GetBuilder().call("abs", input); + output = model_builder.GetBuilder().call("abs", input, options); } else if (op_type == "Ceil") { - output = model_builder.GetBuilder().call("ceil", input); + output = model_builder.GetBuilder().call("ceil", input, options); } else if (op_type == "Cos") { - output = model_builder.GetBuilder().call("cos", input); + output = model_builder.GetBuilder().call("cos", input, options); } else if (op_type == "Erf") { - output = model_builder.GetBuilder().call("erf", input); + output = model_builder.GetBuilder().call("erf", input, options); } else if (op_type == "Exp") { - output = model_builder.GetBuilder().call("exp", input); + output = model_builder.GetBuilder().call("exp", input, options); } else if (op_type == "Floor") { - output = model_builder.GetBuilder().call("floor", input); + output = model_builder.GetBuilder().call("floor", input, options); } else if (op_type == "Identity") { - output = model_builder.GetBuilder().call("identity", input); + output = model_builder.GetBuilder().call("identity", input, options); } else if (op_type == "Log") { - output = model_builder.GetBuilder().call("log", input); + output = model_builder.GetBuilder().call("log", input, options); } else if (op_type == "Neg") { - output = model_builder.GetBuilder().call("neg", input); + output = model_builder.GetBuilder().call("neg", input, options); } else if (op_type == "Not") { - output = model_builder.GetBuilder().call("logicalNot", input); + output = model_builder.GetBuilder().call("logicalNot", input, options); } else if (op_type == "Reciprocal") { - output = model_builder.GetBuilder().call("reciprocal", input); + output = model_builder.GetBuilder().call("reciprocal", input, options); } else if (op_type == "Sin") { - output = model_builder.GetBuilder().call("sin", input); + output = model_builder.GetBuilder().call("sin", input, options); } else if (op_type == "Sqrt") { - output = model_builder.GetBuilder().call("sqrt", input); + output = model_builder.GetBuilder().call("sqrt", input, options); } else if (op_type == "Tan") { - output = model_builder.GetBuilder().call("tan", input); + output = model_builder.GetBuilder().call("tan", input, options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "UnaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); From d8888136e3cdf29fa63d3b0a08a58683a7c9f0a0 Mon Sep 17 00:00:00 2001 From: mingyueliuh <131847423+mingyueliuh@users.noreply.github.com> Date: Mon, 29 Jul 2024 12:45:52 -0400 Subject: [PATCH 16/37] Add support tensor element type for register custom op shape infer function (#21387) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Functionality extension for the SetOutputShape method in custom op shape inference. ### Motivation and Context - **SetOutputShape** Interface enhancement Actually, the shape infer function need set the tensor type and shape ,Add a parameter **type** to allow users to specify the tensor type, and set **ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT** as default value to ensure compatibility. Co-authored-by: mingyue --- include/onnxruntime/core/session/onnxruntime_cxx_api.h | 2 +- include/onnxruntime/core/session/onnxruntime_cxx_inline.h | 3 ++- onnxruntime/core/session/custom_ops.cc | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 5d974e1ff5185..29a229f427163 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2216,7 +2216,7 @@ struct ShapeInferContext { size_t GetInputCount() const { return input_shapes_.size(); } - Status SetOutputShape(size_t indice, const Shape& shape); + Status SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); int64_t GetAttrInt(const char* attr_name); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index aaef111b9f15b..9b9dd81a749c0 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1998,9 +1998,10 @@ inline ShapeInferContext::ShapeInferContext(const OrtApi* ort_api, } } -inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape) { +inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type) { OrtTensorTypeAndShapeInfo* info = {}; ORT_CXX_RETURN_ON_API_FAIL(ort_api_->CreateTensorTypeAndShapeInfo(&info)); + ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetTensorElementType(info, type)); using InfoPtr = std::unique_ptr>; diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 4c782f647371e..33d2a0244b453 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -105,6 +105,7 @@ struct OrtShapeInferContext { } } ONNX_NAMESPACE::updateOutputShape(ctx_, index, shape_proto); + ONNX_NAMESPACE::updateOutputElemType(ctx_, index, info->type); return onnxruntime::Status::OK(); } From 05cef469e81e3695667f122beecf97600094d09b Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Tue, 30 Jul 2024 00:59:46 +0800 Subject: [PATCH 17/37] Move on-device training packages publish step (#21539) ### Description Since the onedevice training cpu packaging has been a separated pipeline, it's nuget package publishing step must be moved as well. ### Motivation and Context Fixes the exception in Nuget Publishing Packaging Pipeline caused by #21485 --- .../c-api-training-packaging-pipelines.yml | 27 +++++++++++++++++-- .../github/azure-pipelines/publish-nuget.yml | 7 +---- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/c-api-training-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-training-packaging-pipelines.yml index aecece05a0e58..22ee7de8a5de0 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-training-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-training-packaging-pipelines.yml @@ -32,13 +32,25 @@ parameters: displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. type: number default: 0 - + +# these 2 parameters are used for debugging. +- name: SpecificArtifact + displayName: Use Specific Artifact (Debugging only) + type: boolean + default: false + +- name: BuildId + displayName: Pipeline BuildId, you could find it in the URL + type: string + default: '0' + stages: - template: stages/set_packaging_variables_stage.yml parameters: IsReleaseBuild: ${{ parameters.IsReleaseBuild }} PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} + - template: templates/ondevice-training-cpu-packaging-pipeline.yml parameters: RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} @@ -48,4 +60,15 @@ stages: OrtNugetPackageId: 'Microsoft.ML.OnnxRuntime.Training' AdditionalBuildFlags: '--enable_training_apis' AdditionalWinBuildFlags: '--enable_onnx_tests --enable_wcos' - BuildVariant: 'default' \ No newline at end of file + BuildVariant: 'default' + +- template: templates/publish-nuget-steps.yml + parameters: + download_artifacts_steps: + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Pipeline Artifact - Signed NuGet Training Package' + ArtifactName: 'drop-signed-nuget-Training-CPU' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact/final-package' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} diff --git a/tools/ci_build/github/azure-pipelines/publish-nuget.yml b/tools/ci_build/github/azure-pipelines/publish-nuget.yml index 206a9464de6ef..b78d586288ba3 100644 --- a/tools/ci_build/github/azure-pipelines/publish-nuget.yml +++ b/tools/ci_build/github/azure-pipelines/publish-nuget.yml @@ -32,11 +32,6 @@ stages: artifact: 'drop-signed-nuget-dml' - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-dml\*" $(Build.BinariesDirectory)\nuget-artifact\final-package - - download: build - displayName: 'Download Pipeline Artifact - Signed NuGet Package' - artifact: 'drop-signed-nuget-Training-CPU' - - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-Training-CPU\*" $(Build.BinariesDirectory)\nuget-artifact\final-package - # Publish CUDA 11 Nuget/Java pkgs to ADO feed - template: stages/nuget-cuda-publishing-stage.yml parameters: @@ -44,4 +39,4 @@ stages: - template: stages/java-cuda-publishing-stage.yml parameters: - artifact_feed: $(ArtifactFeed) \ No newline at end of file + artifact_feed: $(ArtifactFeed) From bc3713206dc1d6c7e5062389ef7db42ac2051a30 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Mon, 29 Jul 2024 10:00:21 -0700 Subject: [PATCH 18/37] Update QNN pipeline pool (#21482) ### Description Update QNN pipeline pool ### Motivation and Context Let all our pipelines are using the latest NDK version --- ...droid-arm64-v8a-QNN-crosscompile-ci-pipeline.yml | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index 6649206c0d79c..c80092fc82ed5 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -35,7 +35,7 @@ parameters: jobs: - job: Build_QNN_EP - pool: onnxruntime-qnn-ubuntu-2004-cpu + pool: onnxruntime-Ubuntu2204-AMD-CPU timeoutInMinutes: 30 workspace: clean: all @@ -46,6 +46,10 @@ jobs: inputs: versionSpec: $(pythonVersion) + - script: | + env | grep ANDROID + displayName: View Android ENVs + - script: sudo apt-get update -y && sudo apt-get install -y coreutils ninja-build displayName: Install coreutils and ninja @@ -56,13 +60,6 @@ jobs: parameters: QnnSDKVersion: ${{ parameters.QnnSdk }} - - script: | - export ANDROID_SDK_ROOT=/usr/local/lib/android/sdk - export ANDROID_HOME=/usr/local/lib/android/sdk - export ANDROID_NDK_HOME=/usr/local/lib/android/sdk/ndk-bundle - export ANDROID_NDK_ROOT=/usr/local/lib/android/sdk/ndk-bundle - displayName: set Android ENVs - - script: | set -e -x rm -rf /tmp/scripts From 79537d0523a7c215ef1685bf46efbd423242c4c1 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Mon, 29 Jul 2024 10:00:52 -0700 Subject: [PATCH 19/37] Remove tools/ci_build/github/android/run_nnapi_code_coverage.sh (#21371) ### Description Remove tools/ci_build/github/android/run_nnapi_code_coverage.sh ### Motivation and Context This file is no longer needed --- .../github/android/run_nnapi_code_coverage.sh | 36 ------------------- 1 file changed, 36 deletions(-) delete mode 100755 tools/ci_build/github/android/run_nnapi_code_coverage.sh diff --git a/tools/ci_build/github/android/run_nnapi_code_coverage.sh b/tools/ci_build/github/android/run_nnapi_code_coverage.sh deleted file mode 100755 index 472e824eaa47a..0000000000000 --- a/tools/ci_build/github/android/run_nnapi_code_coverage.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash - -# This script will run ORT build for Android with code coverage option - -set -e -set -x - -if [ $# -ne 1 ]; then - echo "One command line argument, the ROOT root directory, is expected" -fi - -ORT_ROOT=$1 -# Build and run onnxruntime using NNAPI execution provider targeting android emulator -python3 ${ORT_ROOT}/tools/ci_build/build.py \ - --android \ - --build_dir build_nnapi \ - --android_sdk_path $ANDROID_HOME \ - --android_ndk_path $ANDROID_NDK_HOME \ - --android_abi=x86_64 \ - --android_api=29 \ - --skip_submodule_sync \ - --parallel \ - --use_nnapi \ - --cmake_generator=Ninja \ - --build_java \ - --path_to_protoc_exe $ORT_ROOT/protobuf_install/bin/protoc \ - --code_coverage - -# Install gcovr -python3 -m pip install gcovr - -# Retrieve runtime code coverage files from the emulator and analyze -python3 ${ORT_ROOT}/tools/ci_build/coverage.py \ - --build_dir build_nnapi \ - --android_sdk_path $ANDROID_HOME - From 0d7cf301a1e0ea784edcdf2242e973643f0bb9c9 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Tue, 30 Jul 2024 02:05:34 +0800 Subject: [PATCH 20/37] [js/webgpu] Add activation Tanh (#21540) Bug:https://github.com/microsoft/onnxruntime/issues/21467 ### Description ### Motivation and Context --- js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts | 4 +++ js/web/test/data/ops/fused-conv.jsonc | 33 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 6e66abacf3471..cfa0b42ef9eeb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -30,6 +30,10 @@ export const getActivationSnippet = baseType}(uniforms.beta)));`; case 'LeakyRelu': return `value = select(${baseType}(uniforms.alpha) * value, value, value >= ${valueType}(0.0));`; + case 'Tanh': + return `let e2x = exp(-2.0 * abs(value)); + value = sign(value) * (1.0 - e2x) / (1.0 + e2x); + `; case '': return ''; // TODO: adding other activations that can be fused. diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc index 6a10e3b96a26a..d88c91ebc9de7 100644 --- a/js/web/test/data/ops/fused-conv.jsonc +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -430,5 +430,38 @@ ] } ] + }, + { + "name": "fused conv with tanh", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "Tanh", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [0.11, 0.12, 0.13, 0.14], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.15572261810302734, 0.20409323275089264, 0.29770541191101074, 0.3425688147544861], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] } ] From b03c9496aa081fa6c07c5b266800694c830afd60 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 29 Jul 2024 13:39:38 -0700 Subject: [PATCH 21/37] [js/web] allow load WebAssembly binary from buffer (#21534) ### Description This PR adds a new option `ort.env.wasm.wasmBinary`, which allows user to set to a buffer containing preload .wasm file content. This PR should resolve the problem from latest discussion in #20876. --- cmake/onnxruntime_webassembly.cmake | 2 +- js/common/lib/env.ts | 6 +++++ js/web/lib/wasm/wasm-factory.ts | 8 ++++++- .../e2e/browser-test-wasm-binary-override.js | 22 +++++++++++++++++++ js/web/test/e2e/run-data.js | 3 +++ 5 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 js/web/test/e2e/browser-test-wasm-binary-override.js diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 7a49e90c00bce..0686b66876d9f 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -225,7 +225,7 @@ else() "SHELL:-s EXPORT_ALL=0" "SHELL:-s VERBOSE=0" "SHELL:-s FILESYSTEM=0" - "SHELL:-s INCOMING_MODULE_JS_API=[preRun,locateFile,arguments,onExit,wasmMemory,buffer,instantiateWasm,mainScriptUrlOrBlob]" + "SHELL:-s INCOMING_MODULE_JS_API=[locateFile,instantiateWasm,wasmBinary]" "SHELL:-s WASM_BIGINT=1" ${WASM_API_EXCEPTION_CATCHING} --no-entry diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index dbb5f8118363f..1a87569a115a6 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -74,6 +74,12 @@ export declare namespace Env { */ wasmPaths?: WasmPrefixOrFilePaths; + /** + * Set a custom buffer which contains the WebAssembly binary. If this property is set, the `wasmPaths` property will + * be ignored. + */ + wasmBinary?: ArrayBufferLike|Uint8Array; + /** * Set or get a boolean value indicating whether to proxy the execution of main thread to a worker thread. * diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index fb068ab42d04c..0f5f10716a00b 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -108,6 +108,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise const mjsPathOverride = (mjsPathOverrideFlag as URL)?.href ?? mjsPathOverrideFlag; const wasmPathOverrideFlag = (wasmPaths as Env.WasmFilePaths)?.wasm; const wasmPathOverride = (wasmPathOverrideFlag as URL)?.href ?? wasmPathOverrideFlag; + const wasmBinaryOverride = flags.wasmBinary; const [objectUrl, ortWasmFactory] = (await importWasmModule(mjsPathOverride, wasmPrefixOverride, numThreads > 1)); @@ -135,7 +136,12 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise numThreads, }; - if (wasmPathOverride || wasmPrefixOverride) { + if (wasmBinaryOverride) { + /** + * Set a custom buffer which contains the WebAssembly binary. This will skip the wasm file fetching. + */ + config.wasmBinary = wasmBinaryOverride; + } else if (wasmPathOverride || wasmPrefixOverride) { /** * A callback function to locate the WebAssembly file. The function should return the full path of the file. * diff --git a/js/web/test/e2e/browser-test-wasm-binary-override.js b/js/web/test/e2e/browser-test-wasm-binary-override.js new file mode 100644 index 0000000000000..35d427fa3b722 --- /dev/null +++ b/js/web/test/e2e/browser-test-wasm-binary-override.js @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +'use strict'; + +const documentUrl = document.currentScript.src; + +it('Browser E2E testing - WebAssembly backend', async function() { + // preload .wasm file binary + const wasmUrl = new URL('./node_modules/onnxruntime-web/dist/ort-wasm-simd-threaded.wasm', documentUrl).href; + const response = await fetch(wasmUrl); + + // make sure the .wasm file is loaded successfully + assert(response.ok); + assert(response.headers.get('Content-Type') === 'application/wasm'); + + // override wasm binary + const binary = await response.arrayBuffer(); + ort.env.wasm.wasmBinary = binary; + + await testFunction(ort, {executionProviders: ['wasm']}); +}); diff --git a/js/web/test/e2e/run-data.js b/js/web/test/e2e/run-data.js index 507192f29be9c..856f29eac6ddf 100644 --- a/js/web/test/e2e/run-data.js +++ b/js/web/test/e2e/run-data.js @@ -36,6 +36,9 @@ const BROWSER_TEST_CASES = [ [true, false, './browser-test-wasm.js', 'ort.bundle.min.mjs', ['num_threads=2', 'proxy=1']], // 2 threads, proxy [true, false, './browser-test-wasm.js', 'ort.bundle.min.mjs', ['num_threads=1', 'proxy=1']], // 1 thread, proxy + // wasm binary override: + [true, false, './browser-test-wasm-binary-override.js', 'ort.min.js'], + // path override: // wasm, path override filenames for both mjs and wasm, same origin [true, false, './browser-test-wasm-path-override-filename.js', 'ort.min.js', ['port=9876', 'files=mjs,wasm']], From c39f1c4fd80668fd7619719ebe7a374f4ae11a5e Mon Sep 17 00:00:00 2001 From: Preetha Veeramalai Date: Mon, 29 Jul 2024 14:12:36 -0700 Subject: [PATCH 22/37] ORT- OVEP 1.19 PR-follow up (#21546) ### Description Follow up PR for bug fixes on 1.19 ### Motivation and Context - Handles 1.19 docker file fixes. - Sets the default file naming of epctx onnx model with _ctx.onnx as suffix. - Create epctx model directories if it doesn't exist. --------- Co-authored-by: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> --- dockerfiles/Dockerfile.openvino | 10 ++++------ .../providers/openvino/backend_manager.cc | 9 ++++++++- .../openvino/openvino_execution_provider.cc | 5 ----- .../openvino/openvino_provider_factory.cc | 20 ++++++++++++++++++- 4 files changed, 31 insertions(+), 13 deletions(-) diff --git a/dockerfiles/Dockerfile.openvino b/dockerfiles/Dockerfile.openvino index 75898770acf28..39e75a68a369f 100644 --- a/dockerfiles/Dockerfile.openvino +++ b/dockerfiles/Dockerfile.openvino @@ -3,11 +3,11 @@ # SPDX-License-Identifier: MIT #-------------------------------------------------------------------------- -ARG OPENVINO_VERSION=2024.0.0 +ARG OPENVINO_VERSION=2024.2.0 # Build stage -FROM openvino/ubuntu20_runtime:${OPENVINO_VERSION} AS builder +FROM openvino/ubuntu22_runtime:${OPENVINO_VERSION} AS builder ENV WORKDIR_PATH=/home/openvino WORKDIR $WORKDIR_PATH @@ -34,20 +34,18 @@ RUN cat /etc/apt/sources.list | sed 's/^# deb-src/deb-src/g' > ./temp; mv temp / RUN apt update; apt install dpkg-dev RUN mkdir /sources WORKDIR /sources -RUN apt-get source cron iso-codes lsb-release powermgmt-base python-apt-common python3-apt python3-dbus python3-gi unattended-upgrades libapt-pkg6.0 libhogweed5 libnettle7 +RUN apt-get source cron iso-codes lsb-release powermgmt-base python-apt-common python3-apt python3-dbus python3-gi libapt-pkg6.0 libhogweed6 libnettle8 WORKDIR / RUN tar cvf GPL_sources.tar.gz /sources # Deploy stage -FROM openvino/ubuntu20_runtime:${OPENVINO_VERSION} +FROM openvino/ubuntu22_runtime:${OPENVINO_VERSION} ENV DEBIAN_FRONTEND noninteractive USER root COPY --from=builder /home/openvino/onnxruntime/build/Linux/Release/dist/*.whl ./ COPY --from=builder /GPL_sources.tar.gz ./ RUN python3 -m pip install ./*.whl && rm ./*.whl -RUN apt update; apt install -y unattended-upgrades && \ - unattended-upgrade ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 8f3658df0d09d..18a6257910a56 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -128,6 +128,13 @@ BackendManager::BackendManager(const GlobalContext& global_context, #endif } } + if (global_context_.export_ep_ctx_blob && !ep_ctx_handle_.IsValidOVEPCtxGraph()) { + auto status = onnxruntime::openvino_ep::BackendManager::ExportCompiledBlobAsEPCtxNode(subgraph, + logger); + if ((!status.IsOK())) { + ORT_THROW(status); + } + } } // Call EPContext model exporter here if the provider option for exporting @@ -158,7 +165,7 @@ Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVie if (dot == std::string::npos) return graph_name; return graph_name.substr(0, dot); }(); - graph_name = graph_name + "-ov_" + GetGlobalContext().device_type + "_blob.onnx"; + graph_name = graph_name + "_ctx.onnx"; } // If embed_mode, then pass on the serialized blob // If not embed_mode, dump the blob here and only pass on the path to the blob diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 5627cb2c122fb..29c45916795d3 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -147,11 +147,6 @@ common::Status OpenVINOExecutionProvider::Compile( *GetLogger(), ep_ctx_handle_); - if (global_context_->export_ep_ctx_blob && !ep_ctx_handle_.IsValidOVEPCtxGraph()) { - ORT_RETURN_IF_ERROR(backend_manager->ExportCompiledBlobAsEPCtxNode(graph_body_viewer, - *GetLogger())); - } - compute_info.create_state_func = [backend_manager](ComputeContext* context, FunctionState* state) { OpenVINOEPFunctionState* p = new OpenVINOEPFunctionState(); diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 716a7cd936405..3738f2a534154 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -192,6 +192,10 @@ struct OpenVINO_Provider : Provider { } if (provider_options_map.find("num_of_threads") != provider_options_map.end()) { + if (!std::all_of(provider_options_map.at("num_of_threads").begin(), + provider_options_map.at("num_of_threads").end(), ::isdigit)) { + ORT_THROW("[ERROR] [OpenVINO-EP] Number of threads should be a number. \n"); + } num_of_threads = std::stoi(provider_options_map.at("num_of_threads")); if (num_of_threads <= 0) { num_of_threads = 1; @@ -298,7 +302,21 @@ struct OpenVINO_Provider : Provider { // The path to dump epctx model is valid only when epctx is enabled. // Overrides the cache_dir option to dump model cache files from OV. if (export_ep_ctx_blob) { - cache_dir = provider_options_map.at("so_epctx_path").c_str(); + auto ep_context_file_path_ = provider_options_map.at("so_epctx_path"); + auto file_path = std::filesystem::path(ep_context_file_path_); + // ep_context_file_path_ file extension must be .onnx + if (!ep_context_file_path_.empty() && + file_path.extension().generic_string() == ".onnx") { + // ep_context_file_path_ must be provided as a directory, create it if doesn't exist + auto parent_path = file_path.parent_path(); + if (!std::filesystem::is_directory(parent_path) && + !std::filesystem::create_directory(parent_path)) { + ORT_THROW("[ERROR] [OpenVINO] Failed to create directory : " + file_path.parent_path().generic_string() + " \n"); + } + cache_dir = ep_context_file_path_.c_str(); + } else { + ORT_THROW("[ERROR] [OpenVINO] Invalid ep_ctx_file_path" + ep_context_file_path_ + " \n"); + } } } From 7543dd040b2d32109a2718d7276d3aca1edadaae Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 30 Jul 2024 10:50:13 +1200 Subject: [PATCH 23/37] Propagate NaNs in the CPU min and max operators (#21492) ### Description Propagates NaN values in the min and max operators so that min or max with a NaN in either input always produces NaN. ### Motivation and Context Fixes #21455 --- .../providers/cpu/math/element_wise_ops.cc | 18 +- onnxruntime/test/providers/checkers.cc | 2 +- .../cpu/math/element_wise_ops_test.cc | 188 ++++++++++++++++-- 3 files changed, 187 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 1d524a90302e7..5ea6000da1cba 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -705,7 +705,7 @@ Status Min_6::Compute(OpKernelContext* ctx) const { for (int index = 1; index < inputCount; index++) { auto& data_n = *ctx->Input(index); ORT_ENFORCE(data_n.Shape() == shape, "All inputs must have the same shape"); - min = min.array().min(EigenMap(data_n).array()); + min = min.array().template min(EigenMap(data_n).array()); } return Status::OK(); @@ -721,15 +721,16 @@ struct Min_8::ComputeImpl { ProcessBroadcastSpanFuncs funcs{ [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput1().array().min(per_iter_bh.ScalarInput0()); + per_iter_bh.EigenInput1().array().template min(per_iter_bh.ScalarInput0()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().min(per_iter_bh.ScalarInput1()); + per_iter_bh.EigenInput0().array().template min(per_iter_bh.ScalarInput1()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().min(per_iter_bh.EigenInput1().array()); + per_iter_bh.EigenInput0().array().template min( + per_iter_bh.EigenInput1().array()); }}; int input_count = inst.Node().InputArgCount().front(); @@ -827,7 +828,7 @@ Status Max_6::Compute(OpKernelContext* ctx) const { for (int index = 1; index < inputCount; index++) { auto& data_n = *ctx->Input(index); ORT_ENFORCE(data_n.Shape() == shape, "All inputs must have the same shape"); - max = max.array().max(EigenMap(data_n).array()); + max = max.array().template max(EigenMap(data_n).array()); } return Status::OK(); @@ -843,15 +844,16 @@ struct Max_8::ComputeImpl { ProcessBroadcastSpanFuncs funcs{ [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput1().array().max(per_iter_bh.ScalarInput0()); + per_iter_bh.EigenInput1().array().template max(per_iter_bh.ScalarInput0()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().max(per_iter_bh.ScalarInput1()); + per_iter_bh.EigenInput0().array().template max(per_iter_bh.ScalarInput1()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().max(per_iter_bh.EigenInput1().array()); + per_iter_bh.EigenInput0().array().template max( + per_iter_bh.EigenInput1().array()); }}; int input_count = inst.Node().InputArgCount().front(); diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc index 5f332ddcddb8d..182fa4729a88f 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/providers/checkers.cc @@ -427,7 +427,7 @@ struct TensorCheck { for (int64_t i = 0; i < size; ++i) { if (std::isnan(f_expected[i])) { - EXPECT_TRUE(std::isnan(f_expected[i])) << "Expected NaN. i:" << i; + EXPECT_TRUE(std::isnan(f_actual[i])) << "Expected NaN. i:" << i; } else if (std::isinf(f_expected[i])) { // Test infinity for equality EXPECT_EQ(f_expected[i], f_actual[i]) << "Expected infinity. i:" << i; } else { diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index eb3575f2cde88..bd3d21d4929f3 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -1553,6 +1553,47 @@ TEST(MathOpTest, Min_12_Float_Nan) { } } +TEST(MathOpTest, Min_12_Float_Nan_with_scalar) { + OpTester test("Min", 12); + test.AddInput("data_1", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5f, 0.5f}); + test.AddInput("data_2", {1}, {0.25f}); + test.AddOutput("min", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5f, 0.25f}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Min_12_Float_with_scalar_Nan) { + OpTester test("Min", 12); + test.AddInput("data_1", {2, 2}, + {0.25f, -0.25f, -0.5f, 0.5f}); + test.AddInput("data_2", {1}, {std::numeric_limits::quiet_NaN()}); + test.AddOutput("min", {2, 2}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + TEST(MathOpTest, Min_12_Double) { OpTester test("Min", 12); test.AddInput("data_0", {1, 3}, @@ -1586,12 +1627,53 @@ TEST(MathOpTest, Min_12_Double_Nan) { std::numeric_limits::quiet_NaN(), -1.0, -1.0, -2.0, 0.5, 0.0, 1.0}); - if (nullptr != DefaultCpuExecutionProvider().get()) { + if (nullptr != DefaultCpuExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } - if (nullptr != DefaultCudaExecutionProvider().get()) { + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Min_12_Double_Nan_with_scalar) { + OpTester test("Min", 12); + test.AddInput("data_1", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5, 0.5}); + test.AddInput("data_2", {1}, {0.25}); + test.AddOutput("min", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5, 0.25}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Min_12_Double_with_scalar_Nan) { + OpTester test("Min", 12); + test.AddInput("data_1", {2, 2}, + {0.25, -0.25, -0.5, 0.5}); + test.AddInput("data_2", {1}, {std::numeric_limits::quiet_NaN()}); + test.AddOutput("min", {2, 2}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); @@ -1666,7 +1748,7 @@ TEST(MathOpTest, Min_12_UInt64) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Min_12_MLFLoat16) { +TEST(MathOpTest, Min_12_MLFloat16) { OpTester test("Min", 12); test.AddInput("data_0", {1, 3}, MakeMLFloat16({1.f, 1.f, 1.f})); @@ -1679,7 +1761,7 @@ TEST(MathOpTest, Min_12_MLFLoat16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Min_12_MLFLoat16_Scalar0) { +TEST(MathOpTest, Min_12_MLFloat16_Scalar0) { OpTester test("Min", 12); test.AddInput("data_0", {}, MakeMLFloat16({-10.f})); @@ -1692,7 +1774,7 @@ TEST(MathOpTest, Min_12_MLFLoat16_Scalar0) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Min_12_MLFLoat16_Scalar1) { +TEST(MathOpTest, Min_12_MLFloat16_Scalar1) { OpTester test("Min", 12); test.AddInput("data_0", {1, 3}, MakeMLFloat16({2.f, 3.f, 4.f})); @@ -1809,12 +1891,53 @@ TEST(MathOpTest, Max_12_Float_Nan) { std::numeric_limits::quiet_NaN(), -0.5f, 0.0f, -1.0f, 1.0f, 1.0f, 2.0f}); - if (nullptr != DefaultCpuExecutionProvider().get()) { + if (nullptr != DefaultCpuExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } - if (nullptr != DefaultCudaExecutionProvider().get()) { + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Max_12_Float_Nan_with_scalar) { + OpTester test("Max", 12); + test.AddInput("data_1", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5f, 0.5f}); + test.AddInput("data_2", {1}, {0.25f}); + test.AddOutput("max", {3, 1}, + {std::numeric_limits::quiet_NaN(), 0.25f, 0.5f}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Max_12_Float_with_scalar_Nan) { + OpTester test("Max", 12); + test.AddInput("data_1", {2, 2}, + {0.25f, -0.25f, -0.5f, 0.5f}); + test.AddInput("data_2", {1}, {std::numeric_limits::quiet_NaN()}); + test.AddOutput("max", {2, 2}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); @@ -1854,12 +1977,53 @@ TEST(MathOpTest, Max_12_Double_Nan) { std::numeric_limits::quiet_NaN(), -0.5, 0.0, -1.0, 1.0, 1.0, 2.0}); - if (nullptr != DefaultCpuExecutionProvider().get()) { + if (nullptr != DefaultCpuExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } - if (nullptr != DefaultCudaExecutionProvider().get()) { + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Max_12_Double_Nan_with_scalar) { + OpTester test("Max", 12); + test.AddInput("data_1", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5, 0.5}); + test.AddInput("data_2", {1}, {0.25}); + test.AddOutput("max", {3, 1}, + {std::numeric_limits::quiet_NaN(), 0.25, 0.5}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Max_12_Double_with_scalar_Nan) { + OpTester test("Max", 12); + test.AddInput("data_1", {2, 2}, + {0.25, -0.25, -0.5, 0.5}); + test.AddInput("data_2", {1}, {std::numeric_limits::quiet_NaN()}); + test.AddOutput("max", {2, 2}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); @@ -1934,7 +2098,7 @@ TEST(MathOpTest, Max_12_UInt64) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Max_12_MLFLoat16) { +TEST(MathOpTest, Max_12_MLFloat16) { OpTester test("Max", 12); test.AddInput("data_0", {1, 3}, MakeMLFloat16({-1.f, -1.f, -1.f})); @@ -1947,7 +2111,7 @@ TEST(MathOpTest, Max_12_MLFLoat16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Max_12_MLFLoat16_Scalar0) { +TEST(MathOpTest, Max_12_MLFloat16_Scalar0) { OpTester test("Max", 12); test.AddInput("data_0", {}, MakeMLFloat16({-1.f})); @@ -1960,7 +2124,7 @@ TEST(MathOpTest, Max_12_MLFLoat16_Scalar0) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Max_12_MLFLoat16_Scalar1) { +TEST(MathOpTest, Max_12_MLFloat16_Scalar1) { OpTester test("Max", 12); test.AddInput("data_0", {1, 3}, MakeMLFloat16({-1.f, -2.f, -3.f})); From d98581495f996084af65ae1e6600378bed949460 Mon Sep 17 00:00:00 2001 From: Sophie Schoenmeyer <107952697+sophies927@users.noreply.github.com> Date: Mon, 29 Jul 2024 16:06:03 -0700 Subject: [PATCH 24/37] Update labeling bot (#21548) Current labeling bot over-applies many of the labels (e.g., ep:CUDA and platform:windows) and is missing some of the APIs + EPs Working on migrating this workflow to GitHub policies but would like to use this fix in the meantime to avoid causing any issues w/ ORT 1.19 ### Description ### Motivation and Context --- .github/labeler.yml | 31 ++++++++++++++---------- .github/title-only-labeler.yml | 4 +++ .github/workflows/title-only-labeler.yml | 20 +++++++++++++++ 3 files changed, 42 insertions(+), 13 deletions(-) create mode 100644 .github/title-only-labeler.yml create mode 100644 .github/workflows/title-only-labeler.yml diff --git a/.github/labeler.yml b/.github/labeler.yml index 526d8a643e713..c14e2a213bc60 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -1,20 +1,25 @@ -api:javascript: '/\bjavascript\b/i' +api:CSharp: '/(\bc\s*sharp\b|\bc#)/i' api:java: '/\bjava\b/i' +api:javascript: '/\bjavascript\b/i' ep:ACL: '/\bacl\b/i' ep:ArmNN: '/\barmnn\b/i' -ep:CUDA: '/\bcuda\b/i' -ep:DML: '/(\bdirectml\b|\bdml\b)/i' -ep:MIGraphX: '/\bmigraphx\b/i' -ep:oneDNN: '/\bonednn\b/i' +ep:CANN: '/\bcann\b/i' +ep:CoreML: '/\bcore\s*ml\b/i' +ep:DML: '/(\bdirect\s*ml\b|\bdml\b)/i' +ep:MIGraphX: '/\bmi\s*graph\s*x\b/i' +ep:oneDNN: '/\bone\s*dnn\b/i' ep:OpenVINO: '/\bopen\s*vino\b/i' -ep:RockchipNPU: '/\brockchip\b/i' +ep:QNN: '/\bqnn\b/i' +ep:RockchipNPU: '/\brockchip(?:npu)?\b/i' ep:ROCm: '/\brocm\b/i' -ep:TensorRT: '/(\btensor\s*rt\b|\btrt\b)/i' +ep:SNPE: '/\bsnpe\b/i' ep:tvm: '/\btvm\b/i' ep:VitisAI: '/\bvitis(?:ai)?\b/i' -platform:jetson: '/\bjetson\b/i' -platform:mobile: '/(\bobj(?:ective)?-?c\b|\bnnapi\b|\bcore-?ml\b|\bmobile\b|\bandroid\b|\bios\b|\bxamarin\b|\bmaui\b)/i' -platform:web: '/(\bwebgl\b|\bweb-?gpu\b|\bwasm\b|\bonnxruntime-node\b|\bonnxruntime-web\b)/i' -platform:windows: '/(\bwindows\b|\bwinrt\b|\bwinml\b)/i' -model:transformer: '/(\bbert\b|\bgpt-?2\b|\bhugging-?face\b|\blong-?former\b|\bt5\b)/i' -quantization: '/(is this a quantized model\?\n\nYes|\bquantization\b)/i' +ep:WebGPU: '/\bwebgpu\b/i' +ep:WebNN: '/\bwebnn\b/i' +ep:Xnnpack: '/\bxnn\s*pack\b/i' +.NET: '/(\bdot\s*net\b|\bnuget\b|\.net\b)/i' +platform:jetson: '/(\bjetson\b|\bjetpack\b)/i' +platform:mobile: '/(\bobj(?:ective)?-?c\b|\bnnapi\b|\bmobile\b|\bandroid\b|\bios\b|\bxamarin\b|\bmaui\b)/i' +platform:web: '/(\bwebgl\b|\bweb-?gpu\b|\bwasm\b|\bonnxruntime-node\b|\bonnxruntime-web\b|\bonnxruntime-react-native\b|\bnpm\b|\btransformers\.js\b)/i' +model:transformer: '/\btransformers(?!\.js)\b/i' diff --git a/.github/title-only-labeler.yml b/.github/title-only-labeler.yml new file mode 100644 index 0000000000000..4980f7251bcb4 --- /dev/null +++ b/.github/title-only-labeler.yml @@ -0,0 +1,4 @@ +ep:CUDA: '/\bcuda\b/i' +ep:TensorRT: '/(\btensor\s*rt\b|\btrt\b)/i' +platform:windows: '/(\bwindows\b|\bwinrt\b|\bwinml\b)/i' +quantization: '/(quant|\bqdq\b)/i' diff --git a/.github/workflows/title-only-labeler.yml b/.github/workflows/title-only-labeler.yml new file mode 100644 index 0000000000000..e0af2dd06b1b7 --- /dev/null +++ b/.github/workflows/title-only-labeler.yml @@ -0,0 +1,20 @@ +name: "Title Only Issue Labeler" +on: + issues: + types: [opened, edited] + +permissions: + issues: write + +jobs: + triage: + runs-on: ubuntu-latest + steps: + - uses: github/issue-labeler@v3.4 + with: + repo-token: "${{ secrets.GITHUB_TOKEN }}" + configuration-path: .github/title-only-labeler.yml + not-before: 2020-01-15T02:54:32Z + enable-versioned-regex: 0 + include-title: 1 + include-body: 0 From 8417c325ec160dc8ee62edaf6d1daf91ad979d56 Mon Sep 17 00:00:00 2001 From: mcollinswisc Date: Mon, 29 Jul 2024 16:06:51 -0700 Subject: [PATCH 25/37] Keep QDQ nodes w/ nonpositive scale around MaxPool (#21182) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description This change adds a check for whether the scale in the QuantizeLinear (or DequantizeLinear) is a positive scalar, and a new selector to disallow removing the QDQ around MaxPool if it is not. ### Motivation and Context Currently, the DropQDQNodesRules optimization removes QuantizeLinear and DequantizeLinear nodes from DequantizeLinear ∘ MaxPool ∘ QuantizeLinear. However, if the x_scale/y_scale values are non-positive, the (de-)quantization changes the ordering of the elements in the input value, so this optimization is changing the results. https://github.com/microsoft/onnxruntime/issues/21176 --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- .../optimizer/qdq_transformer/qdq_util.cc | 35 ++++++++++++++ .../core/optimizer/qdq_transformer/qdq_util.h | 4 ++ .../qdq_selector_action_transformer.cc | 27 +++++++++-- .../selectors_actions/qdq_selectors.cc | 7 +++ .../selectors_actions/qdq_selectors.h | 10 ++-- .../test/optimizer/qdq_transformer_test.cc | 46 +++++++++++++++++++ 6 files changed, 120 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc index a4d1ea3c7cf56..7ef4ced1835f0 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc @@ -166,6 +166,41 @@ bool QOrDQNodeHasConstantScalarScaleAndZeroPoint( return true; } +bool IsQOrDQScalePositiveConstantScalar( + const Node& q_or_dq_node, const GetConstantInitializerFn& get_const_initializer, + const std::filesystem::path& model_path) { + auto q_or_dq_input_defs = q_or_dq_node.InputDefs(); + + ORT_ENFORCE(q_or_dq_input_defs.size() >= 2); + + if (!optimizer_utils::IsScalar(*q_or_dq_input_defs[InputIndex::SCALE_ID])) { + return false; + } + + const ONNX_NAMESPACE::TensorProto* q_or_dq_scale_tensor_proto = + get_const_initializer(q_or_dq_input_defs[InputIndex::SCALE_ID]->Name()); + if (nullptr == q_or_dq_scale_tensor_proto) { + return false; + } + + Initializer q_or_dq_scale(*q_or_dq_scale_tensor_proto, model_path); + + switch (q_or_dq_scale.data_type()) { + case ONNX_NAMESPACE::TensorProto::FLOAT: + return q_or_dq_scale.data()[0] > 0; + + case ONNX_NAMESPACE::TensorProto::FLOAT16: + return q_or_dq_scale.data()[0] > 0; + + case ONNX_NAMESPACE::TensorProto::BFLOAT16: + return q_or_dq_scale.data()[0] > 0; + + default: + assert(false); + return false; + } +} + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) bool MatchQNode(const Node& node) { diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h index 5d11b8bfd5558..008f9972a143b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h @@ -65,6 +65,10 @@ bool QOrDQNodeHasConstantScalarScaleAndZeroPoint( const GetConstantInitializerFn& get_const_initializer, bool& zero_point_exists); +// Checks that the y_scale/x_scale input to the QuantizeLinear/DequantizeLinear node is a positive scalar. +bool IsQOrDQScalePositiveConstantScalar(const Node& q_or_dq_node, const GetConstantInitializerFn& get_const_initializer, + const std::filesystem::path& model_path); + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // Check Q node op type, version, and domain. bool MatchQNode(const Node& node); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 17e66a3953b97..d81701fdf443b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -35,6 +35,7 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) { // 3 nodes. DQ, target, Q. Merge into target and remove DQ and Q. const std::string drop_action_name{"drop"}; const std::string drop_action_no_int16_name{"drop_no_int16_support"}; + const std::string drop_action_no_int16_and_positive_scale_name{"drop_no_int16_support_and_positive_scale"}; NTO::NodeLocation dq{NTO::NodeType::kInput, 0}; NTO::NodeLocation q{NTO::NodeType::kOutput, 0}; @@ -46,19 +47,32 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr drop_action_no_int16 = std::make_unique( std::vector(moves)); // Copy before std::move(moves) + std::unique_ptr drop_action_no_int16_and_positive_scale = std::make_unique( + std::vector(moves)); // Copy before std::move(moves) std::unique_ptr drop_action = std::make_unique(std::move(moves)); #if !defined(ORT_MINIMAL_BUILD) - // Use a separate selector + action that disallows 16-bit types for MaxPool and Resize. + // Use separate selectors & actions for MaxPool and Resize. + // + // They disallow 16-bit types for MaxPool and Resize: // int16 MaxPool is not supported by the ONNX specification. // int16 Resize is not supported by the ORT implementation (although allowed by ONNX). - std::unique_ptr selector_disallow_16bit = std::make_unique(false); + // + // And cannot eliminate the QDQ for MaxPool if the scale is not positive, as a negative + // scale will change the ordering of the elements between quantized & de-quantized values. + std::unique_ptr selector_no_16bit = std::make_unique(false); qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_name, - {{"MaxPool", {12}}, - {"Resize", {}}}, - std::move(selector_disallow_16bit), + {{"Resize", {}}}, + std::move(selector_no_16bit), std::move(drop_action_no_int16)); + std::unique_ptr selector_no_16bit_and_positive_scale = + std::make_unique(false, true, false); + qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_and_positive_scale_name, + {{"MaxPool", {12}}}, + std::move(selector_no_16bit_and_positive_scale), + std::move(drop_action_no_int16_and_positive_scale)); + std::unique_ptr selector = std::make_unique(true); qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_name, {{"Gather", {}}, @@ -70,6 +84,9 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) { std::move(drop_action)); #else qdq_selector_action_registry.RegisterAction(drop_action_no_int16_name, std::move(drop_action_no_int16)); + qdq_selector_action_registry.RegisterAction( + drop_action_no_int16_and_positive_scale_name, + std::move(drop_action_no_int16_and_positive_scale)); qdq_selector_action_registry.RegisterAction(drop_action_name, std::move(drop_action)); #endif } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index e271ae8df3356..203aba2c3dd91 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -150,6 +150,13 @@ bool DropQDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, return graph_viewer.GetConstantInitializer(initializer_name, true); }; + if (!allow_nonpositive_scale_) { + // IsQDQPairSupported will check that the scale is the same between q_node and dq_node. + if (!IsQOrDQScalePositiveConstantScalar(q_node, get_const_initializer, graph_viewer.ModelPath())) { + return false; + } + } + return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath()); } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 491a15b62cb03..7e009da39403b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -48,8 +48,9 @@ class NodeGroupSelector { // Zero point and scale are constant scalars and must match class DropQDQNodeGroupSelector : public NodeGroupSelector { public: - explicit DropQDQNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true) - : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} + explicit DropQDQNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true, + bool allow_nonpositive_scale = true) + : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit), allow_nonpositive_scale_(allow_nonpositive_scale) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -58,6 +59,7 @@ class DropQDQNodeGroupSelector : public NodeGroupSelector { bool allow_16bit_; bool allow_4bit_; + bool allow_nonpositive_scale_; }; // Single DQ -> node. @@ -300,8 +302,8 @@ class BaseSelector : public NodeSelector { class DropQDQNodesSelector : public BaseSelector { public: - explicit DropQDQNodesSelector(bool allow_16bit = false, bool allow_4bit = false) - : BaseSelector(std::make_unique(allow_16bit, allow_4bit)) {} + explicit DropQDQNodesSelector(bool allow_16bit = false, bool allow_4bit = false, bool allow_nonpositive_scale = true) + : BaseSelector(std::make_unique(allow_16bit, allow_4bit, allow_nonpositive_scale)) {} }; class DropDQNodesSelector : public BaseSelector { diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 367b4a65e3b7b..a043d6553bdfd 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -980,6 +980,52 @@ TEST(QDQTransformerTests, ReshapeDropQDQ) { RunReshapeDropQDQTestCase({1, 3, 2, 2}, {1, 12}, false, 21); // Use int16 ONNX QDQ ops } +// Runs a test case that checks if Q/DQ nodes are *not* dropped from DQ -> MaxPool -> Q if the quantization scale is +// negative. +template +static void RunMaxPoolNegativeScaleDropQDQTestCase() { + auto build_test_case = [](ModelTestBuilder& builder) { + constexpr QuantType qmin = std::numeric_limits::min(); + constexpr QuantType qmax = std::numeric_limits::max(); + + const std::vector input_shape = {1, 17, 17, 3}; + auto* input_arg = builder.MakeInput(input_shape, qmin, qmax); + auto* output_arg = builder.MakeOutput(); + + constexpr float scale = -0.003f; + QuantType zero_point = 1 + (qmax + qmin) / 2; + + auto* input_arg_dq = builder.MakeIntermediate(); + auto* maxpool_output = builder.MakeIntermediate(); + + builder.AddDequantizeLinearNode(input_arg, scale, zero_point, input_arg_dq); + + Node& maxpool_node = builder.AddNode("MaxPool", {input_arg_dq}, {maxpool_output}); + maxpool_node.AddAttribute("auto_pad", "VALID"); + maxpool_node.AddAttribute("kernel_shape", std::vector({2, 2})); + + builder.AddQuantizeLinearNode(maxpool_output, scale, zero_point, output_arg); + }; + + auto check_graph = [](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["MaxPool"], 1); + EXPECT_EQ(op_to_count["QuantizeLinear"], 1); + EXPECT_EQ(op_to_count["DequantizeLinear"], 1); + }; + + constexpr int opset = 21; + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, opset); +} + +// Checks that Q/DQ nodes are *not* dropped from DQ -> MaxPool -> Q for negative scale. Uses 8-bit and 16-bit Q/DQ ops. +TEST(QDQTransformerTests, MaxpoolDontDropQDQForNegativeScale) { + RunMaxPoolNegativeScaleDropQDQTestCase(); + RunMaxPoolNegativeScaleDropQDQTestCase(); + RunMaxPoolNegativeScaleDropQDQTestCase(); + RunMaxPoolNegativeScaleDropQDQTestCase(); +} + // Runs a test case that checks if Q/DQ nodes are dropped from DQ -> (Un)Squeeze -> Q. template static void RunSqueezeUnsqueezeDropQDQTestCase(const std::string& squeeze_type, From 5d78b9a17bb6d126f8ae7fa7eef05cabe4a08dae Mon Sep 17 00:00:00 2001 From: Yifan Li <109183385+yf711@users.noreply.github.com> Date: Mon, 29 Jul 2024 17:27:38 -0700 Subject: [PATCH 26/37] [TensorRT EP] Update TRT OSS Parser to 10.2 (#21552) ### Description Update TRT OSS Parser to [latest 10.2-GA branch](https://github.com/onnx/onnx-tensorrt/commit/f161f95883b4ebd8cb789de5efc67b73c0a6e694) ### Motivation and Context --- cgmanifests/generated/cgmanifest.json | 2 +- cmake/deps.txt | 4 ++-- .../github/azure-pipelines/templates/download-deps.yml | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 66b305a6d36de..7de3f346f6386 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -216,7 +216,7 @@ "component": { "type": "git", "git": { - "commitHash": "06adf4461ac84035bee658c6cf5df39f7ab6071d", + "commitHash": "f161f95883b4ebd8cb789de5efc67b73c0a6e694", "repositoryUrl": "https://github.com/onnx/onnx-tensorrt.git" }, "comments": "onnx_tensorrt" diff --git a/cmake/deps.txt b/cmake/deps.txt index 9d206b6bb3aeb..d0edf963451d5 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -37,8 +37,8 @@ mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/v0.3.zip;5ec64e3071edc7347ebd8a81679cf06e2bb9b851 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.16.1.zip;2eb9198bb352757d5ff13977cbe0634898e0837c -#use the latest commit of 10.0-GA -onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/06adf4461ac84035bee658c6cf5df39f7ab6071d.zip;46dceef659d75d276e7914a8057c2282269d5e7b +#use the latest commit of 10.2-GA +onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/f161f95883b4ebd8cb789de5efc67b73c0a6e694.zip;2148d0c79a171abf2b9451f3bfec164e85caf2ef protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa protoc_win64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip;b4521f7ada5b260380f94c4bd7f1b7684c76969a protoc_win32;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win32.zip;3688010318192c46ce73213cdfb6b3e5656da874 diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index bf11730c2ce28..01965343c4592 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.167 + version: 1.0.173 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.167 + version: 1.0.173 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. From 07d3be5b0e037927c3defd8a7e389e59ec748ad8 Mon Sep 17 00:00:00 2001 From: vraspar Date: Mon, 29 Jul 2024 21:04:47 -0700 Subject: [PATCH 27/37] CoreML: Add ML Program Split Op (#21456) ### Description Add support for Split Op ### Motivation and Context Address operator gaps in high priority model. --------- Co-authored-by: Scott McKay Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- .../coreml/builders/impl/split_op_builder.cc | 138 ++++++++++++------ .../apple/coreml_supported_mlprogram_ops.md | 1 + 2 files changed, 94 insertions(+), 45 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc index 0497357c45c54..dbd0f48576f8b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc @@ -5,6 +5,7 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" @@ -24,6 +25,8 @@ class SplitOpBuilder : public BaseOpBuilder { // Split opset 13- uses "split" as attribute. Currently it's not supported. int GetMinSupportedOpSet(const Node& /* node */) const override { return 13; } + + bool SupportsMLProgram() const override { return true; } }; void SplitOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { @@ -43,55 +46,98 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, ORT_RETURN_IF_NOT(GetShape(*node.InputDefs()[0], data_shape, logger), "Failed to get input shape."); NodeAttrHelper helper(node); - const auto axis = helper.Get("axis", 0); + int64_t axis = helper.Get("axis", 0); - // attribute introduced since opset 18 - uint64_t num_outputs; - - std::unique_ptr layer = model_builder.CreateNNLayer(node); - auto* coreml_splitnd = layer->mutable_splitnd(); - coreml_splitnd->set_axis(axis); - - if (input_defs.size() > 1) { - // if "split" is explicitly provided as an input - const auto& split_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); - Initializer unpacked_tensor(split_tensor); - auto split_span = unpacked_tensor.DataAsSpan(); - auto split_sizes = split_span.size(); - num_outputs = narrow(split_sizes); - for (size_t i = 0; i < split_sizes; i++) { - coreml_splitnd->add_splitsizes(split_span[i]); - } - } else if (node.SinceVersion() < 18) { - num_outputs = narrow(node.OutputDefs().size()); - coreml_splitnd->set_numsplits(num_outputs); - } else { - // note: for opset 18+ 'num_outputs' is a required attribute - num_outputs = narrow(helper.GetInt64("num_outputs").value()); + auto calculate_remainder_and_chunk_size = [&](int32_t num_outputs) { // note: checked in IsOpSupportedImpl that ensures the dim value at splitting axis exists auto split_dim_size = data_shape[HandleNegativeAxis(axis, data_shape.size())]; - uint64_t chunk_size = narrow((split_dim_size + num_outputs - 1) / num_outputs); + uint64_t chunk_size = (split_dim_size + num_outputs - 1) / num_outputs; uint64_t remainder = split_dim_size % chunk_size; - if (remainder) { - // uneven - auto split_sizes = InlinedVector(num_outputs, chunk_size); - split_sizes.back() = remainder; - for (size_t i = 0; i < split_sizes.size(); i++) { - coreml_splitnd->add_splitsizes(split_sizes[i]); - } + return std::make_tuple(remainder, chunk_size); + }; + +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + std::unique_ptr split_op = model_builder.CreateOperation(node, "split"); + AddOperationInput(*split_op, "axis", model_builder.AddScalarConstant(split_op->type(), "axis", axis)); + + if (input_defs.size() > 1) { + // if "split" is explicitly provided as an input + Initializer unpacked_tensor(*model_builder.GetConstantInitializer(input_defs[1]->Name())); + auto split_span = unpacked_tensor.DataAsSpan(); + AddOperationInput(*split_op, "split_sizes", + model_builder.AddConstant(split_op->type(), "split_sizes", split_span)); + } else if (node.SinceVersion() < 18) { + int64_t num_outputs = narrow(node.OutputDefs().size()); + AddOperationInput(*split_op, "num_splits", + model_builder.AddScalarConstant(split_op->type(), "num_splits", num_outputs)); } else { - // even + // note: for opset 18+ 'num_outputs' is a required attribute + int64_t num_outputs = helper.GetInt64("num_outputs").value(); + auto [remainder, chunk_size] = calculate_remainder_and_chunk_size(static_cast(num_outputs)); + if (remainder) { + // uneven + std::vector split_sizes(num_outputs, chunk_size); + split_sizes.back() = remainder; + AddOperationInput(*split_op, "split_sizes", + model_builder.AddConstant(split_op->type(), "split_sizes", split_sizes)); + } else { + // even + AddOperationInput(*split_op, "num_splits", + model_builder.AddScalarConstant(split_op->type(), "num_splits", num_outputs)); + } + } + + AddOperationInput(*split_op, "x", input_defs[0]->Name()); + for (const auto& output_def : node.OutputDefs()) { + AddOperationOutput(*split_op, *output_def); + } + model_builder.AddOperation(std::move(split_op)); + + } else +#endif + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + auto* coreml_splitnd = layer->mutable_splitnd(); + coreml_splitnd->set_axis(axis); + + if (input_defs.size() > 1) { + // if "split" is explicitly provided as an input + // const auto& split_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); + Initializer unpacked_tensor(*model_builder.GetConstantInitializer(input_defs[1]->Name())); + auto split_span = unpacked_tensor.DataAsSpan(); + for (const auto& split_size : split_span) { + coreml_splitnd->add_splitsizes(split_size); + } + } else if (node.SinceVersion() < 18) { + uint64_t num_outputs = narrow(node.OutputDefs().size()); coreml_splitnd->set_numsplits(num_outputs); + } else { + // note: for opset 18+ 'num_outputs' is a required attribute + uint64_t num_outputs = narrow(helper.GetInt64("num_outputs").value()); + auto [remainder, chunk_size] = calculate_remainder_and_chunk_size(static_cast(num_outputs)); + if (remainder) { + // uneven + auto split_sizes = InlinedVector(num_outputs, chunk_size); + split_sizes.back() = remainder; + for (size_t i = 0; i < split_sizes.size(); i++) { + coreml_splitnd->add_splitsizes(split_sizes[i]); + } + } else { + // even + coreml_splitnd->set_numsplits(num_outputs); + } } - } - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - // variadic number of outputs. Calculated based on the length of the given splitSizes if provided. - // Otherwise, uses attribute value 'num_outputs'. - for (uint64_t i = 0; i < num_outputs; i++) { - *layer->mutable_output()->Add() = node.OutputDefs()[i]->Name(); + *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); + // variadic number of outputs. Calculated based on the length of the given splitSizes if provided. + // Otherwise, uses attribute value 'num_outputs'. + for (const auto& output_def : node.OutputDefs()) { + *layer->mutable_output()->Add() = output_def->Name(); + } + model_builder.AddLayer(std::move(layer)); } - model_builder.AddLayer(std::move(layer)); return Status::OK(); } @@ -99,7 +145,6 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); NodeAttrHelper helper(node); const auto axis = helper.Get("axis", 0); @@ -110,16 +155,19 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar const auto split_dims_at_axis = input_shape[HandleNegativeAxis(axis, input_shape.size())]; if (input_defs.size() > 1 && input_defs[1]->Exists()) { - if (!CheckIsConstantInitializer(*input_defs[1], input_params.graph_viewer, logger, "'split'")) { + const auto* splits_tensor = input_params.graph_viewer.GetConstantInitializer(input_defs[1]->Name()); + if (!splits_tensor) { + LOGS(logger, VERBOSE) << "CoreML 'splits' input must be a constant initializer."; return false; } + const auto split_shape = *input_defs[1]->Shape(); if (split_shape.dim_size() < 2) { - LOGS(logger, VERBOSE) << "CoreML SplitND requires to produce at least 2 outputs."; + LOGS(logger, VERBOSE) << "CoreML Split must produce at least 2 outputs."; return false; } - const auto& splits_tensor = *initializers.at(input_defs[1]->Name()); - Initializer unpacked_tensor(splits_tensor); + + Initializer unpacked_tensor(*splits_tensor); auto splits_span = unpacked_tensor.DataAsSpan(); int64_t sum_of_splits = std::accumulate(splits_span.begin(), splits_span.end(), int64_t{0}); if (sum_of_splits != split_dims_at_axis) { diff --git a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md index d2a961f17bd6a..b546c266c131b 100644 --- a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md +++ b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md @@ -24,6 +24,7 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |ai.onnx:Reshape|| |ai.onnx:Resize|See [resize_op_builder.cc](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc) implementation. There are too many permutations to describe the valid combinations.| |ai.onnx.Slice|starts/ends/axes/steps must be constant initializers.| +|ai.onnx:Split|| |ai.onnx:Sub|| |ai.onnx:Sigmoid|| |ai:onnx:Tanh|| From 82036b04978b7930185996a70d2146c2895469ea Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Mon, 29 Jul 2024 21:59:16 -0700 Subject: [PATCH 28/37] Remove references to the outdated CUDA EP factory method (#21549) The function "OrtSessionOptionsAppendExecutionProvider_CUDA" is deprecated. --- .../global_thread_pools/test_inference.cc | 4 +++- onnxruntime/test/shared_lib/test_inference.cc | 20 ++++++++++++++----- .../test/shared_lib/test_model_loading.cc | 5 +++-- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/onnxruntime/test/global_thread_pools/test_inference.cc b/onnxruntime/test/global_thread_pools/test_inference.cc index f553682975f11..c6d958536f488 100644 --- a/onnxruntime/test/global_thread_pools/test_inference.cc +++ b/onnxruntime/test/global_thread_pools/test_inference.cc @@ -74,7 +74,9 @@ static Ort::Session GetSessionObj(Ort::Env& env, T model_uri, int provider_type) if (provider_type == 1) { #ifdef USE_CUDA - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + session_options.AppendExecutionProvider_CUDA_V2(*options); std::cout << "Running simple inference with cuda provider" << std::endl; #else return Ort::Session(nullptr); diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 52491a179c2ce..7a33bf8a527cd 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -1959,7 +1959,9 @@ TEST(CApiTest, get_allocator_cpu) { #ifdef USE_CUDA TEST(CApiTest, get_allocator_cuda) { Ort::SessionOptions session_options; - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + session_options.AppendExecutionProvider_CUDA_V2(*options); Ort::Session session(*ort_env, NAMED_AND_ANON_DIM_PARAM_URI, session_options); Ort::MemoryInfo info_cuda("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); @@ -2076,7 +2078,9 @@ TEST(CApiTest, io_binding_cuda) { #ifdef USE_TENSORRT Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Tensorrt(session_options, 0)); #else - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + session_options.AppendExecutionProvider_CUDA_V2(*options); #endif Ort::Session session(*ort_env, MODEL_URI, session_options); @@ -3438,7 +3442,9 @@ TEST(CApiTest, AllocateInitializersFromNonArenaMemory) { Ort::SessionOptions session_options; #ifdef USE_CUDA - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + session_options.AppendExecutionProvider_CUDA_V2(*options); #else // arena is enabled but the sole initializer will still be allocated from non-arena memory Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(session_options, 1)); @@ -3890,7 +3896,9 @@ TEST(CApiTest, GitHubIssue10179) { try { const auto* model_path = MODEL_URI; Ort::SessionOptions session_options{}; - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + session_options.AppendExecutionProvider_CUDA_V2(*options); Ort::Session session{*ort_env, model_path, session_options}; } catch (const std::exception& e) { std::cerr << "exception: " << e.what() << "\n"; @@ -3920,7 +3928,9 @@ TEST(CApiTest, GitHubIssue10179) { TEST(CApiTest, TestCudaMemcpyToHostWithSequenceTensors) { const auto* model_path = SEQUENCE_MODEL_URI_2; Ort::SessionOptions session_options{}; - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + session_options.AppendExecutionProvider_CUDA_V2(*options); Ort::Session session{*ort_env, model_path, session_options}; Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); diff --git a/onnxruntime/test/shared_lib/test_model_loading.cc b/onnxruntime/test/shared_lib/test_model_loading.cc index b7f6f7f4b9a77..5694398b9cb10 100644 --- a/onnxruntime/test/shared_lib/test_model_loading.cc +++ b/onnxruntime/test/shared_lib/test_model_loading.cc @@ -60,8 +60,9 @@ TEST(CApiTest, model_from_array) { create_session(so); #ifdef USE_CUDA - // test with CUDA provider when using onnxruntime as dll - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(so, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + so.AppendExecutionProvider_CUDA_V2(*options); create_session(so); #endif } From 530a2d7b41b0584f67ddfef6679a79e9dbeee556 Mon Sep 17 00:00:00 2001 From: Yi-Hong Lyu Date: Tue, 30 Jul 2024 03:49:14 -0700 Subject: [PATCH 29/37] Enable FP16 Clip and Handle Bias in FP16 Depthwise Conv (#21493) - Improved accuracy for face-detection, image-classification, and object-detection in the GeekBench ML benchmark on ARM64. - Fixed issue https://github.com/microsoft/onnxruntime/issues/18992 --- docs/OperatorKernels.md | 4 +- onnxruntime/core/mlas/inc/mlas.h | 2 + onnxruntime/core/mlas/lib/dwconv.cpp | 32 +-- onnxruntime/core/mlas/lib/fp16_common.h | 17 ++ .../core/providers/cpu/fp16/fp16_conv.cc | 4 +- onnxruntime/core/providers/cpu/math/clip.cc | 2 +- .../test/providers/cpu/math/clip_test.cc | 18 ++ .../test/providers/cpu/nn/conv_fp16_test.cc | 237 +++++++++++++++++- .../test/providers/cpu/nn/conv_op_test.cc | 235 +++++++++++++++++ 9 files changed, 531 insertions(+), 20 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 211c53d0fecc8..f265c9f985070 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -58,8 +58,8 @@ Do not modify directly.* |Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| |||[6, 12]|**T** = tensor(double), tensor(float)| |Celu|*in* X:**T**
*out* Y:**T**|12+|**T** = tensor(float)| -|Clip|*in* input:**T**
*in* min:**T**
*in* max:**T**
*out* output:**T**

or

*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| -|||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| +|Clip|*in* input:**T**
*in* min:**T**
*in* max:**T**
*out* output:**T**

or

*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| +|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| |||11|**T** = tensor(float)| |||[6, 10]|**T** = tensor(float)| |Col2Im|*in* input:**T**
*in* image_shape:**tensor(int64)**
*in* block_shape:**tensor(int64)**
*out* output:**T**|18+|**T** = tensor(float)| diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 675f7c7a13e8c..e46105324a7fb 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1751,6 +1751,7 @@ MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* Pac * @brief Indirect Depthwise convolution for fp16 * @param Input Supplies the indirect buffer for NHWC input * @param Filter Supplies the address for filter tensor + * @param Bias Supplies the address for 1D bias tensor B, has size of M * @param Output Supplies the address for the result tensor * @param Channels # of input channels * @param OutputCount # of output pixels @@ -1762,6 +1763,7 @@ MLASCALL MlasConvDepthwise( const MLAS_FP16* const* Input, const MLAS_FP16* Filter, + const MLAS_FP16* Bias, MLAS_FP16* Output, size_t Channels, size_t OutputCount, diff --git a/onnxruntime/core/mlas/lib/dwconv.cpp b/onnxruntime/core/mlas/lib/dwconv.cpp index 15511d2d8ceac..d48d9cbb17502 100644 --- a/onnxruntime/core/mlas/lib/dwconv.cpp +++ b/onnxruntime/core/mlas/lib/dwconv.cpp @@ -14,7 +14,6 @@ Module Name: --*/ - #include "fp16_common.h" #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED @@ -24,19 +23,20 @@ void MlasConvDepthwiseKernel( const _mlas_fp16_* const* Input, const _mlas_fp16_* Filter, + const _mlas_fp16_* Bias, _mlas_fp16_* Output, size_t Channels, size_t OutputCount, size_t KernelSize, MLAS_HALF_GEMM_POSTPROCESSOR* PostProc - ) +) { while (OutputCount > 0) { size_t ChannelOffset = 0; size_t c = Channels; while (c >= 8) { - MLAS_FLOAT16X8 Accumulator = MlasZeroFloat16x8(); + MLAS_FLOAT16X8 Accumulator = Bias == nullptr ? MlasZeroFloat16x8() : MlasLoadFloat16x8(&Bias[ChannelOffset]); size_t ChannelKernelOffset = ChannelOffset; for (size_t k = 0; k < KernelSize; k++) { @@ -54,7 +54,7 @@ MlasConvDepthwiseKernel( } if (c >= 4) { - MLAS_FLOAT16X4 Accumulator = MlasZeroFloat16x4(); + MLAS_FLOAT16X4 Accumulator = Bias == nullptr ? MlasZeroFloat16x4() : MlasLoadFloat16x4(&Bias[ChannelOffset]); size_t ChannelKernelOffset = ChannelOffset; for (size_t k = 0; k < KernelSize; k++) { @@ -72,7 +72,8 @@ MlasConvDepthwiseKernel( } if (c > 0) { - MLAS_FLOAT16X4 Accumulator = MlasZeroFloat16x4(); + MLAS_FLOAT16X4 Accumulator = + Bias == nullptr ? MlasZeroFloat16x4() : MlasLoadPartialFloat16x4(&Bias[ChannelOffset], c); size_t ChannelKernelOffset = ChannelOffset; for (size_t k = 0; k < KernelSize; k++) { @@ -86,8 +87,7 @@ MlasConvDepthwiseKernel( Output += c; } if (PostProc) { - PostProc->Process(reinterpret_cast(Output - Channels), 0, 0, 1, Channels, - Channels); + PostProc->Process(reinterpret_cast(Output - Channels), 0, 0, 1, Channels, Channels); } Input += KernelSize; OutputCount -= 1; @@ -101,16 +101,17 @@ void MlasConvDepthwiseKernel( const _mlas_fp16_* const* Input, const _mlas_fp16_* Filter, + const _mlas_fp16_* Bias, _mlas_fp16_* Output, size_t Channels, size_t OutputCount, size_t KernelSize, MLAS_HALF_GEMM_POSTPROCESSOR* PostProc - ) +) { while (OutputCount > 0) { for (size_t ChannelOffset = 0; ChannelOffset < Channels; ChannelOffset++) { - float Accumulator = 0.0f; + float Accumulator = Bias == nullptr ? 0.0f : MLAS_Half2Float(Bias[ChannelOffset]); size_t ChannelKernelOffset = ChannelOffset; for (size_t k = 0; k < KernelSize; k++) { @@ -120,35 +121,36 @@ MlasConvDepthwiseKernel( *Output++ = MLAS_Float2Half(Accumulator); } if (PostProc) { - PostProc->Process(reinterpret_cast(Output - Channels), 0, 0, 1, Channels, - Channels); + PostProc->Process(reinterpret_cast(Output - Channels), 0, 0, 1, Channels, Channels); } Input += KernelSize; OutputCount -= 1; } } -#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED - +#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED void MLASCALL MlasConvDepthwise( const MLAS_FP16* const* Input, const MLAS_FP16* Filter, + const MLAS_FP16* Bias, MLAS_FP16* Output, size_t Channels, size_t OutputCount, size_t KernelSize, MLAS_HALF_GEMM_POSTPROCESSOR* PostProc - ) +) { MlasConvDepthwiseKernel( reinterpret_cast(Input), reinterpret_cast(Filter), + reinterpret_cast(Bias), reinterpret_cast<_mlas_fp16_*>(Output), Channels, OutputCount, KernelSize, - PostProc); + PostProc + ); } diff --git a/onnxruntime/core/mlas/lib/fp16_common.h b/onnxruntime/core/mlas/lib/fp16_common.h index 1fcab870af64f..30b66cdb2ea78 100644 --- a/onnxruntime/core/mlas/lib/fp16_common.h +++ b/onnxruntime/core/mlas/lib/fp16_common.h @@ -64,6 +64,23 @@ MLAS_FORCEINLINE MLAS_FLOAT16X4 MlasLoadFloat16x4(const _mlas_fp16_* Buffer) { return vreinterpret_f16_u16(vld1_u16(Buffer)); } +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasLoadPartialFloat16x4(const _mlas_fp16_* Buffer, size_t len) +{ + MLAS_FLOAT16X4 Vector = MlasZeroFloat16x4(); + if ((len & 1) != 0) { + Vector = vreinterpret_f16_u16(vld1_lane_u16(Buffer + (len - 1), vreinterpret_u16_f16(Vector), 0)); + } + if ((len & 2) != 0) { + Vector = vreinterpret_f16_f32(vdup_lane_f32(vreinterpret_f32_f16(Vector), 0)); + Vector = vreinterpret_f16_f32( + vld1_lane_f32(reinterpret_cast(Buffer), vreinterpret_f32_f16(Vector), 0) + ); + } + return Vector; +} + MLAS_FORCEINLINE void MlasStoreFloat16x8(_mlas_fp16_* Buffer, MLAS_FLOAT16X8 Vector) diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc index e6867f10819ae..37db095e92570 100644 --- a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc +++ b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc @@ -139,8 +139,9 @@ Status FusedConvFp16::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr bool share_prepacked_weights = (prepacked_weights != nullptr); + const bool is_depthwise_conv = (group_input_channels == 1 && group_output_channels == 1); // Don't pack the filter buffer if the MlasConvDepthwise path is used. - if (!(group_input_channels == 1 && group_output_channels == 1)) { + if (!is_depthwise_conv) { packed_W_size_ = MlasHalfGemmPackBSize(group_output_channels, kernel_dim, false); if (packed_W_size_ != 0) { size_t packed_W_data_size = SafeInt(group_count) * packed_W_size_; @@ -472,6 +473,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { MlasConvDepthwise( worker_indirection_buffer, reordered_W, + Bdata, worker_output, static_cast(M), static_cast(output_count), diff --git a/onnxruntime/core/providers/cpu/math/clip.cc b/onnxruntime/core/providers/cpu/math/clip.cc index ddb64a5a0e461..200469bc47835 100644 --- a/onnxruntime/core/providers/cpu/math/clip.cc +++ b/onnxruntime/core/providers/cpu/math/clip.cc @@ -23,7 +23,7 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES( float); ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES( kCpuExecutionProvider, kOnnxDomain, Clip, 12, Input, 0, - float, double, int8_t, uint8_t, int32_t, uint32_t, int64_t, uint64_t); + float, MLFloat16, double, int8_t, uint8_t, int32_t, uint32_t, int64_t, uint64_t); } // namespace op_kernel_type_control using EnabledClip11Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( diff --git a/onnxruntime/test/providers/cpu/math/clip_test.cc b/onnxruntime/test/providers/cpu/math/clip_test.cc index 6f81bbbe31d54..9948a6cc8a681 100644 --- a/onnxruntime/test/providers/cpu/math/clip_test.cc +++ b/onnxruntime/test/providers/cpu/math/clip_test.cc @@ -119,6 +119,24 @@ TEST(MathOpTest, Clip_Default_uint64) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +TEST(MathOpTest, Clip_MLFloat16) { + OpTester test("Clip", 12); + + std::vector dims{3, 3}; + test.AddInput("X", dims, + {MLFloat16(-1.0f), MLFloat16(-2.0f), MLFloat16(-3.0f), + MLFloat16(-4.0f), MLFloat16(0.0f), MLFloat16(2.0f), + MLFloat16(4.0f), MLFloat16(6.0f), MLFloat16(8.0f)}); + test.AddInput("min", {}, {MLFloat16(0.0f)}); + test.AddInput("max", {}, {MLFloat16(6.0f)}); + test.AddOutput("Y", dims, + {MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(0.0f), + MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(2.0f), + MLFloat16(4.0f), MLFloat16(6.0f), MLFloat16(6.0f)}); + + test.Run(); +} + TEST(MathOpTest, Clip_int32) { OpTester test("Clip", 12); diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index cb5fc8095982c..95b274966fbbb 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -714,6 +714,241 @@ TEST(ConvFp16Test, Conv2D_group) { TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } +TEST(ConvFp16Test, Depthwise2D_Bias_Group1_Issue18992) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = {MLFloat16(1.0f)}; + vector X_shape = {1, 1, 1, 1}; + vector W = {MLFloat16(0.5f)}; + vector W_shape = {1, 1, 1, 1}; + vector B = {MLFloat16(0.5f)}; + vector B_shape = {1}; + vector Y_shape = {1, 1, 1, 1}; + auto expected_vals = {MLFloat16(1.0f)}; + + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvFp16Test, Depthwise2D_Bias_Group2) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 2, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + MLFloat16(0.0f), MLFloat16(1.0f), MLFloat16(2.0f), + MLFloat16(3.0f), MLFloat16(4.0f), MLFloat16(5.0f), + MLFloat16(6.0f), MLFloat16(7.0f), MLFloat16(8.0f), + + MLFloat16(9.0f), MLFloat16(10.0f), MLFloat16(11.0f), + MLFloat16(12.0f), MLFloat16(13.0f), MLFloat16(14.0f), + MLFloat16(15.0f), MLFloat16(16.0f), MLFloat16(17.0f)}; + vector X_shape = {1, 2, 3, 3}; + vector W = {MLFloat16(1.0f), MLFloat16(2.0f)}; + vector W_shape = {2, 1, 1, 1}; + vector B = {MLFloat16(1.0f), MLFloat16(-1.0f)}; + vector B_shape = {2}; + vector Y_shape = {1, 2, 3, 3}; + auto expected_vals = { + MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), + MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f), + MLFloat16(7.0f), MLFloat16(8.0f), MLFloat16(9.0f), + + MLFloat16(17.0f), MLFloat16(19.0f), MLFloat16(21.0f), + MLFloat16(23.0f), MLFloat16(25.0f), MLFloat16(27.0f), + MLFloat16(29.0f), MLFloat16(31.0f), MLFloat16(33.0f)}; + + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvFp16Test, Depthwise2D_Bias_Group15) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 15, // group + vector{2, 2}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + // C = 0 + MLFloat16(0.0f), MLFloat16(1.0f), + MLFloat16(2.0f), MLFloat16(3.0f), + + // C = 1 + MLFloat16(4.0f), MLFloat16(5.0f), + MLFloat16(6.0f), MLFloat16(7.0f), + + // C = 2 + MLFloat16(8.0f), MLFloat16(9.0f), + MLFloat16(10.0f), MLFloat16(11.0f), + + // C = 3 + MLFloat16(12.0f), MLFloat16(13.0f), + MLFloat16(14.0f), MLFloat16(15.0f), + + // C = 4 + MLFloat16(16.0f), MLFloat16(17.0f), + MLFloat16(18.0f), MLFloat16(19.0f), + + // C = 5 + MLFloat16(20.0f), MLFloat16(21.0f), + MLFloat16(22.0f), MLFloat16(23.0f), + + // C = 6 + MLFloat16(24.0f), MLFloat16(25.0f), + MLFloat16(26.0f), MLFloat16(27.0f), + + // C = 7 + MLFloat16(28.0f), MLFloat16(29.0f), + MLFloat16(30.0f), MLFloat16(31.0f), + + // C = 8 + MLFloat16(32.0f), MLFloat16(33.0f), + MLFloat16(34.0f), MLFloat16(35.0f), + + // C = 9 + MLFloat16(36.0f), MLFloat16(37.0f), + MLFloat16(38.0f), MLFloat16(39.0f), + + // C = 10 + MLFloat16(40.0f), MLFloat16(41.0f), + MLFloat16(42.0f), MLFloat16(43.0f), + + // C = 11 + MLFloat16(44.0f), MLFloat16(45.0f), + MLFloat16(46.0f), MLFloat16(47.0f), + + // C = 12 + MLFloat16(48.0f), MLFloat16(49.0f), + MLFloat16(50.0f), MLFloat16(51.0f), + + // C = 13 + MLFloat16(52.0f), MLFloat16(53.0f), + MLFloat16(54.0f), MLFloat16(55.0f), + + // C = 14 + MLFloat16(56.0f), MLFloat16(57.0f), + MLFloat16(58.0f), MLFloat16(59.0f)}; + vector X_shape = {1, 15, 2, 2}; + vector W = { + // M = 0 + MLFloat16(0.0f), MLFloat16(1.0f), + MLFloat16(2.0f), MLFloat16(3.0f), + + // M = 1 + MLFloat16(4.0f), MLFloat16(5.0f), + MLFloat16(6.0f), MLFloat16(7.0f), + + // M = 2 + MLFloat16(8.0f), MLFloat16(9.0f), + MLFloat16(10.0f), MLFloat16(11.0f), + + // M = 3 + MLFloat16(12.0f), MLFloat16(13.0f), + MLFloat16(14.0f), MLFloat16(15.0f), + + // M = 4 + MLFloat16(16.0f), MLFloat16(17.0f), + MLFloat16(18.0f), MLFloat16(19.0f), + + // M = 5 + MLFloat16(20.0f), MLFloat16(21.0f), + MLFloat16(22.0f), MLFloat16(23.0f), + + // M = 6 + MLFloat16(24.0f), MLFloat16(25.0f), + MLFloat16(26.0f), MLFloat16(27.0f), + + // M = 7 + MLFloat16(28.0f), MLFloat16(29.0f), + MLFloat16(30.0f), MLFloat16(31.0f), + + // M = 8 + MLFloat16(32.0f), MLFloat16(33.0f), + MLFloat16(34.0f), MLFloat16(35.0f), + + // M = 9 + MLFloat16(36.0f), MLFloat16(37.0f), + MLFloat16(38.0f), MLFloat16(39.0f), + + // M = 10 + MLFloat16(40.0f), MLFloat16(41.0f), + MLFloat16(42.0f), MLFloat16(43.0f), + + // M = 11 + MLFloat16(44.0f), MLFloat16(45.0f), + MLFloat16(46.0f), MLFloat16(47.0f), + + // M = 12 + MLFloat16(48.0f), MLFloat16(49.0f), + MLFloat16(50.0f), MLFloat16(51.0f), + + // M = 13 + MLFloat16(52.0f), MLFloat16(53.0f), + MLFloat16(54.0f), MLFloat16(55.0f), + + // M = 14 + MLFloat16(56.0f), MLFloat16(57.0f), + MLFloat16(58.0f), MLFloat16(59.0f)}; + vector W_shape = {15, 1, 2, 2}; + vector B = { + MLFloat16(101.0f), + MLFloat16(102.0f), + MLFloat16(103.0f), + MLFloat16(104.0f), + MLFloat16(105.0f), + MLFloat16(106.0f), + MLFloat16(107.0f), + MLFloat16(108.0f), + MLFloat16(109.0f), + MLFloat16(110.0f), + MLFloat16(111.0f), + MLFloat16(112.0f), + MLFloat16(113.0f), + MLFloat16(114.0f), + MLFloat16(115.0f)}; + vector B_shape = {15}; + vector Y_shape = {1, 15, 1, 1}; + auto expected_vals = { + MLFloat16(115.0f), // 0.0*0.0 + 1.0*1.0 + 2.0*2.0 + 3.0*3.0 + 101.0 + MLFloat16(228.0f), + MLFloat16(469.0f), + MLFloat16(838.0f), + MLFloat16(1335.0f), + MLFloat16(1960.0f), + MLFloat16(2713.0f), // 24.0*24.0 + 25.0*25.0 + 26.0*26.0 + 27.0*27.0 + 107.0 + MLFloat16(3594.0f), + MLFloat16(4603.0f), + MLFloat16(5740.0f), + MLFloat16(7005.0f), + MLFloat16(8398.0f), + MLFloat16(9919.0f), // 48.0*48.0 + 49.0*49.0 + 50.0*50.0 + 51.0*51.0 + 113.0 + MLFloat16(11568.0f), // 52.0*52.0 + 53.0*53.0 + 54.0*54.0 + 55.0*55.0 + 114.0 + MLFloat16(13345.0f) // 56.0*56.0 + 57.0*57.0 + 58.0*58.0 + 59.0*59.0 + 115.0 + }; + + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + TEST(ConvFp16Test, ConvDimWithZero) { ConvOpAndTestAttributes attrs = { "", // auto_pad @@ -1074,4 +1309,4 @@ TEST(ConvFp16Test, SharedPrepackedWeights) { } // namespace test } // namespace onnxruntime -#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED \ No newline at end of file +#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index 0efa78af2795c..2d885ee9d479f 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -647,6 +647,241 @@ TEST(ConvTest, Conv2D_group) { TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } +TEST(ConvTest, Depthwise2D_Bias_Group1_Issue18992) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = {1.0f}; + vector X_shape = {1, 1, 1, 1}; + vector W = {0.5f}; + vector W_shape = {1, 1, 1, 1}; + vector B = {0.5f}; + vector B_shape = {1}; + vector Y_shape = {1, 1, 1, 1}; + auto expected_vals = {1.0f}; + + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvTest, Depthwise2D_Bias_Group2) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 2, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + 0.0f, 1.0f, 2.0f, + 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f, + + 9.0f, 10.0f, 11.0f, + 12.0f, 13.0f, 14.0f, + 15.0f, 16.0f, 17.0f}; + vector X_shape = {1, 2, 3, 3}; + vector W = {1.0f, 2.0f}; + vector W_shape = {2, 1, 1, 1}; + vector B = {1.0f, -1.0f}; + vector B_shape = {2}; + vector Y_shape = {1, 2, 3, 3}; + auto expected_vals = { + 1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, + + 17.0f, 19.0f, 21.0f, + 23.0f, 25.0f, 27.0f, + 29.0f, 31.0f, 33.0f}; + + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvTest, Depthwise2D_Bias_Group15) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 15, // group + vector{2, 2}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + // C = 0 + 0.0f, 1.0f, + 2.0f, 3.0f, + + // C = 1 + 4.0f, 5.0f, + 6.0f, 7.0f, + + // C = 2 + 8.0f, 9.0f, + 10.0f, 11.0f, + + // C = 3 + 12.0f, 13.0f, + 14.0f, 15.0f, + + // C = 4 + 16.0f, 17.0f, + 18.0f, 19.0f, + + // C = 5 + 20.0f, 21.0f, + 22.0f, 23.0f, + + // C = 6 + 24.0f, 25.0f, + 26.0f, 27.0f, + + // C = 7 + 28.0f, 29.0f, + 30.0f, 31.0f, + + // C = 8 + 32.0f, 33.0f, + 34.0f, 35.0f, + + // C = 9 + 36.0f, 37.0f, + 38.0f, 39.0f, + + // C = 10 + 40.0f, 41.0f, + 42.0f, 43.0f, + + // C = 11 + 44.0f, 45.0f, + 46.0f, 47.0f, + + // C = 12 + 48.0f, 49.0f, + 50.0f, 51.0f, + + // C = 13 + 52.0f, 53.0f, + 54.0f, 55.0f, + + // C = 14 + 56.0f, 57.0f, + 58.0f, 59.0f}; + vector X_shape = {1, 15, 2, 2}; + vector W = { + // M = 0 + 0.0f, 1.0f, + 2.0f, 3.0f, + + // M = 1 + 4.0f, 5.0f, + 6.0f, 7.0f, + + // M = 2 + 8.0f, 9.0f, + 10.0f, 11.0f, + + // M = 3 + 12.0f, 13.0f, + 14.0f, 15.0f, + + // M = 4 + 16.0f, 17.0f, + 18.0f, 19.0f, + + // M = 5 + 20.0f, 21.0f, + 22.0f, 23.0f, + + // M = 6 + 24.0f, 25.0f, + 26.0f, 27.0f, + + // M = 7 + 28.0f, 29.0f, + 30.0f, 31.0f, + + // M = 8 + 32.0f, 33.0f, + 34.0f, 35.0f, + + // M = 9 + 36.0f, 37.0f, + 38.0f, 39.0f, + + // M = 10 + 40.0f, 41.0f, + 42.0f, 43.0f, + + // M = 11 + 44.0f, 45.0f, + 46.0f, 47.0f, + + // M = 12 + 48.0f, 49.0f, + 50.0f, 51.0f, + + // M = 13 + 52.0f, 53.0f, + 54.0f, 55.0f, + + // M = 14 + 56.0f, 57.0f, + 58.0f, 59.0f}; + vector W_shape = {15, 1, 2, 2}; + vector B = { + 101.0f, + 102.0f, + 103.0f, + 104.0f, + 105.0f, + 106.0f, + 107.0f, + 108.0f, + 109.0f, + 110.0f, + 111.0f, + 112.0f, + 113.0f, + 114.0f, + 115.0f}; + vector B_shape = {15}; + vector Y_shape = {1, 15, 1, 1}; + auto expected_vals = { + 115.0f, // 0.0*0.0 + 1.0*1.0 + 2.0*2.0 + 3.0*3.0 + 101.0 + 228.0f, + 469.0f, + 838.0f, + 1335.0f, + 1960.0f, + 2713.0f, // 24.0*24.0 + 25.0*25.0 + 26.0*26.0 + 27.0*27.0 + 107.0 + 3594.0f, + 4603.0f, + 5740.0f, + 7005.0f, + 8398.0f, + 9919.0f, // 48.0*48.0 + 49.0*49.0 + 50.0*50.0 + 51.0*51.0 + 113.0 + 11568.0f, // 52.0*52.0 + 53.0*53.0 + 54.0*54.0 + 55.0*55.0 + 114.0 + 13345.0f // 56.0*56.0 + 57.0*57.0 + 58.0*58.0 + 59.0*59.0 + 115.0 + }; + + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + TEST(ConvTest, ConvDimWithZero) { ConvOpAndTestAttributes attrs = { "", // auto_pad From 1637f22d39b6d57d2774d7d41e6a8ae1815180c5 Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Tue, 30 Jul 2024 09:35:45 -0700 Subject: [PATCH 30/37] Extend Pad Fusion for AveragePool (#21556) ### Description This extends the existing pad_fusion for AveragePool operator i.e. fuse Pad if it is followed by AveragePool operator. ### Motivation and Context --- onnxruntime/core/optimizer/pad_fusion.cc | 3 ++- onnxruntime/core/optimizer/pad_fusion.h | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/optimizer/pad_fusion.cc b/onnxruntime/core/optimizer/pad_fusion.cc index a1c7f8de9e6fe..e266946b0d9e0 100644 --- a/onnxruntime/core/optimizer/pad_fusion.cc +++ b/onnxruntime/core/optimizer/pad_fusion.cc @@ -12,7 +12,7 @@ namespace onnxruntime { * It matches following pattern: * Pad * | - * Conv/MaxPool + * Conv/MaxPool/AveragePool */ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { // if Pad has input axis, don't fuse it. @@ -28,6 +28,7 @@ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const log const Node& child_node = *node.OutputNodesBegin(); if (!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Conv", {1, 11}) && + !graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "AveragePool", {1, 7, 10, 11, 19}) && !graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "MaxPool", {1, 8, 10, 11, 12})) { return false; } diff --git a/onnxruntime/core/optimizer/pad_fusion.h b/onnxruntime/core/optimizer/pad_fusion.h index a1b6978a83d1e..ca05d219b7e2c 100644 --- a/onnxruntime/core/optimizer/pad_fusion.h +++ b/onnxruntime/core/optimizer/pad_fusion.h @@ -8,7 +8,7 @@ namespace onnxruntime { /* * This fusion submerges a Pad operator to it's child - * Conv or MaxPool operator, if and only if PadFusion::SatisfyCondition() + * Conv or MaxPool or AveragePool operator, if and only if PadFusion::SatisfyCondition() * is true. */ class PadFusion : public RewriteRule { From e7aa11607f59c7efeb8505af8fe9186ff2e3dd37 Mon Sep 17 00:00:00 2001 From: Jing Fang <126209182+fajin-corp@users.noreply.github.com> Date: Tue, 30 Jul 2024 15:22:46 -0700 Subject: [PATCH 31/37] Utilize ext data location to reduce qd matmul memory usage (#21451) ### Description When the graph is quantized to qdq format, the DQ + MatMul is transformed to MatMulNBits in the level 2 optimizer when the model is initialized in an inference session. In the transformation step, tensors are transposed and new tensor protos are created. Instead of using protobuf arena allocated memory, the PR sets the tensor proto to use external buffer, and point the external location to memory location which contains the tensor buffer allocated by CPU. Then, in the step that creates OrtValue using the tensor proto, the memory buffers in the tensor proto are directly assigned to the tensors which were originally allocated by Ort Arena. With these two steps, the peak memory usage of QDQ format model is the same as usage of QOperator model. Besides, the model initialization time is significantly reduced. Take [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) for example: || QOperator Model (MatMulNBits) | QDQ Model (DQ + MatMul, original code) | QDQ Model (this PR) | |---|---|---|---| | peak memory consumption | 2.8 GB | ~4.8 GB | 2.8 GB | | initialization time | 3 sec | 9 sec | 5 sec | ### Motivation and Context When the graph is quantized to qdq format, the DQ + MatMul is converted to MatMulNBits in the level 2 optimizer. Originally, the newly created tensor proto use memory allocated by protobuf arena. These memory usage cannot be fully released when the tensor protos are deleted. Then, in the tensor proto to OrtValue step, tensors are created using ORT arena. Later, in the pre-pack step for MatMulNBits, new OrtValues are created. The tensors in the ORT arena are not fully released as well. The two arena memory allocation steps in the DQ + MatMul -> MatMulNBits transformation will result in almost 2x memory consumption in the model initialization. --- .../core/optimizer/graph_transformer_utils.h | 9 +- onnxruntime/core/framework/session_state.cc | 3 +- onnxruntime/core/framework/session_state.h | 11 ++ .../core/framework/session_state_utils.cc | 41 ++++-- .../core/framework/session_state_utils.h | 6 +- .../core/framework/tensorprotoutils.cc | 36 +++++- onnxruntime/core/framework/tensorprotoutils.h | 29 +++-- .../core/optimizer/graph_transformer_utils.cc | 12 +- .../selectors_actions/qdq_actions.cc | 119 +++++++++++------- .../selectors_actions/qdq_actions.h | 6 +- .../qdq_selector_action_transformer.cc | 35 ++++-- .../qdq_selector_action_transformer.h | 7 +- onnxruntime/core/session/inference_session.cc | 14 ++- 13 files changed, 239 insertions(+), 89 deletions(-) diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index 0bb5c7432f0a7..6cff153c336f0 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -3,12 +3,15 @@ #pragma once +#include #include +#include #include #include #include "core/common/inlined_containers.h" #include "core/framework/session_options.h" +#include "core/framework/tensor.h" #include "core/optimizer/graph_transformer.h" #include "core/platform/threadpool.h" @@ -51,7 +54,8 @@ InlinedVector> GenerateTransformers( const SessionOptions& session_options, const IExecutionProvider& execution_provider /*required by constant folding*/, const InlinedHashSet& rules_and_transformers_to_disable = {}, - concurrency::ThreadPool* intra_op_thread_pool = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr, + std::unordered_map>* p_buffered_tensors = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) @@ -81,7 +85,8 @@ InlinedVector> GenerateTransformersForMinimalB const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, const InlinedHashSet& rules_and_transformers_to_disable = {}, - concurrency::ThreadPool* intra_op_thread_pool = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr, + std::unordered_map>* p_buffered_tensors = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index a88f36f63639c..ddb0c3356e544 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1486,7 +1486,8 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string #include #include +#include #include #include "core/common/flatbuffers.h" @@ -303,6 +304,10 @@ class SessionState { const InlinedHashSet* GetToBeExecutedRange(gsl::span fetch_mlvalue_idxs) const; #endif + std::unordered_map>* GetMutableBufferedTensors() { + return &name_to_buffered_tensor_; + } + Status FinalizeSessionState(const std::basic_string& graph_loc, const KernelRegistryManager& kernel_registry_manager, bool remove_initializers = true, @@ -562,6 +567,12 @@ class SessionState { // flag to indicate whether current session using any EP that create device stream dynamically. bool has_device_stream_enabled_ep_ = false; #endif + + // Holds the tensors which provide memory buffer for TensorProtos + // Use case: in optimizer, transform a TensorProto to a new TensorProto whose the memory buffer is + // allocated by CPU instead by protobuf's arena. Arena style memory allocators do not fully release + // a instance's memory which may result large memory consumption, which is a tradeoff for speed. + std::unordered_map> name_to_buffered_tensor_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 059de8e3c8c4a..b13b0cd27496d 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -3,6 +3,8 @@ #include #include +#include +#include #include #include @@ -61,17 +63,23 @@ struct ExtDataValueDeleter { // given a tensor proto with external data return an OrtValue with a tensor for // that data; the pointers for the tensor data and the tensor itself are owned -// by the OrtValue's deleter +// by the OrtValue's deleter. +// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and +// buffered_tensor is not null, buffered_tensor holds the real buffer pointed +// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter +// should release the buffer when tensor_proto is released. static inline common::Status ExtDataTensorProtoToTensor(const Env& env, const std::basic_string& proto_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, - Tensor& tensor, OrtCallback& ext_data_deleter) { + Tensor& tensor, OrtCallback& ext_data_deleter, + Tensor* buffered_tensor = nullptr) { ORT_ENFORCE(utils::HasExternalData(tensor_proto)); void* ext_data_buf = nullptr; SafeInt ext_data_len = 0; ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(env, proto_path.c_str(), tensor_proto, - ext_data_buf, ext_data_len, ext_data_deleter)); + ext_data_buf, ext_data_len, ext_data_deleter, + buffered_tensor)); // NB: creating a do-nothing allocator per tensor is wasteful; can perhaps be // avoided if the Tensor class implements the do-nothing behavior when given a @@ -83,16 +91,24 @@ static inline common::Status ExtDataTensorProtoToTensor(const Env& env, return common::Status::OK(); } +// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and +// buffered_tensor is not null, buffered_tensor holds the real buffer pointed +// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter +// should release the buffer when tensor_proto is released. static common::Status DeserializeTensorProto(const Env& env, const std::basic_string& proto_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer* m, const AllocatorPtr& alloc, const AllocatorPtr& default_cpu_alloc, OrtValue& ort_value, const DataTransferManager& data_transfer_mgr, - bool use_device_allocator_for_initializers = false) { + bool use_device_allocator_for_initializers = false, + Tensor* buffered_tensor = nullptr) { if (bool(alloc) == (m != nullptr)) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DeserializeTensorProto() takes either pre-allocated buffer or an allocator!"); } + ORT_RETURN_IF(buffered_tensor && !utils::HasExternalData(tensor_proto), + "With buffered tensor, tensor proto must use external location and point to buffered tensor"); + // Get shape and type of the tensor, and allocate the empty tensor TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); @@ -123,7 +139,8 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st // utilize the mmap'd buffer directly by calling ExtDataTensorProtoToTensor. If we called // TensorProtoToTensor it would copy the data, causing unnecessary overhead OrtCallback ext_data_deleter; - ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_tensor, ext_data_deleter)); + ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_tensor, + ext_data_deleter, buffered_tensor)); ExtDataValueDeleter deleter{ext_data_deleter, p_tensor.get()}; @@ -154,7 +171,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st std::optional scoped_ort_callback_invoker; if (utils::HasExternalData(tensor_proto)) { ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_deserialize_tensor, - ext_data_deleter)); + ext_data_deleter, buffered_tensor)); scoped_ort_callback_invoker = ScopedOrtCallbackInvoker(ext_data_deleter); } else { ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, proto_path.c_str(), tensor_proto, *p_deserialize_tensor)); @@ -187,7 +204,8 @@ common::Status SaveInitializedTensors( const logging::Logger& logger, const DataTransferManager& data_transfer_mgr, const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, - const MemoryProfileFunction& memory_profile_func) { + const MemoryProfileFunction& memory_profile_func, + std::unordered_map>& buffered_tensors) { LOGS(logger, INFO) << "Saving initialized tensors."; ORT_ENFORCE(ort_value_name_idx_map.MaxIdx() > -1, "OrtValue indexes should have been populated."); @@ -307,9 +325,16 @@ common::Status SaveInitializedTensors( bool use_device_allocator_for_initializers = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsUseDeviceAllocatorForInitializers, "0") == "1"; + Tensor* p_tensor = nullptr; + if (auto iter = buffered_tensors.find(name); + iter != buffered_tensors.end()) { + p_tensor = iter->second.release(); + buffered_tensors.erase(iter); + } + Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (m.has_value()) ? &*m : nullptr, alloc, default_cpu_alloc, ort_value, data_transfer_mgr, - use_device_allocator_for_initializers); + use_device_allocator_for_initializers, p_tensor); if (!st.IsOK()) { std::ostringstream oss; oss << "Deserialize tensor " << name << " failed." << st.ErrorMessage(); diff --git a/onnxruntime/core/framework/session_state_utils.h b/onnxruntime/core/framework/session_state_utils.h index af44c35fbb7f5..499222b6ec613 100644 --- a/onnxruntime/core/framework/session_state_utils.h +++ b/onnxruntime/core/framework/session_state_utils.h @@ -3,6 +3,9 @@ #pragma once #include +#include +#include +#include #include "core/common/const_pointer_container.h" #include "core/framework/allocator.h" @@ -44,7 +47,8 @@ common::Status SaveInitializedTensors( const DataTransferManager& data_transfer_mgr, const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, - const MemoryProfileFunction& memory_profile_func); + const MemoryProfileFunction& memory_profile_func, + std::unordered_map>& buffered_tensors); common::Status SaveInputOutputNamesToNodeMapping(const GraphViewer& graph, SessionState& session_state, diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 4ecd61962d797..cbd53298ab2ad 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -987,7 +987,8 @@ static Status GetFileContent(const Env& env, const std::filesystem::path& file_p Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf, - SafeInt& ext_data_len, OrtCallback& ext_data_deleter) { + SafeInt& ext_data_len, OrtCallback& ext_data_deleter, + Tensor* buffered_tensor) { ORT_ENFORCE(utils::HasExternalData(tensor_proto)); std::basic_string tensor_proto_dir; if (!model_path.empty()) { @@ -1003,7 +1004,12 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo // the value in location is the memory address of the data ext_data_buf = reinterpret_cast(file_offset); ext_data_len = raw_data_safe_len; - ext_data_deleter = OrtCallback{nullptr, nullptr}; + if (buffered_tensor) { + ext_data_deleter = OrtCallback{[](void* p) noexcept { delete reinterpret_cast(p); }, + reinterpret_cast(buffered_tensor)}; + } else { + ext_data_deleter = OrtCallback{nullptr, nullptr}; + } } else { #if defined(__wasm__) ORT_RETURN_IF(file_offset < 0 || file_offset + raw_data_safe_len >= 4294967296, @@ -1241,7 +1247,9 @@ ONNXTensorElementDataType GetTensorElementType(const ONNX_NAMESPACE::TensorProto return CApiElementTypeFromProtoType(tensor_proto.data_type()); } -ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name) { +ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, + const std::string& tensor_proto_name, + bool use_tensor_buffer) { // Set name, dimensions, type, and data of the TensorProto. ONNX_NAMESPACE::TensorProto tensor_proto; @@ -1259,6 +1267,28 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std: for (; f < end; ++f) { *mutable_string_data->Add() = *f; } + } else if (use_tensor_buffer && tensor.SizeInBytes() > 127) { + // The logic aligns with + // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph_flatbuffers_utils.cc#L302 + const auto* raw_data = tensor.DataRaw(); + ORT_ENFORCE(raw_data, "Missing raw data for tensor proto. Invalid tensor."); + static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE)); + tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + + // we reinterpret_cast this back to void* in tensorprotoutils.cc:GetExtDataFromTensorProto. + // use intptr_t as OFFSET_TYPE is signed. in theory you could get a weird looking value if the address uses the + // high bit, but that should be unlikely in a scenario where we care about memory usage enough to use this path. + auto offset = narrow(reinterpret_cast(raw_data)); + + ONNX_NAMESPACE::StringStringEntryProto* entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("location"); + entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag)); + entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("offset"); + entry->set_value(std::to_string(offset)); + entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("length"); + entry->set_value(std::to_string(tensor.SizeInBytes())); } else { utils::SetRawDataInTensorProto(tensor_proto, tensor.DataRaw(), tensor.SizeInBytes()); } diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index e5197adcb94ec..2af1f080be7ee 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -114,14 +114,22 @@ common::Status TensorProtoToTensor(const Env& env, const std::filesystem::path& const ONNX_NAMESPACE::TensorProto& tensor_proto, Tensor& tensor); -/** Creates a TensorProto from a Tensor. - @param[in] tensor the Tensor whose data and shape will be used to create the TensorProto. - @param[in] tensor_proto_name the name of the TensorProto. - @return the TensorProto. - - Note: Method currently requires that data is in little-endian format. +/** + * @brief Creates a TensorProto from a Tensor. + * @param[in] tensor the Tensor whose data and shape will be used to create the TensorProto. + * @param[in] tensor_proto_name the name of the TensorProto. + * @param[in] use_tensor_buffer the tensor proto is set to use external location, with + * 'location' set to onnxruntime::utils::kTensorProtoMemoryAddressTag + * 'offset' set to tensor's memory location, and 'length' set to tensor's + * memory size. The caller is responsible to maintain the lifetime of + * the allocated memory buffer. Use with caution. + * @return the TensorProto. + * + * Note: Method currently requires that data is in little-endian format. */ -ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name); +ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, + const std::string& tensor_proto_name, + bool use_tensor_buffer = false); ONNXTensorElementDataType CApiElementTypeFromProtoType(int type); ONNXTensorElementDataType GetTensorElementType(const ONNX_NAMESPACE::TensorProto& tensor_proto); @@ -141,10 +149,15 @@ constexpr const ORTCHAR_T* kTensorProtoMemoryAddressTag = ORT_TSTR("*/_ORT_MEM_A // Given a tensor proto with external data obtain a pointer to the data and its length. // The ext_data_deleter argument is updated with a callback that owns/releases the data. +// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and +// buffered_tensor is not null, buffered_tensor holds the real buffer pointed +// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter +// should release the buffer when tensor_proto is released. common::Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf, SafeInt& ext_data_len, - OrtCallback& ext_data_deleter); + OrtCallback& ext_data_deleter, + Tensor* buffered_tensor = nullptr); // Convert the AttributeProto from a Constant node into a TensorProto that can be used as an initializer // If AttributeProto contains a TensorProto, this tensor proto is converted as is including the case when the diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index ab1dbaea7b7fd..54bd44ec2dba4 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -189,7 +189,8 @@ InlinedVector> GenerateTransformers( const SessionOptions& session_options, const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ const InlinedHashSet& rules_and_transformers_to_disable, - [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool) { + [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) { InlinedVector> transformers; const bool disable_quant_qdq = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; @@ -309,7 +310,8 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, SatApplyContextVariant{}, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool)); + intra_op_thread_pool, + p_buffered_tensors)); } transformers.emplace_back(std::make_unique(cpu_ep)); @@ -419,7 +421,8 @@ InlinedVector> GenerateTransformersForMinimalB const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, const InlinedHashSet& rules_and_transformers_to_disable, - [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool) { + [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) { InlinedVector> transformers; const bool saving = std::holds_alternative(apply_context); @@ -444,7 +447,8 @@ InlinedVector> GenerateTransformersForMinimalB transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, apply_context, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool)); + intra_op_thread_pool, + p_buffered_tensors)); } transformers.emplace_back(std::make_unique(cpu_ep, apply_context)); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 74fecb0427e14..8f99b7409d4fe 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include + #include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h" #include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/optimizer/initializer.h" @@ -275,8 +278,10 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select } } -DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction(int64_t accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) +DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction( + int64_t accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) : accuracy_level_{accuracy_level}, domain_{kMSDomain}, op_type_{"MatMulNBits"}, @@ -286,7 +291,8 @@ DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction(int64_t accuracy_level, MoveAndAppend(target, ArgType::kInput, 0, ArgType::kInput), MoveAll(target, ArgType::kOutput)}; }()}, - intra_op_thread_pool_{intra_op_thread_pool} { + intra_op_thread_pool_{intra_op_thread_pool}, + p_buffered_tensors_{p_buffered_tensors} { ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); } @@ -311,6 +317,7 @@ DQMatMulToMatMulNBitsAction::ExtraAttributes(const RuntimeState& runtime_state) Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, const NodesToOptimize& selected_nodes, Node& replacement_node) const { + ORT_RETURN_IF_NOT(p_buffered_tensors_, "Buffered tensors map cannot be null"); const auto* dq_node = selected_nodes.Input(0); const auto* weight_arg = dq_node->InputDefs()[0]; const auto* scale_arg = dq_node->InputDefs()[1]; @@ -338,24 +345,35 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, // to what we need. But it does not handle external data. Initializer weight_src(*weight_tensor_proto, graph.ModelPath()); Initializer scale_src(*scale_tensor_proto, graph.ModelPath()); - std::optional zp_src; - Initializer weight_dst(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName(weight_arg->Name() + "_T"), - std::vector{N, quant_num, blob_bytes}); - Initializer scale_dst(static_cast(scale_src.data_type()), - graph.GenerateNodeArgName(scale_arg->Name() + "_T"), - std::vector{N * quant_num}); - std::optional zp_dst; + auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum(ONNX_NAMESPACE::TensorProto_DataType_UINT8)->GetElementType(); + auto scale_type = DataTypeImpl::TensorTypeFromONNXEnum(scale_src.data_type())->GetElementType(); + std::optional zp_src_ptr; + auto cpu_allocator = std::make_shared(); + auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_T"); + auto weight_dst_ptr = std::make_unique(uint8_type, + TensorShape{N, quant_num, blob_bytes}, + cpu_allocator); + auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_T"); + auto scale_size = (TensorShape{N, quant_num}).Size(); + auto scale_dst_ptr = std::make_unique(scale_type, + TensorShape{scale_size}, + cpu_allocator); + std::string zp_dst_name; + std::unique_ptr zp_dst_ptr; + auto zp_size = (TensorShape{N, (quant_num + 1) / 2}).Size(); if (zp_tensor_proto) { - zp_src.emplace(*zp_tensor_proto, graph.ModelPath()); - zp_dst.emplace(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName(zp_arg->Name() + "_T"), - std::vector{N * ((quant_num + 1) / 2)}); + zp_src_ptr.emplace(*zp_tensor_proto, graph.ModelPath()); + zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_T"); + zp_dst_ptr = std::make_unique(uint8_type, + TensorShape{zp_size}, + cpu_allocator); } else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { - zp_dst.emplace(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"), - std::vector{N * ((quant_num + 1) / 2)}); + zp_dst_name = graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"); + zp_dst_ptr = std::make_unique(uint8_type, + TensorShape{zp_size}, + cpu_allocator); + memset(zp_dst_ptr->MutableDataRaw(), 0, zp_dst_ptr->SizeInBytes()); } if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { @@ -363,10 +381,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + weight_dst_ptr->MutableData(), + scale_dst_ptr->MutableData(), + zp_dst_ptr ? zp_dst_ptr->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -376,10 +394,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + weight_dst_ptr->MutableData(), + scale_dst_ptr->MutableData(), + zp_dst_ptr ? zp_dst_ptr->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -391,10 +409,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + weight_dst_ptr->MutableData(), + scale_dst_ptr->MutableData(), + zp_dst_ptr ? zp_dst_ptr->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -405,10 +423,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + weight_dst_ptr->MutableData(), + scale_dst_ptr->MutableData(), + zp_dst_ptr ? zp_dst_ptr->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -417,28 +435,43 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, } } - ONNX_NAMESPACE::TensorProto weight_T_tp; - ONNX_NAMESPACE::TensorProto scale_T_tp; + auto weight_T_tp = utils::TensorToTensorProto(*weight_dst_ptr, weight_dst_name, true); + auto scale_T_tp = utils::TensorToTensorProto(*scale_dst_ptr, scale_dst_name, true); std::optional zp_T_tp; - // TODO(fajin): external_data to memory location to avoid arena allocation - // https://github.com/microsoft/onnxruntime/pull/12465 - weight_dst.ToProto(weight_T_tp); - scale_dst.ToProto(scale_T_tp); - if (zp_dst) { - zp_T_tp.emplace(); - zp_dst->ToProto(zp_T_tp.value()); + if (zp_dst_ptr) { + zp_T_tp.emplace(utils::TensorToTensorProto(*zp_dst_ptr, zp_dst_name, true)); } auto& input_defs = replacement_node.MutableInputDefs(); input_defs.push_back(&graph_utils::AddInitializer(graph, weight_T_tp)); replacement_node.MutableInputArgsCount().push_back(1); + if (weight_T_tp.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + // If tensor is too small, tensor proto directly copies data from tensor. The tensor allocated + // here can be directly destructed. + // Only keep the tensor in p_buffered_tensors_ when the tensor proto is using external data location + // and pointing the location to tensor's buffer. + ORT_RETURN_IF_NOT(p_buffered_tensors_->emplace(weight_dst_name, std::move(weight_dst_ptr)).second, + "Failed to add buffered tensor ", + weight_dst_name); + } + input_defs.push_back(&graph_utils::AddInitializer(graph, scale_T_tp)); replacement_node.MutableInputArgsCount().push_back(1); + if (scale_T_tp.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + ORT_RETURN_IF_NOT(p_buffered_tensors_->emplace(scale_dst_name, std::move(scale_dst_ptr)).second, + "Failed to add buffered tensor ", + scale_dst_name); + } if (zp_T_tp) { input_defs.push_back(&graph_utils::AddInitializer(graph, zp_T_tp.value())); replacement_node.MutableInputArgsCount().push_back(1); + if (zp_T_tp->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + ORT_RETURN_IF_NOT(p_buffered_tensors_->emplace(zp_dst_name, std::move(zp_dst_ptr)).second, + "Failed to add buffered tensor ", + zp_dst_name); + } } return Status::OK(); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index 47821619db65a..d25077ca4b491 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -5,10 +5,12 @@ #include #include +#include #include #include "core/optimizer/selectors_actions/actions.h" #include "core/platform/threadpool.h" +#include "core/framework/tensor.h" namespace onnxruntime { @@ -84,7 +86,8 @@ struct MatMulReplaceWithQLinear : public Action { // used together with DQMatMulNodeGroupSelector, which does the sanity check struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { DQMatMulToMatMulNBitsAction(int64_t accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool); + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors); private: std::string OpType(const RuntimeState&) const override { return op_type_; } @@ -103,6 +106,7 @@ struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { const std::string op_type_; const std::vector value_moves_; concurrency::ThreadPool* intra_op_thread_pool_; + std::unordered_map>* p_buffered_tensors_; }; struct GemmReplaceWithQuant : public Action { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index d81701fdf443b..379d271fbdca7 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -1,8 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include +#include +#include + +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/mlas/inc/mlas.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h" @@ -247,7 +250,8 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_registry, int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) { + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) { // 2 nodes. DQ -> MatMul. DQ is the second input to MatMul. // DQ's weight is int4/uint4. DQ's scale is float/float16. // DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power. @@ -255,7 +259,8 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi std::unique_ptr action = std::make_unique(qdq_matmulnbits_accuracy_level, - intra_op_thread_pool); + intra_op_thread_pool, + p_buffered_tensors); #if !defined(ORT_MINIMAL_BUILD) std::unique_ptr selector = std::make_unique(); @@ -312,9 +317,11 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { #endif } -SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed, - int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) { +SelectorActionRegistry CreateSelectorActionRegistry( + bool is_int8_allowed, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) { SelectorActionRegistry qdq_selector_action_registry; SplitQDQRules(qdq_selector_action_registry); DropQDQNodesRules(qdq_selector_action_registry); @@ -328,20 +335,24 @@ SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed, WhereQDQRules(qdq_selector_action_registry); DQMatMulToMatMulNBitsRules(qdq_selector_action_registry, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool); + intra_op_thread_pool, + p_buffered_tensors); return qdq_selector_action_registry; } } // namespace -QDQSelectorActionTransformer::QDQSelectorActionTransformer(bool is_int8_allowed, - const SatApplyContextVariant& apply_context, - int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) +QDQSelectorActionTransformer::QDQSelectorActionTransformer( + bool is_int8_allowed, + const SatApplyContextVariant& apply_context, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) : SelectorActionTransformer{ "QDQSelectorActionTransformer", - CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, intra_op_thread_pool), + CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, + intra_op_thread_pool, p_buffered_tensors), apply_context, // this transformer is only compatible with the CPU and DML EP {kCpuExecutionProvider, kDmlExecutionProvider}} { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h index ba636f76d1900..627ddd35b9919 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h @@ -3,6 +3,10 @@ #pragma once +#include +#include +#include + #include "core/optimizer/selectors_actions/selector_action_transformer.h" #include "core/mlas/inc/mlas.h" #include "core/platform/threadpool.h" @@ -25,7 +29,8 @@ class QDQSelectorActionTransformer : public SelectorActionTransformer { QDQSelectorActionTransformer(bool is_int8_allowed, const SatApplyContextVariant& apply_context = {}, int64_t qdq_matmulnbits_accuracy_level = 4, - concurrency::ThreadPool* intra_op_thread_pool = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr, + std::unordered_map>* p_buffered_tensors = nullptr); }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 5ad2f08467792..5eed7c5c6f2b5 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1615,7 +1615,8 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, Status ApplyOrtFormatModelRuntimeOptimizations( onnxruntime::Graph& graph, const logging::Logger& logger, const SessionOptions& session_options, const InlinedHashSet& optimizers_to_disable, const IExecutionProvider& cpu_ep, - concurrency::ThreadPool* intra_op_thread_pool) { + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) { bool modified = false; for (int level = static_cast(TransformerLevel::Level2); @@ -1623,7 +1624,7 @@ Status ApplyOrtFormatModelRuntimeOptimizations( ++level) { const auto transformers = optimizer_utils::GenerateTransformersForMinimalBuild( static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, - optimizers_to_disable, intra_op_thread_pool); + optimizers_to_disable, intra_op_thread_pool, p_buffered_tensors); for (const auto& transformer : transformers) { ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger)); @@ -2012,7 +2013,8 @@ common::Status InferenceSession::Initialize() { const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); ORT_RETURN_IF_ERROR_SESSIONID_( ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, - cpu_ep, GetIntraOpThreadPoolToUse())); + cpu_ep, GetIntraOpThreadPoolToUse(), + session_state_->GetMutableBufferedTensors())); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } @@ -3175,7 +3177,8 @@ common::Status InferenceSession::AddPredefinedTransformers( if (use_full_build_optimizations) { return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, optimizers_to_disable_, - GetIntraOpThreadPoolToUse()); + GetIntraOpThreadPoolToUse(), + session_state_->GetMutableBufferedTensors()); } else { const auto sat_context = minimal_build_optimization_handling == @@ -3185,7 +3188,8 @@ common::Status InferenceSession::AddPredefinedTransformers( : SatApplyContextVariant{SatDirectApplicationContext{}}; return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep, optimizers_to_disable_, - GetIntraOpThreadPoolToUse()); + GetIntraOpThreadPoolToUse(), + session_state_->GetMutableBufferedTensors()); } }(); From 1d4b161145ee9604b98d904a6c3da5610fbf6d01 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Wed, 31 Jul 2024 08:46:08 +0800 Subject: [PATCH 32/37] [WebNN EP] Support ConvTranspose for TFLite backend (#21291) ### Description Chromium supports ConvTranspose for TFLite in https://chromium-review.googlesource.com/c/chromium/src/+/5635194 With constraint that only default dilations and groups are supported. --------- Co-authored-by: Dwayne Robinson --- js/web/docs/webnn-operators.md | 2 +- .../core/providers/webnn/builders/helper.h | 2 +- .../webnn/builders/impl/conv_op_builder.cc | 16 ++++++++++++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 75652899b5e5e..8711d4d20e370 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -22,7 +22,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Clip | ai.onnx(7-10, 11, 12, 13+) | clamp | ✓ | ✓ | WebNN CPU backend only supports 3 specific ranges: [0.0, infinity], [-1.0, 1.0], [0.0, 6.0] (Chromium issue: https://issues.chromium.org/issues/326156496) | | Concat | ai.onnx(7-10, 11-12, 13+) | concat | ✓ | ✓ | | | Conv | ai.onnx(7-10, 11+) | conv2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight) | -| ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d | ✗ | ✓ | Only supports 3-D or 4-D input and 'W' (weight). | +| ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight). WebNN CPU backend only supports default dilations and group | | Cos | ai.onnx(7+) | cos | ✓ | ✓ | | | Div | ai.onnx(7-12, 13, 14+) | div | ✓ | ✓ | | | Elu | ai.onnx(7+) | elu | ✓ | ✓ | WebNN CPU backend only supports 'alpha' value is 1.0 | diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 496f886e5a076..63fd97abb9a9a 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -167,7 +167,7 @@ static const InlinedHashMap op_map = { {"Concat", {"concat", true}}, {"Conv", {"conv2d", true}}, {"ConvInteger", {"conv2dInteger", false}}, - {"ConvTranspose", {"convTranspose2d", false}}, + {"ConvTranspose", {"convTranspose2d", true}}, {"Cos", {"cos", true}}, {"Div", {"div", true}}, {"DequantizeLinear", {"dequantizeLinear", false}}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 4f3f7459a7b5b..22049d2519712 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -427,6 +427,22 @@ bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy return false; } + // WebNN CPU backend (TFLite) only supports default dilations and group. + // https://source.chromium.org/chromium/chromium/src/+/main:services/webnn/tflite/graph_builder_tflite.cc;l=1040 + if (device_type == WebnnDeviceType::CPU && op_type == "ConvTranspose") { + NodeAttrHelper helper(node); + const auto dilations = helper.Get("dilations", std::vector{1, 1}); + const auto group = helper.Get("group", 1); + if (dilations[0] != 1 || (dilations.size() > 1 && dilations[1] != 1)) { + LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default dilation 1."; + return false; + } + if (group != 1) { + LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default group 1."; + return false; + } + } + return true; } From b341c44c20d45c88430feb62b684033ee7b12a8f Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Wed, 31 Jul 2024 08:59:55 -0700 Subject: [PATCH 33/37] Fix ETW trace logging crash in multithreading situations (#21566) ### Description ETW trace logger is fakely registered as initialized_ is marked as true before the registration is done, causing crashing issue for Lenovo camera application. A prior attempt to address was made here: https://github.com/microsoft/onnxruntime/pull/21226 It was reverted here: https://github.com/microsoft/onnxruntime/pull/21360 ### Motivation and Context The problem is that during initialization of TraceLoggingRegisterEx, it will reinvoke the callback and attempt reinitialization, which is not allowed. TraceLoggingRegisterEx however can be initialized concurrently when initialization happens on multiple threads. For these reasons it needs to be protected by a lock, but the lock cannot naively block because the callback's reinvocation will cause a deadlock. To solve this problem another tracking variable is added : "initializing" which protects against reinitialization during the first initialization. --------- Co-authored-by: Sheil Kumar --- .../core/platform/windows/logging/etw_sink.cc | 27 +++++++++++++++---- .../core/platform/windows/logging/etw_sink.h | 7 ++++- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/platform/windows/logging/etw_sink.cc b/onnxruntime/core/platform/windows/logging/etw_sink.cc index b0f9eaf4f62d2..ef42c88a67ba6 100644 --- a/onnxruntime/core/platform/windows/logging/etw_sink.cc +++ b/onnxruntime/core/platform/windows/logging/etw_sink.cc @@ -137,28 +137,45 @@ void NTAPI EtwRegistrationManager::ORT_TL_EtwEnableCallback( EtwRegistrationManager::~EtwRegistrationManager() { std::lock_guard lock(callbacks_mutex_); callbacks_.clear(); - ::TraceLoggingUnregister(etw_provider_handle); + if (initialization_status_ == InitializationStatus::Initialized || + initialization_status_ == InitializationStatus::Initializing) { + std::lock_guard init_lock(init_mutex_); + assert(initialization_status_ != InitializationStatus::Initializing); + if (initialization_status_ == InitializationStatus::Initialized) { + ::TraceLoggingUnregister(etw_provider_handle); + initialization_status_ = InitializationStatus::NotInitialized; + } + } } EtwRegistrationManager::EtwRegistrationManager() { } -void EtwRegistrationManager::LazyInitialize() { - if (!initialized_) { +void EtwRegistrationManager::LazyInitialize() try { + if (initialization_status_ == InitializationStatus::NotInitialized) { std::lock_guard lock(init_mutex_); - if (!initialized_) { // Double-check locking pattern - initialized_ = true; + if (initialization_status_ == InitializationStatus::NotInitialized) { // Double-check locking pattern + initialization_status_ = InitializationStatus::Initializing; etw_status_ = ::TraceLoggingRegisterEx(etw_provider_handle, ORT_TL_EtwEnableCallback, nullptr); if (FAILED(etw_status_)) { ORT_THROW("ETW registration failed. Logging will be broken: " + std::to_string(etw_status_)); } + initialization_status_ = InitializationStatus::Initialized; } } +} catch (...) { + initialization_status_ = InitializationStatus::Failed; + throw; } void EtwRegistrationManager::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext) { + if (initialization_status_ != InitializationStatus::Initialized) { + // Drop messages until manager is fully initialized. + return; + } + std::lock_guard lock(callbacks_mutex_); for (const auto& callback : callbacks_) { (*callback)(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); diff --git a/onnxruntime/core/platform/windows/logging/etw_sink.h b/onnxruntime/core/platform/windows/logging/etw_sink.h index 3af45b813a625..d6c9ea27b2955 100644 --- a/onnxruntime/core/platform/windows/logging/etw_sink.h +++ b/onnxruntime/core/platform/windows/logging/etw_sink.h @@ -47,6 +47,11 @@ class EtwSink : public ISink { }; class EtwRegistrationManager { + enum class InitializationStatus { NotInitialized, + Initializing, + Initialized, + Failed }; + public: using EtwInternalCallback = std::function Date: Wed, 31 Jul 2024 09:01:05 -0700 Subject: [PATCH 34/37] [CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed
not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed
not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention
bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past
only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present
(not share buffer) ### Motivation and Context https://github.com/microsoft/onnxruntime/issues/18854 --- .../contrib_ops/cpu/bert/attention_base.cc | 1 - .../contrib_ops/cpu/bert/attention_common.h | 13 +- .../cpu/bert/multihead_attention.cc | 15 +- .../cpu/bert/multihead_attention_helper.h | 574 +++++++----- .../cuda/bert/add_bias_transpose.cu | 114 +++ .../cuda/bert/add_bias_transpose.h | 54 +- .../contrib_ops/cuda/bert/attention.cc | 4 +- .../contrib_ops/cuda/bert/attention_impl.cu | 202 ++-- .../contrib_ops/cuda/bert/attention_impl.h | 55 +- .../cuda/bert/attention_kernel_options.h | 1 + .../cuda/bert/attention_kv_cache.cu | 73 +- .../cuda/bert/attention_prepare_qkv.cu | 864 ++++++++++++------ .../cuda/bert/attention_transpose.cu | 6 + .../decoder_masked_multihead_attention.cc | 13 +- ...decoder_masked_multihead_attention_impl.cu | 3 + .../cuda/bert/multihead_attention.cc | 221 ++--- .../cuda/bert/multihead_attention.h | 6 + .../cuda/bert/packed_attention_impl.cu | 22 +- .../bert/packed_multihead_attention_impl.cu | 58 +- .../quantization/attention_quantization.cc | 4 +- .../cuda/utils/dump_cuda_tensor.cc | 9 + .../contrib_ops/cuda/utils/dump_cuda_tensor.h | 2 +- .../contrib_ops/rocm/bert/attention_impl.cu | 10 +- .../contrib_ops/rocm/bert/attention_impl.h | 6 - .../rocm/bert/multihead_attention.cu | 8 +- .../tools/transformers/fusion_attention.py | 3 + .../test/python/transformers/benchmark_mha.py | 99 +- .../test/python/transformers/test_mha.py | 346 ++++--- 28 files changed, 1729 insertions(+), 1057 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc index 515a967aa2386..f7d8fedc734e4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc @@ -258,7 +258,6 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, output_parameters->scale = scale_; output_parameters->mask_type = mask_type; output_parameters->broadcast_res_pos_bias = broadcast_res_pos_bias; - output_parameters->pass_past_in_kv = false; output_parameters->qkv_format = Q_K_V_BNSH; } diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 55292b35e1e38..88127387d08ea 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -6,6 +6,12 @@ namespace onnxruntime { namespace contrib { +enum AttentionType { + kAttention, + kMultiHeadAttention, + kDecoderMaskedMultiHeadAttention, +}; + enum AttentionMaskType { MASK_NONE, // No mask MASK_1D_KEY_SEQ_LEN, // [batch_size], key sequence length @@ -24,10 +30,12 @@ enum AttentionQkvFormat { UNKNOWN, // enum value not set, or depends on qkv projection implementation details Q_K_V_BNSH, // for non-packed qkv, permuted Q_K_V_BSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention - QKV_BSN3H, // for TRT fused attention, qkv are packed + Q_K_V_BSNH_BNSH_BNSH, // for cross attention, k and v are permuted Q_K_V_BNSH_QKV_BS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH) - Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed Q_K_V_TNH, // for memory efficient attention, qkv are not packed, and paddings are removed. + Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed + QKV_BSN3H, // for TRT fused attention, qkv are packed + QKV_BS3NH, // for DecoderMaskedMultiHeadAttention, qkv are packed QKV_TN3H, // for TRT fused attention, qkv are packed and paddings are removed }; @@ -61,7 +69,6 @@ struct AttentionParameters { bool past_present_share_buffer; bool do_rotary; bool broadcast_res_pos_bias; - bool pass_past_in_kv; float mask_filter_value; float scale; bool use_tf32; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 9677c30f22d8a..0d77376779230 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -85,7 +85,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { scale_, is_unidirectional_, past_present_share_buffer, - false)); + kMultiHeadAttention)); const int batch_size = parameters.batch_size; const int q_sequence_length = parameters.sequence_length; @@ -121,20 +121,13 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - // For each of Q/K/V, there are multiple scenarios: - // 1) Combined QKV bias is null - // a) Q/K/V is (B, S, D) - // b) Q/K/V is (B, S, N, H) - // 2) No packed QKV in Q - // a) Q/K/V has seq_len = 1 - // b) Q/K/V has seq_len > 1 - OrtValue Q; ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias( context, allocator, batch_size, num_heads_, q_sequence_length, qk_head_size, query, bias, q_bias_offset, Q)); - if (parameters.pass_past_in_kv) { // key and value in BNSH format - assert(bias == nullptr); + if (parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH) { + // For cross attention with k and v in BNSH format, we assume that bias for key and value are zeros. + // So we don't need to add bias for key and value here. assert(past_key == nullptr); assert(past_value == nullptr); return ApplyAttention(Q.GetMutable()->MutableData(), diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index bd7ab09659170..cfb8d36843777 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -11,6 +11,232 @@ namespace onnxruntime { namespace contrib { namespace multihead_attention_helper { +template +Status Check_QKV(const T* packed_qkv, AttentionQkvFormat& qkv_format) { + const auto& query_dims = packed_qkv->Shape().GetDims(); + if (query_dims.size() == 3) { + // Packed qkv used by DecoderMaskedMultiHeadAttention. Query shape is (B, S, 3D), no key and value. + qkv_format = AttentionQkvFormat::QKV_BS3NH; + } else { + assert(query_dims.size() == 5); + if (static_cast(query_dims[3]) != 3) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Expect 'query' shape (batch_size, sequence_length, num_heads, 3, head_size) for packed qkv"); + } + + qkv_format = AttentionQkvFormat::QKV_BSN3H; + } + + return Status::OK(); +} + +template +Status Check_Q_KV(const T* query, const T* packed_kv, int num_heads, int head_size, + AttentionQkvFormat& qkv_format, int& kv_sequence_length) { + const auto& query_dims = query->Shape().GetDims(); + const auto& key_dims = packed_kv->Shape().GetDims(); + if (query_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Expect rank of query be 3 for packed kv"); + } + + if (key_dims.size() != 5) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Expect rank of key be 5 for packed kv"); + } + + if (key_dims[0] != query_dims[0] || + static_cast(key_dims[2]) != num_heads || + static_cast(key_dims[3]) != 2 || + static_cast(key_dims[4]) != head_size) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv"); + } + + qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + kv_sequence_length = static_cast(key_dims[1]); + return Status::OK(); +} + +template +Status Check_Q_K_V(const T* query, const T* key, const T* value, int num_heads, int head_size, + AttentionQkvFormat& qkv_format, int& kv_sequence_length, int& v_hidden_size) { + const auto& query_dims = query->Shape().GetDims(); + const auto& key_dims = key->Shape().GetDims(); + const auto& value_dims = value->Shape().GetDims(); + if (query_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Expect rank of query be 3 for packed kv"); + } + + if (key_dims.size() != value_dims.size() || (key_dims.size() != 3 && value_dims.size() != 4)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Expect rank of key and value be same, and either 3 or 4"); + } + + if (key_dims[0] != query_dims[0] || value_dims[0] != query_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query', 'key' and 'value' shall have same dim 0 (batch_size)"); + } + + if (key_dims.size() == 3) { + if (key_dims[2] != query_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 2 (hidden_size)"); + } + + if (key_dims[1] != value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall have same dim 1 (kv_sequence_length)"); + } + + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + kv_sequence_length = static_cast(key_dims[1]); + v_hidden_size = static_cast(value_dims[2]); + } else { // key_dims.size() == 4 + if (value->Shape() != key->Shape() || + static_cast(key_dims[1]) != num_heads || + static_cast(key_dims[3]) != head_size) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall have same shape (batch_size, num_heads, kv_sequence_length, head_size)"); + } + + qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; + kv_sequence_length = static_cast(key_dims[2]); + v_hidden_size = static_cast(value_dims[1]) * static_cast(value_dims[3]); + } + + return Status::OK(); +} + +template +Status CheckPast(const T* past_key, const T* past_value, const T* past_seq_len, + int batch_size, int num_heads, int head_size, bool past_present_share_buffer, + int& past_sequence_length, int& max_sequence_length) { + const auto& past_key_dims = past_key->Shape().GetDims(); + const auto& past_value_dims = past_value->Shape().GetDims(); + + if (past_key_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' is expected to have 4 dimensions, got ", + past_key_dims.size()); + } + if (past_value_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' is expected to have 4 dimensions, got ", + past_value_dims.size()); + } + + if (past_key_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 0 should be batch_size, got ", + past_key_dims[0]); + } + if (past_value_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 0 should be batch_size, got ", + past_value_dims[0]); + } + + if (past_key_dims[1] != num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 1 should be same as number of heads, got ", + past_key_dims[1]); + } + if (past_value_dims[1] != num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 1 should be same as number of heads, got ", + past_value_dims[1]); + } + if (past_key_dims[2] != past_value_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length). ", + past_key_dims[2], " vs ", past_value_dims[2]); + } + if (past_key_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 3 should be same as head_size, got ", + past_key_dims[3]); + } + if (past_value_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 3 should be same as head_size, got ", + past_value_dims[3]); + } + past_sequence_length = static_cast(past_key_dims[2]); + if (past_present_share_buffer) { + max_sequence_length = static_cast(past_key_dims[2]); + if (past_seq_len == nullptr || !onnxruntime::IsScalarOr1ElementVector(past_seq_len)) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "past_sequence_length tensor must be of one element when past_present_share_buffer is set"); + } + past_sequence_length = *((*past_seq_len).template Data()); + } + return Status::OK(); +} + +template +Status CheckRelativePositionBias( + const T* relative_position_bias, int batch_size, int num_heads, int sequence_length, int total_sequence_length, + bool& broadcast_res_pos_bias) { + const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); + + if (relative_position_bias_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' is expected to have 4 dimensions, got ", + relative_position_bias_dims.size()); + } + if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 0 should be batch_size or 1, got ", + relative_position_bias_dims[0]); + } + if (relative_position_bias_dims[0] == 1) { + broadcast_res_pos_bias = true; + } + if (relative_position_bias_dims[1] != num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", + relative_position_bias_dims[1]); + } + if (relative_position_bias_dims[2] != sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", + relative_position_bias_dims[2]); + } + if (relative_position_bias_dims[3] != total_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ", + relative_position_bias_dims[3]); + } + return Status::OK(); +} + +template +AttentionMaskType GetMaskType(const T* key_padding_mask, int batch_size, int sequence_length, int total_sequence_length) { + AttentionMaskType mask_type = AttentionMaskType::MASK_UNKNOWN; + const auto& mask_dims = key_padding_mask->Shape().GetDims(); + if (mask_dims.size() == 1) { + if (mask_dims[0] == static_cast(batch_size)) { + mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN; + } else if (mask_dims[0] == static_cast(3) * static_cast(batch_size) + static_cast(2)) { + mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; + } + } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(total_sequence_length)) { + mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; + } else if (mask_dims.size() == 3 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(sequence_length) && + mask_dims[2] == static_cast(total_sequence_length)) { + mask_type = AttentionMaskType::MASK_3D_ATTENTION; + } + return mask_type; +} + template Status CheckInputs(const T* query, const T* key, @@ -27,176 +253,128 @@ Status CheckInputs(const T* query, float scale, bool is_unidirectional, bool past_present_share_buffer, - bool dmmha_packing) { - // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None - // relative_position_bias : (B, 1, S, L) - // past_key : (B, N, S*, H) - // past_value : (B, N, S*, H) - // When no packing for q/k/v: + AttentionType operator_type) { + // --------------------------------------------------------------- + // Notations: + // B: batch_size + // N: num_heads + // H: head_size (V might have different head size than Q and K) + // D: hidden_size = N * H + // S: q_sequence_length + // P: past_sequence_length + // L: kv_sequence_length + // T: total_sequence_length = P + L + // M: max_sequence_length + // --------------------------------------------------------------- + // MultiHeadAttention inputs: + // --------------------------------------------------------------- + // Q_K_V_BSNH - no packing: // query (Q) : (B, S, D) - // key (K) : (B, L, D) or (B, N, S*, H) - // value (V) : (B, L, D_v) or (B, N, S*, H) - // bias (Q/K/V) : (D + D + D_v) - // When packed kv is used: + // key (K) : (B, L, D) + // value (V) : (B, L, D_v) + // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache is not used, L == T, D == D_v): // query (Q) : (B, S, D) - // key (K) : (B, L, N, 2, H) - // value (V) : None - // bias (Q/K/V) : None - // When packed qkv is used: - // query (Q) : (B, L, N, 3, H) or (B, S, 3*D) + // key (K) : (B, N, L, H) + // value (V) : (B, N, L, H) + // Q_KV_BSNH_BSN2H - packed kv (kv cache is not used, bias is not allowed for packed kv): + // query (Q) : (B, S, D) + // key (K/V) : (B, L, N, 2, H) + // value : None + // QKV_BSN3H - packed qkv (kv cache is not used, S == L, D == D_v): + // query (Q/K/V) : (B, S, N, 3, H) + // key : None + // value : None + // + // Other inputs: + // bias (Q/K/V) : None or (D + D + D_v) + // key_padding_mask (K/V) : (B) or (3 * B + 2) or (B, T) or (B, S, T) + // relative_position_bias : (B, N, S, T) or (1, N, S, T) + // past_key : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH. + // past_value : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH. + // --------------------------------------------------------------- + // DecoderMaskedMultiHeadAttention inputs (S == 1, D == D_v): + // --------------------------------------------------------------- + // Q_K_V_BSNH - no packing: + // query (Q) : (B, S, D) + // key (K) : (B, L, D) + // value (V) : (B, L, D) + // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache and relative_position_bias are not used. L == T): + // query (Q) : (B, S, D) + // key (K) : (B, N, L, H) + // value (V) : (B, N, L, H) + // QKV_BS3NH - packed qkv (S == L): + // query (Q) : (B, S, 3 * D) // key (K) : None // value (V) : None - // bias (Q/K/V) : None or (D + D + D_v) - - AttentionQkvFormat qkv_format; + // + // Other inputs: + // bias (Q/K/V) : None or (3 * D) + // key_padding_mask (K/V) : None or (B, T) + // relative_position_bias : (1, N, S, T), or (B, N, S, T) where only 1 x N x S x T data is used in CUDA. + // + // The following inputs are not used in cross attention (so they are None for cross attention): + // past_key : (B, N, P, H), or (B, N, M, H) when past_present_share_buffer is True. + // For CUDA, past_present_share_buffer is always True. ROCm supports both. + // past_value : (B, N, P, H), or (B, N, M, H) when past_present_share_buffer is True. + // For CUDA, past_present_share_buffer is always True. ROCm supports both. + // past_sequence_length : scalar (1) when past_present_share_buffer is True. + // CUDA version has extra inputs (beam_width, cache_indirection) that are not checked in the class. + // For ROCm, see contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh for more details. + // --------------------------------------------------------------- + AttentionQkvFormat qkv_format = UNKNOWN; const auto& query_dims = query->Shape().GetDims(); - if (query_dims.size() != 3 && query_dims.size() != 5) { + + int query_rank = static_cast(query_dims.size()); + if (query_rank != 3 && query_rank != 5) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 or 5 dimensions, got ", - query_dims.size()); + query_rank); } int batch_size = static_cast(query_dims[0]); int sequence_length = static_cast(query_dims[1]); - int hidden_size = (query_dims.size() == 3) + bool dmmha_packing = operator_type == kDecoderMaskedMultiHeadAttention && key == nullptr && value == nullptr; + int hidden_size = (query_rank == 3) ? (dmmha_packing ? (static_cast(query_dims[2]) / 3) : static_cast(query_dims[2])) : (num_heads * static_cast(query_dims[4])); int head_size = static_cast(hidden_size) / num_heads; int kv_sequence_length = sequence_length; + int v_hidden_size = hidden_size; + if (key != nullptr) { + if (value == nullptr) { + ORT_RETURN_IF_ERROR(Check_Q_KV(query, key, num_heads, head_size, qkv_format, kv_sequence_length)); + } else { + ORT_RETURN_IF_ERROR(Check_Q_K_V(query, key, value, num_heads, head_size, + qkv_format, kv_sequence_length, v_hidden_size)); + } + } else if (value == nullptr) { // no key and value + ORT_RETURN_IF_ERROR(Check_QKV(query, qkv_format)); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'value' shall absent when 'key' is absent"); + } + int past_sequence_length = 0; int max_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { - const auto& past_key_dims = past_key->Shape().GetDims(); - const auto& past_value_dims = past_value->Shape().GetDims(); - - if (past_key_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' is expected to have 4 dimensions, got ", - past_key_dims.size()); - } - if (past_value_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' is expected to have 4 dimensions, got ", - past_value_dims.size()); - } - - if (past_key_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 0 should be batch_size, got ", - past_key_dims[0]); - } - if (past_value_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 0 should be batch_size, got ", - past_value_dims[0]); - } - - if (past_key_dims[1] != num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 1 should be same as number of heads, got ", - past_key_dims[1]); - } - if (past_value_dims[1] != num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 1 should be same as number of heads, got ", - past_value_dims[1]); - } - if (past_key_dims[2] != past_value_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length). ", - past_key_dims[2], " vs ", past_value_dims[2]); - } - if (past_key_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 3 should be same as head_size, got ", - past_key_dims[3]); - } - if (past_value_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 3 should be same as head_size, got ", - past_value_dims[3]); - } - past_sequence_length = static_cast(past_key_dims[2]); - max_sequence_length = static_cast(past_key_dims[2]); - if (past_present_share_buffer) { - if (past_seq_len == nullptr || !onnxruntime::IsScalarOr1ElementVector(past_seq_len)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "past_sequence_length tensor must be of one element when past_present_share_buffer is set"); - } - past_sequence_length = *((*past_seq_len).template Data()); - } + ORT_RETURN_IF_ERROR(CheckPast(past_key, past_value, past_seq_len, + batch_size, num_heads, head_size, past_present_share_buffer, + past_sequence_length, max_sequence_length)); } else if (past_key != nullptr || past_value != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_key' and 'past_value' shall be both present or both absent"); } - if (key != nullptr) { - if (query_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions when key is given, got ", - query_dims.size()); - } - - const auto& key_dims = key->Shape().GetDims(); - if (key_dims.size() != 3 && key_dims.size() != 4 && key_dims.size() != 5) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3, 4, or 5 dimensions, got ", - key_dims.size()); - } - if (query_dims[0] != key_dims[0]) { + if (operator_type == kMultiHeadAttention) { + if (qkv_format == AttentionQkvFormat::QKV_BS3NH) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); + "Packed qkv of 3D BS3NH format is not support by MultiHeadAttention"); } - if (key_dims.size() == 3) { - if (key_dims[2] != query_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 2 (hidden_size)"); - } - - qkv_format = Q_K_V_BSNH; - kv_sequence_length = static_cast(key_dims[1]); - } else if (key_dims.size() == 5) { - if (static_cast(key_dims[2]) != num_heads || static_cast(key_dims[3]) != 2 || static_cast(key_dims[4]) != head_size) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, INVALID_ARGUMENT, - "Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv"); - } - if (value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format."); - } - - qkv_format = Q_KV_BSNH_BSN2H; - kv_sequence_length = static_cast(key_dims[1]); - } else { // key_dims.size() == 4 (cross-attention with past_key) - if (static_cast(key_dims[1]) != num_heads || static_cast(key_dims[3]) != head_size) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, INVALID_ARGUMENT, - "Expect 'key' shape (batch_size, num_heads, kv_sequence_length, head_size)"); - } - - if (value == nullptr || value->Shape().GetDims().size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' shall be 4D when 'key' is 4D"); - } - - if (bias != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' shall be empty when 'key' is 4D"); - } - - qkv_format = UNKNOWN; - kv_sequence_length = static_cast(key_dims[2]); - } - } else { // packed QKV - if (query_dims.size() != 3 && query_dims.size() != 5) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 or 5 dimensions when key is empty, got ", - query_dims.size()); - } - if (query_dims.size() == 5 && (static_cast(query_dims[2]) != num_heads || static_cast(query_dims[3]) != 3)) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, INVALID_ARGUMENT, - "Expect 'query' shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv"); + if (qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H && bias != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' shall be empty when packed kv is used"); } - - qkv_format = QKV_BSN3H; } if (bias != nullptr) { @@ -206,116 +384,31 @@ Status CheckInputs(const T* query, bias_dims.size()); } - if (value == nullptr) { - // Currently, bias is not allowed for packed KV. This constraint can be removed later. - // Here we assume that fusion tool will not include bias for packed KV. - if (query_dims.size() == 5 && query_dims[3] == 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'bias' is not allowed for packed kv. "); - } + int expected_bias_length = 2 * hidden_size + v_hidden_size; + if (bias_dims[0] != static_cast(expected_bias_length)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' length is expected to be 2 * hidden_size + hidden_size_v, got ", + bias_dims.size()); } } int total_sequence_length = past_sequence_length + kv_sequence_length; AttentionMaskType mask_type = AttentionMaskType::MASK_NONE; if (key_padding_mask != nullptr) { - mask_type = AttentionMaskType::MASK_UNKNOWN; - const auto& mask_dims = key_padding_mask->Shape().GetDims(); - if (mask_dims.size() == 1) { - if (mask_dims[0] == static_cast(batch_size)) { - mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN; - } else if (mask_dims[0] == static_cast(3) * static_cast(batch_size) + static_cast(2)) { - mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; - } - } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && - mask_dims[1] == static_cast(kv_sequence_length)) { - mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; - } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && - mask_dims[1] == static_cast(total_sequence_length)) { - mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; - } else if (mask_dims.size() == 3 && mask_dims[0] == static_cast(batch_size) && - mask_dims[1] == static_cast(sequence_length) && - mask_dims[2] == static_cast(total_sequence_length)) { - mask_type = AttentionMaskType::MASK_3D_ATTENTION; - } - + mask_type = GetMaskType(key_padding_mask, batch_size, sequence_length, total_sequence_length); if (mask_type == AttentionMaskType::MASK_UNKNOWN) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key_padding_mask' shape shall be 1D, 2D, or 3D"); - } - } - - // NOTE: In Cross-Attention, we pass the past key and value to 'key' and 'value' instead of 'past_key' and 'past_value'. - bool pass_past_in_kv = false; - int v_hidden_size = hidden_size; - if (value != nullptr) { - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3 && value_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 or 4 dimensions, got ", - value_dims.size()); - } - - if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch_size)"); - } - - if (value_dims.size() == 3) { - if (static_cast(kv_sequence_length) != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)"); - } - v_hidden_size = static_cast(value_dims[2]); - } else { // value_dims.size() == 4 - if (static_cast(kv_sequence_length) != value_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have the same dim 2 (kv_sequence_length)"); - } - - if (past_key != nullptr || past_value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' and 'past_value' shall be empty when 'value' is 4D"); - } - - v_hidden_size = static_cast(value_dims[1]) * static_cast(value_dims[3]); - pass_past_in_kv = true; + "Input 'key_padding_mask' shape is not expected."); } } bool broadcast_res_pos_bias = false; if (relative_position_bias != nullptr) { - const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); - - if (relative_position_bias_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' is expected to have 4 dimensions, got ", - relative_position_bias_dims.size()); - } - if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 0 should be batch_size or 1, got ", - relative_position_bias_dims[0]); - } - if (relative_position_bias_dims[0] == 1) { - broadcast_res_pos_bias = true; - } - if (relative_position_bias_dims[1] != num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", - relative_position_bias_dims[1]); - } - if (relative_position_bias_dims[2] != sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", - relative_position_bias_dims[2]); - } - if (relative_position_bias_dims[3] != total_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ", - relative_position_bias_dims[3]); - } + ORT_RETURN_IF_ERROR(CheckRelativePositionBias( + relative_position_bias, batch_size, num_heads, sequence_length, total_sequence_length, broadcast_res_pos_bias)); } - // TODO: ORT_RETURN_IF(qkv_format == UNKNOWN, "Unrecognized QKV format"); + assert(qkv_format != UNKNOWN); + if (parameters != nullptr) { AttentionParameters* output_parameters = reinterpret_cast(parameters); output_parameters->batch_size = batch_size; @@ -323,7 +416,7 @@ Status CheckInputs(const T* query, output_parameters->past_sequence_length = past_sequence_length; output_parameters->kv_sequence_length = kv_sequence_length; output_parameters->total_sequence_length = total_sequence_length; - output_parameters->max_sequence_length = max_sequence_length; + output_parameters->max_sequence_length = past_present_share_buffer ? max_sequence_length : total_sequence_length; output_parameters->input_hidden_size = 0; output_parameters->hidden_size = hidden_size; output_parameters->v_hidden_size = v_hidden_size; @@ -336,7 +429,6 @@ Status CheckInputs(const T* query, output_parameters->mask_type = mask_type; output_parameters->scale = scale; output_parameters->broadcast_res_pos_bias = broadcast_res_pos_bias; - output_parameters->pass_past_in_kv = pass_past_in_kv; output_parameters->qkv_format = qkv_format; } @@ -359,7 +451,7 @@ Status CheckInputs(const T* query, float scale, bool is_unidirectional, bool past_present_share_buffer, - bool dmmha_packing, + AttentionType operator_type, int max_threads_per_block) { if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); @@ -367,7 +459,7 @@ Status CheckInputs(const T* query, return CheckInputs(query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, past_seq_len, parameters, num_heads, mask_filter_value, scale, is_unidirectional, - past_present_share_buffer, dmmha_packing); + past_present_share_buffer, operator_type); } } // namespace multihead_attention_helper diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index 9e6752b451868..62d6a723bf32c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -520,6 +520,39 @@ __global__ void AddBiasUnpack(int M, const T* input, const T* biases, T* output) } } +template +__global__ void AddBiasTransposeUnpack(int M, const T* input, const T* biases, T* output) { + // Format 5 to unpack TRT packed input format to BNSH for unfused attention. + // Input: BxSxNxMxH + // Output: MxBxNxSxH + // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size + int n = threadIdx.y; + int s = blockIdx.x; + int b = blockIdx.y; + int m = blockIdx.z; // matrix id + + const int head_size = blockDim.x; + const int num_heads = blockDim.y; + + const int sequence_length = gridDim.x; + const int batch_size = gridDim.y; + const int H = head_size; + const int NH = num_heads * head_size; + const int NHS = NH * sequence_length; + + int in_offset = m * head_size + n * M * H + (s * NH + b * NHS) * M; + const int out_offset = (s + n * sequence_length) * head_size + (b + m * batch_size) * NHS; + + const int h = threadIdx.x; + if (h < head_size) { + if (biases != nullptr) { + output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h]; + } else { + output[out_offset + h] = input[in_offset + h]; + } + } +} + template __global__ void AddBiasTransposeCutlass(int M, const T* input, const T* biases, T* output) { // Format 3 for cutlass memory efficient attention @@ -692,6 +725,8 @@ void InvokeAddBiasTranspose( } } else if (format == 4) { // format == 4 AddBiasUnpack<<>>(total_matrix_count, input, biases, output); + } else if (format == 5) { // format == 5 + AddBiasTransposeUnpack<<>>(total_matrix_count, input, biases, output); } else { // format == 0 AddBiasTranspose<<>>(input, biases, output); } @@ -716,6 +751,8 @@ void InvokeAddBiasTranspose( } } else if (format == 4) { // format == 4 ORT_THROW("AddBiasTranspose (format 4) not implemented for hidden_size > max_threads_per_block"); + } else if (format == 5) { // format == 5 + ORT_THROW("AddBiasTranspose (format 5) not implemented for hidden_size > max_threads_per_block"); } else { // format 0 AddBiasTransposeLarge<<>>(qk_head_size, input, biases, output); } @@ -904,6 +941,7 @@ void InvokeAddBias( AddBiasTransposeTrtLarge<<>>(head_size, query, biases, q); } } + // K { const dim3 grid(kv_sequence_length, batch_size, num_matrices); @@ -1011,6 +1049,82 @@ void LaunchAddBias( } } +template +void InvokeAddBias( + cudaStream_t stream, const int max_threads_per_block, + const int batch_size, const int sequence_length, + const int num_heads, const int head_size, + const T* biases, const T* query, T* q) { + assert(num_heads <= max_threads_per_block); + constexpr int num_matrices = 1; + const dim3 grid(sequence_length, batch_size, num_matrices); + if (head_size * num_heads <= max_threads_per_block) { + const dim3 block(head_size, num_heads, 1); + AddBiasTransposeTrt<<>>(query, biases, q); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + AddBiasTransposeTrtLarge<<>>(head_size, query, biases, q); + } +} + +template <> +void LaunchAddBias( + cudaStream_t stream, const int max_threads_per_block, + const int batch_size, const int sequence_length, + const int num_heads, const int head_size, + const float* biases, const float* query, float* q) { + if (0 == (head_size % 4)) { + const int H = head_size / 4; + const float4* query2 = reinterpret_cast(query); + const float4* biases2 = reinterpret_cast(biases); + float4* q2 = reinterpret_cast(q); + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, H, + biases2, query2, q2); + } else if (0 == (head_size & 1)) { + const int H = head_size / 2; + const float2* query2 = reinterpret_cast(query); + const float2* biases2 = reinterpret_cast(biases); + float2* q2 = reinterpret_cast(q); + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, H, + biases2, query2, q2); + } else { + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, head_size, + biases, query, q); + } +} + +template <> +void LaunchAddBias( + cudaStream_t stream, const int max_threads_per_block, + const int batch_size, const int sequence_length, + const int num_heads, const int head_size, + const half* biases, const half* query, half* q) { + if (0 == (head_size % 4)) { + const int H = head_size / 4; + const Half4* query2 = reinterpret_cast(query); + const Half4* biases2 = reinterpret_cast(biases); + Half4* q2 = reinterpret_cast(q); + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, H, + biases2, query2, q2); + } else if (0 == (head_size & 1)) { + const int H = head_size / 2; + const half2* query2 = reinterpret_cast(query); + const half2* biases2 = reinterpret_cast(biases); + half2* q2 = reinterpret_cast(q); + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, H, + biases2, query2, q2); + } else { + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, head_size, + biases, query, q); + } +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h index efc31db43bcdb..bd4e123a272bc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h @@ -3,14 +3,15 @@ #pragma once #include "core/providers/cuda/shared_inc/cuda_utils.h" +#include "contrib_ops/cpu/bert/attention_common.h" namespace onnxruntime { namespace contrib { namespace cuda { -// Fused kernel of Add (bias) and Transpose. +// Fused kernel of Add bias (optional, can be None) and Transpose. // Shape of inputs and outputs: -// biases: (num_matrices, num_heads * head_size) +// biases: (num_matrices, num_heads * head_size) or None // format 0: (requires sequence_length = kv_sequence_length and qk_head_size = v_head_size when num_matrices == 3) // input: (num_matrices, batch_size, sequence_length, num_heads, head_size) // output: (num_matrices, batch_size, num_heads, sequence_length, head_size) @@ -24,9 +25,12 @@ namespace cuda { // format 3: (requires sequence_length = kv_sequence_length and qk_head_size = v_head_size when num_matrices == 3) // input: (batch_size, sequence_length, num_matrices, num_heads, head_size) // output: (num_matrices, batch_size, sequence_length, num_heads, head_size) -// format 4: (requires qk_head_size = v_head_size) +// format 4: (requires qk_head_size == v_head_size) // input: (batch_size, sequence_length, num_heads, num_matrices, head_size) // output: (num_matrices, batch_size, sequence_length, num_heads, head_size) +// format 5: (requires qk_head_size == v_head_size) +// input: (batch_size, sequence_length, num_heads, num_matrices, head_size) +// output: (num_matrices, batch_size, num_heads, sequence_length, head_size) template void LaunchAddBiasTranspose( @@ -35,7 +39,7 @@ void LaunchAddBiasTranspose( const T* input, const T* biases, T* output, bool enable_half4, const int v_head_size, T* qkv_add_bias = nullptr, int total_matrix_count = -1, bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0); -// Add (bias) and Transpose for separated inputs of Q, K and V, and output Trt format. +// Add bias (optional, can be None) and Transpose for separated inputs of Q, K and V, and output Trt format. // For self attention: // output: (batch_size, sequence_length, num_heads, 3, head_size) // It assumes sequence_length == kv_sequence_length and head_size == v_head_size. @@ -50,7 +54,7 @@ void LaunchAddBiasTransposeTrt( const T* biases, const T* query, const T* key, const T* value, T* output, bool is_cross_attention, int kv_sequence_length = -1); -// Add (bias) for separated inputs of Q, K and V. +// Add bias (required) for separated inputs of Q, K and V. // Q: (batch_size, sequence_length, num_heads, head_size) // K: (batch_size, kv_sequence_length, num_heads, head_size) // V: (batch_size, kv_sequence_length, num_heads, v_head_size) @@ -61,6 +65,46 @@ void LaunchAddBias( const int num_heads, const int head_size, const int v_head_size, const T* biases, const T* query, const T* key, const T* value, T* q, T* k, T* v); +// Add bias (required) for Q: (batch_size, sequence_length, num_heads, head_size) +template +void LaunchAddBias( + cudaStream_t stream, const int max_threads_per_block, + const int batch_size, const int sequence_length, + const int num_heads, const int head_size, + const T* biases, const T* query, T* q); + +// Add bias (optional, can be None) transpose kernel defined in packed_multihead_attention_impl.cu. +// Support the following format transforms (for float and half only). +// source_format => target_format: +// Q_K_V_TNH => Q_K_V_BNSH (requires token_offset) +// Q_K_V_TNH => Q_K_V_TNH +// Q_K_V_TNH => QKV_TN3H +// QKV_TN3H => Q_K_V_BNSH (requires token_offset) +// QKV_TN3H => Q_K_V_TNH +// QKV_TN3H => QKV_TN3H +template +void AddBiasTransposePacked( + const T* query, const T* key, const T* value, const T* bias, T* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat source_format, AttentionQkvFormat target_format, + const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + +// Add bias (required) transpose kernel defined in packed_attention_impl.cu. +// Support the following format transforms (for float and half only): +// format transform +// Q_K_V_BNSH: Tx3xNxH => 3xBxNxSxH (requires token_offset) +// Q_K_V_BSNH: Tx3xNxH => 3xTxNxH +// QKV_BSN3H: Tx3xNxH => TxNx3xH +template +void AddBiasTransposePacked( + const T* input, const T* biases, T* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat format, const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 3b7f980ba1881..5c0989bced70c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -260,7 +260,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { fused_runner, use_flash_attention, use_fused_cross_attention, - use_memory_efficient_attention); + use_memory_efficient_attention, + false); IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, workSpaceSize, false, context->GetComputeStream()); typedef typename ToCudaType::MappedType CudaT; @@ -281,6 +282,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { } data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); + data.workspace_bytes = workSpaceSize; data.output = reinterpret_cast(output->MutableData()); if (nullptr != present) { data.present = reinterpret_cast(present->MutableData()); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 997493acd9cb7..f9eabe27d97e4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -58,31 +58,25 @@ size_t AlignSize(size_t bytes) { return bytesAligned; } -void CumulatedSequenceLengthCache::Initialize(int32_t seq_length, cudaStream_t stream) { - if (this->sequence_length != seq_length) { - ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0); - LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, - this->max_batch_size, seq_length, stream); - this->sequence_length = seq_length; +const int32_t* CumulatedSequenceLengthCache::TryGet(int batch_size, int32_t seq_len, cudaStream_t stream) { + if (this->sequence_length == 0 && seq_len > 0) { + // Initialize only once with sequence length in the first request. + std::call_once(init_once_flag_, [&]() { + ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0); + LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, + this->max_batch_size, seq_len, stream); + // Syncronize to ensure thread-safe since other thread will not wait for the above kernel finish. + // Otherwise, the data might be consumed by other threads before it is ready and causes data race issue. + cudaStreamSynchronize(stream); + this->sequence_length = seq_len; + }); } -} -int* GetCumulatedSequenceLength(CumulatedSequenceLengthCache* cache, - const int* mask_index, - int batch_size, - int sequence_length, - cudaStream_t stream, - void* scratch_buffer) { - if (mask_index == nullptr && cache != nullptr) { - if (batch_size <= cache->max_batch_size) { - cache->Initialize(sequence_length, stream); - return reinterpret_cast(cache->buffer.get()); - } + if (this->sequence_length == seq_len && batch_size <= this->max_batch_size) { + return reinterpret_cast(buffer.get()); } - int* sequence_offset = reinterpret_cast(scratch_buffer); - LaunchTrtSequenceOffset(sequence_offset, mask_index, batch_size, sequence_length, stream); - return sequence_offset; + return nullptr; } size_t GetAttentionScratchSize( @@ -114,10 +108,12 @@ size_t GetAttentionWorkspaceSize( void* fused_runner, bool use_flash_attention, bool use_fused_cross_attention, - bool use_memory_efficient_attention) { + bool use_memory_efficient_attention, + bool no_qkv_workspace) { // Note that q, k and v might need alignment for fused attention kernels. - const size_t qkv_bytes = element_size * batch_size * num_heads * - ((sequence_length + kv_sequence_length) * qk_head_size + kv_sequence_length * v_head_size); + const size_t qkv_size = element_size * batch_size * num_heads * + ((sequence_length + kv_sequence_length) * qk_head_size + kv_sequence_length * v_head_size); + const size_t qkv_bytes = no_qkv_workspace ? 0 : qkv_size; #if USE_FLASH_ATTENTION if (use_flash_attention) { @@ -162,39 +158,44 @@ Status FusedTrtCrossAttention( // We only enable fused cross attention when there is no key padding mask. // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. assert(data.mask_index == nullptr); - + assert(data.scratch != nullptr); + assert(data.q != nullptr); + assert(data.k != nullptr); + +#ifndef NDEBUG + char* scratch_end = reinterpret_cast(data.scratch) + 2 * GetSequenceOffsetSize(parameters.batch_size, false); + char* buffer_end = reinterpret_cast(data.workspace) + data.workspace_bytes; + assert(scratch_end <= buffer_end); +#endif const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - int* q_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, - data.mask_index, batch_size, - sequence_length, stream, - data.scratch); + int32_t* q_sequence_offset = const_cast(data.cumulated_sequence_length_q_cache); + if (q_sequence_offset == nullptr) { + q_sequence_offset = reinterpret_cast(data.scratch); + LaunchTrtSequenceOffset(q_sequence_offset, data.mask_index, batch_size, sequence_length, stream); + } + + CUDA_RETURN_IF_ERROR(cudaGetLastError()); DUMP_TENSOR_INIT(); DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); - int* kv_sequence_offset = q_sequence_offset + (GetSequenceOffsetSize(batch_size, false) / sizeof(int)); - kv_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_kv_cache, - data.mask_index, batch_size, parameters.kv_sequence_length, stream, - kv_sequence_offset); - CUDA_RETURN_IF_ERROR(cudaGetLastError()); + int32_t* kv_sequence_offset = const_cast(data.cumulated_sequence_length_kv_cache); + if (kv_sequence_offset == nullptr) { + int* scratch = reinterpret_cast(data.scratch) + (GetSequenceOffsetSize(batch_size, false) / sizeof(int)); + kv_sequence_offset = reinterpret_cast(scratch); + LaunchTrtSequenceOffset(kv_sequence_offset, data.mask_index, batch_size, parameters.kv_sequence_length, stream); + } + CUDA_RETURN_IF_ERROR(cudaGetLastError()); DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel = reinterpret_cast(data.fused_cross_attention_kernel); - // When there is no bias, we can directly use q and packed kv from inputs. - void const* query = data.q; - void const* packed_kv = data.k; - if (data.value == nullptr && data.bias == nullptr) { - query = data.query; - packed_kv = data.key; - } - run_fused_cross_attention( - query, // Q - packed_kv, // packed KV + data.q, // Q + data.k, // packed KV q_sequence_offset, // cumulated sequence length of Q kv_sequence_offset, // cumulated sequence length of KV data.output, // output @@ -206,8 +207,6 @@ Status FusedTrtCrossAttention( parameters.kv_sequence_length, // sequence length of KV stream); - DUMP_TENSOR("trt cross output", data.output, - batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); return Status::OK(); } @@ -225,24 +224,33 @@ Status FusedTrtSelfAttention( cudaStream_t stream, contrib::AttentionParameters& parameters, AttentionData& data) { + assert(data.scratch != nullptr); +#ifndef NDEBUG + char* scratch_end = reinterpret_cast(data.scratch) + GetSequenceOffsetSize(parameters.batch_size, false); + char* buffer_end = reinterpret_cast(data.workspace) + data.workspace_bytes; + assert(scratch_end <= buffer_end); +#endif + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const bool causal = parameters.is_unidirectional; - int* sequence_offset = reinterpret_cast(data.scratch); - - DUMP_TENSOR_INIT(); + const int32_t* sequence_offset = data.cumulated_sequence_length_q_cache; if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { - DUMP_TENSOR_D("mask", reinterpret_cast(data.mask_index), batch_size, sequence_length); - LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream); + LaunchTrtSequenceOffset2d(reinterpret_cast(data.scratch), data.mask_index, batch_size, sequence_length, stream); + sequence_offset = reinterpret_cast(data.scratch); } else { - sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, - data.mask_index, batch_size, sequence_length, stream, - sequence_offset); + if (sequence_offset == nullptr) { + LaunchTrtSequenceOffset(reinterpret_cast(data.scratch), data.mask_index, batch_size, sequence_length, stream); + sequence_offset = reinterpret_cast(data.scratch); + } } - DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1); + FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(data.fused_runner); const int s = causal ? sequence_length : fused_fp16_runner->NormalizeSequenceLength(sequence_length); @@ -252,22 +260,12 @@ Status FusedTrtSelfAttention( if (!causal) { assert(data.qkv_format == AttentionQkvFormat::QKV_BSN3H); - - // When there is no bias, we can directly use packed qkv from inputs. - void const* packed_qkv = data.q; - if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) { - packed_qkv = data.query; - } - - fused_fp16_runner->Run(b, s, packed_qkv, sequence_offset, data.output, stream); - DUMP_TENSOR("fused output", data.output, - batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); + fused_fp16_runner->Run(b, s, data.q, sequence_offset, data.output, stream); } else { assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); fused_fp16_runner->Run(b, s, data.gemm_buffer, sequence_offset, data.output, stream); - DUMP_TENSOR("fused causal output", data.output, - batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); } + return Status::OK(); } @@ -289,38 +287,19 @@ Status FlashAttention( contrib::AttentionParameters& parameters, AttentionData& data, float scale) { - assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); assert(nullptr == data.mask_index); assert(nullptr == data.relative_position_bias); assert(parameters.head_size == parameters.v_head_size); - void* query = reinterpret_cast(data.q); - void* key = reinterpret_cast(data.k); - void* value = reinterpret_cast(data.v); - // For packed KV, we can use query input directly. - if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) { - query = reinterpret_cast(const_cast(data.query)); - } - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), - parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size); - DUMP_TENSOR_D("k(BSNH)", data.k, - parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size); - DUMP_TENSOR_D("v(BSNH)", data.v, - parameters.batch_size, parameters.total_sequence_length, - parameters.num_heads, parameters.v_head_size); - - bool is_bf16 = false; + constexpr bool is_bf16 = false; ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, query, key, value, data.output, reinterpret_cast(data.scratch), + device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast(data.scratch), parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, is_bf16, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), - true)); - - DUMP_TENSOR("flash attention output", data.output, - parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH)); return Status::OK(); } @@ -351,25 +330,8 @@ Status EfficientAttention( float scale) { // We only enable fused cross attention when there is no key padding mask. // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. - assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - - const void* query = data.q; - const void* key = data.k; - const void* value = data.v; - // For packed KV, we can use query input directly. - if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) { - assert(data.bias == nullptr); - query = data.query; - } - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), - parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size); - DUMP_TENSOR_D("k(BSNH)", data.k, - parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size); - DUMP_TENSOR_D("v(BSNH)", data.v, - parameters.batch_size, parameters.total_sequence_length, - parameters.num_heads, parameters.v_head_size); + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; @@ -394,21 +356,19 @@ Status EfficientAttention( ? nullptr : const_cast(reinterpret_cast( data.mask_index + 2 * parameters.batch_size + 1)); - p.query = query; - p.key = key; - p.value = value; + p.query = data.q; + p.key = data.k; + p.value = data.v; p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; - p.is_kv_bsnh = true; + p.is_kv_bsnh = data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH; p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float)) ? data.scratch : nullptr; p.stream = stream; p.has_custom_right_padding = false; run_memory_efficient_attention(p); - DUMP_TENSOR("efficient attention output", data.output, - parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); return Status::OK(); } @@ -449,10 +409,6 @@ Status UnfusedAttention( cublasSetStream(cublas, stream); - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("q[BNSH]", data.q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k[BNSH]", data.k, batch_size, num_heads, total_sequence_length, qk_head_size); - const int present_sequence_length = parameters.past_present_share_buffer ? parameters.max_sequence_length : total_sequence_length; @@ -467,8 +423,7 @@ Status UnfusedAttention( &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop, parameters.use_tf32)); - DUMP_TENSOR_D("Q", data.q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("K", data.k, batch_size, num_heads, qk_head_size, sequence_length); + DUMP_TENSOR_INIT(); DUMP_TENSOR_D("QK", data.scratch, batch_size, num_heads, sequence_length, total_sequence_length); constexpr size_t element_size = sizeof(T); @@ -523,7 +478,6 @@ Status UnfusedAttention( // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, device_prop.maxThreadsPerBlock, false, temp_output, data.output); - DUMP_TENSOR("unfused output", data.output, batch_size, sequence_length, num_heads, v_head_size); return result; } @@ -554,7 +508,7 @@ Status QkvToContext( if (!parameters.past_present_share_buffer) { ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size, - sequence_length, total_sequence_length, parameters.pass_past_in_kv, + sequence_length, total_sequence_length, stream, max_threads_per_block, data)); } else { // past_present_share_buffer diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 56836bdda197c..fad353dcfeb07 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -6,6 +6,8 @@ #include #include #include +#include +#include #include "core/framework/allocator.h" #include "contrib_ops/cpu/bert/attention_common.h" @@ -15,13 +17,18 @@ namespace cuda { constexpr int kCumulatedSequenceLengthCacheMaxBatchSize = 128; +// A cache for cumulated sequence length. It will be initialized in the first request, then become read-only after that. struct CumulatedSequenceLengthCache { onnxruntime::IAllocatorUniquePtr buffer; int32_t max_batch_size; int32_t sequence_length; - CumulatedSequenceLengthCache() : max_batch_size(0), sequence_length(0) {} - void Initialize(int32_t sequence_length, cudaStream_t stream); + CumulatedSequenceLengthCache() : max_batch_size(kCumulatedSequenceLengthCacheMaxBatchSize), sequence_length(0) {} + + const int32_t* TryGet(int batch_size, int32_t sequence_length, cudaStream_t stream); + + // Use this flag to guard the initializaton only once in multi-threading. + mutable std::once_flag init_once_flag_; }; size_t @@ -46,7 +53,8 @@ size_t GetAttentionWorkspaceSize( void* fused_runner, bool use_flash_attention, bool use_fused_cross_attention, - bool use_memory_efficient_attention); + bool use_memory_efficient_attention, + bool no_qkv_workspace); template struct AttentionData { @@ -65,8 +73,6 @@ struct AttentionData { bool has_qkv_workspace = false; T* workspace = nullptr; - T* temp_k_workspace = nullptr; - T* temp_v_workspace = nullptr; T* output = nullptr; T* present = nullptr; @@ -79,22 +85,50 @@ struct AttentionData { bool use_flash_attention = false; bool use_memory_efficient_attention = false; - mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache = nullptr; - mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache = nullptr; + const int32_t* cumulated_sequence_length_q_cache = nullptr; + const int32_t* cumulated_sequence_length_kv_cache = nullptr; // Intermediate data T* q = nullptr; T* k = nullptr; T* v = nullptr; T* scratch = nullptr; - AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + AttentionQkvFormat qkv_format = AttentionQkvFormat::UNKNOWN; // Flash buffers T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; T* out_accum = nullptr; + + // For Debugging + size_t workspace_bytes = 0; + bool allow_debug_info = false; + + bool IsUnfused() const { + return !use_flash_attention && !use_memory_efficient_attention && + (fused_runner == nullptr) && (fused_cross_attention_kernel == nullptr); + } + + void PrintDebugInfo() const { + std::cout << "flash=" << use_flash_attention + << ", efficient=" << use_memory_efficient_attention + << ", fused_runner=" << (fused_runner != nullptr) + << ", fused_cross=" << (fused_cross_attention_kernel != nullptr) + << ", bias=" << (bias != nullptr) + << ", attn_bias=" << (relative_position_bias != nullptr) + << ", mask_dims=" << mask_index_dims.size() + << ", has_qkv_workspace=" << has_qkv_workspace + << ", workspace=" << workspace_bytes + << ", past=" << (past != nullptr ? 1 : (past_key != nullptr ? 2 : 0)) + << ", present=" << (present != nullptr ? 1 : (present_key != nullptr ? 2 : 0)) + << std::endl; + } }; +// Return true if it does not need qkv workspace, false otherwise. +template +bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData& data); + template Status PrepareQkv(contrib::AttentionParameters& parameters, AttentionData& data, @@ -129,6 +163,9 @@ Status LaunchTransQkv(cudaStream_t stream, const int matrix_num, const int max_threads_per_block, const bool reversed_bs, const half* input, half* output, int total_matrix_count = -1); +Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, + const float* input, float* output, cudaStream_t stream, const int max_threads_per_block); + Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, const half* input, half* output, cudaStream_t stream, const int max_threads_per_block); @@ -158,7 +195,7 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream, template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, - int sequence_length, int total_sequence_length, bool pass_past_in_kv, + int sequence_length, int total_sequence_length, cudaStream_t stream, int max_threads_per_block, AttentionData& data); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h index bd7df5f490c76..aba1e01bfd91b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h @@ -50,6 +50,7 @@ class AttentionKernelOptions { bool use_unfused_{true}; bool use_trt_flash_attention_{true}; + bool use_trt_cross_attention_{true}; // Causal attention is disabled by default in #14732. diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu index 89be0f1115f41..9f0f49348c225 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu @@ -249,16 +249,15 @@ Status LaunchConcatPastToPresent(cudaStream_t stream, template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, - int sequence_length, int total_sequence_length, bool pass_past_in_kv, - cudaStream_t stream, - int max_threads_per_block, + int sequence_length, int total_sequence_length, + cudaStream_t stream, int max_threads_per_block, AttentionData& data) { // Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length. // past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH) // past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH) // When there is past state, the head size for Q/K/V shall be same: H == H_v. - if (nullptr != data.present) { + if (nullptr != data.present) { // Attention op assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH || data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); @@ -270,58 +269,52 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int // Update pointers to present_k and present_v. data.k = data.present; data.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size; - } else if (nullptr != data.past_key || nullptr != data.present_key) { - if (nullptr != data.past_key && nullptr == data.present_key) { - data.k = const_cast(data.past_key); - data.v = const_cast(data.past_value); - } else if (nullptr == data.past_key && nullptr != data.present_key) { - if (data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) { + } else { // MultiHeadAttention op + if (nullptr != data.present_key) { + ORT_ENFORCE(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); + if (nullptr != data.past_key) { + assert(data.past_key != data.k); + assert(data.past_value != data.v); + + ORT_RETURN_IF_ERROR( + LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, + batch_size, qk_head_size, num_heads, + max_threads_per_block, 1, data.past_key, data.k, data.present_key)); + ORT_RETURN_IF_ERROR( + LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, + batch_size, v_head_size, num_heads, + max_threads_per_block, 1, data.past_value, data.v, data.present_value)); + // Update pointers to present_k and present_v. data.k = data.present_key; data.v = data.present_value; - } else { - assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - data.k = data.temp_k_workspace; - data.v = data.temp_v_workspace; + } else { // nullptr == data.past_key && nullptr != data.present_key + if (data.k != data.present_key) { + int64_t k_size = (int64_t)batch_size * num_heads * total_sequence_length * qk_head_size; + cudaMemcpyAsync(data.present_key, data.k, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); + } + + if (data.v != data.present_value) { + int64_t v_size = (int64_t)batch_size * num_heads * total_sequence_length * v_head_size; + cudaMemcpyAsync(data.present_value, data.v, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); + } } - } else if (pass_past_in_kv) { - // past_key and past_value are used directly as key and value in attention computations - data.k = const_cast(data.past_key); - data.v = const_cast(data.past_value); - - // This path has a memory copy from past_key and past_value to present_key and present_value - // Avoid this path since the memory copy is unnecessary because past_key == present_key and - // past_value == present_value - int64_t k_size = (int64_t)batch_size * num_heads * total_sequence_length * qk_head_size; - int64_t v_size = (int64_t)batch_size * num_heads * total_sequence_length * v_head_size; - cudaMemcpyAsync(data.present_key, data.past_key, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - cudaMemcpyAsync(data.present_value, data.past_value, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - } else { - ORT_RETURN_IF_ERROR( - LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, - batch_size, qk_head_size, num_heads, - max_threads_per_block, 1, data.past_key, data.k, data.present_key)); - ORT_RETURN_IF_ERROR( - LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, - batch_size, v_head_size, num_heads, - max_threads_per_block, 1, data.past_value, data.v, data.present_value)); - // Update pointers to present_k and present_v. - data.k = data.present_key; - data.v = data.present_value; } } + return CUDA_CALL(cudaGetLastError()); } // Template Instantiation template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, - int sequence_length, int total_sequence_length, bool pass_past_in_kv, + int sequence_length, int total_sequence_length, cudaStream_t stream, int max_threads_per_block, AttentionData& data); template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, - int sequence_length, int total_sequence_length, bool pass_past_in_kv, + int sequence_length, int total_sequence_length, cudaStream_t stream, int max_threads_per_block, AttentionData& data); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index 040d6124e7456..05c592ec61059 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -12,12 +12,101 @@ namespace onnxruntime { namespace contrib { namespace cuda { +#if DEBUG_TENSOR_LEVEL > 1 +// Dump the workspace for Q, K, V after processing QKV data. +template +void DumpQkv(AttentionData& data) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + DUMP_TENSOR_INIT(); + if (data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) { + DUMP_TENSOR_D("q(BNSH)", data.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", data.k, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", data.v, batch_size, num_heads, kv_sequence_length, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) { + DUMP_TENSOR_D("q(BSNH)", data.q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", data.k, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", data.v, batch_size, kv_sequence_length, num_heads, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH) { + DUMP_TENSOR_D("q(BNSH)", data.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", data.k, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", data.v, batch_size, num_heads, kv_sequence_length, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::QKV_BSN3H) { + DUMP_TENSOR_D("q(BSN3H)", data.q, batch_size, sequence_length, num_heads * 3, qk_head_size); + } +} + +// Dump the inputs before processing QKV data. +template +void DumpInputs(contrib::AttentionParameters& parameters, AttentionData& data) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + DUMP_TENSOR_INIT(); + if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) { + DUMP_TENSOR_D("Query(BNSH)", data.query, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("Key(BNSH)", data.key, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("Value(BNSH)", data.value, batch_size, num_heads, kv_sequence_length, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) { + DUMP_TENSOR_D("Query(BSNH)", data.query, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("Key(BSNH)", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("Value(BSNH)", data.value, batch_size, kv_sequence_length, num_heads, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH) { + DUMP_TENSOR_D("Query(BNSH)", data.query, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("Key(BNSH)", data.key, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("Value(BNSH)", data.value, batch_size, num_heads, kv_sequence_length, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::QKV_BSN3H) { + DUMP_TENSOR_D("Query(BSN3H)", data.query, batch_size, sequence_length, num_heads * 3, qk_head_size); + } else if (data.qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H) { + DUMP_TENSOR_D("Query(BNSH)", data.query, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("Value(BSN2H)", data.value, batch_size, sequence_length, num_heads * 2, qk_head_size); + } + + if (data.bias != nullptr) { + DUMP_TENSOR_D("Q_bias", data.bias, num_heads, qk_head_size); + DUMP_TENSOR_D("K_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); + DUMP_TENSOR_D("V_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); + } + + if (data.relative_position_bias != nullptr) { + DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, + parameters.broadcast_res_pos_bias ? 1 : batch_size, + num_heads, sequence_length, kv_sequence_length); + } + + if (data.mask_index != nullptr) { + if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { + DUMP_TENSOR_D("mask", data.mask_index, batch_size, parameters.total_sequence_length); + } + if (parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { + DUMP_TENSOR_D("mask", data.mask_index, 3 * batch_size + 2, 1); + } + } +} + +// Dump the kernel outputs +template +void DumpOutputs(AttentionData& data) { + DUMP_TENSOR_INIT(); + DUMP_TENSOR("output", data.output, + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); +} +#endif + template Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - AttentionQkvFormat& qkv_format) { + int max_threads_per_block) { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int num_heads = parameters.num_heads; @@ -40,7 +129,7 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, int matrix_to_trans = (past_present_share_buffer ? 1 : 3); ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads, max_threads_per_block, false, data.gemm_buffer, qkv, 3)); - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; } else { // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2) // For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3) @@ -48,13 +137,13 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, // For fused causal kernel, use format 1 since we need have K and V to update present state, // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel. const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1)); - qkv_format = use_fused_kernel - ? AttentionQkvFormat::QKV_BSN3H - : (use_flash_or_efficient_attention - ? AttentionQkvFormat::Q_K_V_BSNH - : (use_fused_causal - ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH - : AttentionQkvFormat::Q_K_V_BNSH)); + data.qkv_format = use_fused_kernel + ? AttentionQkvFormat::QKV_BSN3H + : (use_flash_or_efficient_attention + ? AttentionQkvFormat::Q_K_V_BSNH + : (use_fused_causal + ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH + : AttentionQkvFormat::Q_K_V_BNSH)); // For fused causal, we will update gemm_buffer with bias directly. T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr; @@ -71,367 +160,526 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, return Status::OK(); } -// For MultiHeadAttention with past state +// Return true if the workspace is not needed for Q, K, V inputs, false otherwise. +// This shall be in sync with the following function PrepareQkv_MHA_Cross. template -Status PrepareQkv_MHA_WithPast(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { +bool NoQkvWorkspace_MHA_Cross(AttentionData& data) { + // query, key and value are passed as Q, K and V for the following conditions. + return (data.use_memory_efficient_attention || data.use_flash_attention) && (data.bias == nullptr); +} + +// For MultiHeadAttention with cross attention (Q_K_V_BSNH_BNSH_BNSH format) +template +Status PrepareQkv_MHA_Cross(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); + // past_key or past_value is not supported for cross attention + // present_key and present_value can be supported in theory, although we do not allow the senario for now. + assert(data.past_key == nullptr); + assert(data.past_value == nullptr); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_Cross(data)); + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - - DUMP_TENSOR_INIT(); - if (data.bias == nullptr) { - // Below logic does not support fused attention with past without bias - // When there is past state, the format shall be BxNxSxH, so we disable fused attention when there is past. - - // cross attention with past state - if (data.past_key != nullptr && data.present_key == nullptr) { - assert(data.past_value != nullptr); - assert(data.query != nullptr); - assert(data.key == nullptr); - assert(data.value == nullptr); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // Add bias for Q + if (data.bias != nullptr) { + LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, + data.bias, data.query, data.q); + } else { + data.q = const_cast(data.query); } - // cross attention with present state or self attention with present state - else if (data.past_key == nullptr && data.present_key != nullptr) { - assert(data.past_value == nullptr); - assert(data.present_value != nullptr); - assert(data.query != nullptr); - assert(data.key != nullptr); - assert(data.value != nullptr); - - // TODO: supporting packed qkv for self attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - // TODO: supporting packed kv for cross attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.key, data.present_key)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.value, data.present_value)); - } - // self attention with past and present state - else { - assert(data.past_key != nullptr); - assert(data.past_value != nullptr); - assert(data.present_key != nullptr); - assert(data.present_value != nullptr); - assert(data.query != nullptr); - assert(data.key != nullptr); - assert(data.value != nullptr); - // TODO: supporting packed qkv for self attention may benefit performance + // Here we have assumption that there is no bias for key and value when they are in BNSH format. + data.k = const_cast(data.key); + data.v = const_cast(data.value); + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; + } else +#endif + { // unfused kernel + assert(data.IsUnfused()); + if (data.bias == nullptr) { + // Transpose query from BSNH to BNSH ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.key, k)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.value, v)); + max_threads_per_block, false, data.query, data.q)); + } else { + // Add bias to query, and transpose it: Query (BxSxNxH) => Q (BxNxSxH) + constexpr int format = 0; + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, data.q, + true, -1); } - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + + // Here we have assumption that there is no bias for key and value when they are in BNSH format. + // So we do not need to add bias for key and value. Just use the key and value directly. + data.k = const_cast(data.key); + data.v = const_cast(data.value); + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + return Status::OK(); +} + +template +bool NoQkvWorkspace_MHA_NoPast(AttentionData& data) { + // query, key and value are passed as Q, K and V for the following conditions. + return (data.use_memory_efficient_attention || data.use_flash_attention) && data.bias == nullptr; +} + +// For MultiHeadAttention without past state, with Q, K and V inputs +template +Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(data.query != nullptr); + assert(data.key != nullptr); + assert(data.value != nullptr); + assert(data.past_key == nullptr); + assert(data.past_value == nullptr); + assert(data.present_key == nullptr); + assert(data.present_value == nullptr); + assert(!parameters.is_unidirectional); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_NoPast(data)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + if (data.fused_cross_attention_kernel != nullptr) { + assert(qk_head_size == v_head_size); + assert(data.relative_position_bias == nullptr); + assert(data.mask_index == nullptr); + assert(parameters.hidden_size == parameters.v_hidden_size); + + // For fused cross attention, besides adding bias, K and V needed to be packed: + // Key (BxSxNxH), Value (BxSxNxH) => Q (BxSxNxH), K (BxSxNx2xH) + LaunchAddBiasTransposeTrt( + stream, max_threads_per_block, + batch_size, sequence_length, + num_heads, qk_head_size, + data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length); + data.v = nullptr; + data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; } #if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - // When past_key/past_value are inputted directly as key/value and there is no present_key/present_value - else if ((data.use_memory_efficient_attention || data.use_flash_attention) && - data.past_key != nullptr && - data.past_value != nullptr && - parameters.pass_past_in_kv) { - // Transpose past_key and past_value to use memory efficient attention - - // past_key (BxNxSxH) => temp_k_workspace (BxSxNxH) - ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.past_key, data.temp_k_workspace)); - // past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) - ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.past_value, data.temp_v_workspace)); - - // query => q, temp_k_workspace => k, temp_v_workspace => v - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - - data.past_key = nullptr; - data.past_value = nullptr; + else if (data.use_memory_efficient_attention || data.use_flash_attention) { + if (data.bias != nullptr) { + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.key, data.value, data.q, data.k, data.v); + } else { + data.q = const_cast(data.query); + data.k = const_cast(data.key); + data.v = const_cast(data.value); + } + + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH; } - // When there is no past_key/past_value and there is present_key/present_value - // (e.g. get initial kv to use as past_kv in the next iteration) - else if ((data.use_memory_efficient_attention || data.use_flash_attention) && - data.present_key != nullptr && - data.present_value != nullptr) { - // Use memory efficient attention kernel - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace); - - // temp_k_workspace (BxSxNxH) => present_k (BxNxSxH) - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.temp_k_workspace, data.present_key)); +#endif + else if (data.fused_runner != nullptr) { + assert(qk_head_size == v_head_size); + assert(data.relative_position_bias == nullptr); + + // Query (BxSxNxH), Key (BxSxNxH), Value (BxSxNxH) => Q: BxSxNx(H + H + H) + LaunchAddBiasTransposeTrt( + stream, max_threads_per_block, + batch_size, sequence_length, + num_heads, qk_head_size, + data.bias, data.query, data.key, data.value, data.q, false, kv_sequence_length); + data.k = nullptr; + data.v = nullptr; + + data.qkv_format = AttentionQkvFormat::QKV_BSN3H; + } else { // unfused kernel + assert(data.IsUnfused()); + // Query (BxSxNxH) => Q (BxNxSxH) + constexpr int format = 0; + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, data.q, + true, -1); + + // Key (BxLxNxH) => K (BxNxLxH) + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, data.k, + true, -1); - // temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v) + // Value (BxLxNxH_v) => K (BxNxLxH_v) + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, data.v, + true, -1); + + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + + return Status::OK(); +} + +template +bool NoQkvWorkspace_MHA_WithPast_NoBias(AttentionData& data) { + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // Q, K and V redirects to query, present_k and present_v, so we do not need extra workspace for QKV. + return data.past_key == nullptr && data.present_key != nullptr; + } + return false; +} + +// For MultiHeadAttention with kv cache (past or present), but no bias +template +Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(data.query != nullptr); + assert(data.key != nullptr); + assert(data.value != nullptr); + assert(data.bias == nullptr); + assert(data.fused_runner == nullptr); + assert(data.fused_cross_attention_kernel == nullptr); + assert(data.present_key != nullptr); + assert(data.present_value != nullptr); + assert(data.past_key == nullptr && data.past_value == nullptr || + data.past_key != nullptr && data.past_value != nullptr); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_WithPast_NoBias(data)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + // When there is no past state and there is present state, we output K and V directly to present state. + if (data.past_key == nullptr && data.present_key != nullptr) { + data.k = data.present_key; + data.v = data.present_value; + } + +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // Use oiginal Query (BSNH) since there is no bias. + data.q = const_cast(data.query); + + // Key (BxLxNxH) => K (BxNxLxH) + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.key, data.k)); + // Value (BxLxNxH) => V (BxNxLxH) ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.temp_v_workspace, data.present_value)); + max_threads_per_block, false, data.value, data.v)); + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; + } else +#endif + { // unfused kernel + assert(data.IsUnfused()); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.query, data.q)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.key, data.k)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, + max_threads_per_block, false, data.value, data.v)); + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + + return Status::OK(); +} - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; +template +constexpr bool NoQkvWorkspace_MHA_WithPast_Bias(AttentionData& /*data*/) { + return false; +} + +// For MultiHeadAttention with both kv cache (past or present) and bias +template +Status PrepareQkv_MHA_WithPast_Bias(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(data.bias != nullptr); + assert(!(data.past_key != nullptr && data.present_key == nullptr)); + assert(data.fused_runner == nullptr); + assert(data.fused_cross_attention_kernel == nullptr); + assert(data.present_key != nullptr); + assert(data.present_value != nullptr); + assert(data.past_key == nullptr && data.past_value == nullptr || + data.past_key != nullptr && data.past_value != nullptr); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_WithPast_Bias(data)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + // When there is no past state and there is present state, we output K and V directly to present state. + if (data.past_key == nullptr && data.present_key != nullptr) { + data.k = data.present_key; + data.v = data.present_value; } + +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // Query(BxSxNxH) + Bias_Q => Q (BxSxNxH) + LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, + data.bias, data.query, data.q); + + // Key (BxLxNxH) + Bias_K => K (BxNxLxH) + constexpr int format = 0; + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, data.bias + num_heads * qk_head_size, data.k, true, -1); + + // Key (BxLxNxH) + Bias_K => K (BxNxLxH) + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, data.bias + 2 * num_heads * qk_head_size, data.v, true, -1); + + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; + } else #endif - else { - // Use unfused kernel for Q, use unfused kernel for K and V if needed + { // unfused kernel + assert(data.IsUnfused()); + constexpr int format = 0; // Query (BxSxNxH) => Q (BxNxSxH) LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, + data.query, data.bias, data.q, true, -1); - if (!parameters.pass_past_in_kv) { - T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k; - T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v; - - // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, data.bias + num_heads * qk_head_size, k_dest, - true, -1); + // Key (BxLxNxH) => K (BxNxLxH) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, data.bias + num_heads * qk_head_size, data.k, + true, -1); - // Value (BxLxNxH_v) => V (BxNxLxH_v) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, data.bias + 2 * num_heads * qk_head_size, v_dest, - true, -1); + // Value (BxLxNxH_v) => V (BxNxLxH_v) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, data.bias + 2 * num_heads * qk_head_size, data.v, + true, -1); - DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size, num_heads, kv_sequence_length, qk_head_size); - DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size, num_heads, kv_sequence_length, v_head_size); - } - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; } + return Status::OK(); } +template +bool NoQkvWorkspace_MHA_PackedQKV(AttentionData& data) { + // query, key and value are passed as Q, K and V for the following conditions. + return nullptr != data.fused_runner && data.bias == nullptr; +} + // For MultiHeadAttention without past state, with packed QKV inputs template Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - AttentionQkvFormat& qkv_format) { + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::QKV_BSN3H); + assert(data.past_key == nullptr); + assert(data.past_value == nullptr); + assert(data.present_key == nullptr); + assert(data.present_value == nullptr); + assert(parameters.head_size == parameters.v_head_size); + assert(data.fused_cross_attention_kernel == nullptr); + assert(!parameters.is_unidirectional); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_PackedQKV(data)); + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; - void* fused_runner = data.fused_runner; - - T* qkv = data.workspace; - - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - - assert(data.bias == nullptr); - assert(qk_head_size == v_head_size); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size); if (data.use_memory_efficient_attention || data.use_flash_attention) { - // unpack qkv to BSNH. Note that there is no bias so we need not output query to q. + // unpack qkv to BSNH. constexpr int format = 4; T* qkv_add_bias = nullptr; LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, qkv, + data.query, data.bias, data.q, true, v_head_size, qkv_add_bias, 3); - DUMP_TENSOR_D("q(BSNH)", data.q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", data.k, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", data.v, batch_size, sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } else { - if (!use_fused_kernel) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, NOT_IMPLEMENTED, - "packed QKV format is not implemented for current GPU. Please disable it in fusion options."); + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else if (nullptr != data.fused_runner) { + assert(nullptr == data.relative_position_bias); + if (data.bias == nullptr) { + // When there is no bias, we can directly use the original packed QKV input. + // Need revisit this when we add support for causal. + data.q = const_cast(data.query); + data.k = nullptr; + data.v = nullptr; + } else { // data.bias != nullptr + AddBiasTransposePacked( + data.query, data.key, data.value, data.bias, data.q, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + AttentionQkvFormat::QKV_TN3H, AttentionQkvFormat::QKV_TN3H, + nullptr, batch_size * sequence_length, + stream); } - qkv_format = AttentionQkvFormat::QKV_BSN3H; + data.qkv_format = AttentionQkvFormat::QKV_BSN3H; + } else { // unfused kernel + assert(data.IsUnfused()); + // unpack qkv to BNSH + constexpr int format = 5; + T* qkv_add_bias = nullptr; + LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, data.q, + true, v_head_size, qkv_add_bias, 3); + + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; } + return Status::OK(); } +// This shall be in sync with the following function PrepareQkv_MHA_PackedQKV. +template +bool NoQkvWorkspace_MHA_PackedKV(AttentionData& data) { + return data.fused_cross_attention_kernel != nullptr; +} + // For MultiHeadAttention without past state, with packed KV inputs template Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - AttentionQkvFormat& qkv_format) { + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); + assert(data.bias == nullptr); + assert(data.past_key == nullptr); + assert(data.past_value == nullptr); + assert(data.present_key == nullptr); + assert(data.present_value == nullptr); + assert(parameters.head_size == parameters.v_head_size); + assert(data.fused_runner == nullptr); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_PackedKV(data)); + const int batch_size = parameters.batch_size; const int kv_sequence_length = parameters.kv_sequence_length; const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; - // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint. - // CheckInputs verified this constraint. - assert(data.bias == nullptr); - assert(qk_head_size == v_head_size); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size); - if (data.use_memory_efficient_attention || data.use_flash_attention) { - // unpack kv to BSNH. Note that there is no bias so we need not output query to q. + // Note that there is no bias so we need not output query to q. + data.q = const_cast(data.query); + // Unpack kv to BSNH. constexpr int format = 4; T* qkv_add_bias = nullptr; const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, batch_size, kv_sequence_length, num_heads, qk_head_size, data.key, kv_bias, data.k, - true, v_head_size, qkv_add_bias, 2); - DUMP_TENSOR_D("k(BSNH)", data.k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", data.v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } else { - if (data.fused_cross_attention_kernel == nullptr) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, NOT_IMPLEMENTED, - "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options."); - } + true, v_head_size, qkv_add_bias); + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else if (data.fused_cross_attention_kernel != nullptr) { + data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + data.q = const_cast(data.query); + data.k = const_cast(data.key); + data.v = nullptr; + } else { // unfused kernel + assert(data.IsUnfused()); + // Transpose q from BSNH to BNSH. Note that there is no bias. + ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(batch_size, parameters.sequence_length, num_heads, qk_head_size, + data.query, data.q, stream, max_threads_per_block)); - qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + // Unpack kv to BNSH. + constexpr int format = 5; + T* qkv_add_bias = nullptr; + const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); + LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, kv_bias, data.k, + true, v_head_size, qkv_add_bias, 2); + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; } + return Status::OK(); } -// For MultiHeadAttention without past state, with Q, K and V inputs +// Prepare Q, K and V for MultiHeadAttention operator. template -Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - void* fused_runner = data.fused_runner; - - T* qkv = data.workspace; - - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - - // gemm_buffer == nullptr and not packed - assert(data.query != nullptr && data.key != nullptr && data.value != nullptr); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("query", data.query, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("key", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("value", data.value, batch_size, kv_sequence_length, num_heads, v_head_size); - -#if DUMP_TENSOR_LEVEL > 1 - if (data.bias != nullptr) { - DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size); - DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); - DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); - } -#endif - - if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) { - DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, - num_heads, sequence_length, kv_sequence_length); - } - - if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { - DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1); - } - - if (data.fused_cross_attention_kernel != nullptr) { - assert(qk_head_size == v_head_size); - - // For fused cross attention, besides adding bias, K and V needed to be packed: - // K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH - LaunchAddBiasTransposeTrt( - stream, max_threads_per_block, - batch_size, sequence_length, - num_heads, qk_head_size, - data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length); - - qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; - } -#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - else if (data.use_memory_efficient_attention || data.use_flash_attention) { - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.key, data.value, q, k, v); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; +Status PrepareQkv_MultiHeadAttention(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + switch (parameters.qkv_format) { + case AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH: + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_Cross(parameters, data, stream, max_threads_per_block)); + break; + case AttentionQkvFormat::Q_KV_BSNH_BSN2H: + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block)); + break; + case AttentionQkvFormat::QKV_BSN3H: + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block)); + break; + case AttentionQkvFormat::Q_K_V_BSNH: + if (data.past_key != nullptr || data.present_key != nullptr) { + if (data.bias == nullptr) { + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_NoBias(parameters, data, stream, max_threads_per_block)); + } else { + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_Bias(parameters, data, stream, max_threads_per_block)); + } + } else { // no past state + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NoPast(parameters, data, stream, max_threads_per_block)); + } + break; + default: + ORT_THROW("Unsupported QKV format: ", parameters.qkv_format); } -#endif - else if (use_fused_kernel) { - assert(qk_head_size == v_head_size); - - // Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H) - LaunchAddBiasTransposeTrt( - stream, max_threads_per_block, - batch_size, sequence_length, - num_heads, qk_head_size, - data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length); - DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size); - - qkv_format = AttentionQkvFormat::QKV_BSN3H; - } else { // unfused kernel - ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal"); - - // Query (BxSxNxH) => Q (BxNxSxH) - constexpr int format = 0; - LaunchAddBiasTranspose( - stream, 1, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, - true, -1); - - // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose( - stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k, - true, -1); - - // Value (BxLxNxH_v) => K (BxNxLxH_v) - LaunchAddBiasTranspose( - stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v, - true, -1); + return Status::OK(); +} - DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size); - DUMP_TENSOR_D("v(BNSH)", v, batch_size, num_heads, kv_sequence_length, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; +// Check whether there is no needed to have workspace for Q, K and V for MultiHeadAttention operator. +// Please make it in sync with PrepareQkv_MultiHeadAttention. +template +bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData& data) { + switch (parameters.qkv_format) { + case AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH: + return NoQkvWorkspace_MHA_Cross(data); + case AttentionQkvFormat::Q_KV_BSNH_BSN2H: + return NoQkvWorkspace_MHA_PackedKV(data); + case AttentionQkvFormat::QKV_BSN3H: + return NoQkvWorkspace_MHA_PackedQKV(data); + case AttentionQkvFormat::Q_K_V_BSNH: + if (data.past_key != nullptr || data.present_key != nullptr) { + if (data.bias == nullptr) { + return NoQkvWorkspace_MHA_WithPast_NoBias(data); + } else { + return NoQkvWorkspace_MHA_WithPast_Bias(data); + } + } else { // no past state + return NoQkvWorkspace_MHA_NoPast(data); + } + default: + ORT_THROW("Unsupported QKV format: ", parameters.qkv_format); } - return Status::OK(); } template @@ -439,7 +687,6 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, int max_threads_per_block) { - data.scratch = data.workspace; if (data.has_qkv_workspace) { const int size_per_batch_q = parameters.sequence_length * parameters.head_size; const int size_per_batch_k = parameters.kv_sequence_length * parameters.head_size; @@ -452,28 +699,37 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, data.k = data.workspace + elements_q; data.v = data.k + elements_k; data.scratch = data.v + elements_v; + } else { + data.q = nullptr; + data.k = nullptr; + data.v = nullptr; + data.scratch = data.workspace; } +#if DEBUG_TENSOR_LEVEL > 1 + DumpInputs(parameters, data); +#endif + if (nullptr != data.gemm_buffer) { // Attention operator - ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block, - data.qkv_format)); - } else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block, - data.q, data.k, data.v, data.qkv_format)); - } else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, data.qkv_format)); - } else if (data.value == nullptr) { // multihead attention operator, no past, packed kv - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, data.qkv_format)); - } else { // multihead attention operator, no past, separated Q/K/V inputs - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block, - data.q, data.k, data.v, data.qkv_format)); + ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block)); + } else { // MultiHeadAttention operator + ORT_RETURN_IF_ERROR(PrepareQkv_MultiHeadAttention(parameters, data, stream, max_threads_per_block)); } + assert(data.qkv_format != AttentionQkvFormat::UNKNOWN); + +#if DEBUG_TENSOR_LEVEL > 1 + DumpQkv(data); +#endif + CUDA_RETURN_IF_ERROR(cudaGetLastError()); return Status::OK(); } // Template Instantiation +template bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData& data); +template bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData& data); + template Status PrepareQkv( contrib::AttentionParameters& parameters, AttentionData& data, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu index bd38a21aadfcb..9f3e396b7f949 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu @@ -304,6 +304,12 @@ Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, c max_threads_per_block, false, input, output); } +Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, + const float* input, float* output, cudaStream_t stream, const int max_threads_per_block) { + return LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, false, input, output); +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc index 66c0aceaed1e7..037a4fdf3d9a0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc @@ -75,7 +75,6 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* attention::kDecoderMaskedAttentionLoadKVDataInFlight, false); bool is_unidirectional = false; - bool is_dmmha_packing = (key == nullptr && value == nullptr); ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, value, @@ -91,7 +90,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* scale_, is_unidirectional, past_present_share_buffer_, - is_dmmha_packing, // dmmha_packing + kDecoderMaskedMultiHeadAttention, device_prop.maxThreadsPerBlock)); if (bias) { @@ -157,7 +156,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* parameters.is_cross_attention = true; parameters.total_sequence_length = parameters.kv_sequence_length; parameters.max_sequence_length = parameters.kv_sequence_length; - // parameters.k and paraneters.v are nullptr + // parameters.k and parameters.v are nullptr parameters.k_cache = const_cast(key->Data()); parameters.v_cache = const_cast(value->Data()); parameters.k_bias = nullptr; @@ -188,12 +187,14 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* } parameters.is_cross_attention = false; - parameters.is_packed_qkv = is_dmmha_packing; - parameters.k = is_dmmha_packing + bool is_packed_qkv = (key == nullptr && value == nullptr); + parameters.is_packed_qkv = is_packed_qkv; + + parameters.k = is_packed_qkv ? const_cast(query->Data() + parameters.hidden_size) : const_cast(key->Data()); - parameters.v = is_dmmha_packing + parameters.v = is_packed_qkv ? const_cast(query->Data() + 2 * static_cast(parameters.hidden_size)) : const_cast(value->Data()); parameters.k_cache = present_key_data; diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 9efb6f08e8e99..2f8d277cb7342 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -183,6 +183,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; } + // This has assumption that key and value does not have bias for cross attention when they are in BNSH format. if (!params.is_cross_attention) { Qk_vec_k k; @@ -580,6 +581,8 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio // One group of threads computes the product(s) for the current timestep. V_vec_k v_bias; + + // This has assumption that key and value does not have bias for cross attention when they are in BNSH format. if (params.v_bias && !params.is_cross_attention) { zero(v_bias); diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 663bd020ddac7..c36abc8e1d624 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -7,6 +7,7 @@ #include "contrib_ops/cpu/bert/multihead_attention_helper.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -44,7 +45,8 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; - ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead."); + ORT_ENFORCE(!is_unidirectional_, + "MHA support CUDA kernel does not Unidirectional. Consider using Attention or GQA instead."); kernel_options_ = this->GetAttentionKernelOptions(); @@ -95,7 +97,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { scale_, is_unidirectional_, false, // past_present_share_buffer - false, // dmmha_packing + kMultiHeadAttention, device_prop.maxThreadsPerBlock)); int sequence_length = parameters.sequence_length; @@ -111,25 +113,43 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { Tensor* present_key = context->Output(1, present_shape); Tensor* present_value = context->Output(2, present_shape); - MHARunner* fused_runner = nullptr; + int num_past = static_cast(past_key != nullptr) + static_cast(past_value != nullptr); + int num_present = static_cast(present_key != nullptr) + static_cast(present_value != nullptr); + if (num_past == 0 && num_present == 0) { + // It is valid case without past state. + } else if ((num_past == 2 && num_present == 2) || (num_past == 0 && num_present == 2)) { + if (parameters.qkv_format == AttentionQkvFormat::QKV_BSN3H) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'past_key', 'past_value', 'present_key' and 'present_value' shall be empty for packed QKV format"); + } + + if (parameters.qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'past_key', 'past_value', 'present_key' and 'present_value' shall be empty for packed KV format"); + } + + if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'past_key', 'past_value', 'present_key' and 'present_value' shall be empty for cross attention"); + } + } else { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'past_key', 'past_value', 'present_key' and 'present_value' shall be all provided, " + "or all empty, or only present_key and present_value are provided"); + } + MHARunner* fused_runner = nullptr; const FusedMultiHeadCrossAttentionKernel* fused_cross_attention_kernel = nullptr; // Check whether we can use fused kernel int sm = device_prop.major * 10 + device_prop.minor; - bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; - - const bool pass_key_value_as_past = (parameters.pass_past_in_kv && nullptr != key && nullptr != value); - -#if USE_FLASH_ATTENTION || USE_MEMORY_EFFICIENT_ATTENTION - // Exclude this case since PrepareQkv will convert the format to BNSH. - bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; -#endif - #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && - !past_no_bias && nullptr == relative_position_bias && nullptr == key_padding_mask && parameters.head_size == parameters.v_head_size && @@ -138,7 +158,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.num_heads, parameters.num_heads); // When input is packed QKV format, TensorRT kernel might be faster than flash attention when sequence length <= 512. - if (use_flash_attention && key == nullptr && value == nullptr && + if (use_flash_attention && parameters.qkv_format == AttentionQkvFormat::QKV_BS3NH && parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) { use_flash_attention = false; } @@ -162,19 +182,21 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif - bool use_fused_cross_attention = !use_flash_attention && - !disable_fused_cross_attention_ && - nullptr == key_padding_mask && - nullptr == relative_position_bias && - (nullptr == past_key && nullptr == past_value && !parameters.pass_past_in_kv) && - key != nullptr && - (value != nullptr || bias == nullptr) && // TODO: new kernel for adding bias to packed KV - parameters.hidden_size == parameters.v_hidden_size && - has_fused_cross_attention_kernel(sm, parameters.head_size, - parameters.kv_sequence_length); + bool use_fused_cross_attention = + !use_flash_attention && + !disable_fused_cross_attention_ && + nullptr == key_padding_mask && + nullptr == relative_position_bias && + nullptr == past_key && nullptr == present_key && + (parameters.qkv_format == Q_K_V_BSNH || (parameters.qkv_format == Q_KV_BSNH_BSN2H && bias == nullptr)) && + parameters.hidden_size == parameters.v_hidden_size && + has_fused_cross_attention_kernel(sm, parameters.head_size, + parameters.kv_sequence_length); if (use_fused_cross_attention) { if (fused_fp16_cross_attention_kernel_ == nullptr) { - fused_fp16_cross_attention_kernel_ = get_fused_cross_attention_kernels(sm); + std::call_once(fused_cross_init_once_flag_, [&]() { + fused_fp16_cross_attention_kernel_ = get_fused_cross_attention_kernels(sm); + }); } // In case some kernel not loaded due to shared memory limit, we need to double check here. @@ -184,17 +206,18 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { } } - bool use_fused_runner = !use_flash_attention && - !disable_fused_self_attention_ && - fused_cross_attention_kernel == nullptr && - nullptr == relative_position_bias && - (value != nullptr || key == nullptr) && - (nullptr == past_key && nullptr == past_value && !parameters.pass_past_in_kv) && - (nullptr == key_padding_mask || is_mask_1d_seq_len) && - parameters.hidden_size == parameters.v_hidden_size && - parameters.sequence_length == parameters.kv_sequence_length && - FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length, - enable_trt_flash_attention_, false); + bool use_fused_runner = + !use_flash_attention && + !disable_fused_self_attention_ && + fused_cross_attention_kernel == nullptr && + nullptr == relative_position_bias && + (parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) && + nullptr == past_key && nullptr == present_key && + (nullptr == key_padding_mask || AttentionMaskType::MASK_1D_KEY_SEQ_LEN) && + parameters.hidden_size == parameters.v_hidden_size && + parameters.sequence_length == parameters.kv_sequence_length && // self attention only for fused runner + FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length, + enable_trt_flash_attention_, false); if (use_fused_runner) { // Here we assume that num_heads and head_size does not change for a MultiHeadAttention node. if (nullptr == fused_fp16_runner_.get()) { @@ -214,10 +237,11 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { #if USE_MEMORY_EFFICIENT_ATTENTION int length_threshold = this->kernel_options_->MinSeqLenForEfficientAttentionFp32(); - bool is_long_sequence = sizeof(T) == 2 || // sequence length threshold is 0 for FP16 + bool is_long_sequence = std::is_same::value || // sequence length threshold is 0 for FP16 parameters.sequence_length >= length_threshold || parameters.kv_sequence_length >= length_threshold; + // Check whether the relative position bias alignment is good for memory efficient attention. bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; bool use_memory_efficient_attention = @@ -226,82 +250,25 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { fused_cross_attention_kernel == nullptr && !disable_memory_efficient_attention_ && is_long_sequence && - !past_no_bias && (relative_position_bias == nullptr || is_good_for_rpb) && (nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && - has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); + has_memory_efficient_attention(sm, std::is_same::value, + parameters.head_size, parameters.v_head_size); #else constexpr bool use_memory_efficient_attention = false; #endif - if (kernel_options_->AllowDebugInfo()) { - AttentionKernelDebugInfo debug_info; - debug_info.use_flash_attention = use_flash_attention; - debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr; - debug_info.use_efficient_attention = use_memory_efficient_attention; - if (fused_fp16_runner_ != nullptr) { - debug_info.SetTrtFusedKernel(is_unidirectional_, enable_trt_flash_attention_, sequence_length); - } - - debug_info.Print("MultiHeadAttention", - this->Node().Name(), - std::is_same::value, - std::is_same::value); - } - - // When packed kv or packed qkv is used, there is no needed for add bias transpose thus no qkv workspace. - // TODO(tianleiwu): flash attention or memory efficient attention might not need qkv workspace sometime. - bool no_qkv_workspace = nullptr == value && - (use_fused_cross_attention || (nullptr != fused_runner && nullptr == key)) && - nullptr == key_padding_mask && - nullptr == bias; - - size_t workspace_bytes; - constexpr size_t element_size = sizeof(T); - if (no_qkv_workspace) { - workspace_bytes = (parameters.batch_size > kCumulatedSequenceLengthCacheMaxBatchSize) ? 2 * GetSequenceOffsetSize(parameters.batch_size, true) : 0; - } else { - workspace_bytes = GetAttentionWorkspaceSize(element_size, - parameters.batch_size, - parameters.num_heads, - parameters.head_size, - parameters.v_head_size, - parameters.sequence_length, - parameters.kv_sequence_length, - parameters.total_sequence_length, - fused_runner, - use_flash_attention, - use_fused_cross_attention, - use_memory_efficient_attention); - } - - auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); - - const size_t past_k_bytes = element_size * parameters.batch_size * parameters.kv_sequence_length * parameters.num_heads * parameters.head_size; - const size_t past_v_bytes = element_size * parameters.batch_size * parameters.kv_sequence_length * parameters.num_heads * parameters.v_head_size; - const bool use_temp_k_v_workspace = parameters.pass_past_in_kv || use_memory_efficient_attention || use_flash_attention; - auto temp_k_work_space = use_temp_k_v_workspace ? GetScratchBuffer(past_k_bytes, context->GetComputeStream()) : nullptr; - auto temp_v_work_space = use_temp_k_v_workspace ? GetScratchBuffer(past_v_bytes, context->GetComputeStream()) : nullptr; - typedef typename ToCudaType::MappedType CudaT; AttentionData data; data.bias = (nullptr == bias) ? nullptr : reinterpret_cast(bias->Data()); data.query = reinterpret_cast(query->Data()); - data.key = (nullptr == key || parameters.pass_past_in_kv) ? nullptr : reinterpret_cast(key->Data()); - data.value = (nullptr == value || parameters.pass_past_in_kv) ? nullptr : reinterpret_cast(value->Data()); + data.key = (nullptr == key) ? nullptr : reinterpret_cast(key->Data()); + data.value = (nullptr == value) ? nullptr : reinterpret_cast(value->Data()); data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data(); data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims(); - data.past_key = pass_key_value_as_past ? reinterpret_cast(key->Data()) - : (nullptr == past_key) ? nullptr - : reinterpret_cast(past_key->Data()); - data.past_value = pass_key_value_as_past ? reinterpret_cast(value->Data()) - : (nullptr == past_value) ? nullptr - : reinterpret_cast(past_value->Data()); + data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); + data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data()); data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); - data.has_qkv_workspace = !no_qkv_workspace; - data.workspace = reinterpret_cast(work_space.get()); - data.temp_k_workspace = use_temp_k_v_workspace ? reinterpret_cast(temp_k_work_space.get()) : nullptr; - data.temp_v_workspace = use_temp_k_v_workspace ? reinterpret_cast(temp_v_work_space.get()) : nullptr; data.output = reinterpret_cast(output->MutableData()); data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); @@ -309,8 +276,41 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.fused_cross_attention_kernel = fused_cross_attention_kernel; data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; - data.cumulated_sequence_length_q_cache = &(this->cumulated_sequence_length_q_cache_); - data.cumulated_sequence_length_kv_cache = &(this->cumulated_sequence_length_kv_cache_); + + // Cache of cumulated sequence length that could help when sequence length does not change (for example, image model). + // The cache will be initialized only once, and become readonly after that. + if ((data.fused_cross_attention_kernel != nullptr || data.fused_runner != nullptr) && data.mask_index == nullptr) { + cudaStream_t stream = Stream(context); + data.cumulated_sequence_length_q_cache = this->cumulated_sequence_length_q_cache_.TryGet( + parameters.batch_size, parameters.sequence_length, stream); + + if (data.fused_cross_attention_kernel != nullptr) { + data.cumulated_sequence_length_kv_cache = this->cumulated_sequence_length_kv_cache_.TryGet( + parameters.batch_size, parameters.kv_sequence_length, stream); + } + } + + const bool no_qkv_workspace = NoQkvWorkspace(parameters, data); + size_t workspace_bytes = GetAttentionWorkspaceSize(sizeof(T), + parameters.batch_size, + parameters.num_heads, + parameters.head_size, + parameters.v_head_size, + parameters.sequence_length, + parameters.kv_sequence_length, + parameters.total_sequence_length, + fused_runner, + use_flash_attention, + use_fused_cross_attention, + use_memory_efficient_attention, + no_qkv_workspace); + auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + + data.has_qkv_workspace = !no_qkv_workspace; + data.workspace = reinterpret_cast(work_space.get()); + data.workspace_bytes = workspace_bytes; + + data.allow_debug_info = kernel_options_->AllowDebugInfo(); if (softmax_lse_accum_buffer != nullptr) { data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); } @@ -318,8 +318,23 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.out_accum = reinterpret_cast(out_accum_buffer.get()); } - cublasHandle_t cublas = GetCublasHandle(context); + if (data.allow_debug_info) { + AttentionKernelDebugInfo debug_info; + debug_info.use_flash_attention = use_flash_attention; + debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr; + debug_info.use_efficient_attention = use_memory_efficient_attention; + if (fused_fp16_runner_ != nullptr) { + debug_info.SetTrtFusedKernel(is_unidirectional_, enable_trt_flash_attention_, sequence_length); + } + debug_info.Print("MultiHeadAttention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + + data.PrintDebugInfo(); + } + cublasHandle_t cublas = GetCublasHandle(context); return QkvToContext( device_prop, cublas, context->GetComputeStream(), parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index 26e38dbad9fd7..68fd0c9943fca 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/providers/cuda/cuda_kernel.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h" @@ -32,11 +33,16 @@ class MultiHeadAttention final : public CudaKernel { bool disable_fused_cross_attention_; bool disable_flash_attention_; bool disable_memory_efficient_attention_; + + // These mutable members are readonly after they are initialized so that they can be shared among multiple threads. + // Initialization are done only once by the first thread using the resource, so use once_flag to guard each resource. mutable std::unique_ptr fused_fp16_runner_; mutable std::once_flag fused_fp16_runner_created_; mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_; + mutable std::once_flag fused_cross_init_once_flag_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_kv_cache_; + const AttentionKernelOptions* kernel_options_; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index ac2cb5165a94c..2521cd49b5482 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -297,7 +297,7 @@ struct T2 { }; template -void LaunchAddBiasTranspose( +void AddBiasTransposePacked( const T* input, const T* biases, T* output, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const int v_head_size, @@ -452,7 +452,7 @@ Status FusedScaledDotProductAttention( void* fused_runner = data.fused_runner; ORT_RETURN_IF_NOT(nullptr != fused_runner, "fused_runner cannot be NULL"); - LaunchAddBiasTranspose(data.gemm_buffer, data.bias, data.workspace, + AddBiasTransposePacked(data.gemm_buffer, data.bias, data.workspace, batch_size, sequence_length, num_heads, qk_head_size, v_head_size, AttentionQkvFormat::QKV_BSN3H, data.token_offset, @@ -477,7 +477,7 @@ Status FusedScaledDotProductAttentionCutlass( const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; - LaunchAddBiasTranspose(data.gemm_buffer, data.bias, data.workspace, + AddBiasTransposePacked(data.gemm_buffer, data.bias, data.workspace, batch_size, sequence_length, num_heads, qk_head_size, v_head_size, AttentionQkvFormat::Q_K_V_BSNH, data.token_offset, @@ -564,7 +564,7 @@ Status UnfusedScaledDotProductAttention( T* k = q + elements_q; T* v = k + elements_k; - LaunchAddBiasTranspose(data.gemm_buffer, data.bias, data.workspace, + AddBiasTransposePacked(data.gemm_buffer, data.bias, data.workspace, batch_size, sequence_length, num_heads, qk_head_size, v_head_size, AttentionQkvFormat::Q_K_V_BNSH, data.token_offset, @@ -657,6 +657,20 @@ Status QkvToContext( return UnfusedScaledDotProductAttention(device_prop, cublas, stream, parameters, data); } +template void AddBiasTransposePacked( + const float* input, const float* biases, float* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat format, const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + +template void AddBiasTransposePacked( + const half* input, const half* biases, half* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat format, const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index b4ca0194b08bc..e5a4c54f48903 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -502,7 +502,7 @@ struct T2 { }; template -void LaunchTranspose( +void AddBiasTransposePacked( const T* query, const T* key, const T* value, const T* bias, T* output, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const int v_head_size, @@ -566,11 +566,11 @@ Status FusedAttentionTrt( // When packed QKV is used, we can directly pass it to fused runner. Otherwise, we need transpose to BSN3H format. const T* qkv = data.query; if (!data.no_qkv_workspace) { - LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, - batch_size, sequence_length, - num_heads, qk_head_size, v_head_size, - data.source_qkv_format, AttentionQkvFormat::QKV_TN3H, - data.token_offset, parameters.token_count, stream); + AddBiasTransposePacked(data.query, data.key, data.value, data.bias, data.workspace, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + data.source_qkv_format, AttentionQkvFormat::QKV_TN3H, + data.token_offset, parameters.token_count, stream); qkv = data.workspace; } @@ -601,11 +601,11 @@ Status FlashAttention( // When separated Q, K, V is used, we can directly use them in Cutlass FMHA. Otherwise, transpose BSN3H to 3BSNH if (!data.no_qkv_workspace) { - LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, - batch_size, sequence_length, - num_heads, qk_head_size, v_head_size, - data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, - data.token_offset, parameters.token_count, stream); + AddBiasTransposePacked(data.query, data.key, data.value, data.bias, data.workspace, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, + data.token_offset, parameters.token_count, stream); } float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) @@ -675,11 +675,11 @@ Status FusedAttentionCutlass( // When separated Q, K, V is used, we can directly use them in Cutlass FMHA. Otherwise, transpose BSN3H to 3BSNH if (!data.no_qkv_workspace) { - LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, - batch_size, sequence_length, - num_heads, qk_head_size, v_head_size, - data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, - data.token_offset, parameters.token_count, stream); + AddBiasTransposePacked(data.query, data.key, data.value, data.bias, data.workspace, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, + data.token_offset, parameters.token_count, stream); } MemoryEfficientAttentionParams p; @@ -746,11 +746,11 @@ Status UnfusedAttention( const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v); // Q, K and V pointers when fused attention is not used - LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, - batch_size, sequence_length, - num_heads, qk_head_size, v_head_size, - data.source_qkv_format, AttentionQkvFormat::Q_K_V_BNSH, - data.token_offset, parameters.token_count, stream); + AddBiasTransposePacked(data.query, data.key, data.value, data.bias, data.workspace, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + data.source_qkv_format, AttentionQkvFormat::Q_K_V_BNSH, + data.token_offset, parameters.token_count, stream); T* qkv = data.workspace; T* q = qkv; @@ -848,6 +848,22 @@ Status QkvToContext( return UnfusedAttention(device_prop, cublas, stream, parameters, data); } +template void AddBiasTransposePacked( + const half* query, const half* key, const half* value, const half* bias, half* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat source_format, AttentionQkvFormat target_format, + const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + +template void AddBiasTransposePacked( + const float* query, const float* key, const float* value, const float* bias, float* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat source_format, AttentionQkvFormat target_format, + const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 168c69c69f003..b62e566d43f89 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -190,7 +190,8 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { fused_runner, use_flash_attention, use_fused_cross_attention, - use_memory_efficient_attention); + use_memory_efficient_attention, + true); auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); @@ -208,6 +209,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); + data.workspace_bytes = workSpaceSize; data.output = reinterpret_cast(output->MutableData()); if (nullptr != present) { data.present = reinterpret_cast(present->MutableData()); diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc index e10c2ec63fd51..6d52ff7282799 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc @@ -13,6 +13,9 @@ namespace cuda { #if DUMP_TENSOR_LEVEL > 0 +// Environment variable to enable/disable GPU Tensor dumping +constexpr const char* kEnableGpuTensorDumper = "ORT_ENABLE_GPU_DUMP"; + // Total number of elements which trigger snippet rather than full dump (default 200). Value 0 disables snippet. constexpr const char* kTensorSnippetThreshold = "ORT_TENSOR_SNIPPET_THRESHOLD"; @@ -202,6 +205,10 @@ void DumpGpuTensor(const char* name, const Tensor& tensor) { DumpGpuTensor(nullptr, tensor, static_cast(num_rows), static_cast(row_size)); } +CudaTensorConsoleDumper::CudaTensorConsoleDumper() { + is_enabled_ = ParseEnvironmentVariableWithDefault(kEnableGpuTensorDumper, 1) != 0; +} + void CudaTensorConsoleDumper::Print(const std::string& value) const { std::cout << value << std::endl; } @@ -329,6 +336,8 @@ void CudaTensorConsoleDumper::Print(const char* name, const std::string& value, } #else +CudaTensorConsoleDumper::CudaTensorConsoleDumper() { +} void CudaTensorConsoleDumper::Print(const std::string&) const { } diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h index 6ad0ad9a67b75..4f41161cd4a31 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h @@ -13,7 +13,7 @@ namespace cuda { class CudaTensorConsoleDumper : public onnxruntime::contrib::IConsoleDumper { public: - CudaTensorConsoleDumper() = default; + CudaTensorConsoleDumper(); virtual ~CudaTensorConsoleDumper() {} void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override; diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu index b0ed3ff82226a..b94971ffd44d5 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -119,7 +119,7 @@ Status ClassifyAttentionMode( if (attn->qkv_format == Q_K_V_BSNH) { attn->mode = BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE; return Status::OK(); - } else if (attn->pass_past_in_kv) { + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { attn->mode = BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE; return Status::OK(); } @@ -128,7 +128,7 @@ Status ClassifyAttentionMode( if (attn->qkv_format == Q_K_V_BSNH) { attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH; return Status::OK(); - } else if (attn->pass_past_in_kv) { + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH; return Status::OK(); } @@ -136,7 +136,7 @@ Status ClassifyAttentionMode( if (attn->qkv_format == Q_K_V_BSNH) { attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH; return Status::OK(); - } else if (attn->pass_past_in_kv) { + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH; return Status::OK(); } @@ -146,7 +146,7 @@ Status ClassifyAttentionMode( if (attn->qkv_format == Q_K_V_BSNH) { attn->mode = BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH; return Status::OK(); - } else if (attn->pass_past_in_kv) { + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { attn->mode = BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH; return Status::OK(); } @@ -154,7 +154,7 @@ Status ClassifyAttentionMode( if (attn->qkv_format == Q_K_V_BSNH) { attn->mode = BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH; return Status::OK(); - } else if (attn->pass_past_in_kv) { + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { attn->mode = BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH; return Status::OK(); } diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 349df045becf2..d593bc0012826 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -132,12 +132,6 @@ class CompatRocblasMathModeSetter { } }; -enum AttentionType { - kAttention, - kMultiHeadAttention, - kDecoderMaskedMultiHeadAttention, -}; - enum AttentionMode { // Q,K,V,PastK,PastV,PresentK,PresentV QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE, diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu index 09e7d61b71db9..5997daaca6e8a 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu @@ -122,9 +122,11 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, past_seq_len, - &attn, num_heads_, - mask_filter_value_, scale_, false, /*is_unidirectional_*/ - past_present_share_buffer_, false, device_prop.maxThreadsPerBlock)); + &attn, num_heads_, + mask_filter_value_, scale_, false, /*is_unidirectional_*/ + past_present_share_buffer_, + attn_type_, + device_prop.maxThreadsPerBlock)); if (attn_type_ == kDecoderMaskedMultiHeadAttention && attn.sequence_length != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index dc2b38f3928ac..a9ff623fb6967 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -691,6 +691,9 @@ def create_multihead_attention_node( return None # Add bias to inputs for MHA + # Bias for cross attention is not fully supported in DMMHA and cpu MHA kernels since they assume + # bias has been added to key and value when they are in BNSH format, so only bias for query is used. + # Need add checks if we found such assumption is not true. if not self.disable_multi_head_attention_bias: bias_name = self.create_combined_qkv_bias(q_add, k_add, v_add, mha_node_name) mha_inputs.append(bias_name) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 715a92431e6bf..ec350874af32c 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -88,9 +88,11 @@ def __init__( enable_cuda_graph: bool = False, dtype=torch.float, use_kv_cache: bool = False, + has_past_input: bool = False, share_past_present_buffer: bool = False, input_format: int = InputFormats.Q_K_V_BSNH_BSNH_BSNH, verbose: bool = False, + has_bias: bool = False, ): self.operator = "MultiHeadAttention" self.batch_size = batch_size @@ -103,15 +105,25 @@ def __init__( self.causal = causal self.softmax_scale = softmax_scale or (1.0 / (head_size**0.5)) + # Support the case that there is no past but need present output (for prompt case). + self.has_past_input = has_past_input + if has_past_input: + assert use_kv_cache + else: # no past input + assert past_sequence_length == 0 + + self.has_present_output = use_kv_cache + self.use_kv_cache = use_kv_cache if not use_kv_cache: assert past_sequence_length == 0 else: assert self.kv_sequence_length == self.sequence_length - if input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: - # cross attention does not have past state - assert not use_kv_cache + # Only BSNH input format supports past state. + if input_format != InputFormats.Q_K_V_BSNH_BSNH_BSNH: + assert not self.has_past_input + assert not self.has_present_output # Derived values self.total_sequence_length = self.kv_sequence_length + past_sequence_length @@ -130,6 +142,7 @@ def __init__( self.is_packed_qkv = input_format == InputFormats.QKV_BSN3H self.is_packed_kv = input_format == InputFormats.Q_KV_BSNH_BSN2H self.verbose = verbose + self.has_bias = has_bias def __repr__(self): return ( @@ -140,7 +153,8 @@ def __repr__(self): f"causal={self.causal}), softmax_scale={self.softmax_scale}, use_kv_cache={self.use_kv_cache}, " f"share_past_present_buffer={self.share_past_present_buffer}, " f"provider={self.provider}, device={self.device}, enable_cuda_graph={self.enable_cuda_graph}, " - f"dtype={self.dtype}, input_format={InputFormats.input_format_str(self.input_format)}" + f"dtype={self.dtype}, input_format={InputFormats.input_format_str(self.input_format)}, " + f"has_bias={self.has_bias}" ) def shape_dict(self, input_format=None): @@ -176,16 +190,23 @@ def shape_dict(self, input_format=None): "value": (self.batch_size, self.num_heads, self.sequence_length, self.head_size), } - if self.use_kv_cache: - assert input_format != InputFormats.Q_K_V_BSNH_BNSH_BNSH, "cross attention shall not have past state" + if self.has_past_input: shapes = { **shapes, "past_key": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size), "past_value": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size), + } + + if self.has_present_output: + shapes = { + **shapes, "present_key": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size), "present_value": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size), } + if self.has_bias: + shapes["bias"] = (3 * self.num_heads * self.head_size,) + return shapes def symbolic_shape_dict(self, input_format=None): @@ -221,19 +242,26 @@ def symbolic_shape_dict(self, input_format=None): "value": ("batch_size", self.num_heads, "sequence_length", self.head_size), } - if self.use_kv_cache: - assert input_format != InputFormats.Q_K_V_BSNH_BNSH_BNSH, "cross attention shall not have past state" + if self.has_past_input: shapes = { **shapes, "past_key": ("batch_size", self.num_heads, "past_buffer_length", self.head_size), "past_value": ("batch_size", self.num_heads, "past_buffer_length", self.head_size), + } + + if self.has_present_output: + shapes = { + **shapes, "present_key": ("batch_size", self.num_heads, "present_buffer_length", self.head_size), "present_value": ("batch_size", self.num_heads, "present_buffer_length", self.head_size), } + if self.has_bias: + shapes["bias"] = (3 * self.num_heads * self.head_size,) + return shapes - def random_inputs(self, seed: int = 123): + def random_inputs(self, seed: int = 123, no_bias_k_v: bool = False): device = self.device dtype = self.dtype @@ -246,6 +274,14 @@ def random_inputs(self, seed: int = 123): q = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1) k = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1) v = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1) + + bias_q = torch.empty((self.num_heads * self.head_size,), device=device, dtype=dtype).normal_(mean=0, std=0.1) + bias_k = torch.empty((self.num_heads * self.head_size,), device=device, dtype=dtype).normal_(mean=0, std=0.1) + bias_v = torch.empty((self.num_heads * self.head_size,), device=device, dtype=dtype).normal_(mean=0, std=0.1) + if no_bias_k_v: + bias_k = torch.zeros_like(bias_k) + bias_v = torch.zeros_like(bias_v) + k_bnsh = k.transpose(1, 2) v_bnsh = v.transpose(1, 2) @@ -277,7 +313,7 @@ def random_inputs(self, seed: int = 123): "value": v_bnsh.contiguous(), } - if self.use_kv_cache: + if self.has_past_input: feeds = { **feeds, "past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_(mean=0, std=0.1), @@ -286,6 +322,9 @@ def random_inputs(self, seed: int = 123): ), } + if self.has_bias: + feeds["bias"] = torch.concat([bias_q, bias_k, bias_v], dim=0).reshape(shape_dict["bias"]).contiguous() + return feeds def get_input_output_names(self): @@ -299,15 +338,29 @@ def get_input_output_names(self): else: inputs, outputs = ["query", "key", "value"], ["output"] - if self.use_kv_cache: - return [*inputs, "past_key", "past_value"], [*outputs, "present_key", "present_value"] - else: - return inputs, outputs + if self.has_bias: + inputs = [*inputs, "bias"] + + if self.has_past_input: + inputs = [*inputs, "past_key", "past_value"] + + if self.has_present_output: + outputs = [*outputs, "present_key", "present_value"] + + return inputs, outputs def fill_optional_mha_inputs(input_names): inputs = ["query", "key", "value", "bias", "key_padding_mask", "relative_position_bias", "past_key", "past_value"] - return input_names[:-2] + [""] * (len(inputs) - len(input_names)) + input_names[-2:] + + # Remove optional inputs that are not in input_names with empty string + inputs_with_optional = [input if input in input_names else "" for input in inputs] + + # Remove empty string at the end of the list. + while inputs_with_optional[-1] == "": + inputs_with_optional.pop(-1) + + return inputs_with_optional def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use_symbolic_shape=False): @@ -317,7 +370,7 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use nodes = [ helper.make_node( "MultiHeadAttention", - fill_optional_mha_inputs(input_names) if config.use_kv_cache else input_names, + fill_optional_mha_inputs(input_names), output_names, "MultiHeadAttention_0", num_heads=config.num_heads, @@ -331,11 +384,13 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use inputs = [ helper.make_tensor_value_info(input_name, float_type, list(shape_dict[input_name])) for input_name in input_names + if input_name ] outputs = [ helper.make_tensor_value_info(output_name, float_type, list(shape_dict[output_name])) for output_name in output_names + if output_name ] graph = helper.make_graph( @@ -355,6 +410,7 @@ def create_ort_session( session_options=None, attention_kernel=SdpaKernel.DEFAULT, use_symbolic_shape: bool = True, + use_tf32: bool = True, ) -> CudaSession: if config.verbose: print(f"create session for {vars(config)}") @@ -364,6 +420,7 @@ def create_ort_session( device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index provider_options = CudaSession.get_cuda_provider_options(device_id, config.enable_cuda_graph) provider_options["sdpa_kernel"] = int(attention_kernel) + provider_options["use_tf32"] = int(use_tf32) providers = [(config.provider, provider_options), "CPUExecutionProvider"] else: providers = ["CPUExecutionProvider"] @@ -373,9 +430,11 @@ def create_ort_session( def create_session( - config: MultiHeadAttentionConfig, session_options=None, attention_kernel=SdpaKernel.DEFAULT + config: MultiHeadAttentionConfig, session_options=None, attention_kernel=SdpaKernel.DEFAULT, use_tf32: bool = True ) -> CudaSession: - ort_session = create_ort_session(config, session_options, attention_kernel, use_symbolic_shape=False) + ort_session = create_ort_session( + config, session_options, attention_kernel, use_symbolic_shape=False, use_tf32=use_tf32 + ) cuda_session = CudaSession(ort_session, config.device, config.enable_cuda_graph) shape_dict = config.shape_dict() cuda_session.allocate_buffers(shape_dict) @@ -385,8 +444,8 @@ def create_session( class OrtMultiHeadAttention: """A wrapper of ORT MultiHeadAttention to test relevance and performance.""" - def __init__(self, config: MultiHeadAttentionConfig, session_options=None): - self.ort_session = create_session(config, session_options) + def __init__(self, config: MultiHeadAttentionConfig, session_options=None, use_tf32: bool = True): + self.ort_session = create_session(config, session_options, use_tf32=use_tf32) self.feed_dict = config.random_inputs() def infer(self): diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index 0fcbd889847e9..a35d02b0b9d52 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -21,6 +21,47 @@ import onnxruntime +def get_provider_support_info(provider: str, use_kv_cache: bool): + if provider == "CUDAExecutionProvider": + if not use_kv_cache: + formats = [ + InputFormats.Q_K_V_BSNH_BSNH_BSNH, + InputFormats.Q_KV_BSNH_BSN2H, + InputFormats.QKV_BSN3H, + InputFormats.Q_K_V_BSNH_BNSH_BNSH, + ] + else: + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] + + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + dtype = torch.float16 + else: + assert provider == "CPUExecutionProvider" + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] + if not use_kv_cache: + formats.append(InputFormats.Q_K_V_BSNH_BNSH_BNSH) + device = torch.device("cpu") + dtype = torch.float + return device, dtype, formats + + +def get_bias_support(format: InputFormats): + if format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: + return [True, False] + + if format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: + return [True, False] + + if format == InputFormats.Q_KV_BSNH_BSN2H: + return [False] + + if format == InputFormats.QKV_BSN3H: + return [True, False] + + raise RuntimeError(f"Unknown format: {format}") + + def attention_reference( head_size: int, query: torch.Tensor, @@ -84,8 +125,8 @@ def attention_reference( def mha_with_past_reference( config: MultiHeadAttentionConfig, - past_k: torch.Tensor, - past_v: torch.Tensor, + past_k: Optional[torch.Tensor], + past_v: Optional[torch.Tensor], q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -94,41 +135,23 @@ def mha_with_past_reference( ): assert config.kv_sequence_length == config.sequence_length assert config.use_kv_cache - assert past_k.dim() == 4 and k.dim() == 4 and past_k.size(1) == k.size(1) # both BNSH format - assert past_v.dim() == 4 and v.dim() == 4 and past_v.size(1) == v.size(1) # both BNSH format - - present_k = torch.cat((past_k, k), dim=2) - present_v = torch.cat((past_v, v), dim=2) + if past_k is not None: + assert ( + past_k.dim() == 4 and k.dim() == 4 and past_k.size(1) == k.size(1) + ), f"expect BNSH format: {past_k.shape=} {k.shape=}" + + if past_v is not None: + assert ( + past_v.dim() == 4 and v.dim() == 4 and past_v.size(1) == v.size(1) + ), f"expect BNSH format: {past_v.shape=} {v.shape=}" + + present_k = torch.cat((past_k, k), dim=2) if past_k is not None else k + present_v = torch.cat((past_v, v), dim=2) if past_v is not None else v out = attention_reference(config.head_size, q, present_k, present_v, scale=scale, mask=mask) return out, present_k, present_v -def get_provider_support_info(provider: str, use_kv_cache: bool): - if provider == "CUDAExecutionProvider": - if not use_kv_cache: - formats = [ - InputFormats.Q_K_V_BSNH_BSNH_BSNH, - InputFormats.Q_KV_BSNH_BSN2H, - InputFormats.QKV_BSN3H, - InputFormats.Q_K_V_BSNH_BNSH_BNSH, - ] - else: - formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] - - device_id = torch.cuda.current_device() - device = torch.device("cuda", device_id) - dtype = torch.float16 - else: - assert provider == "CPUExecutionProvider" - formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] - if not use_kv_cache: - formats.append(InputFormats.Q_K_V_BSNH_BSNH_BSNH) - device = torch.device("cpu") - dtype = torch.float - return device, dtype, formats - - def get_compute_capability(): if torch.cuda.is_available() and "CUDAExecutionProvider" in onnxruntime.get_available_providers(): major, minor = torch.cuda.get_device_capability() @@ -143,35 +166,38 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): yield batch_sizes = [1, 2, 3] - sequence_lengths = [1, 16, 127, 128, 255, 256, 383, 384, 2048] + sequence_lengths = [1, 16, 127, 128, 255, 256, 383, 384, 512] heads = [1, 3, 4, 16] head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256] device, dtype, formats = get_provider_support_info(provider, False) if comprehensive: + sequence_lengths = [*sequence_lengths, 2048] # Large sequence length is slow and need a lot of memory for batch_size in batch_sizes: for sequence_length in sequence_lengths: for num_heads in heads: for head_size in head_sizes: for format in formats: for causal in [True, False]: - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=0, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=False, - share_past_present_buffer=False, - input_format=format, - ) - yield config + for has_bias in get_bias_support(format): + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=0, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=False, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + ) + yield config else: test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes)) for i in range(test_cases): @@ -179,25 +205,27 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): sequence_length = sequence_lengths[i % len(sequence_lengths)] num_heads = heads[i % len(heads)] head_size = head_sizes[i % len(head_sizes)] - format = formats[i % len(formats)] for causal in [True, False]: - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=0, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=False, - share_past_present_buffer=False, - input_format=format, - ) - yield config + for format in formats: + for has_bias in get_bias_support(format): + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=0, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=False, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + ) + yield config def kv_cache_test_cases(provider: str, comprehensive: bool): @@ -206,37 +234,42 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): yield batch_sizes = [1, 2, 3] - sequence_lengths = [1, 15, 16, 255, 256, 2048] + sequence_lengths = [1, 15, 16, 255, 256, 512] heads = [1, 3, 4, 16] head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - - sequence_length = 1 device, dtype, formats = get_provider_support_info(provider, True) if comprehensive: + sequence_lengths = [*sequence_lengths, 2048] # Large sequence length is slow and need a lot of memory for batch_size in batch_sizes: for past_sequence_length in sequence_lengths: for num_heads in heads: for head_size in head_sizes: for format in formats: for causal in [True, False]: - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=past_sequence_length, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=True, - share_past_present_buffer=False, - input_format=format, - ) - yield config + for has_past_input in [True, False]: + for has_bias in get_bias_support(format): + sequence_length = 1 if has_past_input else past_sequence_length + past_seq_len = past_sequence_length if has_past_input else 0 + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=past_seq_len, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=True, + has_past_input=has_past_input, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + ) + yield config else: test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes)) for i in range(test_cases): @@ -244,31 +277,31 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): past_sequence_length = sequence_lengths[i % len(sequence_lengths)] num_heads = heads[i % len(heads)] head_size = head_sizes[i % len(head_sizes)] - format = formats[i % len(formats)] for causal in [True, False]: - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=past_sequence_length, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=True, - share_past_present_buffer=False, - input_format=format, - ) - yield config - - -def mha_test_cases(provider: str, comprehensive: bool): - return itertools.chain( - no_kv_cache_test_cases(provider, comprehensive), kv_cache_test_cases(provider, comprehensive) - ) + for format in formats: + for has_past_input in [True, False]: + for has_bias in get_bias_support(format): + sequence_length = 1 if has_past_input else past_sequence_length + past_seq_len = past_sequence_length if has_past_input else 0 + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=past_seq_len, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=True, + has_past_input=has_past_input, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + ) + yield config def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): @@ -343,6 +376,7 @@ def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): device=device, dtype=dtype, use_kv_cache=True, + has_past_input=True, share_past_present_buffer=False, input_format=format, ) @@ -350,13 +384,6 @@ def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): yield configs -def multi_thread_test_cases(provider: str, comprehensive: bool): - return itertools.chain( - no_kv_cache_multi_thread_test_cases(provider, comprehensive), - kv_cache_multi_thread_test_cases(provider, comprehensive), - ) - - def causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None): row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) @@ -374,28 +401,31 @@ def parity_check_mha( if config.causal and config.provider == "CUDAExecutionProvider": return - ort_mha = OrtMultiHeadAttention(config) + ort_mha = OrtMultiHeadAttention(config, use_tf32=False) ort_outputs = ort_mha.infer() out = ort_outputs["output"] out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + no_bias_k_v = config.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH config.input_format = InputFormats.Q_K_V_BSNH_BSNH_BSNH - ref_inputs = config.random_inputs() - q = ( - ref_inputs["query"] - .reshape((config.batch_size, config.sequence_length, config.num_heads, config.head_size)) - .transpose(1, 2) - ) - k = ( - ref_inputs["key"] - .reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) - .transpose(1, 2) - ) - v = ( - ref_inputs["value"] - .reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) - .transpose(1, 2) - ) + ref_inputs = config.random_inputs(no_bias_k_v=no_bias_k_v) + q = ref_inputs["query"].reshape((config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + k = ref_inputs["key"].reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) + v = ref_inputs["value"].reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) + + if "bias" in ref_inputs: + bias = ref_inputs["bias"] + bias = bias.reshape((3, config.num_heads, config.head_size)) + bias_q = bias[0, :, :].reshape(1, 1, config.num_heads, config.head_size) + bias_k = bias[1, :, :].reshape(1, 1, config.num_heads, config.head_size) + bias_v = bias[2, :, :].reshape(1, 1, config.num_heads, config.head_size) + q = q + bias_q + k = k + bias_k + v = v + bias_v + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) mask = None if config.causal: @@ -404,8 +434,8 @@ def parity_check_mha( k_cache = None v_cache = None if config.use_kv_cache: - past_k = ref_inputs["past_key"] - past_v = ref_inputs["past_value"] + past_k = ref_inputs.get("past_key", None) + past_v = ref_inputs.get("past_value", None) out_ref, k_cache, v_cache = mha_with_past_reference(config, past_k, past_v, q, k, v, mask=mask) else: out_ref = attention_reference(config.head_size, q, k, v, mask=mask) @@ -445,7 +475,7 @@ def parity_check_mha_multi_threading( test_inputs: List[Dict], rtol: float = 1e-3, atol: float = 1e-3, - attention_kernel: int = SdpaKernel.DEFAULT, + attention_kernel=SdpaKernel.DEFAULT, max_threads: int = 5, verbose: bool = False, ): @@ -454,6 +484,7 @@ def parity_check_mha_multi_threading( # For now, MHA CUDA kernel does not support causal so skip such test cases. if config.causal and config.provider == "CUDAExecutionProvider": return None + # Some kernel does not support certain input format. if attention_kernel not in [ SdpaKernel.DEFAULT, @@ -462,7 +493,7 @@ def parity_check_mha_multi_threading( ] and config.input_format in [InputFormats.Q_KV_BSNH_BSN2H]: return None - ort_session = create_ort_session(config, attention_kernel=attention_kernel, use_symbolic_shape=True) + ort_session = create_ort_session(config, attention_kernel=attention_kernel, use_symbolic_shape=True, use_tf32=False) def convert_to_ort_inputs(feed_dict): ort_inputs = {} @@ -572,18 +603,32 @@ def check_parity_with_config(i: int): return None -# Do not run too many tests in CI pipeline. Change it to True to run all combinations in dev machine. +def mha_test_cases(provider: str, comprehensive: bool): + return itertools.chain( + no_kv_cache_test_cases(provider, comprehensive), + kv_cache_test_cases(provider, comprehensive), + ) + + +def multi_thread_test_cases(provider: str, comprehensive: bool): + return itertools.chain( + no_kv_cache_multi_thread_test_cases(provider, comprehensive), + kv_cache_multi_thread_test_cases(provider, comprehensive), + ) + + +# Off by default so that we do not run too many tests in CI pipeline. comprehensive_mode = False class TestMultiHeadAttention(unittest.TestCase): @parameterized.expand(mha_test_cases("CUDAExecutionProvider", comprehensive_mode), skip_on_empty=True) def test_mha_cuda(self, config): - parity_check_mha(config) + parity_check_mha(config, rtol=5e-3, atol=5e-3) @parameterized.expand(mha_test_cases("CPUExecutionProvider", comprehensive_mode), skip_on_empty=True) def test_mha_cpu(self, config): - parity_check_mha(config) + parity_check_mha(config, rtol=5e-3, atol=5e-3) def run_mha_cuda_multi_threading(self, attention_kernel): for configs in multi_thread_test_cases("CUDAExecutionProvider", comprehensive_mode): @@ -604,19 +649,24 @@ def run_mha_cuda_multi_threading(self, attention_kernel): assert exception is None, f"{attention_kernel=}, {vars(configs[0])}, {exception}" def test_mha_cuda_multi_threading(self): - self.run_mha_cuda_multi_threading(SdpaKernel.DEFAULT) + if get_compute_capability() >= 60: + self.run_mha_cuda_multi_threading(SdpaKernel.DEFAULT) def test_mha_cuda_multi_threading_efficient(self): - self.run_mha_cuda_multi_threading(SdpaKernel.EFFICIENT_ATTENTION) + if comprehensive_mode and get_compute_capability() >= 60: + self.run_mha_cuda_multi_threading(SdpaKernel.EFFICIENT_ATTENTION) + + def test_mha_cuda_multi_threading_math(self): + if comprehensive_mode and get_compute_capability() >= 60: + self.run_mha_cuda_multi_threading(SdpaKernel.MATH) def test_mha_cuda_multi_threading_trt(self): - sm = get_compute_capability() - if sm in [75, 80, 86, 89]: + if get_compute_capability() in [75, 80, 86, 89]: self.run_mha_cuda_multi_threading( SdpaKernel.TRT_FUSED_ATTENTION | SdpaKernel.TRT_FLASH_ATTENTION - | SdpaKernel.TRT_CROSS_ATTENTION | SdpaKernel.TRT_CAUSAL_ATTENTION + | SdpaKernel.TRT_CROSS_ATTENTION ) From a3883af7bfede84315abaa94fbc4cc2a0d2b02a3 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 1 Aug 2024 05:39:21 +0800 Subject: [PATCH 35/37] [WebNN EP] Fixed bug in ConvTranspose (#21569) The constraint of ConvTranspose was placed in wrong place. --- .../webnn/builders/impl/conv_op_builder.cc | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 22049d2519712..76a8a178678df 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -28,7 +28,7 @@ class ConvOpBuilder : public BaseOpBuilder { // Operator support related. private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + const WebnnDeviceType device_type, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; }; @@ -378,6 +378,22 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return false; } + // WebNN CPU backend (TFLite) only supports default dilations and group. + // https://source.chromium.org/chromium/chromium/src/+/main:services/webnn/tflite/graph_builder_tflite.cc;l=1040 + if (device_type == WebnnDeviceType::CPU && op_type == "ConvTranspose") { + NodeAttrHelper helper(node); + const auto dilations = helper.Get("dilations", std::vector{1, 1}); + const auto group = helper.Get("group", 1); + if (dilations[0] != 1 || (dilations.size() > 1 && dilations[1] != 1)) { + LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default dilation 1."; + return false; + } + if (group != 1) { + LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default group 1."; + return false; + } + } + return true; } @@ -427,22 +443,6 @@ bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy return false; } - // WebNN CPU backend (TFLite) only supports default dilations and group. - // https://source.chromium.org/chromium/chromium/src/+/main:services/webnn/tflite/graph_builder_tflite.cc;l=1040 - if (device_type == WebnnDeviceType::CPU && op_type == "ConvTranspose") { - NodeAttrHelper helper(node); - const auto dilations = helper.Get("dilations", std::vector{1, 1}); - const auto group = helper.Get("group", 1); - if (dilations[0] != 1 || (dilations.size() > 1 && dilations[1] != 1)) { - LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default dilation 1."; - return false; - } - if (group != 1) { - LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default group 1."; - return false; - } - } - return true; } From 8540ac4f78bf06c9b6a0fe90d71e76c947836817 Mon Sep 17 00:00:00 2001 From: Jing Fang <126209182+fajin-corp@users.noreply.github.com> Date: Wed, 31 Jul 2024 15:30:33 -0700 Subject: [PATCH 36/37] Fix quant_format argument for 4bit quantizer (#21581) ### Description Original argument accepts Enum QuantFormat.QOperator or QuantFormat.QDQ, but the default value is QOperator. Change the argument to str to accept QOperator or QDQ and convert to QuantFormat after parsing. ### Motivation and Context Bug fix --- .../python/tools/quantization/matmul_4bits_quantizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 40a4a4d26dc1c..cc8bd622df9b1 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -797,8 +797,8 @@ def parse_args(): parser.add_argument( "--quant_format", default="QOperator", - type=QuantFormat, - choices=list(QuantFormat), + type=str, + choices=["QOperator", "QDQ"], help="QuantFormat {QOperator, QDQ}" "QOperator format quantizes the model with quantized operators directly." "QDQ format quantize the model by inserting DeQuantizeLinear before the MatMul.", @@ -814,7 +814,7 @@ def parse_args(): input_model_path = args.input_model output_model_path = args.output_model - quant_format = args.quant_format + quant_format = QuantFormat[args.quant_format] if os.path.exists(output_model_path): logger.error(f"file {output_model_path} already exists") From 4b8f6dcbb69ee9c74330d7785fe5b7ef656a94f5 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 31 Jul 2024 21:05:11 -0700 Subject: [PATCH 37/37] [QNN EP] Improve INT4 accuracy (#21582) ### Description Masks off top 4-bits of INT4 weights, improving accuracy. ### Motivation and Context This is a workaround as the QNN docs state masking is not required. --- .../qnn/builder/qnn_model_wrapper.cc | 6 + onnxruntime/test/providers/qnn/conv_test.cc | 5 +- .../test/providers/qnn/matmul_test.cpp | 154 ++++++++++++++++-- 3 files changed, 151 insertions(+), 14 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index c8537307ef3ba..9d3f460572d84 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -617,6 +617,12 @@ Status QnnModelWrapper::UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& auto dst = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), unpacked_tensor.size()); auto src = gsl::make_span(reinterpret_cast(packed_int4_bytes.data()), packed_int4_bytes.size()); ORT_RETURN_IF_NOT(Int4x2::Unpack(dst, src), "Failed to unpack Tensor for QNN"); + + // NOTE: Masking off top 4 bits to workaround a QNN INT4 accuracy bug. + // Docs explicitly state that masking off top 4 bits should not be required. + for (size_t i = 0; i < dst.size(); i++) { + dst[i] &= 0x0F; // -3 (0b1111_1101) becomes 13 (0b0000_1101) + } } else if (onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer); const size_t num_elems = shape.Size(); diff --git a/onnxruntime/test/providers/qnn/conv_test.cc b/onnxruntime/test/providers/qnn/conv_test.cc index 99636976b9c05..35889c9fa2307 100644 --- a/onnxruntime/test/providers/qnn/conv_test.cc +++ b/onnxruntime/test/providers/qnn/conv_test.cc @@ -799,7 +799,7 @@ TEST_F(QnnHTPBackendTests, ConvU16S4S32_PerChannel_NegativeWeightQuantAxis) { // CPU EP (f32 model): 25.143 21.554 17.964 10.785 7.195 3.605 -3.574 -7.164 -10.753 // CPU EP (qdq model): 24.670 21.103 17.536 10.254 6.689 2.972 -4.161 -7.728 -10.700 // QNN EP (qdq model): 27.186 27.186 27.186 21.541 6.685 -8.022 -10.548 -10.548 -10.548 -TEST_F(QnnHTPBackendTests, DISABLED_ConvU16S4S32_PerChannel_AccuracyIssue) { +TEST_F(QnnHTPBackendTests, ConvU16S4S32_PerChannel_AccuracyIssue) { std::vector input_shape = {1, 2, 4, 4}; std::vector weight_shape = {3, 2, 2, 2}; std::vector bias_shape = {3}; @@ -835,7 +835,8 @@ TEST_F(QnnHTPBackendTests, DISABLED_ConvU16S4S32_PerChannel_AccuracyIssue) { "NOTSET", ExpectedEPNodeAssignment::All, false, // use_qdq_contrib_ops - 21); // opset + 21, // opset + QDQTolerance(0.005f)); } // Test per-channel QDQ Conv is rejected with weight axis != 0 diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index dba60b1041696..d8c34d6a6c6ed 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -28,26 +28,25 @@ static GetTestModelFn BuildMatMulOpTestCase(const TestInputDef& input1_de // Returns a function that creates a graph with a QDQ MatMul operator. template -static GetTestQDQModelFn BuildMatMulOpQDQTestCase(const TestInputDef& input1_def, - const TestInputDef& input2_def, +static GetTestQDQModelFn BuildMatMulOpQDQTestCase(const TestInputDef& input0_def, + const TestInputDef& input1_def, bool use_contrib_qdq) { - return [input1_def, input2_def, use_contrib_qdq](ModelTestBuilder& builder, + return [input0_def, input1_def, use_contrib_qdq](ModelTestBuilder& builder, std::vector>& output_qparams) { // input1 -> Q -> DQ -> - NodeArg* input1 = MakeTestInput(builder, input1_def); - QuantParams input1_qparams = GetTestInputQuantParams(input1_def); - auto* input1_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, input1_qparams.zero_point, + NodeArg* input0 = MakeTestInput(builder, input0_def); + QuantParams input0_qparams = GetTestInputQuantParams(input0_def); + auto* input0_qdq = AddQDQNodePair(builder, input0, input0_qparams.scale, input0_qparams.zero_point, use_contrib_qdq); - - // input2 -> Q -> DQ -> - NodeArg* input2 = MakeTestInput(builder, input2_def); - QuantParams input2_qparams = GetTestInputQuantParams(input2_def); - auto* input2_qdq = AddQDQNodePair(builder, input2, input2_qparams.scale, input2_qparams.zero_point, + // input1 -> Q -> DQ -> + NodeArg* input1 = MakeTestInput(builder, input1_def); + QuantParams input1_qparams = GetTestInputQuantParams(input1_def); + auto* input1_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, input1_qparams.zero_point, use_contrib_qdq); // MatMul auto* op_output = builder.MakeIntermediate(); - builder.AddNode("MatMul", {input1_qdq, input2_qdq}, {op_output}); + builder.AddNode("MatMul", {input0_qdq, input1_qdq}, {op_output}); // op_output -> Q -> DQ -> output AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, @@ -55,6 +54,88 @@ static GetTestQDQModelFn BuildMatMulOpQDQTestCase(const TestInputDe }; } +template +static GetTestQDQModelFn BuildQDQPerChannelMatMulTestCase(const TestInputDef& input_def, + const TestInputDef& weights_def, + int64_t weight_quant_axis, + bool use_contrib_qdq = false) { + return [input_def, weights_def, weight_quant_axis, + use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + std::vector matmul_inputs; + + // input -> Q/DQ -> + auto* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + auto* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + matmul_inputs.push_back(input_qdq); + + // Quantized(weights) -> DQ -> + ORT_ENFORCE(weights_def.IsInitializer() && weights_def.IsRawData()); + std::vector weight_scales; + std::vector weight_zero_points; + TensorShape weights_shape = weights_def.GetTensorShape(); + int64_t pos_weight_quant_axis = weight_quant_axis; + if (pos_weight_quant_axis < 0) { + pos_weight_quant_axis += static_cast(weights_shape.NumDimensions()); + } + GetTestInputQuantParamsPerChannel(weights_def, weight_scales, weight_zero_points, + static_cast(pos_weight_quant_axis), true); + + std::vector quantized_weights; + size_t num_weight_storage_elems = weights_shape.Size(); + if constexpr (std::is_same_v || std::is_same_v) { + num_weight_storage_elems = Int4x2::CalcNumInt4Pairs(weights_shape.Size()); + } + quantized_weights.resize(num_weight_storage_elems); + QuantizeValues(weights_def.GetRawData(), quantized_weights, weights_shape, + weight_scales, weight_zero_points, pos_weight_quant_axis); + + NodeArg* weights_initializer = builder.MakeInitializer(weights_def.GetShape(), quantized_weights); + NodeArg* weights_dq = builder.MakeIntermediate(); + Node& weights_dq_node = builder.AddDequantizeLinearNode(weights_initializer, weight_scales, + weight_zero_points, weights_dq, + nullptr, use_contrib_qdq); + weights_dq_node.AddAttribute("axis", weight_quant_axis); + matmul_inputs.push_back(weights_dq); + + auto* matmul_output = builder.MakeIntermediate(); + builder.AddNode("MatMul", matmul_inputs, {matmul_output}); + + AddQDQNodePairWithOutputAsGraphOutput(builder, matmul_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); + }; +} + +// Runs a QDQ per-channel MatMul model on the QNN HTP backend. Checks the graph node assignment, and that the +// QDQ model is accurate on QNN EP (compared to CPU EP). +template +static void RunQDQPerChannelMatMulOpOpTest(const TestInputDef& input_def, + const TestInputDef& weights_def, + int64_t weight_quant_axis, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 21, + bool use_contrib_qdq = false, + QDQTolerance tolerance = QDQTolerance()) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + TestQDQModelAccuracy(BuildMatMulOpTestCase(input_def, weights_def), + BuildQDQPerChannelMatMulTestCase(input_def, + weights_def, + weight_quant_axis, + use_contrib_qdq), + provider_options, + opset, + expected_ep_assignment, + tolerance); +} + // Runs an MatMul model on the QNN CPU backend. Checks the graph node assignment, and that inference // outputs for QNN and CPU match. static void RunMatMulOpOpTest(const TestInputDef& input1_def, @@ -160,6 +241,55 @@ TEST_F(QnnHTPBackendTests, MatMulOp_HTP_A16_W8Static) { true); // Use com.microsoft Q/DQ ops } +// Test QDQ per-channel MatMul with 16-bit act, signed 4-bit weights (static) +TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_A16_WeightInt4) { + std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; + std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; + RunQDQPerChannelMatMulOpOpTest(TestInputDef({1, 1, 2, 3}, false, input0_data), + TestInputDef({1, 1, 3, 2}, true, input1_data), + 1, // quantization axis + ExpectedEPNodeAssignment::All, + 21, + false); +} + +// Test QDQ per-channel MatMul with 16-bit act, unsigned 4-bit weights (static) +TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_A16_WeightUInt4) { + std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; + std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; + RunQDQPerChannelMatMulOpOpTest(TestInputDef({1, 1, 2, 3}, false, input0_data), + TestInputDef({1, 1, 3, 2}, true, input1_data), + 1, // quantization axis + ExpectedEPNodeAssignment::All, + 21, + false); +} + +// Test QDQ per-channel MatMul with int8 act, int4 weights (static) +TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_AS8_WeightInt4) { + std::vector input0_data = GetFloatDataInRange(-5.0f, 5.0f, 6); + std::vector input1_data = {-2.0f, -1.0f, -0.5f, 0.0f, 1.0f, 2.0f}; + RunQDQPerChannelMatMulOpOpTest(TestInputDef({1, 1, 2, 3}, false, input0_data), + TestInputDef({1, 1, 3, 2}, true, input1_data), + 1, // quantization axis + ExpectedEPNodeAssignment::All, + 21, + false, + QDQTolerance(0.007f)); +} + +// Test QDQ per-channel MatMul with 16-bit act, int8 weights (static) +TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_A16_WeightInt8) { + std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; + std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; + RunQDQPerChannelMatMulOpOpTest(TestInputDef({1, 1, 2, 3}, false, input0_data), + TestInputDef({1, 1, 3, 2}, true, input1_data), + 1, // quantization axis + ExpectedEPNodeAssignment::All, + 21, + false); +} + // Test QDQ MatMul with uint16 activation uint16 weights, both dynamic // Inaccuracy detected for output 'output_0', element 1. // Output quant params: scale=0.0015259021893143654, zero_point=0.