Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve the partition logic #19789

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 46 additions & 16 deletions onnxruntime/core/providers/partitioning_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@
std::unordered_map<NodeIndex, size_t> in_degree{};
// nodes that are ready to process
std::deque<const Node*> nodes_to_process{};
// keep track of the initial nodes_to_process
std::list<const Node*> initial_nodes_to_process{};

Check warning on line 109 in onnxruntime/core/providers/partitioning_utils.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <list> for list<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/partitioning_utils.cc:109: Add #include <list> for list<> [build/include_what_you_use] [4]
// nodes that will be processed when considering the next partition node group
std::deque<const Node*> nodes_to_process_with_next_group{};

Expand All @@ -115,6 +117,7 @@
in_degree.insert({node.Index(), node_input_edge_count});
if (node_input_edge_count == 0) {
nodes_to_process.push_back(&node);
initial_nodes_to_process.push_back(&node);
}
}

Expand Down Expand Up @@ -153,6 +156,30 @@

while (!nodes_to_process.empty() || !nodes_to_process_with_next_group.empty()) {
if (nodes_to_process.empty()) {
auto node = initial_nodes_to_process.begin();
while (node != initial_nodes_to_process.end()) {
bool node_is_counsumed = false;

for (auto output = (*node)->OutputNodesBegin(); output != (*node)->OutputNodesEnd(); ++output) {
if (std::find(supported_group.begin(), supported_group.end(), &(*output)) != supported_group.end()) {
node_is_counsumed = true;
break;
}
}
// The node output is consumed by nodes in supported_group, remove it from initial_nodes_to_process
if (node_is_counsumed) {
node = initial_nodes_to_process.erase(node);
} else {
++node;
}
}

// nodes left in initial_nodes_to_process are nodes not consumed by current group
// try them with next group
for (const auto* node_not_consumed : initial_nodes_to_process) {
supported_group.erase(std::find(supported_group.begin(), supported_group.end(), node_not_consumed));
nodes_to_process_with_next_group.push_back(node_not_consumed);
}
// we have processed all the nodes that we can while building this partition node group, start a new one
close_group();
nodes_to_process.swap(nodes_to_process_with_next_group);
Expand Down Expand Up @@ -252,23 +279,8 @@
const auto& graph_output_list = graph_viewer.GetOutputs();
std::unordered_set<const NodeArg*> graph_outputs(graph_output_list.cbegin(), graph_output_list.cend());

// Process output first in case nodes not in topological order
for (const Node* node : group) {
sub_graph->nodes.push_back(node->Index());

for (const auto* input : node->InputDefs()) {
if (!input->Exists()) {
// skip the placeholder inputs
continue;
}
// if the node input was not produced by this subgraph, add it to the subgraph inputs.
if (!Contains(node_outputs, input)) {
if (!Contains(subgraph_inputs, input)) {
subgraph_inputs.insert(input);
ordered_subgraph_inputs.push_back(input);
}
}
}

const auto& output_defs = node->OutputDefs();
for (const auto* output_def : output_defs) {
node_outputs.insert(output_def);
Expand All @@ -291,6 +303,24 @@
}
}

for (const Node* node : group) {
sub_graph->nodes.push_back(node->Index());

for (const auto* input : node->InputDefs()) {
if (!input->Exists()) {
// skip the placeholder inputs
continue;
}
// if the node input was not produced by this subgraph, add it to the subgraph inputs.
if (!Contains(node_outputs, input)) {
if (!Contains(subgraph_inputs, input)) {
subgraph_inputs.insert(input);
ordered_subgraph_inputs.push_back(input);
}
}
}
}

// Assign inputs and outputs to subgraph's meta_def
auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>();
meta_def->name = generate_metadef_name();
Expand Down
50 changes: 50 additions & 0 deletions onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "core/session/onnxruntime_cxx_api.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/inference_session.h"
#include "core/framework/compute_capability.h"
#include "core/providers/partitioning_utils.h"

#include "test/providers/qnn/qnn_test_utils.h"

Expand Down Expand Up @@ -75,6 +77,52 @@ static GetTestModelFn BuildGraphWithQAndNonQ(bool single_ep_node = true) {
};
}

// This test doesn't run the model through optimization, the ini->Q->DQ on the 2nd Add is placed to the 1st partition
// Still need to be improved.
TEST(GraphPartitionTest, DISABLED_2PartitionWithEqualNodes) {
const std::unordered_map<std::string, int> domain_to_version = {{"", 13}, {kMSDomain, 1}};

auto& logging_manager = DefaultLoggingManager();
logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR);

onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(),
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {},
logging_manager.DefaultLogger());
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
bool single_ep_node = false;
BuildGraphWithQAndNonQ(single_ep_node)(helper);
helper.SetGraphOutputs();
ASSERT_STATUS_OK(model.MainGraph().Resolve());
auto status = Model::Save(model, "test.onnx");
(status);

std::unordered_set<const Node*> supported_nodes{};
for (const auto& node : model.MainGraph().Nodes()) {
if ("FusedMatMul" != node.OpType()) {
supported_nodes.insert(&node);
}
}

int metadef_id = 0;
const auto gen_metadef_name = [&]() {
return MakeString("QNN_0971957_", metadef_id++);
};
GraphViewer viewer(graph);
auto partitions = utils::CreateSupportedPartitions(viewer,
supported_nodes, {},
gen_metadef_name,
"QNN",
kQnnExecutionProvider,
true);
ASSERT_EQ(partitions.size(), 2);
for (auto& partition : partitions) {
ASSERT_EQ(partition->sub_graph->nodes.size(), 7);
auto input_count = partition->sub_graph->GetMetaDef()->inputs.size();
ASSERT_EQ(input_count, 8);
}
}

void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) {
ProviderOptions provider_options;
#if defined(_WIN32)
Expand Down Expand Up @@ -123,6 +171,8 @@ void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) {
for (auto& node : ctx_graph.Nodes()) {
if (node.OpType() == "EPContext") {
++ep_context_node_count;
// validate the fix for the partition issue
ASSERT_EQ(node.InputDefs().size(), 1);
} else {
++non_ep_context_node_count;
}
Expand Down
Loading