Skip to content

Commit

Permalink
[WebNN EP] Support subgraph of the control flow nodes (#18923)
Browse files Browse the repository at this point in the history
This PR also makes some processing on the subgraph's initializers. The
subgraph doesn't contain all its required initializers, some common
initializers are stored in its ancestor graphs. We need to collect all
required initializers and re-map to the subgraph.
  • Loading branch information
Honry authored Jan 9, 2024
1 parent 76dfe53 commit fa14dcd
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 6 deletions.
18 changes: 18 additions & 0 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,24 @@
namespace onnxruntime {
namespace webnn {

InitializedTensorSet CollectAllInitializedTensors(const GraphViewer& graph_viewer) {
InitializedTensorSet all_initializers;
if (graph_viewer.IsSubgraph()) {
const Graph* cur_graph = &graph_viewer.GetGraph();
// Traverse up to the top-level graph, collecting all initializers.
while (cur_graph->IsSubgraph()) {
const auto& current_initializers = cur_graph->GetAllInitializedTensors();
all_initializers.insert(current_initializers.begin(), current_initializers.end());
cur_graph = cur_graph->ParentGraph();
}
// Collect initializers in top-level graph.
const auto& current_initializers = cur_graph->GetAllInitializedTensors();
all_initializers.insert(current_initializers.begin(), current_initializers.end());
}

return all_initializers;
}

bool GetShape(const NodeArg& node_arg, std::vector<int64_t>& shape, const logging::Logger& logger) {
const auto* shape_proto = node_arg.Shape();
if (!shape_proto) {
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ typedef struct {
bool isCpuSupported; // The WebNN CPU backend XNNPack supports it (not about the CPU EP).
} WebnnOpInfo;

// Collects all the initializer tensors in the subGraph and its ancestor graphs.
InitializedTensorSet CollectAllInitializedTensors(const GraphViewer& graph_viewer);

bool GetShape(const NodeArg& node_arg, std::vector<int64_t>& shape, const logging::Logger& logger);

template <typename T>
Expand Down
19 changes: 19 additions & 0 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,25 @@ Status ModelBuilder::Initialize() {
return Status::OK();
}

InitializedTensorSet ModelBuilder::GetInitializerTensors() {
if (graph_viewer_.IsSubgraph()) {
auto all_initializers = CollectAllInitializedTensors(graph_viewer_);
const auto sub_graph_id = graph_viewer_.GetFilterInfo();
const auto subgraph_initializer_names = sub_graph_id->GetMetaDef()->constant_initializers;
InitializedTensorSet subgraph_initializers;

for (const auto& name : subgraph_initializer_names) {
auto it = all_initializers.find(name);
if (it != all_initializers.end()) {
subgraph_initializers.insert(*it);
}
}
return subgraph_initializers;
} else {
return graph_viewer_.GetAllInitializedTensors();
}
}

/* static */ const IOpBuilder* ModelBuilder::GetOpBuilder(const Node& node) {
const auto& op_builders = GetOpBuilders();
const auto it = op_builders.find(node.OpType());
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webnn/builders/model_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ModelBuilder {

// Accessors for members.
const GraphViewer& GetGraphViewer() const { return graph_viewer_; }
const InitializedTensorSet& GetInitializerTensors() const { return graph_viewer_.GetAllInitializedTensors(); }
InitializedTensorSet GetInitializerTensors();

const emscripten::val& GetBuilder() const { return wnn_builder_; }
const emscripten::val& GetContext() const { return wnn_context_; }
Expand Down
26 changes: 21 additions & 5 deletions onnxruntime/core/providers/webnn/webnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,15 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
const IKernelLookup& /*kernel_registries*/) const {
std::vector<std::unique_ptr<ComputeCapability>> result;

// We do not run WebNN EP on subgraph, instead we cover this in the control flow nodes.
// TODO investigate whether we want to support subgraph using WebNN EP.
if (graph_viewer.IsSubgraph()) {
return result;
// For subgraph which is the attribute of the control flow nodes, part of its initializers are stored in its
// ancestor graphs as common initializers shared for other subgraphs. We need to collect all of them used for
// identifying the required initializer names and storing into 'meta_def->constant_initializers'.
// Thus we are able to get the required initialized tensors for this subgraph via the GetInitializerTensors()
// method defined in the model_builder.h file.
InitializedTensorSet all_initializers;
const bool is_subgraph = graph_viewer.IsSubgraph();
if (is_subgraph) {
all_initializers = webnn::CollectAllInitializedTensors(graph_viewer);
}

/*
Expand Down Expand Up @@ -110,6 +115,7 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view

std::unique_ptr<IndexedSubGraph> sub_graph = std::make_unique<IndexedSubGraph>();

std::vector<std::string> subgraph_initializers;
InlinedHashSet<const NodeArg*> node_outputs;
InlinedHashSet<const NodeArg*> subgraph_inputs;
InlinedHashSet<const NodeArg*> subgraph_outputs;
Expand All @@ -126,7 +132,11 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
// skip the placeholder inputs.
continue;
}
// if the node input was not produced by this subgraph, add it to the subgraph inputs.
// If it is a subgraph of a control flow node, collect the constant initializer.
if (is_subgraph && Contains(all_initializers, input->Name())) {
subgraph_initializers.push_back(input->Name());
}
// If the node input was not produced by this subgraph, add it to the subgraph inputs.
if (node_outputs.count(input) == 0) {
if (subgraph_inputs.count(input) == 0) {
subgraph_inputs.insert(input);
Expand Down Expand Up @@ -165,6 +175,12 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
meta_def->since_version = 1;
meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL;

if (is_subgraph) {
for (const auto& initializer : subgraph_initializers) {
meta_def->constant_initializers.push_back(initializer);
}
}

for (const auto& input : ordered_subgraph_inputs) {
meta_def->inputs.push_back(input->Name());
}
Expand Down

0 comments on commit fa14dcd

Please sign in to comment.