From eaf047c82092062e5a5df7f4087a9a5c7ec75f2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 19 Jan 2024 19:36:19 +0100 Subject: [PATCH 01/23] Increment year to 2024 in conf.py (python documentation) (#19107) ### Description Update copyright in python documentation. --- docs/python/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/python/conf.py b/docs/python/conf.py index 065149441b72c..7ab2d42aa15e1 100644 --- a/docs/python/conf.py +++ b/docs/python/conf.py @@ -17,7 +17,7 @@ # -- Project information ----------------------------------------------------- project = "Python API" -copyright = "2018-2023, Microsoft" +copyright = "2018-2024, Microsoft" author = "Microsoft" # -- General configuration --------------------------------------------------- From a3ecb6326711ca368ccc402886dcaf1ef56ff79c Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Fri, 19 Jan 2024 11:09:24 -0800 Subject: [PATCH 02/23] Update LLaMA attention fusions (#19200) ### Description This PR updates the LLaMA-2 attention fusions by adding the following. - Loading the PyTorch model from Hugging Face with the `LlamaAttention` class before exporting - Updating the attention mask pattern matching to support another case This PR also fixes [this issue](https://github.com/microsoft/onnxruntime/issues/19040). ### Motivation and Context Recent changes to Hugging Face's `transformers` library break the existing pattern matching. Since the attention fusions aim to change the graph from `LayerNorm Op --> Set of Attention Nodes --> LayerNorm Op` to `LayerNorm Op --> Attention Op --> LayerNorm Op` per layer, ultimately it does not matter what nodes comprise the `Set of Attention Nodes` because they will all be removed and replaced by the `Attention Op` in the end. Therefore, it does not matter whether the `LlamaAttention` class or a different attention class is used to load the PyTorch model before exporting because the expected graphs after the attention fusions will look identical no matter the attention class chosen. By loading the PyTorch model with the `LlamaAttention` class instead of other attention classes (e.g. `LlamaFlashAttention2` or `LlamaSdpaAttention`) and then exporting it to ONNX, the existing pattern matching will continue to work. --- .../transformers/fusion_rotary_attention.py | 10 ++++++ .../tools/transformers/models/llama/README.md | 31 +++++-------------- .../models/llama/convert_to_onnx.py | 27 ++++++++++++++++ .../transformers/models/llama/llama_torch.py | 1 + .../models/llama/requirements.txt | 2 +- 5 files changed, 46 insertions(+), 25 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py index de89b35366a23..618d3c2fab12c 100644 --- a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py @@ -539,6 +539,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # attn_mask_nodes_1, attn_mask_nodes_2 are for LLaMA-2 Microsoft's 3D attention mask # attn_mask_nodes_3, attn_mask_nodes_4 are for LLaMA-2 Hugging Face's 2D attention mask + # attn_mask_nodes_5, attn_mask_nodes_6 are for LLaMA-2 Microsoft's model for the DML EP + # attn_mask_nodes_7 is for LLaMA-2 Hugging Face's changes to the attention mask attn_mask, add_qk_str = "", "" attn_mask_nodes_1 = self.model.match_parent_path( add_qk, @@ -570,6 +572,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Expand", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], [1, 0, 2, 1, 0, 0, 0], ) + attn_mask_nodes_7 = self.model.match_parent_path( + add_qk, + ["Where", "Cast", "Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], + [1, 0, 0, 0, 0, 1, 0, 0, 0], + ) if attn_mask_nodes_1 is not None: _, slice_mask_1, slice_mask_2 = attn_mask_nodes_1 attn_mask = slice_mask_1.output[0] @@ -588,6 +595,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): elif attn_mask_nodes_6 is not None: # The mask has already been reshaped to (B,N,S,T) add_qk_str = attn_mask_nodes_6[0].output[0] + elif attn_mask_nodes_7 is not None: + # Reshape from (B,1,S,T) to (B,N,S,T) + add_qk_str = self.reshape_add_qk(attn_mask_nodes_7[0].output[0]) else: logger.debug("fuse_rotary_attention: failed to match attention mask nodes") return diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index e7bcc19635f40..f9552e02d74b9 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -42,23 +42,6 @@ $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama To make this option compatible with [Hugging Face's Optimum](https://github.com/huggingface/optimum), you will need to create `config.json` and `generation_config.json` for your model and store them in the same directory as your ONNX models. For example, you can find those JSON files for LLaMA-2 7B on Hugging Face [here](https://huggingface.co/meta-llama/Llama-2-7b-hf). -As indicated in `requirements.txt`, you will also need to install Optimum from source. Once installed, you will need to modify `ORTModelForCausalLM.forward` in `optimum/optimum/onnxruntime/modeling_decoder.py` as follows: - -``` -# Before -if self.use_cache: - if past_key_values is not None: - input_ids = input_ids[:, -1:] - # Flatten the past_key_values (no need to flatten for models using multi-query attn) - - -# After -if self.use_cache: - if past_key_values is not None: - input_ids = input_ids[:, -1:] if past_key_values[0][0].shape[2] != 0 else input_ids - # Flatten the past_key_values (no need to flatten for models using multi-query attn) -``` - ### Option 2: from [Microsoft's custom export](https://github.com/microsoft/Llama-2-Onnx) Please follow the [README instructions](https://github.com/microsoft/Llama-2-Onnx#before-you-start) in the custom export of LLaMA-2. @@ -254,7 +237,7 @@ Here are some examples of how you can benchmark LLaMA-2. 1. PyTorch without `torch.compile`, FP32 ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \ --benchmark-type hf-pt-eager \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp32 \ @@ -266,7 +249,7 @@ python3 -m models.llama.benchmark \ 2. PyTorch with `torch.compile`, FP16 ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \ --benchmark-type hf-pt-compile \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp16 \ @@ -278,7 +261,7 @@ python3 -m models.llama.benchmark \ 3. Optimum + ONNX Runtime, FP32, export via Optimum or convert_to_onnx ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \ --benchmark-type hf-ort \ --hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \ --model-name meta-llama/Llama-2-7b-hf \ @@ -291,7 +274,7 @@ python3 -m models.llama.benchmark \ 4. Optimum + ONNX Runtime, FP16, export via Optimum or convert_to_onnx ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \ --benchmark-type hf-ort \ --hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \ --model-name meta-llama/Llama-2-7b-hf \ @@ -304,7 +287,7 @@ python3 -m models.llama.benchmark \ 5. ONNX Runtime, FP32, Microsoft custom export ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \ --benchmark-type ort-msft \ --ort-model-path ./llama-2-onnx/7B_float32/ONNX/LlamaV2_7B_float32.onnx \ --model-name meta-llama/Llama-2-7b-hf \ @@ -316,7 +299,7 @@ python3 -m models.llama.benchmark \ 6. ONNX Runtime, FP16, Microsoft custom export ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \ --benchmark-type ort-msft \ --ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \ --model-name meta-llama/Llama-2-7b-hf \ @@ -367,7 +350,7 @@ You can profile a variant by adding the `--profile` flag and providing one batch ### Benchmark All You can use `benchmark_all.py` to benchmark across various options and automatically store the results in a CSV file. Here is an example. ``` -python3 -m models.llama.benchmark_all \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark_all \ --hf-pt-eager \ --hf-pt-compile \ --hf-ort-dir-path ./llama2-7b-fp16/ \ diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index bc09b52574a27..71f52faa2c1e6 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -4,6 +4,8 @@ import logging import os import shutil +import subprocess +import sys from itertools import chain import onnx @@ -408,6 +410,31 @@ def optimize_export(config: AutoConfig, input_path: str, output_path: str, remov only_onnxruntime=False, ) model_opt.save_model_to_file(output_path, use_external_data_format=True) + + # Run symbolic shape inference on optimized model to avoid shape errors during runtime + # Ex: Before attention fusion, RotaryEmbedding assumes a 4D input and produces a 4D output. + # After attention fusion, RotaryEmbedding expects a 3D input and produces a 3D output. + wheel_cmd = [sys.executable, "-m", "onnxruntime.tools.symbolic_shape_infer"] + source_cmd = [sys.executable, "../symbolic_shape_infer.py"] + symbolic_shape_infer_args = [ + "--input", + output_path, + "--output", + output_path, + "--auto_merge", + "--save_as_external_data", + "--all_tensors_to_one_file", + "--external_data_location", + os.path.basename(output_path) + ".data", + ] + + file_path = os.path.dirname(__file__) + if os.path.exists(os.path.join(file_path, "../../../tools/symbolic_shape_infer.py")): + main_cmd = wheel_cmd + else: + main_cmd = source_cmd + subprocess.run(main_cmd + symbolic_shape_infer_args) # noqa: PLW1510 + logger.info(f"The ONNX model at {input_path} has been successfully optimized and saved at {output_path}!") if remove_model: remove_existing_model(input_path) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py index 94e0397116d1c..89b459c80beec 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py @@ -21,6 +21,7 @@ def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32, if i == rank % (world_size): l_config = AutoConfig.from_pretrained(location, use_auth_token=use_auth_token, cache_dir=args.cache_dir) l_config.use_cache = True + l_config._attn_implementation = "eager" # "eager" uses LlamaAttention for attention layer llama = AutoModelForCausalLM.from_pretrained( location, use_auth_token=use_auth_token, diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index 4210f36982aef..b72c972e7a16a 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -1,4 +1,4 @@ -git+https://github.com/huggingface/optimum.git +optimum>=1.14.1 transformers>=4.33.2 torch>=2.2.0.dev20230920 onnx>=1.14.0 From 6e17571f2f0cb8eb841395f208d8c31359e2f054 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Fri, 19 Jan 2024 15:16:17 -0800 Subject: [PATCH 03/23] Fix issue that the generated context cache model inputs/outputs order is not guaranteed (#19195) Fix issue that the generated context cache model inputs/outputs order is not guaranteed ### Description Currently, QNN EP generate the context cache model in Compile() method which only get access to the partitioned graph. And the inputs/outputs order for the partitioned graph is not guaranteed. And EP doesn't have the view of the input user model. Have to move the context cache model generation to a higher level in GraphPartitioner which has the view of the partitioned model. This is also a break down of PR for multi-partition support. https://github.com/microsoft/onnxruntime/pull/18865 --- .../core/framework/execution_provider.h | 9 ++ .../onnxruntime_session_options_config_keys.h | 2 +- .../core/framework/graph_partitioner.cc | 105 ++++++++++++++++++ .../core/framework/graph_partitioner.h | 3 + .../qnn/builder/onnx_ctx_model_helper.cc | 13 +-- .../qnn/builder/onnx_ctx_model_helper.h | 3 +- .../providers/qnn/qnn_execution_provider.cc | 16 ++- .../providers/qnn/qnn_execution_provider.h | 4 + onnxruntime/core/session/inference_session.cc | 9 +- .../test/framework/session_state_test.cc | 25 +++-- .../test/providers/qnn/qnn_basic_test.cc | 45 ++++++++ .../test/providers/qnn/simple_op_htp_test.cc | 2 + .../testdata/qnn_ctx_2_inputs_order_test.onnx | Bin 0 -> 2053 bytes 13 files changed, 210 insertions(+), 26 deletions(-) create mode 100644 onnxruntime/test/testdata/qnn_ctx_2_inputs_order_test.onnx diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index ea4f52f99649d..1de0217c7e1fa 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -326,6 +326,15 @@ class IExecutionProvider { */ virtual std::vector CreatePreferredAllocators() { return std::vector(); }; + /** + * Get the array of pointers for EPContext nodes + * EP needs to implement this if has the requirement to generate the context cache model. Otherwise leave it. + * Default return an empty vector if not provided by the Execution Provider + */ + virtual const InlinedVector GetEpContextNodes() const { + return InlinedVector(); + } + private: const std::string type_; diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index df79cb6e5b21b..8fd51962bf087 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -236,7 +236,7 @@ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFil static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes = "session.optimized_model_external_initializers_min_size_in_bytes"; -// Enable EP context feature to dump the partitioned graph which include the EP context into Onnx file. +// Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file. // The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead. // "0": disable. (default) // "1": enable. diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index e4fe0c7564548..07b465c80745a 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -16,6 +16,7 @@ #include "core/graph/function_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" +#include "core/session/onnxruntime_session_options_config_keys.h" // uncomment this line to count non-CUDA ops in ONNX domain // #define COUNT_NON_CUDA_OPS @@ -634,6 +635,100 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide return Status::OK(); } +static Status CreateEpContextModel(const ExecutionProviders& execution_providers, + const Graph& graph, + const std::string& ep_context_path, + const logging::Logger& logger) { + InlinedVector all_ep_context_nodes; + for (const auto& ep : execution_providers) { + const InlinedVector ep_context_nodes = ep->GetEpContextNodes(); + all_ep_context_nodes.insert(all_ep_context_nodes.begin(), ep_context_nodes.begin(), ep_context_nodes.end()); + } + + auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair { + for (auto& node : all_ep_context_nodes) { + if (node_name == node->Name()) { + return std::make_pair(true, node); + } + } + return std::make_pair(false, static_cast(nullptr)); + }; + + onnxruntime::PathString context_cache_path; + PathString model_pathstring = graph.ModelPath().ToPathString(); + if (all_ep_context_nodes.size() > 0) { + if (!ep_context_path.empty()) { + context_cache_path = ToPathString(ep_context_path); + } else if (!model_pathstring.empty()) { + context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + } + + { +#ifdef _WIN32 + std::wifstream fs(context_cache_path); +#else + std::ifstream fs(context_cache_path); +#endif + ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already."); + } + + Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + graph.DomainToVersionMap(), {}, logger); + auto& ep_graph = ep_context_model.MainGraph(); + ep_graph.SetDescription(graph.Description()); + + // Set inputs outputs explicitly to make sure the order is same as the user model. + auto inputs = graph.GetInputs(); + auto outputs = graph.GetOutputs(); + + InlinedVector ep_graph_inputs; + ep_graph_inputs.reserve(inputs.size()); + for (auto& input : inputs) { + auto input_arg = graph.GetNodeArg(input->Name()); + auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); + ep_graph_inputs.push_back(&ep_graph_input_arg); + } + + InlinedVector ep_graph_outputs; + ep_graph_outputs.reserve(outputs.size()); + for (auto& output : outputs) { + auto output_arg = graph.GetNodeArg(output->Name()); + auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); + ep_graph_outputs.push_back(&ep_graph_output_arg); + } + + ep_graph.SetInputs(ep_graph_inputs); + ep_graph.SetOutputs(ep_graph_outputs); + + for (const auto& node : graph.Nodes()) { + // the fused node and EPContext node has same node name + auto ep_context_node = get_ep_context_node(node.Name()); + // Use EpContext node created by the EPs if name matched, otherwise use node from original model + if (ep_context_node.first) { + ep_graph.AddNode(*ep_context_node.second); + } else { + ep_graph.AddNode(node); + } + } + + // handle initializers + for (const auto& input : graph.GetInputsIncludingInitializers()) { + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + if (graph.GetInitializedTensor(input->Name(), initializer)) { + // There initializer could have duplicates so make sure we only add once + const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; + if (!ep_graph.GetInitializedTensor(input->Name(), subgraph_initializer)) { + ep_graph.AddInitializedTensor(*initializer); + } + } + } + + ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path)); + } + + return Status::OK(); +} + static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode, const ExecutionProviders& execution_providers, KernelRegistryManager& kernel_registry_manager) { @@ -840,6 +935,8 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_function, + const ConfigOptions& config_options, + const logging::Logger& logger, Mode mode, const layout_transformation::DebugGraphFn& debug_graph_fn) const { // It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now. @@ -886,7 +983,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, #if !defined(ORT_MINIMAL_BUILD) ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_)); + + bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; + std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + if (ep_context_enabled) { + ORT_RETURN_IF_ERROR(CreateEpContextModel(providers_, graph, ep_context_path, logger)); + } #else + ORT_UNUSED_PARAMETER(config_options); + ORT_UNUSED_PARAMETER(logger); return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build."); #endif //! defined(ORT_MINIMAL_BUILD) } else { diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index 4fc85c2588260..d1ef193cf1520 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -13,6 +13,7 @@ namespace onnxruntime { class ExecutionProviders; class KernelRegistryManager; class Model; +struct ConfigOptions; class GraphPartitioner { public: @@ -31,6 +32,8 @@ class GraphPartitioner { // Run partitioning. Status Partition(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_function, + const ConfigOptions& config_options, + const logging::Logger& logger, Mode mode = Mode::kNormal, const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const; diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index fd9bf200c45ef..5d3f406f50612 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -230,8 +230,7 @@ Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path return Status::OK(); } -Status GenerateCtxCacheOnnxModel(const std::string model_name, - const std::string model_description, +Status GenerateCtxCacheOnnxModel(Model* model, unsigned char* buffer, uint64_t buffer_size, const std::string& sdk_build_version, @@ -240,11 +239,7 @@ Status GenerateCtxCacheOnnxModel(const std::string model_name, const onnxruntime::PathString& context_cache_path, bool qnn_context_embed_mode, const logging::Logger& logger) { - std::unordered_map domain_to_version = {{kOnnxDomain, 11}, {kMSDomain, 1}}; - Model model(model_name, false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), - domain_to_version, {}, logger); - auto& graph = model.MainGraph(); - graph.SetDescription(model_description); + auto& graph = model->MainGraph(); using namespace ONNX_NAMESPACE; int index = 0; @@ -270,7 +265,7 @@ Status GenerateCtxCacheOnnxModel(const std::string model_name, nullptr, kMSDomain); - // Only dump the context buffer once since all QNN graph are in one single context + // Only dump the context buffer once since all QNN graphs are in one single context if (0 == index) { if (qnn_context_embed_mode) { std::string cache_payload(buffer, buffer + buffer_size); @@ -296,8 +291,6 @@ Status GenerateCtxCacheOnnxModel(const std::string model_name, ep_node.AddAttribute(SOURCE, kQnnExecutionProvider); ++index; } - ORT_RETURN_IF_ERROR(graph.Resolve()); - ORT_RETURN_IF_ERROR(Model::Save(model, context_cache_path)); return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index 0011d0f43f5bc..ba6fe23ecd56e 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -73,8 +73,7 @@ Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_mod std::string& cache_source, const logging::Logger& logger); -Status GenerateCtxCacheOnnxModel(const std::string model_name, - const std::string model_description, +Status GenerateCtxCacheOnnxModel(Model* model, unsigned char* buffer, uint64_t buffer_size, const std::string& sdk_build_version, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 04bd58c237141..56eb1f4f59f33 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -613,8 +613,8 @@ Status QNNExecutionProvider::Compile(const std::vector& fused ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); uint64_t buffer_size(0); auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size); - ORT_RETURN_IF_ERROR(qnn::GenerateCtxCacheOnnxModel(model_name, - model_description, + qnn_ep_context_model_ = std::make_unique("qnn_ep_context_model", false, logger); + ORT_RETURN_IF_ERROR(qnn::GenerateCtxCacheOnnxModel(qnn_ep_context_model_.get(), context_buffer.get(), buffer_size, qnn_backend_manager_->GetSdkVersion(), @@ -626,4 +626,16 @@ Status QNNExecutionProvider::Compile(const std::vector& fused } return Status::OK(); } + +const InlinedVector QNNExecutionProvider::GetEpContextNodes() const { + InlinedVector ep_context_nodes; + if (qnn_ep_context_model_) { + const auto& graph = qnn_ep_context_model_->MainGraph(); + for (const auto& node : graph.Nodes()) { + ep_context_nodes.push_back(graph.GetNode(node.Index())); + } + } + + return ep_context_nodes; +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 8b5d0929209ee..d4927f3fa505e 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -9,6 +9,7 @@ #include "core/providers/qnn/builder/qnn_backend_manager.h" #include "core/providers/qnn/builder/qnn_model.h" #include "core/providers/qnn/builder/qnn_graph_configs_helper.h" +#include "core/graph/model.h" namespace onnxruntime { @@ -35,6 +36,8 @@ class QNNExecutionProvider : public IExecutionProvider { DataLayout GetPreferredLayout() const override; + const InlinedVector GetEpContextNodes() const override; + private: bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::unordered_map& node_unit_supported_result, @@ -66,6 +69,7 @@ class QNNExecutionProvider : public IExecutionProvider { bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session. bool qnn_context_embed_mode_ = true; int32_t vtcm_size_in_mb_ = 0; + std::unique_ptr qnn_ep_context_model_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 93877c8dd66bd..e8853c8824738 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1164,6 +1164,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool // Do partitioning based on execution providers' capabilities. ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state_->GetMutableFuncMgr(), transform_layout_fn, + session_options_.config_options, *session_logger_, mode, debug_graph_fn)); // apply Level2 and higher transformers. @@ -1458,7 +1459,9 @@ namespace { Status PartitionOrtFormatModel(onnxruntime::Graph& graph, const ExecutionProviders& providers, KernelRegistryManager& kernel_registry_manager, - SessionState& session_state) { + SessionState& session_state, + const ConfigOptions& config_options, + const logging::Logger& logger) { layout_transformation::TransformLayoutFunction transform_layout_fn = nullptr; #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1479,6 +1482,8 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.GetMutableFuncMgr(), transform_layout_fn, + config_options, + logger, GraphPartitioner::Mode::kOrtFormatLoad)); return Status::OK(); @@ -1833,7 +1838,7 @@ common::Status InferenceSession::Initialize() { #endif // !defined(ORT_MINIMAL_BUILD) } else { ORT_RETURN_IF_ERROR_SESSIONID_(PartitionOrtFormatModel(graph, execution_providers_, kernel_registry_manager_, - *session_state_)); + *session_state_, session_options_.config_options, *session_logger_)); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 8990c23e4af39..0c2d8bcb2eb93 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -171,13 +171,16 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { GraphPartitioner partitioner(krm, execution_providers); ASSERT_STATUS_OK( - partitioner.Partition(graph, session_state.GetMutableFuncMgr(), - [](Graph& graph, bool& modified, const IExecutionProvider& execution_provider, - const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status { - AllocatorPtr cpu_allocator = std::make_shared(); - return layout_transformation::TransformLayoutForEP( - graph, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn); - })); + partitioner.Partition( + graph, session_state.GetMutableFuncMgr(), + [](Graph& graph, bool& modified, const IExecutionProvider& execution_provider, + const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status { + AllocatorPtr cpu_allocator = std::make_shared(); + return layout_transformation::TransformLayoutForEP( + graph, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn); + }, + sess_options.config_options, + DefaultLoggingManager().DefaultLogger())); ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm)); @@ -257,7 +260,9 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status { return layout_transformation::TransformLayoutForEP(graph, modified, execution_provider, cpu_allocator, debug_graph_fn); - })); + }, + sess_options.config_options, + DefaultLoggingManager().DefaultLogger())); ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm)); @@ -314,7 +319,9 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status { return layout_transformation::TransformLayoutForEP( graph, modified, execution_provider, cpu_allocator, debug_graph_fn); - })); + }, + sess_options.config_options, + DefaultLoggingManager().DefaultLogger())); // Finalize the session state ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm)); diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index f9064cad3fe12..bc40682cf87b7 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -600,6 +600,51 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) { // Make sure the Qnn context cache binary file is generated EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + // clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + +// Generate context cache model from the ONNX models with 2 inputs. +// The generated model should have same input order. +// The input ONNX model is created in the way that the model inputs order +// is different with the order in the graph (topological order). +// It cause issue if the generated model doesn't set the inputs/outputs explicitly. +TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + // Add kMSDomain to cover contrib op like Gelu + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); + auto inputs = model->MainGraph().GetInputs(); + EXPECT_TRUE(inputs.size() == 2); + EXPECT_TRUE(inputs[0]->Name() == "attention_mask"); + EXPECT_TRUE(inputs[1]->Name() == "Add_input_0"); + + // clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } // A repro of QC case 06838696, accuracy issue for Cast + Op (quantized) diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 4ac1f5ddca643..1e938ae9e334b 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -778,6 +778,8 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheEmbedModeTest) { QDQTolerance(), logging::Severity::kERROR, context_binary_file); + // Clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } // Run QDQ model on HTP 3 times diff --git a/onnxruntime/test/testdata/qnn_ctx_2_inputs_order_test.onnx b/onnxruntime/test/testdata/qnn_ctx_2_inputs_order_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..46b212dc1fc0e9c2f4e6bd32424fcbd90a2b6c9f GIT binary patch literal 2053 zcmb7FKX21O6z3dTxsQ@MPASAq(-@W*sxr4LbZiXyW8$O z{j&n+2<{#9`^2<{W0!QGk~)yld*g({T3R%rj!lNPX}M}MEq@b*eq~zJaLD1<_2Oxo z*rZj?y1)g3su3VDNteV>`>>J-Lp-aAAdM!GKBuuagGOx9QdlSWj-YI~F7+6*Eiy1h zpI|k6j`*oD(iEs2MwPvC%+kh8s~k~35EN5?n?i0MBn?1V9%7L7Sw~bV5;IJ_@F!bXg(qPZz9N`q3(HZU zaNJa)Q>rR;ex}24=sn~Qw{g0S2fjdx`vOL9q(bl#+Yxp;JRb#-%twJ-+n zLEuMz>epXRdph%e@K)$0p2oQ0`~Lbj+1I&)+>0Gx&leo8`JR6-FME2XH;#AM6``u2 z$VgoVNk-l$d0*+cnOU(slXrE9s>#!SY$RiYr`IK=>TCi8JKnTDP)Du!|IXOkT>~1f F{r^R&&7%MS literal 0 HcmV?d00001 From c8ce83967e5b52062558046d769f3af7d871e893 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Fri, 19 Jan 2024 15:30:09 -0800 Subject: [PATCH 04/23] Download protoc for all Apple host builds, remove protoc build from iOS packaging pipeline. (#19209) --- .../external/onnxruntime_external_deps.cmake | 74 ++++++++++--------- .../stages/mac-ios-packaging-build-stage.yml | 7 +- 2 files changed, 40 insertions(+), 41 deletions(-) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 78f63227c8392..403b4b2c4107a 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -108,41 +108,14 @@ FetchContent_Declare( ) # Download a protoc binary from Internet if needed -if(CMAKE_CROSSCOMPILING AND NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) +if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) # This part of code is only for users' convenience. The code couldn't handle all cases. Users always can manually # download protoc from Protobuf's Github release page and pass the local path to the ONNX_CUSTOM_PROTOC_EXECUTABLE # variable. - message("CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}") - if(CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") - if(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "AMD64") - FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win64} URL_HASH SHA1=${DEP_SHA1_protoc_win64}) - FetchContent_Populate(protoc_binary) - elseif(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86") - FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win32} URL_HASH SHA1=${DEP_SHA1_protoc_win32}) - FetchContent_Populate(protoc_binary) - endif() - if(protoc_binary_SOURCE_DIR) - message("Use prebuilt protoc") - set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc.exe) - set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) - endif() - elseif(CMAKE_HOST_SYSTEM_NAME STREQUAL "Linux") - if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") - FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x64}) - FetchContent_Populate(protoc_binary) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") - FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x86} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x86}) - FetchContent_Populate(protoc_binary) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") - FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_aarch64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_aarch64}) - FetchContent_Populate(protoc_binary) - endif() - if(protoc_binary_SOURCE_DIR) - message("Use prebuilt protoc") - set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) - set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) - endif() - elseif ((CMAKE_SYSTEM_NAME STREQUAL "Emscripten" OR CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") AND CMAKE_HOST_SYSTEM_NAME STREQUAL "Darwin") + if (CMAKE_HOST_APPLE) + # Using CMAKE_CROSSCOMPILING is not recommended for Apple target devices. + # https://cmake.org/cmake/help/v3.26/variable/CMAKE_CROSSCOMPILING.html + # To keep it simple, just download and use the universal protoc binary for all Apple host builds. FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_mac_universal} URL_HASH SHA1=${DEP_SHA1_protoc_mac_universal}) FetchContent_Populate(protoc_binary) if(protoc_binary_SOURCE_DIR) @@ -150,6 +123,38 @@ if(CMAKE_CROSSCOMPILING AND NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) endif() + elseif (CMAKE_CROSSCOMPILING) + message("CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}") + if(CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") + if(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "AMD64") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win64} URL_HASH SHA1=${DEP_SHA1_protoc_win64}) + FetchContent_Populate(protoc_binary) + elseif(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win32} URL_HASH SHA1=${DEP_SHA1_protoc_win32}) + FetchContent_Populate(protoc_binary) + endif() + if(protoc_binary_SOURCE_DIR) + message("Use prebuilt protoc") + set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc.exe) + set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) + endif() + elseif(CMAKE_HOST_SYSTEM_NAME STREQUAL "Linux") + if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x64}) + FetchContent_Populate(protoc_binary) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x86} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x86}) + FetchContent_Populate(protoc_binary) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_aarch64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_aarch64}) + FetchContent_Populate(protoc_binary) + endif() + if(protoc_binary_SOURCE_DIR) + message("Use prebuilt protoc") + set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) + set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) + endif() + endif() endif() endif() @@ -184,9 +189,9 @@ FetchContent_Declare( ) set(protobuf_BUILD_TESTS OFF CACHE BOOL "Build protobuf tests" FORCE) -#TODO: we'd better to turn the following option off. However, it will cause +#TODO: we'd better to turn the following option off. However, it will cause # ".\build.bat --config Debug --parallel --skip_submodule_sync --update" fail with an error message: -# install(EXPORT "ONNXTargets" ...) includes target "onnx_proto" which requires target "libprotobuf-lite" that is +# install(EXPORT "ONNXTargets" ...) includes target "onnx_proto" which requires target "libprotobuf-lite" that is # not in any export set. #set(protobuf_INSTALL OFF CACHE BOOL "Install protobuf binaries and files" FORCE) set(protobuf_USE_EXTERNAL_GTEST ON CACHE BOOL "" FORCE) @@ -562,4 +567,3 @@ endif() FILE(TO_NATIVE_PATH ${CMAKE_BINARY_DIR} ORT_BINARY_DIR) FILE(TO_NATIVE_PATH ${PROJECT_SOURCE_DIR} ORT_SOURCE_DIR) - diff --git a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml index d1dff0769e25f..ed32c5d0e15be 100644 --- a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml @@ -78,10 +78,6 @@ stages: pip install -r tools/ci_build/github/apple/ios_packaging.requirements.txt displayName: "Install Python requirements" - - script: | - $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/install_protobuf.sh -p $(Build.BinariesDirectory)/protobuf_install -d $(Build.SourcesDirectory)/cmake/deps.txt - displayName: "Build Host Protoc" - # create and test mobile pods - script: | python tools/ci_build/github/apple/build_and_assemble_apple_pods.py \ @@ -91,8 +87,7 @@ stages: --test \ --variant ${{ parameters.packageVariant }} \ --build-settings-file "${{ variables.buildSettingsFile }}" \ - ${{ variables.optionalIncludeOpsByConfigOption }} \ - -b="--path_to_protoc_exe=$(Build.BinariesDirectory)/protobuf_install/bin/protoc" + ${{ variables.optionalIncludeOpsByConfigOption }} displayName: "Build macOS/iOS framework and assemble pod package files" - script: | From f3402de01e732283283aaa208022d6c7ae85ca4a Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Sun, 21 Jan 2024 10:51:58 -0800 Subject: [PATCH 05/23] [TensorRT EP] Enhance EP context configs in session options and provider options (#19154) Several changes: 1. To align with other EPs' setting of EP context configs in session options, for example [QNN EP](https://github.com/microsoft/onnxruntime/pull/18877), EP context configs for TRT EP can be configured through: 1. Session Options: `ep.context_enable`, `ep.context_file_path` and `ep.context_embed_mode` 2. Provider Options: `trt_dump_ep_context_model`, `trt_ep_context_file_path` and `trt_dump_ep_context_embed_mode` 3. Above setting has 1:1 mapping and provider options has higher priority over session options. ``` Please note that there are rules for using following context model related provider options: 1. In the case of dumping the context model and loading the context model, for security reason, TRT EP doesn't allow the "ep_cache_context" node attribute of EP context node to be the absolute path or relative path that is outside of context model directory. It means engine cache needs to be in the same directory or sub-directory of context model. 2. In the case of dumping the context model, the engine cache path will be changed to the relative path of context model directory. For example: If "trt_dump_ep_context_model" is enabled and "trt_engine_cache_enable" is enabled, if "trt_ep_context_file_path" is "./context_model_dir", - if "trt_engine_cache_path" is "" -> the engine cache will be saved to "./context_model_dir" - if "trt_engine_cache_path" is "engine_dir" -> the engine cache will be saved to "./context_model_dir/engine_dir" ``` 2. User can decide the naming of the dumped "EP context" model by using `trt_ep_context_file_path`, please see GetCtxModelPath() for more details. 3. Added suggested comments from https://github.com/microsoft/onnxruntime/pull/18217 --- .../tensorrt/tensorrt_provider_options.h | 28 ++- .../tensorrt/onnx_ctx_model_helper.cc | 211 +++++++++++++----- .../tensorrt/onnx_ctx_model_helper.h | 34 +-- .../tensorrt/tensorrt_execution_provider.cc | 153 ++++++++----- .../tensorrt/tensorrt_execution_provider.h | 5 +- .../tensorrt_execution_provider_info.cc | 13 +- .../tensorrt_execution_provider_info.h | 2 +- .../tensorrt/tensorrt_provider_factory.cc | 9 +- .../core/session/provider_bridge_ort.cc | 87 +++++++- .../python/onnxruntime_pybind_state.cc | 17 +- .../gen_trt_engine_wrapper_onnx_model.py | 19 +- .../providers/tensorrt/tensorrt_basic_test.cc | 208 ++++++++++++++++- 12 files changed, 624 insertions(+), 162 deletions(-) diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 60196d0c80cbb..32a9f06464ace 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -11,6 +11,8 @@ /// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions. /// struct OrtTensorRTProviderOptionsV2 { + OrtTensorRTProviderOptionsV2& operator=(const OrtTensorRTProviderOptionsV2& other); // copy assignment operator + int device_id{0}; // cuda device id. int has_user_compute_stream{0}; // indicator of user specified CUDA compute stream. void* user_compute_stream{nullptr}; // user specified CUDA compute stream. @@ -46,8 +48,26 @@ struct OrtTensorRTProviderOptionsV2 { const char* trt_profile_max_shapes{nullptr}; // Specify the range of the input shapes to build the engine with const char* trt_profile_opt_shapes{nullptr}; // Specify the range of the input shapes to build the engine with int trt_cuda_graph_enable{0}; // Enable CUDA graph in ORT TRT - int trt_dump_ep_context_model{0}; // Dump EP context node model - int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data - int trt_ep_context_compute_capability_enable{1}; // Add GPU compute capability as an EP context node's attribute - const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix + + /* + * Please note that there are rules for using following context model related provider options: + * + * 1. In the case of dumping the context model and loading the context model, + * for security reason, TRT EP doesn't allow the "ep_cache_context" node attribute of EP context node to be + * the absolute path or relative path that is outside of context model directory. + * It means engine cache needs to be in the same directory or sub-directory of context model. + * + * 2. In the case of dumping the context model, the engine cache path will be changed to the relative path of context model directory. + * For example: + * If "trt_dump_ep_context_model" is enabled and "trt_engine_cache_enable" is enabled, + * if "trt_ep_context_file_path" is "./context_model_dir", + * - if "trt_engine_cache_path" is "" -> the engine cache will be saved to "./context_model_dir" + * - if "trt_engine_cache_path" is "engine_dir" -> the engine cache will be saved to "./context_model_dir/engine_dir" + * + */ + int trt_dump_ep_context_model{0}; // Dump EP context node model + const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path. + int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data + + const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix }; diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc index 4d8ba6a0891e3..1994d1f5ab0b8 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc @@ -38,13 +38,6 @@ const onnxruntime::Path& GetModelPath(const GraphViewer& graph_viewer) { return main_graph.ModelPath(); } -std::filesystem::path LocateEngineRelativeToPath(std::string engine_cache_path, const onnxruntime::Path& path) { - std::filesystem::path base_path(path.ToPathString()); - std::filesystem::path parent_path = base_path.parent_path(); - std::filesystem::path engine_path = parent_path.append(engine_cache_path); - return engine_path; -} - /* * Update ep_cache_context attribute of the EP context node with the given engine binary data */ @@ -69,14 +62,13 @@ void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto, /* * Create "EP context node" model where engine information is embedded */ -ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer, - const std::string engine_cache_path, - char* engine_data, - size_t size, - const int64_t embed_mode, - bool compute_capability_enable, - std::string compute_capability, - const logging::Logger* logger) { +ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, + const std::string engine_cache_path, + char* engine_data, + size_t size, + const int64_t embed_mode, + std::string compute_capability, + const logging::Logger* logger) { auto model_build = graph_viewer.CreateModel(*logger); auto& graph_build = model_build->MainGraph(); @@ -107,21 +99,20 @@ ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer, engine_data_str.assign(engine_data, size); } attr_1->set_s(engine_data_str); + LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING; } else { attr_1->set_s(engine_cache_path); } + attr_2->set_name(COMPUTE_CAPABILITY); + attr_2->set_type(onnx::AttributeProto_AttributeType_STRING); + attr_2->set_s(compute_capability); + auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create(); - int num_attributes = compute_capability_enable ? 3 : 2; + int num_attributes = 3; node_attributes->reserve(num_attributes); node_attributes->emplace(EMBED_MODE, *attr_0); node_attributes->emplace(EP_CACHE_CONTEXT, *attr_1); - - if (compute_capability_enable) { - attr_2->set_name(COMPUTE_CAPABILITY); - attr_2->set_type(onnx::AttributeProto_AttributeType_STRING); - attr_2->set_s(compute_capability); - node_attributes->emplace(COMPUTE_CAPABILITY, *attr_2); - } + node_attributes->emplace(COMPUTE_CAPABILITY, *attr_2); // Create EP context node graph_build.AddNode(EPCONTEXT_OP, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN); @@ -138,14 +129,111 @@ ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer, } /* - * Dump "EP context node" model + * Return the directory where the ep context model locates + */ +std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path) { + if (ep_context_file_path.empty()) { + return std::filesystem::path(); + } + std::filesystem::path ctx_path(ep_context_file_path); + if (std::filesystem::is_directory(ep_context_file_path)) { + return ctx_path; + } else { + return ctx_path.parent_path(); + } +} + +/* + * Get "EP context" model path. + * + * Function logic: + * If ep_context_file_path is provided, + * - If ep_context_file_path is a file, return "ep_context_file_path". + * - If ep_context_file_path is a directory, return "ep_context_file_path/original_model_name_ctx.onnx". + * If ep_context_file_path is not provided, + * - Return "original_model_name_ctx.onnx". + * + * TRT EP has rules about context model path and engine cache path (see tensorrt_execution_provider.cc): + * - If dump_ep_context_model_ and engine_cache_enabled_ is enabled, TRT EP will dump context model and save engine cache + * to the same directory provided by ep_context_file_path_. (i.e. engine_cache_path_ = ep_context_file_path_) + * + * Example 1: + * ep_context_file_path = "/home/user/ep_context_model_directory" + * original_model_path = "model.onnx" + * => return "/home/user/ep_context_model_folder/model_ctx.onnx" + * + * Example 2: + * ep_context_file_path = "my_ctx_model.onnx" + * original_model_path = "model.onnx" + * => return "my_ctx_model.onnx" + * + * Example 3: + * ep_context_file_path = "/home/user2/ep_context_model_directory/my_ctx_model.onnx" + * original_model_path = "model.onnx" + * => return "/home/user2/ep_context_model_directory/my_ctx_model.onnx" + * + */ +std::string GetCtxModelPath(const std::string& ep_context_file_path, + const std::string& original_model_path) { + std::string ctx_model_path; + + if (!ep_context_file_path.empty() && !std::filesystem::is_directory(ep_context_file_path)) { + ctx_model_path = ep_context_file_path; + } else { + std::filesystem::path model_path = original_model_path; + std::filesystem::path model_name_stem = model_path.stem(); // model_name.onnx -> model_name + std::string ctx_model_name = model_name_stem.string() + "_ctx.onnx"; + + if (std::filesystem::is_directory(ep_context_file_path)) { + std::filesystem::path model_directory = ep_context_file_path; + ctx_model_path = model_directory.append(ctx_model_name).string(); + } else { + ctx_model_path = ctx_model_name; + } + } + return ctx_model_path; +} + +/* + * Dump "EP context" model * */ -void DumpCtxNodeModel(ONNX_NAMESPACE::ModelProto* model_proto, - const std::string engine_cache_path) { - std::fstream dump(engine_cache_path + "_wrapper.onnx", std::ios::out | std::ios::trunc | std::ios::binary); +void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto, + const std::string& ctx_model_path) { + std::fstream dump(ctx_model_path, std::ios::out | std::ios::trunc | std::ios::binary); model_proto->SerializeToOstream(dump); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path + "_wrapper.onnx"; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Dumped " + ctx_model_path; +} + +bool IsAbsolutePath(std::string& path_string) { +#ifdef _WIN32 + onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string); + auto path = std::filesystem::path(ort_path_string.c_str()); + return path.is_absolute(); +#else + if (!path_string.empty() && path_string[0] == '/') { + return true; + } + return false; +#endif +} + +// Like "../file_path" +bool IsRelativePathToParentPath(std::string& path_string) { +#ifdef _WIN32 + onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string); + auto path = std::filesystem::path(ort_path_string.c_str()); + auto relative_path = path.lexically_normal().make_preferred().wstring(); + if (relative_path.find(L"..", 0) != std::string::npos) { + return true; + } + return false; +#else + if (!path_string.empty() && path_string.find("..", 0) != std::string::npos) { + return true; + } + return false; +#endif } Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) { @@ -157,7 +245,7 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph const int64_t embed_mode = attrs.at(EMBED_MODE).i(); if (embed_mode) { - // Get engine from byte stream + // Get engine from byte stream. const std::string& context_binary = attrs.at(EP_CACHE_CONTEXT).s(); *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(const_cast(context_binary.c_str()), static_cast(context_binary.length()))); @@ -167,19 +255,41 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph "TensorRT EP could not deserialize engine from binary data"); } } else { - // Get engine from cache file - std::ifstream engine_file(engine_cache_path_.string(), std::ios::binary | std::ios::in); + // Get engine from cache file. + std::string cache_path = attrs.at(EP_CACHE_CONTEXT).s(); + + // For security purpose, in the case of running context model, TRT EP won't allow + // engine cache path to be the relative path like "../file_path" or the absolute path. + // It only allows the engine cache to be in the same directory or sub directory of the context model. + if (IsAbsolutePath(cache_path)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "For security purpose, the ep_cache_context attribute should be set with a relative path, but it is an absolute path: " + cache_path); + } + if (IsRelativePathToParentPath(cache_path)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "The file path in ep_cache_context attribute has '..'. For security purpose, it's not allowed to point outside the directory."); + } + + // The engine cache and context model (current model) should be in the same directory + std::filesystem::path ctx_model_dir(GetPathOrParentPathOfCtxModel(ep_context_model_path_)); + auto engine_cache_path = ctx_model_dir.append(cache_path); + + if (!std::filesystem::exists(engine_cache_path)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP can't find engine cache: " + engine_cache_path.string() + + ". Please make sure engine cache is in the same directory or sub-directory of context model."); + } + + std::ifstream engine_file(engine_cache_path.string(), std::ios::binary | std::ios::in); engine_file.seekg(0, std::ios::end); size_t engine_size = engine_file.tellg(); engine_file.seekg(0, std::ios::beg); std::unique_ptr engine_buf{new char[engine_size]}; engine_file.read((char*)engine_buf.get(), engine_size); *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path_.string(); if (!(*trt_engine_)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not deserialize engine from cache: " + engine_cache_path_.string()); + "TensorRT EP could not deserialize engine from cache: " + engine_cache_path.string()); } + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path.string(); } return Status::OK(); } @@ -193,37 +303,26 @@ bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewe auto node = graph_viewer.GetNode(0); auto& attrs = node->GetAttributes(); - // Check hardware_architecture(compute_capability) if it's present as an attribute + // Show the warning if compute capability is not matched if (attrs.count(COMPUTE_CAPABILITY) > 0) { std::string model_compute_capability = attrs.at(COMPUTE_CAPABILITY).s(); if (model_compute_capability != compute_capability_) { - LOGS_DEFAULT(ERROR) << "The compute capability of the engine cache doesn't match with the GPU's compute capability"; - LOGS_DEFAULT(ERROR) << "The compute capability of the engine cache: " << model_compute_capability; - LOGS_DEFAULT(ERROR) << "The compute capability of the GPU: " << compute_capability_; - return false; + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine was compiled for a different compatibility level and might not work or perform suboptimal"; + LOGS_DEFAULT(WARNING) << "[TensorRT EP] The compute capability of the engine: " << model_compute_capability; + LOGS_DEFAULT(WARNING) << "[TensorRT EP] The compute capability of the GPU: " << compute_capability_; } } // "embed_mode" attr and "ep_cache_context" attr should be present - if (attrs.count(EMBED_MODE) > 0 && attrs.count(EP_CACHE_CONTEXT) > 0) { - // ep_cache_context: payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0 - const int64_t embed_mode = attrs.at(EMBED_MODE).i(); - - // engine cache path - if (embed_mode == 0) { - // First assume engine cache path is relatvie to model path, - // If not, then assume the engine cache path is an absolute path. - engine_cache_path_ = LocateEngineRelativeToPath(attrs.at(EP_CACHE_CONTEXT).s(), GetModelPath(graph_viewer)); - auto default_engine_cache_path_ = engine_cache_path_; - if (!std::filesystem::exists(engine_cache_path_)) { - engine_cache_path_.assign(attrs.at(EP_CACHE_CONTEXT).s()); - if (!std::filesystem::exists(engine_cache_path_)) { - LOGS_DEFAULT(ERROR) << "Can't find " << default_engine_cache_path_.string() << " or " << engine_cache_path_.string() << " TensorRT engine"; - return false; - } - } - } + assert(attrs.count(EMBED_MODE) > 0); + assert(attrs.count(EP_CACHE_CONTEXT) > 0); + + const int64_t embed_mode = attrs.at(EMBED_MODE).i(); + if (embed_mode == 1) { + // engine binary data + LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING; } + return true; } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h index ab6ea733adfa1..bf3bf9e3495d7 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h @@ -16,20 +16,27 @@ static const std::string EMBED_MODE = "embed_mode"; static const std::string EP_CACHE_CONTEXT = "ep_cache_context"; static const std::string COMPUTE_CAPABILITY = "hardware_architecture"; static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft"; +static const std::string EPCONTEXT_WARNING = + "It's suggested to set the ORT graph optimization level to 0 and \ + make \"embed_mode\" to 0 (\"ep_cache_context\" is the cache path)\ + for the best model loading time"; bool GraphHasCtxNode(const GraphViewer& graph_viewer); const onnxruntime::Path& GetModelPath(const GraphViewer& graph_viewer); -std::filesystem::path LocateEngineRelativeToPath(std::string engine_cache_path, const onnxruntime::Path& path); -ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer, - const std::string engine_cache_path, - char* engine_data, - size_t size, - const int64_t embed_mode, - bool compute_capability_enable, - std::string compute_capability, - const logging::Logger* logger); -void DumpCtxNodeModel(ONNX_NAMESPACE::ModelProto* model_proto, - const std::string engine_cache_path); +std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path); +ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, + const std::string engine_cache_path, + char* engine_data, + size_t size, + const int64_t embed_mode, + std::string compute_capability, + const logging::Logger* logger); +std::string GetCtxModelPath(const std::string& ep_context_file_path, + const std::string& original_model_path); +bool IsAbsolutePath(std::string& path_string); +bool IsRelativePathToParentPath(std::string& path_string); +void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto, + const std::string& ctx_model_path); void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto, char* engine_data, size_t size); @@ -38,7 +45,8 @@ class TensorRTCacheModelHandler { public: TensorRTCacheModelHandler(std::unique_ptr* trt_engine, nvinfer1::IRuntime* trt_runtime, - std::string compute_capability) : trt_engine_(trt_engine), trt_runtime_(trt_runtime), compute_capability_(compute_capability) { + std::string ep_context_model_path, + std::string compute_capability) : trt_engine_(trt_engine), trt_runtime_(trt_runtime), ep_context_model_path_(ep_context_model_path), compute_capability_(compute_capability) { } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler); @@ -49,7 +57,7 @@ class TensorRTCacheModelHandler { private: std::unique_ptr* trt_engine_; nvinfer1::IRuntime* trt_runtime_; - std::filesystem::path engine_cache_path_; + std::string ep_context_model_path_; // If using context model, it implies context model and engine cache is in the same directory std::string compute_capability_; }; // TRTCacheModelHandler } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index aa02d8384afa6..fe6b959b962de 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1079,8 +1079,6 @@ Status BindKernelOutput(Ort::KernelContext& ctx, char const* output_name, size_t output_index, size_t output_type, - std::vector>& scratch_buffers, - OrtAllocator* alloc, cudaStream_t stream) { auto allocator = allocator_map[output_name].get(); auto& shape = allocator->getOutputShape(); @@ -1350,6 +1348,9 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv timing_cache_enable_ = info.timing_cache_enable; force_timing_cache_match_ = info.force_timing_cache; detailed_build_log_ = info.detailed_build_log; + dump_ep_context_model_ = info.dump_ep_context_model; + ep_context_file_path_ = info.ep_context_file_path; + ep_context_embed_mode_ = info.ep_context_embed_mode; if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { cache_path_ = info.engine_cache_path; cache_prefix_ = info.engine_cache_prefix; @@ -1380,9 +1381,6 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv profile_max_shapes = info.profile_max_shapes; profile_opt_shapes = info.profile_opt_shapes; cuda_graph_enable_ = info.cuda_graph_enable; - dump_ep_context_model_ = info.dump_ep_context_model; - ep_context_embed_mode_ = info.ep_context_embed_mode; - ep_context_compute_capability_enable_ = info.ep_context_compute_capability_enable; } else { try { const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations); @@ -1461,6 +1459,21 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv force_timing_cache_match_ = (std::stoi(timing_force_match_env) == 0 ? false : true); } + const std::string dump_ep_context_model_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDumpEpContextModel); + if (!dump_ep_context_model_env.empty()) { + dump_ep_context_model_ = (std::stoi(dump_ep_context_model_env) == 0 ? false : true); + } + + const std::string ep_context_file_path_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextComputeCapabilityEnable); + if (!ep_context_file_path_env.empty()) { + ep_context_file_path_ = ep_context_file_path_env; + } + + const std::string ep_context_embed_mode_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextEmbedMode); + if (!ep_context_embed_mode_env.empty()) { + ep_context_embed_mode_ = std::stoi(ep_context_embed_mode_env); + } + if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { const std::string engine_cache_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCachePath); cache_path_ = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCachePath); @@ -1538,21 +1551,6 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv cuda_graph_enable_ = (std::stoi(cuda_graph_enable_env) == 0 ? false : true); } - const std::string dump_ep_context_model_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDumpEpContextModel); - if (!dump_ep_context_model_env.empty()) { - dump_ep_context_model_ = (std::stoi(dump_ep_context_model_env) == 0 ? false : true); - } - - const std::string ep_context_embed_mode_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextEmbedMode); - if (!ep_context_embed_mode_env.empty()) { - ep_context_embed_mode_ = std::stoi(ep_context_embed_mode_env); - } - - const std::string ep_context_compute_capability_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextComputeCapabilityEnable); - if (!ep_context_compute_capability_env.empty()) { - ep_context_compute_capability_enable_ = (std::stoi(ep_context_compute_capability_env) == 0 ? false : true); - } - } catch (const std::invalid_argument& ex) { LOGS_DEFAULT(WARNING) << "[TensorRT EP] Invalid Argument (from environment variables): " << ex.what(); } catch (const std::out_of_range& ex) { @@ -1580,7 +1578,36 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv dla_core_ = 0; } - if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_ || !cache_prefix_.empty()) { + // If ep_context_file_path_ is provided as a directory, create it if it's not existed + if (dump_ep_context_model_ && !ep_context_file_path_.empty() && std::filesystem::path(ep_context_file_path_).extension().empty() && !std::filesystem::is_directory(ep_context_file_path_)) { + if (!std::filesystem::create_directory(ep_context_file_path_)) { + throw std::runtime_error("Failed to create directory " + ep_context_file_path_); + } + } + + // If dump_ep_context_model_ is enable, TRT EP forces cache_path_ to be the relative path of ep_context_file_path_. + // For example, + // - original cache path = "engine_cache_dir" -> new cache path = "./context_model_dir/engine_cache_dir" + // - original cache path = "" -> new cache path = "./context_model_dir" + // The new cache path will be saved as the "ep_cache_context" node attritue of the EP context node. + // For security reason, it needs to make sure the engine cache is saved inside context model directory. + if (dump_ep_context_model_ && engine_cache_enable_) { + if (IsAbsolutePath(cache_path_)) { + LOGS_DEFAULT(ERROR) << "In the case of dumping context model and for security purpose, the trt_engine_cache_path should be set with a relative path, but it is an absolute path: " << cache_path_; + } + if (IsRelativePathToParentPath(cache_path_)) { + LOGS_DEFAULT(ERROR) << "In the case of dumping context model and for security purpose, The trt_engine_cache_path has '..', it's not allowed to point outside the directory."; + } + + // Engine cache relative path to context model directory. + // It's used when dumping the "ep_cache_context" node attribute. + engine_cache_relative_path_to_context_model_dir = cache_path_; + + // Make cache_path_ to be the relative path of ep_context_file_path_ + cache_path_ = GetPathOrParentPathOfCtxModel(ep_context_file_path_).append(cache_path_).string(); + } + + if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { if (!cache_path_.empty() && !fs::is_directory(cache_path_)) { if (!fs::create_directory(cache_path_)) { throw std::runtime_error("Failed to create directory " + cache_path_); @@ -1692,6 +1719,9 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv << ", trt_profile_max_shapes: " << profile_max_shapes << ", trt_profile_opt_shapes: " << profile_opt_shapes << ", trt_cuda_graph_enable: " << cuda_graph_enable_ + << ", trt_dump_ep_context_model: " << dump_ep_context_model_ + << ", trt_ep_context_file_path: " << ep_context_file_path_ + << ", trt_ep_context_embed_mode: " << ep_context_embed_mode_ << ", trt_cache_prefix: " << cache_prefix_; } @@ -2309,6 +2339,14 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, // Construct subgraph capability from node list std::vector> result; + // Get ModelPath + const auto& path_string = graph.ModelPath().ToPathString(); +#ifdef _WIN32 + wcstombs_s(nullptr, model_path_, sizeof(model_path_), path_string.c_str(), sizeof(model_path_)); +#else + strcpy(model_path_, path_string.c_str()); +#endif + // If the model consists of only a single "EPContext" contrib op, it means TRT EP can fetch the precompiled engine info from the node and // load the engine directly without having to go through the processes of graph proto reconstruction, calling TRT parser and engine compilation. // So, simply return the ComputeCapability here. @@ -2319,14 +2357,6 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, return result; } - // Get ModelPath - const auto& path_string = graph.ModelPath().ToPathString(); -#ifdef _WIN32 - wcstombs_s(nullptr, model_path_, sizeof(model_path_), path_string.c_str(), sizeof(model_path_)); -#else - strcpy(model_path_, path_string.c_str()); -#endif - // Generate unique kernel name for TRT graph HashValue model_hash = TRTGenerateId(graph); @@ -2831,10 +2861,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView std::unique_ptr trt_engine; std::unique_ptr trt_context; - // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache - // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity - std::string cache_suffix = ""; std::string cache_path = ""; + std::string cache_suffix = ""; // Customize cache prefix if assigned if (!cache_prefix_.empty()) { // Generate cache suffix in case user would like to customize cache prefix @@ -2843,11 +2871,19 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } else { cache_path = GetCachePath(cache_path_, trt_node_name_with_precision); } + + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache + // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity const std::string cache_path_prefix = cache_path + "_sm" + compute_capability_; const std::string engine_cache_path = cache_path_prefix + ".engine"; const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; const std::string profile_cache_path = cache_path_prefix + ".profile"; + // Generate file name for dumping ep context model + if (dump_ep_context_model_ && ctx_model_path_.empty()) { + ctx_model_path_ = GetCtxModelPath(ep_context_file_path_, model_path_); + } + if (!has_dynamic_shape) { std::string timing_cache_path = ""; bool engine_update = false; @@ -2984,15 +3020,20 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } // dump EP context node model if (dump_ep_context_model_) { - std::unique_ptr model_proto{CreateCtxNodeModel(graph_body_viewer, - engine_cache_path, - reinterpret_cast(serialized_engine->data()), - serialized_engine->size(), - ep_context_embed_mode_, - ep_context_compute_capability_enable_, - compute_capability_, - GetLogger())}; - DumpCtxNodeModel(model_proto.get(), cache_path_prefix); + // "ep_cache_context" node attribute should be a relative path to context model directory + if (ep_cache_context_attr_.empty()) { + auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); + ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir).append(cache_file_name.string()).string(); + } + + std::unique_ptr model_proto{CreateCtxModel(graph_body_viewer, + ep_cache_context_attr_, + reinterpret_cast(serialized_engine->data()), + serialized_engine->size(), + ep_context_embed_mode_, + compute_capability_, + GetLogger())}; + DumpCtxModel(model_proto.get(), ctx_model_path_); } } } @@ -3052,16 +3093,20 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // TRT EP will serialize the model at inference time due to engine can be updated and the updated engine should be included in the model. // However, if the embed_mode is 0 (only includes engine path), TRT EP will serialize it here. if (dump_ep_context_model_ && has_dynamic_shape) { - model_proto_.reset(CreateCtxNodeModel(graph_body_viewer, - engine_cache_path, - nullptr, - 0, - ep_context_embed_mode_, - ep_context_compute_capability_enable_, - compute_capability_, - GetLogger())); + // "ep_cache_context" node attribute should be a relative path to context model directory + if (ep_cache_context_attr_.empty()) { + auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); + ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir).append(cache_file_name.string()).string(); + } + model_proto_.reset(CreateCtxModel(graph_body_viewer, + ep_cache_context_attr_, + nullptr, + 0, + ep_context_embed_mode_, + compute_capability_, + GetLogger())); if (ep_context_embed_mode_ == 0) { - DumpCtxNodeModel(model_proto_.get(), cache_path_prefix); + DumpCtxModel(model_proto_.get(), ctx_model_path_); } } @@ -3382,7 +3427,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // dump ep context model if (dump_ep_context_model_ && ep_context_embed_mode_) { UpdateCtxNodeModelEngineContext(model_proto_.get(), reinterpret_cast(serialized_engine->data()), serialized_engine->size()); - DumpCtxNodeModel(model_proto_.get(), cache_path_prefix); + DumpCtxModel(model_proto_.get(), ctx_model_path_); } context_update = true; } @@ -3521,7 +3566,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView if (index_iter != output_indexes.end()) { output_index = index_iter->second; } - auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, scratch_buffers, alloc, stream); + auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); } @@ -3575,7 +3620,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con std::unordered_map output_types; // TRT engine output name -> ORT output tensor type // Get engine binary data and deserialize it - auto trt_cache_model_handler = TensorRTCacheModelHandler(&trt_engine, runtime_.get(), compute_capability_); + auto trt_cache_model_handler = TensorRTCacheModelHandler(&trt_engine, runtime_.get(), model_path_, compute_capability_); auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); @@ -3802,7 +3847,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con if (index_iter != output_indexes.end()) { output_index = index_iter->second; } - auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, scratch_buffers, alloc, stream); + auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 401a8da119ac2..ad2d2c55c67e1 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -301,8 +301,11 @@ class TensorrtExecutionProvider : public IExecutionProvider { // For create/dump EP context node model bool dump_ep_context_model_ = false; + std::string ep_context_file_path_; int ep_context_embed_mode_ = 0; - bool ep_context_compute_capability_enable_ = true; + std::string ctx_model_path_; + std::string ep_cache_context_attr_; + std::string engine_cache_relative_path_to_context_model_dir; std::unique_ptr model_proto_ = ONNX_NAMESPACE::ModelProto::Create(); std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"}; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc index 28f6e1720f615..ba9251c71bced 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc @@ -47,9 +47,9 @@ constexpr const char* kProfilesMinShapes = "trt_profile_min_shapes"; constexpr const char* kProfilesMaxShapes = "trt_profile_max_shapes"; constexpr const char* kProfilesOptShapes = "trt_profile_opt_shapes"; constexpr const char* kCudaGraphEnable = "trt_cuda_graph_enable"; -constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model"; constexpr const char* kEpContextEmbedMode = "trt_ep_context_embed_mode"; -constexpr const char* kEpContextComputeCapabilityEnable = "trt_ep_context_compute_capability_enable"; +constexpr const char* kEpContextFilePath = "trt_ep_context_file_path"; +constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model"; } // namespace provider_option_names } // namespace tensorrt @@ -103,8 +103,8 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesOptShapes, info.profile_opt_shapes) .AddAssignmentToReference(tensorrt::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable) .AddAssignmentToReference(tensorrt::provider_option_names::kDumpEpContextModel, info.dump_ep_context_model) + .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextFilePath, info.ep_context_file_path) .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextEmbedMode, info.ep_context_embed_mode) - .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextComputeCapabilityEnable, info.ep_context_compute_capability_enable) .Parse(options)); // add new provider option here. return info; @@ -148,8 +148,8 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE {tensorrt::provider_option_names::kProfilesOptShapes, MakeStringWithClassicLocale(info.profile_opt_shapes)}, {tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)}, {tensorrt::provider_option_names::kDumpEpContextModel, MakeStringWithClassicLocale(info.dump_ep_context_model)}, + {tensorrt::provider_option_names::kEpContextFilePath, MakeStringWithClassicLocale(info.ep_context_file_path)}, {tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.ep_context_embed_mode)}, - {tensorrt::provider_option_names::kEpContextComputeCapabilityEnable, MakeStringWithClassicLocale(info.ep_context_compute_capability_enable)}, }; return options; } @@ -166,6 +166,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor const std::string kProfilesMinShapes_ = empty_if_null(info.trt_profile_min_shapes); const std::string kProfilesMaxShapes_ = empty_if_null(info.trt_profile_max_shapes); const std::string kProfilesOptShapes_ = empty_if_null(info.trt_profile_opt_shapes); + const std::string kEpContextFilePath_ = empty_if_null(info.trt_ep_context_file_path); const ProviderOptions options{ {tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, @@ -202,9 +203,9 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor {tensorrt::provider_option_names::kProfilesMaxShapes, kProfilesMaxShapes_}, {tensorrt::provider_option_names::kProfilesOptShapes, kProfilesOptShapes_}, {tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.trt_cuda_graph_enable)}, + {tensorrt::provider_option_names::kEpContextFilePath, kEpContextFilePath_}, {tensorrt::provider_option_names::kDumpEpContextModel, MakeStringWithClassicLocale(info.trt_dump_ep_context_model)}, {tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.trt_ep_context_embed_mode)}, - {tensorrt::provider_option_names::kEpContextComputeCapabilityEnable, MakeStringWithClassicLocale(info.trt_ep_context_compute_capability_enable)}, }; return options; } @@ -299,6 +300,6 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options trt_provider_options_v2.trt_cuda_graph_enable = internal_options.cuda_graph_enable; trt_provider_options_v2.trt_dump_ep_context_model = internal_options.dump_ep_context_model; trt_provider_options_v2.trt_ep_context_embed_mode = internal_options.ep_context_embed_mode; - trt_provider_options_v2.trt_ep_context_compute_capability_enable = internal_options.ep_context_compute_capability_enable; + trt_provider_options_v2.trt_ep_context_file_path = copy_string_if_needed(internal_options.ep_context_file_path); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h index a133ef45affe8..80424b8d6d196 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h @@ -52,8 +52,8 @@ struct TensorrtExecutionProviderInfo { std::string profile_opt_shapes{""}; bool cuda_graph_enable{false}; bool dump_ep_context_model{false}; + std::string ep_context_file_path{""}; int ep_context_embed_mode{0}; - bool ep_context_compute_capability_enable{1}; std::string engine_cache_prefix{""}; static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index 62f124afbd1e5..568da57a50956 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -61,13 +61,6 @@ std::unique_ptr TensorrtProviderFactory::CreateProvider() { return std::make_unique(info_); } -std::shared_ptr TensorrtProviderFactoryCreator::Create(int device_id) { - TensorrtExecutionProviderInfo info; - info.device_id = device_id; - info.has_trt_options = false; - return std::make_shared(info); -} - struct Tensorrt_Provider : Provider { void* GetInfo() override { return &g_info; } std::shared_ptr CreateExecutionProviderFactory(int device_id) override { @@ -117,8 +110,8 @@ struct Tensorrt_Provider : Provider { info.profile_opt_shapes = options.trt_profile_opt_shapes == nullptr ? "" : options.trt_profile_opt_shapes; info.cuda_graph_enable = options.trt_cuda_graph_enable != 0; info.dump_ep_context_model = options.trt_dump_ep_context_model != 0; + info.ep_context_file_path = options.trt_ep_context_file_path == nullptr ? "" : options.trt_ep_context_file_path; info.ep_context_embed_mode = options.trt_ep_context_embed_mode; - info.ep_context_compute_capability_enable = options.trt_ep_context_compute_capability_enable != 0; info.engine_cache_prefix = options.trt_engine_cache_prefix == nullptr ? "" : options.trt_engine_cache_prefix; return std::make_shared(info); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 45d8006e6b49e..3269c9f0f4e4b 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -89,6 +89,10 @@ using IndexedSubGraph_MetaDef = IndexedSubGraph::MetaDef; #include "core/providers/cann/cann_provider_options.h" #include "core/providers/dnnl/dnnl_provider_options.h" +#if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT) +#include "core/session/onnxruntime_session_options_config_keys.h" +#endif + // The filename extension for a shared library is different per platform #ifdef _WIN32 #define LIBRARY_PREFIX @@ -1372,10 +1376,6 @@ std::shared_ptr DnnlProviderFactoryCreator::Create(in return s_library_dnnl.Get().CreateExecutionProviderFactory(use_arena); } -std::shared_ptr TensorrtProviderFactoryCreator::Create(int device_id) { - return s_library_tensorrt.Get().CreateExecutionProviderFactory(device_id); -} - std::shared_ptr MIGraphXProviderFactoryCreator::Create(int device_id) { return s_library_migraphx.Get().CreateExecutionProviderFactory(device_id); } @@ -1419,11 +1419,44 @@ OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsToOrtTensorRTProviderOpti trt_options_converted.trt_profile_max_shapes = ""; trt_options_converted.trt_profile_opt_shapes = ""; trt_options_converted.trt_cuda_graph_enable = 0; + trt_options_converted.trt_dump_ep_context_model = 0; + trt_options_converted.trt_ep_context_file_path = ""; + trt_options_converted.trt_ep_context_embed_mode = 0; trt_options_converted.trt_engine_cache_prefix = ""; return trt_options_converted; } +#if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT) +// Apply configs from session options to TensorRT provider options V2 that are needed for TensorRT EP. +// For example, EP context configs. +void UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(OrtSessionOptions* session_options, OrtTensorRTProviderOptionsV2* tensorrt_options) { + if (session_options) { + auto context_cache_enabled = (session_options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") != "0"; + tensorrt_options->trt_dump_ep_context_model = context_cache_enabled; + LOGS_DEFAULT(VERBOSE) << "Context cache enable: " << context_cache_enabled; + + auto context_cache_path = (session_options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + tensorrt_options->trt_ep_context_file_path = context_cache_path.c_str(); + LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << tensorrt_options->trt_ep_context_file_path; + + auto embed_mode = (session_options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "1"); + if ("1" == embed_mode) { + tensorrt_options->trt_ep_context_embed_mode = 1; + } else if ("0" == embed_mode) { + tensorrt_options->trt_ep_context_embed_mode = 0; + } else { + LOGS_DEFAULT(VERBOSE) << "Invalid ep.context_embed_mode: " << embed_mode << " only 0 or 1 allowed. Set to 1."; + } + LOGS_DEFAULT(VERBOSE) << "User specified context cache embed mode: " << tensorrt_options->trt_ep_context_embed_mode; + } +} +#endif + +std::shared_ptr TensorrtProviderFactoryCreator::Create(int device_id) { + return s_library_tensorrt.Get().CreateExecutionProviderFactory(device_id); +} + std::shared_ptr TensorrtProviderFactoryCreator::Create(const OrtTensorRTProviderOptions* provider_options) { OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(provider_options); return s_library_tensorrt.Get().CreateExecutionProviderFactory(&trt_options_converted); @@ -1708,7 +1741,24 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtS ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options) { API_IMPL_BEGIN - auto factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); + + std::shared_ptr factory; + +#if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT) + auto ep_context_cache_enabled_from_sess_options = (options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") != "0"; + // If EP context configs are provided in session options, we need to propagate them to provider options + if (ep_context_cache_enabled_from_sess_options) { + OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(tensorrt_options); + + onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &trt_options_converted); + factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&trt_options_converted); + } else { + factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); + } +#else + factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); +#endif + if (!factory) { return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library"); } @@ -1845,7 +1895,31 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM, _In_ Or ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options) { API_IMPL_BEGIN - auto factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); + + std::shared_ptr factory; + +#if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT) + auto ep_context_cache_enabled_from_provider_options = tensorrt_options->trt_dump_ep_context_model != 0; + auto ep_context_cache_enabled_from_sess_options = (options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") != "0"; + + // If EP context configs are provided in session options, we need to propagate them to provider options. However, + // if provider options already have the EP context configs provided, the configs in session options will be ignored + // since provider options has higher priority than session options. + if (!ep_context_cache_enabled_from_provider_options && ep_context_cache_enabled_from_sess_options) { + // We need to create a new provider options V2 object and copy from provider_options, due to the "const" object pointed by provider_options can't be modified. + // Note: No need to worry about tensorrt_options being a local variable, CreateExecutionProviderFactory() in TRT EP will + // create a factory object that copies any provider options from tensorrt_options including "const char*" provider options. + OrtTensorRTProviderOptionsV2 new_tensorrt_options = *tensorrt_options; // copy and assign from tensorrt_options + + onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &new_tensorrt_options); + factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&new_tensorrt_options); + } else { + factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); + } +#else + factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); +#endif + if (!factory) { return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_TensorRT: Failed to load shared library"); } @@ -1991,6 +2065,7 @@ ORT_API(void, OrtApis::ReleaseTensorRTProviderOptions, _Frees_ptr_opt_ OrtTensor delete[] ptr->trt_profile_min_shapes; delete[] ptr->trt_profile_max_shapes; delete[] ptr->trt_profile_opt_shapes; + delete[] ptr->trt_ep_context_file_path; } std::unique_ptr p(ptr); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index d2cd6140b838e..f7ed5520727db 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -475,7 +475,7 @@ std::unique_ptr CreateExecutionProviderInstance( // So we need these std::string variables defined here as they will be kept alive for the lifetime of TRT EP and we can still access them from OrtTensorRTProviderOptionsV2 instance. // (The reason is string copy is involved, for example params.trt_engine_cache_path = cache_path.c_str() and those std::string variable is referenced by OrtTensorRTProviderOptionsV2 instance // and TRT EP instance, so it won't be released.) - std::string calibration_table, cache_path, cache_prefix, timing_cache_path, lib_path, trt_tactic_sources, trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile; + std::string calibration_table, cache_path, cache_prefix, timing_cache_path, lib_path, trt_tactic_sources, trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile, ep_context_file_path; auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { OrtTensorRTProviderOptionsV2 params; @@ -728,20 +728,19 @@ std::unique_ptr CreateExecutionProviderInstance( } else { ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dump_ep_context_model' should be 'True' or 'False'. Default value is 'False'.\n"); } + } else if (option.first == "trt_ep_context_file_path") { + if (!option.second.empty()) { + ep_context_file_path = option.second; + params.trt_ep_context_file_path = ep_context_file_path.c_str(); + } else { + ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_ep_context_file_path' should be a string.\n"); + } } else if (option.first == "trt_ep_context_embed_mode") { if (!option.second.empty()) { params.trt_ep_context_embed_mode = std::stoi(option.second); } else { ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_ep_context_embed_mode' should be a positive integer number i.e. '1'.\n"); } - } else if (option.first == "trt_ep_context_compute_capability_enable") { - if (option.second == "True" || option.second == "true") { - params.trt_ep_context_compute_capability_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.trt_ep_context_compute_capability_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_ep_context_compute_capability_enable' should be 'True' or 'False'. Default value is 'False'.\n"); - } } else { ORT_THROW("Invalid TensorRT EP option: ", option.first); } diff --git a/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py b/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py index 717a0816247e7..b94c2cb76a635 100644 --- a/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py +++ b/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py @@ -15,6 +15,7 @@ def __init__(self, args): engine_cache_path = args.trt_engine_cache_path self.model_name = args.model_name self.dynamic_dim_count = 0 + self.plugins = args.plugins # Get serialized engine from engine cache with open(engine_cache_path, "rb") as file: @@ -25,8 +26,16 @@ def __init__(self, args): else: ep_cache_context_content = engine_cache_path - # Deserialize an TRT engine logger = trt.Logger(trt.Logger.WARNING) + + # Enable TRT plugins + trt.init_libnvinfer_plugins(logger, "") + if len(self.plugins): + import ctypes + + ctypes.CDLL(self.plugins) + + # Deserialize an TRT engine runtime = trt.Runtime(logger) engine = runtime.deserialize_cuda_engine(engine_buffer) num_bindings = engine.num_bindings @@ -165,6 +174,14 @@ def main(): default="trt_engine_wrapper.onnx", type=str, ) + parser.add_argument( + "--plugins", + help="List of plugin paths to load", + required=False, + default=[], + nargs="+", + type=str, + ) args = parser.parse_args() ctor = TensorRTEngineWrapperCreator(args) ctor.create_model() diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 508739ae1d235..4d2538c947dcc 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -122,9 +122,15 @@ void CreateBaseModel(std::string model_name, status = onnxruntime::Model::Save(model, model_name); } -bool HasCacheFileWithPrefix(const std::string& prefix) { - const std::filesystem::path current_dir = std::filesystem::current_path(); - for (const auto& entry : std::filesystem::directory_iterator(current_dir)) { +bool HasCacheFileWithPrefix(const std::string& prefix, std::string file_dir = "") { + std::filesystem::path target_dir; + if (file_dir.empty()) { + target_dir = std::filesystem::current_path(); + } else { + target_dir = std::filesystem::path(file_dir); + } + + for (const auto& entry : std::filesystem::directory_iterator(target_dir)) { if (entry.is_regular_file()) { std::string filename = entry.path().filename().string(); if (filename.rfind(prefix, 0) == 0) { @@ -191,6 +197,8 @@ void RunWithOneSessionSingleThreadInference(std::string model_name, std::string OrtTensorRTProviderOptionsV2 params; params.trt_engine_cache_enable = 1; params.trt_engine_cache_prefix = "TRTEP_Cache_Test"; + params.trt_dump_ep_context_model = 1; + params.trt_ep_context_file_path = "EP_Context_model.onnx"; std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); auto status = session_object.Load(model_name); @@ -209,6 +217,9 @@ void RunWithOneSessionSingleThreadInference(std::string model_name, std::string // Verify on cache with customized prefix ASSERT_TRUE(HasCacheFileWithPrefix(params.trt_engine_cache_prefix)); + + // Verify EP context model with user provided name + ASSERT_TRUE(HasCacheFileWithPrefix(params.trt_ep_context_file_path)); } void RunWithOneSessionMultiThreadsInference(std::string model_name, std::string sess_log_id, bool has_non_zero_node = false) { @@ -348,6 +359,192 @@ TEST(TensorrtExecutionProviderTest, TRTModelIdGeneratorUsingModelHashing) { ASSERT_EQ(model_hash, model_hash3) << "model 1&3 are same models and they have same hash, no matter where they are loaded"; } +TEST(TensorrtExecutionProviderTest, EPContextNode) { + std::string model_name = "EPContextNode_test.onnx"; + std::string graph_name = "EPContextNode_test"; + std::string sess_log_id = "EPContextNode_test"; + std::vector dims = {1, 3, 2}; + CreateBaseModel(model_name, graph_name, dims); + + SessionOptions so; + so.session_logid = sess_log_id; + RunOptions run_options; + run_options.run_tag = so.session_logid; + InferenceSession session_object{so, GetEnvironment()}; + auto cuda_provider = DefaultCudaExecutionProvider(); + auto cpu_allocator = cuda_provider->CreatePreferredAllocators()[1]; + std::vector dims_mul_x = {1, 3, 2}; + std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + OrtValue ml_value_x; + CreateMLValue(cpu_allocator, dims_mul_x, values_mul_x, &ml_value_x); + OrtValue ml_value_y; + CreateMLValue(cpu_allocator, dims_mul_x, values_mul_x, &ml_value_y); + OrtValue ml_value_z; + CreateMLValue(cpu_allocator, dims_mul_x, values_mul_x, &ml_value_z); + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + feeds.insert(std::make_pair("Y", ml_value_y)); + feeds.insert(std::make_pair("Z", ml_value_z)); + + // prepare outputs + std::vector output_names; + output_names.push_back("M"); + + // prepare expected inputs and outputs + std::vector expected_dims_mul_m = {1, 3, 2}; + std::vector expected_values_mul_m = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}; + + /* + * Test case 1: Dump context model + * + * provider options=> + * trt_ep_context_file_path = "EP_Context_model.onnx" + * + * expected result => + * context model "EP_Context_model.onnx" should be created in current directory + * + */ + OrtTensorRTProviderOptionsV2 params; + params.trt_engine_cache_enable = 1; + params.trt_dump_ep_context_model = 1; + params.trt_ep_context_file_path = "EP_Context_model.onnx"; + std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); + EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + auto status = session_object.Load(model_name); + ASSERT_TRUE(status.IsOK()); + status = session_object.Initialize(); + ASSERT_TRUE(status.IsOK()); + ASSERT_TRUE(HasCacheFileWithPrefix(params.trt_ep_context_file_path)); + + /* + * Test case 2: Dump context model + * + * provider options=> + * trt_engine_cache_prefix = "TRT_engine_cache" + * trt_ep_context_file_path = "context_model_folder" + * trt_engine_cache_path = "engine_cache_folder" + * + * expected result => + * engine cache "./context_model_folder/engine_cache_folder/TRT_engine_cache...engine" should be created + * context model "./context_model_folder/EPContextNode_test_ctx.onnx" should be created + */ + InferenceSession session_object2{so, GetEnvironment()}; + OrtTensorRTProviderOptionsV2 params2; + params2.trt_engine_cache_enable = 1; + params2.trt_dump_ep_context_model = 1; + params2.trt_engine_cache_prefix = "TRT_engine_cache"; + params2.trt_engine_cache_path = "engine_cache_folder"; // due to dump_ep_context_model = 1, the new cache path is ./context_model_folder/engine_cache_folder + params2.trt_ep_context_file_path = "context_model_folder"; + execution_provider = TensorrtExecutionProviderWithOptions(¶ms2); + EXPECT_TRUE(session_object2.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + status = session_object2.Load(model_name); + ASSERT_TRUE(status.IsOK()); + status = session_object2.Initialize(); + ASSERT_TRUE(status.IsOK()); + auto new_engine_cache_path = std::filesystem::path(params2.trt_ep_context_file_path).append(params2.trt_engine_cache_path).string(); + // Test engine cache path: + // "./context_model_folder/engine_cache_folder/TRT_engine_cache...engine" should be created + ASSERT_TRUE(HasCacheFileWithPrefix(params2.trt_engine_cache_prefix, new_engine_cache_path)); + // Test context model path: + // "./context_model_folder/EPContextNode_test_ctx.onnx" should be created + ASSERT_TRUE(HasCacheFileWithPrefix("EPContextNode_test_ctx.onnx", params2.trt_ep_context_file_path)); + + /* + * Test case 3: Run the dumped context model + * + * context model path = "./EP_Context_model.onnx" (created from case 1) + * + * expected result=> + * engine cache is also in the same current dirctory as "./xxxxx.engine" + * and the "ep_cache_context" attribute node of the context model should point to that. + * + */ + InferenceSession session_object3{so, GetEnvironment()}; + OrtTensorRTProviderOptionsV2 params3; + model_name = params.trt_ep_context_file_path; + params3.trt_engine_cache_enable = 1; + execution_provider = TensorrtExecutionProviderWithOptions(¶ms3); + EXPECT_TRUE(session_object3.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + status = session_object3.Load(model_name); + ASSERT_TRUE(status.IsOK()); + status = session_object3.Initialize(); + ASSERT_TRUE(status.IsOK()); + // run inference + // TRT engine will be created and cached + // TRT profile will be created and cached only for dynamic input shape + // Data in profile, + // X: 1, 3, 3, 2, 2, 2 + // Y: 1, 3, 3, 2, 2, 2 + // Z: 1, 3, 3, 2, 2, 2 + RunSession(session_object3, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m); + + /* + * Test case 4: Run the dumped context model + * + * context model path = "./context_model_folder/EPContextNode_test_ctx.onnx" (created from case 2) + * + * expected result=> + * engine cache path is "./context_model_folder/engine_cache_folder/xxxxx.engine" + * and the "ep_cache_context" attribute node of the context model should point to "engine_cache_folder/xxxxx.engine". + * + */ + InferenceSession session_object4{so, GetEnvironment()}; + OrtTensorRTProviderOptionsV2 params4; + model_name = "./context_model_folder/EPContextNode_test_ctx.onnx"; + execution_provider = TensorrtExecutionProviderWithOptions(¶ms4); + EXPECT_TRUE(session_object4.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + status = session_object4.Load(model_name); + ASSERT_TRUE(status.IsOK()); + status = session_object4.Initialize(); + ASSERT_TRUE(status.IsOK()); + // run inference + // TRT engine will be created and cached + // TRT profile will be created and cached only for dynamic input shape + // Data in profile, + // X: 1, 3, 3, 2, 2, 2 + // Y: 1, 3, 3, 2, 2, 2 + // Z: 1, 3, 3, 2, 2, 2 + RunSession(session_object4, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m); + + /* + * Test case 5: Dump context model with embed_model = 1 + */ + InferenceSession session_object5{so, GetEnvironment()}; + OrtTensorRTProviderOptionsV2 params5; + params5.trt_dump_ep_context_model = 1; + params5.trt_ep_context_embed_mode = 1; + params5.trt_ep_context_file_path = "EP_Context_model_2.onnx"; + model_name = "EPContextNode_test.onnx"; + execution_provider = TensorrtExecutionProviderWithOptions(¶ms5); + EXPECT_TRUE(session_object5.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + status = session_object5.Load(model_name); + ASSERT_TRUE(status.IsOK()); + status = session_object5.Initialize(); + ASSERT_TRUE(status.IsOK()); + + /* + * Test case 6: Run context model with embed_model = 1 (created from case 5) + */ + InferenceSession session_object6{so, GetEnvironment()}; + OrtTensorRTProviderOptionsV2 params6; + params6.trt_ep_context_embed_mode = 1; + model_name = params5.trt_ep_context_file_path; + execution_provider = TensorrtExecutionProviderWithOptions(¶ms6); + EXPECT_TRUE(session_object6.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + status = session_object6.Load(model_name); + ASSERT_TRUE(status.IsOK()); + status = session_object6.Initialize(); + ASSERT_TRUE(status.IsOK()); + // run inference + // TRT engine will be created and cached + // TRT profile will be created and cached only for dynamic input shape + // Data in profile, + // X: 1, 3, 3, 2, 2, 2 + // Y: 1, 3, 3, 2, 2, 2 + // Z: 1, 3, 3, 2, 2, 2 + RunSession(session_object6, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m); +} + TEST(TensorrtExecutionProviderTest, TRTPluginsCustomOpTest) { std::string model_name = "testdata/trt_plugin_custom_op_test.onnx"; SessionOptions so; @@ -448,6 +645,8 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) { params.trt_engine_cache_enable = 1; params.trt_engine_cache_prefix = "TRTEP_Cache_Test"; + params.trt_dump_ep_context_model = 1; + params.trt_ep_context_file_path = "EP_Context_model.onnx"; std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); auto status = session_object.Load(model_name); @@ -576,6 +775,9 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) { // Verify on cache with customized prefix ASSERT_TRUE(HasCacheFileWithPrefix(params.trt_engine_cache_prefix)); + // Verify EP context model with user provided name + ASSERT_TRUE(HasCacheFileWithPrefix(params.trt_ep_context_file_path)); + if (input_type.compare("static") == 0) { // Can't run inference since input shape changes but the engine is built with static input ASSERT_FALSE(status.IsOK()); From 21034a2c37d707ee913dcca1b00d0b9e7651f980 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Mon, 22 Jan 2024 18:17:11 +0000 Subject: [PATCH 06/23] phi2 contrib ops changes (#19112) ### Description 1. support causal mask in MHA cpu 2. support custom rotary_dim in rotary_emb 3. add bf16 for rotary_emb 4. fix a bug in attention rotary ### Motivation and Context --- docs/ContribOperators.md | 12 +- docs/OperatorKernels.md | 2 +- onnxruntime/contrib_ops/cpu/bert/attention.cc | 6 + .../cpu/bert/multihead_attention.cc | 4 +- .../cpu/bert/multihead_attention.h | 1 + .../cpu/bert/multihead_attention_helper.h | 8 +- .../contrib_ops/cpu/bert/rotary_embedding.cc | 47 ++++--- .../contrib_ops/cpu/bert/rotary_embedding.h | 2 + .../cpu/bert/rotary_embedding_helper.h | 55 ++++---- .../cuda/bert/multihead_attention.cc | 3 + .../cuda/bert/multihead_attention.h | 1 + .../contrib_ops/cuda/bert/rotary_embedding.cc | 6 + .../contrib_ops/cuda/bert/rotary_embedding.h | 2 + .../cuda/bert/rotary_embedding_impl.cu | 64 ++++++--- .../cuda/bert/rotary_embedding_impl.h | 1 + .../contrib_ops/cuda/cuda_contrib_kernels.cc | 2 + .../core/graph/contrib_ops/bert_defs.cc | 18 ++- .../contrib_ops/rotary_embedding_op_test.cc | 127 ++++++++++++++++-- 18 files changed, 280 insertions(+), 81 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 45c0e6f822ce9..22e82443167f6 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -3031,6 +3031,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
+
unidirectional : int
+
Whether every token can only attend to previous tokens. Default value is 0.
#### Inputs (1 - 8) @@ -5021,6 +5023,10 @@ This version of the operator has been available since version 1 of the 'com.micr
interleaved : int
Rotate using interleaved pattern. Default value is 0 (False).
+
num_heads : int
+
Number of attention heads. Default value is 0. Must use with rotary_embedding_dim
+
rotary_embedding_dim : int
+
Rotary embedding dimension. Default value is 0.
scale : float
Custom scale will be used if specified. Default value is 1.0
@@ -5033,9 +5039,9 @@ This version of the operator has been available since version 1 of the 'com.micr
position_ids : M
1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
cos_cache : T
-
2D tensor with shape (max_sequence_length, head_size / 2).
+
2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)
sin_cache : T
-
2D tensor with shape (max_sequence_length, head_size / 2).
+
2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)
#### Outputs @@ -5048,7 +5054,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float), tensor(float16)
+
T : tensor(float), tensor(float16), tensor(bfloat16)
Constrain input and output types to float tensors.
M : tensor(int64)
Constrain input and output types to integer tensors
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 394bd7ad2abae..9ecc58bee0725 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -868,7 +868,7 @@ Do not modify directly.* |RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)| |RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(bfloat16), tensor(float), tensor(float16)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipGroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*in* skip:**T**
*in* bias:**T**
*out* Y:**T**
*out* S:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 4711ccf487cc8..768676259aa14 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -211,6 +211,12 @@ Status Attention::Compute(OpKernelContext* context) const { relative_position_bias, ¶meters)); + if (parameters.do_rotary) { + ORT_NOT_IMPLEMENTED( + "Rotary embedding is not supported in Attention CPU kernel. \ + Please fuse the model with MHA + RotaryEmbedding."); + } + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int input_hidden_size = parameters.input_hidden_size; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 694c40bf3eda6..eb25d0fd7cc1e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -40,6 +40,7 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i num_heads_ = static_cast(num_heads); mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; } // Reshape Q/K/V from BxSxD to BxSxNxH @@ -283,8 +284,9 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { nullptr, ¶meters, num_heads_, - scale, mask_filter_value_, + scale, + is_unidirectional_, past_present_share_buffer, false)); diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h index 4c86b777e9842..fb7da78a5c0a5 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h @@ -18,6 +18,7 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase { protected: int num_heads_; // number of attention heads float mask_filter_value_; + bool is_unidirectional_; }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index 00e82c9844b3d..c91f5b601b4e9 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -25,6 +25,7 @@ Status CheckInputs(const T* query, int num_heads, float mask_filter_value, float scale, + bool is_unidirectional, bool past_present_share_buffer, bool dmmha_packing) { // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None @@ -315,7 +316,7 @@ Status CheckInputs(const T* query, output_parameters->head_size = hidden_size / num_heads; output_parameters->v_head_size = v_hidden_size / num_heads; output_parameters->num_heads = num_heads; - output_parameters->is_unidirectional = false; + output_parameters->is_unidirectional = is_unidirectional; output_parameters->past_present_share_buffer = past_present_share_buffer; output_parameters->mask_filter_value = mask_filter_value; output_parameters->mask_type = mask_type; @@ -342,6 +343,7 @@ Status CheckInputs(const T* query, int num_heads, float mask_filter_value, float scale, + bool is_unidirectional, bool past_present_share_buffer, bool dmmha_packing, int max_threads_per_block) { @@ -350,8 +352,8 @@ Status CheckInputs(const T* query, } return CheckInputs(query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, - past_seq_len, parameters, num_heads, mask_filter_value, scale, past_present_share_buffer, - dmmha_packing); + past_seq_len, parameters, num_heads, mask_filter_value, scale, is_unidirectional, + past_present_share_buffer, dmmha_packing); } } // namespace multihead_attention_helper diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index 47f462d75fcc4..aa8b5b5f608fa 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -27,7 +27,13 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( template RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { scale = info.GetAttrOrDefault("scale", 1.0); + rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0)); interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); + + if (rotary_embedding_dim > 0) { + ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified"); + } } template @@ -42,6 +48,8 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { position_ids, cos_cache, sin_cache, + num_heads, + rotary_embedding_dim, ¶meters)); Tensor* output = context->Output(0, input->Shape()); @@ -59,61 +67,66 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int num_heads = parameters.num_heads; + const int n_heads = parameters.num_heads; const int head_size = parameters.head_size; const int position_ids_format = parameters.position_ids_format; - const int half_head_size = head_size / 2; + const int rotary_emb_dim = parameters.rotary_embedding_dim; + const int half_rotary_emb_dim = rotary_emb_dim / 2; + // Default input tensor shape is [batch, seq_len, hidden_size] int head_stride = head_size; - int seq_stride = num_heads * head_stride; + int seq_stride = n_heads * head_stride; int batch_stride = sequence_length * seq_stride; if (parameters.transposed) { - // Transposed input tensor shape is [batch, num_heads, seq_len, head_size] + // Transposed input tensor shape is [batch, n_heads, seq_len, head_size] seq_stride = head_size; head_stride = sequence_length * seq_stride; - batch_stride = num_heads * head_stride; + batch_stride = n_heads * head_stride; } AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); auto* tp = context->GetOperatorThreadPool(); - const int loop_len = batch_size * sequence_length * num_heads; - const double cost = static_cast(head_size); + const int loop_len = batch_size * sequence_length * n_heads; + const double cost = static_cast(rotary_emb_dim); ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { - const int b = static_cast((ptr / num_heads) / sequence_length); - const int s = static_cast((ptr / num_heads) % sequence_length); - const int n = static_cast(ptr % num_heads); + const int b = static_cast((ptr / n_heads) / sequence_length); + const int s = static_cast((ptr / n_heads) % sequence_length); + const int n = static_cast(ptr % n_heads); const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; const T* input_data = input_src + block_offset; T* output_data = output_dest + block_offset; - // Cache is (M, H/2) + // Cache is (M, H/2) or (M, rotary_embedding_dim/2) const int position_id = (position_ids_format == 0) ? static_cast(pos_ids_data[0]) + s : static_cast(pos_ids_data[b * sequence_length + s]); - const int cache_offset = position_id * half_head_size; + const int cache_offset = position_id * half_rotary_emb_dim; const T* cos_data = cos_cache_data + cache_offset; const T* sin_data = sin_cache_data + cache_offset; int cache_idx = 0; T sign = 0; int j = 0; - for (int i = 0; i < head_size; i++) { + for (int i = 0; i < rotary_emb_dim; i++) { if (interleaved) { - cache_idx = (i / 2) % half_head_size; + cache_idx = (i / 2) % half_rotary_emb_dim; sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1); j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign } else { - cache_idx = i % half_head_size; - sign = (i < half_head_size) ? static_cast(-1) : static_cast(1); - j = (i + half_head_size) % head_size; + cache_idx = i % half_rotary_emb_dim; + sign = (i < half_rotary_emb_dim) ? static_cast(-1) : static_cast(1); + j = (i + half_rotary_emb_dim) % rotary_emb_dim; } output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; } + for (int i = rotary_emb_dim; i < head_size; i++) { + output_data[i] = input_data[i]; + } } }); diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h index be834a66cdc69..4e32424a22b6c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h @@ -16,6 +16,8 @@ class RotaryEmbedding final : public OpKernel { protected: float scale; + int num_heads; + int rotary_embedding_dim; bool interleaved; }; diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h index 7b2e8289f7b06..dcbb36d1c4a3c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h @@ -11,14 +11,15 @@ namespace rotary_embedding_helper { // Parameters deduced from node attributes and inputs/outputs. struct RotaryParameters { - int batch_size; // Batch size used by input - int sequence_length; // Sequence length used by input - int hidden_size; // Hidden size used by input - int head_size; // Head size used by cos/sin cache * 2 - int num_heads; // num_heads = hidden_size / head_size - int max_sequence_length; // Sequence length used by cos/sin cache - int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) - bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) + int batch_size; // Batch size used by input + int sequence_length; // Sequence length used by input + int hidden_size; // Hidden size used by input + int head_size; // Head size + int rotary_embedding_dim; // Rotary embedding dimension. + int num_heads; // num_heads = hidden_size / head_size + int max_sequence_length; // Sequence length used by cos/sin cache + int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) + bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) }; template @@ -26,11 +27,13 @@ Status CheckInputs(const T* input, const T* position_ids, const T* cos_cache, const T* sin_cache, + int num_heads, + int rotary_embedding_dim, void* parameters) { // input : (batch_size, sequence_length, hidden_size) // position ids : (1) or (batch_size, sequence_length) - // cos cache : (max_sequence_length, head_size / 2) - // sin cache : (max_sequence_length, head_size / 2) + // cos cache : (max_sequence_length, rotary_embedding_dim / 2) + // sin cache : (max_sequence_length, rotary_embedding_dim / 2) // Check input const auto& input_dims = input->Shape().GetDims(); @@ -60,6 +63,12 @@ Status CheckInputs(const T* input, "the same shape"); } + // Check num_heads and rotary_embedding_dim + if (rotary_embedding_dim > 0 && num_heads == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads must be provided if rotary_embedding_dim is ", + "specified"); + } + // Get attributes from inputs int batch_size = static_cast(input_dims[0]); int sequence_length = static_cast(input_dims[1]); @@ -73,8 +82,13 @@ Status CheckInputs(const T* input, transposed = true; } int max_sequence_length = static_cast(cos_cache_dims[0]); - int head_size = static_cast(cos_cache_dims[1]) * 2; - int num_heads = hidden_size / head_size; + int head_size = rotary_embedding_dim == 0 ? static_cast(cos_cache_dims[1]) * 2 + : static_cast(hidden_size / num_heads); + if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ", + "head_size"); + } + int position_ids_format = -1; // Check position_ids input shapes @@ -91,23 +105,15 @@ Status CheckInputs(const T* input, } else { position_ids_format = 0; } + // Check cos_cache input shapes if (max_sequence_length != static_cast(cos_cache_dims[0])) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ", "max_sequence_length, got ", cos_cache_dims[0]); } - if ((head_size / 2) != static_cast(cos_cache_dims[1])) { + if ((head_size / 2) != static_cast(cos_cache_dims[1]) && (rotary_embedding_dim > 0 && (rotary_embedding_dim / 2) != static_cast(cos_cache_dims[1]))) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ", - "head_size / 2, got ", cos_cache_dims[1]); - } - // Check sin_cache input shapes - if (max_sequence_length != static_cast(sin_cache_dims[0])) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 0 should be same as ", - "max_sequence_length, got ", sin_cache_dims[0]); - } - if ((head_size / 2) != static_cast(sin_cache_dims[1])) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 1 should be same as ", - "head_size / 2, got ", sin_cache_dims[1]); + "head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]); } // Set rotary parameters @@ -117,10 +123,11 @@ Status CheckInputs(const T* input, output_parameters->sequence_length = sequence_length; output_parameters->hidden_size = hidden_size; output_parameters->head_size = head_size; - output_parameters->num_heads = num_heads; + output_parameters->num_heads = num_heads > 0 ? num_heads : static_cast(hidden_size / head_size); output_parameters->max_sequence_length = max_sequence_length; output_parameters->position_ids_format = position_ids_format; output_parameters->transposed = transposed; + output_parameters->rotary_embedding_dim = rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size; } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index ebd66d8c6528e..f978f50c6851f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -44,6 +44,8 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); scale_ = info.GetAttrOrDefault("scale", 0.0f); + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; + ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead."); disable_fused_self_attention_ = sizeof(T) != 2 || ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); @@ -105,6 +107,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { num_heads_, mask_filter_value_, scale_, + is_unidirectional_, false, // past_present_share_buffer false, // dmmha_packing device_prop.maxThreadsPerBlock)); diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index c162f7133cc1c..86a32c92ce003 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -25,6 +25,7 @@ class MultiHeadAttention final : public CudaKernel { int num_heads_; // number of attention heads float mask_filter_value_; float scale_; + bool is_unidirectional_; bool disable_fused_self_attention_; bool enable_trt_flash_attention_; bool disable_fused_cross_attention_; diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc index 2d12e975d88d7..9de7ba3885c3c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -29,10 +29,13 @@ namespace cuda { REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) template RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) { scale = info.GetAttrOrDefault("scale", 1.0); + rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0)); interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); } @@ -48,6 +51,8 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { position_ids, cos_cache, sin_cache, + num_heads, + rotary_embedding_dim, ¶meters)); Tensor* output = context->Output(0, input->Shape()); @@ -71,6 +76,7 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length, parameters.num_heads, parameters.head_size, + parameters.rotary_embedding_dim, parameters.max_sequence_length, parameters.position_ids_format, interleaved, diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h index 6dab2ad56749e..d52f61d670444 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h @@ -19,6 +19,8 @@ class RotaryEmbedding final : public CudaKernel { protected: float scale; + int num_heads; + int rotary_embedding_dim; bool interleaved; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index e1b83bd8caf54..c6637041f05bd 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -26,6 +26,7 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int position_ids_format, const bool interleaved, const int batch_stride, @@ -33,24 +34,33 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int head_stride) { // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length // Use .x in innermost loop to access global memory efficiently - + const int b = blockIdx.z; const int s = blockIdx.y; const int n = blockIdx.x; const int i = threadIdx.x; + if (i >= head_size) { + return; + } + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; const T* input_data = input + block_offset; T* output_data = output + block_offset; + if (i >= rotary_embedding_dim) { + output_data[i] = input_data[i]; + return; + } + // Cache is (M, H/2) - const int half_head_size = head_size / 2; + const int half_rotary_embedding_dim = rotary_embedding_dim / 2; const int position_id = (position_ids_format == 0) ? \ static_cast(position_ids[0]) + s \ : static_cast(position_ids[b * sequence_length + s]); - const int cache_offset = position_id * half_head_size; + const int cache_offset = position_id * half_rotary_embedding_dim; const T* cos_data = cos_cache + cache_offset; const T* sin_data = sin_cache + cache_offset; @@ -58,13 +68,13 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH T sign = 0; int j = 0; if (interleaved) { - cache_idx = (i / 2) % half_head_size; + cache_idx = (i / 2) % half_rotary_embedding_dim; sign = (i % 2 == 0) ? -1 : 1; j = (i % 2 == 0) ? i+1 : i-1; // i - sign } else { - cache_idx = i % half_head_size; - sign = (i < half_head_size) ? -1 : 1; - j = (i + half_head_size) % head_size; + cache_idx = i % half_rotary_embedding_dim; + sign = (i < half_rotary_embedding_dim) ? -1 : 1; + j = (i + half_rotary_embedding_dim) % rotary_embedding_dim; } output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; } @@ -82,20 +92,23 @@ Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, const int max_threads_per_block, const bool transposed) { - - constexpr int smem_size = 0; - const dim3 grid(num_heads, sequence_length, batch_size); - const dim3 block(head_size, 1, 1); - // Note: Current implementation assumes head_size <= max_threads_per_block // because head_size is currently large for LLaMA-2. For smaller head_size // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` // instead. This will require kernel changes to support. + ORT_ENFORCE(head_size <= max_threads_per_block, + "Rotary embedding dim must be <= max_threads_per_block"); + + int tpb = (head_size + 31)/32*32; + + const dim3 block(tpb); + const dim3 grid(num_heads, sequence_length, batch_size); // Default input tensor shape is [batch, seq, hidden_size] int head_stride = head_size; @@ -109,10 +122,9 @@ Status LaunchRotaryEmbeddingKernel( } assert(head_size <= max_threads_per_block); - RotaryEmbeddingBSNH<<>>( - output, input, cos_cache, sin_cache, position_ids, - sequence_length, num_heads, head_size, position_ids_format, interleaved, - batch_stride, seq_stride, head_stride + RotaryEmbeddingBSNH<<>>( + output, input, cos_cache, sin_cache, position_ids, sequence_length, num_heads, head_size, + rotary_embedding_dim, position_ids_format, interleaved, batch_stride, seq_stride, head_stride ); return CUDA_CALL(cudaGetLastError()); @@ -129,6 +141,7 @@ template Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, @@ -146,6 +159,25 @@ template Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block, + const bool transposed); + +template Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + BFloat16* output, + const BFloat16* input, + const int64_t* position_ids, + const BFloat16* cos_cache, + const BFloat16* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h index ee1ccc43dcbff..36300fe7a660f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -21,6 +21,7 @@ Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 34b44694a5fcc..fa73950c9c6f5 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -98,6 +98,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, RotaryEmbedding); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); @@ -299,6 +300,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 0317ffcfb0e31..7f34647f1faef 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -927,6 +927,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", AttributeProto::FLOAT, OPTIONAL_VALUE) + .Attr("unidirectional", + "Whether every token can only attend to previous tokens. Default value is 0.", + AttributeProto::INT, + static_cast(0)) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape (batch_size, kv_sequence_length, num_heads, 3, head_size)", @@ -1145,6 +1149,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Rotate using interleaved pattern. Default value is 0 (False).", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("rotary_embedding_dim", + "Rotary embedding dimension. Default value is 0.", + AttributeProto::INT, + OPTIONAL_VALUE) + .Attr("num_heads", + "Number of attention heads. Default value is 0. Must use with rotary_embedding_dim", + AttributeProto::INT, + OPTIONAL_VALUE) .Input(0, "input", "3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)", @@ -1155,17 +1167,17 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "M") .Input(2, "cos_cache", - "2D tensor with shape (max_sequence_length, head_size / 2).", + "2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)", "T") .Input(3, "sin_cache", - "2D tensor with shape (max_sequence_length, head_size / 2).", + "2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)", "T") .Output(0, "output", "tensor with same shape as input.", "T") - .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc index 55f01bf0d3f1d..e64de0e6da16a 100644 --- a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc +++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc @@ -11,6 +11,14 @@ namespace onnxruntime { namespace test { +namespace { +enum class TensorType { + kFloat, + kFloat16, + kBFloat16 +}; +} // anonymous namespace + static void RunTest( const std::vector& input_data, const std::vector& position_ids, @@ -20,10 +28,11 @@ static void RunTest( int batch_size, int sequence_length, int head_size, + int rotary_embedding_dim, int num_heads, int max_sequence_length, int64_t interleaved, - bool use_float16, + TensorType tensor_type, bool disable_cpu, bool disable_cuda, bool disable_dml) { @@ -36,7 +45,9 @@ static void RunTest( int hidden_size = num_heads * head_size; std::vector input_dims = {batch_size, sequence_length, hidden_size}; std::vector pos_dims; - std::vector cache_dims = {max_sequence_length, head_size / 2}; + std::vector cache_dims = {max_sequence_length, rotary_embedding_dim > 0 + ? rotary_embedding_dim / 2 + : head_size / 2}; assert(hidden_size != 0 && head_size != 0 && num_heads != 0 && max_sequence_length != 0); assert(max_sequence_length >= sequence_length); @@ -49,7 +60,10 @@ static void RunTest( std::string op_type = "RotaryEmbedding"; std::vector> execution_providers; - int min_cuda_architecture = use_float16 ? 530 : 0; + int min_cuda_architecture = (tensor_type == TensorType::kBFloat16) + ? 800 + : (tensor_type == TensorType::kFloat16) ? 530 + : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml; @@ -59,7 +73,7 @@ static void RunTest( if (enable_dml && !disable_dml) { execution_providers.push_back(DefaultDmlExecutionProvider()); } - if (!use_float16 && !disable_cpu) { + if (tensor_type == TensorType::kFloat && !disable_cpu) { execution_providers.push_back(DefaultCpuExecutionProvider()); } if (execution_providers.size() == 0) { @@ -70,20 +84,36 @@ static void RunTest( OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); test.AddAttribute("interleaved", interleaved); - if (!use_float16) { + if (rotary_embedding_dim > 0) { + test.AddAttribute("rotary_embedding_dim", rotary_embedding_dim); + test.AddAttribute("num_heads", num_heads); + } + + if (tensor_type == TensorType::kFloat) { test.AddInput("input", input_dims, input_data); test.AddInput("position_ids", pos_dims, position_ids); test.AddInput("cos_cache", cache_dims, cos_cache); test.AddInput("sin_cache", cache_dims, sin_cache); test.AddOutput("output", input_dims, output_data); - } else { + } else if (tensor_type == TensorType::kFloat16) { test.AddInput("input", input_dims, ToFloat16(input_data)); test.AddInput("position_ids", pos_dims, position_ids); test.AddInput("cos_cache", cache_dims, ToFloat16(cos_cache)); test.AddInput("sin_cache", cache_dims, ToFloat16(sin_cache)); test.AddOutput("output", input_dims, ToFloat16(output_data)); + } else { + test.AddInput("input", input_dims, FloatsToBFloat16s(input_data)); + test.AddInput("position_ids", pos_dims, position_ids); + test.AddInput("cos_cache", cache_dims, FloatsToBFloat16s(cos_cache)); + test.AddInput("sin_cache", cache_dims, FloatsToBFloat16s(sin_cache)); + test.AddOutput("output", input_dims, FloatsToBFloat16s(output_data)); + } + if (tensor_type == TensorType::kBFloat16) { + test.SetOutputAbsErr("output", 0.03f); + } else { + test.SetOutputAbsErr("output", 0.002f); } - test.SetOutputAbsErr("output", 0.002f); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -95,10 +125,12 @@ static void RunTests(const std::vector& input_data, int batch_size, int sequence_length, int head_size = 0, + int rotary_embedding_dim = 0, int num_heads = 0, int max_sequence_length = 0, int64_t interleaved = 0, - bool use_float16 = true) { + bool use_float16 = true, + bool disable_dml = false) { // FP32 test for CPU RunTest(input_data, position_ids, @@ -108,10 +140,11 @@ static void RunTests(const std::vector& input_data, batch_size, sequence_length, head_size, + rotary_embedding_dim, num_heads, max_sequence_length, interleaved, - false, /* use_fp16 */ + TensorType::kFloat, false, /* disable_cpu */ true, /* disable_cuda */ true /* disable_dml */); @@ -125,13 +158,14 @@ static void RunTests(const std::vector& input_data, batch_size, sequence_length, head_size, + rotary_embedding_dim, num_heads, max_sequence_length, interleaved, - false, /* use_fp16 */ + TensorType::kFloat, false, /* disable_cpu */ false, /* disable_cuda */ - false /* disable_dml */); + disable_dml || false /* disable_dml */); // FP16 test for CUDA and DML if (use_float16) { @@ -143,13 +177,31 @@ static void RunTests(const std::vector& input_data, batch_size, sequence_length, head_size, + rotary_embedding_dim, num_heads, max_sequence_length, interleaved, - true, /* use_fp16 */ + TensorType::kFloat16, true, /* disable_cpu */ false, /* disable_cuda*/ - false /* disable_dml */); + disable_dml || false /* disable_dml */); + + // RunTest(input_data, + // position_ids, + // cos_cache, + // sin_cache, + // output_data, + // batch_size, + // sequence_length, + // head_size, + // rotary_embedding_dim, + // num_heads, + // max_sequence_length, + // interleaved, + // TensorType::kBFloat16, + // true, /* disable_cpu */ + // false, /* disable_cuda*/ + // false /* disable_dml */); } } @@ -159,6 +211,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_SmallData_LlamaMSFT) { int sequence_length = 3; int num_heads = 2; int head_size = 4; + int rotary_embedding_dim = 0; int max_sequence_length = 8; int64_t interleaved = 1; // true @@ -190,6 +243,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_SmallData_LlamaMSFT) { batch_size, sequence_length, head_size, + rotary_embedding_dim, num_heads, max_sequence_length, interleaved); @@ -201,6 +255,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_LargeData_LlamaMSFT) { int sequence_length = 8; int num_heads = 4; int head_size = 6; + int rotary_embedding_dim = 0; int max_sequence_length = 16; int64_t interleaved = 1; // true @@ -388,6 +443,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_LargeData_LlamaMSFT) { batch_size, sequence_length, head_size, + rotary_embedding_dim, num_heads, max_sequence_length, interleaved); @@ -399,6 +455,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_LargeData_LlamaMSFT) { int sequence_length = 8; int num_heads = 4; int head_size = 6; + int rotary_embedding_dim = 0; int max_sequence_length = 16; int64_t interleaved = 0; // false @@ -586,6 +643,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_LargeData_LlamaMSFT) { batch_size, sequence_length, head_size, + rotary_embedding_dim, num_heads, max_sequence_length, interleaved); @@ -597,6 +655,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_SmallData_LlamaMSFT) { int sequence_length = 2; int num_heads = 3; int head_size = 6; + int rotary_embedding_dim = 0; int max_sequence_length = 4; int64_t interleaved = 0; // false @@ -632,10 +691,52 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_SmallData_LlamaMSFT) { batch_size, sequence_length, head_size, + rotary_embedding_dim, num_heads, max_sequence_length, interleaved); } +TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi) { + int batch_size = 1; + int sequence_length = 2; + int num_heads = 1; + int head_size = 6; + int rotary_embedding_dim = 4; + int max_sequence_length = 2; + int64_t interleaved = 0; // false + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f}; + + std::vector position_ids = {0, 1}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.0427f, + -0.2250f, -0.8673f, -1.5071f, -0.4586f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + rotary_embedding_dim, + num_heads, + max_sequence_length, + interleaved, + true, /*use_fp16*/ + true /*disable_dml*/); +} + } // namespace test } // namespace onnxruntime From 373ebac167a1d7f1dfb7a576a6c92f1c75cb711e Mon Sep 17 00:00:00 2001 From: Zhang Lei Date: Mon, 22 Jan 2024 10:40:48 -0800 Subject: [PATCH 07/23] Zhalei/fix seqoutput type (#18765) After refactoring beamsearch, all scores become fp32. Yet it need support fp16 according to original specs. --- .../cpu/transformers/beam_search_impl_gpt.h | 8 +- .../cpu/transformers/beam_search_impl_t5.h | 8 +- .../transformers/beam_search_impl_whisper.h | 8 +- .../cpu/transformers/beam_search_scorer.cc | 82 +++++++++++++------ .../cpu/transformers/beam_search_scorer.h | 12 +-- .../cpu/transformers/generation_shared.h | 3 + .../cuda/transformers/generation_cuda_impl.cu | 63 +++++++++++++- .../cuda/transformers/generation_cuda_impl.h | 19 +++-- .../transformers/generation_device_helper.cc | 55 +++++++++++-- .../models/whisper/whisper_chain.py | 31 ++++++- 10 files changed, 220 insertions(+), 69 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h index 56d950ca2f41e..dc72a038c3d58 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h @@ -397,12 +397,8 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch output_sequences_scores); // Output per token scores - if (output_scores) { - gsl::span target = output_scores->MutableDataAsSpan(); - gsl::span source = beam_state.scores; - assert(target.size() == source.size()); - ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice)); - } + gsl::span per_token_scores = beam_state.scores; + this->beam_scorer_->OutputScores(per_token_scores, output_scores); return status; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 94547887d3a90..cd891a9508019 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -404,12 +404,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches output_sequences_scores); // Output per token scores - if (output_scores) { - gsl::span target = output_scores->MutableDataAsSpan(); - gsl::span source = beam_state.scores; - assert(target.size() == source.size()); - ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice)); - } + gsl::span per_token_scores = beam_state.scores; + this->beam_scorer_->OutputScores(per_token_scores, output_scores); return status; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 91b93a125ad7a..4d6643c68a98b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -500,12 +500,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe output_sequences_scores); // Output per token scores - if (output_scores) { - gsl::span target = output_scores->MutableDataAsSpan(); - gsl::span source = beam_state.scores; - assert(target.size() == source.size()); - ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice)); - } + gsl::span per_token_scores = beam_state.scores; + this->beam_scorer_->OutputScores(per_token_scores, output_scores); return status; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc index 7e2e5b2129221..0eccbe26605f5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc @@ -50,11 +50,12 @@ bool BeamHypotheses::CanImprove(float best_sum_logprobs, int current_length) con return beams_.back().score < current_score; } +template void BeamHypotheses::Output( int top_k, int max_length, - gsl::span& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length) - gsl::span& sequences_scores) // buffer of shape (num_return_sequences) or empty + gsl::span& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length) + gsl::span& sequences_scores) // buffer of shape (num_return_sequences) or empty { // Copy the top_k beams into the sequences ORT_ENFORCE(top_k <= beams_used_); @@ -67,7 +68,7 @@ void BeamHypotheses::Output( gsl::copy(item.hypothesis, target); if (!sequences_scores.empty()) - sequences_scores[index] = item.score; + sequences_scores[index] = (T)item.score; } } @@ -181,21 +182,21 @@ void BeamSearchScorer::Process(ISequences& sequences, } } -void BeamSearchScorer::Finalize(ISequences& sequences, - gsl::span& final_beam_scores, - Tensor* output_sequences, - Tensor* output_sequence_scores) { - ORT_ENFORCE(output_sequences != nullptr); - +template +void OutputSequenceScores(BeamSearchScorer* scorer, + ISequences& sequences, + gsl::span& final_beam_scores, + Tensor* output_sequences, + Tensor* output_sequence_scores) { // Finalize all open beam hypotheses and add to generated hypotheses. - for (size_t batch_index = 0; batch_index < batch_size_; batch_index++) { - BeamHypotheses& beam_hyp = beam_hyps_[batch_index]; + for (size_t batch_index = 0; batch_index < scorer->batch_size_; batch_index++) { + BeamHypotheses& beam_hyp = scorer->beam_hyps_[batch_index]; if (beam_hyp.done_) { continue; } - for (size_t beam_index = 0; beam_index < num_beams_; beam_index++) { - size_t batch_beam_index = batch_index * num_beams_ + beam_index; + for (size_t beam_index = 0; beam_index < scorer->num_beams_; beam_index++) { + size_t batch_beam_index = batch_index * scorer->num_beams_ + beam_index; float final_score = final_beam_scores[batch_beam_index]; auto final_tokens = sequences.GetSequence(narrow(batch_beam_index)); beam_hyp.Add(final_tokens, final_score); @@ -206,26 +207,59 @@ void BeamSearchScorer::Finalize(ISequences& sequences, gsl::span output = output_sequences->MutableDataAsSpan(); // Fill output sequences with pad token ID so that we do not need append it later. - std::fill_n(output.data(), output.size(), pad_token_id_); + std::fill_n(output.data(), output.size(), scorer->pad_token_id_); // Score of each sequence, with shape (batch_size * num_return_sequences). - gsl::span sequence_scores; + gsl::span sequence_scores; if (output_sequence_scores) { - sequence_scores = output_sequence_scores->MutableDataAsSpan(); + sequence_scores = output_sequence_scores->MutableDataAsSpan(); } // Select the best hypotheses according to number of sequences to return. - for (size_t batch_index = 0; batch_index < batch_size_; batch_index++) { - BeamHypotheses& beam_hyp = beam_hyps_[batch_index]; + for (size_t batch_index = 0; batch_index < scorer->batch_size_; batch_index++) { + BeamHypotheses& beam_hyp = scorer->beam_hyps_[batch_index]; - auto batch_output = output.subspan(batch_index * num_return_sequences_ * max_length_, - num_return_sequences_ * max_length_); - gsl::span sequence_scores_buffer; + auto batch_output = output.subspan(batch_index * scorer->num_return_sequences_ * scorer->max_length_, + scorer->num_return_sequences_ * scorer->max_length_); + gsl::span sequence_scores_buffer; if (!sequence_scores.empty()) - sequence_scores_buffer = sequence_scores.subspan(batch_index * num_return_sequences_, num_return_sequences_); + sequence_scores_buffer = sequence_scores.subspan(batch_index * scorer->num_return_sequences_, scorer->num_return_sequences_); + + beam_hyp.template Output(narrow(scorer->num_return_sequences_), narrow(scorer->max_length_), batch_output, + sequence_scores_buffer); + } +} + +void BeamSearchScorer::Finalize(ISequences& sequences, + gsl::span& final_beam_scores, + Tensor* output_sequences, + Tensor* output_sequence_scores) { + ORT_ENFORCE(output_sequences != nullptr); - beam_hyp.Output(narrow(num_return_sequences_), narrow(max_length_), batch_output, - sequence_scores_buffer); + if (output_sequence_scores == nullptr || output_sequence_scores->IsDataType()) { + OutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores); + } else { + ORT_ENFORCE(output_sequence_scores->IsDataType()); + OutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores); + } +} + +void BeamSearchScorer::OutputScores(gsl::span& final_scores, Tensor* output_scores) { + if (output_scores) { + if (output_scores->IsDataType()) { + gsl::span target = output_scores->MutableDataAsSpan(); + ORT_ENFORCE(target.size() == final_scores.size()); + std::copy_n(final_scores.data(), final_scores.size(), target.data()); + } else { + ORT_ENFORCE(output_scores->IsDataType()); + gsl::span target = output_scores->MutableDataAsSpan(); + ORT_ENFORCE(target.size() == final_scores.size()); + const float* src = final_scores.data(); + MLFloat16* dst = target.data(); + for (size_t i = 0; i < target.size(); i++) { + dst[i] = MLFloat16(src[i]); + } + } } } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h index 94b6d340d9f4a..dc92e8038a68e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h @@ -35,10 +35,11 @@ struct BeamHypotheses { bool CanImprove(float best_sum_logprobs, int current_length) const; // Output results - void Output(int top_k, // number of sequences to return - int max_length, // max sequence length - gsl::span& sequences, // buffer with pad token, shape (num_return_sequences, max_length) - gsl::span& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) + template + void Output(int top_k, // number of sequences to return + int max_length, // max sequence length + gsl::span& sequences, // buffer with pad token, shape (num_return_sequences, max_length) + gsl::span& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) gsl::span beams_; // Beam width sized array of hypotheses, sorted by highest scoring int beams_used_; // Number of elements used in beams_ @@ -60,13 +61,14 @@ struct BeamSearchScorer : IBeamScorer { Tensor* output_sequences, Tensor* output_sequence_scores) override; + void OutputScores(gsl::span& final_scores, Tensor* output_scores) override; + bool IsDone() const override { return not_done_count_ == 0; } gsl::span GetNextScores() override { return next_beam_scores_; } gsl::span GetNextTokens() override { return next_beam_tokens_; } gsl::span GetNextIndicesCPU() override { return next_beam_indices_; } - private: size_t batch_size_; size_t num_beams_; size_t max_length_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index f6faf2e325f8f..cb62e2f7bf4da 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -120,6 +120,9 @@ struct IBeamScorer { Tensor* output_sequences, Tensor* output_sequence_scores) = 0; + virtual void OutputScores(gsl::span& final_scores, + Tensor* output_scores) = 0; + virtual bool IsDone() const = 0; // GPU version will return false here, as it asynchronously queues up the event virtual bool IsDoneLater() const { return false; } // GPU version waits for the asynchous result to complete here diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index dbd7fb010462d..a39abefed9cd0 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -307,12 +307,13 @@ __device__ bool BeamHypotheses::CanImprove(float best_sum_logprobs, int current_ return beams_[beams_count_ - 1].score < current_score; } +template __device__ void BeamHypotheses::Output( int top_k, int max_length, int pad_token_id, int32_t* sequences, // buffer of shape (num_return_sequences, max_length) - float* sequences_scores) // buffer of shape (num_return_sequences) or empty + T* sequences_scores) // buffer of shape (num_return_sequences) or empty { // Copy the top_k beams into the sequences for (int index = 0; index < top_k; index++) { @@ -327,7 +328,7 @@ __device__ void BeamHypotheses::Output( target[i] = pad_token_id; if (sequences_scores) - sequences_scores[index] = item.score; + sequences_scores[index] = (T)item.score; } } @@ -501,13 +502,14 @@ void LaunchBeamSearchScorer_AppendNextTokenToSequences(BeamScorerState& state_cp next_beam_tokens.data()); } +template __global__ void BeamSearchScorer_Finalize(BeamScorerState& state, const int32_t* sequences_buffer, int sequence_length, BeamHypotheses* beam_hyps_, const float* final_beam_scores, int32_t* output, - float* sequence_scores) { + T* sequence_scores) { int batch_index = blockIdx.x * blockDim.x + threadIdx.x; if (batch_index >= state.batch_size_) return; @@ -534,6 +536,7 @@ __global__ void BeamSearchScorer_Finalize(BeamScorerState& state, sequence_scores ? sequence_scores + batch_index * state.num_return_sequences_ : nullptr); } +template void LaunchBeamSearchScorer_Finalize(int batch_size, BeamScorerState& state, gsl::span sequences, @@ -541,7 +544,7 @@ void LaunchBeamSearchScorer_Finalize(int batch_size, gsl::span beam_hyps, gsl::span final_beam_scores, gsl::span output, - gsl::span sequence_scores, + gsl::span sequence_scores, cudaStream_t stream) { BeamSearchScorer_Finalize<<<1, batch_size, 0, stream>>>(state, sequences.data(), @@ -552,6 +555,58 @@ void LaunchBeamSearchScorer_Finalize(int batch_size, sequence_scores.data()); } +template void LaunchBeamSearchScorer_Finalize( + int batch_size, + BeamScorerState& state, + gsl::span sequences, + int sequence_length, + gsl::span beam_hyps, + gsl::span final_beam_scores, + gsl::span output, + gsl::span sequence_scores, + cudaStream_t stream); + +template void LaunchBeamSearchScorer_Finalize<__half>( + int batch_size, + BeamScorerState& state, + gsl::span sequences, + int sequence_length, + gsl::span beam_hyps, + gsl::span final_beam_scores, + gsl::span output, + gsl::span<__half> sequence_scores, + cudaStream_t stream); + +template +__global__ void FloatConvertAndCopyKernel(const float* src, T* dst, size_t total_elements) { + int64_t index = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (index < total_elements) { + dst[index] = (T)src[index]; + } +} + +template +void LaunchBeamSearchScoreCopy(gsl::span final_scores, + gsl::span output_scores, + cudaStream_t stream) { + ORT_ENFORCE(final_scores.size() == output_scores.size()); + constexpr unsigned ThreadPerBlock = 256; + unsigned num_blocks = (unsigned)((final_scores.size() + (ThreadPerBlock - 1))/ ThreadPerBlock); + + typedef typename ToCudaType::MappedType CudaT; + + FloatConvertAndCopyKernel<<>>( + final_scores.data(), (CudaT*)output_scores.data(), final_scores.size()); +} + +template void LaunchBeamSearchScoreCopy(gsl::span final_scores, + gsl::span output_scores, + cudaStream_t stream); + +template void LaunchBeamSearchScoreCopy(gsl::span final_scores, + gsl::span output_scores, + cudaStream_t stream); + __global__ void AddProbsKernel(float* log_probs, float* cum_log_probs, const int vocab_size, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h index 5ed5949196b29..281cb6c725975 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h @@ -65,11 +65,12 @@ struct BeamHypotheses { __device__ bool CanImprove(float best_sum_logprobs, int current_length) const; // Output results - __device__ void Output(int top_k, // number of sequences to return - int max_length, // max sequence length - int pad_token_id, // pad token - int32_t* sequences, // buffer with pad token, shape (num_return_sequences, max_length) - float* sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) + template + __device__ void Output(int top_k, // number of sequences to return + int max_length, // max sequence length + int pad_token_id, // pad token + int32_t* sequences, // buffer with pad token, shape (num_return_sequences, max_length) + T* sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) }; struct BeamScorerState { @@ -110,6 +111,7 @@ void LaunchBeamSearchScorer_AppendNextTokenToSequences(BeamScorerState& state_cp gsl::span next_beam_indices, cudaStream_t stream); +template void LaunchBeamSearchScorer_Finalize(int batch_size, BeamScorerState& state, gsl::span sequences, @@ -117,9 +119,14 @@ void LaunchBeamSearchScorer_Finalize(int batch_size, gsl::span beam_hyps_, gsl::span final_beam_scores, gsl::span output, - gsl::span sequence_scores, + gsl::span sequence_scores, cudaStream_t stream); +template +void LaunchBeamSearchScoreCopy(gsl::span final_scores, + gsl::span output_scores, + cudaStream_t stream); + void LaunchNextTokenKernel(const int64_t* next_token_indices, int32_t* next_indices, int32_t* next_tokens, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 380d561bbb23c..bba30805ae1be 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -620,6 +620,8 @@ struct CudaBeamSearchScorer : transformers::IBeamScorer { Tensor* output_sequences, Tensor* output_sequence_scores) override; + void OutputScores(gsl::span& final_scores, Tensor* output_scores) override; + bool IsDone() const override { return false; } // For CUDA we speculatively run the next step while we wait for the GPU to report status. We use 'IsDoneLater()' for this bool IsDoneLater() const override; @@ -632,7 +634,6 @@ struct CudaBeamSearchScorer : transformers::IBeamScorer { } gsl::span GetNextIndicesGPU() override { return next_beam_indices_; } - private: mutable cuda::AutoDestoryCudaEvent event_process_complete_; IAllocatorUniquePtr state_cpu_; IAllocatorUniquePtr state_gpu_; @@ -743,22 +744,58 @@ bool CudaBeamSearchScorer::IsDoneLater() const { return state_cpu_->not_done_count_ == 0; } +template +void CudaOutputSequenceScores(CudaBeamSearchScorer* scorer, + transformers::ISequences& sequences, + gsl::span& final_beam_scores, + Tensor* output_sequences, + Tensor* output_sequence_scores) { + // Word IDs of each sequence, with shape (batch_size * num_return_sequences, max_sequence_length). + gsl::span output{output_sequences->MutableData(), static_cast(output_sequences->Shape().Size())}; + + // Score of each sequence, with shape (batch_size * num_return_sequences). + using CudaT = typename ToCudaType::MappedType; + gsl::span sequence_scores; + if (output_sequence_scores) { + sequence_scores = gsl::span{(CudaT*)output_sequence_scores->MutableData(), static_cast(output_sequence_scores->Shape().Size())}; + } + + cuda::LaunchBeamSearchScorer_Finalize(scorer->state_cpu_->batch_size_, + *scorer->state_gpu_, + sequences.GetCurrentDeviceSequences(), + sequences.GetSequenceLength(), + scorer->beam_hyps_, + final_beam_scores, + output, + sequence_scores, + scorer->stream_); +} + void CudaBeamSearchScorer::Finalize(transformers::ISequences& sequences, gsl::span& final_beam_scores, Tensor* output_sequences, Tensor* output_sequence_scores) { ORT_ENFORCE(output_sequences != nullptr); - // Word IDs of each sequence, with shape (batch_size * num_return_sequences, max_sequence_length). - gsl::span output{output_sequences->MutableData(), static_cast(output_sequences->Shape().Size())}; - - // Score of each sequence, with shape (batch_size * num_return_sequences). - gsl::span sequence_scores; - if (output_sequence_scores) { - sequence_scores = gsl::span{output_sequence_scores->MutableData(), static_cast(output_sequence_scores->Shape().Size())}; + if (output_sequence_scores == nullptr || output_sequence_scores->IsDataType()) { + CudaOutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores); + } else { + ORT_ENFORCE(output_sequence_scores->IsDataType()); + CudaOutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores); } +} - cuda::LaunchBeamSearchScorer_Finalize(state_cpu_->batch_size_, *state_gpu_, sequences.GetCurrentDeviceSequences(), sequences.GetSequenceLength(), beam_hyps_, final_beam_scores, output, sequence_scores, stream_); +void CudaBeamSearchScorer::OutputScores(gsl::span& final_scores, Tensor* output_scores) { + if (output_scores) { + if (output_scores->IsDataType()) { + gsl::span target(output_scores->MutableData(), output_scores->Shape().Size()); + cuda::LaunchBeamSearchScoreCopy(final_scores, target, stream_); + } else { + ORT_ENFORCE(output_scores->IsDataType()); + gsl::span target(output_scores->MutableData(), output_scores->Shape().Size()); + cuda::LaunchBeamSearchScoreCopy(final_scores, target, stream_); + } + } } std::unique_ptr CreateBeamScorer(const transformers::IGenerationParameters& parameters, diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index 33958e55f8c38..a74666b7af297 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -53,9 +53,9 @@ def chain_model(args): beam_outputs = ["sequences"] if args.output_sequence_scores: - beam_outputs.append("sequence_scores") + beam_outputs.append("sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores") if args.output_scores: - beam_outputs.append("scores") + beam_outputs.append("scores_fp16" if args.precision == Precision.FLOAT16 else "scores") if args.use_whisper_beamsearch: assert len(beam_inputs) == 12 @@ -75,6 +75,7 @@ def chain_model(args): beam_outputs.extend(["no_speech_probs_beam"]) input_features_cast_node, len_pen_cast_node, rep_pen_cast_node = None, None, None + output_scores_cast_node = output_sequence_scores_cast_node = None if args.precision == Precision.FLOAT16: input_features_cast_node = helper.make_node( "Cast", @@ -97,6 +98,22 @@ def chain_model(args): name="CastRepetitionPenaltyToFp16", to=TensorProto.FLOAT16, ) + if args.output_sequence_scores: + output_sequence_scores_cast_node = helper.make_node( + "Cast", + inputs=["sequence_scores_fp16"], + outputs=["sequence_scores"], + name="CastOutputSequenceScoresToFp32", + to=TensorProto.FLOAT, + ) + if args.output_scores: + output_scores_cast_node = helper.make_node( + "Cast", + inputs=["scores_fp16"], + outputs=["scores"], + name="CastScoresToFp32", + to=TensorProto.FLOAT, + ) operator_type = "WhisperBeamSearch" if args.use_whisper_beamsearch else "BeamSearch" node = helper.make_node(operator_type, inputs=beam_inputs, outputs=beam_outputs, name="BeamSearch_zcode") @@ -214,10 +231,18 @@ def chain_model(args): opset_import = [helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17)] graph_nodes = ( - [input_features_cast_node, len_pen_cast_node, rep_pen_cast_node, node] + [ + input_features_cast_node, + len_pen_cast_node, + rep_pen_cast_node, + node, + output_sequence_scores_cast_node, + output_scores_cast_node, + ] if args.precision == Precision.FLOAT16 else [node] ) + graph_nodes = [node for node in graph_nodes if node is not None] if args.output_no_speech_probs: prob_cast_node = helper.make_node( "Cast", From 8d9d7511799e2138c14454bb672caf07dcdc2457 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Mon, 22 Jan 2024 12:47:42 -0800 Subject: [PATCH 08/23] [QNN EP] Expose device-level session options (#19212) ### Description - Adds the following session options to configure the device: - `soc_model`: The SoC model number. Refer to the QNN SDK documentation for valid values. Defaults to "0" (unknown). - `htp_arch`: The minimum HTP architecture the driver will use to select compatible QNN operators. - `device_id`: The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). ### Motivation and Context Allow more configuration. --- .../core/session/onnxruntime_c_api.h | 8 ++ .../qnn/builder/qnn_backend_manager.cc | 31 ++++++- .../qnn/builder/qnn_backend_manager.h | 14 ++- .../qnn/builder/qnn_configs_helper.h | 90 +++++++++++++++++++ .../qnn/builder/qnn_graph_configs_helper.cc | 43 --------- .../qnn/builder/qnn_graph_configs_helper.h | 56 ------------ .../providers/qnn/qnn_execution_provider.cc | 69 ++++++++++++-- .../providers/qnn/qnn_execution_provider.h | 7 +- onnxruntime/test/onnx/main.cc | 18 +++- .../test/perftest/command_args_parser.cc | 4 + onnxruntime/test/perftest/ort_test_session.cc | 14 ++- .../test/providers/qnn/qnn_basic_test.cc | 56 +++++++++++- 12 files changed, 292 insertions(+), 118 deletions(-) create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h delete mode 100644 onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.cc delete mode 100644 onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index aca9f4896fbdb..101a578ec3e1d 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3608,6 +3608,14 @@ struct OrtApi { * - "1": Faster preparation time, less optimal graph. * - "2": Longer preparation time, more optimal graph. * - "3": Longest preparation time, most likely even more optimal graph. See QNN SDK documentation for specific details. + * "soc_model": The SoC model number. Refer to the QNN SDK documentation for valid values. Defaults to "0" (unknown). + * "htp_arch": The minimum HTP architecture the driver will use to select compatible QNN operators. Available options: + * - "0": Default (none). + * - "68" + * - "69" + * - "73" + * - "75" + * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). * * SNPE supported keys: * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 193e4f5ff2a31..973b81d337c81 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -17,6 +17,7 @@ #include "core/framework/endian_utils.h" #include "core/common/logging/capture.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" +#include "core/providers/qnn/builder/qnn_configs_helper.h" #ifdef _WIN32 #include @@ -329,9 +330,37 @@ Status QnnBackendManager::CreateDevice() { return Status::OK(); } + qnn::QnnConfigsBuilder device_configs_builder(QNN_DEVICE_CONFIG_INIT, + {}); + if (qnn_backend_type_ == QnnBackendType::HTP) { + // Set SoC Model. The *enum* Qnn_SocModel_t is deprecated and will not be updated in the future. Therefore, + // must use the latest SDK documentation to get the SoC model of the latest HW. + if (soc_model_ != QNN_SOC_MODEL_UNKNOWN) { + QnnHtpDevice_CustomConfig_t& custom_config = device_configs_builder.PushCustomConfig(); + custom_config.option = QNN_HTP_DEVICE_CONFIG_OPTION_SOC; + custom_config.socModel = soc_model_; + + QnnDevice_Config_t& device_config = device_configs_builder.PushConfig(); + device_config.option = QNN_DEVICE_CONFIG_OPTION_CUSTOM; + device_config.customConfig = &custom_config; + } + + // Set the minimum HTP architecture. The driver will use ops that are compatible with this minimum architecture. + if (htp_arch_ != QNN_HTP_DEVICE_ARCH_NONE) { + QnnHtpDevice_CustomConfig_t& custom_config = device_configs_builder.PushCustomConfig(); + custom_config.option = QNN_HTP_DEVICE_CONFIG_OPTION_ARCH; + custom_config.arch.arch = htp_arch_; + custom_config.arch.deviceId = device_id_; + + QnnDevice_Config_t& device_config = device_configs_builder.PushConfig(); + device_config.option = QNN_DEVICE_CONFIG_OPTION_CUSTOM; + device_config.customConfig = &custom_config; + } + } + LOGS_DEFAULT(INFO) << "Create device."; if (nullptr != qnn_interface_.deviceCreate) { - auto result = qnn_interface_.deviceCreate(log_handle_, nullptr, &device_handle_); + auto result = qnn_interface_.deviceCreate(log_handle_, device_configs_builder.GetQnnConfigs(), &device_handle_); if (QNN_SUCCESS != result) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create device. Error: ", result); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 58f207efb9e95..f7b8947ab84bb 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -17,6 +17,7 @@ #include #include "HTP/QnnHtpDevice.h" #include "QnnLog.h" +#include "QnnTypes.h" #include "System/QnnSystemInterface.h" #include "core/common/status.h" #include "core/common/logging/logging.h" @@ -35,13 +36,19 @@ class QnnBackendManager { uint32_t rpc_control_latency, HtpPerformanceMode htp_performance_mode, ContextPriority context_priority, - std::string&& qnn_saver_path) + std::string&& qnn_saver_path, + uint32_t device_id, + QnnHtpDevice_Arch_t htp_arch, + uint32_t soc_model) : backend_path_(backend_path), profiling_level_(profiling_level), rpc_control_latency_(rpc_control_latency), htp_performance_mode_(htp_performance_mode), context_priority_(context_priority), - qnn_saver_path_(qnn_saver_path) { + qnn_saver_path_(qnn_saver_path), + device_id_(device_id), + htp_arch_(htp_arch), + soc_model_(soc_model) { } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnBackendManager); @@ -233,6 +240,9 @@ class QnnBackendManager { #endif const std::string qnn_saver_path_; uint32_t htp_power_config_client_id_ = 0; + uint32_t device_id_ = 0; + QnnHtpDevice_Arch_t htp_arch_ = QNN_HTP_DEVICE_ARCH_NONE; + uint32_t soc_model_ = QNN_SOC_MODEL_UNKNOWN; }; } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h b/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h new file mode 100644 index 0000000000000..9dd9bbaa08d64 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace onnxruntime { +namespace qnn { + +/** + * Helper class for building a null-terminated list of QNN configurations. + * A QNN configuration consists of multiple objects with references to each other. This + * class ensures that all configuration objects have the same lifetime, so that they remain valid + * across calls to qnn_interface.xxxCreate(). + */ +template +class QnnConfigsBuilder { + public: + /** + * Initializes the config build. Provide the initial/default value for each config struct type. + * \param base_config_init The initial/default value for objects of type BaseConfigType. + * \param custom_config_init The initial/default value for objects of type CustomConfigType. + */ + QnnConfigsBuilder(BaseConfigType base_config_init, CustomConfigType custom_config_init) + : base_config_init_(std::move(base_config_init)), custom_config_init_(std::move(custom_config_init)) {} + + /** + * Returns a pointer to the beginning of a null-terminated array of QNN base configurations. + * This result is typically passed to QNN's xxxCreate() APIs. + * + * \return Pointer to null-terminated BaseConfigType* array. + */ + const BaseConfigType** GetQnnConfigs() { + if (config_ptrs_.empty()) { + return nullptr; + } + + if (!IsNullTerminated()) { + config_ptrs_.push_back(nullptr); + } + + return config_ptrs_.data(); + } + + /** + * Creates and returns a reference to a new custom QNN configuration object. The object is initialized to + * the QNN recommended default value. The caller is meant to override fields in this object. + * + * \return A reference to a default CustomConfigType object. + */ + CustomConfigType& PushCustomConfig() { + custom_configs_.push_back(custom_config_init_); + return custom_configs_.back(); + } + + /** + * Creates and returns a reference to a new QNN configuration object. The object is initialized to + * the QNN recommended default value. The caller is meant to override fields in this object. + * + * \return A reference to a default BaseConfigType object. + */ + BaseConfigType& PushConfig() { + configs_.push_back(base_config_init_); + BaseConfigType& config = configs_.back(); + + // Add pointer to this new config to the list of config pointers. + if (IsNullTerminated()) { + config_ptrs_.back() = &config; // Replace last nullptr entry. + } else { + config_ptrs_.push_back(&config); + } + + return config; + } + + private: + bool IsNullTerminated() const { + return !config_ptrs_.empty() && config_ptrs_.back() == nullptr; + } + + BaseConfigType base_config_init_; + CustomConfigType custom_config_init_; + InlinedVector custom_configs_; + InlinedVector configs_; + InlinedVector config_ptrs_; +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.cc b/onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.cc deleted file mode 100644 index 63aa01b48e7e2..0000000000000 --- a/onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.cc +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/qnn/builder/qnn_graph_configs_helper.h" - -#include "HTP/QnnHtpGraph.h" - -namespace onnxruntime { -namespace qnn { - -const QnnGraph_Config_t** QnnGraphConfigsBuilder::GetQnnGraphConfigs() { - if (graph_config_ptrs_.empty()) { - return nullptr; - } - - if (!IsNullTerminated()) { - graph_config_ptrs_.push_back(nullptr); - } - - return graph_config_ptrs_.data(); -} - -QnnHtpGraph_CustomConfig_t& QnnGraphConfigsBuilder::PushHtpGraphCustomConfig() { - htp_custom_graph_configs_.push_back(QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT); - return htp_custom_graph_configs_.back(); -} - -QnnGraph_Config_t& QnnGraphConfigsBuilder::PushGraphConfig() { - graph_configs_.push_back(QNN_GRAPH_CONFIG_INIT); - QnnGraph_Config_t& config = graph_configs_.back(); - - // Add pointer to this new graph config to the list of graph config pointers. - if (IsNullTerminated()) { - graph_config_ptrs_.back() = &config; // Replace last nullptr entry. - } else { - graph_config_ptrs_.push_back(&config); - } - - return config; -} - -} // namespace qnn -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.h b/onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.h deleted file mode 100644 index 8c4928fdacbc4..0000000000000 --- a/onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.h +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "HTP/QnnHtpGraph.h" - -namespace onnxruntime { -namespace qnn { - -/** - * Helper class for building a null-terminated list of QNN Graph configurations. - * A QNN configuration consists of multiple objects with references to each other. This - * class ensures that all configuration objects have the same lifetime, so that they remain valid - * across the call to graphCreate(). - */ -class QnnGraphConfigsBuilder { - public: - /** - * Returns a pointer to the beginning of a null-terminated array of QNN Graph configurations. - * This result is passed QNN's graphCreate() API. - * - * \return Pointer to null-terminated QnnGraph_Config_t* array. - */ - const QnnGraph_Config_t** GetQnnGraphConfigs(); - - /** - * Creates and returns a reference to a new HTP graph configuration object. The object is initialized to - * the QNN recommended default value. The caller is meant to override fields in this object. - * - * \return A reference to a default QnnHtpGraph_CustomConfig_t object. - */ - QnnHtpGraph_CustomConfig_t& PushHtpGraphCustomConfig(); - - /** - * Creates and returns a reference to a new graph configuration object. The object is initialized to - * the QNN recommended default value. The caller is meant to override fields in this object. - * - * \return A reference to a default QnnGraph_Config_t object. - */ - QnnGraph_Config_t& PushGraphConfig(); - - private: - bool IsNullTerminated() const { - return !graph_config_ptrs_.empty() && graph_config_ptrs_.back() == nullptr; - } - - InlinedVector htp_custom_graph_configs_; - InlinedVector graph_configs_; - InlinedVector graph_config_ptrs_; -}; - -} // namespace qnn -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 56eb1f4f59f33..0310cc2bc8f26 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -111,6 +111,22 @@ void QNNExecutionProvider::ParseHtpGraphFinalizationOptimizationMode(const std:: } } +static void ParseHtpArchitecture(const std::string& htp_arch_string, QnnHtpDevice_Arch_t& qnn_htp_arch) { + if (htp_arch_string.empty() || htp_arch_string == "0") { + qnn_htp_arch = QNN_HTP_DEVICE_ARCH_NONE; + } else if (htp_arch_string == "68") { + qnn_htp_arch = QNN_HTP_DEVICE_ARCH_V68; + } else if (htp_arch_string == "69") { + qnn_htp_arch = QNN_HTP_DEVICE_ARCH_V69; + } else if (htp_arch_string == "73") { + qnn_htp_arch = QNN_HTP_DEVICE_ARCH_V73; + } else if (htp_arch_string == "75") { + qnn_htp_arch = QNN_HTP_DEVICE_ARCH_V75; + } else { + LOGS_DEFAULT(WARNING) << "Invalid HTP architecture: " << htp_arch_string; + } +} + QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_options_map, const SessionOptions* session_options) : IExecutionProvider{onnxruntime::kQnnExecutionProvider, true} { @@ -223,13 +239,49 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } } + static const std::string QNN_DEVICE_ID = "device_id"; + uint32_t device_id = 0; + auto dev_id_pos = provider_options_map.find(QNN_DEVICE_ID); + if (dev_id_pos != provider_options_map.end()) { + int value = std::stoi(dev_id_pos->second); + if (value < 0) { + LOGS_DEFAULT(WARNING) << "Invalid device ID '" << value + << "', only >= 0 allowed. Set to " << device_id << "."; + } else { + device_id = static_cast(value); + } + } + + static const std::string QNN_HTP_ARCH = "htp_arch"; + QnnHtpDevice_Arch_t htp_arch = QNN_HTP_DEVICE_ARCH_NONE; + auto htp_arch_pos = provider_options_map.find(QNN_HTP_ARCH); + if (htp_arch_pos != provider_options_map.end()) { + ParseHtpArchitecture(htp_arch_pos->second, htp_arch); + } + + static const std::string QNN_SOC_MODEL = "soc_model"; + uint32_t soc_model = QNN_SOC_MODEL_UNKNOWN; + auto soc_model_pos = provider_options_map.find(QNN_SOC_MODEL); + if (soc_model_pos != provider_options_map.end()) { + int value = std::stoi(soc_model_pos->second); + if (value < 0) { + LOGS_DEFAULT(WARNING) << "Invalid SoC Model '" << value + << "', only >= 0 allowed. Set to " << soc_model << "."; + } else { + soc_model = static_cast(value); + } + } + qnn_backend_manager_ = std::make_unique( std::move(backend_path), profiling_level, rpc_control_latency, htp_performance_mode, context_priority, - std::move(qnn_saver_path)); + std::move(qnn_saver_path), + device_id, + htp_arch, + soc_model); } bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, @@ -512,25 +564,25 @@ Status QNNExecutionProvider::CreateComputeFunc(std::vector& nod return Status::OK(); } -void QNNExecutionProvider::InitQnnGraphConfigs(qnn::QnnGraphConfigsBuilder& configs_builder) const { +void QNNExecutionProvider::InitQnnGraphConfigs(qnn::QnnConfigsBuilder& configs_builder) const { if (qnn_backend_manager_->GetQnnBackendType() == qnn::QnnBackendType::HTP) { if (htp_graph_finalization_opt_mode_ != qnn::HtpGraphFinalizationOptimizationMode::kDefault) { - QnnHtpGraph_CustomConfig_t& htp_graph_opt_config = configs_builder.PushHtpGraphCustomConfig(); + QnnHtpGraph_CustomConfig_t& htp_graph_opt_config = configs_builder.PushCustomConfig(); htp_graph_opt_config.option = QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION; htp_graph_opt_config.optimizationOption.type = QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_OPTIMIZATION_FLAG; htp_graph_opt_config.optimizationOption.floatValue = static_cast(htp_graph_finalization_opt_mode_); - QnnGraph_Config_t& graph_opt_config = configs_builder.PushGraphConfig(); + QnnGraph_Config_t& graph_opt_config = configs_builder.PushConfig(); graph_opt_config.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; graph_opt_config.customConfig = &htp_graph_opt_config; } if (vtcm_size_in_mb_ > 0) { - QnnHtpGraph_CustomConfig_t& htp_graph_opt_config_vtcm = configs_builder.PushHtpGraphCustomConfig(); + QnnHtpGraph_CustomConfig_t& htp_graph_opt_config_vtcm = configs_builder.PushCustomConfig(); htp_graph_opt_config_vtcm.option = QNN_HTP_GRAPH_CONFIG_OPTION_VTCM_SIZE; htp_graph_opt_config_vtcm.vtcmSizeInMB = static_cast(vtcm_size_in_mb_); - QnnGraph_Config_t& graph_opt_config_vtcm = configs_builder.PushGraphConfig(); + QnnGraph_Config_t& graph_opt_config_vtcm = configs_builder.PushConfig(); graph_opt_config_vtcm.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; graph_opt_config_vtcm.customConfig = &htp_graph_opt_config_vtcm; } @@ -547,10 +599,11 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector qnn_model = std::make_unique(logger, qnn_backend_manager_.get()); - qnn::QnnGraphConfigsBuilder graph_configs_builder; + qnn::QnnConfigsBuilder graph_configs_builder(QNN_GRAPH_CONFIG_INIT, + QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT); InitQnnGraphConfigs(graph_configs_builder); - ORT_RETURN_IF_ERROR(qnn_model->ComposeGraph(graph_viewer, fused_node, graph_configs_builder.GetQnnGraphConfigs())); + ORT_RETURN_IF_ERROR(qnn_model->ComposeGraph(graph_viewer, fused_node, graph_configs_builder.GetQnnConfigs())); ORT_RETURN_IF_ERROR(qnn_model->FinalizeGraphs()); ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index d4927f3fa505e..3f75be0efebcd 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -5,11 +5,12 @@ #include "core/framework/execution_provider.h" #include "core/framework/session_options.h" +#include "core/graph/model.h" #include #include "core/providers/qnn/builder/qnn_backend_manager.h" #include "core/providers/qnn/builder/qnn_model.h" -#include "core/providers/qnn/builder/qnn_graph_configs_helper.h" -#include "core/graph/model.h" +#include "core/providers/qnn/builder/qnn_configs_helper.h" +#include "HTP/QnnHtpGraph.h" namespace onnxruntime { @@ -58,7 +59,7 @@ class QNNExecutionProvider : public IExecutionProvider { void ParseHtpGraphFinalizationOptimizationMode(const std::string& htp_graph_finalization_opt_mode_string); - void InitQnnGraphConfigs(qnn::QnnGraphConfigsBuilder& configs_holder) const; + void InitQnnGraphConfigs(qnn::QnnConfigsBuilder& configs_builder) const; private: qnn::HtpGraphFinalizationOptimizationMode htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault; diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 7e0a811b7d07c..aca609cf94270 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -60,6 +60,10 @@ void usage() { "\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n" "\t [QNN only] [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: \n" "\t '0', '1', '2', '3', default is '0'.\n" + "\t [QNN only] [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" + "\t [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. \n" + "\t Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" + "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" "\t [Usage]: -e -i '| |' \n\n" "\t [Example] [For QNN EP] -e qnn -i \"profiling_level|detailed backend_path|/folderpath/libQnnCpu.so\" \n\n" "\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" @@ -483,7 +487,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { if (supported_profiling_level.find(value) == supported_profiling_level.end()) { ORT_THROW("Supported profiling_level: off, basic, detailed"); } - } else if (key == "rpc_control_latency" || key == "vtcm_mb") { + } else if (key == "rpc_control_latency" || key == "vtcm_mb" || key == "soc_model" || key == "device_id") { // no validation } else if (key == "htp_performance_mode") { std::set supported_htp_perf_mode = {"burst", "balanced", "default", "high_performance", @@ -512,10 +516,20 @@ int real_main(int argc, char* argv[], Ort::Env& env) { std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_graph_finalization_optimization_mode. select from: " + str); } + } else if (key == "htp_arch") { + std::unordered_set supported_htp_archs = {"0", "68", "69", "73", "75"}; + if (supported_htp_archs.find(value) == supported_htp_archs.end()) { + std::ostringstream str_stream; + std::copy(supported_htp_archs.begin(), supported_htp_archs.end(), + std::ostream_iterator(str_stream, ",")); + std::string str = str_stream.str(); + ORT_THROW("Wrong value for htp_arch. select from: " + str); + } } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', -'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])"); +'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', +'soc_model', 'htp_arch', 'device_id'])"); } qnn_options[key] = value; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index ef04e2be8fd29..6c1d447c7b3a3 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -78,6 +78,10 @@ namespace perftest { "\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n" "\t [QNN only] [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: \n" "\t '0', '1', '2', '3', default is '0'.\n" + "\t [QNN only] [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" + "\t [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. \n" + "\t Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" + "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" "\t [Usage]: -e -i '| |'\n\n" "\t [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU_FP32 enable_npu_fast_compile|true num_of_threads|5 enable_opencl_throttling|true cache_dir|\"\"\"\n" "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index f8a012af5bb13..6854a2649060a 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -343,7 +343,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device if (supported_profiling_level.find(value) == supported_profiling_level.end()) { ORT_THROW("Supported profiling_level: off, basic, detailed"); } - } else if (key == "rpc_control_latency" || key == "vtcm_mb") { + } else if (key == "rpc_control_latency" || key == "vtcm_mb" || key == "soc_model" || key == "device_id") { // no validation } else if (key == "htp_performance_mode") { std::set supported_htp_perf_mode = {"burst", "balanced", "default", "high_performance", @@ -372,10 +372,20 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device if (supported_qnn_context_priority.find(value) == supported_qnn_context_priority.end()) { ORT_THROW("Supported qnn_context_priority: low, normal, normal_high, high"); } + } else if (key == "htp_arch") { + std::unordered_set supported_htp_archs = {"0", "68", "69", "73", "75"}; + if (supported_htp_archs.find(value) == supported_htp_archs.end()) { + std::ostringstream str_stream; + std::copy(supported_htp_archs.begin(), supported_htp_archs.end(), + std::ostream_iterator(str_stream, ",")); + std::string str = str_stream.str(); + ORT_THROW("Wrong value for htp_arch. select from: " + str); + } } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', -'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])"); +'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', 'soc_model', +'htp_arch', 'device_id'])"); } qnn_options[key] = value; diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index bc40682cf87b7..c50b1002fa8c8 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -176,7 +176,10 @@ TEST(QnnEP, TestDisableCPUFallback_ConflictingConfig) { // types and shapes. static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp, bool enable_qnn_saver = false, std::string htp_graph_finalization_opt_mode = "", - std::string qnn_context_priority = "") { + std::string qnn_context_priority = "", + std::string soc_model = "", + std::string htp_arch = "", + std::string device_id = "") { Ort::SessionOptions so; // Ensure all type/shape inference warnings result in errors! @@ -205,6 +208,18 @@ static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp, bo options["qnn_context_priority"] = std::move(qnn_context_priority); } + if (!soc_model.empty()) { + options["soc_model"] = std::move(soc_model); + } + + if (!htp_arch.empty()) { + options["htp_arch"] = std::move(htp_arch); + } + + if (!device_id.empty()) { + options["device_id"] = std::move(device_id); + } + so.AppendExecutionProvider("QNN", options); Ort::Session session(*ort_env, ort_model_path, so); @@ -519,6 +534,45 @@ TEST_F(QnnHTPBackendTests, HTPGraphFinalizationOptimizationModes) { } } +// Test that models run with various SoC model values +TEST_F(QnnHTPBackendTests, HTPSocModels) { + constexpr std::array soc_models = { "", // No explicit SoC model specified + "0", // "Unknown" +#if defined(_M_ARM64) + "37" }; // SC8280X +#elif defined(__linux__) + "30" }; // SM8350 +#else + "" }; +#endif + + for (auto soc_model : soc_models) { + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", + true, // use_htp + false, // enable_qnn_saver + "", // htp_graph_finalization_opt_mode + "", // qnn_context_priority + soc_model); + } +} + +// Test that models run with various HTP architecture values (and set device_id) +TEST_F(QnnHTPBackendTests, HTPArchValues) { + constexpr std::array htp_archs = {"", // No explicit arch specified + "0", // "None" + "68"}; // v68 + for (auto htp_arch : htp_archs) { + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", + true, // use_htp + false, // enable_qnn_saver + "", // htp_graph_finalization_opt_mode + "", // qnn_context_priority + "", // soc_model + htp_arch, // htp_arch + "0"); // device_id + } +} + // Test that models run with high QNN context priority. TEST_F(QnnHTPBackendTests, QnnContextPriorityHigh) { RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", From 780acda7b4f044564e1f222901fd6a676aa05cbf Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Tue, 23 Jan 2024 06:02:56 +0800 Subject: [PATCH 09/23] Add Big models pipeline (#19222) ### Description 2 models are added in CI. Stabe diffusion Model stage is based on https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md LLama2 FP16 is based on https://github.com/microsoft/Llama-2-Onnx. 12G GPU memory is not enough, so I choose T4 to run it. ### Motivation and Context Add regular E2E test for big models. It will be triggered in main build, that is, it'll run after one PR is merged. More models will be added later. ### Test Runs ### https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1275191&view=results --- .../azure-pipelines/bigmodels-ci-pipeline.yml | 259 ++++++++++++++++++ 1 file changed, 259 insertions(+) create mode 100644 tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml new file mode 100644 index 0000000000000..ff2e7c0468a21 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -0,0 +1,259 @@ +# reference: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +parameters: +- name: specificArtifact + displayName: Use Specific Artifact + type: boolean + default: false +- name: BuildId + displayName: Specific Artifact's RunId + type: number + default: 0 + +resources: + repositories: + - repository: manylinux + type: Github + endpoint: Microsoft + name: pypa/manylinux + ref: 5eda9aded5462201e6310105728d33016e637ea7 + + - repository: LLaMa2Onnx + type: Github + endpoint: Microsoft + name: Microsoft/Llama-2-Onnx + ref: main + +variables: + - template: templates/common-variables.yml + - name: docker_base_image + value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + - name: linux_trt_version + value: 8.6.1.6-1.cuda11.8 + +stages: +- stage: Build_Onnxruntime_Cuda + jobs: + - job: Linux_Build + timeoutInMinutes: 120 + variables: + skipComponentGovernanceDetection: true + CCACHE_DIR: $(Pipeline.Workspace)/ccache + workspace: + clean: all + pool: onnxruntime-Ubuntu2204-AMD-CPU + steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + + - checkout: self + clean: true + submodules: none + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: " + --network=host + --build-arg BASEIMAGE=$(docker_base_image) + --build-arg TRT_VERSION=$(linux_trt_version) + --build-arg BUILD_UID=$( id -u ) + " + Repository: onnxruntimecuda11build + + - task: Cache@2 + inputs: + key: '"ccache" | "$(Build.SourceBranch)" | "$(Build.SourceVersion)"' + path: $(CCACHE_DIR) + restoreKeys: | + "ccache" | "$(Build.SourceBranch)" + "ccache" + cacheHitVar: CACHE_RESTORED + displayName: Cach Task + + - script: | + sudo mkdir -p $(Pipeline.Workspace)/ccache + condition: ne(variables.CACHE_RESTORED, 'true') + displayName: Create Cache Dir + + - task: CmdLine@2 + inputs: + script: | + mkdir -p $HOME/.onnx + docker run -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" --rm \ + --volume /data/onnx:/data/onnx:ro \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + --volume $(Pipeline.Workspace)/ccache:/cache \ + -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + -e CCACHE_DIR=/cache \ + onnxruntimecuda11build \ + /bin/bash -c " + set -ex; \ + env; \ + ccache -s; \ + /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ + --build_dir /build --cmake_generator Ninja \ + --config Release --update --build \ + --skip_submodule_sync \ + --build_shared_lib \ + --parallel \ + --build_wheel \ + --enable_onnx_tests --use_cuda --cuda_version=${{variables.common_cuda_version}} --cuda_home=/usr/local/cuda-${{variables.common_cuda_version}} --cudnn_home=/usr/local/cuda-${{variables.common_cuda_version}} \ + --enable_cuda_profiling --enable_cuda_nhwc_ops \ + --enable_pybind --build_java \ + --use_cache \ + --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=75;86' ; \ + ccache -sv; \ + ccache -z" + workingDirectory: $(Build.SourcesDirectory) + + - task: CmdLine@2 + inputs: + script: | + rm -rf $(Build.BinariesDirectory)/Release/onnxruntime $(Build.BinariesDirectory)/Release/pybind11 + rm -f $(Build.BinariesDirectory)/Release/models + find $(Build.BinariesDirectory)/Release/_deps -mindepth 1 ! -regex '^$(Build.BinariesDirectory)/Release/_deps/onnx-src\(/.*\)?' -delete + cd $(Build.BinariesDirectory)/Release + find -executable -type f > $(Build.BinariesDirectory)/Release/perms.txt + + - script: | + set -ex + mkdir -p $(Agent.TempDirectory)/ort + cp $(Build.BinariesDirectory)/Release/dist/*.whl $(Agent.TempDirectory)/ort/ + displayName: 'Copy Wheels' + + - task: PublishPipelineArtifact@0 + displayName: 'Publish Pipeline Artifact' + inputs: + artifactName: 'drop-ort-linux-gpu' + targetPath: '$(Agent.TempDirectory)/ort' + + - template: templates/explicitly-defined-final-tasks.yml + +- stage: Stale_Diffusion + dependsOn: + - Build_Onnxruntime_Cuda + jobs: + - job: Stale_Diffusion + variables: + skipComponentGovernanceDetection: true + CCACHE_DIR: $(Pipeline.Workspace)/ccache + workspace: + clean: all + pool: onnxruntime-Linux-GPU-A10-12G + steps: + - checkout: self + clean: true + submodules: none + + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Onnxruntime Artifact' + ArtifactName: 'drop-ort-linux-gpu' + TargetPath: '$(Build.BinariesDirectory)/Release' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - script: | + docker run --rm --gpus all -v $PWD:/workspace -v $(Build.BinariesDirectory)/Release:/Release nvcr.io/nvidia/pytorch:22.11-py3 \ + bash -c " + set -ex; \ + python3 --version; \ + python3 -m pip install --upgrade pip; \ + python3 -m pip install /Release/*.whl; \ + pushd /workspace/onnxruntime/python/tools/transformers/models/stable_diffusion; \ + python3 -m pip install -r requirements-cuda11.txt; \ + python3 -m pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com; \ + echo Generate an image guided by a text prompt; \ + python3 demo_txt2img.py "astronaut riding a horse on mars"; \ + echo Generate an image with Stable Diffusion XL guided by a text prompt; \ + python3 demo_txt2img_xl.py 'starry night over Golden Gate Bridge by van gogh'; \ + python3 demo_txt2img_xl.py --enable-refiner 'starry night over Golden Gate Bridge by van gogh'; \ + echo Generate an image guided by a text prompt using LCM LoRA; \ + python3 demo_txt2img_xl.py --scheduler LCM --lora-weights latent-consistency/lcm-lora-sdxl --denoising-steps 4 "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"; \ + popd; \ + " + displayName: 'Run stable diffusion demo' + workingDirectory: $(Build.SourcesDirectory) + +- stage: Llama2_ONNX_FP16 + dependsOn: + - Build_Onnxruntime_Cuda + jobs: + - job: Llama2_ONNX_FP16 + variables: + skipComponentGovernanceDetection: true + workspace: + clean: all + pool: onnxruntime-Linux-GPU-T4 + steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + + - checkout: self + clean: true + submodules: none + + - checkout: LLaMa2Onnx + clean: true + submodules: none + + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Onnxruntime Artifact' + ArtifactName: 'drop-ort-linux-gpu' + TargetPath: '$(Build.BinariesDirectory)/ort-artifact/' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - task: DownloadPackage@1 + displayName: 'Download Llama2 model' + inputs: + packageType: upack + feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' + version: 1.0.0 + definition: '772ebce3-7e06-46d5-b3cc-82040ec4b2ce' + downloadPath: $(Agent.TempDirectory)/llama2_onnx_ft16 + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: onnxruntime/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 + Context: onnxruntime/tools/ci_build/github/linux/docker/ + ScriptName: onnxruntime/tools/ci_build/get_docker_image.py + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" + Repository: onnxruntimeubi8packagestest + UpdateDepsTxt: false + + - script: | + docker run --rm --gpus all -v $(Build.SourcesDirectory)/Llama-2-Onnx:/workspace \ + -v $(Build.BinariesDirectory)/ort-artifact/:/ort-artifact \ + -v $(Agent.TempDirectory)/llama2_onnx_ft16:/models \ + onnxruntimeubi8packagestest \ + bash -c " + set -ex; \ + python3 -m pip install --upgrade pip ; \ + python3 -m pip install /ort-artifact/*.whl ; \ + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118 ; \ + python3 -m pip install sentencepiece ; \ + pushd /workspace ; \ + python3 MinimumExample/Example_ONNX_LlamaV2.py --onnx_file /models/ONNX/LlamaV2_7B_FT_float16.onnx \ + --embedding_file /models/embeddings.pth --tokenizer_path tokenizer.model --prompt 'What is the lightest element?' > /workspace/answer.txt ; \ + popd ; \ + " + displayName: 'Run Llama2 demo' + workingDirectory: $(Build.SourcesDirectory) + + - script: | + set -ex + real=$(cat $(Build.SourcesDirectory)/Llama-2-Onnx/answer.txt) + trim_actual=$(tr -dc '[[:print:]]' <<< "$real") + expected="The lightest element is hydrogen. Hydrogen is the lightest element on the periodic table, with an atomic mass of 1.00794 u (unified atomic mass units)." + [ "$expected" == "$trim_actual" ] && exit 0 || exit 1 + displayName: 'Check result' From 77da2ef278a4e77cca4cef4e5d72ed1ef46fcce3 Mon Sep 17 00:00:00 2001 From: snadampal <87143774+snadampal@users.noreply.github.com> Date: Mon, 22 Jan 2024 16:43:06 -0600 Subject: [PATCH 10/23] [aarch64] Add Sbgemm kernel to accelerate fp32 tensor matmul with bfloat16 (#17031) ### Description This PR adds SbgemmKernel for aarch64. This includes Sbegmm kernel to implement matrix multiplication with bfloat16 SIMD instructions (bfmmla) and MatMul operator changes to invoke the Sbgemm kernel. To enable Sbgemm kernel, set the following session option: "kOrtSessionOptionsGemmFastMathMode" The PR also adds new test cases for mlas and ort. ### Motivation and Context This is to improve MatMul performance on aarch64 platform. I have run the below benchmarking script (bert , roberta and gpt2 model inference) on AWS Graviton3 based c7g.4xl instance and observed 1.2x -1.76x performance improvement compared to sgemm (fp32) kernel performance. ``` cd onnxruntime/python/tools/transformers python3 benchmark.py ``` And the unit test precision results are matching to sgemm kernel results. `./build.sh --config RelWithDebInfo --build_shared_lib --parallel --compile_no_warning_as_error --skip_submodule_sync ` --- cmake/onnxruntime_mlas.cmake | 4 + .../onnxruntime_session_options_config_keys.h | 8 +- onnxruntime/core/common/cpuid_info.cc | 7 + onnxruntime/core/common/cpuid_info.h | 2 + onnxruntime/core/mlas/inc/mlas.h | 113 +++ .../core/mlas/lib/aarch64/SbgemmKernelNeon.S | 907 ++++++++++++++++++ onnxruntime/core/mlas/lib/mlasi.h | 25 + onnxruntime/core/mlas/lib/platform.cpp | 6 + onnxruntime/core/mlas/lib/sbgemm.h | 399 ++++++++ .../core/mlas/lib/sbgemm_kernel_neon.cpp | 362 +++++++ onnxruntime/core/providers/cpu/math/matmul.cc | 106 +- onnxruntime/core/providers/cpu/math/matmul.h | 15 + .../test/mlas/unittest/test_sbgemm.cpp | 141 +++ onnxruntime/test/mlas/unittest/test_sbgemm.h | 281 ++++++ .../qdq_transformer_fastmath_test.cc | 730 ++++++++++++++ .../cpu/math/matmul_fastmath_test.cc | 305 ++++++ onnxruntime/test/util/compare_ortvalue.cc | 80 ++ 17 files changed, 3473 insertions(+), 18 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S create mode 100644 onnxruntime/core/mlas/lib/sbgemm.h create mode 100644 onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp create mode 100644 onnxruntime/test/mlas/unittest/test_sbgemm.cpp create mode 100644 onnxruntime/test/mlas/unittest/test_sbgemm.h create mode 100644 onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc create mode 100644 onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index f89d2150a6830..17de2aa4aaea6 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -355,19 +355,23 @@ else() ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S + ${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S ${MLAS_SRC_DIR}/activate_fp16.cpp ${MLAS_SRC_DIR}/dwconv.cpp ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/pooling_fp16.cpp ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp + ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 8fd51962bf087..b282438795eb5 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -249,4 +249,10 @@ static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_p // Flag to specify whether to dump the EP context into the Onnx model. // "0": dump the EP context into separate file, keep the file name in the Onnx model. // "1": dump the EP context into the Onnx model. (default). -static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; \ No newline at end of file +static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; + +// Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul. +// Option values: +// - "0": Gemm FastMath mode is not enabled. [DEFAULT] +// - "1": Gemm FastMath mode is enabled. +static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16"; diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index fcf9c2b03dea5..711fd595e90fd 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -30,6 +30,10 @@ #define HWCAP2_SVEI8MM (1 << 9) #endif +#ifndef HWCAP2_BF16 +#define HWCAP2_BF16 (1 << 14) +#endif + #endif // ARM #endif // Linux @@ -148,6 +152,7 @@ void CPUIDInfo::ArmLinuxInit() { has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); + has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); const uint32_t core_cnt = cpuinfo_get_cores_count(); core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); @@ -177,6 +182,7 @@ void CPUIDInfo::ArmLinuxInit() { has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); + has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0); #endif } @@ -278,6 +284,7 @@ void CPUIDInfo::ArmWindowsInit() { /* TODO: implement them when hw+sw is available for testing these features */ has_arm_neon_i8mm_ = false; has_arm_sve_i8mm_ = false; + has_arm_neon_bf16_ = false; } #endif /* (arm or arm64) and windows */ diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index a15c75104b83a..2f8041e39f680 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -30,6 +30,7 @@ class CPUIDInfo { bool HasArmNeonDot() const { return has_arm_neon_dot_; } bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } + bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } uint32_t GetCurrentCoreIdx() const; @@ -125,6 +126,7 @@ class CPUIDInfo { bool has_fp16_{false}; bool has_arm_neon_i8mm_{false}; bool has_arm_sve_i8mm_{false}; + bool has_arm_neon_bf16_{false}; #ifdef CPUIDINFO_ARCH_X86 diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index bdd4dba521eba..ce7838556fbf0 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1614,6 +1614,119 @@ MlasHalfGemmConvertPackB( void* PackedB ); +#if defined(__aarch64__) && defined(__linux__) +/** + * @brief Whether current CPU supports Bfloat16(bf16) acceleration. + */ +bool MLASCALL +MlasBf16AccelerationSupported(); + +/** + * @brief Interface for bf16 gemm post processors. + * + * Example implementation of this interface includes activations, + * conversion from single precision to precision, etc. + * + * SBGEMM is computed tile by tile. When a tile of result matrix + * is produced, the method Process() is called to process this tile. + * Parameters of this method describe the location and shape of the + * tile. + */ +class MLAS_SBGEMM_POSTPROCESSOR +{ + public: + virtual void Process(float*, /**< the address of matrix to process */ + size_t, /**< the start row index of matrix */ + size_t, /**< the start col index of matrix */ + size_t, /**< the element count per row to process */ + size_t, /**< the element count per col to process */ + size_t /**< the leading dimension of matrix */ + ) const = 0; + + virtual ~MLAS_SBGEMM_POSTPROCESSOR() {} +}; + +/** + * @brief bfloat16 precision activation functions, with optional sum tensor. + * Supplied sum tensor must be the same layout as the GEMM output tensor. + * And the supplied sum tensor will be added to the tensor before activation. + */ +class MLAS_SBGEMM_ACTIVATION_PROCESSOR : public MLAS_SBGEMM_POSTPROCESSOR +{ + public: + MLAS_SBGEMM_ACTIVATION_PROCESSOR(const MLAS_ACTIVATION& Activation, const float* SumBuf = nullptr) + : Activation_(Activation), SumBuf_(SumBuf) + { + } + + void Process(float* C, size_t StartM, size_t StartN, size_t CountM, size_t CountN, size_t ldc) + const override; + + private: + const MLAS_ACTIVATION& Activation_; + const float* SumBuf_; +}; + +/** + * @brief Data parameters for bfloat16 precision GEMM routine + * All except C are [in] parameters + */ +struct MLAS_SBGEMM_DATA_PARAMS { + const void* A = nullptr; /**< address of A */ + const void* B = nullptr; /**< address of B */ + const float* Bias = nullptr; /**< address of Bias, vector size N */ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/ + size_t ldc = 0; /**< leading dimension of C*/ + const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr; + bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/ + bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/ +}; + +/** + * @brief Bfloat16 precision Batched GEMM: C = A * B + Bias + * Either B can be either fp32 or bf16 + * + * Note: We only support uniform batching, so shapes and types of the + * input must be same across all parameter blocks. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] ThreadPool + * @return + */ +void MLASCALL +MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr); + +/** + * @brief For bfloat16 precision GEMM, returns size of the + * packing buffer needed for right hand side + * @param[in] N Number of columns + * @param[in] K Number of rows + * @return size of the packing buffer, + * 0 if operation not supported + */ +size_t MLASCALL +MlasSBGemmPackBSize(size_t N, size_t K); + +/** + * @brief For bfloat16 precision GEMM, convert the float matrix B + * to blfoat16 precision and pack it into a packing buffer + * + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] B Address of matrix B + * @param[in] ldb leading dimension of input matrix B + * @param[out] PackedB Address of the packed matrix + */ +void MLASCALL +MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB); +#endif + /** * @brief Indirect Depthwise convolution for fp16 * @param Input Supplies the indirect buffer for NHWC input diff --git a/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S new file mode 100644 index 0000000000000..e424c30515e9f --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S @@ -0,0 +1,907 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + SbgemmKernelNeon.s + +Abstract: + + This module implements the kernels for the bfloat16 half precision matrix/matrix + multiply operation (SBGEMM). + +--*/ + +#include "asmmacro.h" + + .text + +// +// Stack frame layout for the sbgemm kernel. d8-d15, x19-x30 need save +// + .equ .LMlasSbgemmKernel_backup_x19_x20, 0 + .equ .LMlasSbgemmKernel_backup_x21_x22, 16 + .equ .LMlasSbgemmKernel_backup_x23_x24, 32 + .equ .LMlasSbgemmKernel_backup_x25_x26, 48 + .equ .LMlasSbgemmKernel_backup_x27_x28, 64 + .equ .LMlasSbgemmKernel_backup_d8_d9, 80 + .equ .LMlasSbgemmKernel_backup_d10_d11, 96 + .equ .LMlasSbgemmKernel_backup_d12_d13, 112 + .equ .LMlasSbgemmKernel_backup_d14_d15, 128 + .equ .LMlasSbgemmKernel_SavedRegisters, 144 + .equ .LMlasSbgemmKernel_SavedRegisters_Neg, -144 + + +// +// ClearRowAccumulators +// +// Generates the code to clear the accumulators for a single row of the output +// block. +// + + .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + mov v\Vec1Reg\().16b,v0.16b +.if \Columns\() > 2 + mov v\Vec2Reg\().16b,v1.16b +.endif +.if \Columns\() > 4 + mov v\Vec3Reg\().16b,v2.16b +.endif +.if \Columns\() > 6 + mov v\Vec4Reg\().16b,v3.16b +.endif + + .endm + +// +// InitBlockAccumulators +// +// Generates the code to init the accumulators for a single row of the output +// block. +// + + .macro InitBlockAccumulators Mode, Columns, Rows + + //check if the Bias != nullptr + cbz x8,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipBiasAdd + + ld1 {v14.4s},[x8],#16 // load Bias[0] + // v4~v7 will be set to matrixB after this, so, they can used now + dup v4.4s,v14.s[0] // broadcast Bias + dup v5.4s,v14.s[1] + dup v6.4s,v14.s[2] + dup v7.4s,v14.s[3] + + zip1 v0.4s, v4.4s, v5.4s + zip2 v1.4s, v6.4s, v7.4s +.if \Columns\() > 4 + ld1 {v15.4s},[x8],#16 // load Bias[4] + dup v4.4s,v15.s[0] // broadcast Bias + dup v5.4s,v15.s[1] + dup v6.4s,v15.s[2] + dup v7.4s,v15.s[3] + + zip1 v2.4s, v4.4s, v5.4s + zip2 v3.4s, v6.4s, v7.4s +.endif + + b .L\Mode\().PopulateAccumulators\Columns\().x\Rows\() + +.L\Mode\().InitBlock\Columns\().x\Rows\().SkipBiasAdd: + eor v0.16b,v0.16b,v0.16b // No bias, reset regs + eor v1.16b,v1.16b,v1.16b + eor v2.16b,v2.16b,v2.16b + eor v3.16b,v3.16b,v3.16b + +.L\Mode\().PopulateAccumulators\Columns\().x\Rows\(): + InitRowAccumulators \Columns\(),16,17,18,19 +.if \Rows\() > 2 + InitRowAccumulators \Columns\(),20,21,22,23 +.endif +.if \Rows\() > 4 + InitRowAccumulators \Columns\(),24,25,26,27 +.endif +.if \Rows\() > 6 + InitRowAccumulators \Columns\(),28,29,30,31 +.endif + + .endm + +// LoadMatrixAElementsBy8 +// +// Generates the code to load 4 or 8 elements from matrix A. +// + .macro LoadMatrixAElementsBy8 Rows + + ldr q8,[x0],#16 + bfcvtn v8.4h, v8.4s +.if \Rows\() > 1 + ldr q1,[x10],#16 + bfcvtn2 v8.8h, v1.4s +.endif + +.if \Rows\() > 2 + ldr q9,[x11],#16 + bfcvtn v9.4h, v9.4s +.endif +.if \Rows\() > 3 + ldr q1,[x12],#16 + bfcvtn2 v9.8h, v1.4s +.endif + +.if \Rows\() > 4 + ldr q10,[x20],#16 + bfcvtn v10.4h, v10.4s +.endif +.if \Rows\() > 5 + ldr q1,[x21],#16 + bfcvtn2 v10.8h, v1.4s +.endif + +.if \Rows\() > 6 + ldr q11,[x22],#16 + bfcvtn v11.4h, v11.4s +.endif +.if \Rows\() > 7 + ldr q1,[x23],#16 + bfcvtn2 v11.8h, v1.4s +.endif + + .endm + + +// +// MultiplyAccumulateRow +// +// Generates the code to multiply and accumulate a single row of the output +// block. +// + + .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + bfmmla v\Vec1Reg\().4s, \MatrixAReg\().8h, v4.8h +.if \Columns\() > 2 + bfmmla v\Vec2Reg\().4s, \MatrixAReg\().8h, v5.8h +.endif +.if \Columns\() > 4 + bfmmla v\Vec3Reg\().4s, \MatrixAReg\().8h, v6.8h +.endif +.if \Columns\() > 6 + bfmmla v\Vec4Reg\().4s, \MatrixAReg\().8h, v7.8h +.endif + + .endm + +// +// MultiplyAccumulateBlock +// +// Generates the code to multiply and accumulate into the output block. +// + + .macro MultiplyAccumulateBlock Columns, Rows + + MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 +.if \Rows\() > 2 + MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 +.endif +.if \Rows\() > 4 + MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 +.endif +.if \Rows\() > 6 + MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 +.endif + + .endm + +// +// ComputeBlockLoop +// +// Generates the code to loop over K entries of the input matrices to produce +// the output block. +// + + .macro ComputeBlockLoop Mode, Columns, Rows + + InitBlockAccumulators \Mode\(),\Columns\(),\Rows\() + + add x10,x0,x6,lsl #2 // compute matrix A plus 1 row +.if \Rows\() > 2 + add x11,x10,x6,lsl #2 // compute matrix A plus 2 rows + add x12,x11,x6,lsl #2 // compute matrix A plus 3 rows +.endif +.if \Rows\() > 4 + add x20,x12,x6,lsl #2 // compute matrix A plus 4 rows + add x21,x20,x6,lsl #2 // compute matrix A plus 5 rows +.endif +.if \Rows\() > 6 + add x22,x21,x6,lsl #2 // compute matrix A plus 6 rows + add x23,x22,x6,lsl #2 // compute matrix A plus 7 rows +.endif + sub x9,x3,#4 // block count to process + tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: + + LoadMatrixAElementsBy8 \Rows\() + ldr q4, [x1],#16 +.if \Columns\() > 2 + ldr q5,[x1],#16 +.endif +.if \Columns\() > 4 + ldr q6,[x1],#16 +.endif +.if \Columns\() > 6 + ldr q7,[x1],#16 +.endif + MultiplyAccumulateBlock \Columns\(),\Rows\() + + sub x9,x9,#4 + tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop +.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: + add x9,x9,#4 // correct for over-subtract above + cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: + LoadMatrixAElementsBy8 \Rows\() + ldr q4, [x1],#16 +.if \Columns\() > 2 + ldr q5,[x1],#16 +.endif +.if \Columns\() > 4 + ldr q6,[x1],#16 +.endif +.if \Columns\() > 6 + ldr q7,[x1],#16 +.endif + MultiplyAccumulateBlock \Columns\(),\Rows\() + +.L\Mode\().Output\Columns\().x\Rows\().Block: + + .endm + + +// +// OutputRow2Element +// OutputRow4Element +// OutputRow6Element +// OutputRow8Element +// OutputRow10Element +// OutputRow12Element +// OutputRow14Element +// OutputRow16Element +// +// Generates the code to store elements to the output block. +// + + .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr s8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr s9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + mov v8.S[2], v9.S[0] + + fadd v8.4s,v8.4s,v\Vec1Reg\().4s + + mov w27, v8.S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v8.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + mov w27, v\Vec1Reg\().S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v\Vec1Reg\().S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + + mov v8.D[1], v9.D[0] + + fadd v8.4s,v8.4s,v\Vec1Reg\().4s + + mov x27, v8.D[0] + mov x28, v8.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + mov x27, v\Vec1Reg\().D[0] + mov x28, v\Vec1Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.endif + + .endm + + + .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#8 + ldr w28,[\AddrReg1\()],#-8 + mov v8.S[2], w28 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-8 + mov v9.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + mov x27, v8.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v8.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v9.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v9.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + mov x27, v4.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v4.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v5.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v5.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + +.endif + + .endm + + + .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr w28, [\AddrReg1\()],#-16 + +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr w27,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + mov v8.S[0], w28 + mov v8.S[2], w27 + + fadd v8.4s,v8.4s,v\Vec3Reg\().4s + + mov w27, v8.S[0] + mov w28, v8.S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov w27, v\Vec3Reg\().S[0] + mov w28, v\Vec3Reg\().S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif +.endif + +.endm + + + .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#-16 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + mov v11.D[0],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + + mov v10.D[1], v11.D[0] + + fadd v10.4s,v10.4s,v\Vec3Reg\().4s + + mov x27, v10.D[0] + mov x28, v10.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov x27, v\Vec3Reg\().D[0] + mov x28, v\Vec3Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif +.endif + + .endm + + .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#8 + ldr w28, [\AddrReg1\()],#-24 + mov v10.S[2], w28 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-24 + mov v11.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + fadd v10.4s,v10.4s,v6.4s + fadd v11.4s,v11.4s,v7.4s + + str q8,[\AddrReg1\()],#16 + + mov x27, v10.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v10.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 + mov x27, v11.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v11.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + str q4,[\AddrReg1\()],#16 + mov x27, v6.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v6.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 + mov x27, v7.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v7.S[2] + str w27, [\AddrReg2\()],#4 +.endif +.endif + + .endm + + + .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldp q8,q10,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldp q9,q11,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + fadd v10.4s,v10.4s,v6.4s + fadd v11.4s,v11.4s,v7.4s + + stp q8,q10,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q9,q11,[\AddrReg2\()],#32 +.endif +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + stp q4,q6,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q5,q7,[\AddrReg2\()],#32 +.endif +.endif + + .endm + +// +// OutputBlock +// +// Generates the code to store the output block. +// + + .macro OutputBlock Mode, Columns, Rows + + OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) + +.if \Rows\() > 2 + OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) +.endif + +.if \Rows\() > 4 + OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) +.endif + +.if \Rows\() > 6 + OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) +.endif + + .endm +// +// ProcessRows +// +// Generates the code to process a compute and store the output block for a +// fixed number of rows. +// + + .macro ProcessRows Mode, Rows + mov x4,#\Rows\() // return number of rows handled + cmp x5,#6 + ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() + +.L\Mode\().ProcessNextColumnLoop8x\Rows\(): + ComputeBlockLoop \Mode\(),8,\Rows\() + + sub x5,x5,#8 + cmp x5,#0 + blt .L\Mode\().Output14ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),16,\Rows\() + mov x0,x26 // reload matrix A + cmp x5,#6 + bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() + cbz x5,.L\Mode\().ExitKernel + + +.L\Mode\().ProcessNextColumnLoop6x\Rows\(): + + cmp x5,#4 + ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() + ComputeBlockLoop \Mode\(),6,\Rows\() + sub x5,x5,#6 + cmp x5,#0 + blt .L\Mode\().Output10ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),12,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#4 + bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop4x\Rows\(): + cmp x5,#2 + ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() + ComputeBlockLoop \Mode\(),4,\Rows\() + sub x5,x5,#4 + cmp x5,#0 + blt .L\Mode\().Output6ElementsOnlyFor\Rows\() + + OutputBlock \Mode\(),8,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#2 + bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop2x\Rows\(): + ComputeBlockLoop \Mode\(),2,\Rows\() + sub x5,x5,#2 + cmp x5,#0 + blt .L\Mode\().Output2ElementsOnlyFor\Rows\() + + OutputBlock \Mode\(),4,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#2 + b .L\Mode\().ExitKernel + +.L\Mode\().Output14ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),14,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output10ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),10,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output6ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),6,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output2ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),2,\Rows\() + b .L\Mode\().ExitKernel + + .endm + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasSbgemmCopyPackB or MlasSbgemmTransposePackB. + + C (x2) - Supplies the address of matrix C. + + CountK (x3) - Supplies the number of columns from matrix A and the number + of rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda (x6) - Supplies the first dimension of matrix A. + + ldc (x7) - Supplies the first dimension of matrix C. + + Bias - Supplies the address of Bias Vector [1xn] + + +Return Value: + + Returns the number of rows handled. + +--*/ + .macro SbgemmKernelNeonFunction Mode + + FUNCTION_ENTRY MlasSbgemmKernel\Mode\() + + ldr x8, [sp, #0] //Bias vector + + stp x19, x20, [sp, #.LMlasSbgemmKernel_SavedRegisters_Neg]! + stp x21, x22, [sp, #.LMlasSbgemmKernel_backup_x21_x22] + stp x23, x24, [sp, #.LMlasSbgemmKernel_backup_x23_x24] + stp x25, x26, [sp, #.LMlasSbgemmKernel_backup_x25_x26] + stp x27, x28, [sp, #.LMlasSbgemmKernel_backup_x27_x28] + stp d8, d9, [sp, #.LMlasSbgemmKernel_backup_d8_d9] + stp d10, d11, [sp, #.LMlasSbgemmKernel_backup_d10_d11] + stp d12, d13, [sp, #.LMlasSbgemmKernel_backup_d12_d13] + stp d14, d15, [sp, #.LMlasSbgemmKernel_backup_d14_d15] + + add x13,x2,x7,lsl #2 // compute matrix C plus 1 row + add x14,x13,x7,lsl #2 // compute matrix C plus 2 rows + add x15,x14,x7,lsl #2 // compute matrix C plus 3 rows + add x16,x15,x7,lsl #2 // compute matrix C plus 4 rows + add x17,x16,x7,lsl #2 // compute matrix C plus 5 rows + add x18,x17,x7,lsl #2 // compute matrix C plus 6 rows + add x19,x18,x7,lsl #2 // compute matrix C plus 7 rows + + mov x26,x0 // save matrix A +// +// Process 8 rows of the matrices. +// + cmp x4,#8 + blt .L\Mode\().ProcessCountMLessThan8 + ProcessRows \Mode\(),8 + +// +// Restore non-volatile registers and return. +// + +.L\Mode\().ExitKernel: + mov x0,x4 + + ldp d14, d15, [sp, #.LMlasSbgemmKernel_backup_d14_d15] + ldp d12, d13, [sp, #.LMlasSbgemmKernel_backup_d12_d13] + ldp d10, d11, [sp, #.LMlasSbgemmKernel_backup_d10_d11] + ldp d8, d9, [sp, #.LMlasSbgemmKernel_backup_d8_d9] + ldp x27, x28, [sp, #.LMlasSbgemmKernel_backup_x27_x28] + ldp x25, x26, [sp, #.LMlasSbgemmKernel_backup_x25_x26] + ldp x23, x24, [sp, #.LMlasSbgemmKernel_backup_x23_x24] + ldp x21, x22, [sp, #.LMlasSbgemmKernel_backup_x21_x22] + ldp x19, x20, [sp], #.LMlasSbgemmKernel_SavedRegisters + + ret + +// +// Process 4 rows of the matrix. +// + +.L\Mode\().ProcessCountMLessThan8: + cmp x4,#4 + blt .L\Mode\().ProcessCountMLessThan4 + ProcessRows \Mode\(),4 + b .L\Mode\().ExitKernel + +// +// Process 2 row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan4: + cmp x4,#2 + blt .L\Mode\().ProcessCountMLessThan2 + + ProcessRows \Mode\(),2 + b .L\Mode\().ExitKernel + + +// +// Process the last row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan2: + ProcessRows \Mode\(),1 + b .L\Mode\().ExitKernel + + + .endm + + SbgemmKernelNeonFunction Zero + SbgemmKernelNeonFunction Add diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 7bb8b17031a84..624eb913d5c9e 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -193,6 +193,8 @@ class MLASCPUIDInfo bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } + bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } + private: MLASCPUIDInfo(); @@ -200,6 +202,7 @@ class MLASCPUIDInfo bool has_fp16_{false}; bool has_arm_neon_i8mm_{false}; bool has_arm_sve_i8mm_{false}; + bool has_arm_neon_bf16_{false}; }; using MLAS_CPUIDINFO = MLASCPUIDInfo; @@ -357,6 +360,20 @@ size_t #else +#if defined(__aarch64__) && defined(__linux__) +typedef size_t(MLASCALL MLAS_SBGEMM_FLOAT_KERNEL)( + const float* A, + const bfloat16_t* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + const float* Bias +); +#endif + typedef size_t (MLASCALL MLAS_GEMM_FLOAT_KERNEL)( @@ -727,6 +744,10 @@ extern "C" { #else MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZero; MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelAdd; +#if defined(__aarch64__) && defined(__linux__) + MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelZero; + MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelAdd; +#endif MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelZero; MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelAdd; #endif @@ -856,6 +877,10 @@ extern "C" { #define MLAS_DGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) #define MLAS_QGEMM_THREAD_COMPLEXITY 65536 +#if defined(__aarch64__) && defined(__linux__) +#define MLAS_SBGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) +#endif + // // Single-threaded single precision matrix/matrix multiply operation. // diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 1310ed3f384b9..de092f7d1d350 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -60,6 +60,10 @@ MLASCPUIDInfo::MLASCPUIDInfo() #define HWCAP2_SVEI8MM (1 << 9) #endif +#ifndef HWCAP2_BF16 +#define HWCAP2_BF16 (1 << 14) +#endif + #if defined(BUILD_MLAS_NO_ONNXRUNTIME) MLASCPUIDInfo::MLASCPUIDInfo() { @@ -70,6 +74,8 @@ MLASCPUIDInfo::MLASCPUIDInfo() has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); + + has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0); } #endif diff --git a/onnxruntime/core/mlas/lib/sbgemm.h b/onnxruntime/core/mlas/lib/sbgemm.h new file mode 100644 index 0000000000000..de7fd72fad45a --- /dev/null +++ b/onnxruntime/core/mlas/lib/sbgemm.h @@ -0,0 +1,399 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + sbgemm.h + +Abstract: + + This module defines the set of template functions to implement bfloat16 + precision matrix/matrix multiply operation (SBGEMM). + + To implement a new kernel, template functions below need to be specialized: + MlasSBGemmConvertPackB + MlasSBGemmPackedBOffset + MlasSBGemmPackedBLeadingDim + MlasSBGemmKernel + + MlasSBGemmOperation is the shared kernel driver. + + A kernel type should define the following constants: + bool PackNeeded; Whether B needs to be packed + size_t KernelMaxM; Max # rows the vectorized kernel can process + size_t PackedK; Packed alignment on the K dim (power of 2) + size_t PackedN; Packed alignment on the n dim (power of 2) + MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#pragma once + +#include +#include + +#include "mlasi.h" + +/** + * @brief Define the default striding parameters for + * the bfloat16 precision gemm operation + */ +struct MLAS_SBGEMM_STRIDES { + size_t M; + size_t N; + size_t K; +}; + +/** + * @brief Convert fp32 matrix B to bf16 and pack the data + * + * @tparam KernelType + * @param[out] D Address of packing buffer + * @param[in] B Address of source matrix B in fp32 + * @param[in] ldb Leading dimension of B + * @param[in] CountN # of column to pack + * @param[in] CountK # of rows to pack + */ +template +void +MlasSBGemmConvertPackB( + bfloat16_t* PackedB, const float* B, size_t ldb, size_t CountN, size_t CountK +); + +/** + * @brief Find the location of PackedB[StartK, StartN] + * + * @tparam KernelType + * @param PackedB + * @param DimN Total columns of the packing buffer + * @param DimK Total rows of the packing buffer + * @param StartN + * @param StartK + * @return Address of PackedB[StartK, StartN] + */ +template +MLAS_FORCEINLINE const bfloat16_t* +MlasSBGemmPackedBOffset( + const bfloat16_t* PackedB, size_t DimN, size_t DimK, size_t StartN, size_t StartK +) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return PackedB + StartK * DimN + StartN; +} + +/** + * @brief leading dimension of the packed B buffer + * Related to how B is packed + * @tparam KernelType + * @param DimN + * @param DimK + * @return leading dimension of the packed B buffer + */ +template +MLAS_FORCEINLINE size_t +MlasSBGemmPackedBLeadingDim(size_t DimN, size_t DimK) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return DimN; +} + +template +void +MlasSBGemmKernel(const size_t CountM, const size_t CountN, const size_t CountK, const float* A, const size_t lda, const bfloat16_t* B, float* C, size_t ldc, const float* Bias, const bool ZeroMode); + +template +MLAS_FORCEINLINE void +MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size_t AlignedN, size_t K, const float* A, size_t lda, const void* PackedB, float* C, size_t ldc, const float* Bias, void* PostProcessor) +{ + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + size_t PackedStrideN = Strides.N; + size_t PackedStrideK = Strides.K; + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + const size_t SliceStartN = RangeStartN + n; + CountN = std::min(RangeCountN - n, PackedStrideN); + + // + // Step through each slice of matrix B along the K dimension. + // + size_t CountK; + for (size_t k = 0; k < K; k += CountK) { + bool ZeroMode = (k == 0); + CountK = std::min(K - k, PackedStrideK); + + const bfloat16_t* pb = (const bfloat16_t*)PackedB + AlignedN * k + CountK * SliceStartN; + float* c = C + n; + const float* pbias = ((nullptr == Bias) ? nullptr : Bias + RangeStartN + n); + MlasSBGemmKernel(M, CountN, CountK, A + k, lda, pb, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); + } + if (PostProcessor != nullptr) { + ((MLAS_SBGEMM_POSTPROCESSOR*)PostProcessor) + ->Process(C + n, M, SliceStartN, M, CountN, ldc); + } + } +} + +template +void +MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_t lda, const float* B, size_t ldb, float* C, size_t ldc, const float* Bias, void* PostProcessor) +{ + // + // Compute the strides to step through slices of the input matrices. + // + // Expand the N stride if K is small or expand the K stride if N is small + // for better utilization of the B panel. Avoid changing the K stride if + // the A panel needs to be used for transposing. + // + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + size_t StrideN = Strides.N; + size_t StrideK = Strides.K; + + if (N >= K) { + while (StrideK / 2 >= K) { + StrideN *= 2; + StrideK /= 2; + } + } else { + while (StrideN > 16 && StrideN / 2 >= N) { + StrideK *= 2; + StrideN /= 2; + } + } + + constexpr size_t packBSize = UpAlignSize(Strides.N * Strides.K * sizeof(bfloat16_t)); + MlasThreadedBufAlloc(packBSize); + uint8_t* p = ThreadedBufHolder.get(); + auto* PanelB = reinterpret_cast(p); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < N; n += CountN) { + CountN = std::min(N - n, StrideN); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountK; + for (size_t k = 0; k < K; k += CountK) { + CountK = std::min(K - k, StrideK); + + // + // Copy a panel of matrix B to a local packed buffer. + // + MlasSBGemmConvertPackB(PanelB, B + n + k * ldb, ldb, CountN, CountK); + + auto* c = C + n; + const float* pbias = + ((nullptr == Bias) ? nullptr : Bias + n); // TODO: check the SliceNStart + + bool ZeroMode = (k == 0); + MlasSBGemmKernel(M, CountN, CountK, A + k, lda, PanelB, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); + } + if (PostProcessor != nullptr) { + ((MLAS_SBGEMM_POSTPROCESSOR*)PostProcessor)->Process(C + n, M, N, M, CountN, ldc); + } + } +} + +template +void +MlasSBGemmOperation(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, const size_t M, const size_t N, const size_t K, const MLAS_SBGEMM_DATA_PARAMS* DataParams, ptrdiff_t ThreadId) +{ + const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN; + const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN; + + // + // Partition the operation along the M dimension. + // + size_t RangeStartM; + size_t RangeCountM; + + MlasPartitionWork(ThreadIdM, ThreadCountM, M, &RangeStartM, &RangeCountM); + + // + // Partition the operation along the N dimension. + // + size_t RangeStartN; + size_t RangeCountN; + + const size_t BlockedN = + (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + MlasPartitionWork(ThreadIdN, ThreadCountN, BlockedN, &RangeStartN, &RangeCountN); + + RangeStartN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + RangeCountN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + RangeCountN = std::min(N - RangeStartN, RangeCountN); + + // + // Dispatch the partitioned operation. + // + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + const float* A = (const float*)DataParams->A + RangeStartM * lda; + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + const float* bias = DataParams->Bias; + + if (!DataParams->BIsfp32) { + MlasSBGemmPackedOperation( + RangeCountM, RangeStartN, RangeCountN, BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, K, A, + lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor + ); + } else { + const size_t ldb = DataParams->ldb; + const float* B = (const float*)DataParams->B + RangeStartN; + MlasSBGemmNonPackedOperation(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor); + } +} + +// +// dispatch structure. +// +typedef void(MLAS_SBGEMM_OPERATION)(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, const size_t M, const size_t N, const size_t K, const MLAS_SBGEMM_DATA_PARAMS* DataParams, ptrdiff_t ThreadId); + +typedef void(MLAS_SBGEMM_CONVERTPACKB_ROUTINE)( + bfloat16_t* D, const float* B, size_t ldb, size_t CountN, size_t CountK +); + +/** + * @brief Hardware dependent dispatch for half precision GEMM + */ +struct MLAS_SBGEMM_DISPATCH { + MLAS_SBGEMM_OPERATION* Operation; /**< HalfGemm driver */ + MLAS_SBGEMM_CONVERTPACKB_ROUTINE* ConvertPackBRoutine; /**< Convert and pack function for B */ + size_t PackedK; + size_t PackedN; + size_t StrideM; + size_t BufOverRead; +}; + +extern const MLAS_SBGEMM_DISPATCH MlasSBGemmDispatchNeon; + +MLAS_FORCEINLINE +const MLAS_SBGEMM_DISPATCH* +MlasSBGemmGetDispatch() +{ +#if defined(MLAS_TARGET_ARM64) + return &MlasSBGemmDispatchNeon; +#else + std::cerr << "SBGemm Kernel is supported only on ARM64 platform."; + exit(1); +#endif +} + +size_t MLASCALL +MlasSBGemmPackBSize(size_t N, size_t K) +{ + // + // Compute the number of bytes required to hold the packed buffer. + // + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return 0; + + const auto padding = dispatch->BufOverRead; + const auto PackedK = dispatch->PackedK; + const auto PackedN = dispatch->PackedN; + + const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); + const size_t AlignedN = (N + PackedN - 1) & ~(PackedN - 1); + const size_t BytesRequired = AlignedN * AlignedK * sizeof(bfloat16_t) + padding; + const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); + const size_t AlignedBytesRequired = + (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); + + return AlignedBytesRequired; +} + +void MLASCALL +MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB) +{ + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + dispatch->ConvertPackBRoutine((bfloat16_t*)PackedB, B, ldb, N, K); +} + +void MLASCALL +MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* Data, MLAS_THREADPOOL* ThreadPool) +{ + const MLAS_SBGEMM_DISPATCH* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + MLAS_SBGEMM_OPERATION* operation = dispatch->Operation; + + // + // Compute the number of target threads given the complexity of the SGEMM + // operation. Small requests should run using the single threaded path. + // + + const double Complexity = double(M) * double(N) * double(K); + + ptrdiff_t TargetThreadCount; + + if (Complexity < double(MLAS_SBGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) { + TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; + } else { + TargetThreadCount = GetMlasPlatform().MaximumThreadCount; + } + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + // + // Segment the operation across multiple threads. + // + // N.B. Currently, the operation is segmented as a 1D partition, which + // works okay for operations involving skinny matrices. + // + ptrdiff_t ThreadsPerGemm = (TargetThreadCount + BatchN - 1) / BatchN; + ptrdiff_t ThreadCountM; + ptrdiff_t ThreadCountN; + + if (N > M) { + const size_t BlockedN = + (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + if (size_t(ThreadsPerGemm) > BlockedN) { + ThreadsPerGemm = ptrdiff_t(BlockedN); + } + + ThreadCountM = 1; + ThreadCountN = ThreadsPerGemm; + + } else { + if (size_t(ThreadsPerGemm) > M) { + ThreadsPerGemm = ptrdiff_t(M); + } + + ThreadCountM = ThreadsPerGemm; + ThreadCountN = 1; + } + + MlasTrySimpleParallel( + ThreadPool, ThreadsPerGemm * static_cast(BatchN), [=](ptrdiff_t tid) { + ptrdiff_t GemmIdx = tid / ThreadsPerGemm; + ptrdiff_t ThreadIdx = tid % ThreadsPerGemm; + operation(ThreadCountM, ThreadCountN, M, N, K, &(Data[GemmIdx]), ThreadIdx); + } + ); +} +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp new file mode 100644 index 0000000000000..a6a73996c548b --- /dev/null +++ b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp @@ -0,0 +1,362 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + sbgemm_kernel_neon.cpp + +Abstract: + + This module implements bfloat16 precision GEMM kernel for neon. + +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#include "arm_neon.h" +#include "mlasi.h" +#include "sbgemm.h" + +struct MLAS_SBGEMM_KERNEL_NEON { + static constexpr bool PackNeeded = true; + static constexpr size_t KernelMaxM = 8; // max # rows the vectorized kernel can process + static constexpr size_t PackedK = 4; + static constexpr size_t PackedN = MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + static constexpr MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; // M:N:K +}; + +bool MLASCALL +MlasBf16AccelerationSupported() +{ +#if defined(MLAS_TARGET_ARM64) + return MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_BF16(); +#else + return false; +#endif +} + +/* + This routine converts fp32 to bf16 and copies elements from the source + matrix to the destination packed buffer. + + 4x2 elements from the source matrix are unrolled to be physically + contiguous for better locality inside the SBGEMM kernels. The remaining + rows and columns are padded to 4 and 2 alignment. +*/ +MLAS_FORCEINLINE +void +MlasSBGemmConvertCopyPackB(bfloat16_t* D, const float* B, size_t ldb, size_t CountN, size_t CountK) +{ + // + // Copy data from matrix B into the destination buffer 4x2 blocks at a + // time. + // + // + while (CountN >= 8) { + const float* b = B; + int y = static_cast(CountK); + + while (y > 0) { + MLAS_FLOAT32X4 t0_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t0_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3_h = MlasZeroFloat32x4(); + + if (y >= 4) { + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + t2_l = MlasLoadFloat32x4(&b[ldb * 2]); + t2_h = MlasLoadFloat32x4(&b[ldb * 2 + 4]); + t3_l = MlasLoadFloat32x4(&b[ldb * 3]); + t3_h = MlasLoadFloat32x4(&b[ldb * 3 + 4]); + } else { + switch (y) { + case 3: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + t2_l = MlasLoadFloat32x4(&b[ldb * 2]); + t2_h = MlasLoadFloat32x4(&b[ldb * 2 + 4]); + break; + case 2: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + break; + case 1: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + break; + } + } + + float32x4x2_t z0_l = vzipq_f32(t0_l, t2_l); + float32x4x2_t z1_l = vzipq_f32(t1_l, t3_l); + float32x4x2_t o0_l = vzipq_f32(z0_l.val[0], z1_l.val[0]); + float32x4x2_t o1_l = vzipq_f32(z0_l.val[1], z1_l.val[1]); + t0_l = o0_l.val[0]; + t1_l = o0_l.val[1]; + t2_l = o1_l.val[0]; + t3_l = o1_l.val[1]; + + bfloat16x8_t t0t1_l_4h = vcvtq_low_bf16_f32(t0_l); + bfloat16x8_t t0t1_l_8h = vcvtq_high_bf16_f32(t0t1_l_4h, t1_l); + + bfloat16x8_t t2t3_l_4h = vcvtq_low_bf16_f32(t2_l); + bfloat16x8_t t2t3_l_8h = vcvtq_high_bf16_f32(t2t3_l_4h, t3_l); + + vst1q_bf16(&D[0], t0t1_l_8h); + vst1q_bf16(&D[8], t2t3_l_8h); + + float32x4x2_t z0_h = vzipq_f32(t0_h, t2_h); + float32x4x2_t z1_h = vzipq_f32(t1_h, t3_h); + float32x4x2_t o0_h = vzipq_f32(z0_h.val[0], z1_h.val[0]); + float32x4x2_t o1_h = vzipq_f32(z0_h.val[1], z1_h.val[1]); + t0_h = o0_h.val[0]; + t1_h = o0_h.val[1]; + t2_h = o1_h.val[0]; + t3_h = o1_h.val[1]; + + bfloat16x8_t t0t1_h_4h = vcvtq_low_bf16_f32(t0_h); + bfloat16x8_t t0t1_h_8h = vcvtq_high_bf16_f32(t0t1_h_4h, t1_h); + + bfloat16x8_t t2t3_h_4h = vcvtq_low_bf16_f32(t2_h); + bfloat16x8_t t2t3_h_8h = vcvtq_high_bf16_f32(t2t3_h_4h, t3_h); + + vst1q_bf16(&D[16], t0t1_h_8h); + vst1q_bf16(&D[24], t2t3_h_8h); + + D += 32; + b += ldb * 4; + y -= 4; + }; + B += 8; + CountN -= 8; + } + + // + // Special case the handling of the remaining columns less than 8 elements + // wide. + // + if (CountN > 0) { + int y = static_cast(CountK); + while (y > 0) { + const float* b = B; + size_t b_inc = 0; + if ((CountN & 4) != 0) { + MLAS_FLOAT32X4 t0 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3 = MlasZeroFloat32x4(); + if (y >= 4) { + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + t2 = MlasLoadFloat32x4(&b[ldb * 2]); + t3 = MlasLoadFloat32x4(&b[ldb * 3]); + } else { + switch (y) { + case 3: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + t2 = MlasLoadFloat32x4(&b[ldb * 2]); + break; + case 2: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + break; + case 1: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + break; + } + } + + float32x4x2_t z0 = vzipq_f32(t0, t2); + float32x4x2_t z1 = vzipq_f32(t1, t3); + float32x4x2_t o0 = vzipq_f32(z0.val[0], z1.val[0]); + float32x4x2_t o1 = vzipq_f32(z0.val[1], z1.val[1]); + + t0 = o0.val[0]; + t1 = o0.val[1]; + t2 = o1.val[0]; + t3 = o1.val[1]; + + bfloat16x8_t t0t1_4h = vcvtq_low_bf16_f32(t0); + bfloat16x8_t t0t1_8h = vcvtq_high_bf16_f32(t0t1_4h, t1); + + bfloat16x8_t t2t3_4h = vcvtq_low_bf16_f32(t2); + bfloat16x8_t t2t3_8h = vcvtq_high_bf16_f32(t2t3_4h, t3); + + vst1q_bf16(&D[0], t0t1_8h); + vst1q_bf16(&D[8], t2t3_8h); + + D += 16; + b += 4; + b_inc += 4; + } + + if ((CountN & 2) != 0) { + float32x2_t t0 = {0x0, 0x0}; + float32x2_t t1 = {0x0, 0x0}; + float32x2_t t2 = {0x0, 0x0}; + float32x2_t t3 = {0x0, 0x0}; + + if (y >= 4) { + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + t2 = vld1_f32(&b[ldb * 2]); + t3 = vld1_f32(&b[ldb * 3]); + } else { + switch (y) { + case 3: + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + t2 = vld1_f32(&b[ldb * 2]); + break; + case 2: + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + break; + case 1: + t0 = vld1_f32(&b[ldb * 0]); + break; + } + } + + float32x2x2_t z0 = vzip_f32(t0, t2); + float32x2x2_t z1 = vzip_f32(t1, t3); + float32x2x2_t o0 = vzip_f32(z0.val[0], z1.val[0]); + float32x2x2_t o1 = vzip_f32(z0.val[1], z1.val[1]); + + float32x4_t tt0 = vcombine_f32(o0.val[0], o0.val[1]); + float32x4_t tt1 = vcombine_f32(o1.val[0], o1.val[1]); + + bfloat16x8_t t_4h = vcvtq_low_bf16_f32(tt0); + bfloat16x8_t t_8h = vcvtq_high_bf16_f32(t_4h, tt1); + + vst1q_bf16(&D[0], t_8h); + + D += 8; + b += 2; + b_inc += 2; + } + if ((CountN & 1) != 0) { + float a = 0.0f; + float b = 0.0f; + float c = 0.0f; + float d = 0.0f; + + if (y >= 4) { + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + c = *(float*)(&B[ldb * 2 + b_inc]); + d = *(float*)(&B[ldb * 3 + b_inc]); + } else { + switch (y) { + case 3: + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + c = *(float*)(&B[ldb * 2 + b_inc]); + break; + case 2: + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + break; + case 1: + a = *(float*)(&B[ldb * 0 + b_inc]); + break; + } + } + + float32x2_t t0 = {a, 0x0}; + float32x2_t t1 = {b, 0x0}; + float32x2_t t2 = {c, 0x0}; + float32x2_t t3 = {d, 0x0}; + + float32x2x2_t z0 = vzip_f32(t0, t2); + float32x2x2_t z1 = vzip_f32(t1, t3); + float32x2x2_t o0 = vzip_f32(z0.val[0], z1.val[0]); + float32x2x2_t o1 = vzip_f32(z0.val[1], z1.val[1]); + + float32x4_t tt0 = vcombine_f32(o0.val[0], o0.val[1]); + float32x4_t tt1 = vcombine_f32(o1.val[0], o1.val[1]); + + bfloat16x8_t t_4h = vcvtq_low_bf16_f32(tt0); + bfloat16x8_t t_8h = vcvtq_high_bf16_f32(t_4h, tt1); + + vst1q_bf16(&D[0], t_8h); + + D += 8; + b += 1; + b_inc += 1; + } + B += 4 * ldb; + y -= 4; + } + } +} + +template +void +MlasSBGemmConvertPackB( + bfloat16_t* PackedB, const float* B, size_t ldb, size_t CountN, size_t CountK +) +{ + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + const auto PackedN = dispatch->PackedN; + + const size_t AlignedN = (CountN + PackedN - 1) & ~(PackedN - 1); + + // + // Step through each slice of matrix B along the K dimension. + // + size_t K_block_size; + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + + for (size_t k = 0; k < CountK; k += K_block_size) { + K_block_size = std::min(CountK - k, Strides.K); + + MlasSBGemmConvertCopyPackB((bfloat16_t*)PackedB, B + k * ldb, ldb, CountN, K_block_size); + PackedB = (bfloat16_t*)PackedB + AlignedN * K_block_size; + } +} + +template <> +MLAS_FORCEINLINE void +MlasSBGemmKernel(size_t CountM, size_t CountN, size_t CountK, const float* A, size_t lda, const bfloat16_t* B, float* C, size_t ldc, const float* Bias, const bool ZeroMode) +{ + while (CountM > 0) { + size_t RowsHandled; + if (ZeroMode) { + RowsHandled = MlasSbgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); + } else { + RowsHandled = MlasSbgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); + } + C += ldc * RowsHandled; + A += lda * RowsHandled; + CountM -= RowsHandled; + } +} + +const MLAS_SBGEMM_DISPATCH MlasSBGemmDispatchNeon = { + MlasSBGemmOperation, + MlasSBGemmConvertPackB, + MLAS_SBGEMM_KERNEL_NEON::PackedK, + MLAS_SBGEMM_KERNEL_NEON::PackedN, + MLAS_SBGEMM_KERNEL_NEON::KernelMaxM, + 32 // kernel may read beyond buffer end by 32 bytes +}; +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index ec395cf018f5e..583ee759cc2e6 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -6,7 +6,6 @@ #include "core/providers/cpu/math/matmul_helper.h" #include "core/util/math.h" #include "core/util/math_cpuonly.h" -#include "core/mlas/inc/mlas.h" namespace onnxruntime { @@ -125,6 +124,44 @@ Status MatMul::Compute(OpKernelContext* ctx) const { return Status::OK(); } +#if defined(__aarch64__) && defined(__linux__) +bool GemmPackBBfloat16(AllocatorPtr& alloc, + const Tensor& tensor_b, + bool trans_b, + IAllocatorUniquePtr& packed_b, + size_t& packed_b_size, + TensorShape& b_shape) { + // Only handle the common case of a 2D weight matrix. Additional matrices + // could be handled by stacking the packed buffers. + if (tensor_b.Shape().NumDimensions() != 2) { + return false; + } + + b_shape = tensor_b.Shape(); + + const size_t K = trans_b ? static_cast(b_shape[1]) : static_cast(b_shape[0]); + const size_t N = trans_b ? static_cast(b_shape[0]) : static_cast(b_shape[1]); + + packed_b_size = MlasSBGemmPackBSize(N, K); + if (packed_b_size == 0) { + return false; + } + + packed_b = IAllocator::MakeUniquePtr(alloc, packed_b_size, true); + auto* packed_b_data = packed_b.get(); + + // Initialize memory to 0 as there could be some padding associated with pre-packed + // buffer memory and we don not want it uninitialized and generate different hashes + // if and when we try to cache this pre-packed buffer for sharing between sessions. + memset(packed_b_data, 0, packed_b_size); + MlasSBGemmConvertPackB(N, + K, + tensor_b.Data(), + trans_b ? K : N, + packed_b_data); + return true; +} +#endif Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, @@ -134,7 +171,24 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc // only pack Matrix B if (input_idx == 1) { size_t packed_b_size; - is_packed = GemmPackBFp32(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); +#if defined(__aarch64__) && defined(__linux__) + size_t dim1 = 0; + size_t dim2 = 0; + TensorShape b_shape = tensor.Shape(); + + if (b_shape.NumDimensions() == 2) { + dim1 = static_cast(b_shape[0]); + dim2 = static_cast(b_shape[1]); + } + + if (use_fastmath_mode_ && (trans_b_attr_ == 0) && ((dim1 * dim2) >= kFastMathModeKernelsizeThreshold)) { + is_packed = GemmPackBBfloat16(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); + } else +#endif + { + is_packed = GemmPackBFp32(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); + } + bool share_prepacked_weights = (prepacked_weights != nullptr); if (is_packed && share_prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); @@ -186,22 +240,40 @@ Status MatMul::Compute(OpKernelContext* ctx) const { const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(trans_a); const size_t ldb = helper.Ldb(trans_b); - - std::vector data(max_len); - for (size_t i = 0; i < max_len; i++) { - data[i].BIsPacked = bool(packed_b_); - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; - data[i].B = data[i].BIsPacked ? (float*)packed_b_.get() : b_data + helper.RightOffsets()[i]; - data[i].ldb = ldb; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; - data[i].alpha = alpha_attr_; - data[i].beta = 0.0f; +#if defined(__aarch64__) && defined(__linux__) + if (use_fastmath_mode_ && !trans_b && ((N * K) >= kFastMathModeKernelsizeThreshold)) { + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].BIsfp32 = !(bool(packed_b_)); + data[i].AIsfp32 = true; + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = data[i].BIsfp32 ? b_data + helper.RightOffsets()[i] : (float*)packed_b_.get(); + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].Bias = nullptr; + data[i].OutputProcessor = nullptr; + } + MlasSBGemmBatch(M, N, K, max_len, data.data(), thread_pool); + } else +#endif + { + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].BIsPacked = bool(packed_b_); + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = data[i].BIsPacked ? (float*)packed_b_.get() : b_data + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = alpha_attr_; + data[i].beta = 0.0f; + } + MlasGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, + M, N, K, data.data(), max_len, thread_pool); } - MlasGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, - M, N, K, data.data(), max_len, thread_pool); - return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/math/matmul.h b/onnxruntime/core/providers/cpu/math/matmul.h index b960fa4fb0587..b9bbe36583879 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.h +++ b/onnxruntime/core/providers/cpu/math/matmul.h @@ -4,6 +4,8 @@ #pragma once #include "core/framework/op_kernel.h" +#include "core/mlas/inc/mlas.h" +#include "core/session/onnxruntime_session_options_config_keys.h" namespace onnxruntime { @@ -27,6 +29,11 @@ class MatMul final : public OpKernel { info.GetAttrOrDefault("transBatchB", &trans_batch_b_attr, 0); trans_batch_a_ = trans_batch_a_attr != 0; trans_batch_b_ = trans_batch_b_attr != 0; + +#if defined(__aarch64__) && defined(__linux__) + auto config_ops = info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16); + use_fastmath_mode_ = (config_ops == "1") && MlasBf16AccelerationSupported(); +#endif } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, @@ -48,6 +55,14 @@ class MatMul final : public OpKernel { int64_t trans_b_attr_; bool trans_batch_a_; bool trans_batch_b_; + +#if defined(__aarch64__) && defined(__linux__) + // fastmath mode state + bool use_fastmath_mode_; + // sbgemm kernel is implemented as 8x8 blocks with weights pre-packed to 4 blocks of 4x2 + // so a minimum of 32 elements is defined to outweigh the additional prepacking overhead + const size_t kFastMathModeKernelsizeThreshold = 32; +#endif }; } // namespace onnxruntime diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp new file mode 100644 index 0000000000000..941de8f05061f --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp @@ -0,0 +1,141 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + test_sbgemm.cpp + +Abstract: + + Tests for MLAS bf16 precision GEMM. + +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#include "test_sbgemm.h" + +// +// Short Execute() test helper to register each test seperately by all parameters. +// +template +class SBGemmShortExecuteTest : public MlasTestFixture> { + public: + explicit SBGemmShortExecuteTest(size_t M, size_t N, size_t K, size_t Batch, bool hasBias) + : M_(M), N_(N), K_(K), Batch_(Batch), hasBias_(hasBias) {} + + void TestBody() override { + MlasTestFixture>::mlas_tester->Test(M_, N_, K_, Batch_, hasBias_); + } + + static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t Batch, bool hasBias) { + std::stringstream ss; + ss << "Batch" << Batch << "/M" << M << "xN" << N << "xK" << K << "/" + << "hasBias" << hasBias; + auto test_name = ss.str(); + + testing::RegisterTest( + MlasSBGemmTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MlasTestFixture>* { + return new SBGemmShortExecuteTest( + M, N, K, Batch, hasBias); + }); + + return 1; + } + + static size_t RegisterShortExecuteTests() { + size_t test_registered = 0; + for (size_t b = 1; b < 16; b++) { + test_registered += RegisterSingleTest(b, b, b, 1, false); + test_registered += RegisterSingleTest(b, b, b, 1, true); + } + for (size_t b = 16; b <= 256; b <<= 1) { + test_registered += RegisterSingleTest(b, b, b, 1, false); + test_registered += RegisterSingleTest(b, b, b, 1, true); + } + for (size_t b = 256; b < 320; b += 32) { + test_registered += RegisterSingleTest(b, b, b, 1, true); + } + for (size_t b = 1; b < 96; b++) { + test_registered += RegisterSingleTest(1, b, 32, 1, false); + test_registered += RegisterSingleTest(1, 32, b, 1, true); + test_registered += RegisterSingleTest(1, b, b, 1, false); + if (!Packed) { + test_registered += RegisterSingleTest(1, b, 32, 3, true); + test_registered += RegisterSingleTest(1, 32, b, 5, false); + } + } + // TODO: check why the cosine similary is < 0.99 for this shape alone + // test_registered += RegisterSingleTest(43, 500, 401, 1, true); + test_registered += RegisterSingleTest(1001, 1027, 1031, 1, false); + if (!Packed) { + test_registered += RegisterSingleTest(43, 500, 401, 5, true); + test_registered += RegisterSingleTest(1000, 1029, 1030, 3, false); + } + + return test_registered; + } + + private: + size_t M_, N_, K_, Batch_; + bool hasBias_; +}; + +static size_t SBGemmRegistLongExecute() { + size_t count = 0; + + count += MlasLongExecuteTests>::RegisterLongExecute(); + if (MlasSBGemmPackBSize(128, 128) > 0) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + } + + if (GetMlasThreadPool() != nullptr) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + if (MlasSBGemmPackBSize(128, 128) > 0) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + } + } + + return count; +} + +static size_t SBGemmRegistShortExecute() { + size_t count = 0; + + count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); + if (MlasSBGemmPackBSize(128, 128) > 0) { + count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); + } + + if (GetMlasThreadPool() != nullptr) { + count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); + if (MlasSBGemmPackBSize(128, 128) > 0) { + count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); + } + } + + return count; +} + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + if (!MlasBf16AccelerationSupported()) { + return false; + } + + if (is_short_execute) { + return SBGemmRegistShortExecute() > 0; + } + return SBGemmRegistLongExecute() > 0; +}); +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.h b/onnxruntime/test/mlas/unittest/test_sbgemm.h new file mode 100644 index 0000000000000..13701e2e3de46 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.h @@ -0,0 +1,281 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + test_sbgemm.h + +Abstract: + + Tests for MLAS bf16 precision GEMM. + +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#pragma once + +#include "test_util.h" + +template +void SmallFloatFill(T* start, size_t size) { + constexpr float MinimumFillValue = -11.0f; + auto FillAddress = start; + size_t offset = size % 23; + + for (size_t i = 0; i < size; i++) { + offset = (offset + 21) % 23; + *FillAddress++ = T((MinimumFillValue + offset) / 16.0f); + } +} + +float cosine_similarity(const float* A, const float* B, size_t Vector_Length) { + float dot = 0.0, denom_a = 0.0, denom_b = 0.0; + for (size_t i = 0u; i < Vector_Length; ++i) { + dot += A[i] * B[i]; + denom_a += A[i] * A[i]; + denom_b += B[i] * B[i]; + } + return dot / (sqrt(denom_a) * sqrt(denom_b)); +} + +/** + * @brief Test class for bf16 precision GEMM + * @tparam AType Data type of A matrix, need to be float + * @tparam BType Data type of b matrix, can be either float or prepacked bf16 + */ +template +class MlasSBGemmTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferBPacked; + MatrixGuardBuffer BufferA; + MatrixGuardBuffer BufferB; + MatrixGuardBuffer BufferBias; + MatrixGuardBuffer BufferC; + MatrixGuardBuffer BufferCReference; + MatrixGuardBuffer BufferFloatC; + MLAS_THREADPOOL* threadpool_; + + void* PackB(size_t N, size_t K, const BType* B, size_t ldb) { + size_t PackedBSize = MlasSBGemmPackBSize(N, K); + if (PackedBSize == 0) { + return nullptr; + } + void* PackedB = BufferBPacked.GetBuffer(PackedBSize); + if (std::is_same::value) { + MlasSBGemmConvertPackB(N, K, (const float*)B, ldb, PackedB); + } else { + } + return PackedB; + } + + void CallSBGemm(size_t M, + size_t N, + size_t K, + size_t BatchSize, + const float* A, + size_t lda, + const BType* B, + size_t ldb, + const float* Bias, + float* C, + size_t ldc) { + std::vector GemmParameters(BatchSize); + + for (size_t i = 0; i < GemmParameters.size(); i++) { + auto& params = GemmParameters[i]; + params.A = A + (M * lda * i); + params.lda = lda; + if (nullptr != Bias) { + params.Bias = reinterpret_cast(Bias + N * i); + } else { + params.Bias = nullptr; + } + params.C = reinterpret_cast(C + (M * ldc * i)); + params.ldc = ldc; + params.AIsfp32 = true; + params.BIsfp32 = true; + + if (Packed) { + ASSERT_EQ(BatchSize, size_t(1)) << "Packing B not supported in batching yet!"; + params.B = PackB(N, K, B, ldb); + params.ldb = 0; + params.BIsfp32 = false; + } else { + params.B = B + (K * N * i); + params.ldb = ldb; + } + } + + MlasSBGemmBatch(M, N, K, BatchSize, GemmParameters.data(), threadpool_); + } + + void ReferenceSgemm(size_t M, + size_t N, + size_t K, + size_t BatchSize, + const AType* A, + const BType* B, + const float* Bias, + float* C) { + constexpr size_t KStride = 256; + + for (size_t batch = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const AType* a = A + M * K * batch + m * K; + const BType* b = B + K * N * batch + n; + float* c = C + (M * N * batch) + (m * N) + n; + + for (size_t k = 0; k < K; k += KStride) { + float sum = 0.0f; + if (k == 0 && Bias != nullptr) { + sum = float(Bias[n]); + } + for (size_t kk = 0; kk < std::min(KStride, K - k); kk++) { + float down(float(*b) * float(*a) + sum); + sum = float(down); + b += N; + a += 1; + } + if (k == 0) { + *c = sum; + } else { + float d(sum + *c); + *c = float(d); + } + } + } + } + if (Bias) { + Bias += N; + } + } + } + + public: + MlasSBGemmTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} + + void Test(size_t M, size_t N, size_t K, size_t BatchSize, bool withBias) { + AType* A = BufferA.GetFilledBuffer(K * M * BatchSize + 16, SmallFloatFill); + AType Atail[16]; + std::memcpy(Atail, A + K * M * BatchSize, 16 * sizeof(AType)); + + BType* B = BufferB.GetFilledBuffer(N * K * BatchSize + 16, SmallFloatFill); + BType Btail[16]; + std::memcpy(Btail, B + N * K * BatchSize, 16 * sizeof(BType)); + + float BiasTail[16]; + const float* Bias = nullptr; + if (withBias) { + Bias = BufferBias.GetFilledBuffer(N * BatchSize + 16, SmallFloatFill); + std::memcpy(BiasTail, Bias + N * BatchSize, 16 * sizeof(float)); + } + + float* C = BufferC.GetFilledBuffer(N * M * BatchSize, SmallFloatFill); + float* CReference = BufferCReference.GetFilledBuffer( + N * M * BatchSize, + [](float* start, size_t size) { + std::fill_n(start, size, -1.0f); + }); + this->CallSBGemm(M, N, K, BatchSize, A, K, B, N, Bias, C, N); + ReferenceSgemm(M, N, K, BatchSize, A, B, Bias, CReference); + const float cosine_similarity_threshold = 0.98; + + for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++, f++) { + if (!(CloseEnough(float(C[f]), CReference[f]))) { + float cos_sim = cosine_similarity(C, CReference, (BatchSize * M * N)); + if (abs(cos_sim) < cosine_similarity_threshold) { + ASSERT_TRUE(false) << "cosine similarity check failed" << cos_sim; + } else { + break; + } + } + } + } + } + + ASSERT_EQ(std::memcmp(Atail, A + K * M * BatchSize, 16 * sizeof(AType)), 0) << "Matrix A buffer overwritten!"; + ASSERT_EQ(std::memcmp(Btail, B + N * K * BatchSize, 16 * sizeof(BType)), 0) << "Matrix B buffer overwritten!"; + if (withBias) { + ASSERT_EQ(std::memcmp(BiasTail, Bias + N * BatchSize, 16 * sizeof(float)), 0) << "Bias buffer overwritten!"; + } + } + + private: + public: + static const char* GetTestSuiteName() { + static std::string suite_name = std::string("SBGemmFP") + + (std::is_same::value ? "32" : "16") + + (std::is_same::value ? "32" : "16") + + (Packed ? "_Packed" : "_NoPack") + + (Threaded ? "_Threaded" : "_SingleThread"); + return suite_name.c_str(); + } + + void ExecuteLong(void) override { + for (size_t M = 16; M < 160; M += 32) { + for (size_t N = 16; N < 160; N += 32) { + static const size_t ks[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 20, 32, 48, 64, 118, 119, 120, 121, 122, 160, 240, 320}; + for (size_t k = 0; k < _countof(ks); k++) { + size_t K = ks[k]; + + Test(M, N, K, 1, false); + Test(M, N, K, 1, true); + Test(M + 1, N, K, 1, false); + Test(M, N + 1, K, 1, true); + Test(M + 1, N + 1, K, 1, false); + Test(M + 3, N + 2, K, 1, true); + Test(M + 4, N, K, 1, false); + Test(M, N + 4, K, 1, true); + Test(M + 4, N + 4, K, 1, false); + Test(M + 3, N + 7, K, 1, true); + Test(M + 8, N, K, 1, false); + Test(M, N + 8, K, 1, true); + Test(M + 12, N + 12, K, 1, false); + Test(M + 13, N, K, 1, true); + Test(M, N + 15, K, 1, false); + Test(M + 15, N + 15, K, 1, false); + if (!Packed) { + Test(M, N, K, 7, false); + Test(M + 3, N, K, 8, true); + Test(M, N + 1, K, 9, false); + Test(M + 12, N, K, 10, true); + Test(M, N + 15, K, 11, false); + Test(M + 15, N + 15, K, 12, true); + } + } + } + printf("M %zd\n", M); + } + + for (size_t M = 1; M < 160; M++) { + for (size_t N = 1; N < 160; N++) { + for (size_t K = 1; K < 160; K++) { + Test(M, N, K, 1, true); + } + } + printf("M %zd\n", M); + } + + for (size_t M = 160; M < 320; M += 24) { + for (size_t N = 112; N < 320; N += 24) { + for (size_t K = 1; K < 16; K++) { + Test(M, N, K, 1, true); + } + for (size_t K = 16; K < 160; K += 32) { + Test(M, N, K, 1, false); + } + } + printf("M %zd\n", M); + } + } +}; + +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc b/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc new file mode 100644 index 0000000000000..ec9f78da14a75 --- /dev/null +++ b/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc @@ -0,0 +1,730 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// Licensed under the MIT License. + +#include "core/framework/compute_capability.h" +#include "core/graph/model.h" +#include "core/graph/onnx_protobuf.h" +#include "core/mlas/inc/mlas.h" +#include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/optimizer/utils.h" +#include "core/providers/partitioning_utils.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/environment.h" +#include "core/session/inference_session.h" + +#include "test/compare_ortvalue.h" +#include "test/test_environment.h" +#include "test/framework/test_utils.h" +#include "test/util/include/asserts.h" +#include "test/util/include/inference_session_wrapper.h" + +#include "gtest/gtest.h" +#include "graph_transform_test_builder.h" + +#include "qdq_test_utils.h" + +#if defined(__aarch64__) && defined(__linux__) && !defined(DISABLE_CONTRIB_OPS) + +struct QDQOpKeys { + const char* quantize_linear; + const char* dequantize_linear; +}; + +constexpr QDQOpKeys GetQDQOpKeys(bool use_contrib_qdq) { + if (use_contrib_qdq) { + return {"com.microsoft.QuantizeLinear", "com.microsoft.DequantizeLinear"}; + } + return {"QuantizeLinear", "DequantizeLinear"}; +} + +namespace onnxruntime { +namespace test { + +#if !defined(DISABLE_CONTRIB_OPS) + +TEST(QDQTransformerTests, DQ_S8_to_U8_FastMath) { + auto test_case = [](bool use_contrib_qdq) { + const std::vector& input_shape = {19, 37}; + const std::vector& weights_shape = {37, 23}; + + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input_shape, -1.f, 1.f); + + // Use full range weight values to expose u8s8 overflow problems + auto* weight = builder.MakeInitializer(weights_shape, -128, 127); + auto* output_arg = builder.MakeOutput(); + + // add QDQ activation + typedef std::numeric_limits Input1Limits; + auto* dq1_output = AddQDQNodePair(builder, input1_arg, .039f, + (int8_t)((Input1Limits::max() + Input1Limits::min()) / 2 + 1), + use_contrib_qdq); + + // add DQ weight + auto* dq_w_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(weight, .003f, -10, dq_w_output, use_contrib_qdq); + + builder.AddNode("MatMul", {dq1_output, dq_w_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 1); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, /*using NAN as a magic number to trigger cosine similarity*/ + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 18 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options); + + auto add_session_options_disable_fm = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options_disable_fm); + }; + + test_case(false); // Use ONNX QDQ ops + test_case(true); // Use com.microsoft QDQ ops +} + +template +void QDQTransformerMatMulTests(bool has_output_q, bool disable_fastmath = false) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, + bool use_contrib_qdq = false) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -1.f, 1.f); + auto* input2_arg = builder.MakeInput(input2_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + + typedef std::numeric_limits Input1Limits; + typedef std::numeric_limits Input2Limits; + typedef std::numeric_limits OutputTypeLimits; + + // add QDQ 1 + auto* q1_output = builder.MakeIntermediate(); + auto* dq1_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input1_arg, + .039f, + (Input1Limits::max() + Input1Limits::min()) / 2 + 1, + q1_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q1_output, + .039f, + (Input2Limits::max() + Input1Limits::min()) / 2 + 1, + dq1_output, use_contrib_qdq); + + // add QDQ 2 + auto* q2_output = builder.MakeIntermediate(); + auto* dq2_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input2_arg, + .04f, + (Input2Limits::max() + Input2Limits::min()) / 2 + 1, + q2_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q2_output, + .04f, + (Input2Limits::max() + Input2Limits::min()) / 2 + 1, + dq2_output, use_contrib_qdq); + + if (has_output_q) { + // add binary operator + auto* matmul_op_output = builder.MakeIntermediate(); + builder.AddNode("MatMul", {dq1_output, dq2_output}, {matmul_op_output}); + + // add QDQ output + auto* q3_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(matmul_op_output, + .039f, + (OutputTypeLimits::max() + OutputTypeLimits::min()) / 2 + 1, + q3_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q3_output, + .039f, + (OutputTypeLimits::max() + OutputTypeLimits::min()) / 2 + 1, + output_arg, use_contrib_qdq); + } else { + builder.AddNode("MatMul", {dq1_output, dq2_output}, {output_arg}); + } + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + if (has_output_q) { + if constexpr (std::is_same::value && + (std::is_same::value || + QDQIsInt8Allowed() && std::is_same::value)) { + EXPECT_EQ(op_to_count["QLinearMatMul"], 1); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + } else { + EXPECT_EQ(op_to_count["QLinearMatMul"], 0); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 3); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 3); + } + } else { + if constexpr (std::is_same::value || + (QDQIsInt8Allowed() && std::is_same::value)) { + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + } else { + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 0); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 2); + } + } + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 18 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + + if (disable_fastmath) { + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + } + }; + + test_case({1, 2, 2}, {1, 2, 4}); + test_case({1, 23, 13, 13}, {13, 13}); + test_case({1, 22, 11, 13, 15}, {1, 22, 11, 15, 15}); + test_case({1, 2, 2}, {1, 2, 4}, true); // Use com.microsoft QDQ ops +} + +TEST(QDQTransformerTests, MatMul_U8U8U8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_U8S8S8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_U8U8S8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_U8S8U8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_S8S8S8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_S8U8U8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_S8U8S8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_S8S8U8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +// dummy test to disable the fastmath session op +TEST(QDQTransformerTests, MatMul_S8S8U8_DisableFastMath) { + QDQTransformerMatMulTests(false, true); + QDQTransformerMatMulTests(true, true); +} + +template +void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one = false, bool disable_fastmath = false) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, + bool use_contrib_qdq = false) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -1.f, 1.f); + auto* input2_arg = builder.MakeInput(input2_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + + typedef std::numeric_limits Input1Limits; + typedef std::numeric_limits Input2Limits; + typedef std::numeric_limits OutputTypeLimits; + + std::vector input_args; + + // add QDQ A + auto* q1_output = builder.MakeIntermediate(); + auto* dq1_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input1_arg, + .039f, + (Input1Limits::max() + Input1Limits::min()) / 2 + 1, + q1_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q1_output, + .039f, + (Input2Limits::max() + Input1Limits::min()) / 2 + 1, + dq1_output, use_contrib_qdq); + + input_args.push_back(dq1_output); + + // add QDQ B + auto* q2_output = builder.MakeIntermediate(); + auto* dq2_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input2_arg, + .04f, + (Input2Limits::max() + Input2Limits::min()) / 2 + 1, + q2_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q2_output, + .04f, + (Input2Limits::max() + Input2Limits::min()) / 2 + 1, + dq2_output, use_contrib_qdq); + input_args.push_back(dq2_output); + + if (has_bias) { + auto* dq_bias_output = builder.MakeIntermediate(); + auto* bias = builder.MakeInitializer({input2_shape[1]}, static_cast(0), static_cast(127)); + builder.AddDequantizeLinearNode(bias, 0.00156f, + 0, + dq_bias_output, use_contrib_qdq); + input_args.push_back(dq_bias_output); + } + + Node* gemm_node = nullptr; + + if (has_output_q) { + auto* gemm_op_output = builder.MakeIntermediate(); + gemm_node = &builder.AddNode("Gemm", input_args, {gemm_op_output}); + + // add QDQ output + auto* q3_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(gemm_op_output, + .039f, + (OutputTypeLimits::max() + OutputTypeLimits::min()) / 2 + 1, + q3_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q3_output, + .039f, + (OutputTypeLimits::max() + OutputTypeLimits::min()) / 2 + 1, + output_arg, use_contrib_qdq); + } else { + gemm_node = &builder.AddNode("Gemm", input_args, {output_arg}); + } + + if (beta_not_one) { + gemm_node->AddAttribute("beta", 2.0f); + } + }; + + auto check_binary_op_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + if ((!has_output_q || std::is_same_v)&&(!has_bias || (std::is_same_v && !beta_not_one)) && + (std::is_same_v || std::is_same_v)) { + EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1); + EXPECT_EQ(op_to_count["Gemm"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], has_output_q ? 1 : 0); + } else { + int q_count = 2; // Q for A and B + int dq_count = 2; // DQ for A and B + if (has_bias) { + dq_count++; + } + if (has_output_q) { + q_count++; + dq_count++; + } + EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 0); + EXPECT_EQ(op_to_count["Gemm"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], q_count); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], dq_count); + } + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, + check_binary_op_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + TransformerTester(build_test_case, + check_binary_op_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 18 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + TransformerTester(build_test_case, + check_binary_op_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + + if (disable_fastmath) { + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, + check_binary_op_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + } + }; + + test_case({2, 2}, {2, 4}); + test_case({13, 15}, {15, 15}); + test_case({2, 2}, {2, 4}, true); // Use com.microsoft QDQ ops +} + +template +void QDQTransformerGemmTests() { + QDQTransformerGemmTests(false, false); + QDQTransformerGemmTests(false, true); + QDQTransformerGemmTests(true, false); + QDQTransformerGemmTests(true, true); + QDQTransformerGemmTests(false, false, true); + QDQTransformerGemmTests(false, true, true); + QDQTransformerGemmTests(true, false, true); + QDQTransformerGemmTests(true, true, true); + // dummy test to disable the fastmath session + QDQTransformerGemmTests(true, true, true, true); +} + +TEST(QDQTransformerTests, Gemm_U8U8U8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_U8S8S8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_U8U8S8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_U8S8U8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_S8S8S8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_S8U8U8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_S8U8S8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_S8S8U8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, MatMul_No_Fusion_FastMath) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, + bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -1.f, 1.f); + auto* input2_arg = builder.MakeInput(input2_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + + // add QDQ + MatMul + auto* matmul_output = builder.MakeIntermediate(); + auto* dq_matmul_output1 = AddQDQNodePair(builder, input1_arg, .004f, 129, use_contrib_qdq); + builder.AddNode("MatMul", {dq_matmul_output1, input2_arg}, {matmul_output}); + + // add Q + builder.AddQuantizeLinearNode(matmul_output, .0039f, 135, output_arg, use_contrib_qdq); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count["QLinearMatMul"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options); + + auto add_session_options_disable_fm = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options_disable_fm); + }; + + test_case({12, 37}, {37, 12}, false /*use_contrib_qdq*/); + test_case({12, 37}, {37, 12}, true /*use_contrib_qdq*/); +} + +TEST(QDQTransformerTests, MatMul_1st_Input_Int8_FastMath) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, + bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -128, 127); + auto* input2_arg = builder.MakeInput(input2_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + + // add DQ with type int8 + auto* dq_output_1 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input1_arg, .004f, 1, dq_output_1, use_contrib_qdq); + + // add QDQ + MatMul + auto* matmul_output = builder.MakeIntermediate(); + auto* dq_matmul_output2 = AddQDQNodePair(builder, input2_arg, .004f, 129, use_contrib_qdq); + builder.AddNode("MatMul", {dq_output_1, dq_matmul_output2}, {matmul_output}); + + // add Q + builder.AddQuantizeLinearNode(matmul_output, .0039f, 135, output_arg, use_contrib_qdq); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count["QLinearMatMul"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 2); + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options); + + auto add_session_options_disable_fm = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options_disable_fm); + }; + + test_case({12, 37}, {37, 12}, false /*use_contrib_qdq*/); + test_case({12, 37}, {37, 12}, true /*use_contrib_qdq*/); + test_case({23, 13, 13}, {13, 13}, false /*use_contrib_qdq*/); + test_case({22, 11, 13, 15}, {15, 13}, false /*use_contrib_qdq*/); +} + +TEST(QDQTransformerTests, MatMulIntegerToFloat_FastMath) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, + bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* input2_arg = builder.MakeInput(input2_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output_1 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input1_arg, .0035f, 135, dq_output_1, use_contrib_qdq); + + auto* dq_output_2 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input2_arg, .0035f, 135, dq_output_2, use_contrib_qdq); + + builder.AddNode("MatMul", {dq_output_1, dq_output_2}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, + add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, + add_session_options); + + auto add_session_options_disable_fm = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_disable_fm); + }; + + test_case({12, 37}, {37, 12}, false /*use_contrib_qdq*/); + test_case({12, 37}, {37, 12}, true /*use_contrib_qdq*/); + test_case({23, 13, 13}, {13, 13}, false /*use_contrib_qdq*/); + test_case({22, 11, 13, 15}, {15, 13}, false /*use_contrib_qdq*/); +} + +#endif // !defined(DISABLE_CONTRIB_OPS) && defined(__aarch64) + +} // namespace test +} // namespace onnxruntime + +#endif // defined(__aarch64) && defined(__linux__) && !defined(DISABLE_CONTRIB_OPS) diff --git a/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc new file mode 100644 index 0000000000000..75e0c06b04f0d --- /dev/null +++ b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc @@ -0,0 +1,305 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// Licensed under the MIT License. + +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/providers/run_options_config_keys.h" +#include "test/common/dnnl_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/common/tensor_op_test_utils.h" +#include "default_providers.h" + +#if defined(__aarch64__) && defined(__linux__) + +namespace onnxruntime { +namespace test { + +namespace { + +const onnxruntime::RunOptions run_options = []() { + onnxruntime::RunOptions options{}; + ORT_THROW_IF_ERROR(options.config_options.AddConfigEntry(kOpTesterRunOptionsConfigTestTunableOp, "true")); + return options; +}(); + +const constexpr auto run_with_tunable_op = &run_options; + +} // namespace + +template +struct MatMulTestData { + std::string name; + std::vector input0_dims; + std::vector input1_dims; + std::vector expected_dims; + std::vector expected_vals; +}; + +template +std::vector> GenerateTestCases() { + std::vector> test_cases; + test_cases.push_back( + {"test padding and broadcast A > B", + {3, 1, 1, 6}, + {2, 6, 7}, + {3, 2, 1, 7}, + {385, 400, 415, 430, 445, 460, 475, 1015, 1030, 1045, 1060, 1075, 1090, 1105, 1015, 1066, 1117, 1168, 1219, 1270, 1321, 3157, 3208, 3259, 3310, 3361, 3412, 3463, 1645, 1732, 1819, 1906, 1993, 2080, 2167, 5299, 5386, 5473, 5560, 5647, 5734, 5821}}); + + test_cases.push_back( + {"test padding and broadcast B > A", + {2, 3, 12}, + {3, 2, 12, 3}, + {3, 2, 3, 3}, + {1518, 1584, 1650, 3894, 4104, 4314, 6270, 6624, 6978, 26574, 27072, 27570, 34134, 34776, 35418, 41694, 42480, 43266, 6270, 6336, 6402, 19014, 19224, 19434, 31758, 32112, 32466, 62430, 62928, 63426, 80358, 81000, 81642, 98286, 99072, 99858, 11022, 11088, 11154, 34134, 34344, 34554, 57246, 57600, 57954, 98286, 98784, 99282, 126582, 127224, 127866, 154878, 155664, 156450}}); + + test_cases.push_back( + {"test 2D", + {8, 6}, + {6, 6}, + {8, 6}, + {330, 345, 360, 375, 390, 405, 870, 921, 972, 1023, 1074, 1125, 1410, 1497, 1584, 1671, 1758, 1845, 1950, 2073, 2196, 2319, 2442, 2565, 2490, 2649, 2808, 2967, 3126, 3285, 3030, 3225, 3420, 3615, 3810, 4005, 3570, 3801, 4032, 4263, 4494, 4725, 4110, 4377, 4644, 4911, 5178, 5445}}); + + test_cases.push_back( + {"test 2D special", + {2, 2, 16}, + {16, 4}, + {2, 2, 4}, + {4960, 5080, 5200, 5320, 12640, 13016, 13392, 13768, 20320, 20952, 21584, 22216, 28000, 28888, 29776, 30664}}); + + test_cases.push_back( + {"test 2D special 2", + {2, 2, 9}, + {1, 9, 4}, + {2, 2, 4}, + {816, 852, 888, 924, 2112, 2229, 2346, 2463, 3408, 3606, 3804, 4002, 4704, 4983, 5262, 5541}}); + + test_cases.push_back( + {"test 2D special 3", + {2, 12}, + {1, 1, 12, 3}, + {1, 1, 2, 3}, + {1518, 1584, 1650, 3894, 4104, 4314}}); + + test_cases.push_back( + {"test 3D batch", + {3, 1, 18}, + {3, 18, 2}, + {3, 1, 2}, + { + // clang-format off + 3570, 3723, + 26250, 26727, + 72258, 73059, + // clang-format on + }}); + + test_cases.push_back( + {"test 4D batch", + {2, 2, 1, 20}, + {2, 2, 20, 2}, + {2, 2, 1, 2}, + { + // clang-format off + 4940, 5130, + 36140, 36730, + 99340, 100330, + 194540, 195930, + // clang-format on + }}); + + return test_cases; +} + +template +void RunMatMulTest(int32_t opset_version, bool is_a_constant, bool is_b_constant, bool disable_fastmath) { + for (auto t : GenerateTestCases()) { + SCOPED_TRACE("test case: " + t.name); + + OpTester test("MatMul", opset_version); + + int64_t size0 = TensorShape::FromExistingBuffer(t.input0_dims).SizeHelper(0, t.input0_dims.size()); + std::vector input0_vals = ValueRange(size0); + + test.AddInput("A", t.input0_dims, input0_vals, is_a_constant); + + int64_t size1 = TensorShape::FromExistingBuffer(t.input1_dims).SizeHelper(0, t.input1_dims.size()); + std::vector input1_vals = ValueRange(size1); + test.AddInput("B", t.input1_dims, input1_vals, is_b_constant); + + test.AddOutput("Y", t.expected_dims, t.expected_vals); + + // OpenVINO EP: Disabled temporarily matmul broadcasting not fully supported + // Disable TensorRT because of unsupported data type + std::unordered_set excluded_providers{kTensorrtExecutionProvider, kOpenVINOExecutionProvider}; + if (t.name == "test 2D empty input") { + // NNAPI: currently fails for the "test 2D empty input" case + excluded_providers.insert(kNnapiExecutionProvider); + } + + if ("test padding and broadcast A > B" == t.name || "test 2D empty input" == t.name) { + // QNN can't handle 0 shap + excluded_providers.insert(kQnnExecutionProvider); + } + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + + test.ConfigExcludeEps(excluded_providers) + .Config(run_with_tunable_op) + .Config(so) + .RunWithConfig(); + + if (disable_fastmath) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + + test.ConfigExcludeEps(excluded_providers) + .Config(run_with_tunable_op) + .Config(so) + .RunWithConfig(); + } + } +} + +template +void RunMatMulTest(int32_t opset_version) { + RunMatMulTest(opset_version, false, false, false); +} + +TEST(MathOpTest, MatMulFloatType_FastMath) { + // TODO: Unskip when fixed #41968513 + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)"; + } + RunMatMulTest(7, false, false, false); +} + +TEST(MathOpTest, MatMulFloatTypeInitializer_FastMath) { + // TODO: Unskip when fixed #41968513 + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)"; + } + RunMatMulTest(7, false, true, false); +} + +TEST(MathOpTest, MatMulInt32Type_FastMath) { + RunMatMulTest(9); +} + +TEST(MathOpTest, MatMulUint32Type_FastMath) { + RunMatMulTest(9); +} + +TEST(MathOpTest, MatMulInt64Type_FastMath) { + RunMatMulTest(9); +} + +TEST(MathOpTest, MatMulUint64Type_FastMath) { + RunMatMulTest(9); +} + +#ifndef ENABLE_TRAINING +// Prepacking is disabled in full training build so no need to test the feature in a training build. +TEST(MathOpTest, MatMulSharedPrepackedWeights_FastMath) { + OpTester test("MatMul"); + + std::vector b_init_values(32, 1.0f); + test.AddInput("A", {8, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + 1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + 1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + 1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + // B is to be an initializer for triggering pre-packing + test.AddInput("B", {4, 8}, b_init_values, true); + + test.AddOutput("Y", {8, 8}, + {10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, + 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, + 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, + 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f}); + + OrtValue b; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({4, 8}), + b_init_values.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), b); + + SessionOptions so; + // Set up B as a shared initializer to be shared between sessions + ASSERT_EQ(so.AddInitializer("B", &b), Status::OK()); + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + + // We want all sessions running using this OpTester to be able to share pre-packed weights if applicable + test.EnableSharingOfPrePackedWeightsAcrossSessions(); + + // Pre-packing is limited just to the CPU EP for now and we will only test the CPU EP + // and we want to ensure that it is available in this build + auto cpu_ep = []() -> std::vector> { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + return execution_providers; + }; + + size_t number_of_pre_packed_weights_counter_session_1 = 0; + size_t number_of_shared_pre_packed_weights_counter = 0; + + // Session 1 + { + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&number_of_pre_packed_weights_counter_session_1, &number_of_shared_pre_packed_weights_counter); + // Assert that no pre-packed weights have been shared thus far + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(0)); + } + + auto number_of_elements_in_shared_prepacked_buffers_container = + test.GetNumPrePackedWeightsShared(); + // Assert that the number of elements in the shared container + // is the same as the number of weights that have been pre-packed + ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_elements_in_shared_prepacked_buffers_container); + + // On some platforms/architectures MLAS may choose to not do any pre-packing and the number of elements + // that have been pre-packed will be zero in which case we do not continue with the testing + // of "sharing" of pre-packed weights as there are no pre-packed weights to be shared at all. + if (number_of_pre_packed_weights_counter_session_1 == 0) + return; + + // Session 2 + { + size_t number_of_pre_packed_weights_counter_session_2 = 0; + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&number_of_pre_packed_weights_counter_session_2, &number_of_shared_pre_packed_weights_counter); + + // Assert that the same number of weights were pre-packed in both sessions + ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_pre_packed_weights_counter_session_2); + + // Assert that the number of pre-packed weights that were shared equals + // the number of pre-packed weights in the second session + ASSERT_EQ(number_of_pre_packed_weights_counter_session_2, + static_cast(number_of_shared_pre_packed_weights_counter)); + } +} + +#endif + +// Dummy run to disable the FastMath mode for the current session +TEST(MathOpTest, MatMulUint64Type_DisableFastMath) { + RunMatMulTest(9, false, false, true); +} + +} // namespace test +} // namespace onnxruntime +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/test/util/compare_ortvalue.cc b/onnxruntime/test/util/compare_ortvalue.cc index 3d53d4a3a0193..64ebe24188762 100644 --- a/onnxruntime/test/util/compare_ortvalue.cc +++ b/onnxruntime/test/util/compare_ortvalue.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. // Licensed under the MIT License. #include "test/compare_ortvalue.h" @@ -65,6 +66,54 @@ const char* ElementTypeToString(MLDataType type) { return DataTypeImpl::ToString(type); } +#if defined(__aarch64__) && defined(__linux__) +template +std::pair CheckCosineSimilarity(const Tensor& outvalue, const Tensor& expected_value) { + const size_t tensor_size = static_cast(expected_value.Shape().Size()); + const T* expected_output = expected_value.Data(); + const T* real_output = outvalue.Data(); + std::pair res = std::make_pair(COMPARE_RESULT::SUCCESS, ""); + const T cosine_similarity_threshold = 0.99f; + + T dot = 0.0f, denom_a = 0.0f, denom_b = 0.0f; + for (size_t i = 0u; i < tensor_size; ++i) { + if (isnan(expected_output[i]) && isnan(real_output[i])) + continue; + if (isinf(expected_output[i]) && isinf(real_output[i])) + continue; + dot += expected_output[i] * real_output[i]; + denom_a += expected_output[i] * expected_output[i]; + denom_b += real_output[i] * real_output[i]; + } + + T cos_factor = abs(dot / (sqrt(denom_a) * sqrt(denom_b))); + if (cos_factor < cosine_similarity_threshold) { + res.first = COMPARE_RESULT::RESULT_DIFFERS; + std::ostringstream oss; + oss << std::hex << "results differed, cosine similarity factor is " << cos_factor << "."; + res.second = oss.str(); + } + return res; +} + +template +std::pair CheckCloseMatch(const Tensor& outvalue, const Tensor& expected_value) { + const size_t size1 = static_cast(expected_value.Shape().Size()); + const T* expected_output = expected_value.Data(); + const T* real_output = outvalue.Data(); + const T close_match_threshold = 1.0; + + for (size_t di = 0; di != size1; ++di) { + const T diff = expected_output[di] - real_output[di]; + if (std::fabs(diff) > close_match_threshold) { + std::ostringstream oss; + oss << "expected " << expected_output[di] << ", got " << real_output[di]; + return std::make_pair(COMPARE_RESULT::RESULT_DIFFERS, oss.str()); + } + } + return std::make_pair(COMPARE_RESULT::SUCCESS, ""); +} +#endif /** * @brief Check if two values are closely matched with given tolerance. @@ -207,6 +256,37 @@ std::pair CompareTwoTensors(const Tensor& outvalue, oss << "shape mismatch, expect " << expected_tensor.Shape().ToString() << " got " << outvalue.Shape().ToString(); return std::make_pair(COMPARE_RESULT::SHAPE_MISMATCH, oss.str()); } + +#if defined(__aarch64__) && defined(__linux__) + if (isnan(per_sample_tolerance) || isnan(per_sample_tolerance)) { + if (outvalue.IsDataType()) { + return CheckCosineSimilarity(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCosineSimilarity(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else { + return std::make_pair(COMPARE_RESULT::NOT_SUPPORT, ""); + } + } +#endif + if (outvalue.IsDataType()) { return CompareFloatResult(outvalue, expected_tensor, per_sample_tolerance, relative_per_sample_tolerance, post_processing); From 24b74aebcbd5fbaaa44ca41143b3b6afe3207978 Mon Sep 17 00:00:00 2001 From: Linnea May Date: Mon, 22 Jan 2024 15:37:09 -0800 Subject: [PATCH 11/23] [DML] Register DML operators for opset 19 (#16939) ### Description Register DML operators for opset 19. - Cast19 - Castlike19 - Constant19 - Equal19 - Identity19 - QuantizeLinear19 - DequantizeLinear19 - Reshape19 - Shape19 - Size ### Motivation and Context --------- Co-authored-by: linnealovespie --- docs/OperatorKernels.md | 27 ++++++++++++------ .../src/Operators/DmlOperatorCast.cpp | 3 +- .../src/Operators/DmlOperatorElementWise.cpp | 28 +++++++++++-------- .../src/Operators/OperatorRegistration.cpp | 10 +++++++ .../dml/OperatorAuthorHelper/OperatorHelper.h | 3 ++ .../OperatorAuthorHelper/OperatorVersions.h | 10 +++++++ .../cpu/tensor/quantize_linear_test.cc | 10 ------- 7 files changed, 59 insertions(+), 32 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 9ecc58bee0725..9a2a7ac89bbb3 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -922,10 +922,12 @@ Do not modify directly.* |BitwiseNot|*in* X:**T**
*out* Y:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BitwiseOr|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BitwiseXor|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Cast|*in* input:**T1**
*out* output:**T2**|13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Cast|*in* input:**T1**
*out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||9+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||6+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|CastLike|*in* input:**T1**
*in* target_type:**T2**
*out* output:**T2**|15+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|CastLike|*in* input:**T1**
*in* target_type:**T2**
*out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||15+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)| |||6+|**T** = tensor(float), tensor(float16)| |Celu|*in* X:**T**
*out* Y:**T**|12+|**T** = tensor(float), tensor(float16)| @@ -952,7 +954,8 @@ Do not modify directly.* |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|13+|**T** = tensor(int32), tensor(int8), tensor(uint8)| +|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|19+|**T1** = tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|||13+|**T** = tensor(int32), tensor(int8), tensor(uint8)| |||10+|**T** = tensor(int32), tensor(int8), tensor(uint8)| |Div|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -961,7 +964,8 @@ Do not modify directly.* |DynamicQuantizeLinear|*in* x:**T1**
*out* y:**T2**
*out* y_scale:**tensor(float)**
*out* y_zero_point:**T2**|11+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |Einsum|*in* Inputs:**T**
*out* Output:**T**|12+|**T** = tensor(float), tensor(float16)| |Elu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float), tensor(float16)| -|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|19+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||11+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||7+|**T** = tensor(float), tensor(float16)
**T1** = tensor(bool)| |Erf|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)| @@ -1004,7 +1008,8 @@ Do not modify directly.* |Hardmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)| |||11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| -|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|16+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|19+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||16+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||14+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -1099,7 +1104,8 @@ Do not modify directly.* |||7+|**T** = tensor(float), tensor(float16)| |QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)| |QLinearMatMul|*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| -|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|13+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| +|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|19+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| +|||13+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |||10+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |RNN|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|14+|**T** = tensor(float), tensor(float16)| |||7+|**T** = tensor(float), tensor(float16)| @@ -1150,7 +1156,8 @@ Do not modify directly.* |Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)| |||13+|**T** = tensor(float), tensor(float16)| |||6+|**T** = tensor(float), tensor(float16)| -|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||5+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|13+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| @@ -1178,7 +1185,8 @@ Do not modify directly.* |SequenceErase|*in* input_sequence:**S**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |SequenceInsert|*in* input_sequence:**S**
*in* tensor:**T**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |SequenceLength|*in* input_sequence:**S**
*out* length:**I**|11+|**I** = tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| -|Shape|*in* data:**T**
*out* shape:**T1**|15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Shape|*in* data:**T**
*out* shape:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)| @@ -1188,7 +1196,8 @@ Do not modify directly.* |||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float), tensor(float16)| |Sinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16)| -|Size|*in* data:**T**
*out* size:**T1**|13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Size|*in* data:**T**
*out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |Slice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*in* steps:**Tind**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp index 76b9b308fe98f..45ff25c4fdd90 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp @@ -29,7 +29,7 @@ class DmlOperatorCast : public DmlOperator castDesc.OutputTensor = outputDescs.data(); DML_OPERATOR_DESC opDesc = { DML_OPERATOR_CAST, &castDesc }; - + SetDmlOperatorDesc(opDesc, kernelInfo); } @@ -49,5 +49,6 @@ class DmlOperatorCast : public DmlOperator DML_OP_DEFINE_CREATION_FUNCTION(Cast, DmlOperatorCast); DML_OP_DEFINE_CREATION_FUNCTION(CastLike15, DmlOperatorCast); +DML_OP_DEFINE_CREATION_FUNCTION(CastLike19, DmlOperatorCast); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index ab8ddbfe91bf0..16bb10f004f91 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -487,7 +487,7 @@ class DmlOperatorElementwisePow : public DmlOperator Initialize(kernelInfo, kernelInputIndices, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)); std::vector inputDescs = GetDmlInputDescs(); - std::vector outputDescs = GetDmlOutputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC opDesc = {}; opDesc.InputTensor = &inputDescs[0]; @@ -497,11 +497,11 @@ class DmlOperatorElementwisePow : public DmlOperator SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW, &opDesc}, kernelInfo); } else - { + { Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)); std::vector inputDescs = GetDmlInputDescs(); - std::vector outputDescs = GetDmlOutputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {}; opDesc.InputTensor = &inputDescs[0]; @@ -519,13 +519,16 @@ class DmlOperatorElementwiseQLinear : public DmlOperator public: DmlOperatorElementwiseQLinear(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo) { - ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 3); + + ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 2); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); + Initialize(kernelInfo, std::nullopt, std::nullopt); + std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0); const uint32_t outputShapeDimCount = gsl::narrow_cast(outputShape.size()); - - Initialize(kernelInfo, std::nullopt, std::nullopt); + const DML_TENSOR_DATA_TYPE inputDataType = m_inputTensorDescs[0].GetDmlDataType(); + bool hasZeroPointTensor = kernelInfo.IsInputValid(2); uint32_t axis = 0; @@ -541,9 +544,14 @@ class DmlOperatorElementwiseQLinear : public DmlOperator axis = Dml::HandleNegativeAxis(signedAxis, outputShapeDimCount, /*validateAxis*/ false); } - // Explicitly reshape each of the inputs after the first input (scale and zero point tensors). + // Explicitly reshape each of the inputs after the first input (scale tensor and optional zero point tensor). for (uint32_t index = 1, inputCount = gsl::narrow_cast(m_inputTensorDescs.size()); index < inputCount; ++index) { + if (!kernelInfo.IsInputValid(index)) + { + continue; + } + auto edgeDesc = kernelInfo.GetInputEdgeDescription(index); assert(edgeDesc.edgeType == MLOperatorEdgeType::Tensor); @@ -587,12 +595,8 @@ class DmlOperatorElementwiseQLinear : public DmlOperator TOperatorDesc opDesc = {}; opDesc.InputTensor = &inputDescs[0]; opDesc.ScaleTensor = &inputDescs[1]; - opDesc.ZeroPointTensor = &inputDescs[2]; + opDesc.ZeroPointTensor = hasZeroPointTensor ? &inputDescs[2] : nullptr; opDesc.OutputTensor = &outputDescs[0]; - - TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ScaleTensor, 1); - TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ZeroPointTensor, 2); - SetDmlOperatorDesc({ApiTraits::OperatorDescTraits::Type, &opDesc}, kernelInfo); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 15a8051953c79..18e29c8b99ced 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -436,6 +436,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(FusedMatMul); DML_OP_EXTERN_CREATION_FUNCTION(FusedMatMulActivation); DML_OP_EXTERN_CREATION_FUNCTION(Cast); DML_OP_EXTERN_CREATION_FUNCTION(CastLike15); +DML_OP_EXTERN_CREATION_FUNCTION(CastLike19); DML_OP_EXTERN_CREATION_FUNCTION(MemcpyFromHost); DML_OP_EXTERN_CREATION_FUNCTION(MemcpyToHost); DML_OP_EXTERN_CREATION_FUNCTION(TopK7); @@ -785,6 +786,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_COPY(13, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(14, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(16, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, + {REG_INFO_COPY(19, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, @@ -798,6 +800,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_COPY( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO_COPY(13, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO_COPY(14, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO_COPY(19, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, // Elementwise {REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, @@ -857,8 +860,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, Affine, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)}, {REG_INFO( 13, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)}, + {REG_INFO( 19, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)}, {REG_INFO( 10, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)}, {REG_INFO( 13, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)}, + {REG_INFO( 19, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)}, {REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)}, @@ -943,6 +948,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7, DmlGraphSupport::Supported)}, {REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)}, {REG_INFO( 13, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)}, + {REG_INFO( 19, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)}, {REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, {REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, {REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, @@ -1004,7 +1010,9 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO( 13, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, + {REG_INFO( 19, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO_VER( 15, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 19, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)}, {REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)}, {REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported)}, @@ -1015,8 +1023,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, {REG_INFO( 13, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, {REG_INFO( 15, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, + {REG_INFO( 19, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, {REG_INFO( 7, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, {REG_INFO( 13, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, + {REG_INFO( 19, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, {REG_INFO_DYNAMIC_OUTPUTS( 9, NonZero, typeNameListDefault, supportedTypeListNonZero, DmlGraphSupport::NotSupported)}, {REG_INFO_DYNAMIC_OUTPUTS(13, NonZero, typeNameListDefault, supportedTypeListNonZero, DmlGraphSupport::NotSupported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 0e0e6bb1eaf5c..0d425997e6a6a 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1606,6 +1606,7 @@ using ShapeInferenceHelper_Expand = ExpandHelper; using ShapeInferenceHelper_Reshape7 = ReshapeHelper; using ShapeInferenceHelper_Reshape13 = ReshapeHelper; using ShapeInferenceHelper_Reshape14 = ReshapeHelper; +using ShapeInferenceHelper_Reshape19 = ReshapeHelper; using ShapeInferenceHelper_ConstantOfShape = ConstantOfShapeHelper; using ShapeInferenceHelper_Tile = TileHelper; using ShapeInferenceHelper_Resize10 = VersionedOpsetHelper; @@ -1725,6 +1726,7 @@ using ShapeInferenceHelper_Identity7 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity13 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity14 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity16 = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_Identity19 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_MatMul = MatMulHelper; using ShapeInferenceHelper_MatMulInteger = MatMulHelper; using ShapeInferenceHelper_QLinearMatMul = QLinearMatMulHelper; @@ -1750,6 +1752,7 @@ using ShapeInferenceHelper_CumSum14 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Range = RangeHelper; using ShapeInferenceHelper_CastLike15 = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_CastLike19 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_DmlFusedConv = ConvHelper; using ShapeInferenceHelper_DmlFusedConvTranspose = ConvTransposeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 8438bc620712c..79efc2d2836fe 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -413,6 +413,16 @@ namespace OperatorHelper namespace OnnxOperatorSet19 { static const int sc_sinceVer_AveragePool = 19; + static const int sc_sinceVer_Cast = 19; + static const int sc_sinceVer_CastLike = 19; + static const int sc_sinceVer_Constant = 19; + static const int sc_sinceVer_Equal = 19; + static const int sc_sinceVer_Identity = 19; + static const int sc_sinceVer_QuantizeLinear = 19; + static const int sc_sinceVer_DequantizeLinear = 19; + static const int sc_sinceVer_Reshape = 19; + static const int sc_sinceVer_Shape = 19; + static const int sc_sinceVer_Size = 19; } namespace MsftOperatorSet1 diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 026bb07edf44c..0c8d6c46d4639 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -34,11 +34,6 @@ TEST(DequantizeLinearOpTest, Int8) { // scalar zero & scale with int8 TEST(DequantizeLinearOpTest, Int32) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect"; - } - OpTester test("DequantizeLinear", 10); std::vector dims{4}; test.AddInput("x", dims, {-30, -3, 100, 127}); @@ -98,11 +93,6 @@ TEST(DequantizeLinearOpMLFloat16Test, Scalar) { // dequantize without zero point TEST(DequantizeLinearOpTest, Without_Zero_Point) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect"; - } - OpTester test("DequantizeLinear", 10); test.AddInput("x", {}, {100}); test.AddInput("x_scale", {}, {2.0f}); From e283cdb21857170848e9b8a8fbca24d0463b4193 Mon Sep 17 00:00:00 2001 From: Yifan Li <109183385+yf711@users.noreply.github.com> Date: Mon, 22 Jan 2024 15:44:57 -0800 Subject: [PATCH 12/23] Fix Fuzz Testing CI (#19228) ### Description Add BuildArch To verify: https://aiinfra.visualstudio.com/Lotus/_build/results?buildId=400952&view=logs&j=5b022bb4-70a7-5401-8766-a8a7802c7150&t=291e85c7-5547-590b-50de-4e01fcd4eba3&l=14 ### Motivation and Context --- tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml index b8f9566274acc..db39c2cd2087f 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml @@ -28,7 +28,7 @@ jobs: parameters: EnvSetupScript: $(EnvSetupScript) DownloadCUDA: false - BuildArch: $(buildArch) + BuildArch: x64 BuildConfig: $(BuildConfig) MachinePool: 'onnxruntime-Win-CPU-2022' WithCache: true From 2e0a388c36b92bc412dfa8ad45af23c7f28a4d49 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 23 Jan 2024 07:53:26 +0800 Subject: [PATCH 13/23] [js/webgpu] Add HardSigmoid support (#19215) ### Description This op is required in mobilenetv3-small-100. With this PR, mobilenetv3-small-100 model becomes less than 10 ms from over 100 ms on ADL. --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 1 + js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 20 +++++++++++++++++++ js/web/test/suite-test-list.jsonc | 6 +++--- .../providers/js/js_execution_provider.cc | 2 ++ .../core/providers/js/operators/unary.cc | 3 +++ 6 files changed, 30 insertions(+), 3 deletions(-) diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 2f510308d9306..2557971eb4ded 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -52,6 +52,7 @@ Do not modify directly.* | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | Greater | ai.onnx(7-8,9-12,13+) | | | GreaterOrEqual | ai.onnx(12-15,16+) | | +| HardSigmoid | ai.onnx(6+) | | | If | ai.onnx(1-10,11-12,13-18,19+) | | | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | | LayerNormalization | ai.onnx(17+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 90e02da986b8f..cc504093ca0d7 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -82,6 +82,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], ['Greater', [binaryOps.greater]], ['GreaterOrEqual', [binaryOps.greaterOrEqual]], + ['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]], ['InstanceNormalization', [instanceNorm]], ['LayerNormalization', [layerNorm]], ['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index a25e7fe4229b4..82311d72e58b9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -242,6 +242,26 @@ export const sigmoid = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`)); }; +export interface HardSigmoidAttributes extends AttributeWithCacheKey { + readonly alpha: number; + readonly beta: number; +} + +export const parseHardSigmoidAttributes = (attributes: Record): HardSigmoidAttributes => + createAttributeWithCacheKey(attributes as { + alpha: number; + beta: number; + }); + +export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttributes): void => { + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); + context.compute(createElementwiseProgramInfo( + context.inputs[0], 'HardSigmoid', + a => `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${ + attributes.beta})))`, + undefined, attributes.cacheKey)); +}; + export const sin = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sin', 'sin')); }; diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 033b3b3f4b0f5..373b3c645df57 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -597,9 +597,9 @@ // // "test_hardmax_example", // // "test_hardmax_negative_axis", // // "test_hardmax_one_hot", - // // "test_hardsigmoid_default", - // // "test_hardsigmoid_example", - // // "test_hardsigmoid", + "test_hardsigmoid_default", + "test_hardsigmoid_example", + "test_hardsigmoid", // // "test_hardswish_expanded", // // "test_hardswish", "test_if", diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index c2ff2ebc39e13..af9658271d210 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -98,6 +98,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Erf); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Sigmoid); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Sigmoid); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, HardSigmoid); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Log); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Log); @@ -392,6 +393,7 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO(13, Erf), KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid), KERNEL_CREATE_INFO(13, Sigmoid), + KERNEL_CREATE_INFO(6, HardSigmoid), KERNEL_CREATE_INFO_VERSIONED(6, 12, Log), KERNEL_CREATE_INFO(13, Log), diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc index 78563d30b0136..9082527e3a8d7 100644 --- a/onnxruntime/core/providers/js/operators/unary.cc +++ b/onnxruntime/core/providers/js/operators/unary.cc @@ -77,6 +77,9 @@ JSEP_KERNEL_IMPL(Sigmoid, Sigmoid) JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, Sigmoid) JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, Sigmoid) +JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(HardSigmoid, HardSigmoid, alpha, 0.2, beta, 0.5) +JSEP_ELEMENTWISE_KERNEL(HardSigmoid, 6, HardSigmoid) + JSEP_KERNEL_IMPL(Log, Log) JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, Log) JSEP_ELEMENTWISE_KERNEL(Log, 13, Log) From d226e40856738531cf8b481b07379545f7cfefe2 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 23 Jan 2024 08:08:55 +0800 Subject: [PATCH 14/23] [js/webgpu] set query type in onRunStart (#19202) ### Description `env.webgpu.profiling` is a global flag. It may change before each session.run. So the best place is to update it in `onRunStart` event. After this, we can directly check `this.queryType`'s value. Without this pr, we need to make sure that `getCommandEncoder()` is called before checking `this.queryType`. Otherwise, it may happen that `pendingKernels`'s length is not equal to `pendingDispatchNumber`'s length. See the two ugly workarounds [1)](https://github.com/microsoft/onnxruntime/pull/18989/commits/e630dbf528fc3a955702cceb968930d0abdfc652#diff-006fc84d3997f96a29b8033bd2075d6a0a9509211bd5812a6b934fc74fedfd9dR267-R268) and [2)](https://github.com/microsoft/onnxruntime/pull/18989/commits/e630dbf528fc3a955702cceb968930d0abdfc652#diff-618fe297fbe7a1da586380163b8fd2627311ccc217640a3c5cdc9c17a33472c1R73-R80) if we don't introduce `onRunStart`. Or we need to call `setQueryType` in each kernel run. --- js/web/lib/wasm/binding/ort-wasm.d.ts | 4 ++++ js/web/lib/wasm/jsep/backend-webgpu.ts | 9 +++++---- js/web/lib/wasm/wasm-core-impl.ts | 2 +- onnxruntime/wasm/js_internal_api.js | 3 +++ 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 9d4d5875310b7..68054210e79a7 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -182,6 +182,10 @@ export interface OrtWasmModule extends EmscriptenModule { jsepCreateDownloader: (gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes) => () => Promise; + /** + * [exported from js_internal_api.js] Called when InferenceSession.run started. + */ + jsepOnRunStart: () => void; // #endregion } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 2956ec1cad4da..afef7042a4280 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -208,7 +208,7 @@ export class WebGpuBackend { Object.defineProperty(this.env.webgpu, 'device', {value: this.device}); - // init queryType, which is necessary for createKernel + // init queryType, which is necessary for InferenceSession.create this.setQueryType(); } @@ -223,8 +223,6 @@ export class WebGpuBackend { if (!this.commandEncoder) { this.commandEncoder = this.device.createCommandEncoder(); - // refresh queryType, as sometimes we only need to enable query for a specific run - this.setQueryType(); if (this.queryType !== 'none' && typeof this.querySet === 'undefined') { this.querySet = this.device.createQuerySet({ type: 'timestamp', @@ -639,6 +637,7 @@ export class WebGpuBackend { return createView(data.buffer, type); }; } + // #endregion writeTimestamp(index: number): void { if (this.queryType !== 'inside-passes') { return; @@ -657,5 +656,7 @@ export class WebGpuBackend { } } } - // #endregion + onRunStart(): void { + this.setQueryType(); + } } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 5821fac3c468f..8768643fa7257 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -488,8 +488,8 @@ export const run = async( } } + wasm.jsepOnRunStart?.(); let errorCode: number; - if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { errorCode = await wasm._OrtRunWithBinding( sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle); diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 25ece9c700d5d..7c70515e73eab 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -186,4 +186,7 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { return backend['createDownloader'](gpuBuffer, size, type); }; + Module['jsepOnRunStart'] = () => { + return backend['onRunStart'](); + }; }; From 37d14d78960fb1ba54c0bb2dc3be740e93d2ca15 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Mon, 22 Jan 2024 18:14:41 -0800 Subject: [PATCH 15/23] [QNN EP] Create Windows ARM64 nightly python package (#19128) ### Description Adds a job to create a nightly python package for ORT/QNN on Windows ARM64. Must build onnxruntime-qnn with python 3.11 and numpy 1.25. **Note: pipeline run may take up to 3 hrs** ### Motivation and Context Make it possible to get a nightly python package with the latest updates to QNN EP. Issue #19161 --- .../azure-pipelines/py-packaging-pipeline.yml | 8 +- .../templates/py-packaging-stage.yml | 13 ++ .../templates/py-win-arm64-qnn.yml | 165 ++++++++++++++++++ 3 files changed, 185 insertions(+), 1 deletion(-) create mode 100644 tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index 06cca0068523d..5349b1ca67ab1 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -29,6 +29,11 @@ parameters: type: boolean default: true +- name: enable_windows_arm64_qnn + displayName: 'Whether Windows ARM64 package with QNN EP is built.' + type: boolean + default: true + - name: build_py_parameters displayName: 'Specify extra build parameters' type: string @@ -64,5 +69,6 @@ stages: enable_windows_gpu: ${{ parameters.enable_windows_gpu }} enable_mac_cpu: ${{ parameters.enable_mac_cpu }} enable_linux_arm: ${{ parameters.enable_linux_arm }} + enable_windows_arm64_qnn: ${{ parameters.enable_windows_arm64_qnn }} build_py_parameters: ${{ parameters.build_py_parameters }} - cmake_build_type: ${{ parameters.cmake_build_type }} \ No newline at end of file + cmake_build_type: ${{ parameters.cmake_build_type }} diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 8669a883c31f1..297498843c38d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -35,6 +35,11 @@ parameters: type: boolean default: true +- name: enable_windows_arm64_qnn + displayName: 'Whether Windows ARM64 package with QNN EP is built.' + type: boolean + default: true + # TODO: Now the Windows jobs use a different cmake build type. Consider to merge it. - name: cmake_build_type type: string @@ -446,3 +451,11 @@ stages: machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} + + - ${{ if eq(parameters.enable_windows_arm64_qnn, true) }}: + - template: py-win-arm64-qnn.yml + parameters: + MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' + QNN_SDK: 'qnn-v2.18.0.240101_win' + PYTHON_VERSION: '3.11' + NUMPY_VERSION: '1.25.2' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml new file mode 100644 index 0000000000000..adf7aa9c43205 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -0,0 +1,165 @@ +parameters: + +- name: MACHINE_POOL + type: string + default: 'onnxruntime-qnn-windows-vs-2022-arm64' + +- name: QNN_SDK + displayName: QNN Windows SDK path + type: string + default: qnn-v2.18.0.240101_win + +- name: PYTHON_VERSION + type: string + default: '3.11' + +- name: NUMPY_VERSION + type: string + default: '1.25.2' + +- name: ENV_SETUP_SCRIPT + type: string + default: '' + +- name: BUILD_PY_PARAMETERS + displayName: > + Extra parameters to pass to build.py. Don't put newlines in here. + type: string + default: '' + +jobs: +- job: Win_py_arm64_qnn_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }} + timeoutInMinutes: 210 + workspace: + clean: all + pool: + name: ${{ parameters.MACHINE_POOL }} + variables: + GRADLE_OPTS: '-Dorg.gradle.daemon=false' + VSGenerator: 'Visual Studio 17 2022' + QNN_SDK_ROOTDIR: 'C:\data\qnnsdk\${{parameters.QNN_SDK}}' + steps: + - checkout: self + clean: true + submodules: recursive + + - template: telemetry-steps.yml + + - script: | + DIR C:\data\qnnsdk + displayName: Check available QNN SDKs + + - script: | + MKDIR $(Agent.ToolsDirectory)\Python\3.11.0\arm64 + XCOPY /s /y /h /e /c /q "C:\Python\Python311\*.*" $(Agent.ToolsDirectory)\Python\3.11.0\arm64\ + COPY NUL $(Agent.ToolsDirectory)\Python\3.11.0\arm64.complete + DIR $(Agent.ToolsDirectory)\Python + DIR $(Agent.ToolsDirectory)\Python\3.11.0 + DIR $(Agent.ToolsDirectory)\Python\3.11.0\arm64 + displayName: Copy python 3.11.0 version to agent tools directory + + - task: UsePythonVersion@0 + inputs: + versionSpec: ${{ parameters.PYTHON_VERSION }} + addToPath: true + architecture: 'arm64' + + - task: onebranch.pipeline.tsaoptions@1 + displayName: 'OneBranch TSAOptions' + inputs: + tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' + appendSourceBranchName: false + + - task: PythonScript@0 + inputs: + scriptSource: inline + script: | + import subprocess + subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel', 'numpy==${{parameters.NUMPY_VERSION}}']) + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Install python modules' + + - template: set-nightly-build-option-variable-step.yml + + - task: PythonScript@0 + displayName: 'Generate cmake config' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + --config RelWithDebInfo + --build_dir $(Build.BinariesDirectory) + --skip_submodule_sync + --cmake_generator "$(VSGenerator)" + --use_qnn + --qnn_home $(QNN_SDK_ROOTDIR) + --enable_pybind + --parallel --update + --numpy_version ${{ parameters.NUMPY_VERSION }} + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} + workingDirectory: '$(Build.BinariesDirectory)' + + - task: VSBuild@1 + displayName: 'Build' + inputs: + solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' + platform: 'arm64' + configuration: RelWithDebInfo + msbuildArchitecture: 'arm64' + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' + createLogFile: true + + # Esrp signing + - template: win-esrp-dll.yml + parameters: + FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' + DisplayName: 'ESRP - Sign Native dlls' + DoEsrp: true + Pattern: '*.pyd,*.dll' + + - task: PythonScript@0 + displayName: 'Build wheel' + inputs: + scriptPath: '$(Build.SourcesDirectory)\setup.py' + arguments: 'bdist_wheel ${{ parameters.BUILD_PY_PARAMETERS }} $(NightlyBuildOption) --wheel_name_suffix=qnn' + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + + - task: CopyFiles@2 + displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist' + Contents: '*.whl' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: PublishBuildArtifacts@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + ArtifactName: onnxruntime_qnn + + - script: | + 7z x *.whl + workingDirectory: '$(Build.ArtifactStagingDirectory)' + displayName: 'unzip the package' + + - task: CredScan@3 + displayName: 'Run CredScan' + inputs: + debugMode: false + continueOnError: true + + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll' + + - task: TSAUpload@2 + displayName: 'TSA upload' + condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) + inputs: + GdnPublishTsaOnboard: false + GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' + + - template: component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' From b2aec41a8309bc2dced74a991b1f3c311e037e3d Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Mon, 22 Jan 2024 19:17:04 -0800 Subject: [PATCH 16/23] [ROCm] enable hipGraph (#18382) This ports the cudaGraph support from the CUDA EP to the ROCM EP's hipGraph. --- cmake/onnxruntime_unittests.cmake | 7 ++ .../core/session/onnxruntime_c_api.h | 3 + .../providers/rocm/rocm_execution_provider.cc | 77 +++++++++++- .../providers/rocm/rocm_execution_provider.h | 24 ++++ .../rocm/rocm_execution_provider_info.cc | 3 + .../rocm/rocm_execution_provider_info.h | 2 + .../providers/rocm/rocm_provider_factory.cc | 2 + onnxruntime/core/session/inference_session.cc | 52 +++++--- .../core/session/provider_bridge_ort.cc | 1 + onnxruntime/test/shared_lib/test_inference.cc | 112 +++++++++++++++--- 10 files changed, 241 insertions(+), 42 deletions(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index fa395802d95ff..0987d6d164dbd 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1277,6 +1277,9 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) if (onnxruntime_USE_CUDA) list(APPEND onnxruntime_shared_lib_test_LIBS cudart) endif() + if (onnxruntime_USE_ROCM) + list(APPEND onnxruntime_shared_lib_test_LIBS hip::host) + endif() if (onnxruntime_USE_TENSORRT) list(APPEND onnxruntime_shared_lib_test_LIBS ${TENSORRT_LIBRARY_INFER}) endif() @@ -1294,6 +1297,10 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu) endif() + if (onnxruntime_USE_ROCM) + target_include_directories(onnxruntime_shared_lib_test PRIVATE ${onnxruntime_ROCM_HOME}/include) + target_compile_definitions(onnxruntime_shared_lib_test PRIVATE __HIP_PLATFORM_AMD__) + endif() if (CMAKE_SYSTEM_NAME STREQUAL "Android") target_sources(onnxruntime_shared_lib_test PRIVATE "${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc" diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 101a578ec3e1d..2ce9d361e8e56 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -496,6 +496,7 @@ typedef struct OrtROCMProviderOptions { has_user_compute_stream{}, user_compute_stream{}, default_memory_arena_cfg{}, + enable_hip_graph{false}, tunable_op_enable{false}, tunable_op_tuning_enable{false}, tunable_op_max_tuning_duration_ms{} {} @@ -548,6 +549,8 @@ typedef struct OrtROCMProviderOptions { */ OrtArenaCfg* default_memory_arena_cfg; + int enable_hip_graph; + /** \brief Enable TunableOp for using. * Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default. * This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE. diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index d7c5098d9dbe4..d7bec337a6be4 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -170,6 +170,8 @@ ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId de MIOPEN_CALL_THROW(miopenCreate(&miopen_handle_)); MIOPEN_CALL_THROW(miopenSetStream(miopen_handle_, stream)); + + hip_graph_.SetStream(stream); } ROCMExecutionProvider::PerThreadContext::~PerThreadContext() { @@ -177,6 +179,33 @@ ROCMExecutionProvider::PerThreadContext::~PerThreadContext() { ORT_IGNORE_RETURN_VALUE(MIOPEN_CALL(miopenDestroy(miopen_handle_))); } +bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const { + return regular_run_count_before_graph_capture_ >= min_num_runs_before_hip_graph_capture_; +} + +void ROCMExecutionProvider::PerThreadContext::CaptureBegin() { + hip_graph_.Reset(); + hip_graph_.CaptureBegin(); +} + +void ROCMExecutionProvider::PerThreadContext::CaptureEnd() { + hip_graph_.CaptureEnd(); + is_graph_captured_ = true; +} + +bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptured() const { + return is_graph_captured_; +} + +Status ROCMExecutionProvider::PerThreadContext::ReplayGraph() { + ORT_ENFORCE(IsGraphCaptured()); + return hip_graph_.Replay(); +} + +void ROCMExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() { + ++regular_run_count_before_graph_capture_; +} + void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) { if (auto env_tunable_op_enable = onnxruntime::ParseTestOnlyEnvironmentVariable( "ORT_ROCM_TUNABLE_OP_ENABLE", {"0", "1"}, "Use provider_options \"tunable_op_enable\" instead."); @@ -219,6 +248,11 @@ ROCMExecutionProvider::ROCMExecutionProvider(const ROCMExecutionProviderInfo& in if (info.external_allocator_info.UseExternalAllocator()) { use_ep_level_unified_stream_ = true; stream_ = nullptr; + } else if (info.enable_hip_graph) { + // current hip graph implementation only works with single stream + // use EP level unified stream for all the reqeust + HIP_CALL_THROW(hipStreamCreateWithFlags(&stream_, hipStreamNonBlocking)); + use_ep_level_unified_stream_ = true; } else { stream_ = nullptr; } @@ -322,25 +356,58 @@ Status ROCMExecutionProvider::Sync() const { Status ROCMExecutionProvider::OnRunStart() { // always set ROCM device when session::Run() in case it runs in a worker thread HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId())); + if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { + LOGS_DEFAULT(INFO) << "Capturing the hip graph for this model"; + GetPerThreadContext().CaptureBegin(); + } return Status::OK(); } Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) { + if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) { + if (GetPerThreadContext().IsGraphCaptureAllowed()) { + GetPerThreadContext().CaptureEnd(); + // HIP work issued to a capturing stream doesn’t actually run on the GPU, + // so run the captured graph here to actually execute the work. + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph()); + } else { + GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(); + } + } + if (sync_stream) { HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream_))); } - // In extreme cases (e.g., 1-op graph and that op fallbacks to CPU), - // PerThreadContext won't be created and there is nothing to - // release. This didn't happen before because we always call - // GetPerThreadContext in OnRunStart. - if (PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) { + // The reason of !IsGraphCaptureEnabled(): + // If hip graph is enabled, the per thread context will not be released + // because the per thread hip graph needs to be maintained and replayed for + // the next run. + // The reason of PerThreadContextCache()->find(this) != PerThreadContextCache()->end(): + // In extreme cases (e.g., 1-op graph and that op fallbacks to CPU), + // PerThreadContext won't be created and there is nothing to + // release. This didn't happen before because we always call + // GetPerThreadContext in OnRunStart. + if (!IsGraphCaptureEnabled() && + PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) { ReleasePerThreadContext(); } return Status::OK(); } +bool ROCMExecutionProvider::IsGraphCaptureEnabled() const { + return info_.enable_hip_graph; +} + +bool ROCMExecutionProvider::IsGraphCaptured() const { + return GetPerThreadContext().IsGraphCaptured(); +} + +Status ROCMExecutionProvider::ReplayGraph() { + return GetPerThreadContext().ReplayGraph(); +} + namespace rocm { // opset 1 to 9 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index c4945b9ac2481..37d5f7b42210f 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -10,6 +10,7 @@ #include "core/framework/execution_provider.h" #include "core/platform/ort_mutex.h" #include "core/providers/rocm/rocm_execution_provider_info.h" +#include "core/providers/rocm/rocm_graph.h" #include "core/providers/rocm/rocm_pch.h" #include "core/providers/rocm/shared_inc/rocm_utils.h" #include "core/providers/rocm/shared_inc/rocm_call.h" @@ -73,6 +74,9 @@ class ROCMExecutionProvider : public IExecutionProvider { std::unique_ptr GetProfiler() override; + bool IsGraphCaptureEnabled() const override; + bool IsGraphCaptured() const override; + Status ReplayGraph() override; void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; std::vector CreatePreferredAllocators() override; @@ -81,6 +85,7 @@ class ROCMExecutionProvider : public IExecutionProvider { ROCMExecutionProviderInfo info_; hipDeviceProp_t device_prop_; bool external_stream_ = false; + // only used when set user external stream or hip graph hipStream_t stream_ = nullptr; bool use_ep_level_unified_stream_ = false; @@ -133,6 +138,13 @@ class ROCMExecutionProvider : public IExecutionProvider { } } + bool IsGraphCaptureAllowed() const; + void CaptureBegin(); + void CaptureEnd(); + bool IsGraphCaptured() const; + Status ReplayGraph(); + void IncrementRegularRunCountBeforeGraphCapture(); + private: rocblas_handle rocblas_handle_ = nullptr; miopenHandle_t miopen_handle_ = nullptr; @@ -141,6 +153,18 @@ class ROCMExecutionProvider : public IExecutionProvider { std::unique_ptr> constant_ones_double_; std::unique_ptr> constant_ones_half_; std::unique_ptr> constant_ones_bfloat16_; + + // Hip graph with multi threads will be supported in the future, so hip_graph_ + // is put under PerThreadContext. + ROCMGraph hip_graph_; + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + + // There is chance that the second regular run allocates GPU memory for causes like: + // (1) memory pattern is enabled. (2) arena allocation for stream. + // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs + // to allocate enough memory in Arena before graph capturing. + const int min_num_runs_before_hip_graph_capture_ = 2; // required min regular runs before graph capture for the necessary memory allocations. }; using PerThreadContextMap = std::unordered_map>; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc index 650635c153640..b557f92287f2b 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc @@ -21,6 +21,7 @@ constexpr const char* kGpuExternalAlloc = "gpu_external_alloc"; constexpr const char* kGpuExternalFree = "gpu_external_free"; constexpr const char* kGpuExternalEmptyCache = "gpu_external_empty_cache"; constexpr const char* kMiopenConvUseMaxWorkspace = "miopen_conv_use_max_workspace"; +constexpr const char* kEnableHipGraph = "enable_hip_graph"; constexpr const char* kTunableOpEnable = "tunable_op_enable"; constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable"; constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_duration_ms"; @@ -84,6 +85,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P info.miopen_conv_exhaustive_search) .AddAssignmentToReference(rocm::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream) .AddAssignmentToReference(rocm::provider_option_names::kMiopenConvUseMaxWorkspace, info.miopen_conv_use_max_workspace) + .AddAssignmentToReference(rocm::provider_option_names::kEnableHipGraph, info.enable_hip_graph) .AddValueParser( rocm::provider_option_names::kTunableOpEnable, [&info](const std::string& value_str) -> Status { @@ -121,6 +123,7 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution {rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)}, {rocm::provider_option_names::kDoCopyInDefaultStream, MakeStringWithClassicLocale(info.do_copy_in_default_stream)}, {rocm::provider_option_names::kMiopenConvUseMaxWorkspace, MakeStringWithClassicLocale(info.miopen_conv_use_max_workspace)}, + {rocm::provider_option_names::kEnableHipGraph, MakeStringWithClassicLocale(info.enable_hip_graph)}, {rocm::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op.enable)}, {rocm::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)}, {rocm::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op.max_tuning_duration_ms)}, diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h index e35c0cc0afecc..2f549cc1ac143 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h @@ -63,6 +63,8 @@ struct ROCMExecutionProviderInfo { // If set to false, use fix workspace size (32M) for Conv algo search, the final algo might not be the best. bool miopen_conv_use_max_workspace{true}; + bool enable_hip_graph{false}; + rocm::TunableOpInfo tunable_op{}; static ROCMExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc index 4d88c25469372..88ef666678b3e 100644 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc +++ b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc @@ -185,6 +185,7 @@ struct ROCM_Provider : Provider { info.has_user_compute_stream = params->has_user_compute_stream != 0; info.user_compute_stream = params->user_compute_stream; info.default_memory_arena_cfg = params->default_memory_arena_cfg; + info.enable_hip_graph = params->enable_hip_graph; info.tunable_op.enable = params->tunable_op_enable; info.tunable_op.tuning_enable = params->tunable_op_tuning_enable; info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms; @@ -215,6 +216,7 @@ struct ROCM_Provider : Provider { rocm_options.user_compute_stream = internal_options.user_compute_stream; } rocm_options.default_memory_arena_cfg = internal_options.default_memory_arena_cfg; + rocm_options.enable_hip_graph = internal_options.enable_hip_graph; rocm_options.tunable_op_enable = internal_options.tunable_op.enable; rocm_options.tunable_op_tuning_enable = internal_options.tunable_op.tuning_enable; rocm_options.tunable_op_max_tuning_duration_ms = internal_options.tunable_op.max_tuning_duration_ms; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e8853c8824738..39f47c09f2402 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -153,7 +153,7 @@ static bool AreAllComputeNodesAssignedToCudaEp(const Graph& graph) { // Empty node provider means CPU EP if (!node_provider.empty() && - node_provider != kCudaExecutionProvider && + !(node_provider == kCudaExecutionProvider || node_provider == kRocmExecutionProvider) && node_provider != kCpuExecutionProvider) { nodes_on_cpu_and_cuda_eps_only = false; break; @@ -1715,7 +1715,8 @@ common::Status InferenceSession::Initialize() { // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); - // Currently CUDA graph is only considered by CUDA EP and TRT EP. + // Currently CUDA graph is only considered by CUDA EP and TRT EP, and + // HIP graph is only considered by ROCM EP. // // Check for CUDA EP: // If the CUDA EP is part of the providers list for this session AND @@ -1728,47 +1729,58 @@ common::Status InferenceSession::Initialize() { // The TRT EP is configured to do a graph capture AND // All the graph nodes have been assigned to the TRT EP, // Then the TRT EP is cached for triggering a ReplayGraph() in Run(). - std::vector cuda_graph_support_ep_list = {onnxruntime::kTensorrtExecutionProvider, onnxruntime::kCudaExecutionProvider}; + // + // Check for ROCM EP: + // If the ROCM EP is part of the providers list for this session AND + // The ROCM EP is configured to do a graph capture AND + // All the "compute" graph nodes have been assigned to the ROCM EP, + // Then the ROCM EP is cached for triggering a ReplayGraph() in Run(). + // + std::vector graph_support_ep_list = { + onnxruntime::kTensorrtExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kRocmExecutionProvider}; - for (auto& it : cuda_graph_support_ep_list) { + for (auto& it : graph_support_ep_list) { auto* target_ep = execution_providers_.Get(it); if (target_ep && target_ep->IsGraphCaptureEnabled()) { - // CUDA Graphs can't work with control flow nodes + // CUDA/HIP Graphs can't work with control flow nodes if (HasControlflowNodes(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << "as the model has control flow nodes which can't be supported by CUDA Graphs."; + LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA/HIP Graph feature as requested by the user " + << "as the model has control flow nodes which can't be supported by CUDA/HIP Graphs."; ORT_RETURN_IF_ERROR_SESSIONID_( ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - "as the model has control flow nodes which can't be supported by CUDA Graphs.")); + "This session cannot use the CUDA/HIP Graph feature as requested by the user " + "as the model has control flow nodes which can't be supported by CUDA/HIP Graphs.")); } - if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0) { + if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0 || + strcmp(target_ep->Type().c_str(), onnxruntime::kRocmExecutionProvider) == 0) { // Ensure that all nodes have been partitioned to CUDA or CPU EP && there are no memcpy nodes // The reasoning behind this logic is that certain shape nodes will be forced onto CPU // and as long as there are no memcpy nodes this is confirmation that no compute nodes have been placed on the CPU EP // which is all we care about. if (!AreAllComputeNodesAssignedToCudaEp(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << " as all compute graph nodes have not been partitioned to the CUDA EP."; + LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA/HIP Graph feature as requested by the user " + << " as all compute graph nodes have not been partitioned to the CUDA/HIP EP."; ORT_RETURN_IF_ERROR_SESSIONID_( ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - " as all compute graph nodes have not been partitioned to the CUDA EP.")); + "This session cannot use the CUDA/HIP Graph feature as requested by the user " + " as all compute graph nodes have not been partitioned to the CUDA/HIP EP.")); } // Log a warning for the user to know that there are shape subgraphs that will execute on CPU if (HasShapeSubgraphNodes(graph)) { LOGS(*session_logger_, WARNING) << "This model has shape massaging nodes that will execute on CPU. " - << "Use the CUDA Graph feature with caution. " + << "Use the CUDA/HIP Graph feature with caution. " << "As long as the intermediate shapes produced in the model " - << "using the representative input used to capture the CUDA graph, " + << "using the representative input used to capture the CUDA/HIP graph, " << "will match the shapes produced in the model for other inputs " << "of the same shape as the representative input (common case), " - << "it is safe to use the CUDA Graph feature."; + << "it is safe to use the CUDA/HIP Graph feature."; } } else { // Following code path is for TRT EP currently. @@ -1787,7 +1799,7 @@ common::Status InferenceSession::Initialize() { } } - LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user."; + LOGS(*session_logger_, INFO) << "This session will use the CUDA/HIP Graph feature as requested by the user."; cached_execution_provider_for_graph_replay_.SetExecutionProvider(target_ep); break; // Make sure only one ep can run CUDA graph. } @@ -2477,7 +2489,9 @@ Status InferenceSession::Run(const RunOptions& run_options, // As N+1 inference runs (N for memory allocation and 1 for graph capturing) // are needed before replaying the captured graph, here run N inference runs recursively until graph captured, // so that users just need one session run to capture the graph. - // N is defined in min_num_runs_before_cuda_graph_capture_ for CUDA EP, and the value could be different for other EP. + // N is defined in min_num_runs_before_cuda_graph_capture_ for CUDA EP, + // N is defined in min_num_runs_before_hip_graph_capture_ for ROCM EP, + // and the value could be different for other EP. if (retval.IsOK() && cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() && !cached_execution_provider_for_graph_replay_.IsGraphCaptured()) { LOGS(*session_logger_, INFO) << "Start another run for necessary memory allocation or graph capture."; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 3269c9f0f4e4b..3178c13d30eec 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2380,6 +2380,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateROCMProviderOptions, _Outptr_ OrtROCMProvider options->has_user_compute_stream = 0; options->user_compute_stream = nullptr; options->default_memory_arena_cfg = nullptr; + options->enable_hip_graph = false; options->tunable_op_enable = 0; options->tunable_op_tuning_enable = 0; options->tunable_op_max_tuning_duration_ms = 0; diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 6ffe72f81bd24..8dad2c8e2d10d 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -43,6 +43,10 @@ #include #endif +#ifdef USE_ROCM +#include +#endif + // Once we use C++17 this could be replaced with std::size template constexpr size_t countof(T (&)[N]) { return N; } @@ -1762,6 +1766,27 @@ TEST(CApiTest, get_allocator_cuda) { } #endif +#ifdef USE_ROCM +TEST(CApiTest, get_allocator_rocm) { + Ort::SessionOptions session_options; + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, 0)); + Ort::Session session(*ort_env, NAMED_AND_ANON_DIM_PARAM_URI, session_options); + + Ort::MemoryInfo info_rocm("Hip", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); + Ort::Allocator rocm_allocator(session, info_rocm); + + auto allocator_info = rocm_allocator.GetInfo(); + ASSERT_TRUE(info_rocm == allocator_info); + void* p = rocm_allocator.Alloc(1024); + ASSERT_NE(p, nullptr); + rocm_allocator.Free(p); + + auto mem_allocation = rocm_allocator.GetAllocation(1024); + ASSERT_NE(nullptr, mem_allocation.get()); + ASSERT_EQ(1024U, mem_allocation.size()); +} +#endif + TEST(CApiTest, io_binding) { Ort::SessionOptions session_options; Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(session_options, 1)); @@ -1937,7 +1962,7 @@ TEST(CApiTest, io_binding_cuda) { } #endif -#if defined(USE_CUDA) || defined(USE_TENSORRT) +#if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) TEST(CApiTest, basic_cuda_graph) { const auto& api = Ort::GetApi(); Ort::SessionOptions session_options; @@ -1955,7 +1980,7 @@ TEST(CApiTest, basic_cuda_graph) { ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( static_cast(session_options), rel_trt_options.get()) == nullptr); -#else +#elif defined(USE_CUDA) // Enable cuda graph in cuda provider option. OrtCUDAProviderOptionsV2* cuda_options = nullptr; ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); @@ -1968,34 +1993,55 @@ TEST(CApiTest, basic_cuda_graph) { ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2( static_cast(session_options), rel_cuda_options.get()) == nullptr); +#elif defined(USE_ROCM) + // Enable hip graph in rocm provider option. + OrtROCMProviderOptions* rocm_options = nullptr; + ASSERT_TRUE(api.CreateROCMProviderOptions(&rocm_options) == nullptr); + std::unique_ptr + rel_rocm_options(rocm_options, api.ReleaseROCMProviderOptions); + std::vector keys{"enable_hip_graph"}; + std::vector values{"1"}; + ASSERT_TRUE(api.UpdateROCMProviderOptions(rel_rocm_options.get(), keys.data(), values.data(), 1) == nullptr); + + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_ROCM( + static_cast(session_options), + rel_rocm_options.get()) == nullptr); #endif Ort::Session session(*ort_env, MODEL_URI, session_options); - Ort::MemoryInfo info_cuda("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); +#if defined(USE_ROCM) +// local hipify +#define cudaMemcpy hipMemcpy +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost + Ort::MemoryInfo info_mem("Hip", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); +#else + Ort::MemoryInfo info_mem("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); +#endif - Ort::Allocator cuda_allocator(session, info_cuda); - auto allocator_info = cuda_allocator.GetInfo(); - ASSERT_TRUE(info_cuda == allocator_info); + Ort::Allocator allocator(session, info_mem); + auto allocator_info = allocator.GetInfo(); + ASSERT_TRUE(info_mem == allocator_info); const std::array x_shape = {3, 2}; std::array x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; - auto input_data = cuda_allocator.GetAllocation(x_values.size() * sizeof(float)); + auto input_data = allocator.GetAllocation(x_values.size() * sizeof(float)); ASSERT_NE(input_data.get(), nullptr); - cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); + (void)cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); // Create an OrtValue tensor backed by data on CUDA memory - Ort::Value bound_x = Ort::Value::CreateTensor(info_cuda, reinterpret_cast(input_data.get()), x_values.size(), + Ort::Value bound_x = Ort::Value::CreateTensor(info_mem, reinterpret_cast(input_data.get()), x_values.size(), x_shape.data(), x_shape.size()); const std::array expected_y_shape = {3, 2}; std::array expected_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; - auto output_data = cuda_allocator.GetAllocation(expected_y.size() * sizeof(float)); + auto output_data = allocator.GetAllocation(expected_y.size() * sizeof(float)); ASSERT_NE(output_data.get(), nullptr); // Create an OrtValue tensor backed by data on CUDA memory - Ort::Value bound_y = Ort::Value::CreateTensor(info_cuda, reinterpret_cast(output_data.get()), + Ort::Value bound_y = Ort::Value::CreateTensor(info_mem, reinterpret_cast(output_data.get()), expected_y.size(), expected_y_shape.data(), expected_y_shape.size()); // Create IoBinding for inputs and outputs. @@ -2008,31 +2054,37 @@ TEST(CApiTest, basic_cuda_graph) { // Check the values against the bound raw memory (needs copying from device to host first) std::array y_values; - cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); + (void)cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); ASSERT_THAT(y_values, ::testing::ContainerEq(expected_y)); // Replay the captured CUDA graph session.Run(Ort::RunOptions(), binding); - cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); + (void)cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); ASSERT_THAT(y_values, ::testing::ContainerEq(expected_y)); // Change the input and replay the CUDA graph again. x_values = {10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f}; - cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); + (void)cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); binding.SynchronizeInputs(); session.Run(Ort::RunOptions(), binding); - cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); + (void)cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); expected_y = {10.0f, 40.0f, 90.0f, 160.0f, 250.0f, 360.0f}; ASSERT_THAT(y_values, ::testing::ContainerEq(expected_y)); // Clean up binding.ClearBoundInputs(); binding.ClearBoundOutputs(); +#if defined(USE_ROCM) +#undef cudaMemcpy +#undef cudaMemcpyHostToDevice +#undef cudaMemcpyDeviceToHost +#endif } -#ifndef REDUCED_OPS_BUILD // The following test uses some ops not supported in the reduced ops build +#ifndef REDUCED_OPS_BUILD +#if defined(USE_CUDA) || defined(USE_TENSORRT) TEST(CApiTest, cuda_graph_with_shape_nodes) { const auto& api = Ort::GetApi(); @@ -2053,10 +2105,34 @@ TEST(CApiTest, cuda_graph_with_shape_nodes) { // Successful loading of the ONNX model with shape nodes with cuda graph feature enabled Ort::Session session(*ort_env, TSTR("testdata/cuda_graph_with_shape_nodes.onnx"), session_options); } +#endif // defined(USE_CUDA) || defined(USE_TENSORRT) -#endif +#if defined(USE_ROCM) +TEST(CApiTest, hip_graph_with_shape_nodes) { + const auto& api = Ort::GetApi(); -#endif + // Enable hip graph in rocm provider option. + OrtROCMProviderOptions* rocm_options = nullptr; + ASSERT_TRUE(api.CreateROCMProviderOptions(&rocm_options) == nullptr); + std::unique_ptr + rel_rocm_options(rocm_options, api.ReleaseROCMProviderOptions); + std::vector keys{"enable_hip_graph"}; + std::vector values{"1"}; + ASSERT_TRUE(api.UpdateROCMProviderOptions(rel_rocm_options.get(), keys.data(), values.data(), 1) == nullptr); + + Ort::SessionOptions session_options; + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_ROCM( + static_cast(session_options), + rel_rocm_options.get()) == nullptr); + + // Successful loading of the ONNX model with shape nodes with hip graph feature enabled + Ort::Session session(*ort_env, TSTR("testdata/cuda_graph_with_shape_nodes.onnx"), session_options); +} +#endif // defined(USE_ROCM) + +#endif // REDUCED_OPS_BUILD + +#endif // defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) TEST(CApiTest, create_tensor) { const char* s[] = {"abc", "kmp"}; From 6ca7c1a933e57e0078d8d01eff3a1520098cfed1 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 22 Jan 2024 20:42:30 -0800 Subject: [PATCH 17/23] unet fusion for stable diffusion webui (#19227) ### Description Update unet fusion for [stable diffusion webui extension](https://github.com/tianleiwu/Stable-Diffusion-WebUI-OnnxRuntime): (1) Update fusion pattern to support fp16 unet model. (2) Add progress bar (3) Use a cached map to speed up dtype or shape lookup in shape inference result. ### Motivation and Context --- .../tools/transformers/fusion_attention.py | 14 +- .../transformers/fusion_attention_unet.py | 166 ++++++++++++++++-- .../tools/transformers/fusion_embedlayer.py | 18 +- .../tools/transformers/fusion_gemmfastgelu.py | 2 +- .../tools/transformers/fusion_nhwc_conv.py | 15 +- .../python/tools/transformers/fusion_shape.py | 8 +- .../python/tools/transformers/fusion_utils.py | 47 +++-- .../python/tools/transformers/import_utils.py | 20 +++ .../models/stable_diffusion/README.md | 2 +- .../python/tools/transformers/onnx_model.py | 98 ++++++++--- .../tools/transformers/onnx_model_bert.py | 16 +- .../tools/transformers/onnx_model_unet.py | 71 +++++++- 12 files changed, 395 insertions(+), 82 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/import_utils.py diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index d11cb91d98b0c..f48cabd25fc5c 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -129,6 +129,9 @@ def __init__( self.num_heads_warning = True self.hidden_size_warning = True + self.shape_infer = None + self.shape_infer_done = True + def get_num_heads_and_hidden_size_from_concat(self, concat: NodeProto) -> Tuple[int, int]: """ Detect num_heads and hidden_size from Concat node in the following subgraph: @@ -202,12 +205,15 @@ def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int] return num_heads, hidden_size def get_add_qk_str(self, add_qk: NodeProto): - shape_infer = self.model.infer_runtime_shape(update=True) - if shape_infer is None: + if not self.shape_infer_done: + self.shape_infer = self.model.infer_runtime_shape(update=True) + self.shape_infer_done = True + + if self.shape_infer is None: return None - input_0_shape = shape_infer.get_edge_shape(add_qk.input[0]) - input_1_shape = shape_infer.get_edge_shape(add_qk.input[1]) + input_0_shape = self.shape_infer.get_edge_shape(add_qk.input[0]) + input_1_shape = self.shape_infer.get_edge_shape(add_qk.input[1]) if input_0_shape is None or input_1_shape is None: logger.debug(f"one of the inputs of {add_qk} is None") diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py index 250ec5f3eb159..9a353e7e2d675 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_unet.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py @@ -28,10 +28,19 @@ def __init__( enable_packed_qkv: bool, enable_packed_kv: bool, ): - super().__init__(model, "MultiHeadAttention" if is_cross_attention else "Attention", ["LayerNormalization"]) + super().__init__( + model, + "Attention" if is_cross_attention and enable_packed_qkv else "MultiHeadAttention", + ["LayerNormalization"], + ) self.hidden_size = hidden_size self.num_heads = num_heads self.is_cross_attention = is_cross_attention + + # Note: pack Q/K/V or K/V weights into one tensor make it harder for updating initializers for LoRA. + # To support LoRA, it is better to use separated Q, K and V inputs in offline optimization, + # and CUDA operator pre-packs those tensors to preferred format based on available kernels. + # In this way, we can support LoRA and get optimal performance at same time. self.enable_packed_qkv = enable_packed_qkv self.enable_packed_kv = enable_packed_kv @@ -170,9 +179,7 @@ def create_attention_node( return None # Sometimes weights are stored in fp16 - if q_weight.data_type == 10: - logger.debug("weights are in fp16. Please run fp16 conversion after optimization") - return None + float_type = q_weight.data_type qw = NumpyHelper.to_array(q_weight) kw = NumpyHelper.to_array(k_weight) @@ -212,7 +219,7 @@ def create_attention_node( matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV") self.add_initializer( name=matmul_node_name + "_weight", - data_type=TensorProto.FLOAT, + data_type=float_type, dims=[qkv_weight.shape[0], qkv_weight.shape[1]], vals=qkv_weight, ) @@ -235,8 +242,11 @@ def create_attention_node( reshape_node = helper.make_node( "Reshape", - inputs=[matmul_node_name + "_out", matmul_node_name + "_reshape_shape"], - outputs=[attention_node_name + "_input"], + inputs=[ + matmul_node_name + "_out", + matmul_node_name + "_reshape_shape", + ], + outputs=[attention_node_name + "_qkv_input"], name=matmul_node_name + "_reshape", ) self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name @@ -251,7 +261,7 @@ def create_attention_node( self.add_initializer( name=attention_node_name + "_qkv_weight", - data_type=TensorProto.FLOAT, + data_type=float_type, dims=[qw_in_size, qkv_weight_dim], vals=qkv_weight, ) @@ -280,7 +290,7 @@ def create_attention_node( matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV") self.add_initializer( name=matmul_node_name + "_weight", - data_type=TensorProto.FLOAT, + data_type=float_type, dims=[kv_weight.shape[0], kv_weight.shape[1]], vals=kv_weight, ) @@ -303,8 +313,11 @@ def create_attention_node( reshape_node = helper.make_node( "Reshape", - inputs=[matmul_node_name + "_out", matmul_node_name + "_reshape_shape"], - outputs=[k_matmul.output[0]], + inputs=[ + matmul_node_name + "_out", + matmul_node_name + "_reshape_shape", + ], + outputs=[attention_node_name + "_kv_input"], name=matmul_node_name + "_reshape", ) self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name @@ -317,7 +330,7 @@ def create_attention_node( self.add_initializer( name=attention_node_name + "_qkv_bias", - data_type=TensorProto.FLOAT, + data_type=float_type, dims=[qkv_bias_dim], vals=qkv_bias, ) @@ -330,7 +343,7 @@ def create_attention_node( attention_node_name + "_qkv_bias", ] else: - attention_inputs = [attention_node_name + "_input"] + attention_inputs = [attention_node_name + "_qkv_input"] else: if not self.enable_packed_kv: attention_inputs = [ @@ -342,7 +355,7 @@ def create_attention_node( else: attention_inputs = [ q_matmul.output[0], - k_matmul.output[0], + attention_node_name + "_kv_input", ] attention_node = helper.make_node( @@ -839,6 +852,9 @@ def create_attention_node_lora( return attention_node def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + if self.fuse_a1111_fp16(normalize_node, input_name_to_nodes, output_name_to_node): + return + node_before_layernorm = self.model.match_parent(normalize_node, "Add", 0) # In SD 1.5, for self attention, LayerNorm has parent Reshape @@ -1168,3 +1184,125 @@ def match_lora_path( return (lora_mul_node, lora_matmul_1_node) return None + + def fuse_a1111_fp16(self, normalize_node, input_name_to_nodes, output_name_to_node): + """Fuse attention of fp16 UNet exported in A1111 (stable diffusion webui) extension""" + entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Add"], [0, 0]) + if entry_path is None: + entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Reshape"], [0, 0]) + if entry_path is None: + return False + _cast, node_before_layernorm = entry_path + + root_input = node_before_layernorm.output[0] + + children_nodes = input_name_to_nodes[root_input] + skip_add = None + for node in children_nodes: + if node.op_type == "Add": # SkipLayerNormalization fusion is not applied yet + skip_add = node + break + if skip_add is None: + return False + + match_qkv = self.match_qkv_a1111(root_input, skip_add) + if match_qkv is None: + return False + + ( + reshape_qkv, + transpose_qkv, + reshape_q, + matmul_q, + matmul_k, + matmul_v, + ) = match_qkv + + cast_q = self.model.match_parent(matmul_q, "Cast", 0) + cast_k = self.model.match_parent(matmul_k, "Cast", 0) + cast_v = self.model.match_parent(matmul_v, "Cast", 0) + if not ( + cast_q is not None + and cast_k is not None + and (cast_q == cast_k if not self.is_cross_attention else cast_q != cast_k) + and cast_k == cast_v + ): + return False + + if cast_q.input[0] != normalize_node.output[0]: + return False + + attention_last_node = reshape_qkv + + q_num_heads = self.get_num_heads(reshape_q, True) or self.get_num_heads(reshape_q, False) + if q_num_heads <= 0: + logger.debug("fuse_attention: failed to detect num_heads") + return False + + q_hidden_size = self.get_hidden_size(normalize_node) + + # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads + new_node = self.create_attention_node( + matmul_q, + matmul_k, + matmul_v, + q_num_heads, + q_hidden_size, + input=matmul_q.input[0], + output=attention_last_node.output[0], + ) + if new_node is None: + return False + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([attention_last_node, transpose_qkv]) + + # Use prune graph to remove nodes since they are shared by all attention nodes. + self.prune_graph = True + return True + + def match_qkv_a1111(self, root_input, skip_add): + """Match Q, K and V paths exported by A1111 (stable diffusion webui) extension""" + another_input = 1 if skip_add.input[0] == root_input else 0 + qkv_nodes = self.model.match_parent_path( + skip_add, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "Einsum"], + [another_input, None, None, 0, 0, 0], + ) + + if qkv_nodes is None: + return None + + (_, _, reshape_qkv, transpose_qkv, reshape_einsum, einsum_qkv) = qkv_nodes + + v_nodes = self.model.match_parent_path(einsum_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0]) + if v_nodes is None: + logger.debug("fuse_attention: failed to match v path") + return None + (_, _, _, matmul_v) = v_nodes + + qk_nodes = self.model.match_parent_path( + einsum_qkv, ["Cast", "Cast", "Softmax", "Mul", "Einsum"], [0, 0, 0, 0, None] + ) + if qk_nodes is not None: + (_, _, _softmax_qk, _, einsum_qk) = qk_nodes + else: + logger.debug("fuse_attention: failed to match qk path") + return None + + q_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0]) + if q_nodes is None: + logger.debug("fuse_attention: failed to match q path") + return None + (_, _transpose_q, reshape_q, matmul_q) = q_nodes + + k_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0]) + if k_nodes is None: + logger.debug("fuse_attention: failed to match k path") + return None + + (_, _, _, matmul_k) = k_nodes + + return reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v diff --git a/onnxruntime/python/tools/transformers/fusion_embedlayer.py b/onnxruntime/python/tools/transformers/fusion_embedlayer.py index bc38399e3cce5..42156d9123383 100644 --- a/onnxruntime/python/tools/transformers/fusion_embedlayer.py +++ b/onnxruntime/python/tools/transformers/fusion_embedlayer.py @@ -28,7 +28,9 @@ def __init__(self, model: OnnxModel, description: str = "no mask"): description, ) self.utils = FusionUtils(model) - self.shape_infer_helper = self.model.infer_runtime_shape({}, update=True) + self.shape_infer = None + self.shape_infer_done = False + # The following will be reset in each fuse call of FusionEmbedLayerNormalization self.attention = None self.embed_node = None @@ -329,9 +331,13 @@ def check_embedding(self, word_embedding_gather, segment_embedding_gather, posit segment_ids = segment_embedding_gather.input[1] if segment_embedding_gather else None position_ids = position_embedding_gather.input[1] - if self.shape_infer_helper is not None: - input_ids_shape = self.shape_infer_helper.get_edge_shape(input_ids) - position_ids_shape = self.shape_infer_helper.get_edge_shape(position_ids) + if not self.shape_infer_done: + self.shape_infer = self.model.infer_runtime_shape(update=True) + self.shape_infer_done = True + + if self.shape_infer is not None: + input_ids_shape = self.shape_infer.get_edge_shape(input_ids) + position_ids_shape = self.shape_infer.get_edge_shape(position_ids) assert input_ids_shape and position_ids_shape if not ( len(input_ids_shape) == 2 @@ -345,11 +351,11 @@ def check_embedding(self, word_embedding_gather, segment_embedding_gather, posit ) return False - if segment_ids and not self.shape_infer_helper.compare_shape(input_ids, segment_ids): + if segment_ids and not self.shape_infer.compare_shape(input_ids, segment_ids): logger.info( "Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {} != {}".format( input_ids_shape, - self.shape_infer_helper.get_edge_shape(segment_ids), + self.shape_infer.get_edge_shape(segment_ids), ) ) return False diff --git a/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py b/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py index f1d803a3cc082..4d9913f427b37 100644 --- a/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py @@ -32,7 +32,7 @@ def get_dimensions(self, input_name: str) -> Union[int, None]: return self.get_dimensions_from_tensor_proto(graph_input) if not self.shape_infer_done: - self.shape_infer = self.model.infer_runtime_shape({}, update=True) + self.shape_infer = self.model.infer_runtime_shape(update=True) self.shape_infer_done = True if self.shape_infer is not None: diff --git a/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py b/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py index 141ebb1f95a11..5233fdf272fbd 100644 --- a/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py +++ b/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py @@ -7,7 +7,8 @@ from typing import List from fusion_base import Fusion -from onnx import TensorProto, helper, numpy_helper +from fusion_utils import FusionUtils +from onnx import helper, numpy_helper from onnx_model import OnnxModel logger = getLogger(__name__) @@ -19,6 +20,7 @@ class FusionNhwcConv(Fusion): def __init__(self, model: OnnxModel, update_weight=False): super().__init__(model, "NhwcConv", ["Conv"], "NhwcConv") self.update_weight = update_weight + self.fusion_utils = FusionUtils(model) def create_transpose_node(self, input_name: str, perm: List[int], output_name=None): """Append a Transpose node after an input""" @@ -49,6 +51,15 @@ def fuse(self, conv, input_name_to_nodes, output_name_to_node): if len(weight.shape) != 4: return + dtype = self.model.get_dtype(nhwc_conv_input) + if not (dtype is not None and weight_tensor.data_type == dtype): + cast_node = self.fusion_utils.add_cast_node( + input_name=nhwc_conv_input, + to_type=weight_tensor.data_type, + output_name_to_node=output_name_to_node, + ) + nhwc_conv_input = cast_node.output[0] + if self.update_weight: # Transpose weights from NCHW to NHWC weight = weight.transpose(0, 2, 3, 1) @@ -56,7 +67,7 @@ def fuse(self, conv, input_name_to_nodes, output_name_to_node): weight_name = node_name + "_weight_NHWC" self.add_initializer( name=weight_name, - data_type=TensorProto.FLOAT, + data_type=weight_tensor.data_type, dims=list(weight.shape), vals=weight, ) diff --git a/onnxruntime/python/tools/transformers/fusion_shape.py b/onnxruntime/python/tools/transformers/fusion_shape.py index bc32d78eda66c..dfa77fc7d0221 100644 --- a/onnxruntime/python/tools/transformers/fusion_shape.py +++ b/onnxruntime/python/tools/transformers/fusion_shape.py @@ -29,12 +29,12 @@ def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[i return None def get_dimensions(self, input_name: str) -> Union[int, None]: - graph_input = self.model.find_graph_input(input_name) - if graph_input: - return self.get_dimensions_from_tensor_proto(graph_input) + shape = self.model.get_shape(input_name) + if shape is not None: + return len(shape) if not self.shape_infer_done: - self.shape_infer = self.model.infer_runtime_shape({}, update=True) + self.shape_infer = self.model.infer_runtime_shape(update=True) self.shape_infer_done = True if self.shape_infer is not None: diff --git a/onnxruntime/python/tools/transformers/fusion_utils.py b/onnxruntime/python/tools/transformers/fusion_utils.py index afc968fab46c1..726c587ff7043 100644 --- a/onnxruntime/python/tools/transformers/fusion_utils.py +++ b/onnxruntime/python/tools/transformers/fusion_utils.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Tuple +from typing import Optional, Tuple import numpy from numpy import array_equal, ndarray @@ -29,17 +29,7 @@ def cast_graph_input_to_int32(self, input_name: str) -> Tuple[bool, str]: return False, input_name def cast_input(self, input_name: str, target_type="int32"): - cast_output = input_name + "_" + target_type - - # Avoid consequent Cast nodes. - inputs = [input_name] - output_name_to_node = self.model.output_name_to_node() - if input_name in output_name_to_node: - parent_node = output_name_to_node[input_name] - if parent_node and parent_node.op_type == "Cast": - inputs = [parent_node.input[0]] - - cast_node = helper.make_node("Cast", inputs=inputs, outputs=[cast_output]) + output_name = input_name + "_" + target_type if target_type == "int32": to_type = int(TensorProto.INT32) @@ -50,10 +40,36 @@ def cast_input(self, input_name: str, target_type="int32"): else: raise ValueError("Invalid target_type: {target_type}") + cast_node = self.add_cast_node(input_name, to_type, output_name) + + return output_name, cast_node + + def add_cast_node( + self, + input_name: str, + to_type: int, + output_name: Optional[str] = None, + output_name_to_node=None, + graph_name: Optional[str] = None, + ): + if output_name is None: + output_name = input_name + f"_cast_to_{to_type}" + + # Avoid consequent Cast nodes. + inputs = [input_name] + if output_name_to_node is None: + output_name_to_node = self.model.output_name_to_node() + if input_name in output_name_to_node: + parent_node = output_name_to_node[input_name] + if parent_node and parent_node.op_type == "Cast": + inputs = [parent_node.input[0]] + + cast_node = helper.make_node("Cast", inputs=inputs, outputs=[output_name]) + cast_node.attribute.extend([helper.make_attribute("to", to_type)]) - self.model.add_node(cast_node) + self.model.add_node(cast_node, graph_name=graph_name) - return cast_output, cast_node + return cast_node def cast_input_to_int32(self, input_name: str): return self.cast_input(input_name, "int32") @@ -224,9 +240,10 @@ def check_node_input_value(self, node, input_index: int, expected_value): def remove_identity_nodes(self): """Remove Identity nodes, except those right before graph output.""" nodes_to_remove = [] + graph_output_names = self.model.get_graphs_output_names() for node in self.model.nodes(): if node.op_type == "Identity": - if node.output[0] not in self.model.get_graphs_output_names(): + if node.output[0] not in graph_output_names: self.model.replace_input_of_all_nodes(node.output[0], node.input[0]) nodes_to_remove.append(node) diff --git a/onnxruntime/python/tools/transformers/import_utils.py b/onnxruntime/python/tools/transformers/import_utils.py new file mode 100644 index 0000000000000..9755a26b7b004 --- /dev/null +++ b/onnxruntime/python/tools/transformers/import_utils.py @@ -0,0 +1,20 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import importlib.metadata +import importlib.util + + +def is_installed(package): + try: + dist = importlib.metadata.distribution(package) + except importlib.metadata.PackageNotFoundError: + try: + spec = importlib.util.find_spec(package) + except ModuleNotFoundError: + return False + + return spec is not None + + return dist is not None diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index b10c10c87ee57..8607485bc265b 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -51,7 +51,7 @@ sh build.sh --config Release --build_shared_lib --parallel --use_cuda --cuda_ve --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 \ --allow_running_as_root python3 -m pip install --upgrade pip -python3 -m pip install build/Linux/Release/dist/onnxruntime_gpu-1.17.0-cp310-cp310-linux_x86_64.whl --force-reinstall +python3 -m pip install build/Linux/Release/dist/onnxruntime_gpu-*.whl --force-reinstall ``` If the GPU is not A100, change `CMAKE_CUDA_ARCHITECTURES=80` in the command line according to the GPU compute capacity (like 89 for RTX 4090, or 86 for RTX 3090). diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 37b39c91b5c15..9d1066b6e372b 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -40,6 +40,12 @@ def initialize(self, model): self.enable_shape_infer: bool = True self.all_graphs: Optional[List[GraphProto]] = None + # Cache of shape and data type from onnx graph to speed up optimization. + # Be careful that fusion shall not reuse node output name for different shape/type (in adding/removing nodes) + # Note that these do not cache the symbolic shape inference result. + self._dtype_dict: Optional[Dict[str, int]] = None + self._shape_dict: Optional[Dict[str, List]] = None + def disable_shape_inference(self): self.enable_shape_infer = False @@ -519,20 +525,60 @@ def tensor_shape_to_list(self, tensor_type): shape_list.append("?") # shall not happen return shape_list - def get_dtype(self, input_or_output: str): - """Try get data type given a name (could be initializer, graph input or output).""" - tensor_type_map = {obj.name: obj.type for obj in self.model.graph.value_info} + def get_dtype(self, name: str, symbolic_shape_helper: Optional[SymbolicShapeInferenceHelper] = None): + """Try get data type given a name (could be initializer, input or output of graph or node).""" + + if self._dtype_dict is None: + self._dtype_dict = {} + for value_info in itertools.chain( + self.model.graph.value_info, + self.model.graph.input, + self.model.graph.output, + ): + self._dtype_dict[value_info.name] = value_info.type.tensor_type.elem_type + + for initializer in self.model.graph.initializer: + if initializer.name not in self._dtype_dict: + self._dtype_dict[initializer.name] = initializer.data_type - if input_or_output in tensor_type_map: - return tensor_type_map[input_or_output].tensor_type.elem_type + if name in self._dtype_dict: + return self._dtype_dict[name] - graph_input = self.find_graph_input(input_or_output) - if graph_input: - return graph_input.type.tensor_type.elem_type + if symbolic_shape_helper is not None and name in symbolic_shape_helper.known_vi_: + value_info = symbolic_shape_helper.known_vi_[name] + return value_info.type.tensor_type.elem_type + + return None - graph_output = self.find_graph_output(input_or_output) - if graph_output: - return graph_output.type.tensor_type.elem_type + def get_shape(self, name: str, symbolic_shape_helper: Optional[SymbolicShapeInferenceHelper] = None): + """Try get shape given a name (could be initializer, input or output of graph or node).""" + + if self._shape_dict is None: + self._shape_dict = {} + for value_info in itertools.chain( + self.model.graph.value_info, + self.model.graph.input, + self.model.graph.output, + ): + if value_info.type.tensor_type.HasField("shape"): + shape = [] + for dim in value_info.type.tensor_type.shape.dim: + if dim.dim_param: + shape.append(dim.dim_param) + else: + shape.append(dim.dim_value) + self._shape_dict[value_info.name] = shape + + for initializer in self.model.graph.initializer: + if initializer.name not in self._shape_dict: + self._shape_dict[initializer.name] = initializer.dims + + if name in self._shape_dict: + return self._shape_dict[name] + + if symbolic_shape_helper is not None and name in symbolic_shape_helper.known_vi_: + value_info = symbolic_shape_helper.known_vi_[name] + return value_info.type.tensor_type.elem_type return None @@ -566,23 +612,14 @@ def remove_cascaded_cast_nodes(self): def remove_useless_cast_nodes(self): """Remove cast nodes that are not needed: input and output has same data type.""" shape_infer = self.infer_runtime_shape(update=True) - if shape_infer is None: - logger.info("Skip removing useless cast nodes since shape inference failed.") - return - - def get_data_type(input_or_output_name): - dtype = self.get_dtype(input_or_output_name) - if dtype: - return dtype - if shape_infer.known_vi_[input_or_output_name].type.tensor_type.HasField("elem_type"): - return shape_infer.known_vi_[input_or_output_name].type.tensor_type.elem_type - return None + if self.enable_shape_infer and shape_infer is None: + logger.warning("shape inference failed which might impact useless cast node detection.") nodes_to_remove = [] for node in self.nodes(): if node.op_type == "Cast": - input_dtype = get_data_type(node.input[0]) - output_dtype = get_data_type(node.output[0]) + input_dtype = self.get_dtype(node.input[0], shape_infer) + output_dtype = self.get_dtype(node.output[0], shape_infer) if input_dtype and input_dtype == output_dtype: nodes_to_remove.append(node) @@ -601,7 +638,10 @@ def get_data_type(input_or_output_name): self.replace_input_of_all_nodes(node.output[0], node.input[0]) self.remove_node(node) - logger.info("Removed %d Cast nodes with output type same as input", len(nodes_to_remove)) + logger.info( + "Removed %d Cast nodes with output type same as input", + len(nodes_to_remove), + ) def convert_model_float32_to_float16(self, cast_input_output=True): logger.warning( @@ -1214,7 +1254,10 @@ def remove_duplicated_initializer(self, cache: Optional[dict] = None): continue for j in range(i + 1, initializer_count): if OnnxModel.has_same_value( - self.model.graph.initializer[i], self.model.graph.initializer[j], cache, cache + self.model.graph.initializer[i], + self.model.graph.initializer[j], + cache, + cache, ): same[j] = i @@ -1223,7 +1266,8 @@ def remove_duplicated_initializer(self, cache: Optional[dict] = None): if same[i] >= 0: count += 1 self.replace_input_of_all_nodes( - self.model.graph.initializer[i].name, self.model.graph.initializer[same[i]].name + self.model.graph.initializer[i].name, + self.model.graph.initializer[same[i]].name, ) if count > 0: diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 51deb67ce5bf3..431e64509e3cc 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -126,7 +126,8 @@ def fuse_rotary_embeddings(self): # Remove non-MS domain functions rot_emb_nodes = list( filter( - lambda node: node.op_type == "RotaryEmbedding" and node.domain != "com.microsoft", self.model.graph.node + lambda node: node.op_type == "RotaryEmbedding" and node.domain != "com.microsoft", + self.model.graph.node, ) ) non_ms_domains_to_keep = set(map(lambda node: node.domain, rot_emb_nodes)) @@ -350,7 +351,11 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo self.attention_mask.set_mask_format(options.attention_mask_format) if options.use_multi_head_attention and not isinstance(self.attention_fusion, FusionBartAttention): self.attention_fusion = FusionAttention( - self, self.hidden_size, self.num_heads, self.attention_mask, options.use_multi_head_attention + self, + self.hidden_size, + self.num_heads, + self.attention_mask, + options.use_multi_head_attention, ) if (options is None) or options.enable_attention: @@ -415,7 +420,12 @@ def get_fused_operator_statistics(self): "SkipSimplifiedLayerNormalization", "RotaryEmbedding", ] - q_ops = ["QOrderedAttention", "QOrderedGelu", "QOrderedLayerNormalization", "QOrderedMatMul"] + q_ops = [ + "QOrderedAttention", + "QOrderedGelu", + "QOrderedLayerNormalization", + "QOrderedMatMul", + ] for op in ops + q_ops: nodes = self.get_nodes_by_op_type(op) op_count[op] = len(nodes) diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py index 4d15b9288e7b6..01298b3576eb1 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_unet.py +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from logging import getLogger +import logging from typing import Optional from fusion_attention_unet import FusionAttentionUnet @@ -14,11 +14,12 @@ from fusion_options import FusionOptions from fusion_skip_group_norm import FusionSkipGroupNorm from fusion_transpose import FusionInsertTranspose, FusionTranspose +from import_utils import is_installed from onnx import ModelProto from onnx_model import OnnxModel from onnx_model_bert import BertOnnxModel -logger = getLogger(__name__) +logger = logging.getLogger(__name__) class UnetOnnxModel(BertOnnxModel): @@ -94,14 +95,24 @@ def fuse_multi_head_attention(self, options: Optional[FusionOptions] = None): # Self Attention enable_packed_qkv = (options is None) or options.enable_packed_qkv self_attention_fusion = FusionAttentionUnet( - self, self.hidden_size, self.num_heads, False, enable_packed_qkv, False + self, + self.hidden_size, + self.num_heads, + is_cross_attention=False, + enable_packed_qkv=enable_packed_qkv, + enable_packed_kv=False, ) self_attention_fusion.apply() # Cross Attention enable_packed_kv = (options is None) or options.enable_packed_kv cross_attention_fusion = FusionAttentionUnet( - self, self.hidden_size, self.num_heads, True, False, enable_packed_kv + self, + self.hidden_size, + self.num_heads, + is_cross_attention=True, + enable_packed_qkv=False, + enable_packed_kv=enable_packed_kv, ) cross_attention_fusion.apply() @@ -110,23 +121,48 @@ def fuse_bias_add(self): fusion.apply() def optimize(self, options: Optional[FusionOptions] = None): + if is_installed("tqdm"): + import tqdm + from tqdm.contrib.logging import logging_redirect_tqdm + + with logging_redirect_tqdm(): + steps = 18 + progress_bar = tqdm.tqdm(range(0, steps), initial=0, desc="fusion") + self._optimize(options, progress_bar) + else: + logger.info("tqdm is not installed. Run optimization without progress bar") + self._optimize(options, None) + + def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): if (options is not None) and not options.enable_shape_inference: self.disable_shape_inference() self.utils.remove_identity_nodes() + if progress_bar: + progress_bar.update(1) # Remove cast nodes that having same data type of input and output based on symbolic shape inference. self.utils.remove_useless_cast_nodes() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_layer_norm: self.fuse_layer_norm() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_gelu: self.fuse_gelu() + if progress_bar: + progress_bar.update(1) self.preprocess() + if progress_bar: + progress_bar.update(1) self.fuse_reshape() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_group_norm: channels_last = (options is None) or options.group_norm_channels_last @@ -135,42 +171,66 @@ def optimize(self, options: Optional[FusionOptions] = None): insert_transpose_fusion = FusionInsertTranspose(self) insert_transpose_fusion.apply() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_bias_splitgelu: bias_split_gelu_fusion = FusionBiasSplitGelu(self) bias_split_gelu_fusion.apply() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_attention: + # self.save_model_to_file("before_mha.onnx") self.fuse_multi_head_attention(options) + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_skip_layer_norm: self.fuse_skip_layer_norm() + if progress_bar: + progress_bar.update(1) self.fuse_shape() + if progress_bar: + progress_bar.update(1) # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. self.utils.remove_useless_reshape_nodes() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_skip_group_norm: skip_group_norm_fusion = FusionSkipGroupNorm(self) skip_group_norm_fusion.apply() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_bias_skip_layer_norm: # Fuse SkipLayerNormalization and Add Bias before it. self.fuse_add_bias_skip_layer_norm() + if progress_bar: + progress_bar.update(1) if options is not None and options.enable_gelu_approximation: self.gelu_approximation() + if progress_bar: + progress_bar.update(1) if options is None or options.enable_nhwc_conv: self.convert_conv_to_nhwc() - self.merge_adjacent_transpose() + if progress_bar: + progress_bar.update(1) if options is not None and options.enable_bias_add: self.fuse_bias_add() + if progress_bar: + progress_bar.update(1) self.postprocess() + if progress_bar: + progress_bar.update(1) logger.info(f"opset version: {self.get_opset_version()}") @@ -190,6 +250,7 @@ def get_fused_operator_statistics(self): "NhwcConv", "BiasAdd", ] + for op in ops: nodes = self.get_nodes_by_op_type(op) op_count[op] = len(nodes) From 61610ff9862ad834f153ed3e70ba526dac86ae7c Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Wed, 24 Jan 2024 00:25:05 +0800 Subject: [PATCH 18/23] [js/webgpu] Add FusedConv clip test case (#18900) Bug: https://github.com/microsoft/onnxruntime/issues/18899 --- js/web/test/data/ops/fused-conv.jsonc | 34 +++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc index 812e9d7c2def0..ad1c0a72c11d3 100644 --- a/js/web/test/data/ops/fused-conv.jsonc +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -108,5 +108,39 @@ ] } ] + }, + { + "name": "fused conv with clip", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "Clip", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [400.0, 600.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [400, 470, 600, 600], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] } ] From 0ea48fc73ec6bdbb8af2010483a61823fcf1a613 Mon Sep 17 00:00:00 2001 From: Heflin Stephen Raj Date: Tue, 23 Jan 2024 23:40:54 +0530 Subject: [PATCH 19/23] Modified the condition to load the optimiser model (#18891) --- java/src/main/native/ai_onnxruntime_OrtTrainingSession.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c index 9f7b8d3a3dcfc..464234c34798a 100644 --- a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c @@ -66,7 +66,7 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtTrainingSession_createTrainingSes } } wchar_t* optimizerStr = NULL; - if (optimizerPath == NULL) { + if (optimizerPath != NULL) { optimizerStr = copyAndPad(jniEnv, optimizerPath); if (optimizerStr == NULL) { // exception has been thrown in Java, go to cleanup and return null. From 54871a27736cf54cbda9c4f09bb27e931de7334e Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Wed, 24 Jan 2024 02:49:24 +0800 Subject: [PATCH 20/23] Replace T4 to A10 in Linux GPU workflow (#19205) ### Description 1. Update Linux GPU machine from T4 to A10, sm=8.6 2. update the tolerance ### Motivation and Context 1. Free more T4 and test with higher compute capability. 2. ORT enables TF32 in GEMM for A10/100. TF32 will cause precsion loss and fail this test ``` 2024-01-19T13:27:18.8302842Z [ RUN ] ModelTests/ModelTest.Run/cuda__models_zoo_opset12_SSD_ssd12 2024-01-19T13:27:25.8438153Z /onnxruntime_src/onnxruntime/test/providers/cpu/model_tests.cc:347: Failure 2024-01-19T13:27:25.8438641Z Expected equality of these values: 2024-01-19T13:27:25.8438841Z COMPARE_RESULT::SUCCESS 2024-01-19T13:27:25.8439276Z Which is: 4-byte object <00-00 00-00> 2024-01-19T13:27:25.8439464Z ret.first 2024-01-19T13:27:25.8445514Z Which is: 4-byte object <01-00 00-00> 2024-01-19T13:27:25.8445962Z expected 0.145984 (3e157cc1), got 0.975133 (3f79a24b), diff: 0.829149, tol=0.0114598 idx=375. 20 of 388 differ 2024-01-19T13:27:25.8446198Z 2024-01-19T13:27:25.8555736Z [ FAILED ] ModelTests/ModelTest.Run/cuda__models_zoo_opset12_SSD_ssd12, where GetParam() = "cuda_../models/zoo/opset12/SSD/ssd-12.onnx" (7025 ms) 2024-01-19T13:27:25.8556077Z [ RUN ] ModelTests/ModelTest.Run/cuda__models_zoo_opset12_YOLOv312_yolov312 2024-01-19T13:27:29.3174318Z /onnxruntime_src/onnxruntime/test/providers/cpu/model_tests.cc:347: Failure 2024-01-19T13:27:29.3175144Z Expected equality of these values: 2024-01-19T13:27:29.3175389Z COMPARE_RESULT::SUCCESS 2024-01-19T13:27:29.3175812Z Which is: 4-byte object <00-00 00-00> 2024-01-19T13:27:29.3176080Z ret.first 2024-01-19T13:27:29.3176322Z Which is: 4-byte object <01-00 00-00> 2024-01-19T13:27:29.3178431Z expected 4.34958 (408b2fb8), got 4.51324 (40906c80), diff: 0.16367, tol=0.0534958 idx=9929. 22 of 42588 differ ``` 3. some other test like SSD throw other exception, so skip them ''' 2024-01-22T09:07:40.8446910Z [ RUN ] ModelTests/ModelTest.Run/cuda__models_zoo_opset12_SSD_ssd12 2024-01-22T09:07:51.5587571Z /onnxruntime_src/onnxruntime/test/providers/cpu/model_tests.cc:358: Failure 2024-01-22T09:07:51.5588512Z Expected equality of these values: 2024-01-22T09:07:51.5588870Z COMPARE_RESULT::SUCCESS 2024-01-22T09:07:51.5589467Z Which is: 4-byte object <00-00 00-00> 2024-01-22T09:07:51.5589953Z ret.first 2024-01-22T09:07:51.5590462Z Which is: 4-byte object <01-00 00-00> 2024-01-22T09:07:51.5590841Z expected 1, got 63 ''' --- .../test/global_thread_pools/test_inference.cc | 8 +++++++- onnxruntime/test/providers/cpu/model_tests.cc | 17 +++++++++++++++++ .../providers/cuda/nhwc/conv_transpose_test.cc | 6 +++++- .../azure-pipelines/linux-gpu-ci-pipeline.yml | 4 ++-- 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/global_thread_pools/test_inference.cc b/onnxruntime/test/global_thread_pools/test_inference.cc index 4772e7de2bdd7..f553682975f11 100644 --- a/onnxruntime/test/global_thread_pools/test_inference.cc +++ b/onnxruntime/test/global_thread_pools/test_inference.cc @@ -55,9 +55,15 @@ static void RunSession(OrtAllocator& allocator, Ort::Session& session_object, // size_t total_len = type_info.GetElementCount(); ASSERT_EQ(values_y.size(), static_cast(5)); +// test inference is using onnxruntime_shared_lib_test_LIBS, so HasCudaEnvironment(800) isn't available +#ifdef USE_CUDA + const float tolerance = 1e-5f; +#else + const float tolerance = 1e-6f; +#endif OutT* f = output_tensor->GetTensorMutableData(); for (size_t i = 0; i != static_cast(5); ++i) { - ASSERT_NEAR(values_y[i], f[i], 1e-6f); + ASSERT_NEAR(values_y[i], f[i], tolerance); } } diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index 859e082716760..8128c170c5211 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -39,6 +39,8 @@ #include "core/providers/armnn/armnn_provider_factory.h" #endif +#include "test/common/cuda_op_test_utils.h" + // test infrastructure #include "test/onnx/testenv.h" #include "test/onnx/TestCase.h" @@ -94,6 +96,21 @@ TEST_P(ModelTest, Run) { std::unique_ptr model_info = std::make_unique(model_path.c_str()); +#if defined(__linux__) + // ORT enables TF32 in GEMM for A100. TF32 will cause precsion loss and fail this test. + if (HasCudaEnvironment(800) && provider_name == "cuda") { + per_sample_tolerance = 1e-1; + if (model_path.find(ORT_TSTR("SSD")) > 0 || + model_path.find(ORT_TSTR("ssd")) > 0 || + model_path.find(ORT_TSTR("yolov3")) > 0 || + model_path.find(ORT_TSTR("mask_rcnn")) > 0 || + model_path.find(ORT_TSTR("FNS")) > 0) { + SkipTest("Skipping SSD test for big tolearance failure or other errors"); + return; + } + } +#endif + if (model_info->HasDomain(ONNX_NAMESPACE::AI_ONNX_TRAINING_DOMAIN) || model_info->HasDomain(ONNX_NAMESPACE::AI_ONNX_PREVIEW_TRAINING_DOMAIN)) { SkipTest("it has the training domain. No pipeline should need to run these tests."); diff --git a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc index 06da2a5304716..6514feadf0ff7 100644 --- a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc @@ -70,7 +70,11 @@ TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcBias) { auto op = ConvTransposeOp{.input_dims = {1, 8, 80, 80}, .kernel_shape = {5, 5}, .channels = 16, .bias = true}; - MAKE_PROVIDERS_EPS_TYPE(TypeParam) + if (HasCudaEnvironment(800)) { + MAKE_PROVIDERS_EPS(1e-2) + } else { + MAKE_PROVIDERS_EPS_TYPE(TypeParam) + } } TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcPad) { diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index 1060a0138e0b7..5779b1da3fd43 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -137,7 +137,7 @@ jobs: --enable_cuda_profiling --enable_cuda_nhwc_ops \ --enable_pybind --build_java \ --use_cache \ - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75; \ + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86; \ ccache -sv; \ ccache -z" workingDirectory: $(Build.SourcesDirectory) @@ -166,7 +166,7 @@ jobs: skipComponentGovernanceDetection: true workspace: clean: all - pool: Onnxruntime-Linux-GPU-T4 + pool: onnxruntime-Linux-GPU-A10 dependsOn: - Linux_Build steps: From f53068446e7e560012862e1812270bcf908fbda4 Mon Sep 17 00:00:00 2001 From: petermcaughan Date: Tue, 23 Jan 2024 13:44:34 -0800 Subject: [PATCH 21/23] Add Temperature to WhisperBeamSearch input (#19188) ### Description Add `temperature` as an input to WhisperBeamSearch op and initialize correctly in parameter setup. ### Motivation and Context Currently, temperature is included as an attribute to the BeamSearch op, which doesn't let the model act dynamically in a single inference session. By including this variable as an input, the temperature value can be altered in any inference call (important for 1P teams) --------- Co-authored-by: Peter McAughan Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Co-authored-by: Kunal Vaishnavi --- docs/ContribOperators.md | 4 +++- docs/OperatorKernels.md | 4 ++-- .../cpu/transformers/beam_search_parameters.cc | 14 +++++++++++++- .../contrib_ops/cuda/transformers/beam_search.cc | 1 + onnxruntime/core/graph/contrib_ops/contrib_defs.cc | 1 + 5 files changed, 20 insertions(+), 4 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 22e82443167f6..624cda1d37f73 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -5761,7 +5761,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
-#### Inputs (5 - 14) +#### Inputs (5 - 15)
input_ids : F
@@ -5792,6 +5792,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
extra_decoding_ids (optional) : I
Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.
+
temperature (optional) : T
+
Temperature value to apply to logits processing during this execution's decoding. Shape is (1)
#### Outputs (1 - 5) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 9a2a7ac89bbb3..3b695af2839b6 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -499,7 +499,7 @@ Do not modify directly.* |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int64)| |Unique|*in* x:**T**
*out* y:**T**
*out* idx:**tensor(int64)**
*out* counts:**tensor(int64)**|1+|**T** = tensor(float)| -|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float)| +|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*in* temperature:**T**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float)| |WordConvEmbedding|*in* Sequence:**T**
*in* W:**T1**
*in* B:**T1**
*in* C:**T1**
*out* Y:**T1**|1+|**T** = tensor(int32)
**T1** = tensor(float)| | | | | @@ -876,7 +876,7 @@ Do not modify directly.* |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |UnfoldTensor|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)| +|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*in* temperature:**T**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)| | | | | diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 3962486d5b5eb..bb6885c3216bc 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -123,8 +123,20 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { logits_processor = logits_processor_tensor ? static_cast(*logits_processor_tensor->Data()) : 0; ORT_ENFORCE(logits_processor >= 0, "logits_processor shall be a non-negative integer, got ", logits_processor); -} + if (this->model_type == IGenerationParameters::kModelTypeWhisper) { + auto* temperature_tensor = context->Input(14); + if (temperature_tensor) { + if (temperature_tensor->IsDataType()) { + temperature = *temperature_tensor->Data(); + } else { + temperature = static_cast(*temperature_tensor->Data()); + } + } else { + temperature = 1.0f; + } + } +} void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) { // Override vocab_size using the inferred shape from the decoder subgraph ONLY IF // the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch) diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc index 2a90e4911f286..08cbb145a6f65 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc @@ -49,6 +49,7 @@ ONNX_OPERATOR_KERNEL_EX( .InputMemoryType(OrtMemTypeCPUInput, 9) // 'attention_mask' needs to be on CPU .InputMemoryType(OrtMemTypeCPUInput, 10) // 'decoder_input_ids' needs to be on CPU .InputMemoryType(OrtMemTypeCPUInput, 11) // 'logits_processor' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 14) // 'temperature' needs to be on CPU .OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU .OutputMemoryType(OrtMemTypeCPUOutput, 1) // 'sequences_scores' output on CPU .TypeConstraint("T", {DataTypeImpl::GetTensorType(), diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 982e8fd834b76..27c968a59eb91 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1231,6 +1231,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, "In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) " "are treated as stop of the extra_decoding_ids for corresponding batch.", "I", OpSchema::Optional) + .Input(14, "temperature", "Temperature value to apply to logits processing during this execution's decoding. Shape is (1)", "T", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", From 532f8c642ce9c1ea2971b7d0f0ff8a4197bcb3a0 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 23 Jan 2024 14:57:30 -0800 Subject: [PATCH 22/23] Fix a backend test by using local backend (#19230) The decomposition pass (e.g., converting torch.add to aten.add) in DORT no longer exists. Therefore, we have to use `use_aot_autograd=True` to enable Dynamo's built-in operator decomposition. I think we need to add the decomposition pass back to DORT or remove `use_aot_autograd` (remove because it will always be `true`). --- .../orttraining/test/python/orttraining_test_dort.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_dort.py b/orttraining/orttraining/test/python/orttraining_test_dort.py index f0b6b9c5fba28..573ec85d76013 100644 --- a/orttraining/orttraining/test/python/orttraining_test_dort.py +++ b/orttraining/orttraining/test/python/orttraining_test_dort.py @@ -216,7 +216,12 @@ def elementwise_model(tensor_x: torch.Tensor): tensor_q = tensor_p.relu() return tensor_q - local_backend = make_local_backend(dynamic=True, use_aot_autograd=False) + # TODO: Set use_aot_autograd=False. In order to decompose torch + # function calls to aten ops, we need to set + # user_aot_autograd=True because there is no decomposition in DORT + # anymore. A long-term fix will be brining # decomposition pass back + # into DORT. + local_backend = make_local_backend(dynamic=True, use_aot_autograd=True) optimized_elementwise_model = torch.compile(elementwise_model, backend=local_backend, dynamic=True) def run(fun, list_x): From cbb29d80ff5ec63d3cc2289911c4420f5a9d8a2d Mon Sep 17 00:00:00 2001 From: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Date: Tue, 23 Jan 2024 16:34:26 -0800 Subject: [PATCH 23/23] GQA Rotary and Packed QKV with Flash (#18906) ### Description These changes add rotary embedding and packed qkv input to gqa. As of now, the changes are only supported with Flash-Attention (SM >= 80) but should soon be supported with Memory Efficient Attention as well. ### Motivation and Context With the fusion of rotary embedding into this Attention op, we hope to observe some perf gain. The packed QKV should also provide some perf gain in the context of certain models, like Llama2, that would benefit from running ops on the fused QKV matrix, rather than the separate Q, K, and V. --------- Co-authored-by: Yufeng Li --- docs/ContribOperators.md | 16 +- docs/OperatorKernels.md | 2 +- .../contrib_ops/cpu/bert/attention_common.h | 5 + .../cuda/bert/flash_attention/flash_api.cc | 51 +- .../cuda/bert/flash_attention/flash_api.h | 6 +- .../cuda/bert/group_query_attention.cc | 26 +- .../cuda/bert/group_query_attention.h | 5 + .../cuda/bert/group_query_attention_helper.h | 150 ++-- .../cuda/bert/group_query_attention_impl.cu | 125 ++-- .../cuda/bert/group_query_attention_impl.h | 2 + .../core/graph/contrib_ops/bert_defs.cc | 34 +- .../test/python/transformers/rotary_flash.py | 693 ++++++++++++++++++ .../python/transformers/test_flash_attn.py | 668 ++++++++++++++--- tools/ci_build/build.py | 3 +- ...txt => requirements-transformers-test.txt} | 3 +- 15 files changed, 1517 insertions(+), 272 deletions(-) create mode 100644 onnxruntime/test/python/transformers/rotary_flash.py rename tools/ci_build/{requirements.txt => requirements-transformers-test.txt} (94%) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 624cda1d37f73..e7b537d6894c8 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2398,24 +2398,28 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
+
do_rotary : int
+
Whether to use rotary position embedding. Default value is 0.
kv_num_heads : int (required)
Number of attention heads for k and v
local_window_size : int
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
num_heads : int (required)
Number of attention heads for q
+
rotary_interleaved : int
+
Rotate using interleaved pattern. Default value is 0 (False).
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
-#### Inputs +#### Inputs (7 - 9)
query : T
-
Query with shape (batch_size, sequence_length, hidden_size)
-
key : T
+
Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).
+
key (optional) : T
Key with shape (batch_size, kv_sequence_length, kv_hidden_size)
-
value : T
+
value (optional) : T
Value with shape (batch_size, kv_sequence_length, kv_hidden_size)
past_key (optional) : T
past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
@@ -2425,6 +2429,10 @@ This version of the operator has been available since version 1 of the 'com.micr
1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.
total_sequence_length : M
Scalar tensor of total sequence length (past + new).
+
cos_cache (optional) : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
+
sin_cache (optional) : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
#### Outputs diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 3b695af2839b6..31cca232fde34 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -843,7 +843,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index da489a6901512..8afeb874750b4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -99,10 +99,15 @@ struct GroupQueryAttentionParameters { bool is_unidirectional; // causal int local_window_size; bool kv_share_buffer; + bool is_packed_qkv; bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor + bool do_rotary; + bool rotary_interleaved; float scale; AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; + int zeros_count; + int* zero_ptr; }; namespace attention { diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index d6eb87228bb4a..2c296bf4f8483 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -355,13 +355,15 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in Status mha_fwd_kvcache(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size - void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size - void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size - void* k, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size - void* v, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size void* out, // batch_size x seqlen_q x num_heads x head_size void* softmax_lse, // batch_size x num_heads x seqlen_q void* seqlens_k_, // batch_size + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) int batch_size, int num_heads, int num_heads_k, @@ -376,16 +378,15 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - int local_window_size) { - // if (seqlen_q == 1) { - // is_causal = false; - // } // causal=true is the same as causal=false in this case - + int local_window_size, + bool is_rotary_interleaved, + bool is_packed_qkv) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + // In kv-cache case, seqlen_k_max as kv sequence length Flash_fwd_params params; set_params_fprop(params, batch_size, @@ -406,15 +407,24 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, is_causal ? 0 : -1); params.dprops = &dprops; - if (k != nullptr && v != nullptr) { + if (k_new != nullptr && v_new != nullptr) { params.seqlen_knew = seqlen_k_new; - params.knew_ptr = k; - params.vnew_ptr = v; + params.knew_ptr = k_new; + params.vnew_ptr = v_new; // All stride are in elements, not bytes. - params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; - params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; - params.knew_row_stride = num_heads_k * head_size; - params.vnew_row_stride = num_heads_k * head_size; + if (is_packed_qkv) { + params.q_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.q_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.knew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.vnew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.knew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.vnew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + } else { + params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.knew_row_stride = num_heads_k * head_size; + params.vnew_row_stride = num_heads_k * head_size; + } params.knew_head_stride = head_size; params.vnew_head_stride = head_size; } else { @@ -434,6 +444,13 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, params.cu_seqlens_k = static_cast(seqlens_k_); } + if (rotary_cos != nullptr) { + params.rotary_cos_ptr = rotary_cos; + params.rotary_sin_ptr = rotary_sin; + params.is_rotary_interleaved = is_rotary_interleaved; + params.rotary_dim = (head_size / 16) * 16; + } + params.num_splits = num_splits; if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { params.softmax_lseaccum_ptr = softmax_lse_accum; @@ -444,7 +461,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, } // Only split kernel supports appending to KV cache - run_mha_fwd(params, stream, /*force_split_kernel=*/k != nullptr); + run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 3d75d6834b8e0..387d1cf9d84fe 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -87,6 +87,8 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, void* out, // batch_size x seqlen_q x num_heads x head_size void* softmax_lse, // batch_size x num_heads x seqlen_q void* seqlens_k_, // batch_size + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) int batch_size, int num_heads, int num_heads_k, @@ -101,7 +103,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - int local_window_size = -1); + int local_window_size = -1, + bool is_rotary_interleaved = false, + bool is_packed_qkv = false); size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index fd6fb79742cac..fe56f84f0a886 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -47,6 +47,8 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) kv_num_heads_ = static_cast(kv_num_heads); is_past_bsnh_ = false; // info.GetAttrOrDefault("is_past_bsnh", 1) == 1; local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; scale_ = info.GetAttrOrDefault("scale", 0.0f); #if USE_FLASH_ATTENTION @@ -62,6 +64,9 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) #else disable_memory_efficient_attention_ = true; #endif + if (!disable_flash_attention_) { + zeros_ = this->GetScratchBuffer(kZerosCount, nullptr); + } } template @@ -73,6 +78,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* past_value = context->Input(4); const Tensor* seqlens_k = context->Input(5); const Tensor* total_seqlen = context->Input(6); + const Tensor* cos_cache = context->Input(7); + const Tensor* sin_cache = context->Input(8); auto& device_prop = GetDeviceProp(); GroupQueryAttentionParameters parameters; @@ -84,6 +91,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { value, past_key, past_value, + cos_cache, + sin_cache, ¶meters, num_heads_, kv_num_heads_, @@ -93,7 +102,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { scale_, device_prop.maxThreadsPerBlock)); parameters.local_window_size = local_window_size_; + parameters.is_unidirectional = is_unidirectional_; + parameters.zeros_count = kZerosCount; + parameters.zero_ptr = zeros_.get(); + // parameters.left_padding = left_padding_; int sequence_length = parameters.sequence_length; + parameters.do_rotary = do_rotary_; + parameters.rotary_interleaved = rotary_interleaved_; TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size); @@ -139,6 +154,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { !use_flash_attention && !disable_memory_efficient_attention_ && local_window_size_ == -1 && + do_rotary_ == false && + key != nullptr && (parameters.head_size & 7) == 0 && parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length && (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && @@ -182,8 +199,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { Tensor* present_value = context->Output(2, present_shape); data.query = reinterpret_cast(query->Data()); - data.key = reinterpret_cast(key->Data()); - data.value = reinterpret_cast(value->Data()); + data.key = key == nullptr ? nullptr : reinterpret_cast(key->Data()); + data.value = value == nullptr ? nullptr : reinterpret_cast(value->Data()); data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data()); data.output = reinterpret_cast(output->MutableData()); @@ -229,6 +246,11 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { if (fmha_buffer != nullptr) { data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); } + // Rotary + if (parameters.do_rotary) { + data.cos_cache = reinterpret_cast(cos_cache->Data()); + data.sin_cache = reinterpret_cast(sin_cache->Data()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 54a8127e29e7b..15573ece166fc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -23,10 +23,15 @@ class GroupQueryAttention final : public CudaKernel { int num_heads_; // number of attention heads int kv_num_heads_; // different for k and v for group query attention int local_window_size_; + bool is_unidirectional_; bool is_past_bsnh_; + bool do_rotary_; + bool rotary_interleaved_; float scale_; bool disable_flash_attention_; bool disable_memory_efficient_attention_; + static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) + IAllocatorUniquePtr zeros_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index 2cb9955807f26..853e1a710cb24 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -16,6 +16,8 @@ Status CheckInputs(const Tensor* query, const Tensor* value, const Tensor* past_key, const Tensor* past_value, + const Tensor* cos_cache, + const Tensor* sin_cache, void* parameters, int num_heads, int kv_num_heads, @@ -24,19 +26,18 @@ Status CheckInputs(const Tensor* query, bool is_past_bsnh, float scale) { // Note: Here S* is past_cache_sequence_length, S- is past_sequence_length, S+ is sequence_length - // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) - // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) + // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr + // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr // no packing for q/k/v: - // query (Q) : (B, S, D) - // key (K) : (B, S, D_kv) - // value (V) : (B, S, D_kv) + // query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv)) + // key (K) : (B, S, D_kv) or nullptr + // value (V) : (B, S, D_kv) or nullptr ORT_UNUSED_PARAMETER(value); AttentionQkvFormat qkv_format = Q_K_V_BSNH; AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH; - + const bool is_packed_qkv = key == nullptr; const auto& query_dims = query->Shape().GetDims(); - const auto& key_dims = key->Shape().GetDims(); if (query_dims.size() != 3) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", @@ -46,10 +47,69 @@ Status CheckInputs(const Tensor* query, int batch_size = static_cast(query_dims[0]); int sequence_length = static_cast(query_dims[1]); int q_hidden_size = static_cast(query_dims[2]); - int head_size = static_cast(q_hidden_size) / num_heads; + int head_size = 0; + + if (num_heads % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", + num_heads % kv_num_heads); + } - int kv_hidden_size = static_cast(key_dims[2]); + int kv_hidden_size = 0; + // Check key and value when not packed + if (!is_packed_qkv) { + head_size = static_cast(q_hidden_size) / num_heads; + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + const auto& key_dims = key->Shape().GetDims(); + if (key_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", + key_dims.size()); + } else if (query_dims[0] != key_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 0 (batch size)"); + } else if (query_dims[1] != key_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 1 (sequence length)"); + } + kv_hidden_size = static_cast(key_dims[2]); + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", + value_dims.size()); + } else if (query_dims[0] != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (batch size)"); + } else if (query_dims[1] != value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 1 (sequence length)"); + } else if (value_dims[2] != kv_hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); + } + } else { + // Check packed qkv + head_size = static_cast(q_hidden_size) / (num_heads + 2 * kv_num_heads); + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + q_hidden_size = head_size * num_heads; + kv_hidden_size = head_size * kv_num_heads; + } + // Check past-present KV int32_t past_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { const auto& past_key_dims = past_key->Shape().GetDims(); @@ -130,41 +190,6 @@ Status CheckInputs(const Tensor* query, "Input 'past_key' and 'past_value' shall be both present or both absent."); } - if (key_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", - key_dims.size()); - } - if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } - - if (num_heads % kv_num_heads != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", - num_heads % kv_num_heads); - } - - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", - value_dims.size()); - } - - if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch_size)"); - } - - if (static_cast(sequence_length) != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query,' 'key,' and 'value' shall have the same dim 1 (sequence_length)"); - } - - if (value_dims[2] != kv_hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); - } - // Check seqlens_k tensor (holding past seqlen for token gen) const auto& seqlens_dim = seqlens_k->Shape().GetDims(); if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { @@ -180,6 +205,36 @@ Status CheckInputs(const Tensor* query, int total_sequence_length = *((*total_seqlen).template Data()); int present_sequence_length = std::max(total_sequence_length, past_sequence_length); + if (cos_cache != nullptr && sin_cache != nullptr) { + const auto& cos_dims = cos_cache->Shape().GetDims(); + const auto& sin_dims = sin_cache->Shape().GetDims(); + + if (head_size % 16 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size shall be a multiple of 16. Got head_size % 16 == ", + head_size % 16); + } + if (cos_dims[0] != present_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache dimension 0 must be of present_sequence_length."); + } + if (sin_dims[0] != present_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sin_cache dimension 0 must be of present_sequence_length."); + } + if (cos_dims[1] != (head_size / 16) * 8) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); + } + if (sin_dims[1] != (head_size / 16) * 8) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); + } + } else if (cos_cache != nullptr || sin_cache != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); + } + bool is_prompt = sequence_length != 1; if (parameters != nullptr) { @@ -190,9 +245,10 @@ Status CheckInputs(const Tensor* query, output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors output_parameters->hidden_size = q_hidden_size; output_parameters->num_heads = num_heads; - output_parameters->head_size = q_hidden_size / num_heads; + output_parameters->head_size = head_size; output_parameters->kv_hidden_size = kv_hidden_size; output_parameters->kv_num_heads = kv_num_heads; + output_parameters->is_packed_qkv = is_packed_qkv; output_parameters->is_unidirectional = true; output_parameters->is_prompt = is_prompt; output_parameters->scale = scale; @@ -208,6 +264,8 @@ Status CheckInputs(const Tensor* query, const Tensor* value, const Tensor* past_key, const Tensor* past_value, + const Tensor* cos_cache, + const Tensor* sin_cache, void* parameters, int num_heads, int kv_num_heads, @@ -220,7 +278,7 @@ Status CheckInputs(const Tensor* query, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(query, key, value, past_key, past_value, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale); + return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale); } } // namespace group_query_attention_helper diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 5b0f5d0cfe601..d88e9a49fb5ee 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -151,9 +151,10 @@ template Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, cudaStream_t stream, - const int max_threads_per_block) { + const int max_threads_per_block, + const bool past_only = false) { const int batch_size = parameters.batch_size; - const int kv_sequence_length = parameters.sequence_length; + const int kv_sequence_length = past_only ? 0 : parameters.sequence_length; const int past_sequence_length = parameters.seqlen_past_kv_cache; const int present_sequence_length = parameters.seqlen_present_kv_cache; const int kv_num_heads = parameters.kv_num_heads; @@ -441,7 +442,6 @@ Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters, return CUDA_CALL(cudaGetLastError()); } - __global__ void PastToTotalSeqlen(int32_t* seqlens_k, int32_t* seqlens_k_buff, const int add_seqlen) { @@ -451,7 +451,7 @@ __global__ void PastToTotalSeqlen(int32_t* seqlens_k, // Convert Past to Total sequence length tensor Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream, - const int threads_per_block) { + const int threads_per_block) { if (parameters.is_prompt) { return Status::OK(); } @@ -482,91 +482,63 @@ Status FlashAttention( const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int kv_sequence_length = parameters.sequence_length; - const int present_sequence_length = parameters.seqlen_present_kv_cache; const int num_heads = parameters.num_heads; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; - - void* query = reinterpret_cast(const_cast(data.query)); - void* key = reinterpret_cast(const_cast(data.key)); - void* value = reinterpret_cast(const_cast(data.value)); - bool is_causal = true; - bool is_bf16 = std::is_same::value; - // Note: seqlens_k is past sequence length for flash - if (parameters.is_prompt) { - // Launch kernel to copy seqlen - constexpr int thr_per_blk = 256; - int blk_in_grid = (batch_size + thr_per_blk -1) / thr_per_blk; - repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, batch_size); - } - - void* seqlens_k = reinterpret_cast(data.seqlens_k); - - if (parameters.kv_share_buffer) { - // Share buffer case - if (data.past_key == nullptr || data.past_key != data.present_key) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Past and present kv shall share the same tensor when kv_share_buffer is on."); - } - - if (parameters.is_prompt) { - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); - key = nullptr; - value = nullptr; - seqlens_k = reinterpret_cast(data.seqlens_k_total); - } - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1); + void* query = reinterpret_cast(const_cast(data.query)); + void* key; + void* value; - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( - device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), - seqlens_k, batch_size, num_heads, kv_num_heads, - head_size, sequence_length, present_sequence_length, kv_sequence_length, - scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum), parameters.local_window_size)); + if (!parameters.is_packed_qkv) { + key = reinterpret_cast(const_cast(data.key)); + value = reinterpret_cast(const_cast(data.value)); } else { - // Not share buffer case - // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient - if (data.past_key != nullptr && data.past_key == data.present_key) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Past and present kv share the same tensor but kv_share_buffer is not on."); - } - - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + const size_t key_offset = static_cast(num_heads * head_size); + const size_t value_offset = static_cast(kv_num_heads * head_size); + key = reinterpret_cast(query) + key_offset; + value = reinterpret_cast(key) + value_offset; + } - if (!parameters.is_prompt) { - ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); + void* seqlens_k = reinterpret_cast(data.seqlens_k); + if (parameters.is_prompt) { + // set seqlens_k to zeros... flash api uses seqlens_k to indicate where to append key and value + // user should use seqlens_k to index into output to get new tokens + if (batch_size <= parameters.zeros_count) { + seqlens_k = parameters.zero_ptr; + } else { + // Launch kernel to create larger seqlen tensor when batch_size > 256 + constexpr int thr_per_blk = 256; + int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; + repeat_seqlen<<>>(data.seqlens_k_total, 0, batch_size); + seqlens_k = data.seqlens_k_total; } - - seqlens_k = reinterpret_cast(data.seqlens_k_total); - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1); - DUMP_TENSOR("Q", data.query, batch_size, sequence_length, num_heads, head_size); - DUMP_TENSOR("K", data.present_key, batch_size, kv_num_heads, present_sequence_length, head_size); - DUMP_TENSOR("V", data.present_value, batch_size, kv_num_heads, present_sequence_length, head_size); - - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( - device_prop, stream, query, present_key, present_value, nullptr, nullptr, data.output, reinterpret_cast(data.softmax_lse), - seqlens_k, batch_size, num_heads, kv_num_heads, - head_size, sequence_length, present_sequence_length, 0, - scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum), parameters.local_window_size)); + } else if (!parameters.kv_share_buffer) { // copy past kv to present kv + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block, true)); } + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + void* cos_cache = reinterpret_cast(const_cast(data.cos_cache)); + void* sin_cache = reinterpret_cast(const_cast(data.sin_cache)); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, stream, query, present_key, present_value, key, value, data.output, + reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, + batch_size, num_heads, kv_num_heads, head_size, sequence_length, + parameters.seqlen_present_kv_cache, kv_sequence_length, + scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved, + parameters.is_packed_qkv)); + + // if (parameters.left_padding && parameters.is_prompt) { + // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); + // } + DUMP_TENSOR_INIT(); DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); @@ -672,7 +644,6 @@ Status EfficientAttention( p.has_custom_right_padding = true; run_memory_efficient_attention(p); - DUMP_TENSOR_INIT(); DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index de32d7ea93163..1bf91f9c875eb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -21,6 +21,8 @@ struct GroupQueryAttentionData { const T* past_key = nullptr; const T* past_value = nullptr; int* seqlens_k = nullptr; + const T* cos_cache = nullptr; + const T* sin_cache = nullptr; // Flash buffers T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 7f34647f1faef..8583474a1e391 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -259,13 +259,13 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& *output_shape.add_dim() = query_dims[1]; *output_shape.add_dim() = query_dims[2]; updateOutputShape(ctx, 0, output_shape); - } else { - fail_shape_inference("Missing input 2 (value)"); } } if (ctx.getNumOutputs() > 1) { // has present output if (hasInputShape(ctx, past_key_index)) { + // auto& query_shape = getInputShape(ctx, 0); + // auto& query_dims = query_shape.dim(); auto& past_shape = getInputShape(ctx, past_key_index); auto& past_dims = past_shape.dim(); if (past_dims.size() != 4) { @@ -273,8 +273,7 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& } ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, past_key_index, 1); ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); - ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, past_key_index, 1); - ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); + // TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not } } } @@ -1015,18 +1014,29 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "left_window_size for local attention (like Mistral). Default value is -1 meaning unused.", AttributeProto::INT, static_cast(-1)) + .Attr("do_rotary", + "Whether to use rotary position embedding. Default value is 0.", + AttributeProto::INT, + OPTIONAL_VALUE) + .Attr("rotary_interleaved", + "Rotate using interleaved pattern. Default value is 0 (False).", + AttributeProto::INT, + OPTIONAL_VALUE) .Input(0, "query", - "Query with shape (batch_size, sequence_length, hidden_size)", + "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape" + "(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).", "T") .Input(1, "key", "Key with shape (batch_size, kv_sequence_length, kv_hidden_size) ", - "T") + "T", + OpSchema::Optional) .Input(2, "value", "Value with shape (batch_size, kv_sequence_length, kv_hidden_size)", - "T") + "T", + OpSchema::Optional) .Input(3, "past_key", "past state key with support for format BNSH. When past_key uses same tensor as present_key" @@ -1047,6 +1057,16 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "total_sequence_length", "Scalar tensor of total sequence length (past + new).", "M") + .Input(7, + "cos_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T", + OpSchema::Optional) + .Input(8, + "sin_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T", + OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", diff --git a/onnxruntime/test/python/transformers/rotary_flash.py b/onnxruntime/test/python/transformers/rotary_flash.py new file mode 100644 index 0000000000000..42bff9c92b41b --- /dev/null +++ b/onnxruntime/test/python/transformers/rotary_flash.py @@ -0,0 +1,693 @@ +# Copyright (c) 2023, Tri Dao. + + +from typing import Optional, Tuple, Union + +import torch +import triton +import triton.language as tl +from einops import rearrange, repeat + +##### TRITON KERNEL FOR ROTARY ##### + + +# @triton.autotune( +# configs=[ +# triton.Config({"block_m": 2}), +# triton.Config({"block_m": 4}), +# triton.Config({"block_m": 8}), +# triton.Config({"block_m": 16}), +# ], +# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"], +# ) +@triton.jit +def rotary_kernel( + out_, # Pointers to matrices + x_, + cos_, + sin_, + CU_SEQLENS, + SEQLEN_OFFSETS, # this could be int or a pointer + # Matrix dimensions + seqlen, + nheads, + rotary_dim, + seqlen_ro, + CACHE_KEY_SEQLEN, + # strides + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + # Meta-parameters + block_k: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + block_m: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + rotary_dim_half = rotary_dim // 2 + + if not IS_VARLEN: + x_ = x_ + pid_batch * stride_x_batch + pid_head * stride_x_nheads + out_ = out_ + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + x_ = x_ + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + out_ = out_ + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + + if pid_m * block_m >= seqlen: + return + rm = pid_m * block_m + tl.arange(0, block_m) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + rk = tl.arange(0, block_k) + rk_half = tl.arange(0, block_k // 2) + + if not INTERLEAVED: + # Load the 1st and 2nd halves of x_, do calculation, then store to 1st and 2nd halves of out_ + x_ = x_ + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim) + cos_ = cos_ + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + sin_ = sin_ + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + cos = tl.load(cos_, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0).to( + tl.float32 + ) + sin = tl.load(sin_, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0).to( + tl.float32 + ) + x0 = tl.load(x_, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32) + x1 = tl.load( + x_ + rotary_dim_half * stride_x_headdim, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + # write back result + out_ = out_ + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim) + tl.store(out_, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)) + tl.store( + out_ + rotary_dim_half * stride_out_headdim, + o1, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + ) + else: + # We don't want to load x_[0, 2, 4, ...] and x_[1, 3, 5, ...] separately since both are slow. + # Instead, we load x0 = x_[0, 1, 2, 3, ...] and x1 = x_[1, 0, 3, 2, ...]. + # Loading x0 will be fast but x1 will be slow. + # Then we load cos = cos_[0, 0, 1, 1, ...] and sin = sin_[0, 0, 1, 1, ...]. + # Then we do the calculation and use tl.where to pick put the right outputs for the even + # and for the odd indices. + rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... + rk_repeat = tl.arange(0, block_k) // 2 + x0_ = x_ + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim) + x1_ = x_ + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim) + cos_ = cos_ + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + sin_ = sin_ + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + cos = tl.load( + cos_, + mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), + other=1.0, + ).to(tl.float32) + sin = tl.load( + sin_, + mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + x0 = tl.load(x0_, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(tl.float32) + x1 = tl.load(x1_, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + x0_cos = x0 * cos + x1_sin = x1 * sin + out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) + out_ = out_ + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim) + tl.store(out_, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim)) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + """ + Arguments: + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). + cos: (seqlen_ro, rotary_dim / 2) + sin: (seqlen_ro, rotary_dim / 2) + seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Returns: + y: (batch, seqlen, nheads, headdim) + """ + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" + total_seqlen, nheads, headdim = x.shape + batch_p_1 = cu_seqlens.shape[0] + batch = batch_p_1 - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + + assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert x.dtype == cos.dtype, f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + cos, sin = cos.contiguous(), sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + block_k = 32 if rotary_dim <= 32 else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) + grid = lambda META: (triton.cdiv(seqlen, META["block_m"]), batch, nheads) # noqa + block_m = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(x.device.index): + rotary_kernel[grid]( + output, # data ptrs + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, # shapes + nheads, + rotary_dim, + seqlen_ro, + seqlen // 128, # key for triton cache (limit number of compilations) + output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 + output.stride(-3), # seqlen_stride or total_seqlen_stride + output.stride(-2), # nheads_stride + output.stride(-1), # headdim_stride + x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 + x.stride(-3), # seqlen stride or total_seqlen_stride + x.stride(-2), # nheads stride + x.stride(-1), # headdim stride + block_k, + isinstance(seqlen_offsets, torch.Tensor), + is_varlen, + interleaved, + conjugate, + block_m, + ) + return output + + +##### ROTARY API ##### + + +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) + + +def apply_rotary_emb_torch(x, cos, sin, interleaved=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], + dim=-1, + ) + + +class ApplyRotaryEmb(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + out = apply_rotary( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=interleaved, + inplace=inplace, + ) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + ctx.inplace = inplace + ctx.max_seqlen = max_seqlen + return out if not inplace else x + + @staticmethod + def backward(ctx, do): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cu_seqlens = ctx.saved_tensors + # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with + # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. + if not ctx.interleaved and not ctx.inplace: + do = do.clone() + dx = apply_rotary( + do, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=ctx.interleaved, + inplace=ctx.inplace, + conjugate=True, + ) + return dx, None, None, None, None, None, None, None + + +def apply_rotary_emb( + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + """ + Arguments: + x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Return: + out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + return ApplyRotaryEmb.apply(x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen) + + +# For backward compatibility +apply_rotary_emb_func = apply_rotary_emb + + +class ApplyRotaryEmbQKV(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cos, + sin, + cos_k=None, + sin_k=None, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + ): + batch, seqlen, three, nheads, headdim = qkv.shape + assert three == 3 + if cos_k is None and sin_k is None and qkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") + qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim) + apply_rotary(qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True) + else: + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + q, k = qkv[:, :, 0], qkv[:, :, 1] + apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True) + apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True) + ctx.save_for_backward(cos, sin, cos_k, sin_k) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cos_k, sin_k) + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + return qkv + + @staticmethod + def backward(ctx, dqkv): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cos_k, sin_k = ctx.saved_tensors + if cos_k is None and sin_k is None and dqkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d") + apply_rotary( + dqk, + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=ctx.interleaved, + inplace=True, + conjugate=True, + ) + else: + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + dq, dk = dqkv[:, :, 0], dqkv[:, :, 1] + apply_rotary(dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True) + apply_rotary( + dk, + cos_k, + sin_k, + seqlen_offsets, + interleaved=ctx.interleaved, + inplace=True, + conjugate=True, + ) + return dqkv, None, None, None, None, None, None + + +def apply_rotary_emb_qkv_( + qkv, + cos, + sin, + cos_k=None, + sin_k=None, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, +): + """ + Arguments: + qkv: (batch_size, seqlen, 3, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + cos_k, sin_k: (seqlen, rotary_dim / 2), optional + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of + 1st half and 2nd half (GPT-NeoX style). + seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. + Most commonly used in inference when we have KV cache. + Return: + qkv: (batch_size, seqlen, 3, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding *inplace* to the first rotary_dim of Q and K. + """ + return ApplyRotaryEmbQKV.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets) + + +class ApplyRotaryEmbKV(torch.autograd.Function): + @staticmethod + def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0): + batch, seqlen, two, nheads, headdim = kv.shape + assert two == 2 + k = kv[:, :, 0] + apply_rotary(k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + return kv + + @staticmethod + def backward(ctx, dkv): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, seqlen_offsets = ctx.saved_tensors + else: + cos, sin = ctx.saved_tensors + apply_rotary( + dkv[:, :, 0], + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=ctx.interleaved, + inplace=True, + conjugate=True, + ) + return dkv, None, None, None, None + + +apply_rotary_emb_kv_ = ApplyRotaryEmbKV.apply + + +def apply_rotary_emb_kv_( + kv, + cos, + sin, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, +): + """ + Arguments: + kv: (batch_size, seqlen, 2, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of + 1st half and 2nd half (GPT-NeoX style). + seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. + Most commonly used in inference when we have KV cache. + Return: + kv: (batch_size, seqlen, 2, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding *inplace* to the first rotary_dim of K. + """ + return ApplyRotaryEmbKV.apply(kv, cos, sin, interleaved, seqlen_offsets) + + +class RotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + + If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). + A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 + Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py + """ + + def __init__( + self, + dim: int, + base=10000.0, + interleaved=False, + scale_base=None, + pos_idx_in_fp32=True, + device=None, + ): + """ + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, + otherwise they might be in lower precision. + This option was added because previously (before 2023-07-02), when we construct + the position indices, we use the dtype of self.inv_freq. In most cases this would + be fp32, but if the model is trained in pure bf16 (not mixed precision), then + self.inv_freq would be bf16, and the position indices are also in bf16. + Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the + embeddings for some positions will coincide. + To maintain compatibility with models previously trained in pure bf16, + we add this option. + """ + super().__init__() + self.dim = dim + self.base = float(base) + self.pos_idx_in_fp32 = pos_idx_in_fp32 + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = self._compute_inv_freq(device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.interleaved = interleaved + self.scale_base = scale_base + scale = ( + (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) + if scale_base is not None + else None + ) + self.register_buffer("scale", scale, persistent=False) + + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + + def _compute_inv_freq(self, device=None): + return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + # Reset the tables if the sequence length has changed, + # if we're on a new device (possibly due to tracing for instance), + # or if we're switching from inference mode to training + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference()) + ): + self._seq_len_cached = seqlen + # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 + # And the output of arange can be quite large, so bf16 would lose a lot of precision. + # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + inv_freq = self._compute_inv_freq(device=device) + else: + inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + # Don't do einsum, it converts fp32 to fp16 under AMP + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, inv_freq) + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + else: + power = ( + torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 + ) / self.scale_base + scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") + # We want the multiplication by scale to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) + + def forward( + self, + qkv: torch.Tensor, + kv: Optional[torch.Tensor] = None, + seqlen_offset: Union[int, torch.Tensor] = 0, + max_seqlen: Optional[int] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, + else it's just q of shape (batch, seqlen, nheads, headdim) + kv: (batch, seqlen, 2, nheads, headdim) + seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one + should pass in max_seqlen, which will update the cos / sin cache up to that length. + Apply rotary embedding *inplace* to qkv and / or kv. + """ + seqlen = qkv.shape[1] + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) + elif isinstance(seqlen_offset, int): + self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) + if kv is None: + if self.scale is None: + return apply_rotary_emb_qkv_( + qkv, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + else: + return apply_rotary_emb_qkv_( + qkv, + self._cos_cached, + self._sin_cached, + self._cos_k_cached, + self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + else: + q = qkv + q = apply_rotary_emb_func( + q, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + inplace=True, + seqlen_offsets=seqlen_offset, + ) + if self.scale is None: + kv = apply_rotary_emb_kv_( + kv, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + else: + kv = apply_rotary_emb_kv_( + kv, + self._cos_k_cached, + self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + return q, kv diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py index 8a839875de2a2..90d28872d3cc8 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -20,6 +20,7 @@ from bert_padding import pad_input, unpad_input from einops import rearrange, repeat from onnx import TensorProto, helper +from rotary_flash import apply_rotary_emb from onnxruntime import InferenceSession, OrtValue, SessionOptions @@ -184,7 +185,13 @@ def create_multihead_attention_graph(config): def create_group_query_attention_graph_prompt( - config, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1 + config, + past_kv_format=Formats.BSNH, + share_buffer=True, + local_window_size=-1, + rotary=False, + rotary_interleaved=False, + packed=False, ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length @@ -193,18 +200,22 @@ def create_group_query_attention_graph_prompt( "GroupQueryAttention", [ "query", - "key", - "value", + "key" if not packed else "", + "value" if not packed else "", "past_key" if share_buffer else "", "past_value" if share_buffer else "", "seqlens_k", "total_sequence_length", + "cos_cache" if rotary else "", + "sin_cache" if rotary else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, local_window_size=local_window_size, + do_rotary=rotary, + rotary_interleaved=rotary_interleaved, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -218,25 +229,9 @@ def create_group_query_attention_graph_prompt( [ config.batch_size, config.q_sequence_length, - config.num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "key", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "value", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads * config.head_size, + (config.num_heads * config.head_size) + if not packed + else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size), ], ), helper.make_tensor_value_info( @@ -250,6 +245,27 @@ def create_group_query_attention_graph_prompt( [1], ), ] + if not packed: + graph_input += [ + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + ] if share_buffer: graph_input += [ helper.make_tensor_value_info( @@ -273,6 +289,25 @@ def create_group_query_attention_graph_prompt( ], ), ] + if rotary: + graph_input += [ + helper.make_tensor_value_info( + "cos_cache", + TensorProto.FLOAT16, + [ + config.buffer_sequence_length if share_buffer else config.kv_sequence_length, + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + helper.make_tensor_value_info( + "sin_cache", + TensorProto.FLOAT16, + [ + config.buffer_sequence_length if share_buffer else config.kv_sequence_length, + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + ] graph_output = [ helper.make_tensor_value_info( @@ -334,7 +369,13 @@ def create_group_query_attention_graph_prompt( def create_group_query_attention_graph_past( - config, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1 + config, + past_kv_format=Formats.BSNH, + share_buffer=True, + local_window_size=-1, + rotary=False, + rotary_interleaved=False, + packed=False, ): past_kv_seqlen = config.kv_sequence_length present_kv_seqlen = ( @@ -345,18 +386,22 @@ def create_group_query_attention_graph_past( "GroupQueryAttention", [ "query", - "key", - "value", + "key" if not packed else "", + "value" if not packed else "", "past_key", "past_value", "seqlens_k", "total_sequence_length", + "cos_cache" if rotary else "", + "sin_cache" if rotary else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, local_window_size=local_window_size, + do_rotary=rotary, + rotary_interleaved=rotary_interleaved, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -370,25 +415,9 @@ def create_group_query_attention_graph_past( [ config.batch_size, config.sequence_length, - config.num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "key", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - config.kv_num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "value", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - config.kv_num_heads * config.head_size, + (config.num_heads * config.head_size) + if not packed + else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size), ], ), helper.make_tensor_value_info( @@ -411,8 +440,6 @@ def create_group_query_attention_graph_past( config.head_size, ], ), - ] - graph_input += [ helper.make_tensor_value_info( "seqlens_k", TensorProto.INT32, @@ -424,6 +451,46 @@ def create_group_query_attention_graph_past( [1], ), ] + if not packed: + graph_input += [ + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + ] + if rotary: + graph_input += [ + helper.make_tensor_value_info( + "cos_cache", + TensorProto.FLOAT16, + [ + config.kv_sequence_length + (0 if share_buffer else config.sequence_length), + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + helper.make_tensor_value_info( + "sin_cache", + TensorProto.FLOAT16, + [ + config.kv_sequence_length + (0 if share_buffer else config.sequence_length), + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + ] graph_output = [ helper.make_tensor_value_info( @@ -663,21 +730,38 @@ def mha_func(q, k, v, config): def gqa_prompt_func( - q, k, v, config, new_k, new_v, seqlens_k=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True + q, + k, + v, + config, + new_k, + new_v, + cos=None, + sin=None, + seqlens_k=None, + window_size=-1, + past_kv_format=Formats.BSNH, + share_buffer=True, + rotary_interleaved=False, ): onnx_model_str = create_group_query_attention_graph_prompt( - config, past_kv_format, share_buffer, local_window_size=window_size + config, + past_kv_format, + share_buffer, + local_window_size=window_size, + rotary=cos is not None, + rotary_interleaved=rotary_interleaved, + packed=new_k is None, ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None past_v = v.clone() if share_buffer else None - new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) - new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) + if new_k is not None: + new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), - "key": new_k.detach().cpu().numpy(), - "value": new_v.detach().cpu().numpy(), "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), @@ -686,9 +770,17 @@ def gqa_prompt_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_input( "past_key", "cuda", 0, numpy.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() ) @@ -713,17 +805,23 @@ def gqa_prompt_func( else: ort_inputs = { "query": q.detach().cpu().numpy(), - "key": new_k.detach().cpu().numpy(), - "value": new_v.detach().cpu().numpy(), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) io_binding.bind_output("output") @@ -737,21 +835,38 @@ def gqa_prompt_func( def gqa_past_func( - q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1 + q, + k, + v, + config, + new_k, + new_v, + cos=None, + sin=None, + seqlens_k=None, + past_kv_format=Formats.BSNH, + share_buffer=True, + window_size=-1, + rotary_interleaved=False, ): onnx_model_str = create_group_query_attention_graph_past( - config, past_kv_format, share_buffer, local_window_size=window_size + config, + past_kv_format, + share_buffer, + local_window_size=window_size, + rotary=cos is not None, + rotary_interleaved=rotary_interleaved, + packed=new_k is None, ) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) past_k = k.clone() past_v = v.clone() - new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) - new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) + if new_k is not None: + new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), - "key": new_k.detach().cpu().numpy(), - "value": new_v.detach().cpu().numpy(), "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), @@ -763,9 +878,17 @@ def gqa_past_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_input( "past_key", "cuda", 0, numpy.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() ) @@ -790,8 +913,6 @@ def gqa_past_func( else: ort_inputs = { "query": q.detach().cpu().numpy(), - "key": new_k.detach().cpu().numpy(), - "value": new_v.detach().cpu().numpy(), "past_key": past_k.detach().cpu().numpy(), "past_value": past_v.detach().cpu().numpy(), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), @@ -805,9 +926,17 @@ def gqa_past_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_cpu_input("past_key", ort_inputs["past_key"]) io_binding.bind_cpu_input("past_value", ort_inputs["past_value"]) io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) @@ -1029,9 +1158,12 @@ def parity_check_mha( def parity_check_gqa_prompt( config, - causal=False, + causal=True, local=False, past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, rtol=1e-3, atol=1e-3, ): @@ -1080,6 +1212,8 @@ def parity_check_gqa_prompt( dtype=torch.float16, requires_grad=False, ) + # print(k.shape) + # print(new_k.shape) window_size = (-1, -1) left_window_size = -1 @@ -1105,19 +1239,47 @@ def parity_check_gqa_prompt( # device="cuda", # ) # cache_seqlens[random.randint(0, cache_seqlens.size(dim=0) - 1)] = config.kv_sequence_length + rotary_seqlens = torch.tensor([0], device="cuda").repeat(config.batch_size) + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: + q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=config.q_sequence_length, + ) + # q_ro = q + k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, new_k + rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") arange = rearrange(torch.arange(config.buffer_sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") kv_seqlens = torch.tensor([config.kv_sequence_length], device="cuda").repeat(config.batch_size) kv_seqlens_expanded = rearrange(kv_seqlens, "b -> b 1") update_mask = arange < kv_seqlens_expanded - k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1125,13 +1287,47 @@ def parity_check_gqa_prompt( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_prompt_func( - q, k, v, config, new_k, new_v, cache_seqlens, left_window_size, past_format, True - ) + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_prompt_func( + packed_qkv, + k, + v, + config, + None, + None, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + True, + rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_prompt_func( + q, + k, + v, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + True, + rotary_interleaved, + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + # print(cache_seqlens[0]) + # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, :, 0]) + # print((out - out_ref)[0, :, 0, 0]) + # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1139,10 +1335,16 @@ def parity_check_gqa_prompt( # Compare results print( "KV-buffer", + " packed:", + packed, " causal:", causal, " local:", local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1171,9 +1373,12 @@ def parity_check_gqa_prompt( def parity_check_gqa_prompt_no_buff( config, - causal=False, + causal=True, local=False, past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, rtol=1e-3, atol=1e-3, ): @@ -1229,13 +1434,42 @@ def parity_check_gqa_prompt_no_buff( # device="cuda", # ) # cache_seqlens[random.randint(0, cache_seqlens.size(dim=0) - 1)] = config.kv_sequence_length + rotary_seqlens = torch.tensor([0], device="cuda").repeat(config.batch_size) + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: + q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=config.q_sequence_length, + ) + # q_ro = q + k_ro = apply_rotary_emb(k_cache_ref, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, k_cache_ref + k_cache_ref = k_ro + brange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") new_mask = brange < cache_seqlens_expanded k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size + q_ro, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1243,9 +1477,39 @@ def parity_check_gqa_prompt_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_prompt_func( - q, None, None, config, new_k, new_v, cache_seqlens, left_window_size, past_format, False - ) + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_prompt_func( + packed_qkv, + None, + None, + config, + None, + None, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + False, + rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_prompt_func( + q, + None, + None, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + False, + rotary_interleaved, + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1256,7 +1520,17 @@ def parity_check_gqa_prompt_no_buff( # Compare results print( - "KV-buffer", + "No buff", + " packed:", + packed, + " causal:", + causal, + " local:", + local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1285,9 +1559,12 @@ def parity_check_gqa_prompt_no_buff( def parity_check_gqa_past( config, - causal=False, + causal=True, local=False, past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, rtol=1e-3, atol=1e-3, ): @@ -1336,6 +1613,7 @@ def parity_check_gqa_past( dtype=torch.float16, requires_grad=False, ) + window_size = (-1, -1) left_window_size = -1 if local: @@ -1359,18 +1637,45 @@ def parity_check_gqa_past( dtype=torch.int32, device="cuda", ) + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: + q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=cache_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=config.sequence_length, + ) + # q_ro = q + k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, new_k + arange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length ) - k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1378,13 +1683,46 @@ def parity_check_gqa_past( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_past_func( - q, k, v, config, new_k, new_v, cache_seqlens, past_format, True, left_window_size - ) + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_past_func( + packed_qkv, + k, + v, + config, + None, + None, + cos, + sin, + cache_seqlens, + past_format, + True, + left_window_size, + rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_past_func( + q, + k, + v, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + past_format, + True, + left_window_size, + rotary_interleaved, + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + # print(cache_seqlens[0]) + # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, cache_seqlens[0], :]) + # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1394,10 +1732,16 @@ def parity_check_gqa_past( "KV-buffer", "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", + " packed:", + packed, " causal:", causal, " local:", local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, " B:", config.batch_size, " S:", @@ -1427,6 +1771,9 @@ def parity_check_gqa_past_no_buff( causal=False, local=False, past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, rtol=1e-3, atol=1e-3, ): @@ -1503,18 +1850,47 @@ def parity_check_gqa_past_no_buff( device="cuda", ) cache_seqlens[random.randint(0, config.batch_size - 1)] = config.kv_sequence_length + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = ( + torch.rand(config.kv_sequence_length + config.sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + ) + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: + q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=cache_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=config.sequence_length, + ) + # q_ro = q + k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, new_k + arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length ) - k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1522,13 +1898,47 @@ def parity_check_gqa_past_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_past_func( - q, k, v, config, new_k, new_v, cache_seqlens, past_format, False, window_size=left_window_size - ) + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_past_func( + packed_qkv, + k, + v, + config, + None, + None, + cos, + sin, + cache_seqlens, + past_format, + False, + window_size=left_window_size, + rotary_interleaved=rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_past_func( + q, + k, + v, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + past_format, + False, + window_size=left_window_size, + rotary_interleaved=rotary_interleaved, + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + # print(cache_seqlens[0]) + # print((out - out_ref)[0]) + # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, :, 0]) + # Make sure past-present buffer updating correctly # assert numpy.allclose( # present_k[:, :, :-1, :], k_cache_ref.detach().cpu().numpy()[:, :, :-1, :], rtol=rtol, atol=atol, equal_nan=True @@ -1540,10 +1950,16 @@ def parity_check_gqa_past_no_buff( # Compare results print( "NO buff", + " packed:", + packed, " causal:", causal, " local:", local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1671,10 +2087,25 @@ def test_gqa_no_past(self): for n, n2 in num_h: for h in h_sizes: for local in [False, True]: - for past_kv_format in [Formats.BNSH]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - parity_check_gqa_prompt(config, local=local, past_format=past_kv_format) - parity_check_gqa_prompt_no_buff(config, local=local, past_format=past_kv_format) + for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for past_kv_format, packed in [(Formats.BNSH, False), (Formats.BNSH, True)]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + parity_check_gqa_prompt( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_prompt_no_buff( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) def test_gqa_past(self): if not torch.cuda.is_available(): @@ -1684,7 +2115,6 @@ def test_gqa_past(self): return os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" print("-------- TEST GQA PAST (TOKEN GEN) ---------") - print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") batches = [5] if pipeline_mode else [1, 3, 5] seqs = ( [(1, 128), (1, 1024), (1, 2048)] @@ -1706,6 +2136,7 @@ def test_gqa_past(self): num_h = [(32, 32), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) + print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") for b in batches: for s, s2 in seqs: for n, n2 in num_h: @@ -1734,23 +2165,30 @@ def test_gqa_past(self): for n, n2 in num_h: for h in h_sizes: for local in [False, True]: - for past_kv_format in [Formats.BNSH]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - parity_check_gqa_past( - config, - local=local, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) - parity_check_gqa_past_no_buff( - config, - local=local, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) + for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for past_kv_format, packed in [(Formats.BNSH, False), (Formats.BNSH, True)]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) if __name__ == "__main__": diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 1034a82cb2854..6e5cd7b57e403 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2046,7 +2046,8 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): numpy_init_version = numpy.__version__ pb_init_version = google.protobuf.__version__ run_subprocess( - [sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd=SCRIPT_DIR + [sys.executable, "-m", "pip", "install", "-r", "requirements-transformers-test.txt"], + cwd=SCRIPT_DIR, ) run_subprocess([sys.executable, "-m", "pytest", "transformers"], cwd=cwd) # Restore initial numpy/protobuf version in case other tests use it diff --git a/tools/ci_build/requirements.txt b/tools/ci_build/requirements-transformers-test.txt similarity index 94% rename from tools/ci_build/requirements.txt rename to tools/ci_build/requirements-transformers-test.txt index 57fc8f08336d2..a5279781462a7 100644 --- a/tools/ci_build/requirements.txt +++ b/tools/ci_build/requirements-transformers-test.txt @@ -3,7 +3,8 @@ packaging protobuf==3.20.2 numpy==1.24.0 ; python_version < '3.12' numpy==1.26.0 ; python_version >= '3.12' +torch coloredlogs==15.0 transformers==4.36.0 psutil -einops \ No newline at end of file +einops