Skip to content

Commit

Permalink
[JS/WebGPU] Support If on WebGPU (microsoft#17478)
Browse files Browse the repository at this point in the history
  • Loading branch information
hariharans29 authored and kleiti committed Mar 22, 2024
1 parent caae13c commit 259f077
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 4 deletions.
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Do not modify directly.*
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| Greater | ai.onnx(7-8,9-12,13+) | |
| GreaterOrEqual | ai.onnx(12-15,16+) | |
| If | ai.onnx(1-10,11-12,13-18,19+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
| LayerNormalization | ai.onnx(17+) | |
| LeakyRelu | ai.onnx(6-15,16+) | |
Expand Down
5 changes: 5 additions & 0 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,11 @@
// // "test_hardsigmoid",
// // "test_hardswish_expanded",
// // "test_hardswish",
"test_if",
// TODO: Uncomment 'test_if_seq' and 'test_if_opt' once the test infra
// supports Sequence and Optional types
// "test_if_seq",
// "test_if_opt",
"test_instancenorm_epsilon",
"test_instancenorm_example",
// "test_isinf_negative",
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,10 @@ Status SessionState::CreateSubgraphSessionState() {
for (auto& node : graph_.Nodes()) {
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
const auto& ep = node.GetExecutionProviderType();
if (!ep.empty() && ep != kCpuExecutionProvider && ep != kCudaExecutionProvider && ep != kRocmExecutionProvider && ep != kDmlExecutionProvider) {
if (!ep.empty() &&
ep != kCpuExecutionProvider && ep != kCudaExecutionProvider &&
ep != kRocmExecutionProvider && ep != kDmlExecutionProvider &&
ep != kJsExecutionProvider) {
// SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow
// node containing the subgraph it will create whatever state it needs internally.
continue;
Expand Down
12 changes: 9 additions & 3 deletions onnxruntime/core/providers/js/js_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Til
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum);

class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad);
Expand All @@ -327,6 +326,11 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad);

class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, If);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, If);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, If);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, If);

std::unique_ptr<KernelRegistry> RegisterKernels() {
auto kernel_registry = std::make_unique<onnxruntime::KernelRegistry>();

Expand Down Expand Up @@ -580,15 +584,17 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum)>,

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad)>,

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, If)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, If)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, If)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, If)>,
};

for (auto& function_table_entry : function_table) {
Expand Down
65 changes: 65 additions & 0 deletions onnxruntime/core/providers/js/operators/if.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "if.h"

using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;

namespace onnxruntime {
namespace js {

ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
kOnnxDomain,
1, 10,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
If);
// output shape rules requiring the output shapes of the 'THEN' and 'ELSE'
// branches to be the same were relaxed in opset-11
ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
kOnnxDomain,
11, 12,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
If);

// opset-13 supports sequence type for If's subgraph outputs
ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
kOnnxDomain,
13, 18,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
// Support sequence/optional tensors when all JSEP infra
// (including tests runner) supports it
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
If);

// opset-19 supports float8
ONNX_OPERATOR_KERNEL_EX(If,
kOnnxDomain,
19,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
// Support sequence/optional tensors when all JSEP infra
// (including tests runner) supports it
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
If);

Status If::Compute(OpKernelContext* ctx) const {
// call the base CPU version.
return onnxruntime::If::Compute(ctx);
}

} // namespace js
} // namespace onnxruntime
24 changes: 24 additions & 0 deletions onnxruntime/core/providers/js/operators/if.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <functional>

#include "core/providers/js/js_kernel.h"
#include "core/common/common.h"
#include "core/providers/cpu/controlflow/if.h"

namespace onnxruntime {
class SessionState;

namespace js {

// Use the CPU implementation for the logic
class If final : public onnxruntime::If {
public:
If(const OpKernelInfo& info) : onnxruntime::If(info) {}

Status Compute(OpKernelContext* ctx) const override;
};
} // namespace js
} // namespace onnxruntime

0 comments on commit 259f077

Please sign in to comment.