Skip to content

Commit

Permalink
Fix issue with handling QDQ node group that has output edge from targ…
Browse files Browse the repository at this point in the history
…et node (int64_t indices output that is not quantized) as well as edge through Q node (values output)
  • Loading branch information
skottmckay committed Mar 7, 2024
1 parent 20ba3eb commit 4df61e9
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 152 deletions.
123 changes: 57 additions & 66 deletions onnxruntime/core/framework/node_unit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,32 +86,32 @@ const std::vector<const Node*> GetQDQIONodes(const GraphViewer& graph_viewer,
for (const auto& node_idx : src_nodes) {
io_nodes.push_back(graph_viewer.GetNode(node_idx));
}

return io_nodes;
}

// Get the input or output NodeUnitIODef(s) for the given QDQ NodeGroup
std::vector<NodeUnitIODef> GetQDQIODefs(const Node& target_node, const QDQ::NodeGroup& node_group,
bool is_input) {
std::vector<NodeUnitIODef> GetQDQIODefs(const Node& target_node, const QDQ::NodeGroup& node_group, bool is_input) {
const auto& dq_or_q_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes;
const auto target_node_io_defs = is_input ? target_node.InputDefs() : target_node.OutputDefs();
const size_t target_node_io_defs_size = target_node_io_defs.size();

// Find all the quantized IO defs and indices (for the input to the target node)
// Find all the quantized IO defs and indices (for the input/output of the target node)
std::unordered_map<size_t, NodeUnitIODef> quantized_io_defs;
quantized_io_defs.reserve(target_node_io_defs_size);

auto cur = is_input ? target_node.InputEdgesBegin() : target_node.OutputEdgesBegin();
auto end = is_input ? target_node.InputEdgesEnd() : target_node.OutputEdgesEnd();

for (; cur != end; ++cur) {
const Node& node = cur->GetNode();

// If we can find the node index in the dq or q nodes, then this is a quantize node (can be DQ or Q depends on is_input)
// If we can find the node index in the dq or q nodes this is a quantized input/output
if (std::find(dq_or_q_nodes.cbegin(), dq_or_q_nodes.cend(), node.Index()) != dq_or_q_nodes.cend()) {
const auto node_inputs = node.InputDefs();
// quantization scale and zp are always the input[1, 2]
NodeUnitIODef::QuantParam quant_param{
*node_inputs[1],
node_inputs.size() == 3 ? node_inputs[2] : nullptr};
NodeUnitIODef::QuantParam quant_param{*node_inputs[1], node_inputs.size() == 3 ? node_inputs[2] : nullptr};

if (is_input) {
// DQ is input to the target node, use the DstArgIndex
auto idx = cur->GetDstArgIndex();
Expand All @@ -131,7 +131,7 @@ std::vector<NodeUnitIODef> GetQDQIODefs(const Node& target_node, const QDQ::Node
std::vector<NodeUnitIODef> io_defs;
io_defs.reserve(target_node_io_defs_size);
for (size_t i = 0; i < target_node_io_defs_size; i++) {
// If we can find the NodeUnitIODef for this index, this is a quantized input
// If we can find the NodeUnitIODef for this index, this is a quantized input/output
if (quantized_io_defs.find(i) != quantized_io_defs.cend()) {
io_defs.push_back(std::move(quantized_io_defs.at(i)));
} else {
Expand All @@ -153,20 +153,41 @@ NodeUnit::NodeUnit(const Node& node)
}

NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group)
: q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)},
dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)},
: dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)},
target_node_(*graph_viewer.GetNode(node_group.target_node)),
q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)},
type_(Type::QDQGroup),
inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)},
outputs_{GetQDQIODefs(target_node_, node_group, false /* is_input */)} {
ORT_THROW_IF_ERROR(QDQ::ValidateNodeGroupDQNodes(graph_viewer, target_node_, dq_nodes_));
ORT_THROW_IF_ERROR(QDQ::ValidateNodeGroupQDQNodes(graph_viewer, target_node_, dq_nodes_, q_nodes_));

input_edge_count_ = std::accumulate(dq_nodes_.cbegin(), dq_nodes_.cend(), size_t(0),
[](size_t acc, const Node* node) { return acc + node->GetInputEdgesCount(); });

// add edges for inputs that are not from DQ nodes. there is one edge to each DQ node.
// other inputs could come from initializers or graph inputs (no edges) or other nodes (edge).
input_edge_count_ += target_node_.GetInputEdgesCount() - dq_nodes_.size();

// create output edges. each target node output either goes to Q node/s or non-Q node/s.
// ValidateNodeGroupQDQNodes ensures this.
auto cur_edge = target_node_.OutputEdgesBegin();
auto end_edge = target_node_.OutputEdgesEnd();
for (; cur_edge != end_edge; ++cur_edge) {
const Node& node = cur_edge->GetNode();

// if node is in q_nodes we hide the Q node.
if (std::find(q_nodes_.cbegin(), q_nodes_.cend(), &node) != q_nodes_.cend()) {
auto src_idx = cur_edge->GetSrcArgIndex();
auto q_cur_edge = node.OutputEdgesBegin();
auto q_end_edge = node.OutputEdgesEnd();
for (; q_cur_edge != q_end_edge; ++q_cur_edge) {
output_edges_.insert(Node::EdgeEnd{q_cur_edge->GetNode(), src_idx, q_cur_edge->GetDstArgIndex()});
}
} else {
// non-Q node, or Q node that isn't in the QDQ node group (unexpected but may be possible). add as-is.
output_edges_.insert(*cur_edge);
}
}
}

const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); }
Expand All @@ -181,8 +202,7 @@ void NodeUnit::InitForSingleNode() {
const auto& input_defs = target_node_.InputDefs();
const auto& output_defs = target_node_.OutputDefs();
auto qlinear_type = GetQLinearOpType(target_node_);
if (qlinear_type == QLinearOpType::Unknown ||
IsVariadicQLinearOp(qlinear_type)) { // TODO, add variadic support
if (qlinear_type == QLinearOpType::Unknown || IsVariadicQLinearOp(qlinear_type)) { // TODO, add variadic support
// Not a Qlinear op, add all inputs / outputs
auto add_all_io = [](std::vector<NodeUnitIODef>& defs,
const ConstPointerContainer<std::vector<NodeArg*>>& node_defs) {
Expand All @@ -192,86 +212,57 @@ void NodeUnit::InitForSingleNode() {
defs.push_back(NodeUnitIODef{*def, std::nullopt});
}
};

add_all_io(inputs_, input_defs);
add_all_io(outputs_, output_defs);
} else if (IsUnaryQLinearOp(qlinear_type)) {
// Unary QLinear Op has 5 inputs
// x, x_scale, x_zp, y_scale, y_zp (optional)
inputs_.push_back(NodeUnitIODef{
*input_defs[0],
NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}});

outputs_.push_back(NodeUnitIODef{
*output_defs[0],
NodeUnitIODef::QuantParam{*input_defs[3],
input_defs.size() > 4
? input_defs[4]
: nullptr}});
inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}});
outputs_.push_back(NodeUnitIODef{*output_defs[0],
NodeUnitIODef::QuantParam{*input_defs[3],
input_defs.size() > 4 ? input_defs[4] : nullptr}});

} else if (IsBinaryQLinearOp(qlinear_type)) {
// Binary QLinear Op has 9 inputs
// x1, x1_scale, x1_zp, x2/w, x2_scale, x2_zp, y_scale , y_zp, B
inputs_.push_back(NodeUnitIODef{
*input_defs[0],
NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}});
inputs_.push_back(NodeUnitIODef{
*input_defs[3],
NodeUnitIODef::QuantParam{*input_defs[4], input_defs[5]}});

if (input_defs.size() == 9) { // has Bias
inputs_.push_back(NodeUnitIODef{
*input_defs[8],
std::nullopt}); // for Bias the scale and zp are optional
inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}});
inputs_.push_back(NodeUnitIODef{*input_defs[3], NodeUnitIODef::QuantParam{*input_defs[4], input_defs[5]}});

if (input_defs.size() == 9) { // has Bias
inputs_.push_back(NodeUnitIODef{*input_defs[8], std::nullopt}); // for Bias the scale and zp are optional
}

outputs_.push_back(NodeUnitIODef{
*output_defs[0],
NodeUnitIODef::QuantParam{*input_defs[6], input_defs[7]}});
outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[6], input_defs[7]}});

} else if (qlinear_type == QLinearOpType::DequantizeLinear) {
// DequantizeLinear has 3 inputs
// x, x_scale, x_zp
// output is not quantized
inputs_.push_back(NodeUnitIODef{
*input_defs[0],
NodeUnitIODef::QuantParam{*input_defs[1],
input_defs.size() == 3
? input_defs[2]
: nullptr}});
inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs.size() == 3
? input_defs[2]
: nullptr}});
outputs_.push_back(NodeUnitIODef{*output_defs[0], std::nullopt});

} else if (qlinear_type == QLinearOpType::QuantizeLinear) {
// QuantizeLinear the input is not quantized and has 3 inputs
// x, y_scale, y_zp (optional)
// The output is quantized
inputs_.push_back(NodeUnitIODef{*input_defs[0], std::nullopt});
outputs_.push_back(NodeUnitIODef{
*output_defs[0],
NodeUnitIODef::QuantParam{*input_defs[1],
input_defs.size() == 3
? input_defs[2]
: nullptr}});
outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs.size() == 3
? input_defs[2]
: nullptr}});
} else {
ORT_THROW("The QLinear op [", static_cast<uint8_t>(qlinear_type), "] is not supported");
}
}

Node::EdgeConstIterator NodeUnit::OutputEdgesBegin(size_t index) const {
// q_nodes_ can be empty for logical operators with DQ inputs. as they produce bool output no Q is possible
if (type_ == Type::SingleNode || q_nodes_.empty()) {
ORT_ENFORCE(index == 0, "invalid output node index");
return target_node_.OutputEdgesBegin();
} else {
ORT_ENFORCE(index < q_nodes_.size(), "invalid output node index");
return q_nodes_[index]->OutputEdgesBegin();
}
Node::EdgeConstIterator NodeUnit::OutputEdgesBegin() const {
return (type_ == Type::SingleNode) ? target_node_.OutputEdgesBegin() : output_edges_.begin();
}

Node::EdgeConstIterator NodeUnit::OutputEdgesEnd(size_t index) const {
if (type_ == Type::SingleNode || q_nodes_.empty()) {
ORT_ENFORCE(index == 0, "invalid output node index");
return target_node_.OutputEdgesEnd();
} else {
ORT_ENFORCE(index < q_nodes_.size(), "invalid output node index");
return q_nodes_[index]->OutputEdgesEnd();
}
Node::EdgeConstIterator NodeUnit::OutputEdgesEnd() const {
return (type_ == Type::SingleNode) ? target_node_.OutputEdgesEnd() : output_edges_.end();
}

std::vector<const Node*> NodeUnit::GetAllNodesInGroup() const noexcept {
Expand Down
18 changes: 12 additions & 6 deletions onnxruntime/core/framework/node_unit.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,28 @@ class NodeUnit {
/// Number of input edges to the logical node. For a QDQ node this is the count of input edges to the DQ nodes
/// plus any other edges to the target node for inputs that are not via a DQ node.
size_t InputEdgeCount() const { return input_edge_count_; }
Node::EdgeConstIterator OutputEdgesBegin(size_t index) const;
Node::EdgeConstIterator OutputEdgesEnd(size_t index) const;

// output edges. src index is for outputs of the target node. dest index and node is for consumer of node unit
// output. any Q nodes are hidden.
Node::EdgeConstIterator OutputEdgesBegin() const;
Node::EdgeConstIterator OutputEdgesEnd() const;

private:
const std::vector<const Node*> q_nodes_; // q-nodes for this NodeUnit
const std::vector<const Node*> dq_nodes_; // dq nodes for this NodeUnit, not all inputs
// Initialization for a NodeUnit that contains a single node
void InitForSingleNode();

const std::vector<const Node*> dq_nodes_; // dq nodes for this NodeUnit, not necessarily all inputs
const Node& target_node_;
const std::vector<const Node*> q_nodes_; // q-nodes for this NodeUnit. not necessarily all outputs
const Type type_;

std::vector<NodeUnitIODef> inputs_;
std::vector<NodeUnitIODef> outputs_;

size_t input_edge_count_; // total number of input edges

// Initializing for a single Node
void InitForSingleNode();
// output edges, hiding any Q nodes involved. src_idx will be value from target node. only used for QDQ node group.
Node::EdgeSet output_edges_;
};

// Get all the nodes in the given graph_viewer as NodeUnits (SingleNode or QDQGroup)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Nod
return false;
}

if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes);
!dq_validation_status.IsOK()) {
if (const auto qdq_validation_status = QDQ::ValidateNodeGroupQDQNodes(graph_viewer, node, dq_nodes, q_nodes);
!qdq_validation_status.IsOK()) {
return false;
}

Expand Down Expand Up @@ -153,8 +153,8 @@ bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer,
return false;
}

if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes);
!dq_validation_status.IsOK()) {
if (const auto qdq_validation_status = QDQ::ValidateNodeGroupQDQNodes(graph_viewer, node, dq_nodes, q_nodes);
!qdq_validation_status.IsOK()) {
return false;
}

Expand Down Expand Up @@ -544,8 +544,8 @@ bool TopKNodeGroupSelector::Check(const GraphViewer& graph_viewer,
return false;
}

if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes);
!dq_validation_status.IsOK()) {
if (const auto qdq_validation_status = QDQ::ValidateNodeGroupQDQNodes(graph_viewer, node, dq_nodes, q_nodes);
!qdq_validation_status.IsOK()) {
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,10 @@ std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& grap
return qdq_selections;
}

Status ValidateNodeGroupDQNodes(const GraphViewer& graph_viewer,
const Node& target_node,
gsl::span<const Node* const> dq_nodes) {
Status ValidateNodeGroupQDQNodes(const GraphViewer& graph_viewer,
const Node& target_node,
gsl::span<const Node* const> dq_nodes,
gsl::span<const Node* const> q_nodes) {
// Within a QDQ node group, a target node input is the only consumer of each DQ.
// This should have been ensured by the EnsureUniqueDQForNodeUnit graph transformer, but other graph modifications
// may have happened since. Verify that this is still true.
Expand All @@ -345,6 +346,41 @@ Status ValidateNodeGroupDQNodes(const GraphViewer& graph_viewer,
dq_node->Name(), ", target node: ", target_node.Name());
}

// an output from the target node can have either Q consumers or direct consumers. it cannot have both.
// this must be checked on a per output basis.
// NOTE: rules about the target node not producing a graph output must be checked by the selector as it's operator
// dependent.
// e.g. TopK produces values and indices. The indices output won't be quantized, so even if we replace the TopK QDQ
// node group with a quantized TopK, an int64_t indices value will be produced and can provide a graph output.
if (!q_nodes.empty()) {
auto cur_edge = target_node.OutputEdgesBegin();
auto end_edge = target_node.OutputEdgesEnd();
std::vector<const Node*> output_consumers(target_node.OutputDefs().size(), nullptr);

for (; cur_edge != end_edge; ++cur_edge) {
auto output_idx = cur_edge->GetSrcArgIndex();
const Node& this_consumer = cur_edge->GetNode();
const Node* existing_consumer = output_consumers[output_idx];

if (existing_consumer != nullptr) {
// another edge for this output. either both are Q or both are not.
bool valid = true;
if (existing_consumer->OpType() == "QuantizeLinear") {
valid = this_consumer.OpType() == "QuantizeLinear";
} else {
valid = this_consumer.OpType() != "QuantizeLinear";
}

ORT_RETURN_IF_NOT(valid,
"QDQ node group cannot have an output from the target node being consumed by a Q node and "
"a non-Q node. target node: ",
target_node.Name());
} else {
output_consumers[output_idx] = &this_consumer;
}
}
}

return Status::OK();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ class SelectorManager {

// Checks whether the provided DQ nodes are valid for forming a QDQ node group with the provided target node.
// Returns successful status if so, failed status with reason otherwise.
Status ValidateNodeGroupDQNodes(const GraphViewer& graph_viewer,
const Node& target_node,
gsl::span<const Node* const> dq_nodes);
Status ValidateNodeGroupQDQNodes(const GraphViewer& graph_viewer,
const Node& target_node,
gsl::span<const Node* const> dq_nodes,
gsl::span<const Node* const> q_nodes);

} // namespace QDQ
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ int32_t ModelBuilder::FindActivation(const NodeUnit& node_unit) {

int32_t fuse_code = ANEURALNETWORKS_FUSED_NONE;
bool fuse_code_assigned_from_activation = false;
for (auto it = node_unit.OutputEdgesBegin(0), end = node_unit.OutputEdgesEnd(0); it != end; ++it) {
for (auto it = node_unit.OutputEdgesBegin(), end = node_unit.OutputEdgesEnd(); it != end; ++it) {
const auto& dst_node = it->GetNode();
const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()];

Expand Down
Loading

0 comments on commit 4df61e9

Please sign in to comment.