Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JS/WebGPU] Support If on WebGPU #17478

Merged
merged 9 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Do not modify directly.*
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| Greater | ai.onnx(7-8,9-12,13+) | |
| GreaterOrEqual | ai.onnx(12-15,16+) | |
| If | ai.onnx(1-10,11-12,13-18,19+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
| LayerNormalization | ai.onnx(17+) | |
| LeakyRelu | ai.onnx(6-15,16+) | |
Expand Down
5 changes: 5 additions & 0 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,11 @@
// // "test_hardsigmoid",
// // "test_hardswish_expanded",
// // "test_hardswish",
"test_if",
// TODO: Uncomment 'test_if_seq' and 'test_if_opt' once the test infra
// supports Sequence and Optional types
// "test_if_seq",
// "test_if_opt",
"test_instancenorm_epsilon",
"test_instancenorm_example",
// "test_isinf_negative",
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,10 @@ Status SessionState::CreateSubgraphSessionState() {
for (auto& node : graph_.Nodes()) {
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
const auto& ep = node.GetExecutionProviderType();
if (!ep.empty() && ep != kCpuExecutionProvider && ep != kCudaExecutionProvider && ep != kRocmExecutionProvider && ep != kDmlExecutionProvider) {
if (!ep.empty() &&
hariharans29 marked this conversation as resolved.
Show resolved Hide resolved
ep != kCpuExecutionProvider && ep != kCudaExecutionProvider &&
ep != kRocmExecutionProvider && ep != kDmlExecutionProvider &&
ep != kJsExecutionProvider) {
// SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow
// node containing the subgraph it will create whatever state it needs internally.
continue;
Expand Down
12 changes: 9 additions & 3 deletions onnxruntime/core/providers/js/js_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Til
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum);

class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad);
Expand All @@ -327,6 +326,11 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad);

class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, If);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, If);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, If);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, If);

std::unique_ptr<KernelRegistry> RegisterKernels() {
auto kernel_registry = std::make_unique<onnxruntime::KernelRegistry>();

Expand Down Expand Up @@ -580,15 +584,17 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum)>,

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad)>,

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, If)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, If)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, If)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, If)>,
};

for (auto& function_table_entry : function_table) {
Expand Down
65 changes: 65 additions & 0 deletions onnxruntime/core/providers/js/operators/if.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "if.h"

using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;

namespace onnxruntime {
namespace js {

ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
kOnnxDomain,
1, 10,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
If);
// output shape rules requiring the output shapes of the 'THEN' and 'ELSE'
// branches to be the same were relaxed in opset-11
ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
kOnnxDomain,
11, 12,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
If);

// opset-13 supports sequence type for If's subgraph outputs
ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
kOnnxDomain,
13, 18,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
// Support sequence/optional tensors when all JSEP infra
// (including tests runner) supports it
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
If);

// opset-19 supports float8
ONNX_OPERATOR_KERNEL_EX(If,
kOnnxDomain,
19,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
// Support sequence/optional tensors when all JSEP infra
// (including tests runner) supports it
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
If);

Status If::Compute(OpKernelContext* ctx) const {
// call the base CPU version.
return onnxruntime::If::Compute(ctx);
}

} // namespace js
} // namespace onnxruntime
24 changes: 24 additions & 0 deletions onnxruntime/core/providers/js/operators/if.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <functional>

#include "core/providers/js/js_kernel.h"
#include "core/common/common.h"
#include "core/providers/cpu/controlflow/if.h"

namespace onnxruntime {
class SessionState;

namespace js {

// Use the CPU implementation for the logic
class If final : public onnxruntime::If {
public:
If(const OpKernelInfo& info) : onnxruntime::If(info) {}

Status Compute(OpKernelContext* ctx) const override;
};
} // namespace js
} // namespace onnxruntime