From 16f8c0971b91149b4d597056eefc53df622f8e3d Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Thu, 7 Mar 2024 22:20:47 +1000 Subject: [PATCH] - Refactor to build in the various configurations. - The whole QDQ setup needs a rethink at some point as it's currently spread across too many places (framework, optimizer, base providers lib, EP specific providers lib) - move NodeGroup to framework/node_unit.h and ValidateNodeGroupQDQNodes to NodeGroup::CanCreateNodeGroup so it's in the framework lib as it's used by NodeUnit - move GetAllNodeUnits to optimizer - doesn't quite belong there but this works will all the current EPs that use it. --- onnxruntime/core/framework/node_unit.cc | 122 +++++++++++------- onnxruntime/core/framework/node_unit.h | 23 ++-- .../selectors_actions/qdq_selectors.cc | 6 +- .../selectors_actions/qdq_selectors.h | 8 +- .../selectors_actions/shared/utils.cc | 93 ++++++------- .../selectors_actions/shared/utils.h | 17 ++- .../nnapi_builtin/builders/model_builder.cc | 4 +- .../nnapi_builtin/nnapi_execution_provider.cc | 4 +- .../core/providers/qnn/builder/qnn_model.cc | 4 +- .../providers/qnn/qnn_execution_provider.cc | 4 +- onnxruntime/core/providers/utils.cc | 3 +- .../xnnpack/xnnpack_execution_provider.cc | 16 +-- .../mlas/unittest/test_fp16_activation.cpp | 1 + .../test/optimizer/qdq_transformer_test.cc | 14 +- .../test/providers/partitioning_utils_test.cc | 9 +- 15 files changed, 177 insertions(+), 151 deletions(-) diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc index 5de51500bd074..9be3fd53bb7ae 100644 --- a/onnxruntime/core/framework/node_unit.cc +++ b/onnxruntime/core/framework/node_unit.cc @@ -3,8 +3,6 @@ #include "node_unit.h" #include "core/graph/graph_viewer.h" -#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" -#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" namespace onnxruntime { @@ -145,6 +143,80 @@ std::vector GetQDQIODefs(const Node& target_node, const QDQ::Node } // namespace +Status QDQ::NodeGroup::CanCreateNodeGroup(const GraphViewer& graph_viewer, + const Node& target_node, + gsl::span dq_nodes, + gsl::span q_nodes) { + // Within a QDQ node group, a target node input is the only consumer of each DQ. + // This should have been ensured by the EnsureUniqueDQForNodeUnit graph transformer, but other graph modifications + // may have happened since. Verify that this is still true. + for (const auto* dq_node : dq_nodes) { + const bool dq_produces_graph_output = graph_viewer.NodeProducesGraphOutput(*dq_node); + ORT_RETURN_IF(dq_produces_graph_output, + "QDQ node group cannot have DQ node that produces a graph output. DQ node: ", dq_node->Name(), + ", target node: ", target_node.Name()); + + const bool dq_has_single_output_edge_to_target = + dq_node->GetOutputEdgesCount() == 1 && + dq_node->OutputEdgesBegin()->GetNode().Index() == target_node.Index(); + ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target, + "QDQ node group cannot have DQ that doesn't have a single output edge to the target node. " + "DQ node: ", + dq_node->Name(), ", target node: ", target_node.Name()); + } + + // an output from the target node can have either Q consumers or direct consumers. it cannot have both. + // this must be checked on a per output basis. + // e.g. TopK produces values and indices. The indices output won't be quantized, so even if we replace the TopK QDQ + // node group with a quantized TopK, an int64_t indices value will be produced and can provide a graph output. + if (!q_nodes.empty()) { + auto cur_edge = target_node.OutputEdgesBegin(); + auto end_edge = target_node.OutputEdgesEnd(); + std::vector output_consumers(target_node.OutputDefs().size(), nullptr); + + for (; cur_edge != end_edge; ++cur_edge) { + auto output_idx = cur_edge->GetSrcArgIndex(); + const Node& this_consumer = cur_edge->GetNode(); + const Node* existing_consumer = output_consumers[output_idx]; + + if (existing_consumer != nullptr) { + // another edge for this output. either both are Q or both are not. + bool valid = true; + if (existing_consumer->OpType() == "QuantizeLinear") { + valid = this_consumer.OpType() == "QuantizeLinear"; + } else { + valid = this_consumer.OpType() != "QuantizeLinear"; + } + + ORT_RETURN_IF_NOT(valid, + "QDQ node group cannot have an output from the target node being consumed by a Q node and " + "a non-Q node. target node: ", + target_node.Name()); + } else { + output_consumers[output_idx] = &this_consumer; + } + } + + const auto& graph_outputs = graph_viewer.GetOutputs(); + for (size_t idx = 0, end = output_consumers.size(); idx < end; ++idx) { + // any output with a Q cannot be a graph output as it will disappear if the QDQ node unit is converted to + // a quantized op. + if (output_consumers[idx] != nullptr && output_consumers[idx]->OpType() == "QuantizeLinear") { + const auto& output_name = target_node.OutputDefs()[idx]->Name(); + bool is_graph_output = std::any_of(graph_outputs.begin(), graph_outputs.end(), + [&output_name](const NodeArg* node_arg) { + return node_arg->Name() == output_name; + }); + ORT_RETURN_IF(is_graph_output, + "QDQ node group cannot have an output from the target node that is consumed by a Q node and " + "a graph output. target node: ", + target_node.Name(), " output idx:", idx); + } + } + } + + return Status::OK(); +} NodeUnit::NodeUnit(const Node& node) : target_node_(node), type_(Type::SingleNode), @@ -159,7 +231,7 @@ NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_g type_(Type::QDQGroup), inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)}, outputs_{GetQDQIODefs(target_node_, node_group, false /* is_input */)} { - ORT_THROW_IF_ERROR(QDQ::ValidateNodeGroupQDQNodes(graph_viewer, target_node_, dq_nodes_, q_nodes_)); + ORT_THROW_IF_ERROR(QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, target_node_, dq_nodes_, q_nodes_)); input_edge_count_ = std::accumulate(dq_nodes_.cbegin(), dq_nodes_.cend(), size_t(0), [](size_t acc, const Node* node) { return acc + node->GetInputEdgesCount(); }); @@ -272,48 +344,4 @@ std::vector NodeUnit::GetAllNodesInGroup() const noexcept { return all_nodes; } -std::pair>, std::unordered_map> -GetAllNodeUnits(const GraphViewer& graph_viewer) { - std::vector> node_unit_holder; - std::unordered_map node_unit_map; - - const auto add_node_unit_to_map = [&](const std::vector& node_indices, const NodeUnit* node_unit) { - for (const auto& node_idx : node_indices) { - const auto* node = graph_viewer.GetNode(node_idx); - node_unit_map.insert({node, node_unit}); - } - }; - - // Get QDQ NodeUnits first - QDQ::SelectorManager selector_mgr; - const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer); - - for (const auto& qdq_selection : qdq_selections) { - auto qdq_unit = std::make_unique(graph_viewer, qdq_selection); - - // Fill the node to node_unit map for all nodes in the QDQ Group - add_node_unit_to_map(qdq_selection.dq_nodes, qdq_unit.get()); - add_node_unit_to_map(qdq_selection.q_nodes, qdq_unit.get()); - add_node_unit_to_map({qdq_selection.target_node}, qdq_unit.get()); - - node_unit_holder.push_back(std::move(qdq_unit)); - } - - // Get the left over SingleNode NodeUnits - const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); - for (const auto node_idx : node_indices) { - const auto* node(graph_viewer.GetNode(node_idx)); - - // This is already part of a QDQ NodeUnit - if (node_unit_map.find(node) != node_unit_map.cend()) - continue; - - auto node_unit = std::make_unique(*node); - node_unit_map[node] = node_unit.get(); - node_unit_holder.push_back(std::move(node_unit)); - } - - return std::make_pair(std::move(node_unit_holder), std::move(node_unit_map)); -} - } // namespace onnxruntime diff --git a/onnxruntime/core/framework/node_unit.h b/onnxruntime/core/framework/node_unit.h index af7772af9eba7..d15a0c96c08b4 100644 --- a/onnxruntime/core/framework/node_unit.h +++ b/onnxruntime/core/framework/node_unit.h @@ -18,8 +18,21 @@ class NodeArg; class Path; namespace QDQ { -struct NodeGroup; -} +// Struct to represent a DequantizeLinear -> Op -> QuantizeLinear node group +struct NodeGroup { + std::vector dq_nodes; + std::vector q_nodes; + NodeIndex target_node; + + // Validator to check if the set of nodes can form a valid QDQ NodeGroup. + // Checks target node is only consumer of each DQ, and that the outputs remain valid if the QDQ node group was to + // be converted into a single node with a quantized operator. + static Status CanCreateNodeGroup(const GraphViewer& graph_viewer, + const Node& target_node, + gsl::span dq_nodes, + gsl::span q_nodes); +}; +} // namespace QDQ // Definition of one input or output // If the optional quant_param is present, then this is a quantized input, @@ -96,10 +109,4 @@ class NodeUnit { Node::EdgeSet output_edges_; }; -// Get all the nodes in the given graph_viewer as NodeUnits (SingleNode or QDQGroup) -// And return a map to quick query the NodeUnit which contains the given Node, -// Note, the value of the map is owned by the vector of std::unique_ptr -std::pair>, std::unordered_map> -GetAllNodeUnits(const GraphViewer& graph_viewer); - } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 2212e846104db..6b4f62ae1343d 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -58,7 +58,7 @@ bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Nod return false; } - if (const auto qdq_validation_status = QDQ::ValidateNodeGroupQDQNodes(graph_viewer, node, dq_nodes, q_nodes); + if (const auto qdq_validation_status = NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes); !qdq_validation_status.IsOK()) { return false; } @@ -153,7 +153,7 @@ bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } - if (const auto qdq_validation_status = QDQ::ValidateNodeGroupQDQNodes(graph_viewer, node, dq_nodes, q_nodes); + if (const auto qdq_validation_status = NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes); !qdq_validation_status.IsOK()) { return false; } @@ -544,7 +544,7 @@ bool TopKNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } - if (const auto qdq_validation_status = QDQ::ValidateNodeGroupQDQNodes(graph_viewer, node, dq_nodes, q_nodes); + if (const auto qdq_validation_status = QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes); !qdq_validation_status.IsOK()) { return false; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index deee6e7f25f1a..c90a42a36483d 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -5,6 +5,7 @@ #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +#include "core/framework/node_unit.h" #include "core/optimizer/selectors_actions/selector_action_transformer.h" namespace onnxruntime { @@ -13,13 +14,6 @@ class Node; namespace QDQ { -// Struct to represent a DQ->Op->Q node group -struct NodeGroup { - std::vector dq_nodes; - std::vector q_nodes; - NodeIndex target_node; -}; - class NodeGroupSelector { public: // This is a QDQ Selectors only function, will return QDQ::NodeGroup instead of NodesToOptimizeIndices diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index a45fdddd347c7..337316b4f37bb 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -13,6 +13,7 @@ #include #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" namespace onnxruntime { namespace QDQ { @@ -324,64 +325,48 @@ std::vector SelectorManager::GetQDQSelections(const GraphViewer& grap return qdq_selections; } -Status ValidateNodeGroupQDQNodes(const GraphViewer& graph_viewer, - const Node& target_node, - gsl::span dq_nodes, - gsl::span q_nodes) { - // Within a QDQ node group, a target node input is the only consumer of each DQ. - // This should have been ensured by the EnsureUniqueDQForNodeUnit graph transformer, but other graph modifications - // may have happened since. Verify that this is still true. - for (const auto* dq_node : dq_nodes) { - const bool dq_produces_graph_output = graph_viewer.NodeProducesGraphOutput(*dq_node); - ORT_RETURN_IF(dq_produces_graph_output, - "QDQ node group cannot have DQ node that produces a graph output. DQ node: ", dq_node->Name(), - ", target node: ", target_node.Name()); - - const bool dq_has_single_output_edge_to_target = - dq_node->GetOutputEdgesCount() == 1 && - dq_node->OutputEdgesBegin()->GetNode().Index() == target_node.Index(); - ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target, - "QDQ node group cannot have DQ that doesn't have a single output edge to the target node. " - "DQ node: ", - dq_node->Name(), ", target node: ", target_node.Name()); - } +std::pair>, std::unordered_map> +GetAllNodeUnits(const GraphViewer& graph_viewer) { + std::vector> node_unit_holder; + std::unordered_map node_unit_map; - // an output from the target node can have either Q consumers or direct consumers. it cannot have both. - // this must be checked on a per output basis. - // NOTE: rules about the target node not producing a graph output must be checked by the selector as it's operator - // dependent. - // e.g. TopK produces values and indices. The indices output won't be quantized, so even if we replace the TopK QDQ - // node group with a quantized TopK, an int64_t indices value will be produced and can provide a graph output. - if (!q_nodes.empty()) { - auto cur_edge = target_node.OutputEdgesBegin(); - auto end_edge = target_node.OutputEdgesEnd(); - std::vector output_consumers(target_node.OutputDefs().size(), nullptr); - - for (; cur_edge != end_edge; ++cur_edge) { - auto output_idx = cur_edge->GetSrcArgIndex(); - const Node& this_consumer = cur_edge->GetNode(); - const Node* existing_consumer = output_consumers[output_idx]; - - if (existing_consumer != nullptr) { - // another edge for this output. either both are Q or both are not. - bool valid = true; - if (existing_consumer->OpType() == "QuantizeLinear") { - valid = this_consumer.OpType() == "QuantizeLinear"; - } else { - valid = this_consumer.OpType() != "QuantizeLinear"; - } - - ORT_RETURN_IF_NOT(valid, - "QDQ node group cannot have an output from the target node being consumed by a Q node and " - "a non-Q node. target node: ", - target_node.Name()); - } else { - output_consumers[output_idx] = &this_consumer; - } + const auto add_node_unit_to_map = [&](const std::vector& node_indices, const NodeUnit* node_unit) { + for (const auto& node_idx : node_indices) { + const auto* node = graph_viewer.GetNode(node_idx); + node_unit_map.insert({node, node_unit}); } + }; + + // Get QDQ NodeUnits first + QDQ::SelectorManager selector_mgr; + const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer); + + for (const auto& qdq_selection : qdq_selections) { + auto qdq_unit = std::make_unique(graph_viewer, qdq_selection); + + // Fill the node to node_unit map for all nodes in the QDQ Group + add_node_unit_to_map(qdq_selection.dq_nodes, qdq_unit.get()); + add_node_unit_to_map(qdq_selection.q_nodes, qdq_unit.get()); + add_node_unit_to_map({qdq_selection.target_node}, qdq_unit.get()); + + node_unit_holder.push_back(std::move(qdq_unit)); + } + + // Get the left over SingleNode NodeUnits + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (const auto node_idx : node_indices) { + const auto* node(graph_viewer.GetNode(node_idx)); + + // This is already part of a QDQ NodeUnit + if (node_unit_map.find(node) != node_unit_map.cend()) + continue; + + auto node_unit = std::make_unique(*node); + node_unit_map[node] = node_unit.get(); + node_unit_holder.push_back(std::move(node_unit)); } - return Status::OK(); + return std::make_pair(std::move(node_unit_holder), std::move(node_unit_map)); } } // namespace QDQ diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h index 918bf32b90d45..de36202afff29 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h @@ -7,6 +7,7 @@ #include "core/common/common.h" #include "core/common/gsl.h" #include "core/common/inlined_containers.h" +#include "core/framework/node_unit.h" #include "core/graph/basic_types.h" #if !defined(ORT_MINIMAL_BUILD) @@ -78,12 +79,16 @@ class SelectorManager { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SelectorManager); }; -// Checks whether the provided DQ nodes are valid for forming a QDQ node group with the provided target node. -// Returns successful status if so, failed status with reason otherwise. -Status ValidateNodeGroupQDQNodes(const GraphViewer& graph_viewer, - const Node& target_node, - gsl::span dq_nodes, - gsl::span q_nodes); +// Get all the nodes in the given graph_viewer as NodeUnits (SingleNode or QDQGroup) +// And return a map to quick query the NodeUnit which contains the given Node, +// Note, the value of the map is owned by the vector of std::unique_ptr +// +// TODO: The overall QDQ setup needs refactoring to separate out generic functionality from optimizer specific +// functionality. +// We currently have a bit of a mess with generic things like this to get all the node units being in the optimizer +// library whereas it should be able to be used by an EP with no dependency on optimizers. +std::pair>, std::unordered_map> +GetAllNodeUnits(const GraphViewer& graph_viewer); } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc index 1595aa8bfaca3..d0ae32378379d 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc @@ -15,6 +15,8 @@ #include "core/framework/tensorprotoutils.h" #include "core/graph/graph_viewer.h" #include "core/optimizer/initializer.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/providers/common.h" #include "core/providers/nnapi/nnapi_builtin/nnapi_api_helper.h" #include "core/providers/nnapi/nnapi_builtin/builders/helper.h" @@ -119,7 +121,7 @@ const NodeUnit& ModelBuilder::GetNodeUnit(const Node* node) const { } void ModelBuilder::PreprocessNodeUnits() { - std::tie(node_unit_holder_, node_unit_map_) = GetAllNodeUnits(graph_viewer_); + std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_); } // Help to get all quantized operators' input and the NodeUnit(s) using the input diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc index 9cc6591feaa3f..4d2888222ff0f 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc @@ -9,6 +9,8 @@ #include "core/framework/compute_capability.h" #include "core/framework/node_unit.h" #include "core/graph/graph_viewer.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" #include "core/platform/env.h" #include "core/providers/common.h" #include "core/providers/nnapi/nnapi_builtin/builders/helper.h" @@ -119,7 +121,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); // This holds the result of whether a NodeUnit is supported or not, // to prevent nodes in a NodeUnit to be checked for multiple times diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index dc91b9dfa199e..b3501dfec1ba8 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -9,6 +9,8 @@ #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" #include "core/framework/utils.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" #include "core/providers/qnn/builder/qnn_utils.h" namespace onnxruntime { @@ -95,7 +97,7 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, // valid throughout the lifetime of the ModelBuilder std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); // This name must be same with the EPContext node name const auto& graph_name = fused_node.Name(); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 37969d19aa016..5c4fa3e0fb88b 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -10,6 +10,8 @@ #include "core/session/onnxruntime_run_options_config_keys.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/kernel_registry.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" #include "core/platform/env.h" #include "core/providers/common.h" #include "core/providers/partitioning_utils.h" @@ -494,7 +496,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, node_unit_holder.size(), is_qnn_ctx_model, logger); diff --git a/onnxruntime/core/providers/utils.cc b/onnxruntime/core/providers/utils.cc index ca3fc4fc1972b..b2f9d265ca053 100644 --- a/onnxruntime/core/providers/utils.cc +++ b/onnxruntime/core/providers/utils.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "core/framework/tensorprotoutils.h" -#include "utils.h" +#include "core/providers/utils.h" namespace onnxruntime { namespace utils { @@ -23,6 +23,5 @@ common::Status OutputOptionalWithoutDataHelper(const ONNX_NAMESPACE::TypeProto& return Status::OK(); } #endif - } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index 4930cd561dbe7..12e567e7080b3 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -6,17 +6,17 @@ #include #include -#include "core/graph/function_utils.h" -#include "xnnpack_execution_provider.h" -#include "detail/utils.h" -#include "detail/node_support_checker.h" - #include "core/framework/compute_capability.h" #include "core/framework/kernel_registry.h" #include "core/framework/node_unit.h" +#include "core/graph/function_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" - -#include "xnnpack_init.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/providers/xnnpack/xnnpack_execution_provider.h" +#include "core/providers/xnnpack/detail/utils.h" +#include "core/providers/xnnpack/detail/node_support_checker.h" +#include "core/providers/xnnpack/xnnpack_init.h" namespace onnxruntime { @@ -268,7 +268,7 @@ std::vector> XnnpackExecutionProvider::GetCap // Get all the NodeUnits in the GraphViewer so we can check if something is in a QDQ node group std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph); // This holds the result of whether a NodeUnit is supported or not, // to prevent nodes in a NodeUnit being checked for multiple times diff --git a/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp b/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp index 484a9a22429d5..969997d2b84ec 100644 --- a/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp +++ b/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "test_fp16.h" +#include #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 45ab0ca1b5509..fbd5c9b5a137b 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/framework/compute_capability.h" +#include "core/framework/node_unit.h" #include "core/graph/model.h" #include "core/graph/onnx_protobuf.h" #include "core/mlas/inc/mlas.h" @@ -9,7 +10,6 @@ #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" -#include "core/optimizer/utils.h" #include "core/providers/partitioning_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/environment.h" @@ -30,10 +30,6 @@ #pragma warning(disable : 4127) #endif // #if defined(_MSC_VER) -#ifdef USE_NNAPI -#include "core/framework/node_unit.h" -#endif // #ifdef USE_NNAPI - struct QDQOpKeys { const char* quantize_linear; const char* dequantize_linear; @@ -3243,14 +3239,14 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { ASSERT_EQ(std::vector({4}), qdq_group.q_nodes); } -// The function GetAllNodeUnits is enabled for NNAPI EP only for now -#ifdef USE_NNAPI +// The function GetAllNodeUnits is used by NNAPI, XNNPACK and QNN +#if defined(USE_NNAPI) || defined(USE_QNN) || defined(USE_XNNPACK) { // Get all the NodeUnits in the graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(whole_graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer); // We should get a single QDQ Node unit in the result ASSERT_EQ(1, node_unit_holder.size()); @@ -3288,7 +3284,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { verify_io_def(qdq_node_unit.Inputs()[2], *whole_graph_viewer.GetNode(2)); // DQ_bias verify_io_def(qdq_node_unit.Outputs()[0], *whole_graph_viewer.GetNode(4)); // Q_output } -#endif // #ifdef USE_NNAPI +#endif // defined(USE_NNAPI) || defined(USE_QNN) || defined(USE_XNNPACK) // Create a graph viewer covers part of the graph // Make sure the qdq conv selector will fail for the partial graph diff --git a/onnxruntime/test/providers/partitioning_utils_test.cc b/onnxruntime/test/providers/partitioning_utils_test.cc index 9cc93e7e02a12..67a6b68ae2c9a 100644 --- a/onnxruntime/test/providers/partitioning_utils_test.cc +++ b/onnxruntime/test/providers/partitioning_utils_test.cc @@ -7,9 +7,12 @@ #include "core/common/common.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" +#include "core/framework/node_unit.h" #include "core/framework/compute_capability.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" #include "core/providers/partitioning_utils.h" -#include "core/framework/node_unit.h" + #include "test/optimizer/graph_transform_test_builder.h" #include "test/optimizer/qdq_test_utils.h" #include "test/util/include/asserts.h" @@ -48,7 +51,7 @@ TEST(PartitioningUtilsTest, TestQDQHandling) { std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); auto result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed, gen_metadef_name, "TEST", kCpuExecutionProvider, &node_unit_map, @@ -79,7 +82,7 @@ static void CheckAllNodesProcessed(const std::function& std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); const auto is_node_supported = [&](const Node& /*node*/) -> bool { return true;