Skip to content

Commit

Permalink
Extend DoubleQDQPairsRemover to handle sequences that end in duplicat…
Browse files Browse the repository at this point in the history
…e DQ nodes (#20759)

### Description
Extend the DoubleQDQPairsRemover optimizer to also handle sequences that
end in duplicate DQ nodes.

For example, the following sequence:
```
 Q1 --> DQ1 --> Q2 --+--> DQ2
                     |
                     +--> DQ2'
```
Is now simplified to:
```
 Q1 ---+--> DQ2
       |
       +--> DQ2'
```


### Motivation and Context
The EnsureUniqueDQNodeUnits pass may add duplicate DQ nodes to ensure
valid QDQ node units. The DoubleQDQPairsRemover should still be able to
remove unnecessary QDQ ops if the target sequence ends in duplicate DQ
nodes.

---------

Co-authored-by: Edward Chen <[email protected]>
  • Loading branch information
adrianlizarraga and edgchen1 authored May 25, 2024
1 parent a7bc49a commit 5bae32e
Show file tree
Hide file tree
Showing 7 changed files with 471 additions and 88 deletions.
204 changes: 129 additions & 75 deletions onnxruntime/core/optimizer/double_qdq_pairs_remover.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,46 @@
// Licensed under the MIT License.
#include "core/optimizer/double_qdq_pairs_remover.h"
#include <cassert>
#include <string>

#include "core/common/span_utils.h"
#include "core/common/inlined_containers_fwd.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/qdq_transformer/qdq_util.h"

namespace onnxruntime {

/// <summary>
/// Returns the zero-point type from the given QuantizeLinear node.
/// </summary>
/// <param name="graph">Graph</param>
/// <param name="q_node">QuantizeLinear node</param>
/// <param name="zp_data_type">Output parameter to store the zero-point data type</param>
/// <returns>True if successfully extracted the zero-point data type</returns>
static bool GetQNodeZeroPointType(const Graph& graph, const Node& q_node,
/*out*/ ONNX_NAMESPACE::TensorProto_DataType& zp_data_type) {
assert(q_node.OpType() == "QuantizeLinear");
const auto input_defs = q_node.InputDefs();

if (QDQ::InputIndex::ZERO_POINT_ID >= input_defs.size() || !input_defs[QDQ::InputIndex::ZERO_POINT_ID]->Exists()) {
// If a zero_point input is absent, get the type from the "output_dtype" attribute or default to uint8.
// The "output_dtype" attribute was added in ONNX opset 21.
const auto* attr = graph_utils::GetNodeAttribute(q_node, "output_dtype");
zp_data_type = attr != nullptr ? static_cast<ONNX_NAMESPACE::TensorProto_DataType>(attr->i())
: ONNX_NAMESPACE::TensorProto_DataType_UINT8;
return true;
}

const auto* zp_proto = graph.GetConstantInitializer(input_defs[QDQ::InputIndex::ZERO_POINT_ID]->Name(), true);
if (zp_proto == nullptr) {
return false;
}

zp_data_type = static_cast<ONNX_NAMESPACE::TensorProto_DataType>(zp_proto->data_type());
return true;
}

// Applies a new zero point or scale as the input for a Q/DQ node.
template <typename T>
static void ApplyNewInputValue(Graph& graph, Node& node, QDQ::InputIndex index, T value) {
Expand Down Expand Up @@ -81,38 +114,64 @@ static bool FindNewZeroPointAndScale(const Graph& graph, const Node& node1, cons
return true;
}

// Recomputes the zero point and scale of the outer Q/DQ nodes (i.e., Q1 and DQ2). This is necessary because
// the original two QDQ pairs may have different zero-points and scales. Ex: Q1 -> DQ1 -> Q2 -> DQ2, where
// Recomputes the zero point and scale of the outer Q/DQ nodes (i.e., Q1 and DQ2(s)). This is necessary because
// the original two QDQ pairs may have different zero-points and scales. Ex: Q1 -> DQ1 -> Q2 -> DQ2*, where
// the first pair has (zp1, scale1) and the second pair has (zp2, scale2).
// After removing the middle two nodes, the zero point and scale of the final (outer) ops must be recomputed
// for correctness.
template <typename ZeroPointType>
static bool RecomputeOuterQDQZeroPointAndScale(Graph& graph, Node& q1, const Node& dq1, const Node& q2, Node& dq2) {
bool skip_reset = false;
static bool RecomputeOuterQDQZeroPointAndScale(Graph& graph, Node& q1, const Node& dq1, const Node& q2,
gsl::span<gsl::not_null<Node*>> dq2s) {
if (dq2s.empty()) {
return false;
}

bool no_change_needed = false;
float new_scale = 0.0f;
ZeroPointType new_zero_point = 0;
if (!FindNewZeroPointAndScale(graph, dq1, q2, new_scale, new_zero_point, skip_reset)) {
if (!FindNewZeroPointAndScale(graph, dq1, q2, new_scale, new_zero_point, no_change_needed)) {
return false;
}
if (skip_reset) {
if (no_change_needed) {
return true;
}
ApplyNewInputValue(graph, dq2, QDQ::InputIndex::SCALE_ID, new_scale);

ApplyNewInputValue(graph, q1, QDQ::InputIndex::SCALE_ID, new_scale);
ApplyNewInputValue(graph, dq2, QDQ::InputIndex::ZERO_POINT_ID, new_zero_point);
ApplyNewInputValue(graph, q1, QDQ::InputIndex::ZERO_POINT_ID, new_zero_point);

for (gsl::not_null<Node*> dq2 : dq2s) {
ApplyNewInputValue(graph, *dq2, QDQ::InputIndex::SCALE_ID, new_scale);
ApplyNewInputValue(graph, *dq2, QDQ::InputIndex::ZERO_POINT_ID, new_zero_point);
}

return true;
}

// Checks if the provided node index (dq1_index) is a part of a valid double QDQ pair sequence
// (i.e., Q1 -> DQ1 -> Q2 -> DQ2) that can be reduced to the outer Q/DQ nodes (i.e., Q1 -> DQ2).
// If so, the zero point and scale of the outer Q/DQ nodes are recomputed and the node indices of the other nodes
// in the sequence (i.e., Q1, Q2, and DQ2) are returned via output parameters.
static bool IsReducibleDoubleQDQSequence(Graph& graph, NodeIndex& q1_index, NodeIndex dq1_index,
NodeIndex& q2_index, NodeIndex& dq2_index) {
/// <summary>
/// Tries to reduce a double QDQ sequence (Q1 -> DQ1 -> Q2 -> DQ2*) beginning with the provided Q1 node index.
/// The scale/zero-point values of the outer Q1 and DQ2* nodes may need to be recomputed.
/// Supports multiple identical DQ2 nodes.
/// </summary>
/// <param name="graph">Graph to modify</param>
/// <param name="q1_index">Index of potential Q1 node</param>
/// <returns>True if the double QDQ sequence was reduced</returns>
static bool TryReduceDoubleQDQSequence(Graph& graph, NodeIndex q1_index) {
const auto get_constant_initializer = [&graph](const std::string& initializer_name) {
return graph.GetConstantInitializer(initializer_name, true);
};

// Ensure that q1 is a Q operator, has only one output, and is not a graph output
Node* q1 = graph.GetNode(q1_index);
if (q1 == nullptr ||
q1->OpType() != "QuantizeLinear" ||
q1->GetOutputEdgesCount() != 1 ||
graph.NodeProducesGraphOutput(*q1)) {
return false;
}

// Ensure that dq1 is a DQ operator, has one parent and one child, and is not a graph output
Node* dq1 = graph.GetNode(dq1_index);
NodeIndex dq1_index = q1->OutputEdgesBegin()->GetNode().Index();
const Node* dq1 = graph.GetNode(dq1_index);
if (dq1 == nullptr ||
dq1->OpType() != "DequantizeLinear" ||
dq1->GetInputEdgesCount() != 1 ||
Expand All @@ -121,75 +180,80 @@ static bool IsReducibleDoubleQDQSequence(Graph& graph, NodeIndex& q1_index, Node
return false;
}

// Ensure that q2 is a Q operator, has only one child, and is not a graph output
q2_index = dq1->OutputEdgesBegin()->GetNode().Index();
const Node* q2 = graph.GetNode(q2_index);
if (q2 == nullptr ||
q2->OpType() != "QuantizeLinear" ||
q2->GetOutputEdgesCount() != 1 ||
graph.NodeProducesGraphOutput(*q2)) {
return false;
}

// Ensure that q1 is a Q operator, has only one output, and is not a graph output
q1_index = dq1->InputEdgesBegin()->GetNode().Index();
Node* q1 = graph.GetNode(q1_index);
if (q1 == nullptr ||
q1->GetOutputEdgesCount() != 1 ||
q1->OpType() != "QuantizeLinear" ||
graph.NodeProducesGraphOutput(*q1)) {
// The Q1 and DQ1 nodes must have equal zero-point and scale values (scalar/constant).
if (!QDQ::IsQDQPairSupported(*q1, *dq1, get_constant_initializer, graph.ModelPath())) {
return false;
}

// Ensure the dq2 is a DQ operator.
dq2_index = q2->OutputEdgesBegin()->GetNode().Index();
Node* dq2 = graph.GetNode(dq2_index);
if (dq2 == nullptr ||
dq2->OpType() != "DequantizeLinear") {
auto q1_quant_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
if (!GetQNodeZeroPointType(graph, *q1, q1_quant_type)) {
return false;
}

const auto get_constant_initializer = [&graph](const std::string& initializer_name) {
return graph.GetConstantInitializer(initializer_name, true);
};
// Ensure that q2 is a Q operator, its output is not a graph output, and that its zero-point quantization type
// is equal to q1's.
NodeIndex q2_index = dq1->OutputEdgesBegin()->GetNode().Index();
const Node* q2 = graph.GetNode(q2_index);
auto q2_quant_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;

// Each QDQ pair (i.e., q1 -> dq1, q2 -> dq2) has to meet the following additional requirements:
// - Scalar/constant zero-point and scale.
// - The DQ and Q ops within a pair must have the same scale and zero-point.
// However, each pair is allowed to have different scales and zero-points.
//
// TODO: IsQDQPairSupported() requires an explicit zero-point input, but technically a default
// value of 0 could be fine.
if (!QDQ::IsQDQPairSupported(*q1, *dq1, get_constant_initializer, graph.ModelPath()) ||
!QDQ::IsQDQPairSupported(*q2, *dq2, get_constant_initializer, graph.ModelPath())) {
if (q2 == nullptr ||
q2->OpType() != "QuantizeLinear" ||
graph.NodeProducesGraphOutput(*q2) ||
!GetQNodeZeroPointType(graph, *q2, q2_quant_type) ||
q1_quant_type != q2_quant_type) {
return false;
}

const auto& dq1_input_defs = dq1->InputDefs();
const ONNX_NAMESPACE::TensorProto* dq1_zp_tensor_proto = graph.GetConstantInitializer(
dq1_input_defs[QDQ::InputIndex::ZERO_POINT_ID]->Name(), true);
// All of q2's children should be DQ nodes with zero-point and scale values equal to those of q2.
InlinedVector<gsl::not_null<Node*>> dq2_nodes;
dq2_nodes.reserve(q2->GetOutputEdgesCount());

assert(dq1_zp_tensor_proto != nullptr); // IsQDQPairSupported should have checked that this exists.
for (auto it = q2->OutputEdgesBegin(); it != q2->OutputEdgesEnd(); it++) {
NodeIndex dq2_index = it->GetNode().Index();
Node* dq2 = graph.GetNode(dq2_index);

auto dq1_zp_type = dq1_zp_tensor_proto->data_type();
if (dq2 == nullptr || dq2->OpType() != "DequantizeLinear") {
// Child is not a DQ op.
return false;
}

if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
return RecomputeOuterQDQZeroPointAndScale<uint8_t>(graph, *q1, *dq1, *q2, *dq2);
// The Q2 and DQ2 nodes must have equal zero-point and scale values (scalar/constant).
if (!QDQ::IsQDQPairSupported(*q2, *dq2, get_constant_initializer, graph.ModelPath())) {
return false;
}

dq2_nodes.push_back(dq2);
}

if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) {
return RecomputeOuterQDQZeroPointAndScale<int8_t>(graph, *q1, *dq1, *q2, *dq2);
bool can_recompute = false;
if (q1_quant_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
can_recompute = RecomputeOuterQDQZeroPointAndScale<uint8_t>(graph, *q1, *dq1, *q2, dq2_nodes);
} else if (q1_quant_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) {
can_recompute = RecomputeOuterQDQZeroPointAndScale<int8_t>(graph, *q1, *dq1, *q2, dq2_nodes);
} else if (q1_quant_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) {
can_recompute = RecomputeOuterQDQZeroPointAndScale<uint16_t>(graph, *q1, *dq1, *q2, dq2_nodes);
} else if (q1_quant_type == ONNX_NAMESPACE::TensorProto_DataType_INT16) {
can_recompute = RecomputeOuterQDQZeroPointAndScale<int16_t>(graph, *q1, *dq1, *q2, dq2_nodes);
}

if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) {
return RecomputeOuterQDQZeroPointAndScale<uint16_t>(graph, *q1, *dq1, *q2, *dq2);
if (!can_recompute) {
return false;
}

if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_INT16) {
return RecomputeOuterQDQZeroPointAndScale<int16_t>(graph, *q1, *dq1, *q2, *dq2);
graph.RemoveEdge(q1_index, dq1_index, 0, 0); // Disconnect Q1 -> DQ1
graph.RemoveEdge(dq1_index, q2_index, 0, 0); // Disconnect DQ1 -> Q2

// Disconnect Q2 --> DQ2(s)
// Connect Q1 -> DQ2(s)
for (gsl::not_null<Node*> dq2 : dq2_nodes) {
graph.RemoveEdge(q2_index, dq2->Index(), 0, 0);
graph.AddEdge(q1_index, dq2->Index(), 0, 0);
}

return false; // Unsupported zero-point type
graph.RemoveNode(q2_index);
graph.RemoveNode(dq1_index);

return true;
}

Status DoubleQDQPairsRemover::ApplyImpl(
Expand All @@ -200,18 +264,8 @@ Status DoubleQDQPairsRemover::ApplyImpl(
const GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();

for (const auto& dq1_index : node_topology_list) {
NodeIndex q1_index = 0;
NodeIndex q2_index = 0;
NodeIndex dq2_index = 0;
if (IsReducibleDoubleQDQSequence(graph, q1_index, dq1_index, q2_index, dq2_index)) {
graph.RemoveEdge(q1_index, dq1_index, 0, 0);
graph.RemoveEdge(dq1_index, q2_index, 0, 0);
graph.RemoveEdge(q2_index, dq2_index, 0, 0);
graph_utils::ReplaceNodeInput(*graph.GetNode(dq2_index), 0, *graph.GetNode(dq1_index)->MutableInputDefs()[0]);
graph.AddEdge(q1_index, dq2_index, 0, 0);
graph.RemoveNode(q2_index);
graph.RemoveNode(dq1_index);
for (NodeIndex node_index : node_topology_list) {
if (TryReduceDoubleQDQSequence(graph, node_index)) {
modified = true;
}
}
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/optimizer/double_qdq_pairs_remover.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@ namespace onnxruntime {
* Specifically, this transformer converts the sequence Q1 -> DQ1 -> Q2 -> DQ2, where the first pair has (zp1, scale1)
* and the second pair has (zp2, scale2), into the sequence Q1 -> DQ2 by removing the middle two nodes. The zero-point
* and scale of the final QDQ pair is recomputed to preserve equality to the original sequence.
*
* Also supports multiple identical DQ2 nodes, which may have been inserted by the EnsureUniqueDQNodeUnit optimizer.
* Q1 --> DQ1 --> Q2 --+--> DQ2
* |
* +--> DQ2'
*
* The above becomes:
* Q1 ---+--> DQ2
* |
* +--> DQ2'
*/
class DoubleQDQPairsRemover : public GraphTransformer {
public:
Expand Down
86 changes: 86 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <string>
#include <vector>

#include "core/common/inlined_containers_fwd.h"
#include "core/common/span_utils.h"
#include "core/graph/model.h"
#include "core/session/inference_session.h"
#include "test/compare_ortvalue.h"
Expand All @@ -20,6 +22,90 @@
namespace onnxruntime {
namespace test {

static InlinedVector<std::byte> GetZeroPointBytes(int64_t zero_point, ONNX_NAMESPACE::TensorProto_DataType type) {
switch (type) {
case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
int8_t val = static_cast<int8_t>(zero_point);
auto span = gsl::as_bytes(gsl::make_span(&val, 1));
return InlinedVector<std::byte>(span.begin(), span.end());
}
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
uint8_t val = static_cast<uint8_t>(zero_point);
auto span = gsl::as_bytes(gsl::make_span(&val, 1));
return InlinedVector<std::byte>(span.begin(), span.end());
}
case ONNX_NAMESPACE::TensorProto_DataType_INT16: {
int16_t val = static_cast<int16_t>(zero_point);
auto span = gsl::as_bytes(gsl::make_span(&val, 1));
return InlinedVector<std::byte>(span.begin(), span.end());
}
case ONNX_NAMESPACE::TensorProto_DataType_UINT16: {
uint16_t val = static_cast<uint16_t>(zero_point);
auto span = gsl::as_bytes(gsl::make_span(&val, 1));
return InlinedVector<std::byte>(span.begin(), span.end());
}
case ONNX_NAMESPACE::TensorProto_DataType_INT32: {
int32_t val = static_cast<int32_t>(zero_point);
auto span = gsl::as_bytes(gsl::make_span(&val, 1));
return InlinedVector<std::byte>(span.begin(), span.end());
}
default:
ORT_THROW("Unhandled zero-point type ", type, ".");
}
}

NodeArg* ModelTestBuilder::MakeInitializer(gsl::span<const int64_t> shape,
ONNX_NAMESPACE::TensorProto_DataType elem_type,
gsl::span<const std::byte> raw_data) {
std::string name = graph_.GenerateNodeArgName("constant");
ONNX_NAMESPACE::TensorProto tensor_proto;
tensor_proto.set_name(name);
tensor_proto.set_data_type(elem_type);
tensor_proto.set_raw_data(raw_data.data(), raw_data.size());

for (auto& dim : shape) {
tensor_proto.add_dims(dim);
}

graph_.AddInitializedTensor(tensor_proto);

return &graph_.GetOrCreateNodeArg(name, nullptr);
}

Node& ModelTestBuilder::AddQuantizeLinearNode(NodeArg* input_arg,
float input_scale,
int64_t input_zero_point,
ONNX_NAMESPACE::TensorProto_DataType zero_point_type,
NodeArg* output_arg,
bool use_ms_domain) {
std::vector<NodeArg*> input_args;
input_args.push_back(input_arg);
input_args.push_back(MakeScalarInitializer<float>(input_scale));

InlinedVector<std::byte> zp_bytes = GetZeroPointBytes(input_zero_point, zero_point_type);
input_args.push_back(MakeInitializer({}, zero_point_type, zp_bytes));

std::string domain = use_ms_domain ? kMSDomain : "";
return AddNode("QuantizeLinear", input_args, {output_arg}, domain);
}

Node& ModelTestBuilder::AddDequantizeLinearNode(NodeArg* input_arg,
float input_scale,
int64_t input_zero_point,
ONNX_NAMESPACE::TensorProto_DataType zero_point_type,
NodeArg* output_arg,
bool use_ms_domain) {
std::vector<NodeArg*> input_args;
input_args.push_back(input_arg);
input_args.push_back(MakeScalarInitializer<float>(input_scale));

InlinedVector<std::byte> zp_bytes = GetZeroPointBytes(input_zero_point, zero_point_type);
input_args.push_back(MakeInitializer({}, zero_point_type, zp_bytes));

std::string domain = use_ms_domain ? kMSDomain : "";
return AddNode("DequantizeLinear", input_args, {output_arg}, domain);
}

void TransformerTester(const std::function<void(ModelTestBuilder& helper)>& build_test_case,
const std::function<void(InferenceSessionWrapper& session)>& check_transformed_graph,
TransformerLevel baseline_level,
Expand Down
Loading

0 comments on commit 5bae32e

Please sign in to comment.