From 8621bdcd444f99891eaf8ab4ea7f5f60431400ba Mon Sep 17 00:00:00 2001 From: MistEO Date: Thu, 28 Sep 2023 05:23:26 +0800 Subject: [PATCH 01/20] Remove unnecessary #incldue (#17716) --- onnxruntime/core/platform/windows/stacktrace.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/core/platform/windows/stacktrace.cc b/onnxruntime/core/platform/windows/stacktrace.cc index cac6f4f29043b..d7d423e4a483e 100644 --- a/onnxruntime/core/platform/windows/stacktrace.cc +++ b/onnxruntime/core/platform/windows/stacktrace.cc @@ -10,7 +10,6 @@ #include #endif #endif -#include #include "core/common/logging/logging.h" #include "core/common/gsl.h" From 1f4a3529ddde7796cd018b6c2bd97a6e099c4b88 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Wed, 27 Sep 2023 14:53:37 -0700 Subject: [PATCH 02/20] Bugfix: Add initializer to model in AttentionMask directly (#17719) --- .../python/tools/transformers/fusion_attention.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 40f2aee875382..a7460157ba409 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -78,7 +78,15 @@ def process_mask(self, input: str) -> str: # ReduceSum-13: axes is moved from attribute to input axes_name = "ort_const_1_reduce_sum_axes" if self.model.get_initializer(axes_name) is None: - self.add_initializer(name=axes_name, data_type=TensorProto.INT64, dims=[1], vals=[1], raw=False) + self.model.add_initializer( + helper.make_tensor( + name=axes_name, + data_type=TensorProto.INT64, + dims=[1], + vals=[1], + raw=False, + ) + ) mask_index_node = helper.make_node( "ReduceSum", inputs=[input_name, axes_name], From 91367484627a25b46520674dcc320efe505437ef Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Thu, 28 Sep 2023 13:46:44 +0800 Subject: [PATCH 03/20] Fix: Fail to skip disabledmodel in winml (#17728) ### Description Move appending source name behind the ModifyNameIfDisabledTest ### Motivation and Context In winml, disabled test name doesn't include the model source name. WinML job will be broken in the new image. https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1151451&view=logs&s=4eef7ad1-5202-529d-b414-e2b14d056c05 ### Verified https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1151691&view=logs&s=4eef7ad1-5202-529d-b414-e2b14d056c05 --- winml/test/model/model_tests.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/winml/test/model/model_tests.cpp b/winml/test/model/model_tests.cpp index 0b4c10eac9142..5057f74046638 100644 --- a/winml/test/model/model_tests.cpp +++ b/winml/test/model/model_tests.cpp @@ -380,13 +380,6 @@ std::string GetFullNameOfTest(ITestCase* testCase, winml::LearningModelDeviceKin name += tokenizedModelPath[tokenizedModelPath.size() - 2] += "_"; // model name name += tokenizedModelPath[tokenizedModelPath.size() - 3]; // opset version - // To introduce models from model zoo, the model path is structured like this "///?.onnx" - std::string source = tokenizedModelPath[tokenizedModelPath.size() - 4]; - // `models` means the root of models, to be ompatible with the old structure, that is, the source name is empty. - if (source != "models") { - name += "_" + source; - } - std::replace_if( name.begin(), name.end(), [](char c) { return !absl::ascii_isalnum(c); }, '_' ); @@ -405,6 +398,13 @@ std::string GetFullNameOfTest(ITestCase* testCase, winml::LearningModelDeviceKin ModifyNameIfDisabledTest(/*inout*/ name, deviceKind); } + // To introduce models from model zoo, the model path is structured like this "///?.onnx" + std::string source = tokenizedModelPath[tokenizedModelPath.size() - 4]; + // `models` means the root of models, to be ompatible with the old structure, that is, the source name is empty. + if (source != "models") { + name += "_" + source; + } + return name; } From fc9a69dcae4f81839593feaa69d081edf9b3c564 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Thu, 28 Sep 2023 09:30:42 -0700 Subject: [PATCH 04/20] Update VecAddMoveOnlyFunctor and VecAddWithIsSupportedMethod with Default constructor (#17705) ### Description ### Motivation and Context --- onnxruntime/test/framework/tunable_op_test.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/framework/tunable_op_test.cc b/onnxruntime/test/framework/tunable_op_test.cc index 6aa7c5ee9f8ac..0d9e557ebc813 100644 --- a/onnxruntime/test/framework/tunable_op_test.cc +++ b/onnxruntime/test/framework/tunable_op_test.cc @@ -263,8 +263,7 @@ TEST(TunableOp, OpWrapsMutableFunctor) { class VecAddMoveOnlyFunctor { public: - VecAddMoveOnlyFunctor() { - } + VecAddMoveOnlyFunctor() = default; VecAddMoveOnlyFunctor(VecAddMoveOnlyFunctor&&) = default; ORT_DISALLOW_COPY_AND_ASSIGNMENT(VecAddMoveOnlyFunctor); @@ -290,8 +289,7 @@ TEST(TunableOp, OpWrapsMoveOnlyFunctor) { class VecAddWithIsSupportedMethod { public: - VecAddWithIsSupportedMethod() { - } + VecAddWithIsSupportedMethod() = default; VecAddWithIsSupportedMethod(VecAddWithIsSupportedMethod&&) = default; ORT_DISALLOW_COPY_AND_ASSIGNMENT(VecAddWithIsSupportedMethod); From 20f96fd096e8d3b866c0e947d420f971234eb795 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 28 Sep 2023 14:32:08 -0700 Subject: [PATCH 05/20] Fix Attention Runtime Error for CLIP model (#17729) ### Description The condition check is not correct ``` if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT } else { // BERT } ``` Change it to ``` if (is_unidirectional_) { // GPT } else { // BERT } ``` Another walkaround is to enable fused causal attention by adding an environment variable `ORT_ENABLE_FUSED_CAUSAL_ATTENTION=1` before running stable diffusion. ### Motivation and Context Without the fix, optimized CLIP model of stable diffusion will encounter error in running Attention node: 2023-09-24 16:15:31.206037898 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running Attention node. Name:'Attention_0' Status Message: /onnxruntime_src/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu:207 bool onnxruntime::contrib::cuda::FusedMHARunnerFP16v2::mhaImpl::is_flash_attention(int) const interface->mHasCausalMask == false was false. Note that the bug has been there for a long time. It is just surfaced since we recently added a fusion for CLIP, which will trigger the error. We will add a comprehensive test for causal attention later to avoid such corner cases. --- .../contrib_ops/cuda/bert/attention.cc | 42 ++++++++++--------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index f0385ea5abdfb..48af0a9a6adec 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -140,27 +140,29 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { #endif if (!use_flash_attention) { - if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT - // GPT fused kernels requires left side padding. mask can be: - // none (no padding), 1D sequence lengths or 2d mask. - // Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token - // where past state is empty. - bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; - bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && - nullptr == relative_position_bias && - parameters.past_sequence_length == 0 && - parameters.hidden_size == parameters.v_hidden_size && - FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, - enable_trt_flash_attention_, true); - if (use_causal_fused_runner) { - // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. - if (nullptr == fused_fp16_runner_.get()) { - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, - enable_trt_flash_attention_, parameters.scale); + if (is_unidirectional_) { // GPT + if (enable_fused_causal_attention_) { + // GPT fused kernels requires left side padding. mask can be: + // none (no padding), 1D sequence lengths or 2d mask. + // Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token + // where past state is empty. + bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; + bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && + nullptr == relative_position_bias && + parameters.past_sequence_length == 0 && + parameters.hidden_size == parameters.v_hidden_size && + FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, + enable_trt_flash_attention_, true); + if (use_causal_fused_runner) { + // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. + if (nullptr == fused_fp16_runner_.get()) { + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, + enable_trt_flash_attention_, parameters.scale); + } + + // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check. + fused_runner = fused_fp16_runner_.get(); } - - // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check. - fused_runner = fused_fp16_runner_.get(); } } else { // BERT bool use_fused_runner = !disable_fused_self_attention_ && From 9cb60c5b86cd13aab1b782915e1ac1f794fd7fe1 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 29 Sep 2023 08:11:36 +1000 Subject: [PATCH 06/20] Resize and EP specific transpose optimization updates (#17664) ### Description - Treat Resize as layout sensitive by default - whilst the ONNX spec does not specify a layout, EPs tend to implement only one - add second usage in L2 of TransposeOptimizer to plugin the ability to push a Transpose through a Resize assigned to the CPU EP - Allow EP specific logic for changes the ops considered to be layout sensitive to be plugged in - expected usage is for #17200 ### Motivation and Context Finish simplifying/clarifying transpose optimization and layout transformation that was proposed in #15552. This PR along with #17618 should complete the changes. --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- .../core/framework/graph_partitioner.cc | 4 +- .../core/framework/kernel_registry_manager.cc | 5 +- .../contrib_ops/internal_nhwc_onnx_schemas.cc | 2 +- .../core/optimizer/graph_transformer_utils.cc | 19 +-- .../layout_transformation.cc | 123 ++++++++++++------ .../onnx_transpose_optimization.cc | 27 ++-- .../onnx_transpose_optimization.h | 1 + .../ort_transpose_optimization.cc | 27 +++- .../ort_transpose_optimization.h | 1 - .../core/optimizer/transpose_optimizer.cc | 16 ++- .../core/optimizer/transpose_optimizer.h | 8 +- .../core/providers/xnnpack/nn/resize.cc | 8 +- .../xnnpack/xnnpack_execution_provider.cc | 7 +- .../test/optimizer/qdq_transformer_test.cc | 4 +- .../optimizer/transpose_optimizer_test.cc | 96 +++++++++----- .../providers/cpu/tensor/resize_op_test.cc | 2 +- 16 files changed, 235 insertions(+), 115 deletions(-) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index dede1ecc95885..1b492a3561396 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -177,9 +177,9 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - // Run layout transformer only for EPs other than CPU EP and provided the preferred layout is NHWC + // Run layout transformer for EPs with preferred layout of NHWC // CPU EP layout transformation happens later when level 3 transformers are run. - if (params.mode != GraphPartitioner::Mode::kAssignOnly && + if (params.mode != GraphPartitioner::Mode::kAssignOnly && params.transform_layout.get() && current_ep.GetPreferredLayout() == DataLayout::NHWC) { for (auto& capability : capabilities) { TryAssignNodes(graph, *capability->sub_graph, ep_type); diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc index 38c8a4a4e3d5e..c4eef5b27c1bb 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.cc +++ b/onnxruntime/core/framework/kernel_registry_manager.cc @@ -63,8 +63,9 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node, auto create_error_message = [&node, &status](const std::string& prefix) { std::ostringstream errormsg; errormsg << prefix << node.OpType() << "(" << node.SinceVersion() << ")"; - if (!node.Name().empty()) errormsg << " (node " << node.Name() << "). "; - if (!status.IsOK()) errormsg << status.ErrorMessage(); + errormsg << " (node:'" << node.Name() << "' ep:'" << node.GetExecutionProviderType() << "'). "; + if (!status.IsOK()) + errormsg << status.ErrorMessage(); return errormsg.str(); }; diff --git a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc index 3ce7c40e754dc..d3fc5873cb274 100644 --- a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc +++ b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc @@ -94,7 +94,7 @@ void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function> GenerateTransformers( const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; #endif const InlinedHashSet dml_ep = {onnxruntime::kDmlExecutionProvider}; + AllocatorPtr cpu_allocator = std::make_shared(); + switch (level) { case TransformerLevel::Level1: { // RewriteRule optimizations are the simplest (they generally remove unnecessary nodes and are cheap to run) @@ -240,13 +242,14 @@ InlinedVector> GenerateTransformers( // run TransposeOptimizer last as it works in a slightly different way by moving Transpose nodes around. // shouldn't affect the end result - just easier to debug any issue if it's last. - // local CPU allocator is enough as this allocator is finally passed to a local tensor. - // We will also benefit by using a local allocator as we don't need to pass allocator as parameter for EP API refactor - AllocatorPtr cpu_allocator = std::make_shared(); transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); } break; case TransformerLevel::Level2: { + // we run TransposeOptimizer again in Level2 for some CPU EP specific optimizations that can only be + // applied once nodes are assigned to the CPU EP (which happens between level 1 and level 2). + transformers.emplace_back(std::make_unique(std::move(cpu_allocator), kCpuExecutionProvider)); + const bool enable_quant_qdq_cleanup = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableQuantQDQCleanup, "0") == "1"; #if !defined(DISABLE_CONTRIB_OPS) @@ -366,16 +369,16 @@ InlinedVector> GenerateTransformers( if (MlasNchwcGetBlockSize() > 1) { transformers.emplace_back(std::make_unique()); } - AllocatorPtr cpu_allocator = std::make_shared(); + auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry)); if (nhwc_transformer->IsActive()) { transformers.emplace_back(std::move(nhwc_transformer)); } - // NCHWCtransformer should have a higher priority versus this. Because NCHWCtransformer also do the similar things - // of fusion patterns and target on CPU. However, NCHWCtransformer will reorder the layout to nchwc which is only available for - // x86-64 cpu, not edge cpu like arm. But This transformer could be used by opencl-ep/cpu-ep. So - // we will prefer NhwcTransformer once ort runs on x86-64 CPU, otherwise ConvAddActivationFusion is enabled. + + // NchwcTransformer must have a higher priority than ConvAddActivationFusion. NchwcTransformer does similar + // fusions targeting CPU but also reorders the layout to NCHWc which is expected to be more efficient but is + // only available on x86-64. // PR #6351 implemented similar fusion-pattern for CUDA only, and can only fuse conv-add-relu, // while we can fuse more activation. transformers.emplace_back(std::make_unique(cpu_ep)); diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc index 2d12c407e6e31..6c91949e467ae 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc @@ -13,27 +13,91 @@ using namespace onnx_transpose_optimization; namespace onnxruntime { namespace layout_transformation { +namespace { +// Cost check for aggressively pushing the Transpose nodes involved in the layout transformation further out. +CostCheckResult PostLayoutTransformCostCheck(const api::GraphRef& graph, const api::NodeRef& node, + const std::vector& perm, + const std::unordered_set& outputs_leading_to_transpose) { + // we aggressively push the layout transpose nodes. + // Exception: pushing through a Concat can result in Transpose nodes being added to multiple other inputs which + // can potentially be worse for performance. Use the cost check in that case. + if (node.OpType() != "Concat" && + (perm == ChannelFirstToLastPerm(perm.size()) || perm == ChannelLastToFirstPerm(perm.size()))) { + return CostCheckResult::kPushTranspose; + } + + // for other nodes use the default ORT cost check + return OrtEPCostCheck(graph, node, perm, outputs_leading_to_transpose); +} + +/// +/// Default function for checking if a node should have its layout changed. Allows EP specific adjustments to the +/// default set of layout sensitive operators if required. +/// +/// Longer term, if required, the EP API could allow the EP to provide a delegate to plugin EP specific logic so we +/// don't hardcode it here. +/// +/// Node to check +/// true if the node should have its layout converted to NHWC. +bool ConvertNodeLayout(const api::NodeRef& node) { + // skip if op is not an ONNX or contrib op + auto domain = node.Domain(); + if (domain != kOnnxDomain && domain != kMSDomain) { + return false; + } + + const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps(); + + // handle special cases +#if defined(USE_XNNPACK) + if (node.GetExecutionProviderType() == kXnnpackExecutionProvider) { + if (node.OpType() == "Resize") { + // XNNPACK supports NCHW and NHWC for Resize so we don't need to use the internal NHWC domain and wrap the Resize + // with Transpose nodes. EPAwareHandleResize will allow an NCHW <-> NHWC Transpose to be pushed through + // the Resize during transpose optimization. + return false; + } + } +#endif + +#if defined(USE_JSEP) + // TODO(fs-eire): Remove special case handing of JSEP once NHWC Resize implementation is fixed + if (node.GetExecutionProviderType() == kJsExecutionProvider) { + if (node.OpType() == "Resize") { + // leave Resize as-is pending bugfix for NHWC implementation. this means the node will remain in the ONNX domain + // with the original input layout. + return false; + } + } +#endif + + // #if defined(USE_CUDA) + // if (node.GetExecutionProviderType() == kCudaExecutionProvider) { + // Update as per https://github.com/microsoft/onnxruntime/pull/17200 with CUDA ops that support NHWC + // } + // #endif + + return layout_sensitive_ops.count(node.OpType()) != 0; +} +} // namespace // Layout sensitive NCHW ops. TransformLayoutForEP will wrap these with Transpose nodes to convert the input // data to NHWC and output data back to NCHW, and move the op to the internal NHWC domain (kMSInternalNHWCDomain). -// The EP requesting these ops MUST be able to handle the node with the operator in the kMSInternalNHWCDomain. +// The EP requesting these ops MUST be able to handle the node with the operator in the kMSInternalNHWCDomain domain. // Once all the layout sensitive ops requested by the EP are wrapped the transpose optimizer will attempt to remove // as many of the layout transposes as possible. const std::unordered_set& GetORTLayoutSensitiveOps() { static std::unordered_set ort_layout_sensitive_ops = []() { const auto& layout_sensitive_ops = onnx_transpose_optimization::GetLayoutSensitiveOps(); std::unordered_set ort_specific_ops = - { "FusedConv", - "QLinearAveragePool", - "QLinearGlobalAveragePool" -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_QNN) || defined(USE_WEBNN) - // The CUDA/ROCM Resize kernel is layout sensitive as it only handles NCHW input. - // The CPU kernel and ONNX spec are not limited to handling NCHW input so are not layout sensitive, and - // onnx_layout_transformation::HandleResize is used. - , - "Resize" -#endif - }; + { + "FusedConv", + "QLinearAveragePool", + "QLinearGlobalAveragePool", + // Whilst the ONNX spec doesn't specify a layout for Resize, we treat it as layout sensitive by default + // as EPs tend to only support one layout. + "Resize", + }; ort_specific_ops.insert(layout_sensitive_ops.cbegin(), layout_sensitive_ops.cend()); return ort_specific_ops; @@ -42,45 +106,21 @@ const std::unordered_set& GetORTLayoutSensitiveOps() { return ort_layout_sensitive_ops; } -// Cost check for aggressively pushing the Transpose nodes involved in the layout transformation further out. -static CostCheckResult -PostLayoutTransformCostCheck(const api::GraphRef& graph, const api::NodeRef& node, - const std::vector& perm, - const std::unordered_set& outputs_leading_to_transpose) { - // we aggressively push the layout transpose nodes. - // Exception: pushing through a Concat can result in Transpose nodes being added to multiple other inputs which - // can potentially be worse for performance. Use the cost check in that case. - if (node.OpType() != "Concat" && - (perm == ChannelFirstToLastPerm(perm.size()) || perm == ChannelLastToFirstPerm(perm.size()))) { - return CostCheckResult::kPushTranspose; - } - - // for other nodes use the default ORT cost check - return OrtEPCostCheck(graph, node, perm, outputs_leading_to_transpose); -} - Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvider& execution_provider, AllocatorPtr cpu_allocator, const DebugGraphFn& debug_graph_fn) { // We pass in nullptr for the new_node_ep param as new nodes will be assigned by the graph partitioner after // TransformLayoutForEP returns. - // sub graph recurse will be added later. + // sub graph recurse will be added later auto api_graph = MakeApiGraph(graph, cpu_allocator, /*new_node_ep*/ nullptr); - const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps(); // to convert to NHWC we need to wrap layout sensitive nodes to Transpose from NCHW to NHWC and back. for (auto& node : api_graph->Nodes()) { - if (layout_sensitive_ops.count(node->OpType())) { - if (node->GetExecutionProviderType() != execution_provider.Type()) { - continue; - } - - auto domain = node->Domain(); - // Skip if domain is incorrect - if (domain != kOnnxDomain && domain != kMSDomain) { - continue; - } + if (node->GetExecutionProviderType() != execution_provider.Type()) { + continue; + } + if (ConvertNodeLayout(*node)) { // if already transformed then change the domain to kMSInternalNHWCDomain this way the EP // knows this op is in the expected format. if (node->GetAttributeIntDefault("channels_last", 0) == 1) { @@ -137,7 +177,6 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm}); } - // TODO: Technically Resize doesn't need to change domain as the ONNX Resize spec is not layout sensitive. SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain); modified = true; } diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index f6d9a60726ccc..81b415c2e40ae 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -1242,18 +1242,7 @@ static void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, con node.SetInput(i, gather_output); } -static bool HandleResize([[maybe_unused]] HandlerArgs& args) { -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_QNN) || defined(USE_WEBNN) - // The CUDA Resize kernel requires that the input is NCHW, so we can't push a Transpose through a Resize - // in ORT builds with CUDA enabled. - // The ROCm EP is generated from the CUDA EP kernel so the same applies to builds with ROCm enabled. - // The QNN EP requires the input to be NHWC, so the Resize handler is also not enabled for QNN builds. - // - // TODO: Remove this special case once the CUDA Resize kernel is implemented "generically" (i.e.) aligning with the - // generic nature of the ONNX spec. - // See https://github.com/microsoft/onnxruntime/pull/10824 for a similar fix applied to the CPU Resize kernel. - return false; -#else +bool HandleResize([[maybe_unused]] HandlerArgs& args) { auto inputs = args.node.Inputs(); int64_t rank_int = gsl::narrow_cast(args.perm.size()); @@ -1279,10 +1268,10 @@ static bool HandleResize([[maybe_unused]] HandlerArgs& args) { TransposeOutputs(args.ctx, args.node, args.perm); return true; -#endif } -constexpr HandlerInfo resize_handler = {&FirstInput, &HandleResize}; +// Not currently registered by default. +// constexpr HandlerInfo resize_handler = {&FirstInput, &HandleResize}; static bool HandlePad(HandlerArgs& args) { size_t rank = args.perm.size(); @@ -2034,8 +2023,11 @@ static const std::unordered_map handler_ma {"Split", split_handler}, {"Shape", shape_handler}, {"Pad", pad_handler}, - {"Resize", resize_handler}, - {"ReduceSum", reduce_op_handler}, + + // Execution providers tend to only implement Resize for specific layouts. Due to that, it's safer to not + // push a Transpose through a Resize unless the EP specifically checks that it can handle the change via an + // extended handler. + // {"Resize", resize_handler}, {"ReduceLogSum", reduce_op_handler}, {"ReduceLogSumExp", reduce_op_handler}, @@ -2043,6 +2035,7 @@ static const std::unordered_map handler_ma {"ReduceMean", reduce_op_handler}, {"ReduceMin", reduce_op_handler}, {"ReduceProd", reduce_op_handler}, + {"ReduceSum", reduce_op_handler}, {"ReduceSumSquare", reduce_op_handler}, {"ReduceL1", reduce_op_handler}, {"ReduceL2", reduce_op_handler}, @@ -2385,6 +2378,8 @@ OptimizeResult OptimizeImpl(OptimizerCtx& ctx) { continue; } + // NOTE: this bleeds ORT specific logic into the base optimizer, however we justify that for now because we expect + // the types that the ORT DQ provides to be added to the ONNX spec, at which point this special case can go away. if (IsMSDomain(dq_domain) && !TransposeQuantizeDequantizeAxis(ctx.graph, perm_inv, *dq_node)) { continue; } diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h index f8aaeca915171..cc1552704c187 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h @@ -98,6 +98,7 @@ bool HandleSimpleNodeWithAxis(HandlerArgs& args, std::optional default_ // base handlers that are used by extended handlers. add from transpose_optimizer.cc as needed. bool HandleReduceOps(HandlerArgs& args); +bool HandleResize([[maybe_unused]] HandlerArgs& args); void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, const std::vector& perm, diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc index ead82a6b56741..f4f3505128737 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc @@ -5,12 +5,35 @@ #include #include "core/graph/constants.h" +#include "core/framework/utils.h" #include "core/optimizer/transpose_optimization/ort_optimizer_utils.h" using namespace onnx_transpose_optimization; namespace onnxruntime { +static bool EPAwareHandleResize(HandlerArgs& args) { + // Whilst Resize is not technically layout sensitive, execution providers typically implement handling for only one + // layout. Due to that, only push a Transpose through a Resize once it is assigned and we know it's being handled + // by an EP that supports multiple layouts. Currently that's the CPU and XNNPACK EPs. + const auto ep_type = args.node.GetExecutionProviderType(); + if (ep_type == kCpuExecutionProvider || ep_type == kXnnpackExecutionProvider) { + // allow NCHW <-> NHWC for now. not clear any other sort of transpose has a valid usage in a real model + int64_t rank_int = gsl::narrow_cast(args.perm.size()); + if (rank_int == 4) { + static const std::vector nchw_to_nhwc_perm{0, 2, 3, 1}; + static const std::vector nhwc_to_nchw_perm{0, 3, 1, 2}; + if (args.perm == nchw_to_nhwc_perm || args.perm == nhwc_to_nchw_perm) { + return HandleResize(args); + } + } + } + + return false; +} + +constexpr HandlerInfo ep_aware_resize_handler = {&FirstInput, &EPAwareHandleResize}; + static bool HandleQLinearConcat(HandlerArgs& args) { return HandleSimpleNodeWithAxis(args); } @@ -62,7 +85,7 @@ static bool HandleMaxPool(HandlerArgs& args) { ORT_UNUSED_PARAMETER(args); return false; #else - if (args.node.GetExecutionProviderType() != "CPUExecutionProvider") { + if (args.node.GetExecutionProviderType() != kCpuExecutionProvider) { return false; } @@ -103,6 +126,7 @@ static bool HandleContribQuantizeDequantizeLinear(HandlerArgs& args) { } constexpr HandlerInfo max_pool_op_handler = {&FirstInput, &HandleMaxPool}; + constexpr HandlerInfo node_1_inp_handler = {&FirstInput, &HandleSimpleNode}; constexpr HandlerInfo reduce_op_handler = {&FirstInput, &HandleReduceOps}; constexpr HandlerInfo contrib_quantize_dequantize_linear_handler = {&FirstInput, @@ -113,6 +137,7 @@ const HandlerMap& OrtExtendedHandlers() { static const HandlerMap extended_handler_map = []() { HandlerMap map = { {"MaxPool", max_pool_op_handler}, + {"Resize", ep_aware_resize_handler}, {"com.microsoft.QuantizeLinear", contrib_quantize_dequantize_linear_handler}, {"com.microsoft.DequantizeLinear", contrib_quantize_dequantize_linear_handler}, {"com.microsoft.QLinearAdd", q_linear_binary_op_handler}, diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h index 0a5dbd6d13d06..8245d8c3b4eae 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h @@ -10,7 +10,6 @@ namespace onnxruntime { /// /// Get the extended handlers for ORT specific transpose optimization. /// These include handlers for contrib ops, and where we have an NHWC version of a layout sensitive op. -/// Extends the handlers returned by OrtHandlers. /// /// HandlerMap const onnx_transpose_optimization::HandlerMap& OrtExtendedHandlers(); diff --git a/onnxruntime/core/optimizer/transpose_optimizer.cc b/onnxruntime/core/optimizer/transpose_optimizer.cc index 33e3f5eeaf0fa..092df9cc7dcfb 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer.cc +++ b/onnxruntime/core/optimizer/transpose_optimizer.cc @@ -18,10 +18,18 @@ namespace onnxruntime { Status TransposeOptimizer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { - auto api_graph = MakeApiGraph(graph, cpu_allocator_, /*new_node_ep*/ nullptr); - - OptimizeResult result = onnx_transpose_optimization::Optimize(*api_graph, "", /* default cost check*/ nullptr, - OrtExtendedHandlers()); + OptimizeResult result; + + if (ep_.empty()) { + // basic usage - no EP specific optimizations + auto api_graph = MakeApiGraph(graph, cpu_allocator_, /*new_node_ep*/ nullptr); + result = onnx_transpose_optimization::Optimize(*api_graph, "", /* default cost check*/ nullptr, + OrtExtendedHandlers()); + } else { + // EP specific optimizations enabled. Currently only used for CPU EP. + auto api_graph = MakeApiGraph(graph, cpu_allocator_, /*new_node_ep*/ ep_.c_str()); + result = onnx_transpose_optimization::Optimize(*api_graph, ep_, OrtEPCostCheck, OrtExtendedHandlers()); + } if (result.error_msg) { // currently onnx_layout_transformation::Optimize only fails if we hit an unsupported opset. diff --git a/onnxruntime/core/optimizer/transpose_optimizer.h b/onnxruntime/core/optimizer/transpose_optimizer.h index 1ae6d611d2f0e..97d7ab4d0e220 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer.h +++ b/onnxruntime/core/optimizer/transpose_optimizer.h @@ -15,10 +15,14 @@ Push transposes through ops and eliminate them. class TransposeOptimizer : public GraphTransformer { private: AllocatorPtr cpu_allocator_; + const std::string ep_; public: - explicit TransposeOptimizer(AllocatorPtr cpu_allocator) noexcept - : GraphTransformer("TransposeOptimizer"), cpu_allocator_(std::move(cpu_allocator)) {} + explicit TransposeOptimizer(AllocatorPtr cpu_allocator, + const std::string& ep = {}) noexcept + : GraphTransformer(ep.empty() ? "TransposeOptimizer" : "TransposeOptimizer_" + ep), + cpu_allocator_(std::move(cpu_allocator)), + ep_{ep} {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; diff --git a/onnxruntime/core/providers/xnnpack/nn/resize.cc b/onnxruntime/core/providers/xnnpack/nn/resize.cc index 672b2597279db..76c6b6acbfe32 100644 --- a/onnxruntime/core/providers/xnnpack/nn/resize.cc +++ b/onnxruntime/core/providers/xnnpack/nn/resize.cc @@ -331,7 +331,13 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kOnnxDomain, 13, 17, kXnnpackExecution DataTypeImpl::GetTensorType()}), Resize); -ONNX_OPERATOR_KERNEL_EX(Resize, kOnnxDomain, 18, kXnnpackExecutionProvider, +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kOnnxDomain, 18, 18, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + Resize); + +ONNX_OPERATOR_KERNEL_EX(Resize, kOnnxDomain, 19, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index ba577ac38d48c..494c718cde081 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -46,7 +46,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWC class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 10, 10, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 11, 12, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, 17, Resize); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, 18, Resize); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 19, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool); @@ -84,7 +85,9 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo< ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 12, Softmax)>, BuildKernelCreateInfo< - ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, Resize)>, + ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 19, Resize)>, + BuildKernelCreateInfo< + ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, 18, Resize)>, BuildKernelCreateInfo< ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, 17, Resize)>, BuildKernelCreateInfo< diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index d3616a14d8a5d..1bf1cbacf479e 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -3085,12 +3085,12 @@ TEST(QDQTransformerTests, QDQPropagation_Per_Layer_No_Propagation) { check_graph, TransformerLevel::Default, TransformerLevel::Level1, - 18); // disable TransposeOptimizer for simplicity + 18); TransformerTester(build_test_case, check_graph, TransformerLevel::Default, TransformerLevel::Level1, - 19); // disable TransposeOptimizer for simplicity + 19); }; test_case({1, 13, 13, 23}, {0, 2, 3, 1}, false /*use_contrib_qdq*/); diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index 0d66e6f8d5f6e..4f4157bd7b1cf 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -320,20 +320,6 @@ TEST(TransposeOptimizerTests, TestPadNonconst) { /*opset_version*/ {11, 18}); } -// The CUDA Resize kernel assumes that the input is NCHW and -// Resize can't be supported in ORT builds with CUDA enabled. -// TODO: Enable this once the CUDA Resize kernel is implemented -// "generically" (i.e.) aligning with the generic nature of the -// ONNX spec. -// See https://github.com/microsoft/onnxruntime/pull/10824 for -// a similar fix applied to the CPU Resize kernel. -// Per tests included in #10824, the ROCM EP also generates -// incorrect results when this handler is used, so the Resize -// handler is not enabled even for those builds. -// -// The QNN EP requires the input to be NHWC, so the Resize handler is also not enabled -// for QNN builds. -#if !defined(USE_CUDA) && !defined(USE_ROCM) && !defined(USE_QNN) TEST(TransposeOptimizerTests, TestResize) { auto build_test_case_1 = [&](ModelTestBuilder& builder) { auto* input0_arg = MakeInput(builder, {{4, -1, 2, -1}}, {4, 6, 2, 10}, 0.0, 1.0); @@ -362,7 +348,9 @@ TEST(TransposeOptimizerTests, TestResize) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {10, 18}); } @@ -390,7 +378,9 @@ TEST(TransposeOptimizerTests, TestResizeOpset11) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {11, 18}); } @@ -418,7 +408,9 @@ TEST(TransposeOptimizerTests, TestResizeOpset15) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {15, 18}); } @@ -448,7 +440,9 @@ TEST(TransposeOptimizerTests, TestResizeSizeRoi) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {15, 18}); } @@ -482,7 +476,9 @@ TEST(TransposeOptimizerTests, TestResizeRoiScalesZeroRank0) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, {12, 18}); } @@ -511,7 +507,9 @@ TEST(TransposeOptimizerTests, TestResizeNonconst) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {11, 18}); } @@ -540,11 +538,12 @@ TEST(TransposeOptimizerTests, TestResizeNonconstOpset13) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {13, 18}); } -#endif TEST(TransposeOptimizerTests, TestAdd) { auto build_test_case_1 = [&](ModelTestBuilder& builder) { auto* input0_arg = builder.MakeInput({4, 6, 10}, 0.0, 1.0); @@ -4454,12 +4453,13 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue12151) { testing::ContainerEq(fetches[0].Get().DataAsSpan())); } +// These tests uses internal testing EP with static kernels which requires a full build, +// and the NHWC Conv which requires contrib ops +#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) + // Test a Transpose node followed by a Reshape that is logically equivalent to an Transpose can be merged. // The test graph was extracted from a model we were trying to use with the QNN EP. TEST(TransposeOptimizerTests, QnnTransposeReshape) { - // test uses internal testing EP with static kernels which requires a full build, - // and the NHWC Conv with requires contrib ops -#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) Status status; auto model_uri = ORT_TSTR("testdata/layout_transform_reshape.onnx"); @@ -4509,13 +4509,9 @@ TEST(TransposeOptimizerTests, QnnTransposeReshape) { EXPECT_TRUE(inputs[1]->Exists()); } } -#endif } TEST(TransposeOptimizerTests, QnnTransposeReshapeQDQ) { - // test uses internal testing EP with static kernels which requires a full build, - // and the NHWC Conv with requires contrib ops -#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) Status status; auto model_uri = ORT_TSTR("testdata/layout_transform_reshape.qdq.onnx"); @@ -4552,9 +4548,49 @@ TEST(TransposeOptimizerTests, QnnTransposeReshapeQDQ) { EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name() << "' was not assigned to the internal testing EP."; } -#endif } +// Validate handling for EP with layout specific Resize that prefers NHWC +TEST(TransposeOptimizerTests, QnnResizeOpset11) { + Status status; + auto model_uri = ORT_TSTR("testdata/nhwc_resize_scales_opset11.onnx"); + + SessionOptions so; + // Uncomment to debug + // ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); + + using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + + // set the test EP to support all ops in the model so that the layout transform applies to all nodes + const std::unordered_set empty_set; + auto internal_testing_ep = std::make_unique(empty_set, empty_set, DataLayout::NHWC); + internal_testing_ep->EnableStaticKernels().TakeAllNodes(); + + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::move(internal_testing_ep))); + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + + const auto& graph = session.GetGraph(); + // all nodes should be assigned to the internal testing EP, which also means they should be in NHWC layout + std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + for (const auto& node : graph.Nodes()) { + EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name() + << "' was not assigned to the internal testing EP."; + if (node.OpType() == "Resize") { + EXPECT_EQ(node.Domain(), kMSInternalNHWCDomain) << "Resize was not converted to NHWC layout"; + } + } + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Transpose"], 2) << "Resize should have been wrapped in 2 Transpose nodes to convert to NHWC"; + + // And the post-Resize Transpose should have been pushed all the way to the end + GraphViewer viewer(graph); + EXPECT_EQ(graph.GetNode(viewer.GetNodesInTopologicalOrder().back())->OpType(), "Transpose"); +} +#endif // !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) + static void CheckSharedInitializerHandling(bool broadcast) { auto model_uri = broadcast ? ORT_TSTR("testdata/transpose_optimizer_shared_initializers_broadcast.onnx") : ORT_TSTR("testdata/transpose_optimizer_shared_initializers.onnx"); diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 0434b16dc66ce..2ead9ec91f93f 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -780,7 +780,7 @@ TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_5DTrilinear_pytorch_half_pixel) { } TEST(ResizeOpTest, ResizeOpLinearScalesNoOpTest) { - // To test NNAPI EP, we need the sclaes/sizes to be in initializers + // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; From caf98128c1dbbee1ce3c831364cee88031471a32 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 28 Sep 2023 21:43:29 -0700 Subject: [PATCH 07/20] Update linux-wasm-ci.yml: remove the ln command (#17735) ### Description /usr/local/bin can only be modified by root. This command seems unnecessary --- .../ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index 96a0ebd753d8e..fe7f752513f3c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -98,7 +98,6 @@ jobs: cd '$(Build.SourcesDirectory)/cmake/external/emsdk' ./emsdk install 3.1.44 ccache-git-emscripten-64bit ./emsdk activate 3.1.44 ccache-git-emscripten-64bit - ln -s $(Build.SourcesDirectory)/cmake/external/emsdk/ccache/git-emscripten_64bit/bin/ccache /usr/local/bin/ccache displayName: 'emsdk install and activate ccache for emscripten' condition: eq('${{ parameters.WithCache }}', 'true') From b4fbc25b1f5c3942b59ae2a75a72d070f92d7bb1 Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Fri, 29 Sep 2023 11:00:44 -0700 Subject: [PATCH 08/20] [JS/Web] Add ConvTranspose implementation using MatMul (#17573) ### Description Add ConvTranspose implementation using MatMul to increase perf. ### Motivation and Context --- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 243 ++++++++++++++++++ .../ops/3rd-party/conv_backprop_webgpu.ts | 4 +- .../wasm/jsep/webgpu/ops/conv-transpose.ts | 66 ++++- .../jsep/webgpu/ops/conv2dtranspose-mm.ts | 29 +++ js/web/test/data/ops/conv-transpose.jsonc | 131 +++++++++- .../providers/js/operators/conv_transpose.h | 17 ++ 6 files changed, 471 insertions(+), 19 deletions(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts new file mode 100644 index 0000000000000..3925e1cb4f564 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -0,0 +1,243 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts +// +// modified to fit the needs of the project + +import {LOG_DEBUG} from '../../../log'; +import {TensorView} from '../../../tensor-view'; +import {ShapeUtil} from '../../../util'; +import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; +import {ConvTransposeAttributes} from '../conv-transpose'; + +import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; +import {utilFunctions} from './conv_util'; +import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; + +const conv2dTransposeCommonSnippet = + (isChannelsLast: boolean, addBias = false, activation?: Activation, hasPreluActivationWeights = false, + innerElementSize = 4): string => { + const getWSnippet = (innerElementSize: number) => { + switch (innerElementSize) { + case 1: + return 'return W[getIndexFromCoords4D(coord, wShape)];'; + case 4: + return ` + let coord1 = vec4(coordX, coordY, col + 1, rowInner); + let coord2 = vec4(coordX, coordY, col + 2, rowInner); + let coord3 = vec4(coordX, coordY, col + 3, rowInner); + let v0 = W[getIndexFromCoords4D(coord, wShape)]; + let v1 = W[getIndexFromCoords4D(coord1, wShape)]; + let v2 = W[getIndexFromCoords4D(coord2, wShape)]; + let v3 = W[getIndexFromCoords4D(coord3, wShape)]; + return vec4(v0, v1, v2, v3); + `; + default: + throw new Error(`innerElementSize ${innerElementSize} is not supported.`); + } + }; + const coordASnippet = isChannelsLast ? ` + let coord = vec4(batch, iXR, iXC, xCh); + ` : + ` + let coord = vec4(batch, xCh, iXR, iXC); + `; + + const coordResSnippet = isChannelsLast ? ` + let coords = vec4( + batch, + row / outWidth, + row % outWidth, + col); + ` : + ` + let coords = vec4( + batch, + row, + col / outWidth, + col % outWidth); + `; + + const xHeight = isChannelsLast ? 'outBackprop[1]' : 'outBackprop[2]'; + const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]'; + const row = isChannelsLast ? 'row' : 'col'; + const col = isChannelsLast ? 'col' : 'row'; + + const readASnippet = ` + let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; + let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + let outRow = ${row} / outWidth; + let outCol = ${row} % outWidth; + + let WRow = ${col} / (filterDims[1] * inChannels); + let WCol = ${col} / inChannels % filterDims[1]; + let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]); + let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]); + if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) { + return ${typeSnippet(innerElementSize)}(0.0); + } + if (xC < 0.0 || xC >= f32(${xWidth}) || fract(xC) > 0.0) { + return ${typeSnippet(innerElementSize)}(0.0); + } + let iXR = i32(xR); + let iXC = i32(xC); + let xCh = ${col} % inChannels; + ${coordASnippet} + return x[getIndexFromCoords4D(coord, xShape)/${innerElementSize}];`; + + const sampleA = isChannelsLast ? ` + let col = colIn * ${innerElementSize}; + if (row < dimAOuter && col < dimInner) { + ${readASnippet} + } + return ${typeSnippet(innerElementSize)}(0.0);` : + ` + let col = colIn * ${innerElementSize}; + if (row < dimInner && col < dimBOuter) { + ${readASnippet} + } + return ${typeSnippet(innerElementSize)}(0.0);`; + + const sampleW = ` + let col = colIn * ${innerElementSize}; + let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; + let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels); + let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1]; + if (${ + isChannelsLast ? 'row < dimInner && col < dimBOuter' : + 'row < dimInner && col < dimAOuter'} && coordX >= 0 && coordY >= 0) { + let rowInner = row % inChannels; + let coord = vec4(coordX, coordY, col, rowInner); + ${getWSnippet(innerElementSize)} + } + return ${typeSnippet(innerElementSize)}(0.0); + `; + + + const userCode = ` + ${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)} + fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${typeSnippet(innerElementSize)} { + ${isChannelsLast ? sampleA : sampleW} + } + + fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${typeSnippet(innerElementSize)} { + ${isChannelsLast ? sampleW : sampleA} + } + + fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${typeSnippet(innerElementSize)}) { + let col = colIn * ${innerElementSize}; + if (row < dimAOuter && col < dimBOuter) { + var value = valueInput; + let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + ${coordResSnippet} + ${biasActivationSnippet(addBias, activation)} + result[getIndexFromCoords4D(coords, outShape)/${innerElementSize}] = value; + } + }`; + return userCode; + }; + +export const createConv2DTransposeMatMulProgramInfo = + (inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: ConvTransposeAttributes, + outputShape: readonly number[], dimAOuter: number, dimBOuter: number, dimInner: number, hasBias: boolean, + sequentialAccessByThreads: boolean): ProgramInfo => { + const isChannelsLast = attributes.format === 'NHWC'; + const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1]; + const batchSize = outputShape[0]; + const outWidth = isChannelsLast ? outputShape[2] : outputShape[3]; + const outHeight = isChannelsLast ? outputShape[1] : outputShape[2]; + const outChannels = isChannelsLast ? outputShape[3] : outputShape[1]; + const isVec4 = + isChannelsLast ? inChannels % 4 === 0 && outChannels % 4 === 0 : outWidth % 4 === 0 && outChannels % 4 === 0; + + // TODO: fine tune size + const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; + const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; + const workGroupSize: [number, number, number] = isVec4 ? + [8, 8, 1] : + [(dispatchX <= 4 || dispatchY <= 4) ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1]; + const elementsPerThread = + isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 4, dispatchX > 4 && dispatchY <= 4 ? 1 : 4, 1]; + const dispatch = [ + Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]), + Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), + Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[2]) + ]; + + LOG_DEBUG('verbose', () => `[conv_backprop_mm_webgpu] dispatch = ${dispatch}`); + + const innerElementSize = isVec4 ? 4 : 1; + const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); + + + const declareInputs = [ + `@group(0) @binding(0) var x: array<${isVec4 ? 'vec4' : 'f32'}>;`, + '@group(0) @binding(1) var W: array;' + ]; + let declareFunctions = ''; + if (hasBias) { + declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? 'vec4' : 'f32'}>;`); + declareFunctions += ` + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; + }`; + } + return { + ...metadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}), + getShaderSource: () => ` + ${utilFunctions} + ${declareInputs.join('\n')} + @group(0) @binding(${declareInputs.length}) var result: array<${ + isVec4 ? 'vec4' : 'f32'}>; + const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); + const xShape : vec4 = vec4(${inputs[0].dims.join(',')}); + const wShape : vec4 = vec4(${inputs[1].dims.join(',')}); + const outShape : vec4 = vec4(${outputShape.join(',')}); + const outShapeStrides : vec3 = vec3(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')}); + const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ + attributes.kernelShape[isChannelsLast ? 2 : 3]}); + const effectiveFilterDims : vec2 = filterDims + vec2( + ${ + attributes.dilations[0] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, + ${ + attributes.dilations[1] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); + const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${ + attributes.pads[0] + attributes.pads[2]})/2, + i32(effectiveFilterDims[1]) - 1 - (${ + attributes.pads[1] + attributes.pads[3]})/2); + const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); + const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); + const dimAOuter : i32 = ${dimAOuter}; + const dimBOuter : i32 = ${dimBOuter}; + const dimInner : i32 = ${dimInner}; + ${declareFunctions} + ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, undefined, false, innerElementSize)} + ${ + isVec4 ? + makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) : + makeMatMulPackedSource( + elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner, false, undefined, + sequentialAccessByThreads)}` + }; + }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index ec6df438129fb..4c8922238ac5b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -197,14 +197,14 @@ const createConvTranspose2DOpProgramShaderSource = continue; } let idyC: u32 = u32(dyC); - + var inputChannel = groupId * ${inputChannelsPerGroup}; for (var d2: u32 = 0; d2 < ${inputChannelsPerGroup}; d2 = d2 + 1) { - let inputChannel = groupId * ${inputChannelsPerGroup} + d2; let xValue = ${ isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') : dy.get('batch', 'inputChannel', 'idyR', 'idyC')}; let wValue = ${w.get('inputChannel', 'wOutChannel', 'u32(wRPerm)', 'u32(wCPerm)')}; dotProd = dotProd + xValue * wValue; + inputChannel = inputChannel + 1; } } } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index e7d1ddf771650..5641386cce849 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -8,7 +8,9 @@ import {ComputeContext, GpuDataType, ProgramInfoLoader, ProgramMetadata} from '. import {createConvTranspose2DProgramInfo} from './3rd-party/conv_backprop_webgpu'; import {ConvAttributes} from './conv'; +import {createConv2DTransposeMatMulProgramInfoLoader} from './conv2dtranspose-mm'; import {parseInternalActivationAttributes} from './fuse-utils'; +import {createTransposeProgramInfo, TransposeAttributes, transposeProgramMetadata} from './transpose'; const computeTotalPad = (inDim: number, stride: number, adj: number, kernel: number, dilation: number, outSize: number) => @@ -63,7 +65,7 @@ const getAdjustedConvTransposeAttributes = (attributes: T, inputs: readonly TensorView[]): T => { const kernelShape = attributes.kernelShape.slice(); // if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims - if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 0) === 0) { + if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 1) === 0) { kernelShape.length = 0; for (let i = 2; i < inputs[1].dims.length; ++i) { kernelShape.push(inputs[1].dims[i]); @@ -95,9 +97,11 @@ const getAdjustedConvTransposeAttributes = // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - Object.assign( - newAttributes, - {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey: attributes.cacheKey}); + const cacheKey = attributes.cacheKey + [ + kernelShape.join('n,'), pads.join(','), strides.join(','), outputPadding.join(','), outputShape.join(','), + dilations.join(',') + ].join('_'); + Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey}); return newAttributes; }; @@ -226,12 +230,64 @@ const createConvTranspose2DProgramInfoLoader = }; }; +// for transposing weight tensor from [C, M/group, KH, KW] to [KH, KW, M/group, C] +const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [2, 3, 1, 0]}); + const convTranspose2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => { const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); + const isChannelsLast = attributes.format === 'NHWC'; + const hasBias = inputs.length === 3; + if (adjustedAttributes.group !== 1) { + context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes)); + return; + } + const outputShape = adjustedAttributes.outputShape; + const outHeight = outputShape[isChannelsLast ? 1 : 2]; + const outWidth = outputShape[isChannelsLast ? 2 : 3]; + const outChannels = outputShape[isChannelsLast ? 3 : 1]; + const weightHeight = inputs[1].dims[2]; + const weightWidth = inputs[1].dims[3]; + const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; + + const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; + const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; + const dimInner = weightHeight * weightWidth * inputChannels; + + const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true; + - context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes)); + // STEP.1: transpose weight + const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? + context.compute( + { + ...transposeProgramMetadata, + cacheHint: weightTransposeAttribute.cacheKey, + get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm) + }, + {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; + if (attributes.wIsConst && !context.kernelCustomData.wT) { + context.kernelCustomData.wT = transposedWeight; + } + + // STEP.2: prepare reshaped inputs + const convTransposeInputs = [inputs[0], transposedWeight]; + if (hasBias) { + if (!isChannelsLast && inputs[2].dims.length === 1) { + convTransposeInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1])); + } else { + convTransposeInputs.push(inputs[2]); + } + } + + // STEP.3: compute matmul + context.compute( + createConv2DTransposeMatMulProgramInfoLoader( + convTransposeInputs, adjustedAttributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, + sequentialAccessByThreads), + {inputs: convTransposeInputs}); }; + const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttributes): void => { // extend the input to 2D by adding H dimension const isChannelLast = attributes.format === 'NHWC'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts new file mode 100644 index 0000000000000..da04b5063a9f0 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types'; + +import {createConv2DTransposeMatMulProgramInfo} from './3rd-party/conv_backprop_mm_webgpu'; +import {ConvTransposeAttributes} from './conv-transpose'; + + +const createConv2DTransposeMatMulProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({ + name: 'Conv2DTransposeMatMul', + inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : + [GpuDataType.default, GpuDataType.default], + cacheHint +}); + +export const createConv2DTransposeMatMulProgramInfoLoader = + (inputs: readonly TensorView[], attributes: ConvTransposeAttributes, outputShape: readonly number[], + dimAOuter: number, dimBOuter: number, dimInner: number, hasBias: boolean, + sequentialAccessByThreads: boolean): ProgramInfoLoader => { + const metadata = createConv2DTransposeMatMulProgramMetadata(hasBias, attributes.cacheKey); + return { + ...metadata, + get: () => createConv2DTransposeMatMulProgramInfo( + inputs, metadata, attributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, + sequentialAccessByThreads) + }; + }; diff --git a/js/web/test/data/ops/conv-transpose.jsonc b/js/web/test/data/ops/conv-transpose.jsonc index a249dc807fa0b..7038e2a4f8766 100644 --- a/js/web/test/data/ops/conv-transpose.jsonc +++ b/js/web/test/data/ops/conv-transpose.jsonc @@ -28,6 +28,37 @@ } ] }, + { + "name": "ConvTranspose without bias addition A - NHWC", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "operator": "ConvTranspose", + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40], + "dims": [1, 1, 2, 2], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [10, 40, 40, 60, 200, 160, 90, 240, 160], + "dims": [1, 1, 3, 3], + "type": "float32" + } + ] + } + ] + }, { "name": "ConvTranspose without bias addition B", "operator": "ConvTranspose", @@ -74,26 +105,22 @@ }, { "data": [ - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 ], "dims": [4, 4, 2, 2], "type": "float32" }, { - "data": [0.1, 0.2, 0.3, 0.4], + "data": [65, 66, 67, 68], "dims": [4], "type": "float32" } ], "outputs": [ { - "data": [ - 100.0999984741211, 100.0999984741211, 100.0999984741211, 100.0999984741211, 100.19999694824219, - 100.19999694824219, 100.19999694824219, 100.19999694824219, 100.30000305175781, 100.30000305175781, - 100.30000305175781, 100.30000305175781, 100.4000015258789, 100.4000015258789, 100.4000015258789, - 100.4000015258789 - ], + "data": [3365, 3465, 3565, 3665, 3766, 3866, 3966, 4066, 4167, 4267, 4367, 4467, 4568, 4668, 4768, 4868], "dims": [1, 4, 2, 2], "type": "float32" } @@ -115,7 +142,43 @@ "type": "float32" }, { - "data": [1, 1, 1, 1], + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + }, + { + "data": [5], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [11, 25, 28, 19, 32, 86, 99, 55, 40, 114, 131, 67, 29, 73, 80, 41], + "dims": [1, 1, 4, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose with bias addition B - NHWC", + "operator": "ConvTranspose", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [6, 8, 7, 9, 15, 11, 8, 12, 9], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], "dims": [1, 1, 2, 2], "type": "float32" }, @@ -127,7 +190,7 @@ ], "outputs": [ { - "data": [11, 19, 20, 12, 20, 43, 46, 23, 22, 49, 52, 25, 13, 25, 26, 14], + "data": [11, 25, 28, 19, 32, 86, 99, 55, 40, 114, 131, 67, 29, 73, 80, 41], "dims": [1, 1, 4, 4], "type": "float32" } @@ -251,7 +314,6 @@ } ] }, - { "name": "ConvTranspose- pointwise", "operator": "ConvTranspose", @@ -285,5 +347,50 @@ ] } ] + }, + { + "name": "ConvTranspose with bias addition C", + "operator": "ConvTranspose", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [{ "name": "kernel_shape", "data": [1, 1], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ], + "dims": [1, 4, 4, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [4, 4, 1, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1021, 1049, 1077, 1105, 1133, 1161, 1189, 1217, 1245, 1273, 1301, 1329, 1357, 1385, 1413, 1441, 1122, + 1154, 1186, 1218, 1250, 1282, 1314, 1346, 1378, 1410, 1442, 1474, 1506, 1538, 1570, 1602, 1223, 1259, + 1295, 1331, 1367, 1403, 1439, 1475, 1511, 1547, 1583, 1619, 1655, 1691, 1727, 1763, 1324, 1364, 1404, + 1444, 1484, 1524, 1564, 1604, 1644, 1684, 1724, 1764, 1804, 1844, 1884, 1924 + ], + "dims": [1, 4, 4, 4], + "type": "float32" + } + ] + } + ] } ] diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index a5aeae8646373..c3babbc5ce81f 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -108,6 +108,23 @@ class ConvTranspose : public JsKernel { } } + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* /* prepacked_weights */) override { + is_packed = false; + + if (input_idx == 1) { + // Only handle the common case of conv2D + if (tensor.Shape().NumDimensions() != 4 || tensor.SizeInBytes() == 0) { + return Status::OK(); + } + + w_is_const_ = true; + } + + return Status::OK(); + } + protected: ConvTransposeAttributes conv_transpose_attrs_; bool w_is_const_; From 561aca97cfcf76ce6d190a2403cae34c17bee75a Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 29 Sep 2023 11:24:42 -0700 Subject: [PATCH 09/20] [js/webgpu] support IO binding (#17480) **This PR is based on a few prerequisites PRs. They are listed as below:** - #17465 - #17469 - #17470 - #17472 - #17473 - #17484 Please review the current change by only looking at commit e2e6623e673ec6de55a5c1f8edcbd3a46b535a89 and later. ### Description This PR introduces WebGPU IO binding. This new feature allows onnxruntime-web users to use tensors created from GPU as model input/output so that a model inferencing can be done without unnecessary data copy between CPU and GPU for model input/output. ### Examples An E2E demo/example is being worked on. Following is some simple demo with code snippet. Let's first check today how we do: ```js // STEP.1 - create an inference session: const mySession = await ort.InferenceSession.create('./my_model.onnx', { executionProviders: ['webgpu'] }); // STEP.2 - create model input: (supposing myImageCpuData is a Float32Array) const feeds = { 'input_image:0': new ort.Tensor('float32', myImageCpuData, [1, 224, 224, 3]) }; // STEP.3 - run model const myResults = await mySession.run(feeds); // STEP.4 - get output data const myData = myResults['output_image:0'].data; // Float32Array ``` #### for inputs (GPU tensor): Now, with IO binding, you can create a tensor from a GPU buffer, and feed it to the model: ```js // new STEP.2.A - create model input from a GPU buffer: (supposing myInputGpuBuffer is a `GPUBuffer` object with input data) const feeds = { 'input_image:0': ort.Tensor.fromGpuBuffer(myInputGpuBuffer, { dataType: 'float32', dims: [1, 224, 224, 3] }) }; ``` ### for outputs (pre-allocated GPU tensor) you can also do that for output, **if you know the output shape**: ```js // new STEP.2.B - create model output from a GPU buffer: (supposing myOutputGpuBuffer is a pre-allocated `GPUBuffer` object) const fetches = { 'output_image:0': ort.Tensor.fromGpuBuffer(myOutputGpuBuffer, { dataType: 'float32', dims: [1, 512, 512, 3] }) }; // new STEP.3 - run model with pre-allocated output (fetches) const myResults = await mySession.run(feeds, fetches); ``` ### for outputs (specify location) if you do not know the output shape, you can specify the output location when creating the session: ```js // new STEP.1 - create an inference session with an option "preferredOutputLocation": const mySession = await ort.InferenceSession.create('./my_model.onnx', { executionProviders: ['webgpu'], preferredOutputLocation: "gpu-buffer" }); ``` if the model has multiple outputs, you can specify them seperately: ```js // new STEP.1 - create an inference session with an option "preferredOutputLocation": const mySession = await ort.InferenceSession.create('./my_model.onnx', { executionProviders: ['webgpu'], preferredOutputLocation: { "output_image:0": "gpu-buffer" } }); ``` now you don't need to prepare the `fetches` object and onnxruntime-web will prepare output data on the location that specified. #### read data when you get the output tensor, you can: ```js // get the gpu buffer object: const gpuBuffer = myOutputTensor.gpuBuffer; // GPUBuffer // get the CPU data asynchronizely const cpuData = await myOutputTensor.getData(); // get the CPU data asynchronizely and release the underlying GPU resources const cpuData = await myOutputTensor.getData(true); // dispose the tensor (release the underlying GPU resources). This tensor object will be invalid after dispose() is called. myOutputTensor.dispose(); ``` #### resource management JavaScript has GC so you don't need to worry about managing JavaScript objects. But there are 2 types of resources that are not managed by GC: - GPU buffer that used in tensors - Underlying ORT native resources To simplify, most of the unmanaged resources and handled inside ORT web. But there are a few resources that need users to manage: - All external GPU resources, including GPU buffers inside all tensors created by `Tensor.fromGpuBuffer()`, will not be managed by ORT. User should manage those GPU buffers themselves. - When a session is created with `preferredOutputLocation` == "gpu-buffer" specified in session options, and the corresponding output is not pre-allocated, user need to call the output tensor's `dispose()` or `getData(true)` to manually release the underlying GPU buffers. - ORT internal errors (including providing a pre-allocated output tensor with wrong type/dims) will invalidate the whole wasm memory and is not recoverable. An exception is thrown in this situation. --- js/web/lib/wasm/binding/ort-wasm.d.ts | 80 +++- js/web/lib/wasm/jsep/backend-webgpu.ts | 66 ++- js/web/lib/wasm/jsep/init.ts | 15 +- .../lib/wasm/jsep/webgpu/gpu-data-manager.ts | 125 ++++- js/web/lib/wasm/proxy-messages.ts | 32 +- js/web/lib/wasm/proxy-wrapper.ts | 30 +- js/web/lib/wasm/session-handler.ts | 51 +- js/web/lib/wasm/wasm-common.ts | 32 ++ js/web/lib/wasm/wasm-core-impl.ts | 434 ++++++++++++------ js/web/script/test-runner-cli-args.ts | 16 + js/web/script/test-runner-cli.ts | 21 +- js/web/test/test-runner.ts | 181 +++++++- js/web/test/test-types.ts | 14 + onnxruntime/core/providers/js/js_kernel.h | 2 +- onnxruntime/wasm/api.cc | 113 ++++- onnxruntime/wasm/api.h | 58 ++- onnxruntime/wasm/js_internal_api.js | 178 +++++-- .../azure-pipelines/templates/win-web-ci.yml | 17 +- 18 files changed, 1177 insertions(+), 288 deletions(-) diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 59da1369e152e..b7b2ff4537095 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import type {Tensor} from 'onnxruntime-common'; + export declare namespace JSEP { type BackendType = unknown; type AllocFunction = (size: number) => number; @@ -9,11 +11,8 @@ export declare namespace JSEP { type DownloadFunction = (gpuDataId: number, dataOffset: number, size: number) => Promise; type CreateKernelFunction = (name: string, kernel: number, attribute: unknown) => void; type ReleaseKernelFunction = (kernel: number) => void; - type RunFunction = (kernel: number, contextDataOffset: number, sessionState: SessionState) => number; - export interface SessionState { - sessionId: number; - errors: Array>; - } + type RunFunction = + (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => number; } export interface OrtWasmModule extends EmscriptenModule { @@ -40,14 +39,23 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtFree(stringHandle: number): void; - _OrtCreateTensor(dataType: number, dataOffset: number, dataLength: number, dimsOffset: number, dimsLength: number): - number; + _OrtCreateTensor( + dataType: number, dataOffset: number, dataLength: number, dimsOffset: number, dimsLength: number, + dataLocation: number): number; _OrtGetTensorData(tensorHandle: number, dataType: number, dataOffset: number, dimsOffset: number, dimsLength: number): number; _OrtReleaseTensor(tensorHandle: number): void; + _OrtCreateBinding(sessionHandle: number): number; + _OrtBindInput(bindingHandle: number, nameOffset: number, tensorHandle: number): Promise; + _OrtBindOutput(bindingHandle: number, nameOffset: number, tensorHandle: number, location: number): number; + _OrtClearBoundOutputs(ioBindingHandle: number): void; + _OrtReleaseBinding(ioBindingHandle: number): void; + _OrtRunWithBinding( + sessionHandle: number, ioBindingHandle: number, outputCount: number, outputsOffset: number, + runOptionsHandle: number): Promise; _OrtRun( sessionHandle: number, inputNamesOffset: number, inputsOffset: number, inputCount: number, - outputNamesOffset: number, outputCount: number, outputsOffset: number, runOptionsHandle: number): number; + outputNamesOffset: number, outputCount: number, outputsOffset: number, runOptionsHandle: number): Promise; _OrtCreateSessionOptions( graphOptimizationLevel: number, enableCpuMemArena: boolean, enableMemPattern: boolean, executionMode: number, @@ -102,17 +110,67 @@ export interface OrtWasmModule extends EmscriptenModule { // #endregion // #region JSEP + /** + * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime. + * This function initializes WebGPU backend and registers a few callbacks that will be called in C++ code. + */ jsepInit? (backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction, download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction, releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void; + /** + * [exported from wasm] Specify a kernel's output when running OpKernel::Compute(). + * + * @param context - specify the kernel context pointer. + * @param index - specify the index of the output. + * @param data - specify the pointer to encoded data of type and dims. + */ _JsepOutput(context: number, index: number, data: number): number; + /** + * [exported from wasm] Get name of an operator node. + * + * @param kernel - specify the kernel pointer. + * @returns the pointer to a C-style UTF8 encoded string representing the node name. + */ _JsepGetNodeName(kernel: number): number; - jsepOnRunStart?(sessionId: number): void; - jsepOnRunEnd?(sessionId: number): Promise; - jsepRunPromise?: Promise; + /** + * [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output. + * + * @param sessionId - specify the session ID. + * @param index - specify an integer to represent which input/output it is registering for. For input, it is the + * input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index + * corresponding to the session's ouputNames. + * @param buffer - specify the GPU buffer to register. + * @param size - specify the original data size in byte. + * @returns the GPU data ID for the registered GPU buffer. + */ + jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number; + /** + * [exported from js_internal_api.js] Unregister all user GPU buffers for a session. + * + * @param sessionId - specify the session ID. + */ + jsepUnregisterBuffers?: (sessionId: number) => void; + /** + * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID. + * + * @param dataId - specify the GPU data ID + * @returns the GPU buffer. + */ + jsepGetBuffer: (dataId: number) => GPUBuffer; + /** + * [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor. + * + * @param gpuBuffer - specify the GPU buffer + * @param size - specify the original data size in byte. + * @param type - specify the tensor type. + * @returns the generated downloader function. + */ + jsepCreateDownloader: + (gpuBuffer: GPUBuffer, size: number, + type: Tensor.GpuBufferDataTypes) => () => Promise; // #endregion } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 5e77a0343b4ee..5bec562b157ac 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Env} from 'onnxruntime-common'; +import {Env, Tensor} from 'onnxruntime-common'; import {configureLogger, LOG_DEBUG} from './log'; -import {TensorView} from './tensor-view'; -import {createGpuDataManager, GpuDataManager} from './webgpu/gpu-data-manager'; +import {createView, TensorView} from './tensor-view'; +import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager'; import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules'; import {ProgramManager} from './webgpu/program-manager'; import {ComputeContext, GpuData, ProgramInfo, ProgramInfoLoader} from './webgpu/types'; @@ -98,6 +98,11 @@ export class WebGpuBackend { env: Env; + /** + * a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping. + */ + sessionExternalDataMapping: Map> = new Map(); + async initialize(env: Env): Promise { if (!navigator.gpu) { // WebGPU is not available. @@ -192,11 +197,13 @@ export class WebGpuBackend { } flush(): void { - this.endComputePass(); - this.device.queue.submit([this.getCommandEncoder().finish()]); - this.gpuDataManager.refreshPendingBuffers(); - this.commandEncoder = null; - this.pendingDispatchNumber = 0; + if (this.commandEncoder) { + this.endComputePass(); + this.device.queue.submit([this.getCommandEncoder().finish()]); + this.gpuDataManager.refreshPendingBuffers(); + this.commandEncoder = null; + this.pendingDispatchNumber = 0; + } } /** @@ -304,12 +311,9 @@ export class WebGpuBackend { } async download(gpuDataId: number, getTargetBuffer: () => Uint8Array): Promise { - const arrayBuffer = await this.gpuDataManager.download(gpuDataId); - // the underlying buffer may be changed after the async function is called. so we use a getter function to make sure // the buffer is up-to-date. - const data = getTargetBuffer(); - data.set(new Uint8Array(arrayBuffer, 0, data.byteLength)); + await this.gpuDataManager.download(gpuDataId, getTargetBuffer); } alloc(size: number): number { @@ -372,7 +376,7 @@ export class WebGpuBackend { kernelEntry(context, attributes[1]); return 0; // ORT_OK } catch (e) { - LOG_DEBUG('warning', `[WebGPU] Kernel "[${opType}] ${nodeName}" failed. Error: ${e}`); + errors.push(Promise.resolve(`[WebGPU] Kernel "[${opType}] ${nodeName}" failed. ${e}`)); return 1; // ORT_FAIL } finally { if (useErrorScope) { @@ -387,4 +391,40 @@ export class WebGpuBackend { this.currentKernelId = null; } } + + // #region external buffer + registerBuffer(sessionId: number, index: number, buffer: GPUBuffer, size: number): number { + let sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId); + if (!sessionInputOutputMapping) { + sessionInputOutputMapping = new Map(); + this.sessionExternalDataMapping.set(sessionId, sessionInputOutputMapping); + } + + const previousBuffer = sessionInputOutputMapping.get(index); + const id = this.gpuDataManager.registerExternalBuffer(buffer, size, previousBuffer?.[1]); + sessionInputOutputMapping.set(index, [id, buffer]); + return id; + } + unregisterBuffers(sessionId: number): void { + const sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId); + if (sessionInputOutputMapping) { + sessionInputOutputMapping.forEach(bufferInfo => this.gpuDataManager.unregisterExternalBuffer(bufferInfo[1])); + this.sessionExternalDataMapping.delete(sessionId); + } + } + getBuffer(gpuDataId: number): GPUBuffer { + const gpuData = this.gpuDataManager.get(gpuDataId); + if (!gpuData) { + throw new Error(`no GPU data for buffer: ${gpuDataId}`); + } + return gpuData.buffer; + } + createDownloader(gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes): + () => Promise { + return async () => { + const data = await downloadGpuData(this, gpuBuffer, size); + return createView(data.buffer, type); + }; + } + // #endregion } diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 78316cbe1c825..6ff3971d720fd 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -3,7 +3,7 @@ import {Env} from 'onnxruntime-common'; -import {JSEP, OrtWasmModule} from '../binding/ort-wasm'; +import {OrtWasmModule} from '../binding/ort-wasm'; import {DataType, getTensorElementSize} from '../wasm-common'; import {WebGpuBackend} from './backend-webgpu'; @@ -120,6 +120,11 @@ class ComputeContextImpl implements ComputeContext { this.module.HEAPU32[offset++] = dims[i]; } return this.module._JsepOutput(this.opKernelContext, index, data); + } catch (e) { + throw new Error( + `Failed to generate kernel's output[${index}] with dims [${dims}]. ` + + 'If you are running with pre-allocated output, please make sure the output type/dims are correct. ' + + `Error: ${e}`); } finally { this.module.stackRestore(stack); } @@ -138,7 +143,7 @@ export const init = async(module: OrtWasmModule, env: Env): Promise => { init( // backend - {backend}, + backend, // jsepAlloc() (size: number) => backend.alloc(size), @@ -178,13 +183,13 @@ export const init = async(module: OrtWasmModule, env: Env): Promise => { (kernel: number) => backend.releaseKernel(kernel), // jsepRun - (kernel: number, contextDataOffset: number, sessionState: JSEP.SessionState) => { + (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => { LOG_DEBUG( 'verbose', - () => `[WebGPU] jsepRun: sessionId=${sessionState.sessionId}, kernel=${kernel}, contextDataOffset=${ + () => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${ contextDataOffset}`); const context = new ComputeContextImpl(module, backend, contextDataOffset); - return backend.computeKernel(kernel, context, sessionState.errors); + return backend.computeKernel(kernel, context, errors); }); } }; diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 92fdd5abc3892..131f7a9bfa29b 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -35,7 +35,7 @@ export interface GpuDataManager { /** * copy data from GPU to CPU. */ - download(id: GpuDataId): Promise; + download(id: GpuDataId, getTargetBuffer: () => Uint8Array): Promise; /** * refresh the buffers that marked for release. @@ -46,6 +46,19 @@ export interface GpuDataManager { */ refreshPendingBuffers(): void; + /** + * register an external buffer for IO Binding. If the buffer is already registered, return the existing GPU data ID. + * + * GPU data manager only manages a mapping between the buffer and the GPU data ID. It will not manage the lifecycle of + * the external buffer. + */ + registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number; + + /** + * unregister an external buffer for IO Binding. + */ + unregisterExternalBuffer(buffer: GPUBuffer): void; + /** * destroy all gpu buffers. Call this when the session.release is called. */ @@ -62,12 +75,56 @@ interface StorageCacheValue { */ const calcNormalizedBufferSize = (size: number) => Math.ceil(size / 16) * 16; -let guid = 0; +let guid = 1; const createNewGpuDataId = () => guid++; +/** + * exported standard download function. This function is used by the session to download the data from GPU, and also by + * factory to create GPU tensors with the capacity of downloading data from GPU. + * + * @param backend - the WebGPU backend + * @param gpuBuffer - the GPU buffer to download + * @param originalSize - the original size of the data + * @param getTargetBuffer - optional. If provided, the data will be copied to the target buffer. Otherwise, a new buffer + * will be created and returned. + */ +export const downloadGpuData = + async(backend: WebGpuBackend, gpuBuffer: GPUBuffer, originalSize: number, getTargetBuffer?: () => Uint8Array): + Promise => { + const bufferSize = calcNormalizedBufferSize(originalSize); + const gpuReadBuffer = backend.device.createBuffer( + // eslint-disable-next-line no-bitwise + {size: bufferSize, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ}); + try { + const commandEncoder = backend.getCommandEncoder(); + backend.endComputePass(); + commandEncoder.copyBufferToBuffer( + gpuBuffer /* source buffer */, 0 /* source offset */, gpuReadBuffer /* destination buffer */, + 0 /* destination offset */, bufferSize /* size */ + ); + backend.flush(); + + await gpuReadBuffer.mapAsync(GPUMapMode.READ); + + const arrayBuffer = gpuReadBuffer.getMappedRange(); + if (getTargetBuffer) { + // if we already have a CPU buffer to accept the data, no need to clone the ArrayBuffer. + const targetBuffer = getTargetBuffer(); + targetBuffer.set(new Uint8Array(arrayBuffer, 0, originalSize)); + return targetBuffer; + } else { + // the mapped ArrayBuffer will be released when the GPU buffer is destroyed. Need to clone the + // ArrayBuffer. + return new Uint8Array(arrayBuffer.slice(0, originalSize)); + } + } finally { + gpuReadBuffer.destroy(); + } + }; + class GpuDataManagerImpl implements GpuDataManager { // GPU Data ID => GPU Data ( storage buffer ) - storageCache: Map; + private storageCache: Map; // pending buffers for uploading ( data is unmapped ) private buffersForUploadingPending: GPUBuffer[]; @@ -77,11 +134,15 @@ class GpuDataManagerImpl implements GpuDataManager { // The reusable storage buffers for computing. private freeBuffers: Map; + // The external buffers registered users for IO Binding. + private externalBuffers: Map; + constructor(private backend: WebGpuBackend) { this.storageCache = new Map(); this.freeBuffers = new Map(); this.buffersForUploadingPending = []; this.buffersPending = []; + this.externalBuffers = new Map(); } upload(id: GpuDataId, data: Uint8Array): void { @@ -143,6 +204,42 @@ class GpuDataManagerImpl implements GpuDataManager { sourceGpuDataCache.gpuData.buffer, 0, destinationGpuDataCache.gpuData.buffer, 0, size); } + registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number { + let id: number|undefined; + if (previousBuffer) { + id = this.externalBuffers.get(previousBuffer); + if (id === undefined) { + throw new Error('previous buffer is not registered'); + } + if (buffer === previousBuffer) { + LOG_DEBUG( + 'verbose', + () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${ + id}, buffer is the same, skip.`); + return id; + } + this.externalBuffers.delete(previousBuffer); + } else { + id = createNewGpuDataId(); + } + + this.storageCache.set(id, {gpuData: {id, type: GpuDataType.default, buffer}, originalSize}); + this.externalBuffers.set(buffer, id); + LOG_DEBUG( + 'verbose', + () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${id}, registered.`); + return id; + } + + unregisterExternalBuffer(buffer: GPUBuffer): void { + const id = this.externalBuffers.get(buffer); + if (id !== undefined) { + this.storageCache.delete(id); + this.externalBuffers.delete(buffer); + LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.unregisterExternalBuffer() => id=${id}`); + } + } + // eslint-disable-next-line no-bitwise create(size: number, usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST): GpuData { const bufferSize = calcNormalizedBufferSize(size); @@ -193,31 +290,13 @@ class GpuDataManagerImpl implements GpuDataManager { return cachedData.originalSize; } - async download(id: GpuDataId): Promise { + async download(id: GpuDataId, getTargetBuffer: () => Uint8Array): Promise { const cachedData = this.storageCache.get(id); if (!cachedData) { throw new Error('data does not exist'); } - const commandEncoder = this.backend.getCommandEncoder(); - this.backend.endComputePass(); - const bufferSize = calcNormalizedBufferSize(cachedData.originalSize); - const gpuReadBuffer = this.backend.device.createBuffer( - // eslint-disable-next-line no-bitwise - {size: bufferSize, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ}); - commandEncoder.copyBufferToBuffer( - cachedData.gpuData.buffer /* source buffer */, 0 /* source offset */, gpuReadBuffer /* destination buffer */, - 0 /* destination offset */, bufferSize /* size */ - ); - this.backend.flush(); - - return new Promise((resolve) => { - gpuReadBuffer.mapAsync(GPUMapMode.READ).then(() => { - const data = gpuReadBuffer.getMappedRange().slice(0); - gpuReadBuffer.destroy(); - resolve(data); - }); - }); + await downloadGpuData(this.backend, cachedData.gpuData.buffer, cachedData.originalSize, getTargetBuffer); } refreshPendingBuffers(): void { diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts index e5a2d8c2351b8..43f70c23f7193 100644 --- a/js/web/lib/wasm/proxy-messages.ts +++ b/js/web/lib/wasm/proxy-messages.ts @@ -3,20 +3,24 @@ import {Env, InferenceSession, Tensor} from 'onnxruntime-common'; -/** - * tuple elements are: ORT element type; dims; tensor data - */ -export type SerializableTensor = [Tensor.Type, readonly number[], Tensor.DataType]; +export type SerializableTensorMetadata = + [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu']; -/** - * tuple elements are: InferenceSession handle; input names; output names - */ -export type SerializableSessionMetadata = [number, string[], string[]]; +export type GpuBufferMetadata = { + gpuBuffer: Tensor.GpuBufferType; + download?: () => Promise; + dispose?: () => void; +}; -/** - * tuple elements are: modeldata.offset, modeldata.length - */ -export type SerializableModeldata = [number, number]; +export type UnserializableTensorMetadata = + [dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer']| + [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned']; + +export type TensorMetadata = SerializableTensorMetadata|UnserializableTensorMetadata; + +export type SerializableSessionMetadata = [sessionHandle: number, inputNames: string[], outputNames: string[]]; + +export type SerializableModeldata = [modelDataOffset: number, modelDataLength: number]; interface MessageError { err?: string; @@ -58,10 +62,10 @@ interface MessageReleaseSession extends MessageError { interface MessageRun extends MessageError { type: 'run'; in ?: { - sessionId: number; inputIndices: number[]; inputs: SerializableTensor[]; outputIndices: number[]; + sessionId: number; inputIndices: number[]; inputs: SerializableTensorMetadata[]; outputIndices: number[]; options: InferenceSession.RunOptions; }; - out?: SerializableTensor[]; + out?: SerializableTensorMetadata[]; } interface MesssageEndProfiling extends MessageError { diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 815b223e40379..202209ed3bfed 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -3,7 +3,7 @@ import {Env, env, InferenceSession} from 'onnxruntime-common'; -import {OrtWasmMessage, SerializableModeldata, SerializableSessionMetadata, SerializableTensor} from './proxy-messages'; +import {OrtWasmMessage, SerializableModeldata, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; import * as core from './wasm-core-impl'; import {initializeWebAssembly} from './wasm-factory'; @@ -22,7 +22,7 @@ const createSessionAllocateCallbacks: Array> = []; const createSessionCallbacks: Array> = []; const releaseSessionCallbacks: Array> = []; -const runCallbacks: Array> = []; +const runCallbacks: Array> = []; const endProfilingCallbacks: Array> = []; const ensureWorker = (): void => { @@ -177,6 +177,10 @@ export const createSessionFinalize = async(modeldata: SerializableModeldata, opt export const createSession = async(model: Uint8Array, options?: InferenceSession.SessionOptions): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { + // check unsupported options + if (options?.preferredOutputLocation) { + throw new Error('session option "preferredOutputLocation" is not supported for proxy.'); + } ensureWorker(); return new Promise((resolve, reject) => { createSessionCallbacks.push([resolve, reject]); @@ -202,17 +206,27 @@ export const releaseSession = async(sessionId: number): Promise => { }; export const run = async( - sessionId: number, inputIndices: number[], inputs: SerializableTensor[], outputIndices: number[], - options: InferenceSession.RunOptions): Promise => { + sessionId: number, inputIndices: number[], inputs: TensorMetadata[], outputIndices: number[], + outputs: Array, options: InferenceSession.RunOptions): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { + // check inputs location + if (inputs.some(t => t[3] !== 'cpu')) { + throw new Error('input tensor on GPU is not supported for proxy.'); + } + // check outputs location + if (outputs.some(t => t)) { + throw new Error('pre-allocated output tensor is not supported for proxy.'); + } ensureWorker(); - return new Promise((resolve, reject) => { + return new Promise((resolve, reject) => { runCallbacks.push([resolve, reject]); - const message: OrtWasmMessage = {type: 'run', in : {sessionId, inputIndices, inputs, outputIndices, options}}; - proxyWorker!.postMessage(message, core.extractTransferableBuffers(inputs)); + const serializableInputs = inputs as SerializableTensorMetadata[]; // every input is on CPU. + const message: OrtWasmMessage = + {type: 'run', in : {sessionId, inputIndices, inputs: serializableInputs, outputIndices, options}}; + proxyWorker!.postMessage(message, core.extractTransferableBuffers(serializableInputs)); }); } else { - return core.run(sessionId, inputIndices, inputs, outputIndices, options); + return core.run(sessionId, inputIndices, inputs, outputIndices, outputs, options); } }; diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler.ts index d8c5ae7886fe4..4e00878d0063b 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler.ts @@ -5,12 +5,41 @@ import {readFile} from 'fs'; import {env, InferenceSession, SessionHandler, Tensor} from 'onnxruntime-common'; import {promisify} from 'util'; -import {SerializableModeldata} from './proxy-messages'; +import {SerializableModeldata, TensorMetadata} from './proxy-messages'; import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper'; +import {isGpuBufferSupportedType} from './wasm-common'; let runtimeInitialized: boolean; let runtimeInitializationPromise: Promise|undefined; +const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { + switch (tensor.location) { + case 'cpu': + return [tensor.type, tensor.dims, tensor.data, 'cpu']; + case 'gpu-buffer': + return [tensor.type, tensor.dims, {gpuBuffer: tensor.gpuBuffer}, 'gpu-buffer']; + default: + throw new Error(`invalid data location: ${tensor.location} for ${getName()}`); + } +}; + +const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => { + switch (tensor[3]) { + case 'cpu': + return new Tensor(tensor[0], tensor[2], tensor[1]); + case 'gpu-buffer': { + const dataType = tensor[0]; + if (!isGpuBufferSupportedType(dataType)) { + throw new Error(`not supported data type: ${dataType} for deserializing GPU tensor`); + } + const {gpuBuffer, download, dispose} = tensor[2]; + return Tensor.fromGpuBuffer(gpuBuffer, {dataType, dims: tensor[1], download, dispose}); + } + default: + throw new Error(`invalid data location: ${tensor[3]}`); + } +}; + export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler { private sessionId: number; @@ -74,25 +103,31 @@ export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler { inputIndices.push(index); }); + const outputArray: Array = []; const outputIndices: number[] = []; Object.entries(fetches).forEach(kvp => { const name = kvp[0]; - // TODO: support pre-allocated output + const tensor = kvp[1]; const index = this.outputNames.indexOf(name); if (index === -1) { throw new Error(`invalid output '${name}'`); } + outputArray.push(tensor); outputIndices.push(index); }); - const outputs = - await run(this.sessionId, inputIndices, inputArray.map(t => [t.type, t.dims, t.data]), outputIndices, options); + const inputs = + inputArray.map((t, i) => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`)); + const outputs = outputArray.map( + (t, i) => t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null); + + const results = await run(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); - const result: SessionHandler.ReturnType = {}; - for (let i = 0; i < outputs.length; i++) { - result[this.outputNames[outputIndices[i]]] = new Tensor(outputs[i][0], outputs[i][2], outputs[i][1]); + const resultMap: SessionHandler.ReturnType = {}; + for (let i = 0; i < results.length; i++) { + resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]); } - return result; + return resultMap; } startProfiling(): void { diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index 389773f3e8884..b9eff45e890c4 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -164,3 +164,35 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro throw new Error(`unsupported logging level: ${logLevel}`); } }; + +/** + * Check whether the given tensor type is supported by GPU buffer + */ +export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuBufferDataTypes => type === 'float32' || + type === 'int32' || type === 'int64' || type === 'bool' || type === 'float16' || type === 'uint32'; + +/** + * Map string data location to integer value + */ +export const dataLocationStringToEnum = (location: Tensor.DataLocation): number => { + switch (location) { + case 'none': + return 0; + case 'cpu': + return 1; + case 'cpu-pinned': + return 2; + case 'texture': + return 3; + case 'gpu-buffer': + return 4; + default: + throw new Error(`unsupported data location: ${location}`); + } +}; + +/** + * Map integer data location to string value + */ +export const dataLocationEnumToString = (location: number): Tensor.DataLocation|undefined => + (['none', 'cpu', 'cpu-pinned', 'texture', 'gpu-buffer'] as const)[location]; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index fcca82ab2aa54..5b49a1d4202e3 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -3,10 +3,10 @@ import {Env, InferenceSession, Tensor} from 'onnxruntime-common'; -import {SerializableModeldata, SerializableSessionMetadata, SerializableTensor} from './proxy-messages'; +import {SerializableModeldata, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; -import {logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; +import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; import {getInstance} from './wasm-factory'; import {allocWasmString, checkLastError} from './wasm-utils'; @@ -60,9 +60,36 @@ export const initRuntime = async(env: Env): Promise => { }; /** - * tuple elements are: InferenceSession ID; inputNamesUTF8Encoded; outputNamesUTF8Encoded + * valid data locations for input/output tensors. */ -type SessionMetadata = [number, number[], number[]]; +type SupportedTensorDataLocationForInputOutput = 'cpu'|'cpu-pinned'|'gpu-buffer'; + +type IOBindingState = { + /** + * the handle of IO binding. + */ + readonly handle: number; + + /** + * the preferred location for each output tensor. + * + * value is one of 'cpu', 'cpu-pinned', 'gpu-buffer'. + */ + readonly outputPreferredLocations: readonly SupportedTensorDataLocationForInputOutput[]; + + /** + * enum value of the preferred location for each output tensor. + */ + readonly outputPreferredLocationsEncoded: readonly number[]; +}; + +/** + * tuple elements are: InferenceSession ID; inputNamesUTF8Encoded; outputNamesUTF8Encoded; bindingState + */ +type SessionMetadata = [ + inferenceSessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[], + bindingState: IOBindingState|null +]; const activeSessions = new Map(); @@ -92,6 +119,7 @@ export const createSessionFinalize = let sessionHandle = 0; let sessionOptionsHandle = 0; + let ioBindingHandle = 0; let allocs: number[] = []; const inputNamesUTF8Encoded = []; const outputNamesUTF8Encoded = []; @@ -108,6 +136,7 @@ export const createSessionFinalize = const inputNames = []; const outputNames = []; + const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = []; for (let i = 0; i < inputCount; i++) { const name = wasm._OrtGetInputName(sessionHandle, i); if (name === 0) { @@ -122,15 +151,45 @@ export const createSessionFinalize = checkLastError('Can\'t get an output name.'); } outputNamesUTF8Encoded.push(name); - outputNames.push(wasm.UTF8ToString(name)); + const nameString = wasm.UTF8ToString(name); + outputNames.push(nameString); + + if (!BUILD_DEFS.DISABLE_WEBGPU) { + const location = typeof options?.preferredOutputLocation === 'string' ? + options.preferredOutputLocation : + options?.preferredOutputLocation?.[nameString] ?? 'cpu'; + if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') { + throw new Error(`Not supported preferred output location: ${location}.`); + } + outputPreferredLocations.push(location); + } + } + + // use IO binding only when at least one output is preffered to be on GPU. + let bindingState: IOBindingState|null = null; + if (!BUILD_DEFS.DISABLE_WEBGPU && outputPreferredLocations.some(l => l === 'gpu-buffer')) { + ioBindingHandle = wasm._OrtCreateBinding(sessionHandle); + if (ioBindingHandle === 0) { + checkLastError('Can\'t create IO binding.'); + } + + bindingState = { + handle: ioBindingHandle, + outputPreferredLocations, + outputPreferredLocationsEncoded: outputPreferredLocations.map(l => dataLocationStringToEnum(l)), + }; } - activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded]); + activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState]); return [sessionHandle, inputNames, outputNames]; } catch (e) { inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + if (ioBindingHandle !== 0) { + wasm._OrtReleaseBinding(ioBindingHandle); + } + if (sessionHandle !== 0) { wasm._OrtReleaseSession(sessionHandle); } @@ -161,7 +220,13 @@ export const releaseSession = (sessionId: number): void => { if (!session) { throw new Error(`cannot release session. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded] = session; + const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + + if (ioBindingState) { + wasm._OrtReleaseBinding(ioBindingState.handle); + } + + wasm.jsepUnregisterBuffers?.(sessionId); inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); @@ -169,18 +234,84 @@ export const releaseSession = (sessionId: number): void => { activeSessions.delete(sessionId); }; +const prepareInputOutputTensor = + (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number): + void => { + if (!tensor) { + tensorHandles.push(0); + return; + } + + const wasm = getInstance(); + + const dataType = tensor[0]; + const dims = tensor[1]; + const location = tensor[3]; + + let rawData: number; + let dataByteLength: number; + + if (dataType === 'string' && location === 'gpu-buffer') { + throw new Error('String tensor is not supported on GPU.'); + } + + if (location === 'gpu-buffer') { + const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; + const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; + dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; + rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength); + } else { + const data = tensor[2]; + + if (Array.isArray(data)) { + // string tensor + dataByteLength = 4 * data.length; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + let dataIndex = rawData / 4; + for (let i = 0; i < data.length; i++) { + if (typeof data[i] !== 'string') { + throw new TypeError(`tensor data at index ${i} is not a string`); + } + wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); + } + } else { + dataByteLength = data.byteLength; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + } + } + + const stack = wasm.stackSave(); + const dimsOffset = wasm.stackAlloc(4 * dims.length); + try { + let dimIndex = dimsOffset / 4; + dims.forEach(d => wasm.HEAP32[dimIndex++] = d); + const tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, + dataLocationStringToEnum(location)); + if (tensor === 0) { + checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + } + tensorHandles.push(tensor); + } finally { + wasm.stackRestore(stack); + } + }; + /** * perform inference run */ export const run = async( - sessionId: number, inputIndices: number[], inputs: SerializableTensor[], outputIndices: number[], - options: InferenceSession.RunOptions): Promise => { + sessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], + outputTensors: Array, options: InferenceSession.RunOptions): Promise => { const wasm = getInstance(); const session = activeSessions.get(sessionId); if (!session) { throw new Error(`cannot run inference. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded] = session; + const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; const inputCount = inputIndices.length; const outputCount = outputIndices.length; @@ -188,171 +319,200 @@ export const run = async( let runOptionsHandle = 0; let runOptionsAllocs: number[] = []; - const inputValues: number[] = []; - const inputAllocs: number[] = []; + const inputTensorHandles: number[] = []; + const outputTensorHandles: number[] = []; + const inputOutputAllocs: number[] = []; + + const beforeRunStack = wasm.stackSave(); + const inputValuesOffset = wasm.stackAlloc(inputCount * 4); + const inputNamesOffset = wasm.stackAlloc(inputCount * 4); + const outputValuesOffset = wasm.stackAlloc(outputCount * 4); + const outputNamesOffset = wasm.stackAlloc(outputCount * 4); try { [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); // create input tensors for (let i = 0; i < inputCount; i++) { - const dataType = inputs[i][0]; - const dims = inputs[i][1]; - const data = inputs[i][2]; - - let dataOffset: number; - let dataByteLength: number; - - if (Array.isArray(data)) { - // string tensor - dataByteLength = 4 * data.length; - dataOffset = wasm._malloc(dataByteLength); - inputAllocs.push(dataOffset); - let dataIndex = dataOffset / 4; - for (let i = 0; i < data.length; i++) { - if (typeof data[i] !== 'string') { - throw new TypeError(`tensor data at index ${i} is not a string`); - } - wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], inputAllocs); - } - } else { - dataByteLength = data.byteLength; - dataOffset = wasm._malloc(dataByteLength); - inputAllocs.push(dataOffset); - wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), dataOffset); - } + prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i]); + } - const stack = wasm.stackSave(); - const dimsOffset = wasm.stackAlloc(4 * dims.length); - try { - let dimIndex = dimsOffset / 4; - dims.forEach(d => wasm.HEAP32[dimIndex++] = d); - const tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(dataType), dataOffset, dataByteLength, dimsOffset, dims.length); - if (tensor === 0) { - checkLastError(`Can't create tensor for input[${i}].`); - } - inputValues.push(tensor); - } finally { - wasm.stackRestore(stack); - } + // create output tensors + for (let i = 0; i < outputCount; i++) { + prepareInputOutputTensor( + outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i]); + } + + let inputValuesIndex = inputValuesOffset / 4; + let inputNamesIndex = inputNamesOffset / 4; + let outputValuesIndex = outputValuesOffset / 4; + let outputNamesIndex = outputNamesOffset / 4; + for (let i = 0; i < inputCount; i++) { + wasm.HEAPU32[inputValuesIndex++] = inputTensorHandles[i]; + wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]]; + } + for (let i = 0; i < outputCount; i++) { + wasm.HEAPU32[outputValuesIndex++] = outputTensorHandles[i]; + wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; } - const beforeRunStack = wasm.stackSave(); - const inputValuesOffset = wasm.stackAlloc(inputCount * 4); - const inputNamesOffset = wasm.stackAlloc(inputCount * 4); - const outputValuesOffset = wasm.stackAlloc(outputCount * 4); - const outputNamesOffset = wasm.stackAlloc(outputCount * 4); - - try { - let inputValuesIndex = inputValuesOffset / 4; - let inputNamesIndex = inputNamesOffset / 4; - let outputValuesIndex = outputValuesOffset / 4; - let outputNamesIndex = outputNamesOffset / 4; + if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { + const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState; + + if (inputNamesUTF8Encoded.length !== inputCount) { + throw new Error(`input count from feeds (${ + inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`); + } + + // process inputs for (let i = 0; i < inputCount; i++) { - wasm.HEAPU32[inputValuesIndex++] = inputValues[i]; - wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]]; + const index = inputIndices[i]; + const errorCode = await wasm._OrtBindInput(handle, inputNamesUTF8Encoded[index], inputTensorHandles[i]); + if (errorCode !== 0) { + checkLastError(`Can't bind input[${i}] for session=${sessionId}.`); + } } + + // process pre-allocated outputs for (let i = 0; i < outputCount; i++) { - wasm.HEAPU32[outputValuesIndex++] = 0; - wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; + const index = outputIndices[i]; + const location = outputTensors[i]?.[3]; // undefined means output is not pre-allocated. + + if (location) { + // output is pre-allocated. bind the tensor. + const errorCode = wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], outputTensorHandles[i], 0); + if (errorCode !== 0) { + checkLastError(`Can't bind pre-allocated output[${i}] for session=${sessionId}.`); + } + } else { + // output is not pre-allocated. reset preferred location. + const errorCode = + wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], 0, outputPreferredLocationsEncoded[index]); + if (errorCode !== 0) { + checkLastError(`Can't bind output[${i}] to ${outputPreferredLocations[i]} for session=${sessionId}.`); + } + } } + } - // jsepOnRunStart is only available when JSEP is enabled. - wasm.jsepOnRunStart?.(sessionId); + let errorCode: number; - // support RunOptions - let errorCode = wasm._OrtRun( + if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { + errorCode = await wasm._OrtRunWithBinding( + sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle); + } else { + errorCode = await wasm._OrtRun( sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount, outputValuesOffset, runOptionsHandle); + } - const runPromise = wasm.jsepRunPromise; - if (runPromise) { - // jsepRunPromise is a Promise object. It is only available when JSEP is enabled. - // - // OrtRun() is a synchrnous call, but it internally calls async functions. Emscripten's ASYNCIFY allows it to - // work in this way. However, OrtRun() does not return a promise, so when code reaches here, it is earlier than - // the async functions are finished. - // - // To make it work, we created a Promise and resolve the promise when the C++ code actually reaches the end of - // OrtRun(). If the promise exists, we need to await for the promise to be resolved. - errorCode = await runPromise; - } + if (errorCode !== 0) { + checkLastError('failed to call OrtRun().'); + } - const jsepOnRunEnd = wasm.jsepOnRunEnd; - if (jsepOnRunEnd) { - // jsepOnRunEnd is only available when JSEP is enabled. - // - // it returns a promise, which is resolved or rejected when the following async functions are finished: - // - collecting GPU validation errors. - await jsepOnRunEnd(sessionId); + const output: TensorMetadata[] = []; + + for (let i = 0; i < outputCount; i++) { + const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + if (tensor === outputTensorHandles[i]) { + // output tensor is pre-allocated. no need to copy data. + output.push(outputTensors[i]!); + continue; } - const output: SerializableTensor[] = []; + const beforeGetTensorDataStack = wasm.stackSave(); + // stack allocate 4 pointer value + const tensorDataOffset = wasm.stackAlloc(4 * 4); - if (errorCode !== 0) { - checkLastError('failed to call OrtRun().'); - } + let keepOutputTensor = false; + let type: Tensor.Type|undefined, dataOffset = 0; + try { + const errorCode = wasm._OrtGetTensorData( + tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); + if (errorCode !== 0) { + checkLastError(`Can't access output tensor data on index ${i}.`); + } + let tensorDataIndex = tensorDataOffset / 4; + const dataType = wasm.HEAPU32[tensorDataIndex++]; + dataOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const dims = []; + for (let i = 0; i < dimsLength; i++) { + dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + } + wasm._OrtFree(dimsOffset); - for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + const size = dims.reduce((a, b) => a * b, 1); + type = tensorDataTypeEnumToString(dataType); - const beforeGetTensorDataStack = wasm.stackSave(); - // stack allocate 4 pointer value - const tensorDataOffset = wasm.stackAlloc(4 * 4); + const preferredLocation = ioBindingState?.outputPreferredLocations[outputIndices[i]]; - let type: Tensor.Type|undefined, dataOffset = 0; - try { - errorCode = wasm._OrtGetTensorData( - tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); - if (errorCode !== 0) { - checkLastError(`Can't access output tensor data on index ${i}.`); + if (type === 'string') { + if (preferredLocation === 'gpu-buffer') { + throw new Error('String tensor is not supported on GPU.'); } - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; - const dims = []; - for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + const stringData: string[] = []; + let dataIndex = dataOffset / 4; + for (let i = 0; i < size; i++) { + const offset = wasm.HEAPU32[dataIndex++]; + const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); } - wasm._OrtFree(dimsOffset); - - const size = dims.length === 0 ? 1 : dims.reduce((a, b) => a * b); - type = tensorDataTypeEnumToString(dataType); - if (type === 'string') { - const stringData: string[] = []; - let dataIndex = dataOffset / 4; - for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; - stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); + output.push([type, dims, stringData, 'cpu']); + } else { + // If a certain output's preferred location is GPU but the tensor is empty, we still need to create a CPU + // tensor for it. There is no mapping GPU buffer for an empty tensor. + if (preferredLocation === 'gpu-buffer' && size > 0) { + const gpuBuffer = wasm.jsepGetBuffer(dataOffset); + const elementSize = getTensorElementSize(dataType); + if (elementSize === undefined || !isGpuBufferSupportedType(type)) { + throw new Error(`Unsupported data type: ${type}`); } - output.push([type, dims, stringData]); + + // do not release the tensor right now. it will be released when user calls tensor.dispose(). + keepOutputTensor = true; + + output.push([ + type, dims, { + gpuBuffer, + download: wasm.jsepCreateDownloader(gpuBuffer, size * elementSize, type), + dispose: () => { + wasm._OrtReleaseTensor(tensor); + } + }, + 'gpu-buffer' + ]); } else { const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); const data = new typedArrayConstructor(size); new Uint8Array(data.buffer, data.byteOffset, data.byteLength) .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); - output.push([type, dims, data]); - } - } finally { - wasm.stackRestore(beforeGetTensorDataStack); - if (type === 'string' && dataOffset) { - wasm._free(dataOffset); + output.push([type, dims, data, 'cpu']); } + } + } finally { + wasm.stackRestore(beforeGetTensorDataStack); + if (type === 'string' && dataOffset) { + wasm._free(dataOffset); + } + if (!keepOutputTensor) { wasm._OrtReleaseTensor(tensor); } } + } - return output; - } finally { - wasm.stackRestore(beforeRunStack); + if (ioBindingState) { + wasm._OrtClearBoundOutputs(ioBindingState.handle); } + + return output; } finally { - inputValues.forEach(v => wasm._OrtReleaseTensor(v)); - inputAllocs.forEach(p => wasm._free(p)); + wasm.stackRestore(beforeRunStack); + + inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + inputOutputAllocs.forEach(p => wasm._free(p)); if (runOptionsHandle !== 0) { wasm._OrtReleaseRunOptions(runOptionsHandle); @@ -380,11 +540,11 @@ export const endProfiling = (sessionId: number): void => { wasm._OrtFree(profileFileName); }; -export const extractTransferableBuffers = (tensors: readonly SerializableTensor[]): ArrayBufferLike[] => { +export const extractTransferableBuffers = (tensors: readonly SerializableTensorMetadata[]): ArrayBufferLike[] => { const buffers: ArrayBufferLike[] = []; for (const tensor of tensors) { const data = tensor[2]; - if (!Array.isArray(data) && data.buffer) { + if (!Array.isArray(data) && 'buffer' in data) { buffers.push(data.buffer); } } diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index f90f568879146..3f903515694db 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -51,6 +51,10 @@ Options: -P[=<...>], --perf[=<...>] Generate performance number. Cannot be used with flag --debug. This flag can be used with a number as value, specifying the total count of test cases to run. The test cases may be used multiple times. Default value is 10. -c, --file-cache Enable file cache. + -i=<...>, --io-binding=<...> Specify the IO binding testing type. Should be one of the following: + none (default) + gpu-tensor use pre-allocated GPU tensors for inputs and outputs + gpu-location use pre-allocated GPU tensors for inputs and set preferredOutputLocation to 'gpu-buffer' *** Session Options *** -u=<...>, --optimized-model-file-path=<...> Specify whether to dump the optimized model. @@ -109,6 +113,7 @@ export declare namespace TestRunnerCliArgs { type Backend = 'cpu'|'webgl'|'webgpu'|'wasm'|'onnxruntime'|'xnnpack'|'webnn'; type Environment = 'chrome'|'edge'|'firefox'|'electron'|'safari'|'node'|'bs'; type BundleMode = 'prod'|'dev'|'perf'; + type IOBindingMode = 'none'|'gpu-tensor'|'gpu-location'; } export interface TestRunnerCliArgs { @@ -140,6 +145,8 @@ export interface TestRunnerCliArgs { */ bundleMode: TestRunnerCliArgs.BundleMode; + ioBindingMode: TestRunnerCliArgs.IOBindingMode; + logConfig: Test.Config['log']; /** @@ -416,6 +423,13 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs logConfig.push({category: 'TestRunner.Perf', config: {minimalSeverity: 'verbose'}}); } + // Option: -i=<...>, --io-binding=<...> + const ioBindingArg = args['io-binding'] || args.i; + const ioBindingMode = (typeof ioBindingArg !== 'string') ? 'none' : ioBindingArg; + if (['none', 'gpu-tensor', 'gpu-location'].indexOf(ioBindingMode) === -1) { + throw new Error(`not supported io binding mode ${ioBindingMode}`); + } + // Option: -u, --optimized-model-file-path const optimizedModelFilePath = args['optimized-model-file-path'] || args.u || undefined; if (typeof optimizedModelFilePath !== 'undefined' && typeof optimizedModelFilePath !== 'string') { @@ -455,6 +469,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs npmlog.verbose('TestRunnerCli.Init', ` Env: ${env}`); npmlog.verbose('TestRunnerCli.Init', ` Debug: ${debug}`); npmlog.verbose('TestRunnerCli.Init', ` Backend: ${backend}`); + npmlog.verbose('TestRunnerCli.Init', ` IO Binding Mode: ${ioBindingMode}`); npmlog.verbose('TestRunnerCli.Init', 'Parsing commandline arguments... DONE'); return { @@ -467,6 +482,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs logConfig, profile, times: perf ? times : undefined, + ioBindingMode: ioBindingMode as TestRunnerCliArgs['ioBindingMode'], optimizedModelFilePath, graphOptimizationLevel: graphOptimizationLevel as TestRunnerCliArgs['graphOptimizationLevel'], fileCache, diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index f3764e63fcf45..d8fecec1b8084 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -257,7 +257,7 @@ async function main() { times?: number): Test.ModelTest { if (times === 0) { npmlog.verbose('TestRunnerCli.Init.Model', `Skip test data from folder: ${testDataRootFolder}`); - return {name: path.basename(testDataRootFolder), backend, modelUrl: '', cases: []}; + return {name: path.basename(testDataRootFolder), backend, modelUrl: '', cases: [], ioBinding: args.ioBindingMode}; } let modelUrl: string|null = null; @@ -323,6 +323,16 @@ async function main() { } } + let ioBinding: Test.IOBindingMode; + if (backend !== 'webgpu' && args.ioBindingMode !== 'none') { + npmlog.warn( + 'TestRunnerCli.Init.Model', `Ignoring IO Binding Mode "${args.ioBindingMode}" for backend "${backend}".`); + ioBinding = 'none'; + } else { + ioBinding = args.ioBindingMode; + } + + npmlog.verbose('TestRunnerCli.Init.Model', 'Finished preparing test data.'); npmlog.verbose('TestRunnerCli.Init.Model', '==============================================================='); npmlog.verbose('TestRunnerCli.Init.Model', ` Model file: ${modelUrl}`); @@ -330,7 +340,7 @@ async function main() { npmlog.verbose('TestRunnerCli.Init.Model', ` Test set(s): ${cases.length} (${caseCount})`); npmlog.verbose('TestRunnerCli.Init.Model', '==============================================================='); - return {name: path.basename(testDataRootFolder), platformCondition, modelUrl, backend, cases}; + return {name: path.basename(testDataRootFolder), platformCondition, modelUrl, backend, cases, ioBinding}; } function tryLocateModelTestFolder(searchPattern: string): string { @@ -390,6 +400,13 @@ async function main() { for (const test of tests) { test.backend = backend; test.opset = test.opset || {domain: '', version: MAX_OPSET_VERSION}; + if (backend !== 'webgpu' && args.ioBindingMode !== 'none') { + npmlog.warn( + 'TestRunnerCli.Init.Op', `Ignoring IO Binding Mode "${args.ioBindingMode}" for backend "${backend}".`); + test.ioBinding = 'none'; + } else { + test.ioBinding = args.ioBindingMode; + } } npmlog.verbose('TestRunnerCli.Init.Op', 'Finished preparing test data.'); npmlog.verbose('TestRunnerCli.Init.Op', '==============================================================='); diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 46d80a9f56f35..628e5408150f8 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -14,7 +14,8 @@ import {Operator} from '../lib/onnxjs/operators'; import {onnx} from '../lib/onnxjs/ort-schema/protobuf/onnx'; import {Tensor} from '../lib/onnxjs/tensor'; import {ProtoUtil} from '../lib/onnxjs/util'; -import {tensorDataTypeStringToEnum} from '../lib/wasm/wasm-common'; +import {createView} from '../lib/wasm/jsep/tensor-view'; +import {getTensorElementSize, isGpuBufferSupportedType, tensorDataTypeStringToEnum} from '../lib/wasm/wasm-common'; import {base64toBuffer, createMockGraph, readFile} from './test-shared'; import {Test} from './test-types'; @@ -136,8 +137,8 @@ async function loadTensors( } async function initializeSession( - modelFilePath: string, backendHint: string, profile: boolean, sessionOptions: ort.InferenceSession.SessionOptions, - fileCache?: FileCacheBuffer): Promise { + modelFilePath: string, backendHint: string, ioBindingMode: Test.IOBindingMode, profile: boolean, + sessionOptions: ort.InferenceSession.SessionOptions, fileCache?: FileCacheBuffer): Promise { const preloadModelData: Uint8Array|undefined = fileCache && fileCache[modelFilePath] ? fileCache[modelFilePath] : undefined; Logger.verbose( @@ -146,8 +147,14 @@ async function initializeSession( preloadModelData ? ` [preloaded(${preloadModelData.byteLength})]` : ''}`); const profilerConfig = profile ? {maxNumberEvents: 65536} : undefined; - const sessionConfig = - {...sessionOptions, executionProviders: [backendHint], profiler: profilerConfig, enableProfiling: profile}; + const sessionConfig = { + ...sessionOptions, + executionProviders: [backendHint], + profiler: profilerConfig, + enableProfiling: profile, + preferredOutputLocation: ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined + }; + let session: ort.InferenceSession; try { @@ -181,6 +188,7 @@ export class ModelTestContext { readonly session: ort.InferenceSession, readonly backend: string, readonly perfData: ModelTestContext.ModelTestPerfData, + readonly ioBinding: Test.IOBindingMode, private readonly profile: boolean, ) {} @@ -232,8 +240,8 @@ export class ModelTestContext { this.initializing = true; const initStart = now(); - const session = - await initializeSession(modelTest.modelUrl, modelTest.backend!, profile, sessionOptions || {}, this.cache); + const session = await initializeSession( + modelTest.modelUrl, modelTest.backend!, modelTest.ioBinding, profile, sessionOptions || {}, this.cache); const initEnd = now(); for (const testCase of modelTest.cases) { @@ -244,6 +252,7 @@ export class ModelTestContext { session, modelTest.backend!, {init: initEnd - initStart, firstRun: -1, runs: [], count: 0}, + modelTest.ioBinding, profile, ); } finally { @@ -481,6 +490,130 @@ export class TensorResultValidator { } } +function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor { + if (!isGpuBufferSupportedType(cpuTensor.type) || Array.isArray(cpuTensor.data)) { + throw new Error(`createGpuTensorForInput can not work with ${cpuTensor.type} tensor`); + } + const device = ort.env.webgpu.device as GPUDevice; + const gpuBuffer = device.createBuffer({ + // eslint-disable-next-line no-bitwise + usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE, + size: Math.ceil(cpuTensor.data.byteLength / 16) * 16, + mappedAtCreation: true + }); + const arrayBuffer = gpuBuffer.getMappedRange(); + new Uint8Array(arrayBuffer) + .set(new Uint8Array(cpuTensor.data.buffer, cpuTensor.data.byteOffset, cpuTensor.data.byteLength)); + gpuBuffer.unmap(); + + // TODO: how to "await" for the copy to finish, so that we can get more accurate performance data? + + return ort.Tensor.fromGpuBuffer( + gpuBuffer, {dataType: cpuTensor.type, dims: cpuTensor.dims, dispose: () => gpuBuffer.destroy()}); +} + +function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) { + if (!isGpuBufferSupportedType(type)) { + throw new Error(`createGpuTensorForOutput can not work with ${type} tensor`); + } + + const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(type))!; + const size = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; + + const device = ort.env.webgpu.device as GPUDevice; + const gpuBuffer = device.createBuffer({ + // eslint-disable-next-line no-bitwise + usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE, + size: Math.ceil(size / 16) * 16 + }); + + return ort.Tensor.fromGpuBuffer(gpuBuffer, { + dataType: type, + dims, + dispose: () => gpuBuffer.destroy(), + download: async () => { + const stagingBuffer = device.createBuffer({ + // eslint-disable-next-line no-bitwise + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + size: gpuBuffer.size + }); + const encoder = device.createCommandEncoder(); + encoder.copyBufferToBuffer(gpuBuffer, 0, stagingBuffer, 0, gpuBuffer.size); + device.queue.submit([encoder.finish()]); + + await stagingBuffer.mapAsync(GPUMapMode.READ); + const arrayBuffer = stagingBuffer.getMappedRange().slice(0, size); + stagingBuffer.destroy(); + + return createView(arrayBuffer, type) as ort.Tensor.DataTypeMap[ort.Tensor.GpuBufferDataTypes]; + } + }); +} + +export async function sessionRun(options: { + session: ort.InferenceSession; feeds: Record; + outputsMetaInfo: Record>; + ioBinding: Test.IOBindingMode; +}): Promise<[number, number, ort.InferenceSession.OnnxValueMapType]> { + const session = options.session; + const feeds = options.feeds; + const fetches: Record = {}; + + // currently we only support IO Binding for WebGPU + // + // For inputs, we create GPU tensors on both 'gpu-tensor' and 'gpu-location' binding testing mode. + // For outputs, we create GPU tensors on 'gpu-tensor' binding testing mode only. + // in 'gpu-device' binding mode, outputs are not pre-allocated. + const shouldUploadInput = options.ioBinding === 'gpu-tensor' || options.ioBinding === 'gpu-location'; + const shouldUploadOutput = options.ioBinding === 'gpu-tensor'; + try { + if (shouldUploadInput) { + // replace the CPU tensors in feeds into GPU tensors + for (const name in feeds) { + if (Object.hasOwnProperty.call(feeds, name)) { + feeds[name] = createGpuTensorForInput(feeds[name]); + } + } + } + + if (shouldUploadOutput) { + for (const name in options.outputsMetaInfo) { + if (Object.hasOwnProperty.call(options.outputsMetaInfo, name)) { + const {type, dims} = options.outputsMetaInfo[name]; + fetches[name] = createGpuTensorForOutput(type, dims); + } + } + } + + const start = now(); + Logger.verbose('TestRunner', `Timestamp before session run: ${start}`); + const outputs = await ( + shouldUploadOutput ? session.run(feeds, fetches) : + session.run(feeds, Object.getOwnPropertyNames(options.outputsMetaInfo))); + const end = now(); + Logger.verbose('TestRunner', `Timestamp after session run: ${end}`); + + // download each output tensor if needed + for (const name in outputs) { + if (Object.hasOwnProperty.call(outputs, name)) { + const tensor = outputs[name]; + // Tensor.getData(true) release the underlying resource + await tensor.getData(true); + } + } + + return [start, end, outputs]; + } finally { + // dispose the GPU tensors in feeds + for (const name in feeds) { + if (Object.hasOwnProperty.call(feeds, name)) { + const tensor = feeds[name]; + tensor.dispose(); + } + } + } +} + /** * run a single model test case. the inputs/outputs tensors should already been prepared. */ @@ -491,12 +624,11 @@ export async function runModelTestSet( const validator = new TensorResultValidator(context.backend); try { const feeds: Record = {}; + const outputsMetaInfo: Record = {}; testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor); - const start = now(); - Logger.verbose('TestRunner', `Timestamp before session run: ${start}`); - const outputs = await context.session.run(feeds); - const end = now(); - Logger.verbose('TestRunner', `Timestamp after session run: ${end}`); + testCase.outputs!.forEach((tensor, i) => outputsMetaInfo[context.session.outputNames[i]] = tensor); + const [start, end, outputs] = + await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding}); if (context.perfData.count === 0) { context.perfData.firstRun = end - start; } else { @@ -575,6 +707,7 @@ export class ProtoOpTestContext { private readonly loadedData: Uint8Array; // model data, inputs, outputs session: ort.InferenceSession; readonly backendHint: string; + readonly ioBindingMode: Test.IOBindingMode; constructor(test: Test.OperatorTest, private readonly sessionOptions: ort.InferenceSession.SessionOptions = {}) { const opsetImport = onnx.OperatorSetIdProto.create(test.opset); const operator = test.operator; @@ -713,6 +846,7 @@ export class ProtoOpTestContext { model.graph.name = test.name; this.backendHint = test.backend!; + this.ioBindingMode = test.ioBinding; this.loadedData = onnx.ModelProto.encode(model).finish(); // in debug mode, open a new tab in browser for the generated onnx model. @@ -729,8 +863,11 @@ export class ProtoOpTestContext { } } async init(): Promise { - this.session = await ort.InferenceSession.create( - this.loadedData, {executionProviders: [this.backendHint], ...this.sessionOptions}); + this.session = await ort.InferenceSession.create(this.loadedData, { + executionProviders: [this.backendHint], + preferredOutputLocation: this.ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined, + ...this.sessionOptions + }); } async dispose(): Promise { @@ -739,10 +876,11 @@ export class ProtoOpTestContext { } async function runProtoOpTestcase( - session: ort.InferenceSession, testCase: Test.OperatorTestCase, validator: TensorResultValidator): Promise { + session: ort.InferenceSession, testCase: Test.OperatorTestCase, ioBindingMode: Test.IOBindingMode, + validator: TensorResultValidator): Promise { const feeds: Record = {}; - const fetches: string[] = []; - testCase.inputs!.forEach((input, i) => { + const fetches: Record> = {}; + testCase.inputs.forEach((input, i) => { if (input.data) { let data: number[]|BigUint64Array|BigInt64Array = input.data; if (input.type === 'uint64') { @@ -756,7 +894,7 @@ async function runProtoOpTestcase( const outputs: ort.Tensor[] = []; const expectedOutputNames: string[] = []; - testCase.outputs!.forEach((output, i) => { + testCase.outputs.forEach((output, i) => { if (output.data) { let data: number[]|BigUint64Array|BigInt64Array = output.data; if (output.type === 'uint64') { @@ -766,11 +904,11 @@ async function runProtoOpTestcase( } outputs.push(new ort.Tensor(output.type, data, output.dims)); expectedOutputNames.push(`output_${i}`); - fetches.push(`output_${i}`); + fetches[`output_${i}`] = {dims: output.dims, type: output.type}; } }); - const results = await session.run(feeds, fetches); + const [, , results] = await sessionRun({session, feeds, outputsMetaInfo: fetches, ioBinding: ioBindingMode}); const actualOutputNames = Object.getOwnPropertyNames(results); expect(actualOutputNames.length).to.equal(expectedOutputNames.length); @@ -821,7 +959,8 @@ async function runOpTestcase( export async function runOpTest( testcase: Test.OperatorTestCase, context: ProtoOpTestContext|OpTestContext): Promise { if (context instanceof ProtoOpTestContext) { - await runProtoOpTestcase(context.session, testcase, new TensorResultValidator(context.backendHint)); + await runProtoOpTestcase( + context.session, testcase, context.ioBindingMode, new TensorResultValidator(context.backendHint)); } else { await runOpTestcase( context.inferenceHandler, context.createOperator(), testcase, new TensorResultValidator(context.backendHint)); diff --git a/js/web/test/test-types.ts b/js/web/test/test-types.ts index 1f95d1cd8e682..88915e7972383 100644 --- a/js/web/test/test-types.ts +++ b/js/web/test/test-types.ts @@ -43,6 +43,18 @@ export declare namespace Test { */ export type PlatformCondition = string; + /** + * The IOBindingMode represents how to test a model with GPU data. + * + * - none: inputs will be pre-allocated as CPU tensors; no output will be pre-allocated; `preferredOutputLocation` + * will not be set. + * - gpu-location: inputs will be pre-allocated as GPU tensors; no output will be pre-allocated; + * `preferredOutputLocation` will be set to `gpu-buffer`. + * - gpu-tensor: inputs and outputs will all be pre-allocated as GPU tensors. `preferredOutputLocation` + * will not be set. + */ + export type IOBindingMode = 'none'|'gpu-tensor'|'gpu-location'; + export interface ModelTestCase { name: string; dataFiles: readonly string[]; @@ -54,6 +66,7 @@ export declare namespace Test { name: string; modelUrl: string; backend?: string; // value should be populated at build time + ioBinding: IOBindingMode; platformCondition?: PlatformCondition; cases: readonly ModelTestCase[]; } @@ -82,6 +95,7 @@ export declare namespace Test { inputShapeDefinitions?: 'none'|'rankOnly'|'static'|ReadonlyArray; opset?: OperatorTestOpsetImport; backend?: string; // value should be populated at build time + ioBinding: IOBindingMode; platformCondition?: PlatformCondition; attributes?: readonly AttributeValue[]; cases: readonly OperatorTestCase[]; diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h index 177c0a9e691ed..fdd5c7dee5bfc 100644 --- a/onnxruntime/core/providers/js/js_kernel.h +++ b/onnxruntime/core/providers/js/js_kernel.h @@ -196,7 +196,7 @@ class JsKernel : public OpKernel { } int status_code = EM_ASM_INT( - { return Module.jsepRunKernel($0, $1, Module.jsepSessionState); }, + { return Module.jsepRunKernel($0, $1, Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); }, this, reinterpret_cast(p_serialized_kernel_context)); LOGS_DEFAULT(VERBOSE) << "outputs = " << context->OutputCount() << ". Y.data=" diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 174edabbc91fe..968eece361724 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -9,6 +9,7 @@ #include "api.h" #include +#include #include namespace { @@ -17,6 +18,14 @@ OrtErrorCode g_last_error_code; std::string g_last_error_message; } // namespace +enum DataLocation { + DATA_LOCATION_NONE = 0, + DATA_LOCATION_CPU = 1, + DATA_LOCATION_CPU_PINNED = 2, + DATA_LOCATION_TEXTURE = 3, + DATA_LOCATION_GPU_BUFFER = 4 +}; + static_assert(sizeof(const char*) == sizeof(size_t), "size of a pointer and a size_t value should be the same."); static_assert(sizeof(size_t) == 4, "size of size_t should be 4 in this build (wasm32)."); @@ -223,13 +232,23 @@ void OrtFree(void* ptr) { } } -OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length) { +OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length, int data_location) { + if (data_location != DATA_LOCATION_CPU && + data_location != DATA_LOCATION_CPU_PINNED && + data_location != DATA_LOCATION_GPU_BUFFER) { + std::ostringstream ostr; + ostr << "Invalid data location: " << data_location; + CheckStatus(Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, ostr.str().c_str())); + return nullptr; + } + std::vector shapes(dims_length); for (size_t i = 0; i < dims_length; i++) { shapes[i] = dims[i]; } if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { + // data_location is ignored for string tensor. It is always CPU. OrtAllocator* allocator = nullptr; RETURN_NULLPTR_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); @@ -244,12 +263,16 @@ OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* return UNREGISTER_AUTO_RELEASE(value); } else { - OrtMemoryInfo* memoryInfo = nullptr; - RETURN_NULLPTR_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memoryInfo); - REGISTER_AUTO_RELEASE_HANDLE(MemoryInfo, memoryInfo); + OrtMemoryInfo* memory_info = nullptr; + if (data_location != DATA_LOCATION_GPU_BUFFER) { + RETURN_NULLPTR_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info); + } else { + RETURN_NULLPTR_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); + } + REGISTER_AUTO_RELEASE_HANDLE(MemoryInfo, memory_info); OrtValue* value = nullptr; - int error_code = CHECK_STATUS(CreateTensorWithDataAsOrtValue, memoryInfo, data, data_length, + int error_code = CHECK_STATUS(CreateTensorWithDataAsOrtValue, memory_info, data, data_length, dims_length > 0 ? shapes.data() : nullptr, dims_length, static_cast(data_type), &value); @@ -373,15 +396,85 @@ void OrtReleaseRunOptions(OrtRunOptions* run_options) { Ort::GetApi().ReleaseRunOptions(run_options); } +OrtIoBinding* OrtCreateBinding(OrtSession* session) { + OrtIoBinding* binding = nullptr; + int error_code = CHECK_STATUS(CreateIoBinding, session, &binding); + return (error_code == ORT_OK) ? binding : nullptr; +} + +int EMSCRIPTEN_KEEPALIVE OrtBindInput(OrtIoBinding* io_binding, + const char* name, + OrtValue* input) { + return CHECK_STATUS(BindInput, io_binding, name, input); +} + +int EMSCRIPTEN_KEEPALIVE OrtBindOutput(OrtIoBinding* io_binding, + const char* name, + OrtValue* output, + int output_location) { + if (output) { + return CHECK_STATUS(BindOutput, io_binding, name, output); + } else { + if (output_location != DATA_LOCATION_NONE && + output_location != DATA_LOCATION_CPU && + output_location != DATA_LOCATION_CPU_PINNED && + output_location != DATA_LOCATION_GPU_BUFFER) { + std::ostringstream ostr; + ostr << "Invalid data location (" << output_location << ") for output: \"" << name << "\"."; + return CheckStatus(Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, ostr.str().c_str())); + } + + OrtMemoryInfo* memory_info = nullptr; + if (output_location != DATA_LOCATION_GPU_BUFFER) { + RETURN_ERROR_CODE_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info); + } else { + RETURN_ERROR_CODE_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); + } + REGISTER_AUTO_RELEASE_HANDLE(MemoryInfo, memory_info); + return CHECK_STATUS(BindOutputToDevice, io_binding, name, memory_info); + } +} + +void OrtClearBoundOutputs(OrtIoBinding* io_binding) { + Ort::GetApi().ClearBoundOutputs(io_binding); +} + +void OrtReleaseBinding(OrtIoBinding* io_binding) { + Ort::GetApi().ReleaseIoBinding(io_binding); +} + +int OrtRunWithBinding(OrtSession* session, + OrtIoBinding* io_binding, + size_t output_count, + OrtValue** outputs, + OrtRunOptions* run_options) { + RETURN_ERROR_CODE_IF_ERROR(RunWithBinding, session, run_options, io_binding); + + OrtAllocator* allocator = nullptr; + RETURN_ERROR_CODE_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); + + size_t binding_output_count = 0; + OrtValue** binding_outputs = nullptr; + RETURN_ERROR_CODE_IF_ERROR(GetBoundOutputValues, io_binding, allocator, &binding_outputs, &binding_output_count); + REGISTER_AUTO_RELEASE_BUFFER(OrtValue*, binding_outputs, allocator); + + if (binding_output_count != output_count) { + return CheckStatus( + Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "Output count is inconsistent with IO Binding output data.")); + } + + for (size_t i = 0; i < output_count; i++) { + outputs[i] = binding_outputs[i]; + } + + return ORT_OK; +} + int OrtRun(OrtSession* session, const char** input_names, const ort_tensor_handle_t* inputs, size_t input_count, const char** output_names, size_t output_count, ort_tensor_handle_t* outputs, OrtRunOptions* run_options) { - auto status_code = CHECK_STATUS(Run, session, run_options, input_names, inputs, input_count, output_names, output_count, outputs); -#if defined(USE_JSEP) - EM_ASM({ Module.jsepRunPromiseResolve ?.($0); }, status_code); -#endif - return status_code; + return CHECK_STATUS(Run, session, run_options, input_names, inputs, input_count, output_names, output_count, outputs); } char* OrtEndProfiling(ort_session_handle_t session) { diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index 398c901e0e5ed..9a0664697f0ff 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -15,6 +15,9 @@ struct OrtSession; using ort_session_handle_t = OrtSession*; +struct OrtIoBinding; +using ort_io_binding_handle_t = OrtIoBinding*; + struct OrtSessionOptions; using ort_session_options_handle_t = OrtSessionOptions*; @@ -164,9 +167,10 @@ void EMSCRIPTEN_KEEPALIVE OrtFree(void* ptr); * @param data_length size of the buffer 'data' in bytes. * @param dims a pointer to an array of dims. the array should contain (dims_length) element(s). * @param dims_length the length of the tensor's dimension + * @param data_location specify the memory location of the tensor data. 0 for CPU, 1 for GPU buffer. * @returns a tensor handle. Caller must release it after use by calling OrtReleaseTensor(). */ -ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length); +ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length, int data_location); /** * get type, shape info and data of the specified tensor. @@ -216,6 +220,58 @@ int EMSCRIPTEN_KEEPALIVE OrtAddRunConfigEntry(ort_run_options_handle_t run_optio */ void EMSCRIPTEN_KEEPALIVE OrtReleaseRunOptions(ort_run_options_handle_t run_options); +/** + * create an instance of ORT IO binding. + */ +ort_io_binding_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateBinding(ort_session_handle_t session); + +/** + * bind an input tensor to the IO binding instance. A cross device copy will be performed if necessary. + * @param io_binding handle of the IO binding + * @param name name of the input + * @param input handle of the input tensor + * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. + */ +int EMSCRIPTEN_KEEPALIVE OrtBindInput(ort_io_binding_handle_t io_binding, + const char* name, + ort_tensor_handle_t input); + +/** + * bind an output tensor or location to the IO binding instance. + * @param io_binding handle of the IO binding + * @param name name of the output + * @param output handle of the output tensor. nullptr for output location binding. + * @param output_location specify the memory location of the output tensor data. + * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. + */ +int EMSCRIPTEN_KEEPALIVE OrtBindOutput(ort_io_binding_handle_t io_binding, + const char* name, + ort_tensor_handle_t output, + int output_location); + +/** + * clear all bound outputs. + */ +void EMSCRIPTEN_KEEPALIVE OrtClearBoundOutputs(ort_io_binding_handle_t io_binding); + +/** + * release the specified ORT IO binding. + */ +void EMSCRIPTEN_KEEPALIVE OrtReleaseBinding(ort_io_binding_handle_t io_binding); + +/** + * inference the model. + * @param session handle of the specified session + * @param io_binding handle of the IO binding + * @param run_options handle of the run options + * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. + */ +int EMSCRIPTEN_KEEPALIVE OrtRunWithBinding(ort_session_handle_t session, + ort_io_binding_handle_t io_binding, + size_t output_count, + ort_tensor_handle_t* outputs, + ort_run_options_handle_t run_options); + /** * inference the model. * @param session handle of the specified session diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 15d393f4ce62d..427ad6f6d14f3 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -14,40 +14,156 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module.jsepReleaseKernel = releaseKernel; Module.jsepRunKernel = runKernel; - Module['jsepOnRunStart'] = sessionId => { - Module['jsepRunPromise'] = new Promise(r => { - Module.jsepRunPromiseResolve = r; - }); - - if (Module.jsepSessionState) { - throw new Error('Session already started'); - } - - Module.jsepSessionState = { - sessionId, - errors: [] + // This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1) + // It removes some overhead in cwarp() and ccall() that we don't need. + // + // Currently in JSEP build, we only use this for the following functions: + // - OrtRun() + // - OrtRunWithBinding() + // - OrtBindInput() + // + // Note: about parameters "getFunc" and "setFunc": + // - Emscripten has different behaviors for Debug and Release builds for generating exported function wrapper. + // + // - In Debug build, it will generate a wrapper function for each exported function. For example, it generates a + // wrapper for OrtRun() like this (minified): + // ``` + // var _OrtRun = Module["_OrtRun"] = createExportWrapper("OrtRun"); + // ``` + // + // - In Release build, it will generate a lazy loading wrapper for each exported function. For example, it generates + // a wrapper for OrtRun() like this (minified): + // ``` + // d._OrtRun = (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); + // ``` + // + // The behavior of these two wrappers are different. The debug build will assign `Module["_OrtRun"]` only once + // because `createExportWrapper()` does not reset `Module["_OrtRun"]` inside. The release build, however, will + // reset d._OrtRun to J.ka when the first time it is called. + // + // The difference is important because we need to design the async wrapper in a way that it can handle both cases. + // + // Now, let's look at how the async wrapper is designed to work for both cases: + // + // - Debug build: + // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to `createExportWrapper("OrtRun")`. + // 2. When the first time `Module["jsepInit"]` is called, `Module["_OrtRun"]` is re-assigned to a new async + // wrapper function. + // Value of `Module["_OrtRun"]` will not be changed again. + // + // - Release build: + // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to a lazy loading wrapper function. + // 2. When the first time `Module["jsepInit"]` is called, `Module["_OrtRun"]` is re-assigned to a new async + // wrapper function. + // 3. When the first time `Module["_OrtRun"]` is called, the async wrapper will be called. It will call into this + // function: + // ``` + // (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); + // ``` + // This function will assign d._OrtRun (ie. the minimized `Module["_OrtRun"]`) to the real function (J.ka). + // 4. Since d._OrtRun is re-assigned, we need to update the async wrapper to re-assign its stored + // function to the updated value (J.ka), and re-assign the value of `d._OrtRun` back to the async wrapper. + // Value of `Module["_OrtRun"]` will not be changed again. + // + // The value of `Module["_OrtRun"]` will need to be assigned for 2 times for debug build and 4 times for release + // build. + // + // This is why we need this `getFunc` and `setFunc` parameters. They are used to get the current value of an + // exported function and set the new value of an exported function. + // + const jsepWrapAsync = (func, getFunc, setFunc) => { + return (...args) => { + // cache the async data before calling the function. + const previousAsync = Asyncify.currData; + + const previousFunc = getFunc?.(); + const ret = func(...args); + const newFunc = getFunc?.(); + if (previousFunc !== newFunc) { + // The exported function has been updated. + // Set the sync function reference to the new function. + func = newFunc; + // Set the exported function back to the async wrapper. + setFunc(previousFunc); + // Remove getFunc and setFunc. They are no longer needed. + setFunc = null; + getFunc = null; + } + + // If the async data has been changed, it means that the function started an async operation. + if (Asyncify.currData != previousAsync) { + // returns the promise + return Asyncify.whenDone(); + } + // the function is synchronous. returns the result. + return ret; }; }; - Module['jsepOnRunEnd'] = sessionId => { - if (Module.jsepSessionState.sessionId !== sessionId) { - throw new Error('Session ID mismatch'); - } - - const errorPromises = Module.jsepSessionState.errors; - Module.jsepSessionState = null; - - return errorPromises.length === 0 ? Promise.resolve() : new Promise((resolve, reject) => { - Promise.all(errorPromises).then(errors => { - errors = errors.filter(e => e); - if (errors.length > 0) { - reject(new Error(errors.join('\n'))); - } else { - resolve(); + // This is a wrapper for OrtRun() and OrtRunWithBinding() to ensure that Promises are handled correctly. + const runAsync = (runAsyncFunc) => { + return async (...args) => { + try { + // Module.jsepSessionState should be null, unless we are in the middle of a session. + // If it is not null, it means that the previous session has not finished yet. + if (Module.jsepSessionState) { + throw new Error('Session already started'); + } + const state = Module.jsepSessionState = {sessionHandle: args[0], errors: []}; + + // Run the acyncified function: OrtRun() or OrtRunWithBinding() + const ret = await runAsyncFunc(...args); + + // Check if the session is still valid. this object should be the same as the one we set above. + if (Module.jsepSessionState !== state) { + throw new Error('Session mismatch'); + } + + // Flush the backend. This will submit all pending commands to the GPU. + backend['flush'](); + + // Await all pending promises. This includes GPU validation promises for diagnostic purposes. + const errorPromises = state.errors; + if (errorPromises.length > 0) { + let errors = await Promise.all(errorPromises); + errors = errors.filter(e => e); + if (errors.length > 0) { + throw new Error(errors.join('\n')); + } } - }, reason => { - reject(reason); - }); - }); + + return ret; + } finally { + Module.jsepSessionState = null; + } + }; + }; + + // replace the original functions with asyncified versions + Module['_OrtRun'] = runAsync(jsepWrapAsync( + Module['_OrtRun'], + () => Module['_OrtRun'], + v => Module['_OrtRun'] = v)); + Module['_OrtRunWithBinding'] = runAsync(jsepWrapAsync( + Module['_OrtRunWithBinding'], + () => Module['_OrtRunWithBinding'], + v => Module['_OrtRunWithBinding'] = v)); + Module['_OrtBindInput'] = jsepWrapAsync( + Module['_OrtBindInput'], + () => Module['_OrtBindInput'], + v => Module['_OrtBindInput'] = v); + + // expose webgpu backend functions + Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => { + return backend['registerBuffer'](sessionId, index, buffer, size); + }; + Module['jsepUnregisterBuffers'] = sessionId => { + backend['unregisterBuffers'](sessionId); + }; + Module['jsepGetBuffer'] = (dataId) => { + return backend['getBuffer'](dataId); + }; + Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { + return backend['createDownloader'](gpuBuffer, size, type); }; }; diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index d737376eb99b5..788b02f539821 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -29,6 +29,7 @@ jobs: pool: ${{ parameters.PoolName }} variables: + webgpuCommandlineExtraFlags: '--chromium-flags=--ignore-gpu-blocklist --chromium-flags=--gpu-vendor-id=0x10de' runCodesignValidationInjection: false timeoutInMinutes: 60 workspace: @@ -159,12 +160,22 @@ jobs: npm test -- -e=edge -b=webgl,wasm,xnnpack workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (wasm,webgl,xnnpack backend)' - condition: ne('${{ parameters.RunWebGpuTests }}', 'true') + condition: eq('${{ parameters.RunWebGpuTests }}', 'false') - script: | - npm test -- -e=edge -b=webgl,wasm,xnnpack,webgpu --chromium-flags=--ignore-gpu-blocklist --chromium-flags=--gpu-vendor-id=0x10de + npm test -- -e=edge -b=webgl,wasm,xnnpack,webgpu $(webgpuCommandlineExtraFlags) workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (ALL backends)' - condition: ne('${{ parameters.RunWebGpuTests }}', 'false') + condition: eq('${{ parameters.RunWebGpuTests }}', 'true') + - script: | + npm test -- suite1 -e=edge -b=webgpu --io-binding=gpu-tensor $(webgpuCommandlineExtraFlags) + workingDirectory: '$(Build.SourcesDirectory)\js\web' + displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-tensor)' + condition: eq('${{ parameters.RunWebGpuTests }}', 'true') + - script: | + npm test -- suite1 -e=edge -b=webgpu --io-binding=gpu-location $(webgpuCommandlineExtraFlags) + workingDirectory: '$(Build.SourcesDirectory)\js\web' + displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-location)' + condition: eq('${{ parameters.RunWebGpuTests }}', 'true') - script: | npm test -- --webgl-texture-pack-mode -b=webgl -e=edge workingDirectory: '$(Build.SourcesDirectory)\js\web' From 14d349e29075db52eb6971569b572ebe54215596 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 29 Sep 2023 12:32:56 -0700 Subject: [PATCH 10/20] Enable backtrace in unit tests (#17655) ### Description Google test can be built either with absl/re2 or not. This PR enables the build option so that google test framework can print out a nice stacktrace when something went wrong. It helps locate test errors in CI build pipelines. Also, Google test will remove the build option and make it always ON. So sooner or later we must make this change. --- cmake/CMakeLists.txt | 14 ++-- .../external/onnxruntime_external_deps.cmake | 8 +- cmake/onnxruntime_providers.cmake | 22 +++-- cmake/onnxruntime_unittests.cmake | 80 ++++++++++--------- .../core/platform/windows/debug_alloc.cc | 6 +- .../contrib_ops/qordered_attention_test.cc | 6 -- .../qordered_longformer_attention_op_test.cc | 6 -- .../contrib_ops/qordered_matmul_op_test.cc | 6 -- .../test/contrib_ops/qordered_qdq_op_test.cc | 12 --- onnxruntime/test/providers/cpu/model_tests.cc | 6 ++ onnxruntime/test/xctest/xcgtest.mm | 2 +- 11 files changed, 83 insertions(+), 85 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index cc63844f46d28..b03a3019764ca 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -73,6 +73,11 @@ option(onnxruntime_ENABLE_PYTHON "Enable python buildings" OFF) # Enable it may cause LNK1169 error option(onnxruntime_ENABLE_MEMLEAK_CHECKER "Experimental: Enable memory leak checker in Windows debug build" OFF) option(onnxruntime_USE_CUDA "Build with CUDA support" OFF) +# Enable ONNX Runtime CUDA EP's internal unit tests that directly access the EP's internal functions instead of through +# OpKernels. When the option is ON, we will have two copies of GTest library in the same process. It is not a typical +# use. If you hit any problem with that, please do not report it to GTest. Turn OFF the following build option instead. +cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS;LINUX" OFF) + option(onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO "When building with CUDA support, generate device code line number information." OFF) option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF) option(onnxruntime_USE_COREML "Build with CoreML support" OFF) @@ -146,10 +151,11 @@ option(onnxruntime_DISABLE_SPARSE_TENSORS "Disable sparse tensors data types" OF option(onnxruntime_DISABLE_OPTIONAL_TYPE "Disable optional type" OFF) option(onnxruntime_DISABLE_FLOAT8_TYPES "Disable float 8 types" OFF) option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF) -cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON" OFF) +cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON;NOT onnxruntime_USE_CUDA" OFF) # For now onnxruntime_DISABLE_EXCEPTIONS will only work with onnxruntime_MINIMAL_BUILD, more changes (ONNX, non-CPU EP, ...) are required to run this standalone cmake_dependent_option(onnxruntime_DISABLE_EXCEPTIONS "Disable exception handling. Requires onnxruntime_MINIMAL_BUILD currently." ON "onnxruntime_MINIMAL_BUILD;NOT onnxruntime_ENABLE_PYTHON" OFF) -option(onnxruntime_DISABLE_ABSEIL "Do not link to Abseil. Redefine Inlined containers to STD containers." OFF) +# Even when onnxruntime_DISABLE_ABSEIL is ON, ONNX Runtime still needs to link to abseil. +option(onnxruntime_DISABLE_ABSEIL "Do not use Abseil data structures in ONNX Runtime source code. Redefine Inlined containers to STD containers." OFF) option(onnxruntime_EXTENDED_MINIMAL_BUILD "onnxruntime_MINIMAL_BUILD with support for execution providers that compile kernels." OFF) option(onnxruntime_MINIMAL_BUILD_CUSTOM_OPS "Add custom operator kernels support to a minimal build." OFF) @@ -269,10 +275,6 @@ if (onnxruntime_ENABLE_TRAINING_APIS) endif() endif() -if (onnxruntime_USE_CUDA) - set(onnxruntime_DISABLE_RTTI OFF) -endif() - if (onnxruntime_USE_ROCM) if (WIN32) message(FATAL_ERROR "ROCM does not support build in Windows!") diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index e1671bcf43ed9..019c6341d2e46 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -37,8 +37,12 @@ if (onnxruntime_BUILD_UNIT_TESTS) set(gtest_disable_pthreads ON) endif() set(INSTALL_GTEST OFF CACHE BOOL "" FORCE) - # Set it to ON will cause crashes in onnxruntime_test_all when onnxruntime_USE_CUDA is ON - set(GTEST_HAS_ABSL OFF CACHE BOOL "" FORCE) + if (CMAKE_SYSTEM_NAME STREQUAL "iOS") + # Needs to update onnxruntime/test/xctest/xcgtest.mm + set(GTEST_HAS_ABSL OFF CACHE BOOL "" FORCE) + else() + set(GTEST_HAS_ABSL ON CACHE BOOL "" FORCE) + endif() # gtest and gmock FetchContent_Declare( googletest diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 4861643832cab..96c05e5282bb5 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -460,13 +460,17 @@ if (onnxruntime_USE_CUDA) if (onnxruntime_REDUCED_OPS_BUILD) substitute_op_reduction_srcs(onnxruntime_providers_cuda_src) endif() - # cuda_provider_interface.cc is removed from the object target: onnxruntime_providers_cuda_obj and - # add to the lib onnxruntime_providers_cuda separatedly. - # onnxruntime_providers_cuda_ut can share all the object files with onnxruntime_providers_cuda except cuda_provider_interface.cc. - set(cuda_provider_interface_src ${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_provider_interface.cc) - list(REMOVE_ITEM onnxruntime_providers_cuda_src ${cuda_provider_interface_src}) - onnxruntime_add_object_library(onnxruntime_providers_cuda_obj ${onnxruntime_providers_cuda_src}) - onnxruntime_add_shared_library_module(onnxruntime_providers_cuda ${cuda_provider_interface_src} $) + if(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) + # cuda_provider_interface.cc is removed from the object target: onnxruntime_providers_cuda_obj and + # add to the lib onnxruntime_providers_cuda separatedly. + # onnxruntime_providers_cuda_ut can share all the object files with onnxruntime_providers_cuda except cuda_provider_interface.cc. + set(cuda_provider_interface_src ${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_provider_interface.cc) + list(REMOVE_ITEM onnxruntime_providers_cuda_src ${cuda_provider_interface_src}) + onnxruntime_add_object_library(onnxruntime_providers_cuda_obj ${onnxruntime_providers_cuda_src}) + onnxruntime_add_shared_library_module(onnxruntime_providers_cuda ${cuda_provider_interface_src} $) + else() + onnxruntime_add_shared_library_module(onnxruntime_providers_cuda ${onnxruntime_providers_cuda_src}) + endif() # config_cuda_provider_shared_module can be used to config onnxruntime_providers_cuda_obj, onnxruntime_providers_cuda & onnxruntime_providers_cuda_ut. # This function guarantees that all 3 targets have the same configurations. function(config_cuda_provider_shared_module target) @@ -600,7 +604,9 @@ if (onnxruntime_USE_CUDA) target_compile_definitions(${target} PRIVATE ENABLE_ATEN) endif() endfunction() - config_cuda_provider_shared_module(onnxruntime_providers_cuda_obj) + if(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) + config_cuda_provider_shared_module(onnxruntime_providers_cuda_obj) + endif() config_cuda_provider_shared_module(onnxruntime_providers_cuda) install(TARGETS onnxruntime_providers_cuda diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index ec83eb2095071..0e642c5a7e0aa 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -35,13 +35,17 @@ function(AddTest) if (MSVC AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) #TODO: fix the warnings, they are dangerous - target_compile_options(${_UT_TARGET} PRIVATE "/wd4244") + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd4244>" + "$<$>:/wd4244>") endif() if (MSVC) - target_compile_options(${_UT_TARGET} PRIVATE "/wd6330") + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd6330>" + "$<$>:/wd6330>") #Abseil has a lot of C4127/C4324 warnings. - target_compile_options(${_UT_TARGET} PRIVATE "/wd4127") - target_compile_options(${_UT_TARGET} PRIVATE "/wd4324") + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd4127>" + "$<$>:/wd4127>") + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd4324>" + "$<$>:/wd4324>") endif() set_target_properties(${_UT_TARGET} PROPERTIES FOLDER "ONNXRuntimeTest") @@ -60,6 +64,11 @@ function(AddTest) Threads::Threads) target_compile_definitions(${_UT_TARGET} PRIVATE -DUSE_ONNXRUNTIME_DLL) else() + if(onnxruntime_USE_CUDA) + #XXX: we should not need to do this. onnxruntime_test_all.exe should not have direct dependency on CUDA DLLs, + # otherwise it will impact when CUDA DLLs can be unloaded. + target_link_libraries(${_UT_TARGET} PRIVATE cudart) + endif() target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES}) endif() @@ -85,29 +94,22 @@ function(AddTest) # include dbghelp in case tests throw an ORT exception, as that exception includes a stacktrace, which requires dbghelp. target_link_libraries(${_UT_TARGET} PRIVATE debug dbghelp) - if (onnxruntime_USE_CUDA) - # disable a warning from the CUDA headers about unreferenced local functions - if (MSVC) - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Xcompiler /wd4505>" - "$<$>:/wd4505>") - endif() - endif() if (MSVC) # warning C6326: Potential comparison of a constant with another constant. # Lot of such things came from gtest - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Xcompiler /wd6326>" + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd6326>" "$<$>:/wd6326>") # Raw new and delete. A lot of such things came from googletest. - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Xcompiler /wd26409>" + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd26409>" "$<$>:/wd26409>") # "Global initializer calls a non-constexpr function." - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Xcompiler /wd26426>" + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd26426>" "$<$>:/wd26426>") endif() target_compile_options(${_UT_TARGET} PRIVATE ${disabled_warnings}) else() target_compile_options(${_UT_TARGET} PRIVATE ${DISABLED_WARNINGS_FOR_TVM}) - target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:-Xcompiler -Wno-error=sign-compare>" + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options -Wno-error=sign-compare>" "$<$>:-Wno-error=sign-compare>") target_compile_options(${_UT_TARGET} PRIVATE "-Wno-error=uninitialized") endif() @@ -698,7 +700,7 @@ onnxruntime_add_static_library(onnxruntime_test_utils ${onnxruntime_test_utils_s if(MSVC) target_compile_options(onnxruntime_test_utils PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") - target_compile_options(onnxruntime_test_utils PRIVATE "$<$:-Xcompiler /wd6326>" + target_compile_options(onnxruntime_test_utils PRIVATE "$<$:SHELL:--compiler-options /wd6326>" "$<$>:/wd6326>") else() target_compile_definitions(onnxruntime_test_utils PUBLIC -DNSYNC_ATOMIC_CPP11) @@ -755,13 +757,8 @@ set_target_properties(onnx_test_runner_common PROPERTIES FOLDER "ONNXRuntimeTest set(all_tests ${onnxruntime_test_common_src} ${onnxruntime_test_ir_src} ${onnxruntime_test_optimizer_src} ${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src} ${onnxruntime_test_quantiztion_src}) -if(NOT TARGET onnxruntime AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - list(APPEND all_tests ${onnxruntime_shared_lib_test_SRC}) -endif() -if (onnxruntime_USE_CUDA) - onnxruntime_add_static_library(onnxruntime_test_cuda_ops_lib ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu) - list(APPEND onnxruntime_test_common_libs onnxruntime_test_cuda_ops_lib) +if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) file(GLOB onnxruntime_test_providers_cuda_ut_src CONFIGURE_DEPENDS "${TEST_SRC_DIR}/providers/cuda/test_cases/*" ) @@ -822,7 +819,9 @@ if (onnxruntime_USE_TENSORRT) # made test name contain the "ep" and "model path" information, so we can easily filter the tests using cuda ep or other ep with *cpu_* or *xxx_*. list(APPEND test_all_args "--gtest_filter=-*cpu_*:*cuda_*" ) endif () - +if(NOT onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) + list(REMOVE_ITEM all_tests ${TEST_SRC_DIR}/providers/cuda/cuda_provider_test.cc) +endif() AddTest( TARGET onnxruntime_test_all SOURCES ${all_tests} ${onnxruntime_unittest_main_src} @@ -832,11 +831,15 @@ AddTest( DEPENDS ${all_dependencies} TEST_ARGS ${test_all_args} ) + if (MSVC) # The warning means the type of two integral values around a binary operator is narrow than their result. # If we promote the two input values first, it could be more tolerant to integer overflow. # However, this is test code. We are less concerned. - target_compile_options(onnxruntime_test_all PRIVATE "/wd26451" "/wd4244") + target_compile_options(onnxruntime_test_all PRIVATE "$<$:SHELL:--compiler-options /wd26451>" + "$<$>:/wd26451>") + target_compile_options(onnxruntime_test_all PRIVATE "$<$:SHELL:--compiler-options /wd4244>" + "$<$>:/wd4244>") else() target_compile_options(onnxruntime_test_all PRIVATE "-Wno-parentheses") endif() @@ -1092,18 +1095,18 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_include_directories(onnxruntime_benchmark PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} ${ONNXRUNTIME_ROOT}/core/mlas/inc) target_compile_definitions(onnxruntime_benchmark PRIVATE BENCHMARK_STATIC_DEFINE) if(WIN32) - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd4141>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd4141>" "$<$>:/wd4141>") # Avoid using new and delete. But this is a benchmark program, it's ok if it has a chance to leak. - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26409>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26409>" "$<$>:/wd26409>") - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26400>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26400>" "$<$>:/wd26400>") - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26814>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26814>" "$<$>:/wd26814>") - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26814>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26814>" "$<$>:/wd26497>") - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26426>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26426>" "$<$>:/wd26426>") target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") @@ -1255,7 +1258,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) list(APPEND onnxruntime_shared_lib_test_LIBS cpuinfo) endif() if (onnxruntime_USE_CUDA) - list(APPEND onnxruntime_shared_lib_test_LIBS onnxruntime_test_cuda_ops_lib cudart) + list(APPEND onnxruntime_shared_lib_test_LIBS cudart) endif() if (onnxruntime_USE_TENSORRT) list(APPEND onnxruntime_shared_lib_test_LIBS ${TENSORRT_LIBRARY_INFER}) @@ -1270,7 +1273,10 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) LIBS ${onnxruntime_shared_lib_test_LIBS} DEPENDS ${all_dependencies} ) - + if (onnxruntime_USE_CUDA) + target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu) + endif() if (CMAKE_SYSTEM_NAME STREQUAL "Android") target_sources(onnxruntime_shared_lib_test PRIVATE "${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc" @@ -1356,13 +1362,13 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) ) onnxruntime_add_executable(onnxruntime_mlas_test ${onnxruntime_mlas_test_src}) if(MSVC) - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:-Xcompiler /wd26409>" + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26409>" "$<$>:/wd26409>") target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:-Xcompiler /wd6326>" + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd6326>" "$<$>:/wd6326>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:-Xcompiler /wd26426>" + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26426>" "$<$>:/wd26426>") endif() if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") @@ -1476,7 +1482,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") "${TEST_SRC_DIR}/testdata/custom_op_library/cuda/cuda_ops.*") list(APPEND custom_op_lib_include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${onnxruntime_CUDNN_HOME}/include) if (HAS_QSPECTRE) - list(APPEND custom_op_lib_option "$<$:SHELL:-Xcompiler /Qspectre>") + list(APPEND custom_op_lib_option "$<$:SHELL:--compiler-options /Qspectre>") endif() endif() @@ -1503,7 +1509,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") else() set(ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG "-DEF:${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.def") if (NOT onnxruntime_USE_CUDA) - target_compile_options(custom_op_library PRIVATE "$<$:-Xcompiler /wd26409>" + target_compile_options(custom_op_library PRIVATE "$<$:SHELL:--compiler-options /wd26409>" "$<$>:/wd26409>") endif() endif() diff --git a/onnxruntime/core/platform/windows/debug_alloc.cc b/onnxruntime/core/platform/windows/debug_alloc.cc index b08d189f79866..2b612a4303442 100644 --- a/onnxruntime/core/platform/windows/debug_alloc.cc +++ b/onnxruntime/core/platform/windows/debug_alloc.cc @@ -12,7 +12,7 @@ // #ifndef NDEBUG #ifdef ONNXRUNTIME_ENABLE_MEMLEAK_CHECK -constexpr int c_callstack_limit = 16; // Maximum depth of callstack in leak trace +constexpr int c_callstack_limit = 32; // Maximum depth of callstack in leak trace #define VALIDATE_HEAP_EVERY_ALLOC 0 // Call HeapValidate on every new/delete #pragma warning(disable : 4073) // initializers put in library initialization area (this is intentional) @@ -223,6 +223,10 @@ Memory_LeakCheck::~Memory_LeakCheck() { // empty_group_names = new std::map; }); if (string.find("RtlRunOnceExecuteOnce") == std::string::npos && string.find("re2::RE2::Init") == std::string::npos && + string.find("dynamic initializer for 'FLAGS_") == std::string::npos && + string.find("AbslFlagDefaultGenForgtest_") == std::string::npos && + string.find("::SetProgramUsageMessage") == std::string::npos && + string.find("testing::internal::ParseGoogleTestFlagsOnly") == std::string::npos && string.find("testing::internal::Mutex::ThreadSafeLazyInit") == std::string::npos && string.find("testing::internal::ThreadLocalRegistryImpl::GetThreadLocalsMapLocked") == std::string::npos && string.find("testing::internal::ThreadLocalRegistryImpl::GetValueOnCurrentThread") == std::string::npos && diff --git a/onnxruntime/test/contrib_ops/qordered_attention_test.cc b/onnxruntime/test/contrib_ops/qordered_attention_test.cc index 24e4bff528285..1dd0162ad722f 100644 --- a/onnxruntime/test/contrib_ops/qordered_attention_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_attention_test.cc @@ -240,12 +240,6 @@ static std::vector transpose(const T& src, size_t h, size_t w) { } TEST(QOrderedTest, Attention_WithData_ROW_ORDER) { - int cuda_runtime_version = 0; - // Need 11.4 or higher cuda runtime - if ((cudaRuntimeGetVersion(&cuda_runtime_version) != cudaSuccess) || (cuda_runtime_version < 11040)) { - return; - } - // Needs Turing architecture if (NeedSkipIfCudaArchLowerThan(750) || NeedSkipIfCudaArchGreaterEqualThan(800)) { return; diff --git a/onnxruntime/test/contrib_ops/qordered_longformer_attention_op_test.cc b/onnxruntime/test/contrib_ops/qordered_longformer_attention_op_test.cc index 55209d9422fdd..06fe42ca989d7 100644 --- a/onnxruntime/test/contrib_ops/qordered_longformer_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_longformer_attention_op_test.cc @@ -34,12 +34,6 @@ static void run_qordered_longformer_attention_op_test( const int64_t head_size, const int64_t window, int64_t input_hidden_size = 0) { - int cuda_runtime_version = 0; - // Need 11.4 or higher cuda runtime - if ((cudaRuntimeGetVersion(&cuda_runtime_version) != cudaSuccess) || (cuda_runtime_version < 11040)) { - return; - } - // Needs Turing architecture if (NeedSkipIfCudaArchLowerThan(750) || NeedSkipIfCudaArchGreaterEqualThan(800)) { return; diff --git a/onnxruntime/test/contrib_ops/qordered_matmul_op_test.cc b/onnxruntime/test/contrib_ops/qordered_matmul_op_test.cc index e5b3d59ef86e3..e3905db6355d9 100644 --- a/onnxruntime/test/contrib_ops/qordered_matmul_op_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_matmul_op_test.cc @@ -21,12 +21,6 @@ static void RunQOrdered_MatMul_Test( OrderCublasLt weight_order, float scale_A, float scale_B, float scale_C, float scale_Y, bool add_bias = false, bool broadcast_c_batch = false, bool per_channel = false) { - int cuda_runtime_version = 0; - // Need 11.4 or higher cuda runtime - if ((cudaRuntimeGetVersion(&cuda_runtime_version) != cudaSuccess) || (cuda_runtime_version < 11040)) { - return; - } - // Needs Turing architecture if (NeedSkipIfCudaArchLowerThan(750) || NeedSkipIfCudaArchGreaterEqualThan(800)) { return; diff --git a/onnxruntime/test/contrib_ops/qordered_qdq_op_test.cc b/onnxruntime/test/contrib_ops/qordered_qdq_op_test.cc index 15e97751acf2d..0f3f702695b80 100644 --- a/onnxruntime/test/contrib_ops/qordered_qdq_op_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_qdq_op_test.cc @@ -73,12 +73,6 @@ static void RunQOrdered_Quantize_Test( std::vector const& shape, OrderCublasLt order_q, float scale) { - int cuda_runtime_version = 0; - // Need 11.4 or higher cuda runtime - if ((cudaRuntimeGetVersion(&cuda_runtime_version) != cudaSuccess) || (cuda_runtime_version < 11040)) { - return; - } - auto qvec = QuantizeTransform(shape, scale, fvec, order_q); std::vector> execution_providers; @@ -153,12 +147,6 @@ static void RunQOrdered_Dequantize_Test( std::vector const& shape, OrderCublasLt order_q, float scale) { - int cuda_runtime_version = 0; - // Need 11.4 or higher cuda runtime - if ((cudaRuntimeGetVersion(&cuda_runtime_version) != cudaSuccess) || (cuda_runtime_version < 11040)) { - return; - } - auto fvec = DequantizeTransform(shape, scale, qvec, order_q); std::vector> execution_providers; diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index 9b41ba8c0d2ba..da906ebf76f79 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -1172,6 +1172,12 @@ ::std::vector<::std::basic_string> GetParameterStrings() { ORT_TSTR("bvlc_alexnet"), ORT_TSTR("bvlc_reference_caffenet"), ORT_TSTR("coreml_VGG16_ImageNet"), + ORT_TSTR("VGG 16-fp32"), + ORT_TSTR("VGG 19-caffe2"), + ORT_TSTR("VGG 19-bn"), + ORT_TSTR("VGG 16-bn"), + ORT_TSTR("VGG 19"), + ORT_TSTR("VGG 16"), ORT_TSTR("faster_rcnn"), ORT_TSTR("GPT2"), ORT_TSTR("GPT2_LM_HEAD"), diff --git a/onnxruntime/test/xctest/xcgtest.mm b/onnxruntime/test/xctest/xcgtest.mm index 5367f3e89c07c..c02f18d906cbe 100644 --- a/onnxruntime/test/xctest/xcgtest.mm +++ b/onnxruntime/test/xctest/xcgtest.mm @@ -201,7 +201,7 @@ + (void)registerTestClasses { delete listeners.Release(listeners.default_result_printer()); free(argv); - BOOL runDisabledTests = testing::GTEST_FLAG(also_run_disabled_tests); + BOOL runDisabledTests = GTEST_FLAG_GET(also_run_disabled_tests); NSMutableDictionary* testFilterMap = [NSMutableDictionary dictionary]; NSCharacterSet* decimalDigitCharacterSet = [NSCharacterSet decimalDigitCharacterSet]; From 5a623dca0118a1fc75419875d2af813161d097a0 Mon Sep 17 00:00:00 2001 From: shaahji <96227573+shaahji@users.noreply.github.com> Date: Fri, 29 Sep 2023 14:11:05 -0700 Subject: [PATCH 11/20] Python API to check whether collective ops are available or not (#17730) Python API to check whether collective ops are available or not ### Description Adding an API to check whether collective ops are available or not. Since there is no independent MPI enabled build, this flag can be used on Python front for branching. Specifically, to conditionally enable tests. ### Motivation and Context Flag to be used in Python to check whether onnxruntime supports collective ops or not. Handy for conditionally enabling/disabling tests and for other branching decisions. --- onnxruntime/__init__.py | 1 + onnxruntime/python/onnxruntime_pybind_module.cc | 7 +++++++ onnxruntime/test/python/onnxruntime_test_collective.py | 6 ++++++ .../orttraining/python/orttraining_python_module.cc | 8 ++++++++ 4 files changed, 22 insertions(+) diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index fd147eaa11f3f..0ed7d887fc5e5 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -42,6 +42,7 @@ from onnxruntime.capi._pybind_state import get_build_info # noqa: F401 from onnxruntime.capi._pybind_state import get_device # noqa: F401 from onnxruntime.capi._pybind_state import get_version_string # noqa: F401 + from onnxruntime.capi._pybind_state import has_collective_ops # noqa: F401 from onnxruntime.capi._pybind_state import set_default_logger_severity # noqa: F401 from onnxruntime.capi._pybind_state import set_default_logger_verbosity # noqa: F401 from onnxruntime.capi._pybind_state import set_seed # noqa: F401 diff --git a/onnxruntime/python/onnxruntime_pybind_module.cc b/onnxruntime/python/onnxruntime_pybind_module.cc index f320707697c9e..6824a5d0bf98f 100644 --- a/onnxruntime/python/onnxruntime_pybind_module.cc +++ b/onnxruntime/python/onnxruntime_pybind_module.cc @@ -10,6 +10,12 @@ namespace onnxruntime { namespace python { namespace py = pybind11; +#if defined(USE_MPI) && defined(ORT_USE_NCCL) +static constexpr bool HAS_COLLECTIVE_OPS = true; +#else +static constexpr bool HAS_COLLECTIVE_OPS = false; +#endif + void CreateInferencePybindStateModule(py::module& m); PYBIND11_MODULE(onnxruntime_pybind11_state, m) { @@ -23,6 +29,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { m.def("get_version_string", []() -> std::string { return ORT_VERSION; }); m.def("get_build_info", []() -> std::string { return ORT_BUILD_INFO; }); + m.def("has_collective_ops", []() -> bool { return HAS_COLLECTIVE_OPS; }); } } // namespace python } // namespace onnxruntime diff --git a/onnxruntime/test/python/onnxruntime_test_collective.py b/onnxruntime/test/python/onnxruntime_test_collective.py index db1ebb5384730..4882b403c3c91 100644 --- a/onnxruntime/test/python/onnxruntime_test_collective.py +++ b/onnxruntime/test/python/onnxruntime_test_collective.py @@ -155,6 +155,7 @@ def _create_alltoall_ut_model_for_boolean_tensor( ) return ORTBertPretrainTest._create_model_with_opsets(graph_def) + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") @parameterized.expand( [ (np.float32, TensorProto.FLOAT), @@ -193,6 +194,7 @@ def test_all_reduce(self, np_elem_type, elem_type): outputs[0], size * input, err_msg=f"{rank}: AllGather ({np_elem_type}, {elem_type}): results mismatch" ) + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") @parameterized.expand( [ (np.float32, TensorProto.FLOAT, TensorProto.FLOAT), @@ -231,6 +233,7 @@ def test_all_gather(self, np_elem_type, elem_type, communication_elem_type): err_msg=f"{rank}: AllGather (axis0) ({np_elem_type}, {elem_type}, {communication_elem_type}): results mismatch", ) + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") def test_all_gather_bool(self): model = self._create_allgather_ut_model((4,), 0, TensorProto.INT64, TensorProto.INT64) rank, _ = self._get_rank_size() @@ -250,6 +253,7 @@ def test_all_gather_bool(self): np.testing.assert_allclose(y, y_expected, err_msg=f"{rank}: AllGather (bool): results mismatch") + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") def test_all_gather_axis1(self): model = self._create_allgather_ut_model((128, 128), 1) rank, size = self._get_rank_size() @@ -268,6 +272,7 @@ def test_all_gather_axis1(self): np.testing.assert_allclose(outputs[0], expected_output, err_msg=f"{rank}: AllGather (axis1): results mismatch") + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") @parameterized.expand( [ (np.float32, TensorProto.FLOAT, TensorProto.FLOAT), @@ -349,6 +354,7 @@ def test_all_to_all(self, np_elem_type, elem_type, communication_elem_type): err_msg=f"{rank}: AllToAll ({np_elem_type}, {elem_type}, {communication_elem_type}): results mismatch", ) + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") def test_all_to_all_bool(self): rank, _ = self._get_rank_size() diff --git a/orttraining/orttraining/python/orttraining_python_module.cc b/orttraining/orttraining/python/orttraining_python_module.cc index 7024244629c3e..88ef90a7feaa8 100644 --- a/orttraining/orttraining/python/orttraining_python_module.cc +++ b/orttraining/orttraining/python/orttraining_python_module.cc @@ -15,6 +15,12 @@ namespace onnxruntime { namespace python { namespace py = pybind11; +#if defined(USE_MPI) && defined(ORT_USE_NCCL) +static constexpr bool HAS_COLLECTIVE_OPS = true; +#else +static constexpr bool HAS_COLLECTIVE_OPS = false; +#endif + using namespace onnxruntime::logging; std::unique_ptr CreateExecutionProviderInstance( @@ -361,6 +367,8 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { }, "Clean the execution provider instances used in ort training module."); + m.def("has_collective_ops", []() -> bool { return HAS_COLLECTIVE_OPS; }); + // See documentation for class TrainingEnvInitialzer earlier in this module // for an explanation as to why this is needed. auto atexit = py::module_::import("atexit"); From e106b1eb8f22c57414a8bb4e69cbb62b702b12c7 Mon Sep 17 00:00:00 2001 From: Rachel Guo <35738743+YUNQIUGUO@users.noreply.github.com> Date: Fri, 29 Sep 2023 18:03:28 -0700 Subject: [PATCH 12/20] Fix react native load from Uint8Array buffer bug (#17739) ### Description Use `.buffer` of Uint8Array to get ArrayBuffer. TODO: Add E2E React Native test case to cover JS level testing to avoid future breakage. ### Motivation and Context #17732 Co-authored-by: rachguo --- js/react_native/lib/backend.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/js/react_native/lib/backend.ts b/js/react_native/lib/backend.ts index b3f0c466308a5..058531f415d61 100644 --- a/js/react_native/lib/backend.ts +++ b/js/react_native/lib/backend.ts @@ -66,12 +66,14 @@ class OnnxruntimeSessionHandler implements SessionHandler { let results: Binding.ModelLoadInfoType; // load a model if (typeof this.#pathOrBuffer === 'string') { + // load model from model path results = await this.#inferenceSession.loadModel(normalizePath(this.#pathOrBuffer), options); } else { + // load model from buffer if (!this.#inferenceSession.loadModelFromBlob) { throw new Error('Native module method "loadModelFromBlob" is not defined'); } - const modelBlob = jsiHelper.storeArrayBuffer(this.#pathOrBuffer); + const modelBlob = jsiHelper.storeArrayBuffer(this.#pathOrBuffer.buffer); results = await this.#inferenceSession.loadModelFromBlob(modelBlob, options); } // resolve promise if onnxruntime session is successfully created From 6a5f469d44aca607bd08cc2aca117c33bab31da8 Mon Sep 17 00:00:00 2001 From: Caroline Zhu Date: Fri, 29 Sep 2023 19:05:10 -0700 Subject: [PATCH 13/20] Add training interfaces to js/common (#17333) ### Description Following the design document: * Added CreateTrainingSessionHandler to the Backend interface * All existing Backend implementations throw an error for the new method createTrainingSessionHandler * Created TrainingSession namespace, interface, and TrainingSessionFactory interface * Created TrainingSessionImpl class implementation As methods are implemented, the TrainingSession interface will be added to or modified. ### Motivation and Context Adding the public-facing interfaces to the onnxruntime-common package is one of the first steps to support ORT training for web bindings. --------- Co-authored-by: Caroline Zhu --- js/common/lib/backend-impl.ts | 2 +- js/common/lib/backend.ts | 35 ++++++- js/common/lib/env.ts | 1 + js/common/lib/index.ts | 1 + js/common/lib/inference-session-impl.ts | 8 +- js/common/lib/training-session-impl.ts | 49 +++++++++ js/common/lib/training-session.ts | 134 ++++++++++++++++++++++++ js/node/lib/backend.ts | 8 +- js/react_native/lib/backend.ts | 8 +- js/web/lib/backend-onnxjs.ts | 6 +- js/web/lib/backend-wasm.ts | 12 ++- js/web/lib/onnxjs/session-handler.ts | 4 +- js/web/lib/wasm/session-handler.ts | 4 +- 13 files changed, 243 insertions(+), 29 deletions(-) create mode 100644 js/common/lib/training-session-impl.ts create mode 100644 js/common/lib/training-session.ts diff --git a/js/common/lib/backend-impl.ts b/js/common/lib/backend-impl.ts index 75feba1d0ae08..e129c6971a85c 100644 --- a/js/common/lib/backend-impl.ts +++ b/js/common/lib/backend-impl.ts @@ -26,7 +26,7 @@ const backendsSortedByPriority: string[] = []; * @ignore */ export const registerBackend = (name: string, backend: Backend, priority: number): void => { - if (backend && typeof backend.init === 'function' && typeof backend.createSessionHandler === 'function') { + if (backend && typeof backend.init === 'function' && typeof backend.createInferenceSessionHandler === 'function') { const currentBackend = backends.get(name); if (currentBackend === undefined) { backends.set(name, {backend, priority}); diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts index 804f33f00d103..dd04ef3f15997 100644 --- a/js/common/lib/backend.ts +++ b/js/common/lib/backend.ts @@ -3,6 +3,7 @@ import {InferenceSession} from './inference-session.js'; import {OnnxValue} from './onnx-value.js'; +import {TrainingSession} from './training-session.js'; /** * @ignore @@ -14,16 +15,23 @@ export declare namespace SessionHandler { } /** - * Represent a handler instance of an inference session. + * Represents shared SessionHandler functionality * * @ignore */ -export interface SessionHandler { +interface SessionHandler { dispose(): Promise; readonly inputNames: readonly string[]; readonly outputNames: readonly string[]; +} +/** + * Represent a handler instance of an inference session. + * + * @ignore + */ +export interface InferenceSessionHandler extends SessionHandler { startProfiling(): void; endProfiling(): void; @@ -31,6 +39,20 @@ export interface SessionHandler { options: InferenceSession.RunOptions): Promise; } +/** + * Represent a handler instance of a training inference session. + * + * @ignore + */ +export interface TrainingSessionHandler extends SessionHandler { + runTrainStep( + feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions): Promise; + + loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; + getContiguousParameters(trainableOnly: boolean): Promise; +} + /** * Represent a backend that provides implementation of model inferencing. * @@ -42,8 +64,13 @@ export interface Backend { */ init(): Promise; - createSessionHandler(uriOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise; + createInferenceSessionHandler(uriOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): + Promise; + + createTrainingSessionHandler? + (checkpointStateUriOrBuffer: TrainingSession.URIorBuffer, trainModelUriOrBuffer: TrainingSession.URIorBuffer, + evalModelUriOrBuffer: TrainingSession.URIorBuffer, optimizerModelUriOrBuffer: TrainingSession.URIorBuffer, + options: InferenceSession.SessionOptions): Promise; } export {registerBackend} from './backend-impl.js'; diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index 525272294c587..c78ae0fc83010 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -9,6 +9,7 @@ export declare namespace Env { 'ort-wasm.wasm'?: string; 'ort-wasm-threaded.wasm'?: string; 'ort-wasm-simd.wasm'?: string; + 'ort-training-wasm-simd.wasm'?: string; 'ort-wasm-simd-threaded.wasm'?: string; /* eslint-enable @typescript-eslint/naming-convention */ }; diff --git a/js/common/lib/index.ts b/js/common/lib/index.ts index 85df1747f8576..9cbfcc4e8bcdc 100644 --- a/js/common/lib/index.ts +++ b/js/common/lib/index.ts @@ -22,3 +22,4 @@ export * from './env.js'; export * from './inference-session.js'; export * from './tensor.js'; export * from './onnx-value.js'; +export * from './training-session.js'; diff --git a/js/common/lib/inference-session-impl.ts b/js/common/lib/inference-session-impl.ts index 06949b4a26c0d..9bc2088f2088a 100644 --- a/js/common/lib/inference-session-impl.ts +++ b/js/common/lib/inference-session-impl.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {resolveBackend} from './backend-impl.js'; -import {SessionHandler} from './backend.js'; +import {InferenceSessionHandler} from './backend.js'; import {InferenceSession as InferenceSessionInterface} from './inference-session.js'; import {OnnxValue} from './onnx-value.js'; import {Tensor} from './tensor.js'; @@ -14,7 +14,7 @@ type FetchesType = InferenceSessionInterface.FetchesType; type ReturnType = InferenceSessionInterface.ReturnType; export class InferenceSession implements InferenceSessionInterface { - private constructor(handler: SessionHandler) { + private constructor(handler: InferenceSessionHandler) { this.handler = handler; } run(feeds: FeedsType, options?: RunOptions): Promise; @@ -195,7 +195,7 @@ export class InferenceSession implements InferenceSessionInterface { const eps = options.executionProviders || []; const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); const backend = await resolveBackend(backendHints); - const handler = await backend.createSessionHandler(filePathOrUint8Array, options); + const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, options); return new InferenceSession(handler); } @@ -213,5 +213,5 @@ export class InferenceSession implements InferenceSessionInterface { return this.handler.outputNames; } - private handler: SessionHandler; + private handler: InferenceSessionHandler; } diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts new file mode 100644 index 0000000000000..f06d06bda035f --- /dev/null +++ b/js/common/lib/training-session-impl.ts @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TrainingSessionHandler} from './backend.js'; +import {InferenceSession as InferenceSession} from './inference-session.js'; +import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js'; + +type SessionOptions = InferenceSession.SessionOptions; + +export class TrainingSession implements TrainingSessionInterface { + private constructor(handler: TrainingSessionHandler) { + this.handler = handler; + } + private handler: TrainingSessionHandler; + + get inputNames(): readonly string[] { + return this.handler.inputNames; + } + get outputNames(): readonly string[] { + return this.handler.outputNames; + } + + static async create(_trainingOptions: TrainingSessionCreateOptions, _sessionOptions?: SessionOptions): + Promise { + throw new Error('Method not implemented'); + } + + async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + + async getContiguousParameters(_trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + + runTrainStep(feeds: InferenceSession.OnnxValueMapType, options?: InferenceSession.RunOptions|undefined): + Promise; + runTrainStep( + feeds: InferenceSession.OnnxValueMapType, fetches: InferenceSession.FetchesType, + options?: InferenceSession.RunOptions|undefined): Promise; + async runTrainStep(_feeds: unknown, _fetches?: unknown, _options?: unknown): + Promise { + throw new Error('Method not implemented.'); + } + + async release(): Promise { + return this.handler.dispose(); + } +} diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts new file mode 100644 index 0000000000000..0967d79b33434 --- /dev/null +++ b/js/common/lib/training-session.ts @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {InferenceSession} from './inference-session.js'; +import {TrainingSession as TrainingSessionImpl} from './training-session-impl.js'; + +/* eslint-disable @typescript-eslint/no-redeclare */ + +export declare namespace TrainingSession { + /** + * Either URI file path (string) or Uint8Array containing model or checkpoint information. + */ + type URIorBuffer = string|Uint8Array; +} + +/** + * Represent a runtime instance of an ONNX training session, + * which contains a model that can be trained, and, optionally, + * an eval and optimizer model. + */ +export interface TrainingSession { + // #region run() + + /** + * Run TrainStep asynchronously with the given feeds and options. + * + * @param feeds - Representation of the model input. See type description of `InferenceSession.InputType` for + detail. + * @param options - Optional. A set of options that controls the behavior of model training. + * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding values. + */ + runTrainStep(feeds: InferenceSession.FeedsType, options?: InferenceSession.RunOptions): + Promise; + + /** + * Run a single train step with the given inputs and options. + * + * @param feeds - Representation of the model input. + * @param fetches - Representation of the model output. + * detail. + * @param options - Optional. A set of options that controls the behavior of model inference. + * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding + values. + */ + runTrainStep( + feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType, + options?: InferenceSession.RunOptions): Promise; + + // #endregion + + // #region copy parameters + /** + * Copies from a buffer containing parameters to the TrainingSession parameters. + * + * @param buffer - buffer containing parameters + * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. + */ + loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; + + /** + * Copies from the TrainingSession parameters to a buffer. + * + * @param trainableOnly - True if trainable parameters only to be copied, false othrwise. + * @returns A promise that resolves to a buffer of the requested parameters. + */ + getContiguousParameters(trainableOnly: boolean): Promise; + // #endregion + + // #region release() + + /** + * Release the inference session and the underlying resources. + */ + release(): Promise; + // #endregion + + // #region metadata + + /** + * Get input names of the loaded model. + */ + readonly inputNames: readonly string[]; + + /** + * Get output names of the loaded model. + */ + readonly outputNames: readonly string[]; + // #endregion +} + +/** + * Represents the optional parameters that can be passed into the TrainingSessionFactory. + */ +export interface TrainingSessionCreateOptions { + /** + * URI or buffer for a .ckpt file that contains the checkpoint for the training model. + */ + checkpointState: TrainingSession.URIorBuffer; + /** + * URI or buffer for the .onnx training file. + */ + trainModel: TrainingSession.URIorBuffer; + /** + * Optional. URI or buffer for the .onnx optimizer model file. + */ + optimizerModel?: TrainingSession.URIorBuffer; + /** + * Optional. URI or buffer for the .onnx eval model file. + */ + evalModel?: TrainingSession.URIorBuffer; +} + +/** + * Defines method overload possibilities for creating a TrainingSession. + */ +export interface TrainingSessionFactory { + // #region create() + + /** + * Creates a new TrainingSession and asynchronously loads any models passed in through trainingOptions + * + * @param trainingOptions specify models and checkpoints to load into the Training Session + * @param sessionOptions specify configuration for training session behavior + * + * @returns Promise that resolves to a TrainingSession object + */ + create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: InferenceSession.SessionOptions): + Promise; + + // #endregion +} + +// eslint-disable-next-line @typescript-eslint/naming-convention +export const TrainingSession: TrainingSessionFactory = TrainingSessionImpl; diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts index d3680f9d44236..5f5ad49a2dea8 100644 --- a/js/node/lib/backend.ts +++ b/js/node/lib/backend.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Backend, InferenceSession, SessionHandler} from 'onnxruntime-common'; +import {Backend, InferenceSession, InferenceSessionHandler, SessionHandler} from 'onnxruntime-common'; import {Binding, binding} from './binding'; -class OnnxruntimeSessionHandler implements SessionHandler { +class OnnxruntimeSessionHandler implements InferenceSessionHandler { #inferenceSession: Binding.InferenceSession; constructor(pathOrBuffer: string|Uint8Array, options: InferenceSession.SessionOptions) { @@ -53,8 +53,8 @@ class OnnxruntimeBackend implements Backend { return Promise.resolve(); } - async createSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise { + async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): + Promise { return new Promise((resolve, reject) => { process.nextTick(() => { try { diff --git a/js/react_native/lib/backend.ts b/js/react_native/lib/backend.ts index 058531f415d61..3b1852699ac48 100644 --- a/js/react_native/lib/backend.ts +++ b/js/react_native/lib/backend.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Backend, InferenceSession, SessionHandler, Tensor,} from 'onnxruntime-common'; +import {Backend, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; import {Platform} from 'react-native'; import {binding, Binding, JSIBlob, jsiHelper} from './binding'; @@ -43,7 +43,7 @@ const normalizePath = (path: string): string => { return path; }; -class OnnxruntimeSessionHandler implements SessionHandler { +class OnnxruntimeSessionHandler implements InferenceSessionHandler { #inferenceSession: Binding.InferenceSession; #key: string; @@ -165,8 +165,8 @@ class OnnxruntimeBackend implements Backend { return Promise.resolve(); } - async createSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise { + async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): + Promise { const handler = new OnnxruntimeSessionHandler(pathOrBuffer); await handler.loadModel(options || {}); return handler; diff --git a/js/web/lib/backend-onnxjs.ts b/js/web/lib/backend-onnxjs.ts index 18a068e0ced8b..5ea7de809a495 100644 --- a/js/web/lib/backend-onnxjs.ts +++ b/js/web/lib/backend-onnxjs.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. /* eslint-disable import/no-internal-modules */ -import {Backend, InferenceSession, SessionHandler} from 'onnxruntime-common'; +import {Backend, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; import {Session} from './onnxjs/session'; import {OnnxjsSessionHandler} from './onnxjs/session-handler'; @@ -11,8 +11,8 @@ class OnnxjsBackend implements Backend { // eslint-disable-next-line @typescript-eslint/no-empty-function async init(): Promise {} - async createSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise { + async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): + Promise { // NOTE: Session.Config(from onnx.js) is not compatible with InferenceSession.SessionOptions(from // onnxruntime-common). // In future we should remove Session.Config and use InferenceSession.SessionOptions. diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index ceb20044d97b6..04108c2ad0f66 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Backend, env, InferenceSession, SessionHandler} from 'onnxruntime-common'; +import {Backend, env, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; import {cpus} from 'os'; import {initializeWebAssemblyInstance} from './wasm/proxy-wrapper'; @@ -40,10 +40,12 @@ class OnnxruntimeWebAssemblyBackend implements Backend { // init wasm await initializeWebAssemblyInstance(); } - createSessionHandler(path: string, options?: InferenceSession.SessionOptions): Promise; - createSessionHandler(buffer: Uint8Array, options?: InferenceSession.SessionOptions): Promise; - async createSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise { + createInferenceSessionHandler(path: string, options?: InferenceSession.SessionOptions): + Promise; + createInferenceSessionHandler(buffer: Uint8Array, options?: InferenceSession.SessionOptions): + Promise; + async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): + Promise { const handler = new OnnxruntimeWebAssemblySessionHandler(); await handler.loadModel(pathOrBuffer, options); return Promise.resolve(handler); diff --git a/js/web/lib/onnxjs/session-handler.ts b/js/web/lib/onnxjs/session-handler.ts index 0b06a7a747a44..47e50aeab673a 100644 --- a/js/web/lib/onnxjs/session-handler.ts +++ b/js/web/lib/onnxjs/session-handler.ts @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceSession, SessionHandler, Tensor} from 'onnxruntime-common'; +import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; import {Session} from './session'; import {Tensor as OnnxjsTensor} from './tensor'; -export class OnnxjsSessionHandler implements SessionHandler { +export class OnnxjsSessionHandler implements InferenceSessionHandler { constructor(private session: Session) { this.inputNames = this.session.inputNames; this.outputNames = this.session.outputNames; diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler.ts index 4e00878d0063b..7bc467449c33a 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {readFile} from 'fs'; -import {env, InferenceSession, SessionHandler, Tensor} from 'onnxruntime-common'; +import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; import {promisify} from 'util'; import {SerializableModeldata, TensorMetadata} from './proxy-messages'; @@ -40,7 +40,7 @@ const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => { } }; -export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler { +export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHandler { private sessionId: number; inputNames: string[]; From 668c70ee11b6b20c56997a9bc68e93317674e803 Mon Sep 17 00:00:00 2001 From: Pranav Sharma Date: Fri, 29 Sep 2023 19:46:55 -0700 Subject: [PATCH 14/20] Add support for specifying a custom logging function per session. (#17727) ### Description Add support for specifying a custom logging function per session. Bindings for other languages will be added after this PR is merged. ### Motivation and Context Users want a way to override the logging provided by the environment. --- .../core/session/onnxruntime_c_api.h | 31 ++++++++++-- onnxruntime/core/framework/session_options.h | 4 ++ .../core/session/abi_session_options.cc | 8 +++ onnxruntime/core/session/inference_session.cc | 49 +++++++++++++------ onnxruntime/core/session/inference_session.h | 8 ++- onnxruntime/core/session/onnxruntime_c_api.cc | 3 ++ onnxruntime/core/session/ort_apis.h | 2 + onnxruntime/core/session/ort_env.cc | 16 ++---- onnxruntime/core/session/ort_env.h | 14 +----- onnxruntime/core/session/user_logging_sink.h | 28 +++++++++++ .../test/framework/inference_session_test.cc | 41 ++++++++++++++++ winml/adapter/winml_adapter_environment.cpp | 5 +- 12 files changed, 160 insertions(+), 49 deletions(-) create mode 100644 onnxruntime/core/session/user_logging_sink.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index e483c67a0cfe6..4b911e3482e90 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -745,6 +745,8 @@ struct OrtApi { /** \brief Create an OrtEnv * + * \note Invoking this function will return the same instance of the environment as that returned by a previous call + * to another env creation function; all arguments to this function will be ignored. * \param[in] log_severity_level The log severity level. * \param[in] logid The log identifier. * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv @@ -755,17 +757,20 @@ struct OrtApi { /** \brief Create an OrtEnv * + * \note Invoking this function will return the same instance of the environment as that returned by a previous call + * to another env creation function; all arguments to this function will be ignored. If you want to provide your + * own logging function, consider setting it using the SetUserLoggingFunction API instead. * \param[in] logging_function A pointer to a logging function. * \param[in] logger_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to - * `logging_function`. + * `logging_function`. This parameter is optional. * \param[in] log_severity_level The log severity level. * \param[in] logid The log identifier. * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv * * \snippet{doc} snippets.dox OrtStatus Return Value */ - ORT_API2_STATUS(CreateEnvWithCustomLogger, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, - OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out); + ORT_API2_STATUS(CreateEnvWithCustomLogger, _In_ OrtLoggingFunction logging_function, _In_opt_ void* logger_param, + _In_ OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out); /** \brief Enable Telemetry * @@ -4413,6 +4418,26 @@ struct OrtApi { * \since Version 1.16. */ ORT_API2_STATUS(KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resouce_version, _In_ int resource_id, _Outptr_ void** resource); + + /** \brief Set user logging function + * + * By default the logger created by the CreateEnv* functions is used to create the session logger as well. + * This function allows a user to override this default session logger with a logger of their own choosing. This way + * the user doesn't have to create a separate environment with a custom logger. This addresses the problem when + * the user already created an env but now wants to use a different logger for a specific session (for debugging or + * other reasons). + * + * \param[in] options + * \param[in] user_logging_function A pointer to a logging function. + * \param[in] user_logging_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to + * `user_logging_function`. This parameter is optional. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.17. + */ + ORT_API2_STATUS(SetUserLoggingFunction, _Inout_ OrtSessionOptions* options, + _In_ OrtLoggingFunction user_logging_function, _In_opt_ void* user_logging_param); }; /* diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index aff90b8d40bde..8deeb4c2b8b64 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -148,6 +148,10 @@ struct SessionOptions { std::shared_ptr custom_op_libs; void AddCustomOpLibraryHandle(PathString library_name, void* library_handle); #endif + + // User specified logging func and param + OrtLoggingFunction user_logging_function = nullptr; + void* user_logging_param = nullptr; }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 4fcc6de561f8c..fb314b161f1ad 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -143,6 +143,14 @@ ORT_API_STATUS_IMPL(OrtApis::SetSessionLogId, _In_ OrtSessionOptions* options, c return nullptr; } +///< logging function and optional logging param to use for session output +ORT_API_STATUS_IMPL(OrtApis::SetUserLoggingFunction, _In_ OrtSessionOptions* options, + _In_ OrtLoggingFunction user_logging_function, _In_opt_ void* user_logging_param) { + options->value.user_logging_function = user_logging_function; + options->value.user_logging_param = user_logging_param; + return nullptr; +} + ///< applies to session load, initialization, etc ORT_API_STATUS_IMPL(OrtApis::SetSessionLogVerbosityLevel, _In_ OrtSessionOptions* options, int session_log_verbosity_level) { options->value.session_log_verbosity_level = session_log_verbosity_level; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 21c8fbe0cd2c9..b4d47652942b7 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -56,6 +56,7 @@ #include "core/providers/dml/dml_session_options_config_keys.h" #endif #include "core/session/environment.h" +#include "core/session/user_logging_sink.h" #include "core/session/IOBinding.h" #include "core/session/inference_session_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -298,6 +299,35 @@ static Status FinalizeSessionOptions(const SessionOptions& user_provided_session return Status::OK(); } +logging::Severity GetSeverity(const SessionOptions& session_options) { + logging::Severity severity = logging::Severity::kWARNING; + if (session_options.session_log_severity_level == -1) { + severity = logging::LoggingManager::DefaultLogger().GetSeverity(); + } else { + ORT_ENFORCE(session_options.session_log_severity_level >= 0 && + session_options.session_log_severity_level <= static_cast(logging::Severity::kFATAL), + "Invalid session log severity level. Not a valid onnxruntime::logging::Severity value: ", + session_options.session_log_severity_level); + severity = static_cast(session_options.session_log_severity_level); + } + return severity; +} + +void InferenceSession::SetLoggingManager(const SessionOptions& session_options, + const Environment& session_env) { + logging_manager_ = session_env.GetLoggingManager(); + if (session_options.user_logging_function) { + std::unique_ptr user_sink = std::make_unique(session_options.user_logging_function, + session_options.user_logging_param); + user_logging_manager_ = std::make_unique(std::move(user_sink), + GetSeverity(session_options), + false, + logging::LoggingManager::InstanceType::Temporal, + &session_options.session_logid); + logging_manager_ = user_logging_manager_.get(); + } +} + void InferenceSession::ConstructorCommon(const SessionOptions& session_options, const Environment& session_env) { auto status = FinalizeSessionOptions(session_options, model_proto_, is_model_proto_parsed_, session_options_); @@ -306,6 +336,8 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, ORT_ENFORCE(status.IsOK(), "Could not finalize session options while constructing the inference session. Error Message: ", status.ErrorMessage()); + SetLoggingManager(session_options, session_env); + // The call to InitLogger depends on the final state of session_options_. Hence it should be invoked // after the invocation of FinalizeSessionOptions. InitLogger(logging_manager_); // this sets session_logger_ so that it can be used for logging after this point. @@ -427,7 +459,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const #if !defined(ORT_MINIMAL_BUILD) graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), #endif - logging_manager_(session_env.GetLoggingManager()), environment_(session_env) { // Initialize assets of this session instance ConstructorCommon(session_options, session_env); @@ -441,7 +472,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, #if !defined(ORT_MINIMAL_BUILD) graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), #endif - logging_manager_(session_env.GetLoggingManager()), external_intra_op_thread_pool_(external_intra_op_thread_pool), external_inter_op_thread_pool_(external_inter_op_thread_pool), environment_(session_env) { @@ -454,7 +484,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const const PathString& model_uri) : model_location_(model_uri), graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), - logging_manager_(session_env.GetLoggingManager()), environment_(session_env) { auto status = Model::Load(model_location_, model_proto_); ORT_ENFORCE(status.IsOK(), "Given model could not be parsed while creating inference session. Error message: ", @@ -475,7 +504,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env, std::istream& model_istream) : graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), - logging_manager_(session_env.GetLoggingManager()), environment_(session_env) { Status st = Model::Load(model_istream, &model_proto_); ORT_ENFORCE(st.IsOK(), "Could not parse model successfully while constructing the inference session"); @@ -487,7 +515,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env, const void* model_data, int model_data_len) : graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), - logging_manager_(session_env.GetLoggingManager()), environment_(session_env) { const bool result = model_proto_.ParseFromArray(model_data, model_data_len); ORT_ENFORCE(result, "Could not parse model successfully while constructing the inference session"); @@ -2815,17 +2842,7 @@ const logging::Logger& InferenceSession::CreateLoggerForRun(const RunOptions& ru void InferenceSession::InitLogger(logging::LoggingManager* logging_manager) { // create logger for session, using provided logging manager if possible if (logging_manager != nullptr) { - logging::Severity severity = logging::Severity::kWARNING; - if (session_options_.session_log_severity_level == -1) { - severity = logging::LoggingManager::DefaultLogger().GetSeverity(); - } else { - ORT_ENFORCE(session_options_.session_log_severity_level >= 0 && - session_options_.session_log_severity_level <= static_cast(logging::Severity::kFATAL), - "Invalid session log severity level. Not a valid onnxruntime::logging::Severity value: ", - session_options_.session_log_severity_level); - severity = static_cast(session_options_.session_log_severity_level); - } - + logging::Severity severity = GetSeverity(session_options_); owned_session_logger_ = logging_manager_->CreateLogger(session_options_.session_logid, severity, false, session_options_.session_log_verbosity_level); session_logger_ = owned_session_logger_.get(); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 9259e014b9860..4db436f132d11 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -595,7 +595,8 @@ class InferenceSession { private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InferenceSession); - + void SetLoggingManager(const SessionOptions& session_options, + const Environment& session_env); void ConstructorCommon(const SessionOptions& session_options, const Environment& session_env); @@ -698,7 +699,10 @@ class InferenceSession { SessionOptions session_options_; /// Logging manager if provided. - logging::LoggingManager* const logging_manager_; + logging::LoggingManager* logging_manager_; + + /// User specified logging mgr; logging_manager_ is simply the ptr in this unique_ptr when available + std::unique_ptr user_logging_manager_; /// Logger for this session. WARNING: Will contain nullptr if logging_manager_ is nullptr. std::unique_ptr owned_session_logger_ = nullptr; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 60b6296f7f539..67149e1e99f22 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2713,6 +2713,8 @@ static constexpr OrtApi ort_api_1_to_17 = { &OrtApis::GetCUDAProviderOptionsByName, &OrtApis::KernelContext_GetResource, // End of Version 16 - DO NOT MODIFY ABOVE (see above text for more information) + + &OrtApis::SetUserLoggingFunction, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. @@ -2742,6 +2744,7 @@ static_assert(offsetof(OrtApi, ReleaseCANNProviderOptions) / sizeof(void*) == 22 static_assert(offsetof(OrtApi, GetSessionConfigEntry) / sizeof(void*) == 238, "Size of version 14 API cannot change"); static_assert(offsetof(OrtApi, GetBuildInfoString) / sizeof(void*) == 254, "Size of version 15 API cannot change"); static_assert(offsetof(OrtApi, KernelContext_GetResource) / sizeof(void*) == 265, "Size of version 16 API cannot change"); +static_assert(offsetof(OrtApi, SetUserLoggingFunction) / sizeof(void*) == 266, "Size of version 17 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.17.0", diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 47da2fa524588..f472932e20b8a 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -491,4 +491,6 @@ ORT_API_STATUS_IMPL(GetTensorRTProviderOptionsByName, _In_ const OrtTensorRTProv ORT_API_STATUS_IMPL(UpdateCUDAProviderOptionsWithValue, _Inout_ OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _In_ void* value); ORT_API_STATUS_IMPL(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr); ORT_API_STATUS_IMPL(KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resource_version, _In_ int resource_id, _Outptr_ void** stream); +ORT_API_STATUS_IMPL(SetUserLoggingFunction, _Inout_ OrtSessionOptions* options, + _In_ OrtLoggingFunction user_logging_function, _In_opt_ void* user_logging_param); } // namespace OrtApis diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc index eb78d5d799a55..e3957baa990f8 100644 --- a/onnxruntime/core/session/ort_env.cc +++ b/onnxruntime/core/session/ort_env.cc @@ -9,6 +9,7 @@ #include "core/session/ort_apis.h" #include "core/session/environment.h" #include "core/session/allocator_adapters.h" +#include "core/session/user_logging_sink.h" #include "core/common/logging/logging.h" #include "core/framework/provider_shutdown.h" #include "core/platform/logging/make_platform_default_log_sink.h" @@ -20,17 +21,6 @@ std::unique_ptr OrtEnv::p_instance_; int OrtEnv::ref_count_ = 0; onnxruntime::OrtMutex OrtEnv::m_; -LoggingWrapper::LoggingWrapper(OrtLoggingFunction logging_function, void* logger_param) - : logging_function_(logging_function), logger_param_(logger_param) { -} - -void LoggingWrapper::SendImpl(const onnxruntime::logging::Timestamp& /*timestamp*/, const std::string& logger_id, - const onnxruntime::logging::Capture& message) { - std::string s = message.Location().ToString(); - logging_function_(logger_param_, static_cast(message.Severity()), message.Category(), - logger_id.c_str(), s.c_str(), message.Message().c_str()); -} - OrtEnv::OrtEnv(std::unique_ptr value1) : value_(std::move(value1)) { } @@ -50,8 +40,8 @@ OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_inf std::unique_ptr lmgr; std::string name = lm_info.logid; if (lm_info.logging_function) { - std::unique_ptr logger = std::make_unique(lm_info.logging_function, - lm_info.logger_param); + std::unique_ptr logger = std::make_unique(lm_info.logging_function, + lm_info.logger_param); lmgr = std::make_unique(std::move(logger), static_cast(lm_info.default_warning_level), false, diff --git a/onnxruntime/core/session/ort_env.h b/onnxruntime/core/session/ort_env.h index 7d609acb2db5d..444134d0612e9 100644 --- a/onnxruntime/core/session/ort_env.h +++ b/onnxruntime/core/session/ort_env.h @@ -5,27 +5,15 @@ #include #include #include "core/session/onnxruntime_c_api.h" -#include "core/common/logging/isink.h" #include "core/platform/ort_mutex.h" #include "core/common/status.h" +#include "core/common/logging/logging.h" #include "core/framework/allocator.h" namespace onnxruntime { class Environment; } -class LoggingWrapper : public onnxruntime::logging::ISink { - public: - LoggingWrapper(OrtLoggingFunction logging_function, void* logger_param); - - void SendImpl(const onnxruntime::logging::Timestamp& /*timestamp*/, const std::string& logger_id, - const onnxruntime::logging::Capture& message) override; - - private: - OrtLoggingFunction logging_function_; - void* logger_param_; -}; - struct OrtEnv { public: struct LoggingManagerConstructionInfo { diff --git a/onnxruntime/core/session/user_logging_sink.h b/onnxruntime/core/session/user_logging_sink.h new file mode 100644 index 0000000000000..5a9ceb21d6500 --- /dev/null +++ b/onnxruntime/core/session/user_logging_sink.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +#include "core/session/onnxruntime_c_api.h" +#include "core/common/logging/isink.h" + +namespace onnxruntime { +class UserLoggingSink : public onnxruntime::logging::ISink { + public: + UserLoggingSink(OrtLoggingFunction logging_function, void* logger_param) + : logging_function_(logging_function), logger_param_(logger_param) { + } + + void SendImpl(const onnxruntime::logging::Timestamp& /*timestamp*/, const std::string& logger_id, + const onnxruntime::logging::Capture& message) override { + std::string s = message.Location().ToString(); + logging_function_(logger_param_, static_cast(message.Severity()), message.Category(), + logger_id.c_str(), s.c_str(), message.Message().c_str()); + } + + private: + OrtLoggingFunction logging_function_{}; + void* logger_param_{}; +}; +} // namespace onnxruntime diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 2298e4afa6de0..486ec37d1eebd 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -890,6 +890,47 @@ TEST(InferenceSessionTests, ConfigureVerbosityLevel) { #endif } +TEST(InferenceSessionTests, UseUserSpecifiedLoggingFunctionInSession) { + SessionOptions so; + /* + typedef void(ORT_API_CALL* OrtLoggingFunction)( + void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, + const char* message); + */ + std::vector log_msgs; + so.user_logging_function = [](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, + const char* message) { + ORT_UNUSED_PARAMETER(severity); + ORT_UNUSED_PARAMETER(category); + ORT_UNUSED_PARAMETER(logid); + ORT_UNUSED_PARAMETER(code_location); + std::vector* v_ptr = reinterpret_cast*>(param); + std::vector& msg_vector = *v_ptr; + msg_vector.push_back(std::string(message)); + }; + so.user_logging_param = &log_msgs; + so.session_log_severity_level = static_cast(Severity::kVERBOSE); + so.session_log_verbosity_level = 1; + so.session_logid = "InferenceSessionTests.UseUserSpecifiedLoggingFunctionInSession"; + + InferenceSession session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); + ASSERT_STATUS_OK(session_object.Initialize()); + + RunOptions run_options; + run_options.run_tag = "one session/one tag"; + RunModel(session_object, run_options); + +// vlog output is disabled in release builds +#ifndef NDEBUG + bool have_log_entry_with_vlog_session_msg = + (std::find_if(log_msgs.begin(), log_msgs.end(), + [&](std::string msg) { return msg.find("Added input argument with name") != string::npos; }) != + log_msgs.end()); + ASSERT_TRUE(have_log_entry_with_vlog_session_msg); +#endif +} + TEST(InferenceSessionTests, TestWithIstream) { SessionOptions so; diff --git a/winml/adapter/winml_adapter_environment.cpp b/winml/adapter/winml_adapter_environment.cpp index 43babdf43967e..e2da473c7d5b5 100644 --- a/winml/adapter/winml_adapter_environment.cpp +++ b/winml/adapter/winml_adapter_environment.cpp @@ -9,6 +9,7 @@ #include "winml_adapter_apis.h" #include "core/framework/error_code_helper.h" #include "core/session/ort_env.h" +#include "core/session/user_logging_sink.h" #ifdef USE_DML #include "abi_custom_registry_impl.h" @@ -18,12 +19,12 @@ #endif USE_DML namespace winmla = Windows::AI::MachineLearning::Adapter; -class WinmlAdapterLoggingWrapper : public LoggingWrapper { +class WinmlAdapterLoggingWrapper : public onnxruntime::UserLoggingSink { public: WinmlAdapterLoggingWrapper( OrtLoggingFunction logging_function, OrtProfilingFunction profiling_function, void* logger_param ) - : LoggingWrapper(logging_function, logger_param), + : onnxruntime::UserLoggingSink(logging_function, logger_param), profiling_function_(profiling_function) {} void SendProfileEvent(onnxruntime::profiling::EventRecord& event_record) const override { From 9aad78721c3873e6c9b148584089d3a13a889fc6 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 29 Sep 2023 20:40:09 -0700 Subject: [PATCH 15/20] Update debug_alloc.cc: filter out one more memory leak from absl (#17746) --- onnxruntime/core/platform/windows/debug_alloc.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/core/platform/windows/debug_alloc.cc b/onnxruntime/core/platform/windows/debug_alloc.cc index 2b612a4303442..ff6a059607367 100644 --- a/onnxruntime/core/platform/windows/debug_alloc.cc +++ b/onnxruntime/core/platform/windows/debug_alloc.cc @@ -225,6 +225,7 @@ Memory_LeakCheck::~Memory_LeakCheck() { string.find("re2::RE2::Init") == std::string::npos && string.find("dynamic initializer for 'FLAGS_") == std::string::npos && string.find("AbslFlagDefaultGenForgtest_") == std::string::npos && + string.find("AbslFlagDefaultGenForundefok::Gen") == std::string::npos && string.find("::SetProgramUsageMessage") == std::string::npos && string.find("testing::internal::ParseGoogleTestFlagsOnly") == std::string::npos && string.find("testing::internal::Mutex::ThreadSafeLazyInit") == std::string::npos && From a941dd583e6061d57564d770ae727860ff9b237e Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Sat, 30 Sep 2023 11:00:23 +0400 Subject: [PATCH 16/20] [js/web] FP16 Conv, ConvTranspose and MatMul (#17514) ### Description Another three ops for fp16 --------- Co-authored-by: Guenther Schmuelling Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- .../webgpu/ops/3rd-party/activation_util.ts | 10 ++-- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 42 ++++++++-------- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 27 ++++++----- .../ops/3rd-party/conv_backprop_webgpu.ts | 47 +++++++++--------- .../ops/3rd-party/matmul_packed_webgpu.ts | 43 +++++++++-------- .../wasm/jsep/webgpu/ops/conv-transpose.ts | 10 ---- js/web/lib/wasm/jsep/webgpu/ops/conv.ts | 10 ---- js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 6 --- .../providers/js/js_execution_provider.cc | 24 +++++----- .../core/providers/js/operators/conv.cc | 48 ++++++++----------- .../core/providers/js/operators/conv.h | 2 +- .../providers/js/operators/conv_transpose.cc | 47 ++++++++---------- .../providers/js/operators/conv_transpose.h | 2 +- .../core/providers/js/operators/matmul.cc | 4 +- 14 files changed, 148 insertions(+), 174 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts index dd4f13e76ee04..22b91d680a9b4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts @@ -21,16 +21,16 @@ export declare type Activation = 'linear' | 'relu' | 'prelu' | 'elu' | 'relu6' | 'leakyrelu' | 'sigmoid' | 'gelu'; -export const typeSnippet = (component: number) => { +export const typeSnippet = (component: number, dataType: string) => { switch (component) { case 1: - return 'f32'; + return dataType; case 2: - return 'vec2'; + return `vec2<${dataType}>`; case 3: - return 'vec3'; + return `vec3<${dataType}>`; case 4: - return 'vec4'; + return `vec4<${dataType}>`; default: throw new Error(`${component}-component is not supported.`); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 08b1d1f30b233..e6d4039d8131b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -23,6 +23,7 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; +import {tensorTypeToWsglStorageType} from '../common'; import {ConvAttributes} from '../conv'; import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; @@ -32,13 +33,13 @@ import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packe const conv2dCommonSnippet = (isChannelsLast: boolean, fitAOuter: boolean, fitBOuter: boolean, fitInner: boolean, addBias = false, activation?: Activation, hasPreluActivationWeights = false, innerElementSizeX = 4, innerElementSizeW = 4, - innerElementSize = 4): string => { + innerElementSize = 4, dataType = 'f32'): string => { const getXSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: return 'resData = x[xIndex];'; case 3: - return 'resData = vec3(x[xIndex], x[xIndex + 1], x[xIndex + 2]);'; + return `resData = vec3<${dataType}>(x[xIndex], x[xIndex + 1], x[xIndex + 2]);`; case 4: return 'resData = x[xIndex / 4];'; default: @@ -92,7 +93,7 @@ const conv2dCommonSnippet = let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0]; let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1]; let xCh = ${col} % inChannels; - var resData = ${typeSnippet(innerElementSizeX)}(0.0); + var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0); // The bounds checking is always needed since we use it to pad zero for // the 'same' padding type. if (xRow >= 0 && xRow < ${xHeight} && xCol >= 0 && xCol < ${xWidth}) { @@ -110,7 +111,7 @@ const conv2dCommonSnippet = if (row < dimAOuter && col < dimInner) { ${readXSnippet} } - return ${typeSnippet(innerElementSizeX)}(0.0);`) : + return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) : (fitInner && fitBOuter ? ` let col = colIn * ${innerElementSizeX}; ${readXSnippet}` : @@ -119,13 +120,15 @@ const conv2dCommonSnippet = if (row < dimInner && col < dimBOuter) { ${readXSnippet} } - return ${typeSnippet(innerElementSizeX)}(0.0);`); + return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`); const sampleW = `${getWSnippet(innerElementSizeW)}`; - const resType = typeSnippet(innerElementSize); - const aType = isChannelsLast ? typeSnippet(innerElementSizeX) : typeSnippet(innerElementSizeW); - const bType = isChannelsLast ? typeSnippet(innerElementSizeW) : typeSnippet(innerElementSizeX); + const resType = typeSnippet(innerElementSize, dataType); + const aType = + isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); + const bType = + isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); const userCode = ` ${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { @@ -190,23 +193,24 @@ export const createConv2DMatMulProgramInfo = const fitInner = dimInner % tileInner === 0; const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; + const t = tensorTypeToWsglStorageType(inputs[0].dataType); const declareInputs = [ - `@group(0) @binding(0) var x: array<${isVec4 && innerElementSize === 4 ? 'vec4' : 'f32'}>;`, - `@group(0) @binding(1) var w: array<${isVec4 ? 'vec4' : 'f32'}>;` + `@group(0) @binding(0) var x: array<${isVec4 && innerElementSize === 4 ? `vec4<${t}>` : t}>;`, + `@group(0) @binding(1) var w: array<${isVec4 ? `vec4<${t}>` : t}>;` ]; let declareFunctions = ` - fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? 'vec4' : 'f32'}) { - result[flatIndex] = ${isVec4 ? 'vec4' : 'f32'}(value); + fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { + result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value); } - fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : ${isVec4 ? 'vec4' : 'f32'}) { + fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { let flatIndex = getOutputIndexFromCoords(vec4(d0, d1, d2, d3)); setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value); }`; if (hasBias) { - declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? 'vec4' : 'f32'}>;`); + declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? `vec4<${t}>` : t}>;`); declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${t}>` : t} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } @@ -222,7 +226,7 @@ export const createConv2DMatMulProgramInfo = // dilation : vec2, dimAOuter : i32, dimBOuter : i32, dimInner : i32 }; ${declareInputs.join('')} @group(0) @binding(${declareInputs.length}) var result: array<${ - isVec4 ? 'vec4' : 'f32'}>; + isVec4 ? `vec4<${t}>` : t}>; //@group(0) @binding(${declareInputs.length + 1}) var uniforms: Uniforms; const xShape : vec4 = vec4(${inputs[0].dims.join(',')}); @@ -240,12 +244,12 @@ export const createConv2DMatMulProgramInfo = ${ conv2dCommonSnippet( isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, undefined, false, elementsSize[0], - elementsSize[1], elementsSize[2])} + elementsSize[1], elementsSize[2], t)} ${ isVec4 ? - makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) : + makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( - elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner, false, undefined, + elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner, false, undefined, sequentialAccessByThreads)}` }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 3925e1cb4f564..f41d0d058a624 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -32,6 +32,7 @@ import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packe const conv2dTransposeCommonSnippet = (isChannelsLast: boolean, addBias = false, activation?: Activation, hasPreluActivationWeights = false, innerElementSize = 4): string => { + const type = typeSnippet(innerElementSize, 'f32'); const getWSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: @@ -89,10 +90,10 @@ const conv2dTransposeCommonSnippet = let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]); let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]); if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) { - return ${typeSnippet(innerElementSize)}(0.0); + return ${type}(0.0); } if (xC < 0.0 || xC >= f32(${xWidth}) || fract(xC) > 0.0) { - return ${typeSnippet(innerElementSize)}(0.0); + return ${type}(0.0); } let iXR = i32(xR); let iXC = i32(xC); @@ -105,13 +106,13 @@ const conv2dTransposeCommonSnippet = if (row < dimAOuter && col < dimInner) { ${readASnippet} } - return ${typeSnippet(innerElementSize)}(0.0);` : + return ${type}(0.0);` : ` let col = colIn * ${innerElementSize}; if (row < dimInner && col < dimBOuter) { ${readASnippet} } - return ${typeSnippet(innerElementSize)}(0.0);`; + return ${type}(0.0);`; const sampleW = ` let col = colIn * ${innerElementSize}; @@ -125,21 +126,21 @@ const conv2dTransposeCommonSnippet = let coord = vec4(coordX, coordY, col, rowInner); ${getWSnippet(innerElementSize)} } - return ${typeSnippet(innerElementSize)}(0.0); + return ${type}(0.0); `; const userCode = ` ${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)} - fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${typeSnippet(innerElementSize)} { + fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${type} { ${isChannelsLast ? sampleA : sampleW} } - fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${typeSnippet(innerElementSize)} { + fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${type} { ${isChannelsLast ? sampleW : sampleA} } - fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${typeSnippet(innerElementSize)}) { + fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) { let col = colIn * ${innerElementSize}; if (row < dimAOuter && col < dimBOuter) { var value = valueInput; @@ -234,10 +235,10 @@ export const createConv2DTransposeMatMulProgramInfo = ${declareFunctions} ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, undefined, false, innerElementSize)} ${ - isVec4 ? - makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) : - makeMatMulPackedSource( - elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner, false, undefined, - sequentialAccessByThreads)}` + isVec4 ? makeMatMulPackedVec4Source( + elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : + makeMatMulPackedSource( + elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false, + undefined, sequentialAccessByThreads)}` }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 4c8922238ac5b..c60b94056a360 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -21,12 +21,13 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; -import {inputVariable, outputVariable, ShaderHelper} from '../common'; +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; const createConvTranspose2DOpProgramShaderSource = (shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: ConvTransposeAttributes, - outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false): string => { + outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false, + dataType: string): string => { const isChannelsLast = attributes.format === 'NHWC'; const rowDim = isChannelsLast ? 1 : 2; const colDim = isChannelsLast ? 2 : 3; @@ -39,12 +40,12 @@ const createConvTranspose2DOpProgramShaderSource = const outputChannelsPerGroup = wShape[1]; let declareFunctions = ` - fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? 'vec4' : 'f32'}) { - result[flatIndex] = ${isVec4 ? 'vec4' : 'f32'}(value); + fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) { + result[flatIndex] = ${isVec4 ? `vec4<${dataType}>` : dataType}(value); }`; if (hasBias) { declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${dataType}>` : dataType} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } @@ -66,33 +67,33 @@ const createConvTranspose2DOpProgramShaderSource = // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. - var dotProd: array, ${workPerThread}>; + var dotProd: array, ${workPerThread}>; for (var i = 0; i < ${workPerThread}; i++) { - dotProd[i] = vec4(0.0); + dotProd[i] = vec4<${dataType}>(0.0); } for (var wR: u32 = 0; wR < filterDims[0]; wR = wR + 1) { - var dyR = (f32(dyCorner.x) + f32(wR)) / f32(strides.x); + var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(strides.x); let wRPerm = filterDims[0] - 1 - wR; - if (dyR < 0.0 || dyR >= f32(outBackprop[1]) || + if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[1]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } let idyR: u32 = u32(dyR); for (var wC: u32 = 0; wC < filterDims[1]; wC = wC + 1) { - let dyC = (f32(dyCorner.y) + f32(wC)) / f32(strides.y); - let dyC2 = (f32(dyCorner.y) + 1.0 + f32(wC)) / f32(strides.y); + let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(strides.y); + let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(strides.y); let wCPerm = filterDims[1] - 1 - wC; if (wCPerm < 0) { continue; } var bDyCVal = true; var bDyCVal2 = true; - if (dyC < 0.0 || dyC >= f32(outBackprop[2]) || + if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[2]) || fract(dyC) > 0.0) { bDyCVal = false; } - if (dyC2 < 0.0 || dyC2 >= f32(outBackprop[2]) || + if (dyC2 < 0.0 || dyC2 >= ${dataType}(outBackprop[2]) || fract(dyC2) > 0.0) { bDyCVal2 = false; } @@ -108,7 +109,7 @@ const createConvTranspose2DOpProgramShaderSource = let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; - let tmpval = vec4(dot(xValue, wValue0), + let tmpval = vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -116,7 +117,7 @@ const createConvTranspose2DOpProgramShaderSource = xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; - dotProd[1] = dotProd[1] + vec4(dot(xValue, wValue0), + dotProd[1] = dotProd[1] + vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -130,7 +131,7 @@ const createConvTranspose2DOpProgramShaderSource = let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; - let tmpval = vec4(dot(xValue, wValue0), + let tmpval = vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -145,7 +146,7 @@ const createConvTranspose2DOpProgramShaderSource = let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; var xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; - let tmpval = vec4(dot(xValue, wValue0), + let tmpval = vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -178,9 +179,9 @@ const createConvTranspose2DOpProgramShaderSource = if (wR % dilations.x != 0) { continue; } - let dyR = (f32(dyRCorner) + f32(wR)) / f32(strides[0]); + let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(strides[0]); let wRPerm = filterDims.x - 1 - wR / dilations.x; - if (dyR < 0.0 || dyR >= f32(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || + if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } @@ -190,9 +191,9 @@ const createConvTranspose2DOpProgramShaderSource = if (wC % dilations.y != 0) { continue; } - let dyC = (f32(dyCCorner) + f32(wC)) / f32(strides.y); + let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(strides.y); let wCPerm = filterDims.y - 1 - wC / dilations.y; - if (dyC < 0.0 || dyC >= f32(outBackprop[${colDim}]) || + if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[${colDim}]) || fract(dyC) > 0.0 || wCPerm < 0) { continue; } @@ -256,6 +257,7 @@ export const createConvTranspose2DProgramInfo = ]; LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); return { ...metadata, outputs: [{ @@ -265,6 +267,7 @@ export const createConvTranspose2DProgramInfo = }], dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}), getShaderSource: (shaderHelper: ShaderHelper) => createConvTranspose2DOpProgramShaderSource( - shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1), + shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1, false, + dataType), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 8d43dbb378a69..82f8c82291f4b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -22,7 +22,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; -import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from '../common'; +import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {getActicationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -70,8 +70,8 @@ const calculateResultSnippet = (transposeA: boolean, innerElementSize: number) = }; export const makeMatMulPackedVec4Source = - (workPerThread: number[], workgroupSize: [number, number, number], batchDims?: IndicesHelper, transposeA = false, - tileInner = 32, splitK = false, splitedDimInner = 32): string => { + (workPerThread: number[], workgroupSize: [number, number, number], type = 'f32', batchDims?: IndicesHelper, + transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32): string => { const tileAOuter = workgroupSize[1] * workPerThread[1]; const tileBOuter = workgroupSize[0] * workPerThread[0]; const tileAWidth = transposeA ? tileAOuter : tileInner; @@ -90,8 +90,8 @@ export const makeMatMulPackedVec4Source = workPerThread[0]} must be 4.`); } return ` -var mm_Asub : array, ${tileAWidth / innerElementSize}>, ${tileAHight}>; -var mm_Bsub : array, ${tileBOuter / workPerThread[0]}>, ${tileInner}>; +var mm_Asub : array, ${tileAWidth / innerElementSize}>, ${tileAHight}>; +var mm_Bsub : array, ${tileBOuter / workPerThread[0]}>, ${tileInner}>; const rowPerThread = ${workPerThread[1]}; const colPerThread = ${workPerThread[0]}; @@ -115,7 +115,7 @@ fn main(@builtin(local_invocation_id) localId : vec3, let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; - var acc: array, rowPerThread>; + var acc: array, rowPerThread>; // Loop over shared dimension. let tileRowB = localRow * ${rowPerThreadB}; @@ -179,8 +179,9 @@ const readDataFromSubASnippet = (transposeA: boolean) => // sequentialAccessByThreads means sequential data in memory is accessed by // threads, instead of a single thread (default behavior). export const makeMatMulPackedSource = - (workPerThread: number[], workgroupSize: [number, number, number], batchDims?: IndicesHelper, transposeA = false, - tileInner = 32, splitK = false, splitedDimInner = 32, sequentialAccessByThreads = false): string => { + (workPerThread: number[], workgroupSize: [number, number, number], type = 'f32', batchDims?: IndicesHelper, + transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32, + sequentialAccessByThreads = false): string => { const tileAOuter = workPerThread[1] * workgroupSize[1]; const tileBOuter = workPerThread[0] * workgroupSize[0]; const tileAWidth = transposeA ? tileAOuter : tileInner; @@ -222,7 +223,7 @@ export const makeMatMulPackedSource = workgroupBarrier(); // Compute acc values for a single thread. - var BCached : array; + var BCached : array<${type}, colPerThread>; for (var k = 0; k < tileInner; k = k + 1) { for (var inner = 0; inner < colPerThread; inner = inner + 1) { BCached[inner] = mm_Bsub[k][localCol + inner * ${workgroupSize[0]}]; @@ -283,7 +284,7 @@ for (var t = 0; t < numTiles; t = t + 1) { workgroupBarrier(); // Compute acc values for a single thread. - var BCached : array; + var BCached : array<${type}, colPerThread>; for (var k = 0; k < tileInner; k = k + 1) { for (var inner = 0; inner < colPerThread; inner = inner + 1) { BCached[inner] = mm_Bsub[k][tileCol + inner]; @@ -309,8 +310,8 @@ for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { `; return ` - var mm_Asub : array, ${tileAHight}>; - var mm_Bsub : array, ${tileInner}>; + var mm_Asub : array, ${tileAHight}>; + var mm_Bsub : array, ${tileInner}>; const rowPerThread = ${workPerThread[1]}; const colPerThread = ${workPerThread[0]}; const tileInner = ${tileInner}; @@ -324,7 +325,7 @@ fn main(@builtin(local_invocation_id) localId : vec3, let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; - var acc : array, rowPerThread>; + var acc : array, rowPerThread>; // Without this initialization strange values show up in acc. for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { @@ -347,6 +348,7 @@ const matMulReadWriteFnSource = const outputVariable = variables[5]; const broadCastADims = getBroadcastDims(batchAVariable.shape, batchVariable.shape); const broadCastBDims = getBroadcastDims(batchBVariable.shape, batchVariable.shape); + const dataType = tensorTypeToWsglStorageType(variables[0].type.tensor); const getAIndices = () => { const aRank = aVariable.shape.length; const batchRank = batchVariable.shape.length; @@ -377,8 +379,8 @@ const matMulReadWriteFnSource = }; const source = ` fn mm_readA(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${ - typeSnippet(component)} { - var value = ${typeSnippet(component)}(0.0); + typeSnippet(component, dataType)} { + var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; if(row < dimAOuter && col < dimInner) { @@ -389,8 +391,8 @@ const matMulReadWriteFnSource = } fn mm_readB(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${ - typeSnippet(component)} { - var value = ${typeSnippet(component)}(0.0); + typeSnippet(component, dataType)} { + var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; if(row < dimInner && col < dimBOuter) { @@ -400,7 +402,7 @@ const matMulReadWriteFnSource = return value; } - fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component)}) { + fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component, dataType)}) { let col = colIn * ${component}; if (row < dimAOuter && col < dimBOuter) { var value = valueIn; @@ -444,6 +446,7 @@ export const createMatmulProgramInfo = Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2]) ]; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const components = isVec4 ? 4 : 1; const A = inputVariable('a', inputs[0].dataType, [...outerDimsA, dimAOuter, dimInner / components], components); const B = inputVariable('b', inputs[1].dataType, [...outerDimsB, dimInner, dimBOuter / components], components); @@ -466,8 +469,8 @@ export const createMatmulProgramInfo = ${declareFunctions} ${activationFunction} ${ - isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, batchDims) : - makeMatMulPackedSource(elementsPerThread, workgroupSize, batchDims)} + isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : + makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} ${batchDims.impl()}`; return { ...metadata, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index 5641386cce849..59f11d6d9abba 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types'; @@ -201,15 +200,6 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvTranspose if (attributes.outputShape.length !== 0 && attributes.outputShape.length !== inputs[0].dims.length - 2) { throw new Error('invalid output shape'); } - - // TODO : Need to add support for float64 - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('ConvTranspose input(X,W) should be float tensor'); - } - - if (inputs.length === 3 && inputs[2].dataType !== DataType.float) { - throw new Error('ConvTranspose input(bias) should be float tensor'); - } }; const createConvTranspose2DProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 95a64e5787841..7afc3ce1b9d77 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {PoolConvUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -93,15 +92,6 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvAttribute if (attributes.kernelShape.length !== 0 && attributes.kernelShape.length !== inputs[1].dims.length - 2) { throw new Error('invalid kernel shape'); } - - // TODO : Need to add support for float64 - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('Conv input(X,W) should be float tensor'); - } - - if (inputs.length === 3 && inputs[2].dataType !== DataType.float) { - throw new Error('Conv input(bias) should be float tensor'); - } }; const getAdjustedConvAttributes = (attributes: T, inputs: readonly TensorView[]): T => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index 837ac8410f291..7dadf9a6205ea 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {BroadcastUtil} from '../../util'; import {ComputeContext, GpuDataType, ProgramInfoLoader} from '../types'; @@ -9,7 +8,6 @@ import {ComputeContext, GpuDataType, ProgramInfoLoader} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; import {InternalActivationAttributes} from './fuse-utils'; - const createMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({ name: 'MatMul', inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : @@ -35,10 +33,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (inputs[0].dims[inputs[0].dims.length - 1] !== inputs[1].dims[inputs[1].dims.length - 2]) { throw new Error('shared dimension does not match.'); } - - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('inputs should be float type'); - } }; export const matMul = (context: ComputeContext): void => { diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 72e36a161e9aa..6ced8d4d4a4ad 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -232,18 +232,18 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Uns class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, float, Conv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, float, ConvTranspose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, Conv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Conv); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ConvTranspose); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, Conv); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, ConvTranspose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gemm); @@ -496,18 +496,18 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/conv.cc b/onnxruntime/core/providers/js/operators/conv.cc index c7c9f7f7c3f0e..2e07124dcd901 100644 --- a/onnxruntime/core/providers/js/operators/conv.cc +++ b/onnxruntime/core/providers/js/operators/conv.cc @@ -9,33 +9,27 @@ namespace onnxruntime { namespace js { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Conv, \ - kMSInternalNHWCDomain, \ - 11, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Conv, \ - kOnnxDomain, \ - 11, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - Conv, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); - -REGISTER_KERNEL_TYPED(float) +ONNX_OPERATOR_KERNEL_EX( + Conv, + kMSInternalNHWCDomain, + 11, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + Conv); +ONNX_OPERATOR_KERNEL_EX( + Conv, + kOnnxDomain, + 11, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + Conv); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Conv, + kOnnxDomain, + 1, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + Conv); } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index 22f7721276677..fdf3e5b6c6b66 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -9,7 +9,7 @@ namespace onnxruntime { namespace js { -template +template class Conv : public JsKernel { public: Conv(const OpKernelInfo& info) : JsKernel(info), conv_attrs_(info), w_is_const_(false) { diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.cc b/onnxruntime/core/providers/js/operators/conv_transpose.cc index 1a2fc99eada6a..2228343e1e6e3 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.cc +++ b/onnxruntime/core/providers/js/operators/conv_transpose.cc @@ -7,33 +7,28 @@ #include "conv_transpose.h" namespace onnxruntime { namespace js { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - ConvTranspose, \ - kMSInternalNHWCDomain, \ - 11, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ConvTranspose); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - ConvTranspose, \ - kOnnxDomain, \ - 11, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ConvTranspose); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - ConvTranspose, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ConvTranspose); -REGISTER_KERNEL_TYPED(float) +ONNX_OPERATOR_KERNEL_EX( + ConvTranspose, + kMSInternalNHWCDomain, + 11, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + ConvTranspose); +ONNX_OPERATOR_KERNEL_EX( + ConvTranspose, + kOnnxDomain, + 11, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + ConvTranspose); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ConvTranspose, + kOnnxDomain, + 1, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + ConvTranspose); } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index c3babbc5ce81f..18ef73268005d 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -9,7 +9,7 @@ #include "core/providers/js/js_kernel.h" namespace onnxruntime { namespace js { -template +template class ConvTranspose : public JsKernel { public: ConvTranspose(const OpKernelInfo& info) : JsKernel(info), conv_transpose_attrs_(info), w_is_const_(false) { diff --git a/onnxruntime/core/providers/js/operators/matmul.cc b/onnxruntime/core/providers/js/operators/matmul.cc index ddfbb454def07..6e6f906f7b42c 100644 --- a/onnxruntime/core/providers/js/operators/matmul.cc +++ b/onnxruntime/core/providers/js/operators/matmul.cc @@ -9,11 +9,11 @@ namespace js { JSEP_KERNEL_IMPL(MatMul, MatMul) ONNX_OPERATOR_VERSIONED_KERNEL_EX(MatMul, kOnnxDomain, 1, 12, kJsExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), MatMul); ONNX_OPERATOR_KERNEL_EX(MatMul, kOnnxDomain, 13, kJsExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), MatMul); } // namespace js From 0d606046380caa32fddffb0ee4d3434418719ede Mon Sep 17 00:00:00 2001 From: xhcao Date: Sat, 30 Sep 2023 17:05:32 +0800 Subject: [PATCH 17/20] [JS/WebGPU] support Range operator (#17233) The patch also introduces the method which copies data from GPU to CPU synchronously. ### Description ### Motivation and Context --- js/common/lib/env.ts | 6 ++ js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + js/web/lib/wasm/jsep/webgpu/ops/range.ts | 66 +++++++++++++++++++ js/web/script/test-runner-cli-args.ts | 6 +- js/web/test/suite-test-list.jsonc | 8 +-- js/web/test/test-main.ts | 3 + .../providers/js/js_execution_provider.cc | 7 ++ .../core/providers/js/operators/range.cc | 22 +++++++ .../core/providers/js/operators/range.h | 14 ++++ 10 files changed, 130 insertions(+), 5 deletions(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/range.ts create mode 100644 onnxruntime/core/providers/js/operators/range.cc create mode 100644 onnxruntime/core/providers/js/operators/range.h diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index c78ae0fc83010..76575ef7b9368 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -106,6 +106,12 @@ export declare namespace Env { * see comments on {@link GpuBufferType} for more details about why not use types defined in "@webgpu/types". */ readonly device: unknown; + /** + * Set or get whether validate input content. + * + * @defaultValue `false` + */ + validateInputContent?: boolean; } } diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index a87a894e3b3c5..f8ac29e5f82ca 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -62,6 +62,7 @@ Do not modify directly.* | Not | ai.onnx(1+) | | | Pad | ai.onnx(2-10,11-12,13-17,18,19+) | | | Pow | ai.onnx(7-11,12,13-14,15+) | | +| Range | ai.onnx(11+) | | | Reciprocal | ai.onnx(6-12,13+) | | | ReduceL1 | ai.onnx(1-10,11-12,13-17,18+) | | | ReduceL2 | ai.onnx(1-10,11-12,13-17,18+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index e92e6696d9a78..cbe845b882468 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -16,6 +16,7 @@ import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm'; import {matMul} from './ops/matmul'; import {pad, parsePadAttributes} from './ops/pad'; import * as pool from './ops/pool'; +import {range} from './ops/range'; import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; import {parseResizeAttributes, resize} from './ops/resize'; import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm'; @@ -83,6 +84,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Not', [unaryOps.not]], ['Pad', [pad, parsePadAttributes]], ['Pow', [binaryOps.pow]], + ['Range', [range]], ['Reciprocal', [unaryOps.reciprocal]], ['ReduceMin', [reduceMin, parseReduceAttributes]], ['ReduceMean', [reduceMean, parseReduceAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/range.ts b/js/web/lib/wasm/jsep/webgpu/ops/range.ts new file mode 100644 index 0000000000000..3ecb3308b1899 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/range.ts @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {env} from 'onnxruntime-common'; + +import {DataType} from '../../../wasm-common'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; + +import {outputVariable, ShaderHelper} from './common'; + +const validateInputsContent = (start: number, limit: number, delta: number): void => { + const sameStartLimit = start === limit; + const increasingRangeNegativeStep = start < limit && delta < 0; + const decreasingRangePositiveStep = start > limit && delta > 0; + + if (sameStartLimit || increasingRangeNegativeStep || decreasingRangePositiveStep) { + throw new Error('Range these inputs\' contents are invalid.'); + } +}; + +const createRangeProgramInfo = + (metadata: ProgramMetadata, start: number, limit: number, delta: number, dataType: DataType): ProgramInfo => { + const numElements = Math.abs(Math.ceil((limit - start) / delta)); + const outputShape: number[] = [numElements]; + const outputSize = numElements; + + const output = outputVariable('output', dataType, outputShape); + const wgslType = output.type.storage; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${shaderHelper.declareVariables(output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + output[global_idx] = ${wgslType}(${start}) + ${wgslType}(global_idx) * ${wgslType}(${delta}); + }`; + return { + ...metadata, + getShaderSource, + outputs: [{dims: outputShape, dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }; + }; + +export const range = (context: ComputeContext): void => { + let start = 0; + let limit = 0; + let delta = 0; + if (context.inputs[0].dataType === DataType.int32) { + start = context.inputs[0].getInt32Array()[0]; + limit = context.inputs[1].getInt32Array()[0]; + delta = context.inputs[2].getInt32Array()[0]; + } else if (context.inputs[0].dataType === DataType.float) { + start = context.inputs[0].getFloat32Array()[0]; + limit = context.inputs[1].getFloat32Array()[0]; + delta = context.inputs[2].getFloat32Array()[0]; + } + if (env.webgpu.validateInputContent) { + validateInputsContent(start, limit, delta); + } + + const cacheHint = [start, limit, delta].map(x => x.toString()).join('_'); + const metadata: ProgramMetadata = {name: 'Range', inputTypes: [], cacheHint}; + context.compute( + {...metadata, get: () => createRangeProgramInfo(metadata, start, limit, delta, context.inputs[0].dataType)}, + {inputs: []}); +}; diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index 3f903515694db..31bca8b94306d 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -333,7 +333,11 @@ function parseWebgpuFlags(args: minimist.ParsedArgs): Partial { if (profilingMode !== undefined && profilingMode !== 'off' && profilingMode !== 'default') { throw new Error('Flag "webgpu-profiling-mode" is invalid'); } - return {profilingMode}; + const validateInputContent = args['webgpu-validate-input-content']; + if (validateInputContent !== undefined && typeof validateInputContent !== 'boolean') { + throw new Error('Flag "webgpu-validate-input-content" is invalid'); + } + return {profilingMode, validateInputContent}; } function parseGlobalEnvFlags(args: minimist.ParsedArgs): NonNullable { diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 6e65645ef4756..96ced2bdf9216 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -885,10 +885,10 @@ // // "test_qlinearmatmul_3D", // // "test_quantizelinear_axis", // // "test_quantizelinear", - // "test_range_float_type_positive_delta_expanded", - // "test_range_float_type_positive_delta", - // "test_range_int32_type_negative_delta_expanded", - // "test_range_int32_type_negative_delta", + "test_range_float_type_positive_delta_expanded", + "test_range_float_type_positive_delta", + "test_range_int32_type_negative_delta_expanded", + "test_range_int32_type_negative_delta", "test_reciprocal_example", "test_reciprocal", "test_reduce_l1_default_axes_keepdims_example", diff --git a/js/web/test/test-main.ts b/js/web/test/test-main.ts index 49d0ac225be2f..d3592875bb6c7 100644 --- a/js/web/test/test-main.ts +++ b/js/web/test/test-main.ts @@ -57,6 +57,9 @@ if (options.globalEnvFlags) { if (flags.webgpu?.profilingMode !== undefined) { ort.env.webgpu.profilingMode = flags.webgpu.profilingMode; } + if (flags.webgpu?.validateInputContent !== undefined) { + ort.env.webgpu.validateInputContent = flags.webgpu.validateInputContent; + } } // Set logging configuration diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 6ced8d4d4a4ad..ae33fb752fe00 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -318,6 +318,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Til class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, Range); + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad); @@ -584,7 +587,11 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/range.cc b/onnxruntime/core/providers/js/operators/range.cc new file mode 100644 index 0000000000000..e15861f7f227a --- /dev/null +++ b/onnxruntime/core/providers/js/operators/range.cc @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +#include "range.h" + +namespace onnxruntime { +namespace js { +ONNX_OPERATOR_KERNEL_EX( + Range, + kOnnxDomain, + 11, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .InputMemoryType(OrtMemTypeCPU, 0) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2), + Range); +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/range.h b/onnxruntime/core/providers/js/operators/range.h new file mode 100644 index 0000000000000..8b32bfc3d984b --- /dev/null +++ b/onnxruntime/core/providers/js/operators/range.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +JSEP_KERNEL_IMPL(Range, Range); + +} // namespace js +} // namespace onnxruntime From 63acaf47d2193cbe175f5cabf0cb20b25d6e39aa Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Sun, 1 Oct 2023 03:06:34 +0200 Subject: [PATCH 18/20] Fix onnx quantizer activation and weight type attribute (#17651) In [`quantize_subgraph`](https://github.com/microsoft/onnxruntime/blob/v1.16.0/onnxruntime/python/tools/quantization/onnx_quantizer.py#L188-L189) `self.weight_qType` and `self.activation_qType` are [integers](https://github.com/microsoft/onnxruntime/blob/v1.16.0/onnxruntime/python/tools/quantization/onnx_quantizer.py#L115-L116) while `ONNXQuantizer` expects `QuantType` --- onnxruntime/python/tools/quantization/onnx_quantizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 2d1e418f9d2b4..c3c3d2a837336 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -112,8 +112,8 @@ def __init__( False if "ActivationSymmetric" not in self.extra_options else self.extra_options["ActivationSymmetric"] ) - self.activation_qType = activation_qType.tensor_type - self.weight_qType = weight_qType.tensor_type + self.activation_qType = getattr(activation_qType, "tensor_type", activation_qType) + self.weight_qType = getattr(weight_qType, "tensor_type", weight_qType) """ Dictionary specifying the min and max values for tensors. It has following format: { From ac4e72604605be524f70d1e92da81b70c5db984a Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Mon, 2 Oct 2023 12:25:28 +1000 Subject: [PATCH 19/20] Add bytes model loading test to react native e2e (#17749) ### Description Update E2E test to also check InferenceSession.create with bytes. ### Motivation and Context Add tests to validate #17739 --- js/react_native/e2e/package.json | 3 ++- js/react_native/e2e/src/App.tsx | 16 ++++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/js/react_native/e2e/package.json b/js/react_native/e2e/package.json index 969c70c110123..cd97ec1d099e4 100644 --- a/js/react_native/e2e/package.json +++ b/js/react_native/e2e/package.json @@ -10,7 +10,8 @@ }, "dependencies": { "react": "^18.1.0", - "react-native": "^0.69.1" + "react-native": "^0.69.1", + "react-native-fs": "^2.20.0" }, "devDependencies": { "@babel/core": "^7.17.0", diff --git a/js/react_native/e2e/src/App.tsx b/js/react_native/e2e/src/App.tsx index f3e415f0c5a55..8a76edabc613e 100644 --- a/js/react_native/e2e/src/App.tsx +++ b/js/react_native/e2e/src/App.tsx @@ -8,6 +8,7 @@ import { Image, Text, TextInput, View } from 'react-native'; import { InferenceSession, Tensor } from 'onnxruntime-react-native'; import MNIST, { MNISTInput, MNISTOutput, MNISTResult, } from './mnist-data-handler'; import { Buffer } from 'buffer'; +import { readFile } from 'react-native-fs'; interface State { session: @@ -39,10 +40,21 @@ export default class App extends React.PureComponent<{}, State> { this.setState({ imagePath }); const modelPath = await MNIST.getLocalModelPath(); - const session: InferenceSession = await InferenceSession.create(modelPath); + + // test creating session with path + console.log('Creating with path'); + const pathSession: InferenceSession = await InferenceSession.create(modelPath); + pathSession.release(); + + // and with bytes + console.log('Creating with bytes'); + const base64Str = await readFile(modelPath, 'base64'); + const bytes = Buffer.from(base64Str, 'base64'); + const session: InferenceSession = await InferenceSession.create(bytes); this.setState({ session }); - void this.infer(); + console.log('Test session created'); + void await this.infer(); } catch (err) { console.log(err.message); } From f158f394d695a196ae2e06b350523a7dc741321d Mon Sep 17 00:00:00 2001 From: zesongw Date: Tue, 3 Oct 2023 04:01:04 +0800 Subject: [PATCH 20/20] [WebNN EP] Support Softmax since version 13 (#17714) ### Description WebNN only supports 2-D input tensor along axis 1. For now, we use Reshape and Transpose wraparound to get the compatible input. ### Motivation and Context Enable more models to run on WebNN. --- .../webnn/builders/impl/softmax_op_builder.cc | 97 +++++++++++++------ 1 file changed, 69 insertions(+), 28 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index b207b804416aa..6a86ca7aca6e9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -35,30 +35,79 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); const auto input_size = input_shape.size(); - // WebNN Softmax only support 2d input shape, reshape input to 2d. - if (input_size != 2) { - NodeAttrHelper helper(node); + NodeAttrHelper helper(node); + if (node.SinceVersion() < 13) { int32_t axis = helper.Get("axis", 1); - if (node.SinceVersion() >= 13) - // Opset 13 has default value -1. - axis = helper.Get("axis", -1); + axis = static_cast(HandleNegativeAxis(axis, input_size)); // Coerce the input into a 2-dimensional tensor with dimensions [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}]. + if (input_size != 2) { + int32_t first_dim = static_cast(std::reduce(input_shape.begin(), input_shape.begin() + axis, + 1, std::multiplies())); + int32_t second_dim = static_cast(std::reduce(input_shape.begin() + axis, input_shape.end(), + 1, std::multiplies())); + emscripten::val new_shape = emscripten::val::array(std::vector{first_dim, second_dim}); + input = model_builder.GetBuilder().call("reshape", input, new_shape); + } + + output = model_builder.GetBuilder().call("softmax", input); + + // Reshape output to the same shape of input. + if (input_size != 2) { + emscripten::val new_shape = emscripten::val::array(); + for (size_t i = 0; i < input_size; i++) { + new_shape.call("push", static_cast(input_shape[i])); + } + output = model_builder.GetBuilder().call("reshape", output, new_shape); + } + } else { + int32_t axis = helper.Get("axis", -1); axis = static_cast(HandleNegativeAxis(axis, input_size)); - int32_t first_dim = static_cast(std::reduce(input_shape.begin(), input_shape.begin() + axis, - 1, std::multiplies())); - int32_t second_dim = static_cast(std::reduce(input_shape.begin() + axis, input_shape.end(), - 1, std::multiplies())); - emscripten::val new_shape = emscripten::val::array(std::vector{first_dim, second_dim}); - input = model_builder.GetBuilder().call("reshape", input, new_shape); - } - output = model_builder.GetBuilder().call("softmax", input); - // Reshape output to the same shape of input. - if (input_size != 2) { - emscripten::val new_shape = emscripten::val::array(); - for (size_t i = 0; i < input_size; i++) { - new_shape.call("push", static_cast(input_shape[i])); + // Wraparound for transpose the target axis to the last. + // WebNN compute the softmax values of the 2-D input tensor along axis 1. + // https://www.w3.org/TR/webnn/#api-mlgraphbuilder-softmax-method + if (axis != static_cast(input_shape.size() - 1)) { + emscripten::val options = emscripten::val::object(); + std::vector permutation(input_shape.size()); + std::iota(permutation.begin(), permutation.end(), 0); + permutation.erase(permutation.begin() + axis); + permutation.push_back(axis); + options.set("permutation", emscripten::val::array(permutation)); + input = model_builder.GetBuilder().call("transpose", input, options); + } + // Wraparound for reshape input tensor to 2-D. + if (input_shape.size() != 2) { + uint32_t first_dim = static_cast(std::reduce(input_shape.begin(), input_shape.begin() + axis, + 1, std::multiplies())); + first_dim *= static_cast(std::reduce(input_shape.begin() + axis + 1, input_shape.end(), + 1, std::multiplies())); + uint32_t second_dim = static_cast(input_shape[axis]); + emscripten::val new_shape = emscripten::val::array(std::vector{first_dim, second_dim}); + input = model_builder.GetBuilder().call("reshape", input, new_shape); + } + + output = model_builder.GetBuilder().call("softmax", input); + + // Transpose back to the axis. + if (input_shape.size() != 2) { + std::vector new_shape; + std::transform(input_shape.begin(), input_shape.begin() + axis, std::back_inserter(new_shape), + [](int64_t dim) -> uint32_t { return static_cast(dim); }); + std::transform(input_shape.begin() + axis + 1, input_shape.end(), std::back_inserter(new_shape), + [](int64_t dim) -> uint32_t { return static_cast(dim); }); + new_shape.push_back(static_cast(input_shape[axis])); + output = model_builder.GetBuilder().call("reshape", + output, emscripten::val::array(new_shape)); + } + // Reshape to the original shape. + if (axis != static_cast(input_shape.size() - 1)) { + emscripten::val options = emscripten::val::object(); + std::vector permutation(input_shape.size()); + std::iota(permutation.begin(), permutation.end(), 0); + permutation.pop_back(); + permutation.insert(permutation.begin() + axis, input_shape.size() - 1); + options.set("permutation", emscripten::val::array(permutation)); + output = model_builder.GetBuilder().call("transpose", output, options); } - output = model_builder.GetBuilder().call("reshape", output, new_shape); } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); @@ -80,14 +129,6 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali << input_size << "d shape"; return false; } - NodeAttrHelper helper(node); - const int64_t axis = helper.Get("axis", 1); - // WebNN softmax only support reshape for the last axis or version before 13. - // TODO: support opset 13 by composing into: Exp(input) / ReduceSum(Exp(input), axis=axis, keepdims=1). - if (axis != -1 && axis != input_shape.size() - 1 && node.SinceVersion() >= 13) { - LOGS(logger, VERBOSE) << "SoftMax only support axis 1 or -1, input axis: " << axis; - return false; - } return true; }