diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 537e552af2763..f36f8283e9bf6 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -69,7 +69,8 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We } } -bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) { +bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, + const logging::Logger& logger, bool allow_empty_input) { const auto& node_arg_name = node_arg.Name(); const auto* shape_proto = node_arg.Shape(); // Optional tensors can be indicated by an empty name, just ignore it. @@ -89,7 +90,7 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n << "use sessionOptions.FreeDimensionOverrides to set a fixed shape: " << node_arg_name; return false; } - if (dim.dim_value() == 0) { + if (dim.dim_value() == 0 && !allow_empty_input) { LOGS(logger, VERBOSE) << "The shape of [" << node_arg_name << "] has 0 dimension which is not supported by WebNN"; return false; } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 23489f142df3e..7fdfc5aefa798 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -181,7 +181,8 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; }); } -bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger); +bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, + const logging::Logger& logger, bool allow_empty_input = false); // Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP. std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer, diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index 1a4a07faa94fc..70fa0f9516c5c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -29,7 +29,7 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const emscripten::val& wnn_limits, const logging::Logger& logger) const { - if (!HasSupportedInputs(initializers, node, wnn_limits, logger)) + if (!HasSupportedInputs(node, wnn_limits, logger)) return false; if (!HasSupportedOutputs(node, wnn_limits, logger)) @@ -41,12 +41,11 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons return IsOpSupportedImpl(initializers, node, device_type, logger); } -bool BaseOpBuilder::HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, - const emscripten::val& wnn_limits, const logging::Logger& logger) const { +bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); for (const auto* input : node.InputDefs()) { - // ONNX initializers should have shape information, skip the shape check if the input is an initializer. - if (!Contains(initializers, input->Name()) && !IsTensorShapeSupported(*input, node_name, logger)) { + if (!IsTensorShapeSupported(*input, node_name, logger, allow_empty_tensor_as_input_)) { return false; } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h index e6026e805679b..9412fa8026fb3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -22,6 +22,9 @@ class BaseOpBuilder : public IOpBuilder { const logging::Logger& logger) const override final ORT_MUST_USE_RESULT; protected: + explicit BaseOpBuilder(bool allow_empty_tensor_as_input = false) + : allow_empty_tensor_as_input_(allow_empty_tensor_as_input) { + } virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const ORT_MUST_USE_RESULT = 0; @@ -53,9 +56,10 @@ class BaseOpBuilder : public IOpBuilder { private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; - bool HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const; + bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; + + const bool allow_empty_tensor_as_input_; // Some operators can handle ignoring an empty tensor as input. }; } // namespace webnn diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index 408045571e422..00f8cff25ccf5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -21,6 +21,8 @@ namespace webnn { class ResizeOpBuilder : public BaseOpBuilder { // Add operator related. public: + // Allow roi and scales potentially being empty inputs that are ignored during processing. + ResizeOpBuilder() : BaseOpBuilder(/*allow empty inputs*/ true) {} void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; private: