Skip to content

Commit

Permalink
[js/webgpu] FP16 Cast, Resize (#18035)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

Cast/Resize with f16 are missing in vae-decoder-f16. With this change,
vae-decoder-f16 becomes 315 ms from over than 1 seconds.
  • Loading branch information
qjia7 authored Oct 24, 2023
1 parent 688524a commit eb47008
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 25 deletions.
3 changes: 1 addition & 2 deletions onnxruntime/core/providers/js/operators/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ const std::vector<MLDataType>& CastOpTypeConstraints() {
// https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section
//
static std::vector<MLDataType> types{
// TODO(fs-eire): support f16 when it's ready
// DataTypeImpl::GetTensorType<MLFloat16>(),
DataTypeImpl::GetTensorType<MLFloat16>(),
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<uint32_t>(),
Expand Down
46 changes: 23 additions & 23 deletions onnxruntime/core/providers/js/operators/resize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

namespace onnxruntime {
namespace js {
#define REGISTER_RESIZE_VERSIONED_10_10_KERNEL(domain) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
Resize, \
domain, \
10, 10, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.InputMemoryType(OrtMemTypeCPUInput, 1) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), \
#define REGISTER_RESIZE_VERSIONED_10_10_KERNEL(domain) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
Resize, \
domain, \
10, 10, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.InputMemoryType(OrtMemTypeCPUInput, 1) \
.TypeConstraint("T", JsepSupportedFloatTypes()), \
Resize);

#define REGISTER_RESIZE_VERSIONED_KERNEL(domain, sinceVersion, endVerion) \
Expand All @@ -26,22 +26,22 @@ namespace js {
.InputMemoryType(OrtMemTypeCPUInput, 1) \
.InputMemoryType(OrtMemTypeCPUInput, 2) \
.InputMemoryType(OrtMemTypeCPUInput, 3) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<float>()), \
.TypeConstraint("T1", JsepSupportedFloatTypes()) \
.TypeConstraint("T2", JsepSupportedFloatTypes()), \
Resize);

#define REGISTER_RESIZE_KERNEL(domain, sinceVersion) \
ONNX_OPERATOR_KERNEL_EX( \
Resize, \
domain, \
sinceVersion, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.InputMemoryType(OrtMemTypeCPUInput, 1) \
.InputMemoryType(OrtMemTypeCPUInput, 2) \
.InputMemoryType(OrtMemTypeCPUInput, 3) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<float>()), \
#define REGISTER_RESIZE_KERNEL(domain, sinceVersion) \
ONNX_OPERATOR_KERNEL_EX( \
Resize, \
domain, \
sinceVersion, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.InputMemoryType(OrtMemTypeCPUInput, 1) \
.InputMemoryType(OrtMemTypeCPUInput, 2) \
.InputMemoryType(OrtMemTypeCPUInput, 3) \
.TypeConstraint("T1", JsepSupportedFloatTypes()) \
.TypeConstraint("T2", JsepSupportedFloatTypes()), \
Resize);

#define REGISTER_RESIZE_KERNEL_DOMAIN(domain) \
Expand Down

0 comments on commit eb47008

Please sign in to comment.