diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index cb75b0b8751bb..e4fefdbf86369 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -37,8 +37,13 @@ if (onnxruntime_BUILD_UNIT_TESTS) set(gtest_disable_pthreads ON) endif() set(INSTALL_GTEST OFF CACHE BOOL "" FORCE) - if (CMAKE_SYSTEM_NAME STREQUAL "iOS") - # Needs to update onnxruntime/test/xctest/xcgtest.mm + if (IOS OR ANDROID) + # on mobile platforms the absl flags class dumps the flag names (assumably for binary size), which breaks passing + # any args to gtest executables, such as using --gtest_filter to debug a specific test. + # Processing of compile definitions: + # https://github.com/abseil/abseil-cpp/blob/8dc90ff07402cd027daec520bb77f46e51855889/absl/flags/config.h#L21 + # If set, this code throws away the flag and does nothing on registration, which results in no flags being known: + # https://github.com/abseil/abseil-cpp/blob/8dc90ff07402cd027daec520bb77f46e51855889/absl/flags/flag.h#L205-L217 set(GTEST_HAS_ABSL OFF CACHE BOOL "" FORCE) else() set(GTEST_HAS_ABSL ON CACHE BOOL "" FORCE) diff --git a/cmake/onnxruntime_providers_coreml.cmake b/cmake/onnxruntime_providers_coreml.cmake index 8f3b1828e1c61..b8ebc4ca53239 100644 --- a/cmake/onnxruntime_providers_coreml.cmake +++ b/cmake/onnxruntime_providers_coreml.cmake @@ -70,8 +70,8 @@ list(FILTER coreml_proto_generated_srcs INCLUDE REGEX "\.pb\.(h|cc)$") source_group(TREE ${CMAKE_CURRENT_BINARY_DIR} PREFIX coreml_proto_generated FILES ${coreml_proto_generated_srcs}) # These are shared utils, -# TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML -file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS +# TODO, move this to a separate lib when used by EPs other than NNAPI and CoreML +file(GLOB onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" ) diff --git a/cmake/onnxruntime_providers_nnapi.cmake b/cmake/onnxruntime_providers_nnapi.cmake index 5ac25a3b76efb..b718a976eb26f 100644 --- a/cmake/onnxruntime_providers_nnapi.cmake +++ b/cmake/onnxruntime_providers_nnapi.cmake @@ -49,12 +49,10 @@ endif() # These are shared utils, - # TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML + # TODO, move this to a separate lib when used by EPs other than NNAPI and CoreML list(APPEND onnxruntime_provider_nnapi_cc_src_patterns "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc" ) file(GLOB onnxruntime_providers_nnapi_cc_srcs CONFIGURE_DEPENDS ${onnxruntime_provider_nnapi_cc_src_patterns}) @@ -81,4 +79,4 @@ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() \ No newline at end of file + endif() diff --git a/cmake/onnxruntime_providers_qnn.cmake b/cmake/onnxruntime_providers_qnn.cmake index a93a06e960c81..b68d84c23bb32 100644 --- a/cmake/onnxruntime_providers_qnn.cmake +++ b/cmake/onnxruntime_providers_qnn.cmake @@ -4,12 +4,10 @@ add_compile_definitions(USE_QNN=1) # These are shared utils, - # TODO, move this to a separated lib when used by EPs other than QNN, NNAPI and CoreML - file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS + # TODO, move to a separate lib when used by EPs other than QNN, NNAPI and CoreML + file(GLOB onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc" ) file(GLOB_RECURSE @@ -42,4 +40,4 @@ # ignore the warning unknown-pragmas on "pragma region" if(NOT MSVC) target_compile_options(onnxruntime_providers_qnn PRIVATE "-Wno-unknown-pragmas") - endif() \ No newline at end of file + endif() diff --git a/cmake/onnxruntime_providers_xnnpack.cmake b/cmake/onnxruntime_providers_xnnpack.cmake index 6342c24b2917e..796536ac9d12b 100644 --- a/cmake/onnxruntime_providers_xnnpack.cmake +++ b/cmake/onnxruntime_providers_xnnpack.cmake @@ -7,9 +7,6 @@ "${ONNXRUNTIME_INCLUDE_DIR}/core/providers/xnnpack/*.h" "${ONNXRUNTIME_ROOT}/core/providers/xnnpack/*.h" "${ONNXRUNTIME_ROOT}/core/providers/xnnpack/*.cc" - # utils for handling QDQ models - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc" ) source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_xnnpack_cc_srcs}) diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc new file mode 100644 index 0000000000000..4dee1c14b3761 --- /dev/null +++ b/onnxruntime/core/framework/node_unit.cc @@ -0,0 +1,351 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + +#include "node_unit.h" +#include "core/graph/graph_viewer.h" + +namespace onnxruntime { + +namespace { + +enum class QLinearOpType : uint8_t { + Unknown, // Unknown or not a linear quantized op + DequantizeLinear, + QuantizeLinear, + QLinearConv, + QLinearMatMul, + QLinearAdd, + QLinearSigmoid, + QLinearAveragePool, + QLinearMul, + QLinearReduceMean, + QLinearConcat, + QLinearGlobalAveragePool, + QLinearLeakyRelu, +}; + +QLinearOpType GetQLinearOpType(const onnxruntime::Node& node) { + const auto& op_type = node.OpType(); + if (op_type == "DequantizeLinear") + return QLinearOpType::DequantizeLinear; + else if (op_type == "QuantizeLinear") + return QLinearOpType::QuantizeLinear; + else if (op_type == "QLinearConv") + return QLinearOpType::QLinearConv; + else if (op_type == "QLinearMatMul") + return QLinearOpType::QLinearMatMul; + else if (op_type == "QLinearAdd") + return QLinearOpType::QLinearAdd; + else if (op_type == "QLinearSigmoid") + return QLinearOpType::QLinearSigmoid; + else if (op_type == "QLinearAveragePool") + return QLinearOpType::QLinearAveragePool; + else if (op_type == "QLinearMul") + return QLinearOpType::QLinearMul; + else if (op_type == "QLinearReduceMean") + return QLinearOpType::QLinearReduceMean; + else if (op_type == "QLinearConcat") + return QLinearOpType::QLinearConcat; + else if (op_type == "QLinearGlobalAveragePool") + return QLinearOpType::QLinearGlobalAveragePool; + else if (op_type == "QLinearLeakyRelu") + return QLinearOpType::QLinearLeakyRelu; + + return QLinearOpType::Unknown; +} + +// Ops have 1 input +bool IsUnaryQLinearOp(QLinearOpType type) { + return type == QLinearOpType::QLinearSigmoid || + type == QLinearOpType::QLinearAveragePool || + type == QLinearOpType::QLinearGlobalAveragePool || + type == QLinearOpType::QLinearLeakyRelu || + type == QLinearOpType::QLinearReduceMean; +} + +// Ops have 2 inputs +bool IsBinaryQLinearOp(QLinearOpType type) { + return type == QLinearOpType::QLinearConv || + type == QLinearOpType::QLinearMatMul || + type == QLinearOpType::QLinearAdd || + type == QLinearOpType::QLinearMul; +} + +// Ops have 1 or more inputs +bool IsVariadicQLinearOp(QLinearOpType type) { + return type == QLinearOpType::QLinearConcat; +} + +const std::vector GetQDQIONodes(const GraphViewer& graph_viewer, + const QDQ::NodeGroup& node_group, bool is_input) { + std::vector io_nodes; + const auto& src_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes; + io_nodes.reserve(src_nodes.size()); + for (const auto& node_idx : src_nodes) { + io_nodes.push_back(graph_viewer.GetNode(node_idx)); + } + + return io_nodes; +} + +// Get the input or output NodeUnitIODef(s) for the given QDQ NodeGroup +std::vector GetQDQIODefs(const Node& target_node, const QDQ::NodeGroup& node_group, bool is_input) { + const auto& dq_or_q_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes; + const auto target_node_io_defs = is_input ? target_node.InputDefs() : target_node.OutputDefs(); + const size_t target_node_io_defs_size = target_node_io_defs.size(); + + // Find all the quantized IO defs and indices (for the input/output of the target node) + std::unordered_map quantized_io_defs; + quantized_io_defs.reserve(target_node_io_defs_size); + + auto cur = is_input ? target_node.InputEdgesBegin() : target_node.OutputEdgesBegin(); + auto end = is_input ? target_node.InputEdgesEnd() : target_node.OutputEdgesEnd(); + + for (; cur != end; ++cur) { + const Node& node = cur->GetNode(); + + // If we can find the node index in the dq or q nodes this is a quantized input/output + if (std::find(dq_or_q_nodes.cbegin(), dq_or_q_nodes.cend(), node.Index()) != dq_or_q_nodes.cend()) { + const auto node_inputs = node.InputDefs(); + // quantization scale and zp are always the input[1, 2] + NodeUnitIODef::QuantParam quant_param{*node_inputs[1], node_inputs.size() == 3 ? node_inputs[2] : nullptr}; + + if (is_input) { + // DQ is input to the target node, use the DstArgIndex + auto idx = cur->GetDstArgIndex(); + // This is a DQ node, we are using x, x_scale, x_zp (input[0, 1, 2]) + quantized_io_defs.insert({idx, NodeUnitIODef{*node_inputs[0], quant_param}}); + } else { + // Q is output of the target node, use the SrcArgIndex + auto idx = cur->GetSrcArgIndex(); + // This is a Q node, we are using y (output[0]), y_scale, y_zp (input[1, 2]) + const auto node_outputs = node.OutputDefs(); + quantized_io_defs.insert({idx, NodeUnitIODef{*node_outputs[0], quant_param}}); + } + } + } + + // Construct the IODefs for this QDQ NodeGroup + std::vector io_defs; + io_defs.reserve(target_node_io_defs_size); + for (size_t i = 0; i < target_node_io_defs_size; i++) { + // If we can find the NodeUnitIODef for this index, this is a quantized input/output + if (quantized_io_defs.find(i) != quantized_io_defs.cend()) { + io_defs.push_back(std::move(quantized_io_defs.at(i))); + } else { + // This is a regular input + io_defs.push_back({*target_node_io_defs[i], std::nullopt}); + } + } + + return io_defs; +} + +} // namespace + +Status QDQ::NodeGroup::CanCreateNodeGroup(const GraphViewer& graph_viewer, + const Node& target_node, + gsl::span dq_nodes, + gsl::span q_nodes) { + // Within a QDQ node group, a target node input is the only consumer of each DQ. + // This should have been ensured by the EnsureUniqueDQForNodeUnit graph transformer, but other graph modifications + // may have happened since. Verify that this is still true. + for (const auto* dq_node : dq_nodes) { + const bool dq_produces_graph_output = graph_viewer.NodeProducesGraphOutput(*dq_node); + ORT_RETURN_IF(dq_produces_graph_output, + "QDQ node group cannot have DQ node that produces a graph output. DQ node: ", dq_node->Name(), + ", target node: ", target_node.Name()); + + const bool dq_has_single_output_edge_to_target = + dq_node->GetOutputEdgesCount() == 1 && + dq_node->OutputEdgesBegin()->GetNode().Index() == target_node.Index(); + ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target, + "QDQ node group cannot have DQ that doesn't have a single output edge to the target node. " + "DQ node: ", + dq_node->Name(), ", target node: ", target_node.Name()); + } + + // an output from the target node can have either Q consumers or direct consumers. it cannot have both. + // this must be checked on a per output basis. + // e.g. TopK produces values and indices. The indices output won't be quantized, so even if we replace the TopK QDQ + // node group with a quantized TopK, an int64_t indices value will be produced and can provide a graph output. + if (!q_nodes.empty()) { + auto cur_edge = target_node.OutputEdgesBegin(); + auto end_edge = target_node.OutputEdgesEnd(); + std::vector output_consumers(target_node.OutputDefs().size(), nullptr); + + for (; cur_edge != end_edge; ++cur_edge) { + auto output_idx = cur_edge->GetSrcArgIndex(); + const Node& this_consumer = cur_edge->GetNode(); + const Node* existing_consumer = output_consumers[output_idx]; + + if (existing_consumer != nullptr) { + // another edge for this output. either both are Q or both are not. + bool valid = true; + if (existing_consumer->OpType() == "QuantizeLinear") { + valid = this_consumer.OpType() == "QuantizeLinear"; + } else { + valid = this_consumer.OpType() != "QuantizeLinear"; + } + + ORT_RETURN_IF_NOT(valid, + "QDQ node group cannot have an output from the target node being consumed by a Q node and " + "a non-Q node. target node: ", + target_node.Name()); + } else { + output_consumers[output_idx] = &this_consumer; + } + } + + const auto& graph_outputs = graph_viewer.GetOutputs(); + for (size_t idx = 0, end = output_consumers.size(); idx < end; ++idx) { + // any output with a Q cannot be a graph output as it will disappear if the QDQ node unit is converted to + // a quantized op. + if (output_consumers[idx] != nullptr && output_consumers[idx]->OpType() == "QuantizeLinear") { + const auto& output_name = target_node.OutputDefs()[idx]->Name(); + bool is_graph_output = std::any_of(graph_outputs.begin(), graph_outputs.end(), + [&output_name](const NodeArg* node_arg) { + return node_arg->Name() == output_name; + }); + ORT_RETURN_IF(is_graph_output, + "QDQ node group cannot have an output from the target node that is consumed by a Q node and " + "a graph output. target node: ", + target_node.Name(), " output idx:", idx); + } + } + } + + return Status::OK(); +} +NodeUnit::NodeUnit(const Node& node) + : target_node_(node), + type_(Type::SingleNode), + input_edge_count_(node.GetInputEdgesCount()) { + InitForSingleNode(); +} + +NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group) + : dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)}, + target_node_(*graph_viewer.GetNode(node_group.target_node)), + q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)}, + type_(Type::QDQGroup), + inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)}, + outputs_{GetQDQIODefs(target_node_, node_group, false /* is_input */)} { + ORT_THROW_IF_ERROR(QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, target_node_, dq_nodes_, q_nodes_)); + + input_edge_count_ = std::accumulate(dq_nodes_.cbegin(), dq_nodes_.cend(), size_t(0), + [](size_t acc, const Node* node) { return acc + node->GetInputEdgesCount(); }); + + // add edges for inputs that are not from DQ nodes. there is one edge to each DQ node. + // other inputs could come from initializers or graph inputs (no edges) or other nodes (edge). + input_edge_count_ += target_node_.GetInputEdgesCount() - dq_nodes_.size(); + + // create output edges. each target node output either goes to Q node/s or non-Q node/s. + // ValidateNodeGroupQDQNodes ensures this. + auto cur_edge = target_node_.OutputEdgesBegin(); + auto end_edge = target_node_.OutputEdgesEnd(); + for (; cur_edge != end_edge; ++cur_edge) { + const Node& node = cur_edge->GetNode(); + + // if node is in q_nodes we hide the Q node. + if (std::find(q_nodes_.cbegin(), q_nodes_.cend(), &node) != q_nodes_.cend()) { + auto src_idx = cur_edge->GetSrcArgIndex(); + auto q_cur_edge = node.OutputEdgesBegin(); + auto q_end_edge = node.OutputEdgesEnd(); + for (; q_cur_edge != q_end_edge; ++q_cur_edge) { + output_edges_.insert(Node::EdgeEnd{q_cur_edge->GetNode(), src_idx, q_cur_edge->GetDstArgIndex()}); + } + } else { + // non-Q node, or Q node that isn't in the QDQ node group (unexpected but may be possible). add as-is. + output_edges_.insert(*cur_edge); + } + } +} + +const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); } +const std::string& NodeUnit::OpType() const noexcept { return target_node_.OpType(); } +const std::string& NodeUnit::Name() const noexcept { return target_node_.Name(); } +int NodeUnit::SinceVersion() const noexcept { return target_node_.SinceVersion(); } +NodeIndex NodeUnit::Index() const noexcept { return target_node_.Index(); } +const Path& NodeUnit::ModelPath() const noexcept { return target_node_.ModelPath(); } +ProviderType NodeUnit::GetExecutionProviderType() const noexcept { return target_node_.GetExecutionProviderType(); } + +void NodeUnit::InitForSingleNode() { + const auto& input_defs = target_node_.InputDefs(); + const auto& output_defs = target_node_.OutputDefs(); + auto qlinear_type = GetQLinearOpType(target_node_); + if (qlinear_type == QLinearOpType::Unknown || IsVariadicQLinearOp(qlinear_type)) { // TODO, add variadic support + // Not a Qlinear op, add all inputs / outputs + auto add_all_io = [](std::vector& defs, + const ConstPointerContainer>& node_defs) { + defs.reserve(node_defs.size()); + + for (const auto def : node_defs) { + defs.push_back(NodeUnitIODef{*def, std::nullopt}); + } + }; + + add_all_io(inputs_, input_defs); + add_all_io(outputs_, output_defs); + } else if (IsUnaryQLinearOp(qlinear_type)) { + // Unary QLinear Op has 5 inputs + // x, x_scale, x_zp, y_scale, y_zp (optional) + inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}}); + outputs_.push_back(NodeUnitIODef{*output_defs[0], + NodeUnitIODef::QuantParam{*input_defs[3], + input_defs.size() > 4 ? input_defs[4] : nullptr}}); + + } else if (IsBinaryQLinearOp(qlinear_type)) { + // Binary QLinear Op has 9 inputs + // x1, x1_scale, x1_zp, x2/w, x2_scale, x2_zp, y_scale , y_zp, B + inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}}); + inputs_.push_back(NodeUnitIODef{*input_defs[3], NodeUnitIODef::QuantParam{*input_defs[4], input_defs[5]}}); + + if (input_defs.size() == 9) { // has Bias + inputs_.push_back(NodeUnitIODef{*input_defs[8], std::nullopt}); // for Bias the scale and zp are optional + } + + outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[6], input_defs[7]}}); + + } else if (qlinear_type == QLinearOpType::DequantizeLinear) { + // DequantizeLinear has 3 inputs + // x, x_scale, x_zp + // output is not quantized + inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs.size() == 3 + ? input_defs[2] + : nullptr}}); + outputs_.push_back(NodeUnitIODef{*output_defs[0], std::nullopt}); + + } else if (qlinear_type == QLinearOpType::QuantizeLinear) { + // QuantizeLinear the input is not quantized and has 3 inputs + // x, y_scale, y_zp (optional) + // The output is quantized + inputs_.push_back(NodeUnitIODef{*input_defs[0], std::nullopt}); + outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs.size() == 3 + ? input_defs[2] + : nullptr}}); + } else { + ORT_THROW("The QLinear op [", static_cast(qlinear_type), "] is not supported"); + } +} + +Node::EdgeConstIterator NodeUnit::OutputEdgesBegin() const { + return (type_ == Type::SingleNode) ? target_node_.OutputEdgesBegin() : output_edges_.begin(); +} + +Node::EdgeConstIterator NodeUnit::OutputEdgesEnd() const { + return (type_ == Type::SingleNode) ? target_node_.OutputEdgesEnd() : output_edges_.end(); +} + +std::vector NodeUnit::GetAllNodesInGroup() const noexcept { + std::vector all_nodes = dq_nodes_; + all_nodes.push_back(&target_node_); + all_nodes.insert(all_nodes.end(), q_nodes_.begin(), q_nodes_.end()); + return all_nodes; +} + +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/providers/shared/node_unit/node_unit.h b/onnxruntime/core/framework/node_unit.h similarity index 54% rename from onnxruntime/core/providers/shared/node_unit/node_unit.h rename to onnxruntime/core/framework/node_unit.h index b47204ca3c42d..66afaec8ee1e2 100644 --- a/onnxruntime/core/providers/shared/node_unit/node_unit.h +++ b/onnxruntime/core/framework/node_unit.h @@ -3,6 +3,9 @@ #pragma once +// QDQ models require graph modification at runtime, so we know this infrastructure is not used in a minimal build +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + #include #include #include @@ -18,8 +21,21 @@ class NodeArg; class Path; namespace QDQ { -struct NodeGroup; -} +// Struct to represent a DequantizeLinear -> Op -> QuantizeLinear node group +struct NodeGroup { + std::vector dq_nodes; + std::vector q_nodes; + NodeIndex target_node; + + // Validator to check if the set of nodes can form a valid QDQ NodeGroup. + // Checks target node is only consumer of each DQ, and that the outputs remain valid if the QDQ node group was to + // be converted into a single node with a quantized operator. + static Status CanCreateNodeGroup(const GraphViewer& graph_viewer, + const Node& target_node, + gsl::span dq_nodes, + gsl::span q_nodes); +}; +} // namespace QDQ // Definition of one input or output // If the optional quant_param is present, then this is a quantized input, @@ -69,26 +85,33 @@ class NodeUnit { const std::vector& GetQNodes() const noexcept { return q_nodes_; } std::vector GetAllNodesInGroup() const noexcept; - Node::EdgeConstIterator OutputEdgesBegin(size_t index) const; - Node::EdgeConstIterator OutputEdgesEnd(size_t index) const; + /// Number of input edges to the logical node. For a QDQ node this is the count of input edges to the DQ nodes + /// plus any other edges to the target node for inputs that are not via a DQ node. + size_t InputEdgeCount() const { return input_edge_count_; } + + // output edges. src index is for outputs of the target node. dest index and node is for consumer of node unit + // output. any Q nodes are hidden. + Node::EdgeConstIterator OutputEdgesBegin() const; + Node::EdgeConstIterator OutputEdgesEnd() const; private: - const std::vector q_nodes_; // q-nodes for this NodeUnit - const std::vector dq_nodes_; // dq nodes for this NodeUnit, not all inputs + // Initialization for a NodeUnit that contains a single node + void InitForSingleNode(); + + const std::vector dq_nodes_; // dq nodes for this NodeUnit, not necessarily all inputs const Node& target_node_; + const std::vector q_nodes_; // q-nodes for this NodeUnit. not necessarily all outputs const Type type_; std::vector inputs_; std::vector outputs_; - // Initializing for a single Node - void InitForSingleNode(); -}; + size_t input_edge_count_; // total number of input edges -// Get all the nodes in the given graph_viewer as NodeUnits (SingleNode or QDQGroup) -// And return a map to quick query the NodeUnit which contains the given Node, -// Note, the value of the map is owned by the vector of std::unique_ptr -std::pair>, std::unordered_map> -GetAllNodeUnits(const GraphViewer& graph_viewer); + // output edges, hiding any Q nodes involved. src_idx will be value from target node. only used for QDQ node group. + Node::EdgeSet output_edges_; +}; } // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 8535b8c9a944a..6b4f62ae1343d 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -58,8 +58,8 @@ bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Nod return false; } - if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes); - !dq_validation_status.IsOK()) { + if (const auto qdq_validation_status = NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes); + !qdq_validation_status.IsOK()) { return false; } @@ -153,8 +153,8 @@ bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } - if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes); - !dq_validation_status.IsOK()) { + if (const auto qdq_validation_status = NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes); + !qdq_validation_status.IsOK()) { return false; } @@ -544,8 +544,8 @@ bool TopKNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } - if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes); - !dq_validation_status.IsOK()) { + if (const auto qdq_validation_status = QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes); + !qdq_validation_status.IsOK()) { return false; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index deee6e7f25f1a..c90a42a36483d 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -5,6 +5,7 @@ #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +#include "core/framework/node_unit.h" #include "core/optimizer/selectors_actions/selector_action_transformer.h" namespace onnxruntime { @@ -13,13 +14,6 @@ class Node; namespace QDQ { -// Struct to represent a DQ->Op->Q node group -struct NodeGroup { - std::vector dq_nodes; - std::vector q_nodes; - NodeIndex target_node; -}; - class NodeGroupSelector { public: // This is a QDQ Selectors only function, will return QDQ::NodeGroup instead of NodesToOptimizeIndices diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index 544fe82a268c8..1876f7826c968 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -13,6 +13,7 @@ #include #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" namespace onnxruntime { namespace QDQ { @@ -43,6 +44,7 @@ static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { {"Tile", {}}}; } +// These produce int64 indices output, which can't be quantized, so there's no downstream Q node. static const OpVersionsAndSelector::OpVersionsMap GetDropDQOpVersionsMap() { return {{"ArgMax", {}}, {"ArgMin", {}}}; @@ -324,28 +326,48 @@ std::vector SelectorManager::GetQDQSelections(const GraphViewer& grap return qdq_selections; } -Status ValidateNodeGroupDQNodes(const GraphViewer& graph_viewer, - const Node& target_node, - gsl::span dq_nodes) { - // Within a QDQ node group, a target node input is the only consumer of each DQ. - // This should have been ensured by the EnsureUniqueDQForNodeUnit graph transformer, but other graph modifications - // may have happened since. Verify that this is still true. - for (const auto* dq_node : dq_nodes) { - const bool dq_produces_graph_output = graph_viewer.NodeProducesGraphOutput(*dq_node); - ORT_RETURN_IF(dq_produces_graph_output, - "QDQ node group cannot have DQ node that produces a graph output. DQ node: ", dq_node->Name(), - ", target node: ", target_node.Name()); - - const bool dq_has_single_output_edge_to_target = - dq_node->GetOutputEdgesCount() == 1 && - dq_node->OutputEdgesBegin()->GetNode().Index() == target_node.Index(); - ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target, - "QDQ node group cannot have DQ that doesn't have a single output edge to the target node. " - "DQ node: ", - dq_node->Name(), ", target node: ", target_node.Name()); +std::pair>, std::unordered_map> +GetAllNodeUnits(const GraphViewer& graph_viewer) { + std::vector> node_unit_holder; + std::unordered_map node_unit_map; + + const auto add_node_unit_to_map = [&](const std::vector& node_indices, const NodeUnit* node_unit) { + for (const auto& node_idx : node_indices) { + const auto* node = graph_viewer.GetNode(node_idx); + node_unit_map.insert({node, node_unit}); + } + }; + + // Get QDQ NodeUnits first + QDQ::SelectorManager selector_mgr; + const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer); + + for (const auto& qdq_selection : qdq_selections) { + auto qdq_unit = std::make_unique(graph_viewer, qdq_selection); + + // Fill the node to node_unit map for all nodes in the QDQ Group + add_node_unit_to_map(qdq_selection.dq_nodes, qdq_unit.get()); + add_node_unit_to_map(qdq_selection.q_nodes, qdq_unit.get()); + add_node_unit_to_map({qdq_selection.target_node}, qdq_unit.get()); + + node_unit_holder.push_back(std::move(qdq_unit)); + } + + // Get the left over SingleNode NodeUnits + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (const auto node_idx : node_indices) { + const auto* node(graph_viewer.GetNode(node_idx)); + + // This is already part of a QDQ NodeUnit + if (node_unit_map.find(node) != node_unit_map.cend()) + continue; + + auto node_unit = std::make_unique(*node); + node_unit_map[node] = node_unit.get(); + node_unit_holder.push_back(std::move(node_unit)); } - return Status::OK(); + return std::make_pair(std::move(node_unit_holder), std::move(node_unit_map)); } } // namespace QDQ diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h index 246f26c1760ec..de36202afff29 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h @@ -7,6 +7,7 @@ #include "core/common/common.h" #include "core/common/gsl.h" #include "core/common/inlined_containers.h" +#include "core/framework/node_unit.h" #include "core/graph/basic_types.h" #if !defined(ORT_MINIMAL_BUILD) @@ -78,11 +79,16 @@ class SelectorManager { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SelectorManager); }; -// Checks whether the provided DQ nodes are valid for forming a QDQ node group with the provided target node. -// Returns successful status if so, failed status with reason otherwise. -Status ValidateNodeGroupDQNodes(const GraphViewer& graph_viewer, - const Node& target_node, - gsl::span dq_nodes); +// Get all the nodes in the given graph_viewer as NodeUnits (SingleNode or QDQGroup) +// And return a map to quick query the NodeUnit which contains the given Node, +// Note, the value of the map is owned by the vector of std::unique_ptr +// +// TODO: The overall QDQ setup needs refactoring to separate out generic functionality from optimizer specific +// functionality. +// We currently have a bit of a mess with generic things like this to get all the node units being in the optimizer +// library whereas it should be able to be used by an EP with no dependency on optimizers. +std::pair>, std::unordered_map> +GetAllNodeUnits(const GraphViewer& graph_viewer); } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 2d2c89f36f1a7..038423104d92e 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -21,7 +21,6 @@ #include "core/framework/kernel_registry.h" #include "core/graph/function_utils.h" #include "core/graph/indexed_sub_graph.h" -#include "core/providers/shared/node_unit/node_unit.h" #include "data_transfer.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc index 0b32508a5bb38..745504ca04941 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc @@ -11,6 +11,7 @@ #include "core/common/logging/logging.h" #include "core/common/safeint.h" +#include "core/framework/node_unit.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/graph_viewer.h" #include "core/graph/graph.h" @@ -18,7 +19,6 @@ #include "core/providers/common.h" #include "core/providers/nnapi/nnapi_builtin/builders/op_builder.h" #include "core/providers/nnapi/nnapi_builtin/builders/op_builder_factory.h" -#include "core/providers/shared/node_unit/node_unit.h" #include "core/providers/shared/utils/utils.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h index 6a54bf7bdb938..0c0bc7b2e4674 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h @@ -4,7 +4,7 @@ #pragma once #include "core/common/common.h" -#include "core/providers/shared/node_unit/node_unit.h" +#include "core/framework/node_unit.h" #include "core/providers/nnapi/nnapi_builtin/builders/model_builder.h" #include "core/providers/nnapi/nnapi_builtin/builders/op_builder.h" #include "core/providers/nnapi/nnapi_builtin/builders/op_builder_factory.h" diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc index 6962a7be94bb6..d0ae32378379d 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc @@ -11,17 +11,19 @@ #include "core/common/safeint.h" #include "core/common/status.h" #include "core/framework/execution_provider.h" +#include "core/framework/node_unit.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/graph_viewer.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/providers/common.h" #include "core/providers/nnapi/nnapi_builtin/nnapi_api_helper.h" -#include "core/providers/shared/node_unit/node_unit.h" -#include "core/providers/shared/utils/utils.h" #include "core/providers/nnapi/nnapi_builtin/builders/helper.h" #include "core/providers/nnapi/nnapi_builtin/builders/op_builder.h" #include "core/providers/nnapi/nnapi_builtin/builders/op_builder_factory.h" #include "core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h" -#include "core/optimizer/initializer.h" +#include "core/providers/shared/utils/utils.h" using namespace android::nn::wrapper; @@ -119,7 +121,7 @@ const NodeUnit& ModelBuilder::GetNodeUnit(const Node* node) const { } void ModelBuilder::PreprocessNodeUnits() { - std::tie(node_unit_holder_, node_unit_map_) = GetAllNodeUnits(graph_viewer_); + std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_); } // Help to get all quantized operators' input and the NodeUnit(s) using the input @@ -664,7 +666,7 @@ int32_t ModelBuilder::FindActivation(const NodeUnit& node_unit) { int32_t fuse_code = ANEURALNETWORKS_FUSED_NONE; bool fuse_code_assigned_from_activation = false; - for (auto it = node_unit.OutputEdgesBegin(0), end = node_unit.OutputEdgesEnd(0); it != end; ++it) { + for (auto it = node_unit.OutputEdgesBegin(), end = node_unit.OutputEdgesEnd(); it != end; ++it) { const auto& dst_node = it->GetNode(); const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()]; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc index 466865f23f49a..dab7bccf43396 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc @@ -21,7 +21,6 @@ #include "core/optimizer/initializer.h" #include "core/providers/common.h" #include "core/providers/shared/utils/utils.h" -#include "core/providers/shared/node_unit/node_unit.h" #include "core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h" namespace onnxruntime::nnapi::op_builder_helpers { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h index 61a16ceff752f..0844857a06d61 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h @@ -7,12 +7,12 @@ #include #include "core/common/common.h" +#include "core/framework/node_unit.h" #include "core/providers/common.h" #include "core/providers/nnapi/nnapi_builtin/builders/helper.h" #include "core/providers/nnapi/nnapi_builtin/builders/model_builder.h" #include "core/providers/nnapi/nnapi_builtin/builders/op_builder.h" #include "core/providers/nnapi/nnapi_builtin/nnapi_lib/NeuralNetworksWrapper.h" -#include "core/providers/shared/node_unit/node_unit.h" namespace onnxruntime::nnapi::op_builder_helpers { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc index b04703d7611ee..4d2888222ff0f 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc @@ -7,7 +7,10 @@ #include "core/common/logging/logging.h" #include "core/common/string_utils.h" #include "core/framework/compute_capability.h" +#include "core/framework/node_unit.h" #include "core/graph/graph_viewer.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" #include "core/platform/env.h" #include "core/providers/common.h" #include "core/providers/nnapi/nnapi_builtin/builders/helper.h" @@ -17,7 +20,6 @@ #include "core/providers/nnapi/nnapi_builtin/nnapi_api_helper.h" #include "core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h" #include "core/providers/partitioning_utils.h" -#include "core/providers/shared/node_unit/node_unit.h" #include "core/session/onnxruntime_cxx_api.h" namespace onnxruntime { @@ -119,7 +121,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); // This holds the result of whether a NodeUnit is supported or not, // to prevent nodes in a NodeUnit to be checked for multiple times @@ -181,7 +183,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view }; result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed, - gen_metadef_name, NNAPI, kNnapiExecutionProvider); + gen_metadef_name, NNAPI, kNnapiExecutionProvider, &node_unit_map); // Generally, NNAPI supports sub-graphs with at least one non-constant initializer input and one output. // So far, we have a few cases that sub-graph has zero valid inputs, like `CastLike` diff --git a/onnxruntime/core/providers/partitioning_utils.cc b/onnxruntime/core/providers/partitioning_utils.cc index d537a4cf58b2d..c45f5cd0848dd 100644 --- a/onnxruntime/core/providers/partitioning_utils.cc +++ b/onnxruntime/core/providers/partitioning_utils.cc @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// QDQ models require graph modification at runtime, so we know this infrastructure is not used in a minimal build +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + #include "core/providers/partitioning_utils.h" #include @@ -10,6 +13,7 @@ #include "core/framework/compute_capability.h" #include "core/framework/execution_provider.h" +#include "core/framework/node_unit.h" #include "core/graph/graph_viewer.h" #include "core/providers/common.h" @@ -76,6 +80,11 @@ When selecting the next node to process, we first take: The remaining unsupported nodes mark the border of the current group so they will be processed later when we consider the next group. +If node_unit_map is provided, we process NodeUnit instances (a logical 'Node' that can be a single node or a +QDQ node group) instead of individual Node instances. As an EP must take complete NodeUnit instances (i.e. it +must not break up a QDQ node group by taking a subset of nodes in it), this granularity of processing is valid. +It is required to ensure we do not break up a QDQ node unit during partitioning. + @param graph_viewer GraphViewer that IExecutionProvider::GetCapability is called with. @param is_node_supported_fn Callback to check whether a node is supported. @param on_group_closed_fn Callback to indicate a completed partition node group. @@ -88,6 +97,7 @@ std::vector> CreateSupportedPartitionNodeGroups( const IsNodeSupportedFn& is_node_supported_fn, const OnGroupClosedFn& on_group_closed_fn, const std::string& execution_provider_type, + const std::unordered_map* node_unit_map, bool debug_output) { #ifdef NDEBUG ORT_UNUSED_PARAMETER(debug_output); @@ -111,7 +121,18 @@ std::vector> CreateSupportedPartitionNodeGroups( // initialize in-degrees and find root nodes for (const auto& node_index : graph_viewer.GetNodesInTopologicalOrder()) { const auto& node = *graph_viewer.GetNode(node_index); - const auto node_input_edge_count = node.GetInputEdgesCount(); + auto node_input_edge_count = node.GetInputEdgesCount(); + + if (node_unit_map != nullptr) { + const auto& node_unit = node_unit_map->at(&node); + if (&node_unit->GetNode() != &node) { + // only process the target node + continue; + } + + node_input_edge_count = node_unit->InputEdgeCount(); + } + in_degree.insert({node.Index(), node_input_edge_count}); if (node_input_edge_count == 0) { nodes_to_process.push_back(&node); @@ -151,6 +172,8 @@ std::vector> CreateSupportedPartitionNodeGroups( } }; + size_t num_nodes_processed = 0; + while (!nodes_to_process.empty() || !nodes_to_process_with_next_group.empty()) { if (nodes_to_process.empty()) { // we have processed all the nodes that we can while building this partition node group, start a new one @@ -162,9 +185,13 @@ std::vector> CreateSupportedPartitionNodeGroups( const Node& node = *nodes_to_process.front(); nodes_to_process.pop_front(); + const NodeUnit* node_unit = node_unit_map ? node_unit_map->at(&node) : nullptr; + const bool is_qdq_node_unit = node_unit && node_unit->UnitType() == NodeUnit::Type::QDQGroup; + // a node that is already assigned to an EP other than current EP is unsupported - const bool is_node_supported = - (node.GetExecutionProviderType().empty() || node.GetExecutionProviderType() == execution_provider_type) && is_node_supported_fn(node); + const bool is_node_supported = (node.GetExecutionProviderType().empty() || + node.GetExecutionProviderType() == execution_provider_type) && + is_node_supported_fn(node); if (!is_node_supported && Contains(supported_group_border, &node)) { // an unsupported node on the border will be processed after the current partition node group @@ -173,34 +200,62 @@ std::vector> CreateSupportedPartitionNodeGroups( } if (is_node_supported) { - // add node to the partition node group - supported_group.push_back(&node); + if (is_qdq_node_unit) { + // add DQ -> node -> Q for the node unit. must be in topological order + for (const auto& dq : node_unit->GetDQNodes()) { + supported_group.push_back(dq); + } - // remove node from the border and add its outputs to the border + supported_group.push_back(&node); + + for (const auto& q : node_unit->GetQNodes()) { + supported_group.push_back(q); + } + } else { + supported_group.push_back(&node); + } + + // remove node from the border supported_group_border.erase(&node); + } - std::for_each( - node.OutputNodesBegin(), node.OutputNodesEnd(), - [&supported_group_border](const Node& output) { - supported_group_border.insert(&output); - }); + // For each downstream node: + // 1: add the downstream node to the border if the current node is supported + // 2: adjust in-degrees of the nodes consuming the current node's outputs, and add any new nodes to process + const auto process_downstream_node = [&](const Node& downstream_node) { + if (is_node_supported) { + supported_group_border.insert(&downstream_node); + } + + auto& downstream_node_in_degree = in_degree[downstream_node.Index()]; + --downstream_node_in_degree; + + if (downstream_node_in_degree == 0) { + nodes_to_process.push_back(&downstream_node); + } + }; + + if (node_unit_map) { + std::for_each(node_unit->OutputEdgesBegin(), node_unit->OutputEdgesEnd(), + [&](const Node::EdgeEnd& edge_end) { + const Node& n = edge_end.GetNode(); + const NodeUnit& downstream_node_unit = *node_unit_map->at(&n); + const Node& output = downstream_node_unit.GetNode(); + + process_downstream_node(output); + }); + } else { + std::for_each(node.OutputNodesBegin(), node.OutputNodesEnd(), process_downstream_node); } - // adjust in-degrees of the node outputs and add any new nodes to process - std::for_each( - node.OutputNodesBegin(), node.OutputNodesEnd(), - [&](const Node& output) { - auto& output_node_in_degree = in_degree[output.Index()]; - --output_node_in_degree; - - if (output_node_in_degree == 0) { - nodes_to_process.push_back(&output); - } - }); + ++num_nodes_processed; } close_group(); + ORT_ENFORCE(num_nodes_processed == in_degree.size(), + "Processed ", num_nodes_processed, " nodes. Expected to process ", in_degree.size()); + return supported_groups; } } // namespace @@ -318,11 +373,13 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, const GenerateMetadefNameFn& generate_metadef_name_fn, const std::string& execution_provider_name, const std::string& execution_provider_type, + const std::unordered_map* node_unit_map, bool debug_output) { const auto groups = CreateSupportedPartitionNodeGroups(graph_viewer, is_node_supported_fn, on_partition_closed_fn, execution_provider_type, + node_unit_map, debug_output); std::vector> partitions{}; @@ -346,6 +403,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, const GenerateMetadefNameFn& generate_metadef_name_fn, const std::string& execution_provider_name, const std::string& execution_provider_type, + const std::unordered_map* node_unit_map, bool debug_output) { const auto excluded_nodes = CreateExcludedNodeSet(graph_viewer, stop_ops); const bool check_excluded_nodes = !excluded_nodes.empty(); @@ -360,8 +418,11 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, generate_metadef_name_fn, execution_provider_name, execution_provider_type, + node_unit_map, debug_output); } } // namespace utils } // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/providers/partitioning_utils.h b/onnxruntime/core/providers/partitioning_utils.h index 136725c2f7250..c3f6b104e3f6a 100644 --- a/onnxruntime/core/providers/partitioning_utils.h +++ b/onnxruntime/core/providers/partitioning_utils.h @@ -3,6 +3,9 @@ #pragma once +// QDQ models require graph modification at runtime, so we know this infrastructure is not used in a minimal build +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + #include #include #include @@ -14,8 +17,9 @@ namespace onnxruntime { struct ComputeCapability; class GraphViewer; -class NodeArg; class Node; +class NodeArg; +class NodeUnit; namespace utils { @@ -56,6 +60,8 @@ Create the supported partitions for the execution provider. @param generate_metadef_name_fn Callback to create the name for the MetaDef. @param execution_provider_name Name of execution provider creating the ComputeCapability instance. @param execution_provider_type ExecutionProviderType of the EP creating this ComputeCapability instance. +@param node_unit_map Map of each Node in the graph_viewer to its NodeUnit. Provide if EP handles QDQ format models. + Should be created by EP calling GetAllNodeUnits. @param debug_output Print diagnostic output about the partitions and reasons for partition breaks. No-op in a release build. @@ -68,6 +74,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, const GenerateMetadefNameFn& generate_metadef_name_fn, const std::string& execution_provider_name, const std::string& execution_provider_type, + const std::unordered_map* node_unit_map = nullptr, bool debug_output = false); /** @@ -79,6 +86,8 @@ Create the supported partitions for the execution provider. @param generate_metadef_name Functor to create the name for the MetaDef. @param execution_provider_name Name of execution provider creating the ComputeCapability instance. @param execution_provider_type ExecutionProviderType of the EP creating this ComputeCapability instance. +@param node_unit_map Map of each Node in the graph_viewer to its NodeUnit. Provide if EP handles QDQ format models. + Should be created by EP calling GetAllNodeUnits. @param debug_output Print diagnostic output about the partitions and reasons for partition breaks. No-op in a release build. @@ -91,6 +100,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, const GenerateMetadefNameFn& generate_metadef_name, const std::string& execution_provider_name, const std::string& execution_provider_type, + const std::unordered_map* node_unit_map = nullptr, bool debug_output = false); /** @@ -125,3 +135,5 @@ InlinedHashSet CreateExcludedNodeSet(const GraphViewer& graph_viewe const std::unordered_set& stop_ops); } // namespace utils } // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/providers/qnn/builder/op_builder.h b/onnxruntime/core/providers/qnn/builder/op_builder.h index 018d9a2797a66..05398c3f22ea2 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder.h @@ -4,7 +4,7 @@ #pragma once #include "core/graph/graph_viewer.h" -#include "core/providers/shared/node_unit/node_unit.h" +#include "core/framework/node_unit.h" #include "core/providers/shared/utils/utils.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index dc91b9dfa199e..b3501dfec1ba8 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -9,6 +9,8 @@ #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" #include "core/framework/utils.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" #include "core/providers/qnn/builder/qnn_utils.h" namespace onnxruntime { @@ -95,7 +97,7 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, // valid throughout the lifetime of the ModelBuilder std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); // This name must be same with the EPContext node name const auto& graph_name = fused_node.Name(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h index d0dd091cb1688..8fed2f364ba5a 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h @@ -6,13 +6,13 @@ #include #include "core/common/status.h" +#include "core/framework/node_unit.h" #include "core/graph/graph_viewer.h" #include "core/platform/ort_mutex.h" #include "core/providers/qnn/builder/qnn_def.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/qnn_backend_manager.h" #include "core/session/onnxruntime_cxx_api.h" -#include "core/providers/shared/node_unit/node_unit.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index 8ae489c749f31..1e2993f246ae4 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -11,8 +11,8 @@ #include "QnnInterface.h" #include "qnn_def.h" #include "core/common/logging/logging.h" +#include "core/framework/node_unit.h" #include "core/graph/graph_viewer.h" -#include "core/providers/shared/node_unit/node_unit.h" #include "core/providers/shared/utils/utils.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 3d9cfd92b7922..5c4fa3e0fb88b 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -10,6 +10,8 @@ #include "core/session/onnxruntime_run_options_config_keys.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/kernel_registry.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" #include "core/platform/env.h" #include "core/providers/common.h" #include "core/providers/partitioning_utils.h" @@ -494,7 +496,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, node_unit_holder.size(), is_qnn_ctx_model, logger); @@ -534,44 +536,39 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer size_t num_of_supported_nodes = 0; // Create partitions from supported nodes. - { - std::vector> partitions = utils::CreateSupportedPartitions(graph_viewer, - supported_nodes, {}, - gen_metadef_name, QNN, - kQnnExecutionProvider, - true); - - // Filter out partitions that consist of a single QuantizeLinear or DequantizeLinear node. - // We also count the number of supported nodes in all valid partitions. - for (auto& partition : partitions) { - bool is_valid_partition = true; - size_t nodes_in_partition = 0; - - if (partition && partition->sub_graph) { - nodes_in_partition = partition->sub_graph->nodes.size(); - - if (nodes_in_partition == 1 && !is_qnn_ctx_model) { - const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]); - - if (!node) { - LOGS(logger, ERROR) << "QNN EP: Invalid node in partition of one node."; - is_valid_partition = false; - } else if (node->OpType() == "QuantizeLinear" || node->OpType() == "DequantizeLinear") { - LOGS(logger, WARNING) << "QNN EP does not support a single Quantize/Dequantize node in a partition."; - is_valid_partition = false; - } + std::vector> partitions = utils::CreateSupportedPartitions( + graph_viewer, supported_nodes, {}, gen_metadef_name, QNN, kQnnExecutionProvider, &node_unit_map, true); + + // Filter out partitions that consist of a single QuantizeLinear or DequantizeLinear node. + // We also count the number of supported nodes in all valid partitions. + for (auto& partition : partitions) { + bool is_valid_partition = true; + size_t nodes_in_partition = 0; + + if (partition && partition->sub_graph) { + nodes_in_partition = partition->sub_graph->nodes.size(); + + if (nodes_in_partition == 1 && !is_qnn_ctx_model) { + const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]); + + if (!node) { + LOGS(logger, ERROR) << "QNN EP: Invalid node in partition of one node."; + is_valid_partition = false; + } else if (node->OpType() == "QuantizeLinear" || node->OpType() == "DequantizeLinear") { + LOGS(logger, WARNING) << "QNN EP does not support a single Quantize/Dequantize node in a partition."; + is_valid_partition = false; } - } else { - LOGS(logger, ERROR) << "QNN EP: Invalid partition."; - is_valid_partition = false; } + } else { + LOGS(logger, ERROR) << "QNN EP: Invalid partition."; + is_valid_partition = false; + } - if (is_valid_partition) { - result.push_back(std::move(partition)); - num_of_supported_nodes += nodes_in_partition; - } - } // for - } + if (is_valid_partition) { + result.push_back(std::move(partition)); + num_of_supported_nodes += nodes_in_partition; + } + } // for const size_t num_of_partitions = result.size(); const auto summary_msg = MakeString("Number of partitions supported by QNN EP: ", num_of_partitions, diff --git a/onnxruntime/core/providers/shared/node_unit/node_unit.cc b/onnxruntime/core/providers/shared/node_unit/node_unit.cc deleted file mode 100644 index 10dd58ba28375..0000000000000 --- a/onnxruntime/core/providers/shared/node_unit/node_unit.cc +++ /dev/null @@ -1,319 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "node_unit.h" -#include "core/graph/graph_viewer.h" -#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" -#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" - -namespace onnxruntime { - -namespace { - -enum class QLinearOpType : uint8_t { - Unknown, // Unknown or not a linear quantized op - DequantizeLinear, - QuantizeLinear, - QLinearConv, - QLinearMatMul, - QLinearAdd, - QLinearSigmoid, - QLinearAveragePool, - QLinearMul, - QLinearReduceMean, - QLinearConcat, - QLinearGlobalAveragePool, - QLinearLeakyRelu, -}; - -QLinearOpType GetQLinearOpType(const onnxruntime::Node& node) { - const auto& op_type = node.OpType(); - if (op_type == "DequantizeLinear") - return QLinearOpType::DequantizeLinear; - else if (op_type == "QuantizeLinear") - return QLinearOpType::QuantizeLinear; - else if (op_type == "QLinearConv") - return QLinearOpType::QLinearConv; - else if (op_type == "QLinearMatMul") - return QLinearOpType::QLinearMatMul; - else if (op_type == "QLinearAdd") - return QLinearOpType::QLinearAdd; - else if (op_type == "QLinearSigmoid") - return QLinearOpType::QLinearSigmoid; - else if (op_type == "QLinearAveragePool") - return QLinearOpType::QLinearAveragePool; - else if (op_type == "QLinearMul") - return QLinearOpType::QLinearMul; - else if (op_type == "QLinearReduceMean") - return QLinearOpType::QLinearReduceMean; - else if (op_type == "QLinearConcat") - return QLinearOpType::QLinearConcat; - else if (op_type == "QLinearGlobalAveragePool") - return QLinearOpType::QLinearGlobalAveragePool; - else if (op_type == "QLinearLeakyRelu") - return QLinearOpType::QLinearLeakyRelu; - - return QLinearOpType::Unknown; -} - -// Ops have 1 input -bool IsUnaryQLinearOp(QLinearOpType type) { - return type == QLinearOpType::QLinearSigmoid || - type == QLinearOpType::QLinearAveragePool || - type == QLinearOpType::QLinearGlobalAveragePool || - type == QLinearOpType::QLinearLeakyRelu || - type == QLinearOpType::QLinearReduceMean; -} - -// Ops have 2 inputs -bool IsBinaryQLinearOp(QLinearOpType type) { - return type == QLinearOpType::QLinearConv || - type == QLinearOpType::QLinearMatMul || - type == QLinearOpType::QLinearAdd || - type == QLinearOpType::QLinearMul; -} - -// Ops have 1 or more inputs -bool IsVariadicQLinearOp(QLinearOpType type) { - return type == QLinearOpType::QLinearConcat; -} - -const std::vector GetQDQIONodes(const GraphViewer& graph_viewer, - const QDQ::NodeGroup& node_group, bool is_input) { - std::vector io_nodes; - const auto& src_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes; - io_nodes.reserve(src_nodes.size()); - for (const auto& node_idx : src_nodes) { - io_nodes.push_back(graph_viewer.GetNode(node_idx)); - } - return io_nodes; -} - -// Get the input or output NodeUnitIODef(s) for the given QDQ NodeGroup -std::vector GetQDQIODefs(const Node& target_node, const QDQ::NodeGroup& node_group, - bool is_input) { - const auto& dq_or_q_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes; - const auto target_node_io_defs = is_input ? target_node.InputDefs() : target_node.OutputDefs(); - const size_t target_node_io_defs_size = target_node_io_defs.size(); - - // Find all the quantized IO defs and indices (for the input to the target node) - std::unordered_map quantized_io_defs; - quantized_io_defs.reserve(target_node_io_defs_size); - - auto cur = is_input ? target_node.InputEdgesBegin() : target_node.OutputEdgesBegin(); - auto end = is_input ? target_node.InputEdgesEnd() : target_node.OutputEdgesEnd(); - for (; cur != end; ++cur) { - const Node& node = cur->GetNode(); - - // If we can find the node index in the dq or q nodes, then this is a quantize node (can be DQ or Q depends on is_input) - if (std::find(dq_or_q_nodes.cbegin(), dq_or_q_nodes.cend(), node.Index()) != dq_or_q_nodes.cend()) { - const auto node_inputs = node.InputDefs(); - // quantization scale and zp are always the input[1, 2] - NodeUnitIODef::QuantParam quant_param{ - *node_inputs[1], - node_inputs.size() == 3 ? node_inputs[2] : nullptr}; - if (is_input) { - // DQ is input to the target node, use the DstArgIndex - auto idx = cur->GetDstArgIndex(); - // This is a DQ node, we are using x, x_scale, x_zp (input[0, 1, 2]) - quantized_io_defs.insert({idx, NodeUnitIODef{*node_inputs[0], quant_param}}); - } else { - // Q is output of the target node, use the SrcArgIndex - auto idx = cur->GetSrcArgIndex(); - // This is a Q node, we are using y (output[0]), y_scale, y_zp (input[1, 2]) - const auto node_outputs = node.OutputDefs(); - quantized_io_defs.insert({idx, NodeUnitIODef{*node_outputs[0], quant_param}}); - } - } - } - - // Construct the IODefs for this QDQ NodeGroup - std::vector io_defs; - io_defs.reserve(target_node_io_defs_size); - for (size_t i = 0; i < target_node_io_defs_size; i++) { - // If we can find the NodeUnitIODef for this index, this is a quantized input - if (quantized_io_defs.find(i) != quantized_io_defs.cend()) { - io_defs.push_back(std::move(quantized_io_defs.at(i))); - } else { - // This is a regular input - io_defs.push_back({*target_node_io_defs[i], std::nullopt}); - } - } - - return io_defs; -} - -} // namespace - -NodeUnit::NodeUnit(const Node& node) - : target_node_(node), - type_(Type::SingleNode) { - InitForSingleNode(); -} - -NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group) - : q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)}, - dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)}, - target_node_(*graph_viewer.GetNode(node_group.target_node)), - type_(Type::QDQGroup), - inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)}, - outputs_{GetQDQIODefs(target_node_, node_group, false /* is_input */)} { - ORT_THROW_IF_ERROR(QDQ::ValidateNodeGroupDQNodes(graph_viewer, target_node_, dq_nodes_)); -} - -const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); } -const std::string& NodeUnit::OpType() const noexcept { return target_node_.OpType(); } -const std::string& NodeUnit::Name() const noexcept { return target_node_.Name(); } -int NodeUnit::SinceVersion() const noexcept { return target_node_.SinceVersion(); } -NodeIndex NodeUnit::Index() const noexcept { return target_node_.Index(); } -const Path& NodeUnit::ModelPath() const noexcept { return target_node_.ModelPath(); } -ProviderType NodeUnit::GetExecutionProviderType() const noexcept { return target_node_.GetExecutionProviderType(); } - -void NodeUnit::InitForSingleNode() { - const auto& input_defs = target_node_.InputDefs(); - const auto& output_defs = target_node_.OutputDefs(); - auto qlinear_type = GetQLinearOpType(target_node_); - if (qlinear_type == QLinearOpType::Unknown || - IsVariadicQLinearOp(qlinear_type)) { // TODO, add variadic support - // Not a Qlinear op, add all inputs / outputs - auto add_all_io = [](std::vector& defs, - const ConstPointerContainer>& node_defs) { - defs.reserve(node_defs.size()); - - for (const auto def : node_defs) { - defs.push_back(NodeUnitIODef{*def, std::nullopt}); - } - }; - add_all_io(inputs_, input_defs); - add_all_io(outputs_, output_defs); - } else if (IsUnaryQLinearOp(qlinear_type)) { - // Unary QLinear Op has 5 inputs - // x, x_scale, x_zp, y_scale, y_zp (optional) - inputs_.push_back(NodeUnitIODef{ - *input_defs[0], - NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}}); - - outputs_.push_back(NodeUnitIODef{ - *output_defs[0], - NodeUnitIODef::QuantParam{*input_defs[3], - input_defs.size() > 4 - ? input_defs[4] - : nullptr}}); - } else if (IsBinaryQLinearOp(qlinear_type)) { - // Binary QLinear Op has 9 inputs - // x1, x1_scale, x1_zp, x2/w, x2_scale, x2_zp, y_scale , y_zp, B - inputs_.push_back(NodeUnitIODef{ - *input_defs[0], - NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}}); - inputs_.push_back(NodeUnitIODef{ - *input_defs[3], - NodeUnitIODef::QuantParam{*input_defs[4], input_defs[5]}}); - - if (input_defs.size() == 9) { // has Bias - inputs_.push_back(NodeUnitIODef{ - *input_defs[8], - std::nullopt}); // for Bias the scale and zp are optional - } - - outputs_.push_back(NodeUnitIODef{ - *output_defs[0], - NodeUnitIODef::QuantParam{*input_defs[6], input_defs[7]}}); - } else if (qlinear_type == QLinearOpType::DequantizeLinear) { - // DequantizeLinear has 3 inputs - // x, x_scale, x_zp - // output is not quantized - inputs_.push_back(NodeUnitIODef{ - *input_defs[0], - NodeUnitIODef::QuantParam{*input_defs[1], - input_defs.size() == 3 - ? input_defs[2] - : nullptr}}); - outputs_.push_back(NodeUnitIODef{*output_defs[0], std::nullopt}); - } else if (qlinear_type == QLinearOpType::QuantizeLinear) { - // QuantizeLinear the input is not quantized and has 3 inputs - // x, y_scale, y_zp (optional) - // The output is quantized - inputs_.push_back(NodeUnitIODef{*input_defs[0], std::nullopt}); - outputs_.push_back(NodeUnitIODef{ - *output_defs[0], - NodeUnitIODef::QuantParam{*input_defs[1], - input_defs.size() == 3 - ? input_defs[2] - : nullptr}}); - } else { - ORT_THROW("The QLinear op [", static_cast(qlinear_type), "] is not supported"); - } -} - -Node::EdgeConstIterator NodeUnit::OutputEdgesBegin(size_t index) const { - if (type_ == Type::SingleNode) { - ORT_ENFORCE(index == 0, "invalid output node index"); - return target_node_.OutputEdgesBegin(); - } else { - ORT_ENFORCE(index < q_nodes_.size(), "invalid output node index"); - return q_nodes_[index]->OutputEdgesBegin(); - } -} - -Node::EdgeConstIterator NodeUnit::OutputEdgesEnd(size_t index) const { - if (type_ == Type::SingleNode) { - ORT_ENFORCE(index == 0, "invalid output node index"); - return target_node_.OutputEdgesEnd(); - } else { - ORT_ENFORCE(index < q_nodes_.size(), "invalid output node index"); - return q_nodes_[index]->OutputEdgesEnd(); - } -} - -std::vector NodeUnit::GetAllNodesInGroup() const noexcept { - std::vector all_nodes = dq_nodes_; - all_nodes.push_back(&target_node_); - all_nodes.insert(all_nodes.end(), q_nodes_.begin(), q_nodes_.end()); - return all_nodes; -} - -std::pair>, std::unordered_map> -GetAllNodeUnits(const GraphViewer& graph_viewer) { - std::vector> node_unit_holder; - std::unordered_map node_unit_map; - - const auto add_node_unit_to_map = [&](const std::vector& node_indices, const NodeUnit* node_unit) { - for (const auto& node_idx : node_indices) { - const auto* node = graph_viewer.GetNode(node_idx); - node_unit_map.insert({node, node_unit}); - } - }; - - // Get QDQ NodeUnits first - QDQ::SelectorManager selector_mgr; - const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer); - - for (const auto& qdq_selection : qdq_selections) { - auto qdq_unit = std::make_unique(graph_viewer, qdq_selection); - - // Fill the node to node_unit map for all nodes in the QDQ Group - add_node_unit_to_map(qdq_selection.dq_nodes, qdq_unit.get()); - add_node_unit_to_map(qdq_selection.q_nodes, qdq_unit.get()); - add_node_unit_to_map({qdq_selection.target_node}, qdq_unit.get()); - - node_unit_holder.push_back(std::move(qdq_unit)); - } - - // Get the left over SingleNode NodeUnits - const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); - for (const auto node_idx : node_indices) { - const auto* node(graph_viewer.GetNode(node_idx)); - - // This is already part of a QDQ NodeUnit - if (node_unit_map.find(node) != node_unit_map.cend()) - continue; - - auto node_unit = std::make_unique(*node); - node_unit_map[node] = node_unit.get(); - node_unit_holder.push_back(std::move(node_unit)); - } - - return std::make_pair(std::move(node_unit_holder), std::move(node_unit_map)); -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared/utils/utils.cc b/onnxruntime/core/providers/shared/utils/utils.cc index c07a0929353b1..2088618538de5 100644 --- a/onnxruntime/core/providers/shared/utils/utils.cc +++ b/onnxruntime/core/providers/shared/utils/utils.cc @@ -4,12 +4,12 @@ #include "utils.h" -#include -#include -#include -#include -#include "core/providers/shared/node_unit/node_unit.h" +#include "core/common/safeint.h" +#include "core/framework/node_unit.h" +#include "core/framework/tensorprotoutils.h" +#include "core/graph/graph.h" #include "core/optimizer/initializer.h" +#include "core/providers/common.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/utils.cc b/onnxruntime/core/providers/utils.cc index ca3fc4fc1972b..b2f9d265ca053 100644 --- a/onnxruntime/core/providers/utils.cc +++ b/onnxruntime/core/providers/utils.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "core/framework/tensorprotoutils.h" -#include "utils.h" +#include "core/providers/utils.h" namespace onnxruntime { namespace utils { @@ -23,6 +23,5 @@ common::Status OutputOptionalWithoutDataHelper(const ONNX_NAMESPACE::TypeProto& return Status::OK(); } #endif - } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc b/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc index 8e7e228f974e6..e2d71cda68ec4 100644 --- a/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc +++ b/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc @@ -6,12 +6,12 @@ #include #include "core/common/common.h" +#include "core/framework/node_unit.h" #include "core/framework/op_node_proto_helper.h" #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/providers/common.h" #include "core/providers/cpu/nn/pool_attributes.h" -#include "core/providers/shared/node_unit/node_unit.h" #include "core/providers/xnnpack/detail/utils.h" // each operator provides a helper to check if supported diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.cc b/onnxruntime/core/providers/xnnpack/detail/utils.cc index 1a32612981120..f9cb45ebc8abc 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.cc +++ b/onnxruntime/core/providers/xnnpack/detail/utils.cc @@ -6,14 +6,14 @@ #include #include "core/common/common.h" +#include "core/common/safeint.h" +#include "core/framework/node_unit.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/indexed_sub_graph.h" #include "core/graph/node_attr_utils.h" +#include "core/optimizer/initializer.h" -#include "core/providers/shared/node_unit/node_unit.h" #include "onnx/defs/attr_proto_util.h" -#include "core/common/safeint.h" -#include "core/optimizer/initializer.h" namespace onnxruntime { namespace xnnpack { diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.h b/onnxruntime/core/providers/xnnpack/detail/utils.h index 2bbf3ac8c2cb5..d555ee2286b84 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.h +++ b/onnxruntime/core/providers/xnnpack/detail/utils.h @@ -10,10 +10,10 @@ #include #include +#include "core/framework/node_unit.h" #include "core/framework/op_kernel.h" #include "core/graph/indexed_sub_graph.h" #include "core/providers/common.h" -#include "core/providers/shared/node_unit/node_unit.h" #include "xnnpack.h" diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index eafbfae6f01e1..12e567e7080b3 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -6,17 +6,17 @@ #include #include -#include "core/graph/function_utils.h" -#include "xnnpack_execution_provider.h" -#include "detail/utils.h" -#include "detail/node_support_checker.h" - #include "core/framework/compute_capability.h" #include "core/framework/kernel_registry.h" -#include "core/providers/shared/node_unit/node_unit.h" +#include "core/framework/node_unit.h" +#include "core/graph/function_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" - -#include "xnnpack_init.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/providers/xnnpack/xnnpack_execution_provider.h" +#include "core/providers/xnnpack/detail/utils.h" +#include "core/providers/xnnpack/detail/node_support_checker.h" +#include "core/providers/xnnpack/xnnpack_init.h" namespace onnxruntime { @@ -268,7 +268,7 @@ std::vector> XnnpackExecutionProvider::GetCap // Get all the NodeUnits in the GraphViewer so we can check if something is in a QDQ node group std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph); // This holds the result of whether a NodeUnit is supported or not, // to prevent nodes in a NodeUnit being checked for multiple times diff --git a/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp b/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp index 484a9a22429d5..969997d2b84ec 100644 --- a/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp +++ b/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "test_fp16.h" +#include #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 13333f1558cc6..fbd5c9b5a137b 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/framework/compute_capability.h" +#include "core/framework/node_unit.h" #include "core/graph/model.h" #include "core/graph/onnx_protobuf.h" #include "core/mlas/inc/mlas.h" @@ -9,7 +10,6 @@ #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" -#include "core/optimizer/utils.h" #include "core/providers/partitioning_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/environment.h" @@ -30,10 +30,6 @@ #pragma warning(disable : 4127) #endif // #if defined(_MSC_VER) -#ifdef USE_NNAPI -#include "core/providers/shared/node_unit/node_unit.h" -#endif // #ifdef USE_NNAPI - struct QDQOpKeys { const char* quantize_linear; const char* dequantize_linear; @@ -3243,14 +3239,14 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { ASSERT_EQ(std::vector({4}), qdq_group.q_nodes); } -// The function GetAllNodeUnits is enabled for NNAPI EP only for now -#ifdef USE_NNAPI +// The function GetAllNodeUnits is used by NNAPI, XNNPACK and QNN +#if defined(USE_NNAPI) || defined(USE_QNN) || defined(USE_XNNPACK) { // Get all the NodeUnits in the graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(whole_graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer); // We should get a single QDQ Node unit in the result ASSERT_EQ(1, node_unit_holder.size()); @@ -3288,7 +3284,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { verify_io_def(qdq_node_unit.Inputs()[2], *whole_graph_viewer.GetNode(2)); // DQ_bias verify_io_def(qdq_node_unit.Outputs()[0], *whole_graph_viewer.GetNode(4)); // Q_output } -#endif // #ifdef USE_NNAPI +#endif // defined(USE_NNAPI) || defined(USE_QNN) || defined(USE_XNNPACK) // Create a graph viewer covers part of the graph // Make sure the qdq conv selector will fail for the partial graph diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc index 0167f7a7718b1..2e073def5d643 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc @@ -220,6 +220,7 @@ InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer& auto compile_capabilities = utils::CreateSupportedPartitions(graph_viewer, supported_compiled_nodes, stop_ops_, generate_metadef_name, ep_name_, onnxruntime::utils::kInternalTestingExecutionProvider, + /*QDQ NodeUnit map*/ nullptr, debug_output_); if (!static_capabilities.empty()) { diff --git a/onnxruntime/test/providers/partitioning_utils_test.cc b/onnxruntime/test/providers/partitioning_utils_test.cc new file mode 100644 index 0000000000000..5db69489afaef --- /dev/null +++ b/onnxruntime/test/providers/partitioning_utils_test.cc @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/common/common.h" +#include "core/graph/graph_viewer.h" +#include "core/graph/model.h" +#include "core/framework/node_unit.h" +#include "core/framework/compute_capability.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/providers/partitioning_utils.h" + +#include "test/optimizer/graph_transform_test_builder.h" +#include "test/optimizer/qdq_test_utils.h" +#include "test/util/include/asserts.h" +#include "test/util/include/test_utils.h" +#include "test/util/include/test/test_environment.h" + +namespace onnxruntime { +namespace test { + +// Test handling of a DQ node that is connected to an initializer at the start of the graph, but not used +// in a QDQ node group until after an unsupported node in the graph. If we do not process QDQ node units +// correctly this DQ will incorrectly be in the first partition, with the rest of the QDQ node group in +// the second partition. +TEST(PartitioningUtilsTest, TestQDQHandling) { + constexpr const ORTCHAR_T* model_uri = ORT_TSTR("testdata/ort_github_issue_19590.onnx"); + auto& logger = DefaultLoggingManager().DefaultLogger(); + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, logger)); + Graph& graph = p_model->MainGraph(); + GraphViewer graph_viewer = GraphViewer(graph); + + // we want everything but the Cast in the test model to be supported + const auto is_node_supported = [&](const Node& node) -> bool { + return node.OpType() != "Cast"; + }; + + const auto on_group_closed = [&](const std::vector& /*group*/) -> bool { + return true; + }; + + const auto gen_metadef_name = [&]() { + static int metadef_id = 0; + return "TestMetaDef_" + std::to_string(metadef_id++); + }; + + std::vector> node_unit_holder; + std::unordered_map node_unit_map; + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + + auto result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed, + gen_metadef_name, "TEST", kCpuExecutionProvider, &node_unit_map, + true); + + // we should have 2 supported partitions, split by the Cast node. + // the first should have the Mul and NOT the DQ for the initializer if everything worked correctly. + ASSERT_EQ(result.size(), size_t(2)) << "Expected 2 partitions"; + ASSERT_EQ(result[0]->sub_graph->nodes.size(), size_t(1)) << "First partition should only have the Mul and not a DQ"; + ASSERT_EQ(result[1]->sub_graph->nodes.size(), size_t(5)); // everything else except the unsupported Cast +} + +/// Check that CreateSupportedPartitions processes all nodes without error. +static void CheckAllNodesProcessed(const std::function& build_model) { + auto& logger = DefaultLoggingManager().DefaultLogger(); + const std::unordered_map domain_to_version = {{"", 15}}; + + Model model("PartitioningUtils_TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, {}, logger); + + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + build_model(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + GraphViewer graph_viewer = GraphViewer(graph); + + std::vector> node_unit_holder; + std::unordered_map node_unit_map; + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + + const auto is_node_supported = [&](const Node& /*node*/) -> bool { + return true; + }; + + const auto on_group_closed = [&](const std::vector& /*group*/) -> bool { + return true; + }; + + const auto gen_metadef_name = [&]() { + static int metadef_id = 0; + return "TestMetaDef_" + std::to_string(metadef_id++); + }; + + auto result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed, + gen_metadef_name, "TEST", kCpuExecutionProvider, &node_unit_map, + true); + + // the 'real' test is that CreateSupportedPartitions doesn't throw due to a mismatch with expected vs processed nodes + // as all ops are supported there should only ever be 1 partition + ASSERT_EQ(result.size(), size_t(1)) << "Expected 1 partition"; +} + +TEST(PartitioningUtilsTest, TestHandlingQDQNodeUnitWithNoQNodes) { + // build graph with QDQ node unit for logical operator (Equal) that has no Q node and a downstream node (Cast). + auto build_model = [](ModelTestBuilder& builder) { + constexpr uint8_t zero_point = 0; + constexpr float qdq_scale = 0.0038f; + const std::vector input_shape = {1, 3, 8, 8}; + + auto* input0 = builder.MakeInput(input_shape, -1.0f, 1.0f); + auto* input1 = builder.MakeInput(input_shape, -1.0f, 1.0f); + auto* output = builder.MakeOutput(); + + // input -> Q -> DQ -> Op + auto* qdq0_output = AddQDQNodePair(builder, input0, qdq_scale, zero_point); + auto* qdq1_output = AddQDQNodePair(builder, input1, qdq_scale, zero_point); + + // Equal -> + auto* equal_output = builder.MakeIntermediate(); + builder.AddNode("Equal", {qdq0_output, qdq1_output}, {equal_output}); + + // -> Cast -> output + Node& cast_node = builder.AddNode("Cast", {equal_output}, {output}); + cast_node.AddAttribute("to", + static_cast(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + }; + + CheckAllNodesProcessed(build_model); +} + +// TopK produces 2 outputs, one of which is used in a QDQ node group (Q of values output) +// and the other (indices output) is not. A downstream node consuming the indices output has an edge from the target +// node and not a Q node. +// To process this correctly, the QDQ NodeUnit must return output edges for both the Q node/s of the values output, +// and the downstream node (Cast in this case) of the indices output. +TEST(PartitioningUtilsTest, TestQDQNodeGroupWithOutputFromTargetNode) { + const auto build_model = [](ModelTestBuilder& builder) { + constexpr uint8_t zero_point = 0; + constexpr float qdq_scale = 0.0038f; + const std::vector input_shape = {1, 3, 8, 8}; + + auto* input0 = builder.MakeInput(input_shape, -1.0f, 1.0f); + + // input -> Q -> DQ -> + auto* qdq0_output = AddQDQNodePair(builder, input0, qdq_scale, zero_point); + + // K input + NodeArg* k_input = builder.MakeInput({1}, {10}); + + // TopK op + NodeArg* values_output = builder.MakeIntermediate(); + NodeArg* indices_output = builder.MakeIntermediate(); + builder.AddNode("TopK", {qdq0_output, k_input}, {values_output, indices_output}); + + // values -> Q -> DQ -> graph output + AddQDQNodePairWithOutputAsGraphOutput(builder, values_output, qdq_scale, zero_point); + + // indices -> Cast -> graph output + auto* i_output = builder.MakeOutput(); + Node& cast_node = builder.AddNode("Cast", {indices_output}, {i_output}); + const auto dst_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32; + cast_node.AddAttribute("to", static_cast(dst_type)); + }; + + CheckAllNodesProcessed(build_model); +} +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/testdata/ort_github_issue_19590.onnx b/onnxruntime/test/testdata/ort_github_issue_19590.onnx new file mode 100644 index 0000000000000..fa07b624780bb Binary files /dev/null and b/onnxruntime/test/testdata/ort_github_issue_19590.onnx differ diff --git a/onnxruntime/test/testdata/ort_github_issue_19590.py b/onnxruntime/test/testdata/ort_github_issue_19590.py new file mode 100644 index 0000000000000..9be07134fd8ad --- /dev/null +++ b/onnxruntime/test/testdata/ort_github_issue_19590.py @@ -0,0 +1,77 @@ +import onnx +from onnx import TensorProto, helper + +# graph with a QDQ MatMul node unit where one input is and initializer -> DQ and the other is on a path that +# contains a supported node followed by an unsupported node followed by the DQ -> MatMul. +# The DQ of the initializer is prior to the unsupported node. If the partitioning utils do not process the QDQ node +# unit together, the DQ for the initializer and the first supported node will be in the first partition, which +# incorrectly breaks up the QDQ node unit. +graph_proto = helper.make_graph( + [ + # DQ of initializer for MatMul B input + helper.make_node( + "DequantizeLinear", + inputs=["matmul_b_uint8", "scale0"], + outputs=["dq_matmul_b"], + name="dq_matmul_b", + ), + # Treat as supported + helper.make_node( + "Mul", + inputs=["input:0", "scale_input"], + outputs=["mul:0"], + name="mul0", + ), + # Treat as unsupported + helper.make_node("Cast", inputs=["mul:0"], outputs=["mul_uint8"], name="cast0", to=2), + # DQ of MatMul A input + helper.make_node( + "DequantizeLinear", + inputs=["mul_uint8", "scale1"], + outputs=["dq_matmul_a"], + name="dq_matmul_a", + ), + # MatMul + helper.make_node( + "MatMul", + inputs=[ + "dq_matmul_a", + "dq_matmul_b", + ], + outputs=["matmul_ab"], + name="matmul_ab", + ), + # Q + helper.make_node( + "QuantizeLinear", + inputs=["matmul_ab", "scale2"], + outputs=["q_matmul_ab"], + name="q_matmul_ab", + ), + # DQ for model output + helper.make_node( + "DequantizeLinear", + inputs=["q_matmul_ab", "scale2"], + outputs=["out:0"], + name="dq_graph_output", + ), + ], + "Main_graph", + [ + helper.make_tensor_value_info("input:0", TensorProto.FLOAT, [3, 2]), + ], + [ + helper.make_tensor_value_info("out:0", TensorProto.FLOAT, [3, 2]), + ], + [ + helper.make_tensor("scale0", TensorProto.FLOAT, [1], [20.0]), + helper.make_tensor("scale1", TensorProto.FLOAT, [1], [30.0]), + helper.make_tensor("scale2", TensorProto.FLOAT, [1], [40.0]), + helper.make_tensor("matmul_b_uint8", TensorProto.UINT8, [2, 2], [1, 2, 3, 4]), + helper.make_tensor("scale_input", TensorProto.FLOAT, [2], [3.0, 4.0]), + ], +) + +model = helper.make_model(graph_proto) +onnx.checker.check_model(model, True) +onnx.save(model, "ort_github_issue_19590.onnx")