From 9f589d87d2dc14e726bb5644880d03148b73095b Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Mon, 18 Mar 2024 13:42:51 -0700 Subject: [PATCH] handle fp16 for where op (#19969) this prevents falling back from webgpu to cpu, aka helps performance --- .../core/providers/js/operators/where.cc | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/js/operators/where.cc b/onnxruntime/core/providers/js/operators/where.cc index 2f8f5e275aa98..dcdf9bee2f783 100644 --- a/onnxruntime/core/providers/js/operators/where.cc +++ b/onnxruntime/core/providers/js/operators/where.cc @@ -6,18 +6,19 @@ namespace onnxruntime { namespace js { -#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ - ONNX_OPERATOR_KERNEL_EX( \ - OP_TYPE, \ - kOnnxDomain, \ - VERSION, \ - kJsExecutionProvider, \ - KernelDefBuilder() \ - .TypeConstraint("T", \ - {DataTypeImpl::GetTensorType(), \ - DataTypeImpl::GetTensorType(), \ - DataTypeImpl::GetTensorType(), \ - DataTypeImpl::GetTensorType()}), \ +#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kJsExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", \ + {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType()}), \ KERNEL_CLASS); #define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \ @@ -29,6 +30,7 @@ namespace js { KernelDefBuilder() \ .TypeConstraint("T", \ {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ DataTypeImpl::GetTensorType(), \ DataTypeImpl::GetTensorType(), \ DataTypeImpl::GetTensorType()}), \