Skip to content

Commit

Permalink
handle fp16 for where op (#19969)
Browse files Browse the repository at this point in the history
this prevents falling back from webgpu to cpu, aka helps performance
  • Loading branch information
guschmue authored Mar 18, 2024
1 parent 141966b commit a4ac727
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions onnxruntime/core/providers/js/operators/where.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,19 @@
namespace onnxruntime {
namespace js {

#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
ONNX_OPERATOR_KERNEL_EX( \
OP_TYPE, \
kOnnxDomain, \
VERSION, \
kJsExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", \
{DataTypeImpl::GetTensorType<float>(), \
DataTypeImpl::GetTensorType<int32_t>(), \
DataTypeImpl::GetTensorType<uint32_t>(), \
DataTypeImpl::GetTensorType<bool>()}), \
#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
ONNX_OPERATOR_KERNEL_EX( \
OP_TYPE, \
kOnnxDomain, \
VERSION, \
kJsExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", \
{DataTypeImpl::GetTensorType<float>(), \
DataTypeImpl::GetTensorType<MLFloat16>(), \
DataTypeImpl::GetTensorType<int32_t>(), \
DataTypeImpl::GetTensorType<uint32_t>(), \
DataTypeImpl::GetTensorType<bool>()}), \
KERNEL_CLASS);

#define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \
Expand All @@ -29,6 +30,7 @@ namespace js {
KernelDefBuilder() \
.TypeConstraint("T", \
{DataTypeImpl::GetTensorType<float>(), \
DataTypeImpl::GetTensorType<MLFloat16>(), \
DataTypeImpl::GetTensorType<int32_t>(), \
DataTypeImpl::GetTensorType<uint32_t>(), \
DataTypeImpl::GetTensorType<bool>()}), \
Expand Down

0 comments on commit a4ac727

Please sign in to comment.