diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 3c9faef9998e0..540bb3ba48531 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -69,16 +69,17 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We } } -bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) { - const auto& node_arg_name = node_arg.Name(); - const auto* shape_proto = node_arg.Shape(); +bool IsInputSupported(const NodeArg& input, const std::string& parent_name, const logging::Logger& logger) { + const auto& input_name = input.Name(); + const auto* shape_proto = input.Shape(); // Optional tensors can be indicated by an empty name, just ignore it. - if (node_arg_name.empty()) { + if (input_name.empty()) { return true; } - // We do not support input/output with no shape. + // We do not support input with no shape. if (!shape_proto) { - LOGS(logger, VERBOSE) << "Node arg [" << node_arg_name << "] of [" << parent_name << "] has not shape"; + LOGS(logger, VERBOSE) << "Input [" << input_name << "] of [" << parent_name + << "] has not shape"; return false; } @@ -86,11 +87,12 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n // WebNN doesn't support dynamic shape - use sessionOptions.freeDimensionOverrides to fix the shape. if (!dim.has_dim_value()) { LOGS(logger, VERBOSE) << "Dynamic shape is not supported, " - << "use sessionOptions.FreeDimensionOverrides to set a fixed shape: " << node_arg_name; + << "use sessionOptions.FreeDimensionOverrides to set a fixed shape for input: " + << input_name; return false; } if (dim.dim_value() == 0) { - LOGS(logger, VERBOSE) << "The shape of [" << node_arg_name << "] has 0 dimension which is not supported by WebNN"; + LOGS(logger, VERBOSE) << "The shape of [" << input_name << "] has 0 dimension which is not supported by WebNN"; return false; } } @@ -106,12 +108,7 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v std::vector> supported_node_groups; for (const auto* input : graph_viewer.GetInputs()) { - if (!IsTensorShapeSupported(*input, "graph", logger)) { - return supported_node_groups; - } - } - for (const auto* output : graph_viewer.GetOutputs()) { - if (!IsTensorShapeSupported(*output, "graph", logger)) { + if (!IsInputSupported(*input, "graph", logger)) { return supported_node_groups; } } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 932d5eca908e2..9d1d7fb60dcca 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -180,7 +180,7 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; }); } -bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger); +bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger); // Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP. std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index 1e641017f36b6..fffe964e6aaf2 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -34,7 +34,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons if (!HasSupportedInputs(node, wnn_limits, logger)) return false; - if (!HasSupportedOutputs(node, wnn_limits, logger)) + if (!HasSupportedOutputsImpl(node, wnn_limits, logger)) return false; if (!HasSupportedOpSet(node, logger)) @@ -47,7 +47,7 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& const logging::Logger& logger) const { const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); for (const auto* input : node.InputDefs()) { - if (!IsTensorShapeSupported(*input, node_name, logger)) { + if (!IsInputSupported(*input, node_name, logger)) { return false; } } @@ -68,18 +68,6 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger); } -bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { - const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); - for (const auto* output : node.OutputDefs()) { - if (!IsTensorShapeSupported(*output, node_name, logger)) { - return false; - } - } - - return HasSupportedOutputsImpl(node, wnn_limits, logger); -} - bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h index a632876dab2b9..584455f62cb4e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -54,7 +54,6 @@ class BaseOpBuilder : public IOpBuilder { private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; - bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; }; } // namespace webnn diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 84f8cc4b14665..192f399ecdc26 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -222,7 +222,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i if (!shape.empty()) { dims.reserve(shape.size()); for (const auto& dim : shape) { - // dim_param free dimensions should have already been excluded by IsTensorShapeSupported(). + // dim_param free dimensions should have already been excluded by IsInputSupported(). assert(dim.has_dim_value()); dims.push_back(SafeInt(dim.dim_value())); }