From c06a5c284fff24ba2a8bf7fb98865749ced773e0 Mon Sep 17 00:00:00 2001 From: shiyi Date: Wed, 11 Dec 2024 07:48:16 +0800 Subject: [PATCH] [WebNN] Improve data type check of slice op (#22988) A follow-up of [[WebNN] Support negative steps for slice](https://github.com/microsoft/onnxruntime/pull/22871#discussion_r1847929774). Slice op is emulated by reverse+slice when steps < 0 so `SliceOpBuilder::HasSupportedInputsImpl()` should also check the supported data types of reverse. --------- Co-authored-by: Wanming Lin --- .../webnn/builders/impl/base_op_builder.h | 2 +- .../webnn/builders/impl/einsum_op_builder.cc | 14 +++++++------- .../builders/impl/gatherElements_op_builder.cc | 7 ++++--- .../webnn/builders/impl/gatherND_op_builder.cc | 8 ++++---- .../webnn/builders/impl/gru_op_builder.cc | 4 ++-- .../builders/impl/scatterElements_op_builder.cc | 7 ++++--- .../webnn/builders/impl/scatterND_op_builder.cc | 7 ++++--- 7 files changed, 26 insertions(+), 23 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h index 111882c2a678e..0a4367a71add4 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -56,7 +56,7 @@ class BaseOpBuilder : public IOpBuilder { private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; - bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; + bool HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; const bool allow_empty_tensor_as_input_; // Some operators can handle ignoring an empty tensor as input. diff --git a/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc index 931854d0f33c1..ef713f48b8135 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc @@ -27,8 +27,8 @@ class EinsumOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Helper functions, thanks for DML EP's OperatorHelper. @@ -735,8 +735,8 @@ bool EinsumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool EinsumOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool EinsumOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -776,11 +776,11 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten: return false; } else if (recognized_operator_type == RecognizedOperatorType::Pairwise) { // Map to WebNN's gemm or matmul - return IsDataTypeSupportedByOp("MatMul", input0_type, wnn_limits, "a", "inputs", logger); + return IsDataTypeSupportedByWebNNOp(op_type, "matmul", input0_type, wnn_limits, "a", "inputs", logger); } else if (recognized_operator_type == RecognizedOperatorType::ReduceSum) { - return IsDataTypeSupportedByOp("ReduceSum", input0_type, wnn_limits, "input", "inputs", logger); + return IsDataTypeSupportedByWebNNOp(op_type, "reduceSum", input0_type, wnn_limits, "input", "inputs", logger); } else { - return IsDataTypeSupportedByOp("Identity", input0_type, wnn_limits, "input", "inputs", logger); + return IsDataTypeSupportedByWebNNOp(op_type, "identity", input0_type, wnn_limits, "input", "inputs", logger); } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc index 225cfcdfc852c..cb7b7de74e121 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc @@ -20,8 +20,8 @@ class GatherElementsOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -49,7 +49,8 @@ Status GatherElementsOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builde // Operator support related. -bool GatherElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, +bool GatherElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc index cb4f85a40ee12..002a1a6a63026 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc @@ -22,8 +22,8 @@ class GatherNDOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -55,8 +55,8 @@ bool GatherNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initial return true; } -bool GatherNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool GatherNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& op_type = node.OpType(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index fdaaae28df051..b240e30d38b22 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -26,8 +26,8 @@ class GruOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc index c786aa468736c..8c70525835059 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc @@ -22,8 +22,8 @@ class ScatterElementsOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -65,7 +65,8 @@ bool ScatterElementsOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* return true; } -bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, +bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc index feb93cc14b7c4..8089b9706886f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc @@ -22,8 +22,8 @@ class ScatterNDOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -57,7 +57,8 @@ bool ScatterNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia return true; } -bool ScatterNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, +bool ScatterNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1];