Skip to content

Commit

Permalink
- Refactor to build in the various configurations.
Browse files Browse the repository at this point in the history
  - The whole QDQ setup needs a rethink at some point as it's currently spread across too many places (framework, optimizer, base providers lib, EP specific providers lib)
- move NodeGroup to framework/node_unit.h and ValidateNodeGroupQDQNodes to NodeGroup::CanCreateNodeGroup so it's in the framework lib as it's used by NodeUnit
- move GetAllNodeUnits to optimizer
  - doesn't quite belong there but this works will all the current EPs that use it.
  • Loading branch information
skottmckay committed Mar 7, 2024
1 parent cb41315 commit 16f8c09
Show file tree
Hide file tree
Showing 15 changed files with 177 additions and 151 deletions.
122 changes: 75 additions & 47 deletions onnxruntime/core/framework/node_unit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

#include "node_unit.h"

Check warning on line 4 in onnxruntime/core/framework/node_unit.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/core/framework/node_unit.cc:4: Include the directory when naming header files [build/include_subdir] [4]
#include "core/graph/graph_viewer.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"

namespace onnxruntime {

Expand Down Expand Up @@ -145,6 +143,80 @@ std::vector<NodeUnitIODef> GetQDQIODefs(const Node& target_node, const QDQ::Node

} // namespace

Status QDQ::NodeGroup::CanCreateNodeGroup(const GraphViewer& graph_viewer,
const Node& target_node,
gsl::span<const Node* const> dq_nodes,
gsl::span<const Node* const> q_nodes) {
// Within a QDQ node group, a target node input is the only consumer of each DQ.
// This should have been ensured by the EnsureUniqueDQForNodeUnit graph transformer, but other graph modifications
// may have happened since. Verify that this is still true.
for (const auto* dq_node : dq_nodes) {
const bool dq_produces_graph_output = graph_viewer.NodeProducesGraphOutput(*dq_node);
ORT_RETURN_IF(dq_produces_graph_output,
"QDQ node group cannot have DQ node that produces a graph output. DQ node: ", dq_node->Name(),
", target node: ", target_node.Name());

const bool dq_has_single_output_edge_to_target =
dq_node->GetOutputEdgesCount() == 1 &&
dq_node->OutputEdgesBegin()->GetNode().Index() == target_node.Index();
ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target,
"QDQ node group cannot have DQ that doesn't have a single output edge to the target node. "
"DQ node: ",
dq_node->Name(), ", target node: ", target_node.Name());
}

// an output from the target node can have either Q consumers or direct consumers. it cannot have both.
// this must be checked on a per output basis.
// e.g. TopK produces values and indices. The indices output won't be quantized, so even if we replace the TopK QDQ
// node group with a quantized TopK, an int64_t indices value will be produced and can provide a graph output.
if (!q_nodes.empty()) {
auto cur_edge = target_node.OutputEdgesBegin();
auto end_edge = target_node.OutputEdgesEnd();
std::vector<const Node*> output_consumers(target_node.OutputDefs().size(), nullptr);

for (; cur_edge != end_edge; ++cur_edge) {
auto output_idx = cur_edge->GetSrcArgIndex();
const Node& this_consumer = cur_edge->GetNode();
const Node* existing_consumer = output_consumers[output_idx];

if (existing_consumer != nullptr) {
// another edge for this output. either both are Q or both are not.
bool valid = true;
if (existing_consumer->OpType() == "QuantizeLinear") {
valid = this_consumer.OpType() == "QuantizeLinear";
} else {
valid = this_consumer.OpType() != "QuantizeLinear";
}

ORT_RETURN_IF_NOT(valid,
"QDQ node group cannot have an output from the target node being consumed by a Q node and "
"a non-Q node. target node: ",
target_node.Name());
} else {
output_consumers[output_idx] = &this_consumer;
}
}

const auto& graph_outputs = graph_viewer.GetOutputs();
for (size_t idx = 0, end = output_consumers.size(); idx < end; ++idx) {
// any output with a Q cannot be a graph output as it will disappear if the QDQ node unit is converted to
// a quantized op.
if (output_consumers[idx] != nullptr && output_consumers[idx]->OpType() == "QuantizeLinear") {
const auto& output_name = target_node.OutputDefs()[idx]->Name();
bool is_graph_output = std::any_of(graph_outputs.begin(), graph_outputs.end(),
[&output_name](const NodeArg* node_arg) {
return node_arg->Name() == output_name;
});
ORT_RETURN_IF(is_graph_output,
"QDQ node group cannot have an output from the target node that is consumed by a Q node and "
"a graph output. target node: ",
target_node.Name(), " output idx:", idx);
}
}
}

return Status::OK();
}
NodeUnit::NodeUnit(const Node& node)
: target_node_(node),
type_(Type::SingleNode),
Expand All @@ -159,7 +231,7 @@ NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_g
type_(Type::QDQGroup),
inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)},
outputs_{GetQDQIODefs(target_node_, node_group, false /* is_input */)} {
ORT_THROW_IF_ERROR(QDQ::ValidateNodeGroupQDQNodes(graph_viewer, target_node_, dq_nodes_, q_nodes_));
ORT_THROW_IF_ERROR(QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, target_node_, dq_nodes_, q_nodes_));

input_edge_count_ = std::accumulate(dq_nodes_.cbegin(), dq_nodes_.cend(), size_t(0),
[](size_t acc, const Node* node) { return acc + node->GetInputEdgesCount(); });
Expand Down Expand Up @@ -272,48 +344,4 @@ std::vector<const Node*> NodeUnit::GetAllNodesInGroup() const noexcept {
return all_nodes;
}

std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
GetAllNodeUnits(const GraphViewer& graph_viewer) {
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;

const auto add_node_unit_to_map = [&](const std::vector<NodeIndex>& node_indices, const NodeUnit* node_unit) {
for (const auto& node_idx : node_indices) {
const auto* node = graph_viewer.GetNode(node_idx);
node_unit_map.insert({node, node_unit});
}
};

// Get QDQ NodeUnits first
QDQ::SelectorManager selector_mgr;
const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer);

for (const auto& qdq_selection : qdq_selections) {
auto qdq_unit = std::make_unique<NodeUnit>(graph_viewer, qdq_selection);

// Fill the node to node_unit map for all nodes in the QDQ Group
add_node_unit_to_map(qdq_selection.dq_nodes, qdq_unit.get());
add_node_unit_to_map(qdq_selection.q_nodes, qdq_unit.get());
add_node_unit_to_map({qdq_selection.target_node}, qdq_unit.get());

node_unit_holder.push_back(std::move(qdq_unit));
}

// Get the left over SingleNode NodeUnits
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
for (const auto node_idx : node_indices) {
const auto* node(graph_viewer.GetNode(node_idx));

// This is already part of a QDQ NodeUnit
if (node_unit_map.find(node) != node_unit_map.cend())
continue;

auto node_unit = std::make_unique<NodeUnit>(*node);
node_unit_map[node] = node_unit.get();
node_unit_holder.push_back(std::move(node_unit));
}

return std::make_pair(std::move(node_unit_holder), std::move(node_unit_map));
}

} // namespace onnxruntime
23 changes: 15 additions & 8 deletions onnxruntime/core/framework/node_unit.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,21 @@ class NodeArg;
class Path;

namespace QDQ {
struct NodeGroup;
}
// Struct to represent a DequantizeLinear -> Op -> QuantizeLinear node group
struct NodeGroup {
std::vector<NodeIndex> dq_nodes;
std::vector<NodeIndex> q_nodes;
NodeIndex target_node;

// Validator to check if the set of nodes can form a valid QDQ NodeGroup.
// Checks target node is only consumer of each DQ, and that the outputs remain valid if the QDQ node group was to
// be converted into a single node with a quantized operator.
static Status CanCreateNodeGroup(const GraphViewer& graph_viewer,
const Node& target_node,
gsl::span<const Node* const> dq_nodes,
gsl::span<const Node* const> q_nodes);
};
} // namespace QDQ

// Definition of one input or output
// If the optional quant_param is present, then this is a quantized input,
Expand Down Expand Up @@ -96,10 +109,4 @@ class NodeUnit {
Node::EdgeSet output_edges_;
};

// Get all the nodes in the given graph_viewer as NodeUnits (SingleNode or QDQGroup)
// And return a map to quick query the NodeUnit which contains the given Node,
// Note, the value of the map is owned by the vector of std::unique_ptr<NodeUnit>
std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
GetAllNodeUnits(const GraphViewer& graph_viewer);

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Nod
return false;
}

if (const auto qdq_validation_status = QDQ::ValidateNodeGroupQDQNodes(graph_viewer, node, dq_nodes, q_nodes);
if (const auto qdq_validation_status = NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes);
!qdq_validation_status.IsOK()) {
return false;
}
Expand Down Expand Up @@ -153,7 +153,7 @@ bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer,
return false;
}

if (const auto qdq_validation_status = QDQ::ValidateNodeGroupQDQNodes(graph_viewer, node, dq_nodes, q_nodes);
if (const auto qdq_validation_status = NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes);
!qdq_validation_status.IsOK()) {
return false;
}
Expand Down Expand Up @@ -544,7 +544,7 @@ bool TopKNodeGroupSelector::Check(const GraphViewer& graph_viewer,
return false;
}

if (const auto qdq_validation_status = QDQ::ValidateNodeGroupQDQNodes(graph_viewer, node, dq_nodes, q_nodes);
if (const auto qdq_validation_status = QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes);
!qdq_validation_status.IsOK()) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

#include "core/framework/node_unit.h"
#include "core/optimizer/selectors_actions/selector_action_transformer.h"

namespace onnxruntime {
Expand All @@ -13,13 +14,6 @@ class Node;

namespace QDQ {

// Struct to represent a DQ->Op->Q node group
struct NodeGroup {
std::vector<NodeIndex> dq_nodes;
std::vector<NodeIndex> q_nodes;
NodeIndex target_node;
};

class NodeGroupSelector {
public:
// This is a QDQ Selectors only function, will return QDQ::NodeGroup instead of NodesToOptimizeIndices
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <core/providers/common.h>

#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"

namespace onnxruntime {
namespace QDQ {
Expand Down Expand Up @@ -324,64 +325,48 @@ std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& grap
return qdq_selections;
}

Status ValidateNodeGroupQDQNodes(const GraphViewer& graph_viewer,
const Node& target_node,
gsl::span<const Node* const> dq_nodes,
gsl::span<const Node* const> q_nodes) {
// Within a QDQ node group, a target node input is the only consumer of each DQ.
// This should have been ensured by the EnsureUniqueDQForNodeUnit graph transformer, but other graph modifications
// may have happened since. Verify that this is still true.
for (const auto* dq_node : dq_nodes) {
const bool dq_produces_graph_output = graph_viewer.NodeProducesGraphOutput(*dq_node);
ORT_RETURN_IF(dq_produces_graph_output,
"QDQ node group cannot have DQ node that produces a graph output. DQ node: ", dq_node->Name(),
", target node: ", target_node.Name());

const bool dq_has_single_output_edge_to_target =
dq_node->GetOutputEdgesCount() == 1 &&
dq_node->OutputEdgesBegin()->GetNode().Index() == target_node.Index();
ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target,
"QDQ node group cannot have DQ that doesn't have a single output edge to the target node. "
"DQ node: ",
dq_node->Name(), ", target node: ", target_node.Name());
}
std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
GetAllNodeUnits(const GraphViewer& graph_viewer) {
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;

Check warning on line 331 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc:331: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]

// an output from the target node can have either Q consumers or direct consumers. it cannot have both.
// this must be checked on a per output basis.
// NOTE: rules about the target node not producing a graph output must be checked by the selector as it's operator
// dependent.
// e.g. TopK produces values and indices. The indices output won't be quantized, so even if we replace the TopK QDQ
// node group with a quantized TopK, an int64_t indices value will be produced and can provide a graph output.
if (!q_nodes.empty()) {
auto cur_edge = target_node.OutputEdgesBegin();
auto end_edge = target_node.OutputEdgesEnd();
std::vector<const Node*> output_consumers(target_node.OutputDefs().size(), nullptr);

for (; cur_edge != end_edge; ++cur_edge) {
auto output_idx = cur_edge->GetSrcArgIndex();
const Node& this_consumer = cur_edge->GetNode();
const Node* existing_consumer = output_consumers[output_idx];

if (existing_consumer != nullptr) {
// another edge for this output. either both are Q or both are not.
bool valid = true;
if (existing_consumer->OpType() == "QuantizeLinear") {
valid = this_consumer.OpType() == "QuantizeLinear";
} else {
valid = this_consumer.OpType() != "QuantizeLinear";
}

ORT_RETURN_IF_NOT(valid,
"QDQ node group cannot have an output from the target node being consumed by a Q node and "
"a non-Q node. target node: ",
target_node.Name());
} else {
output_consumers[output_idx] = &this_consumer;
}
const auto add_node_unit_to_map = [&](const std::vector<NodeIndex>& node_indices, const NodeUnit* node_unit) {
for (const auto& node_idx : node_indices) {
const auto* node = graph_viewer.GetNode(node_idx);
node_unit_map.insert({node, node_unit});
}
};

// Get QDQ NodeUnits first
QDQ::SelectorManager selector_mgr;
const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer);

for (const auto& qdq_selection : qdq_selections) {
auto qdq_unit = std::make_unique<NodeUnit>(graph_viewer, qdq_selection);

// Fill the node to node_unit map for all nodes in the QDQ Group
add_node_unit_to_map(qdq_selection.dq_nodes, qdq_unit.get());
add_node_unit_to_map(qdq_selection.q_nodes, qdq_unit.get());
add_node_unit_to_map({qdq_selection.target_node}, qdq_unit.get());

node_unit_holder.push_back(std::move(qdq_unit));
}

// Get the left over SingleNode NodeUnits
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
for (const auto node_idx : node_indices) {
const auto* node(graph_viewer.GetNode(node_idx));

// This is already part of a QDQ NodeUnit
if (node_unit_map.find(node) != node_unit_map.cend())
continue;

auto node_unit = std::make_unique<NodeUnit>(*node);

Check warning on line 364 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for make_unique<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc:364: Add #include <memory> for make_unique<> [build/include_what_you_use] [4]
node_unit_map[node] = node_unit.get();
node_unit_holder.push_back(std::move(node_unit));
}

return Status::OK();
return std::make_pair(std::move(node_unit_holder), std::move(node_unit_map));

Check warning on line 369 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc:369: Add #include <utility> for move [build/include_what_you_use] [4]
}

} // namespace QDQ
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "core/common/common.h"
#include "core/common/gsl.h"
#include "core/common/inlined_containers.h"
#include "core/framework/node_unit.h"
#include "core/graph/basic_types.h"

#if !defined(ORT_MINIMAL_BUILD)
Expand Down Expand Up @@ -78,12 +79,16 @@ class SelectorManager {
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SelectorManager);
};

// Checks whether the provided DQ nodes are valid for forming a QDQ node group with the provided target node.
// Returns successful status if so, failed status with reason otherwise.
Status ValidateNodeGroupQDQNodes(const GraphViewer& graph_viewer,
const Node& target_node,
gsl::span<const Node* const> dq_nodes,
gsl::span<const Node* const> q_nodes);
// Get all the nodes in the given graph_viewer as NodeUnits (SingleNode or QDQGroup)
// And return a map to quick query the NodeUnit which contains the given Node,
// Note, the value of the map is owned by the vector of std::unique_ptr<NodeUnit>
//
// TODO: The overall QDQ setup needs refactoring to separate out generic functionality from optimizer specific

Check warning on line 86 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h:86: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// functionality.
// We currently have a bit of a mess with generic things like this to get all the node units being in the optimizer
// library whereas it should be able to be used by an EP with no dependency on optimizers.
std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>

Check warning on line 90 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for pair<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h:90: Add #include <utility> for pair<> [build/include_what_you_use] [4]
GetAllNodeUnits(const GraphViewer& graph_viewer);

} // namespace QDQ
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include "core/framework/tensorprotoutils.h"
#include "core/graph/graph_viewer.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/providers/common.h"
#include "core/providers/nnapi/nnapi_builtin/nnapi_api_helper.h"
#include "core/providers/nnapi/nnapi_builtin/builders/helper.h"
Expand Down Expand Up @@ -119,7 +121,7 @@ const NodeUnit& ModelBuilder::GetNodeUnit(const Node* node) const {
}

void ModelBuilder::PreprocessNodeUnits() {
std::tie(node_unit_holder_, node_unit_map_) = GetAllNodeUnits(graph_viewer_);
std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_);
}

// Help to get all quantized operators' input and the NodeUnit(s) using the input
Expand Down
Loading

0 comments on commit 16f8c09

Please sign in to comment.