forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[JS/WebGPU] Support If on WebGPU (microsoft#17478)
- Loading branch information
1 parent
caae13c
commit 259f077
Showing
6 changed files
with
108 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "if.h" | ||
|
||
using namespace ONNX_NAMESPACE; | ||
using namespace onnxruntime::common; | ||
|
||
namespace onnxruntime { | ||
namespace js { | ||
|
||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, | ||
kOnnxDomain, | ||
1, 10, | ||
kJsExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU | ||
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>()) | ||
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), | ||
If); | ||
// output shape rules requiring the output shapes of the 'THEN' and 'ELSE' | ||
// branches to be the same were relaxed in opset-11 | ||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, | ||
kOnnxDomain, | ||
11, 12, | ||
kJsExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU | ||
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>()) | ||
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), | ||
If); | ||
|
||
// opset-13 supports sequence type for If's subgraph outputs | ||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, | ||
kOnnxDomain, | ||
13, 18, | ||
kJsExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU | ||
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>()) | ||
// Support sequence/optional tensors when all JSEP infra | ||
// (including tests runner) supports it | ||
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), | ||
If); | ||
|
||
// opset-19 supports float8 | ||
ONNX_OPERATOR_KERNEL_EX(If, | ||
kOnnxDomain, | ||
19, | ||
kJsExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU | ||
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>()) | ||
// Support sequence/optional tensors when all JSEP infra | ||
// (including tests runner) supports it | ||
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), | ||
If); | ||
|
||
Status If::Compute(OpKernelContext* ctx) const { | ||
// call the base CPU version. | ||
return onnxruntime::If::Compute(ctx); | ||
} | ||
|
||
} // namespace js | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include <functional> | ||
|
||
#include "core/providers/js/js_kernel.h" | ||
#include "core/common/common.h" | ||
#include "core/providers/cpu/controlflow/if.h" | ||
|
||
namespace onnxruntime { | ||
class SessionState; | ||
|
||
namespace js { | ||
|
||
// Use the CPU implementation for the logic | ||
class If final : public onnxruntime::If { | ||
public: | ||
If(const OpKernelInfo& info) : onnxruntime::If(info) {} | ||
|
||
Status Compute(OpKernelContext* ctx) const override; | ||
}; | ||
} // namespace js | ||
} // namespace onnxruntime |