From 3f0a1fa821fce62a90dcdbe7171fbf72e9b50c73 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Sat, 7 Dec 2024 09:58:15 +0800 Subject: [PATCH] [WebNN] Allow ops to handle ignoring an empty tensor as input (#22972) ### Description Some ops should allow empty tensor as input, e.g. roi, scales inputs in Resize ### Motivation and Context It avoid some unexpected fallback for optional input with empty tensor. e.g. roi and scales are both optional inputs in Resize, in some models they have non-empty name but with empty initializer presented as `[0]`, WebNN currently will fallback all nodes with 0 dimension, which is not expected. ![image](https://github.com/user-attachments/assets/599ba351-b5f6-49ac-8a1f-69fb28dbaf9b) --- onnxruntime/core/providers/webnn/builders/helper.cc | 5 +++-- onnxruntime/core/providers/webnn/builders/helper.h | 3 ++- .../core/providers/webnn/builders/impl/base_op_builder.cc | 2 +- .../core/providers/webnn/builders/impl/base_op_builder.h | 2 ++ 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index ff1ff57e3fb98..45a87960126cd 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -69,7 +69,8 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We } } -bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) { +bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, + const logging::Logger& logger, bool allow_empty_input) { const auto& node_arg_name = node_arg.Name(); const auto* shape_proto = node_arg.Shape(); // Optional tensors can be indicated by an empty name, just ignore it. @@ -89,7 +90,7 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n << "use sessionOptions.FreeDimensionOverrides to set a fixed shape: " << node_arg_name; return false; } - if (dim.dim_value() == 0) { + if (dim.dim_value() == 0 && !allow_empty_input) { LOGS(logger, VERBOSE) << "The shape of [" << node_arg_name << "] has 0 dimension which is not supported by WebNN"; return false; } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index f732d96df9bb0..a06f46f1bdf0a 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -181,7 +181,8 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; }); } -bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger); +bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, + const logging::Logger& logger, bool allow_empty_input = false); // Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP. std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index 18cb1bd2d467a..290d16a48dd83 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -45,7 +45,7 @@ bool BaseOpBuilder::HasSupportedInputs(const InitializedTensorSet& initializers, const logging::Logger& logger) const { const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); for (const auto* input : node.InputDefs()) { - if (!IsTensorShapeSupported(*input, node_name, logger)) { + if (!IsTensorShapeSupported(*input, node_name, logger, allow_empty_tensor_as_input_)) { return false; } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h index a4f7e648798d3..111882c2a678e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -58,6 +58,8 @@ class BaseOpBuilder : public IOpBuilder { bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; + + const bool allow_empty_tensor_as_input_; // Some operators can handle ignoring an empty tensor as input. }; } // namespace webnn