From 259f0773cadc2af338645e7ea19c4240d5496369 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 19 Sep 2023 12:20:18 -0700 Subject: [PATCH] [JS/WebGPU] Support If on WebGPU (#17478) --- js/web/docs/webgpu-operators.md | 1 + js/web/test/suite-test-list.jsonc | 5 ++ onnxruntime/core/framework/session_state.cc | 5 +- .../providers/js/js_execution_provider.cc | 12 +++- onnxruntime/core/providers/js/operators/if.cc | 65 +++++++++++++++++++ onnxruntime/core/providers/js/operators/if.h | 24 +++++++ 6 files changed, 108 insertions(+), 4 deletions(-) create mode 100644 onnxruntime/core/providers/js/operators/if.cc create mode 100644 onnxruntime/core/providers/js/operators/if.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 71d98f5d73671..a87a894e3b3c5 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -46,6 +46,7 @@ Do not modify directly.* | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | Greater | ai.onnx(7-8,9-12,13+) | | | GreaterOrEqual | ai.onnx(12-15,16+) | | +| If | ai.onnx(1-10,11-12,13-18,19+) | | | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | | LayerNormalization | ai.onnx(17+) | | | LeakyRelu | ai.onnx(6-15,16+) | | diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 94592884ccad6..6e65645ef4756 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -602,6 +602,11 @@ // // "test_hardsigmoid", // // "test_hardswish_expanded", // // "test_hardswish", + "test_if", + // TODO: Uncomment 'test_if_seq' and 'test_if_opt' once the test infra + // supports Sequence and Optional types + // "test_if_seq", + // "test_if_opt", "test_instancenorm_epsilon", "test_instancenorm_example", // "test_isinf_negative", diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index b7d26d87f2705..f0e5fbbd38721 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1030,7 +1030,10 @@ Status SessionState::CreateSubgraphSessionState() { for (auto& node : graph_.Nodes()) { for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { const auto& ep = node.GetExecutionProviderType(); - if (!ep.empty() && ep != kCpuExecutionProvider && ep != kCudaExecutionProvider && ep != kRocmExecutionProvider && ep != kDmlExecutionProvider) { + if (!ep.empty() && + ep != kCpuExecutionProvider && ep != kCudaExecutionProvider && + ep != kRocmExecutionProvider && ep != kDmlExecutionProvider && + ep != kJsExecutionProvider) { // SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow // node containing the subgraph it will create whatever state it needs internally. continue; diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 9dccd7c47fbb6..0674fe02d093d 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -318,7 +318,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Til class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization); - class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad); @@ -327,6 +326,11 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, If); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, If); + std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -580,15 +584,17 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/js/operators/if.cc b/onnxruntime/core/providers/js/operators/if.cc new file mode 100644 index 0000000000000..ef072bb1635dd --- /dev/null +++ b/onnxruntime/core/providers/js/operators/if.cc @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "if.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 1, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); +// output shape rules requiring the output shapes of the 'THEN' and 'ELSE' +// branches to be the same were relaxed in opset-11 +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 11, 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); + +// opset-13 supports sequence type for If's subgraph outputs +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 13, 18, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + // Support sequence/optional tensors when all JSEP infra + // (including tests runner) supports it + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); + +// opset-19 supports float8 +ONNX_OPERATOR_KERNEL_EX(If, + kOnnxDomain, + 19, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + // Support sequence/optional tensors when all JSEP infra + // (including tests runner) supports it + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); + +Status If::Compute(OpKernelContext* ctx) const { + // call the base CPU version. + return onnxruntime::If::Compute(ctx); +} + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/if.h b/onnxruntime/core/providers/js/operators/if.h new file mode 100644 index 0000000000000..d060444ccc1d2 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/if.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +#include "core/providers/js/js_kernel.h" +#include "core/common/common.h" +#include "core/providers/cpu/controlflow/if.h" + +namespace onnxruntime { +class SessionState; + +namespace js { + +// Use the CPU implementation for the logic +class If final : public onnxruntime::If { + public: + If(const OpKernelInfo& info) : onnxruntime::If(info) {} + + Status Compute(OpKernelContext* ctx) const override; +}; +} // namespace js +} // namespace onnxruntime