From 8caa55aac5d72affb5d7e35f87ba9ca5bae47d94 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 4 Sep 2023 18:14:14 -0700 Subject: [PATCH 1/4] Support If on WebGPU --- js/web/test/suite-test-list.jsonc | 3 + onnxruntime/core/framework/session_state.cc | 5 +- .../providers/js/js_execution_provider.cc | 9 ++- onnxruntime/core/providers/js/operators/if.cc | 65 +++++++++++++++++++ onnxruntime/core/providers/js/operators/if.h | 24 +++++++ 5 files changed, 104 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/core/providers/js/operators/if.cc create mode 100644 onnxruntime/core/providers/js/operators/if.h diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index aca3526115c7e..3a53a070d847e 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -601,6 +601,9 @@ // // "test_hardsigmoid", // // "test_hardswish_expanded", // // "test_hardswish", + "test_if", + "test_if_seq", + "test_if_opt", "test_instancenorm_epsilon", "test_instancenorm_example", // "test_isinf_negative", diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index b7d26d87f2705..f0e5fbbd38721 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1030,7 +1030,10 @@ Status SessionState::CreateSubgraphSessionState() { for (auto& node : graph_.Nodes()) { for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { const auto& ep = node.GetExecutionProviderType(); - if (!ep.empty() && ep != kCpuExecutionProvider && ep != kCudaExecutionProvider && ep != kRocmExecutionProvider && ep != kDmlExecutionProvider) { + if (!ep.empty() && + ep != kCpuExecutionProvider && ep != kCudaExecutionProvider && + ep != kRocmExecutionProvider && ep != kDmlExecutionProvider && + ep != kJsExecutionProvider) { // SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow // node containing the subgraph it will create whatever state it needs internally. continue; diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 829f3e5f4f143..a49ccde291322 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -314,6 +314,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Til class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, If); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, If); std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -564,7 +568,10 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/js/operators/if.cc b/onnxruntime/core/providers/js/operators/if.cc new file mode 100644 index 0000000000000..ef072bb1635dd --- /dev/null +++ b/onnxruntime/core/providers/js/operators/if.cc @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "if.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 1, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); +// output shape rules requiring the output shapes of the 'THEN' and 'ELSE' +// branches to be the same were relaxed in opset-11 +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 11, 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); + +// opset-13 supports sequence type for If's subgraph outputs +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 13, 18, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + // Support sequence/optional tensors when all JSEP infra + // (including tests runner) supports it + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); + +// opset-19 supports float8 +ONNX_OPERATOR_KERNEL_EX(If, + kOnnxDomain, + 19, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + // Support sequence/optional tensors when all JSEP infra + // (including tests runner) supports it + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); + +Status If::Compute(OpKernelContext* ctx) const { + // call the base CPU version. + return onnxruntime::If::Compute(ctx); +} + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/if.h b/onnxruntime/core/providers/js/operators/if.h new file mode 100644 index 0000000000000..d060444ccc1d2 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/if.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +#include "core/providers/js/js_kernel.h" +#include "core/common/common.h" +#include "core/providers/cpu/controlflow/if.h" + +namespace onnxruntime { +class SessionState; + +namespace js { + +// Use the CPU implementation for the logic +class If final : public onnxruntime::If { + public: + If(const OpKernelInfo& info) : onnxruntime::If(info) {} + + Status Compute(OpKernelContext* ctx) const override; +}; +} // namespace js +} // namespace onnxruntime From 0274fbc53b39b00b853acc9f1242b7ca625f7a07 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 8 Sep 2023 18:46:26 -0700 Subject: [PATCH 2/4] Format --- js/web/test/suite-test-list.jsonc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 3a53a070d847e..90f402928dbe5 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -601,9 +601,9 @@ // // "test_hardsigmoid", // // "test_hardswish_expanded", // // "test_hardswish", - "test_if", - "test_if_seq", - "test_if_opt", + "test_if", + "test_if_seq", + "test_if_opt", "test_instancenorm_epsilon", "test_instancenorm_example", // "test_isinf_negative", From 02308159ccb12af4db88bfdc85aaab4a4b4746db Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 8 Sep 2023 19:10:27 -0700 Subject: [PATCH 3/4] Doc --- js/web/docs/webgpu-operators.md | 1 + 1 file changed, 1 insertion(+) diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index a969e1b86bf99..ce14e460d1886 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -44,6 +44,7 @@ Do not modify directly.* | GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | Greater | ai.onnx(7-8,9-12,13+) | | +| If | ai.onnx(1-10,11-12,13-18,19+) | | | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | | LayerNormalization | ai.onnx(17+) | | | LeakyRelu | ai.onnx(6-15,16+) | | From c3e516eda9e5cd2e4b089068eaf99f30fbc370ef Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Sat, 9 Sep 2023 22:49:01 -0700 Subject: [PATCH 4/4] Fix tests --- js/web/test/suite-test-list.jsonc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 90f402928dbe5..71457e0ab74f0 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -602,8 +602,10 @@ // // "test_hardswish_expanded", // // "test_hardswish", "test_if", - "test_if_seq", - "test_if_opt", + // TODO: Uncomment 'test_if_seq' and 'test_if_opt' once the test infra + // supports Sequence and Optional types + // "test_if_seq", + // "test_if_opt", "test_instancenorm_epsilon", "test_instancenorm_example", // "test_isinf_negative",