Skip to content

Commit

Permalink
Allow ops to handle ignoring an empty tensor as input
Browse files Browse the repository at this point in the history
  • Loading branch information
Honry committed Dec 5, 2024
1 parent 85ceffe commit 58b598f
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 10 deletions.
5 changes: 3 additions & 2 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We
}
}

bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) {
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name,
const logging::Logger& logger, bool allow_empty_input) {
const auto& node_arg_name = node_arg.Name();
const auto* shape_proto = node_arg.Shape();
// Optional tensors can be indicated by an empty name, just ignore it.
Expand All @@ -89,7 +90,7 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n
<< "use sessionOptions.FreeDimensionOverrides to set a fixed shape: " << node_arg_name;
return false;
}
if (dim.dim_value() == 0) {
if (dim.dim_value() == 0 && !allow_empty_input) {
LOGS(logger, VERBOSE) << "The shape of [" << node_arg_name << "] has 0 dimension which is not supported by WebNN";
return false;
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s
return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; });
}

bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name,
const logging::Logger& logger, bool allow_empty_input = false);

// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node&
bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
if (!HasSupportedInputs(initializers, node, wnn_limits, logger))
if (!HasSupportedInputs(node, wnn_limits, logger))
return false;

if (!HasSupportedOutputs(node, wnn_limits, logger))
Expand All @@ -41,12 +41,11 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
return IsOpSupportedImpl(initializers, node, device_type, logger);
}

bool BaseOpBuilder::HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
for (const auto* input : node.InputDefs()) {
// ONNX initializers should have shape information, skip the shape check if the input is an initializer.
if (!Contains(initializers, input->Name()) && !IsTensorShapeSupported(*input, node_name, logger)) {
if (!IsTensorShapeSupported(*input, node_name, logger, allow_empty_tensor_as_input_)) {
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class BaseOpBuilder : public IOpBuilder {
const logging::Logger& logger) const override final ORT_MUST_USE_RESULT;

protected:
explicit BaseOpBuilder(bool allow_empty_tensor_as_input = false)
: allow_empty_tensor_as_input_(allow_empty_tensor_as_input) {
}
virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const ORT_MUST_USE_RESULT = 0;

Expand Down Expand Up @@ -53,9 +56,10 @@ class BaseOpBuilder : public IOpBuilder {

private:
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
bool HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const;
bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;

const bool allow_empty_tensor_as_input_; // Some operators can handle ignoring an empty tensor as input.
};

} // namespace webnn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ namespace webnn {
class ResizeOpBuilder : public BaseOpBuilder {
// Add operator related.
public:
// Allow roi and scales potentially being empty inputs that are ignored during processing.
ResizeOpBuilder() : BaseOpBuilder(/*allow empty inputs*/ true) {}
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;

private:
Expand Down

0 comments on commit 58b598f

Please sign in to comment.