From 5bae32eb3409595ee8bcd6ef17f364b2709e125d Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Fri, 24 May 2024 18:30:15 -0700 Subject: [PATCH] Extend DoubleQDQPairsRemover to handle sequences that end in duplicate DQ nodes (#20759) ### Description Extend the DoubleQDQPairsRemover optimizer to also handle sequences that end in duplicate DQ nodes. For example, the following sequence: ``` Q1 --> DQ1 --> Q2 --+--> DQ2 | +--> DQ2' ``` Is now simplified to: ``` Q1 ---+--> DQ2 | +--> DQ2' ``` ### Motivation and Context The EnsureUniqueDQNodeUnits pass may add duplicate DQ nodes to ensure valid QDQ node units. The DoubleQDQPairsRemover should still be able to remove unnecessary QDQ ops if the target sequence ends in duplicate DQ nodes. --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- .../optimizer/double_qdq_pairs_remover.cc | 204 +++++++++++------- .../core/optimizer/double_qdq_pairs_remover.h | 10 + .../optimizer/graph_transform_test_builder.cc | 86 ++++++++ .../optimizer/graph_transform_test_builder.h | 62 ++++-- onnxruntime/test/optimizer/qdq_test_utils.cc | 49 +++++ onnxruntime/test/optimizer/qdq_test_utils.h | 22 ++ .../test/optimizer/qdq_transformer_test.cc | 126 +++++++++++ 7 files changed, 471 insertions(+), 88 deletions(-) diff --git a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc index 624679e7b1b4b..22b9dca39dceb 100644 --- a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc +++ b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc @@ -2,13 +2,46 @@ // Licensed under the MIT License. #include "core/optimizer/double_qdq_pairs_remover.h" #include +#include +#include "core/common/span_utils.h" +#include "core/common/inlined_containers_fwd.h" #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/qdq_transformer/qdq_util.h" namespace onnxruntime { +/// +/// Returns the zero-point type from the given QuantizeLinear node. +/// +/// Graph +/// QuantizeLinear node +/// Output parameter to store the zero-point data type +/// True if successfully extracted the zero-point data type +static bool GetQNodeZeroPointType(const Graph& graph, const Node& q_node, + /*out*/ ONNX_NAMESPACE::TensorProto_DataType& zp_data_type) { + assert(q_node.OpType() == "QuantizeLinear"); + const auto input_defs = q_node.InputDefs(); + + if (QDQ::InputIndex::ZERO_POINT_ID >= input_defs.size() || !input_defs[QDQ::InputIndex::ZERO_POINT_ID]->Exists()) { + // If a zero_point input is absent, get the type from the "output_dtype" attribute or default to uint8. + // The "output_dtype" attribute was added in ONNX opset 21. + const auto* attr = graph_utils::GetNodeAttribute(q_node, "output_dtype"); + zp_data_type = attr != nullptr ? static_cast(attr->i()) + : ONNX_NAMESPACE::TensorProto_DataType_UINT8; + return true; + } + + const auto* zp_proto = graph.GetConstantInitializer(input_defs[QDQ::InputIndex::ZERO_POINT_ID]->Name(), true); + if (zp_proto == nullptr) { + return false; + } + + zp_data_type = static_cast(zp_proto->data_type()); + return true; +} + // Applies a new zero point or scale as the input for a Q/DQ node. template static void ApplyNewInputValue(Graph& graph, Node& node, QDQ::InputIndex index, T value) { @@ -81,38 +114,64 @@ static bool FindNewZeroPointAndScale(const Graph& graph, const Node& node1, cons return true; } -// Recomputes the zero point and scale of the outer Q/DQ nodes (i.e., Q1 and DQ2). This is necessary because -// the original two QDQ pairs may have different zero-points and scales. Ex: Q1 -> DQ1 -> Q2 -> DQ2, where +// Recomputes the zero point and scale of the outer Q/DQ nodes (i.e., Q1 and DQ2(s)). This is necessary because +// the original two QDQ pairs may have different zero-points and scales. Ex: Q1 -> DQ1 -> Q2 -> DQ2*, where // the first pair has (zp1, scale1) and the second pair has (zp2, scale2). // After removing the middle two nodes, the zero point and scale of the final (outer) ops must be recomputed // for correctness. template -static bool RecomputeOuterQDQZeroPointAndScale(Graph& graph, Node& q1, const Node& dq1, const Node& q2, Node& dq2) { - bool skip_reset = false; +static bool RecomputeOuterQDQZeroPointAndScale(Graph& graph, Node& q1, const Node& dq1, const Node& q2, + gsl::span> dq2s) { + if (dq2s.empty()) { + return false; + } + + bool no_change_needed = false; float new_scale = 0.0f; ZeroPointType new_zero_point = 0; - if (!FindNewZeroPointAndScale(graph, dq1, q2, new_scale, new_zero_point, skip_reset)) { + if (!FindNewZeroPointAndScale(graph, dq1, q2, new_scale, new_zero_point, no_change_needed)) { return false; } - if (skip_reset) { + if (no_change_needed) { return true; } - ApplyNewInputValue(graph, dq2, QDQ::InputIndex::SCALE_ID, new_scale); + ApplyNewInputValue(graph, q1, QDQ::InputIndex::SCALE_ID, new_scale); - ApplyNewInputValue(graph, dq2, QDQ::InputIndex::ZERO_POINT_ID, new_zero_point); ApplyNewInputValue(graph, q1, QDQ::InputIndex::ZERO_POINT_ID, new_zero_point); + for (gsl::not_null dq2 : dq2s) { + ApplyNewInputValue(graph, *dq2, QDQ::InputIndex::SCALE_ID, new_scale); + ApplyNewInputValue(graph, *dq2, QDQ::InputIndex::ZERO_POINT_ID, new_zero_point); + } + return true; } -// Checks if the provided node index (dq1_index) is a part of a valid double QDQ pair sequence -// (i.e., Q1 -> DQ1 -> Q2 -> DQ2) that can be reduced to the outer Q/DQ nodes (i.e., Q1 -> DQ2). -// If so, the zero point and scale of the outer Q/DQ nodes are recomputed and the node indices of the other nodes -// in the sequence (i.e., Q1, Q2, and DQ2) are returned via output parameters. -static bool IsReducibleDoubleQDQSequence(Graph& graph, NodeIndex& q1_index, NodeIndex dq1_index, - NodeIndex& q2_index, NodeIndex& dq2_index) { +/// +/// Tries to reduce a double QDQ sequence (Q1 -> DQ1 -> Q2 -> DQ2*) beginning with the provided Q1 node index. +/// The scale/zero-point values of the outer Q1 and DQ2* nodes may need to be recomputed. +/// Supports multiple identical DQ2 nodes. +/// +/// Graph to modify +/// Index of potential Q1 node +/// True if the double QDQ sequence was reduced +static bool TryReduceDoubleQDQSequence(Graph& graph, NodeIndex q1_index) { + const auto get_constant_initializer = [&graph](const std::string& initializer_name) { + return graph.GetConstantInitializer(initializer_name, true); + }; + + // Ensure that q1 is a Q operator, has only one output, and is not a graph output + Node* q1 = graph.GetNode(q1_index); + if (q1 == nullptr || + q1->OpType() != "QuantizeLinear" || + q1->GetOutputEdgesCount() != 1 || + graph.NodeProducesGraphOutput(*q1)) { + return false; + } + // Ensure that dq1 is a DQ operator, has one parent and one child, and is not a graph output - Node* dq1 = graph.GetNode(dq1_index); + NodeIndex dq1_index = q1->OutputEdgesBegin()->GetNode().Index(); + const Node* dq1 = graph.GetNode(dq1_index); if (dq1 == nullptr || dq1->OpType() != "DequantizeLinear" || dq1->GetInputEdgesCount() != 1 || @@ -121,75 +180,80 @@ static bool IsReducibleDoubleQDQSequence(Graph& graph, NodeIndex& q1_index, Node return false; } - // Ensure that q2 is a Q operator, has only one child, and is not a graph output - q2_index = dq1->OutputEdgesBegin()->GetNode().Index(); - const Node* q2 = graph.GetNode(q2_index); - if (q2 == nullptr || - q2->OpType() != "QuantizeLinear" || - q2->GetOutputEdgesCount() != 1 || - graph.NodeProducesGraphOutput(*q2)) { - return false; - } - - // Ensure that q1 is a Q operator, has only one output, and is not a graph output - q1_index = dq1->InputEdgesBegin()->GetNode().Index(); - Node* q1 = graph.GetNode(q1_index); - if (q1 == nullptr || - q1->GetOutputEdgesCount() != 1 || - q1->OpType() != "QuantizeLinear" || - graph.NodeProducesGraphOutput(*q1)) { + // The Q1 and DQ1 nodes must have equal zero-point and scale values (scalar/constant). + if (!QDQ::IsQDQPairSupported(*q1, *dq1, get_constant_initializer, graph.ModelPath())) { return false; } - // Ensure the dq2 is a DQ operator. - dq2_index = q2->OutputEdgesBegin()->GetNode().Index(); - Node* dq2 = graph.GetNode(dq2_index); - if (dq2 == nullptr || - dq2->OpType() != "DequantizeLinear") { + auto q1_quant_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + if (!GetQNodeZeroPointType(graph, *q1, q1_quant_type)) { return false; } - const auto get_constant_initializer = [&graph](const std::string& initializer_name) { - return graph.GetConstantInitializer(initializer_name, true); - }; + // Ensure that q2 is a Q operator, its output is not a graph output, and that its zero-point quantization type + // is equal to q1's. + NodeIndex q2_index = dq1->OutputEdgesBegin()->GetNode().Index(); + const Node* q2 = graph.GetNode(q2_index); + auto q2_quant_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; - // Each QDQ pair (i.e., q1 -> dq1, q2 -> dq2) has to meet the following additional requirements: - // - Scalar/constant zero-point and scale. - // - The DQ and Q ops within a pair must have the same scale and zero-point. - // However, each pair is allowed to have different scales and zero-points. - // - // TODO: IsQDQPairSupported() requires an explicit zero-point input, but technically a default - // value of 0 could be fine. - if (!QDQ::IsQDQPairSupported(*q1, *dq1, get_constant_initializer, graph.ModelPath()) || - !QDQ::IsQDQPairSupported(*q2, *dq2, get_constant_initializer, graph.ModelPath())) { + if (q2 == nullptr || + q2->OpType() != "QuantizeLinear" || + graph.NodeProducesGraphOutput(*q2) || + !GetQNodeZeroPointType(graph, *q2, q2_quant_type) || + q1_quant_type != q2_quant_type) { return false; } - const auto& dq1_input_defs = dq1->InputDefs(); - const ONNX_NAMESPACE::TensorProto* dq1_zp_tensor_proto = graph.GetConstantInitializer( - dq1_input_defs[QDQ::InputIndex::ZERO_POINT_ID]->Name(), true); + // All of q2's children should be DQ nodes with zero-point and scale values equal to those of q2. + InlinedVector> dq2_nodes; + dq2_nodes.reserve(q2->GetOutputEdgesCount()); - assert(dq1_zp_tensor_proto != nullptr); // IsQDQPairSupported should have checked that this exists. + for (auto it = q2->OutputEdgesBegin(); it != q2->OutputEdgesEnd(); it++) { + NodeIndex dq2_index = it->GetNode().Index(); + Node* dq2 = graph.GetNode(dq2_index); - auto dq1_zp_type = dq1_zp_tensor_proto->data_type(); + if (dq2 == nullptr || dq2->OpType() != "DequantizeLinear") { + // Child is not a DQ op. + return false; + } - if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { - return RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, *dq2); + // The Q2 and DQ2 nodes must have equal zero-point and scale values (scalar/constant). + if (!QDQ::IsQDQPairSupported(*q2, *dq2, get_constant_initializer, graph.ModelPath())) { + return false; + } + + dq2_nodes.push_back(dq2); } - if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) { - return RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, *dq2); + bool can_recompute = false; + if (q1_quant_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { + can_recompute = RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, dq2_nodes); + } else if (q1_quant_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) { + can_recompute = RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, dq2_nodes); + } else if (q1_quant_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) { + can_recompute = RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, dq2_nodes); + } else if (q1_quant_type == ONNX_NAMESPACE::TensorProto_DataType_INT16) { + can_recompute = RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, dq2_nodes); } - if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) { - return RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, *dq2); + if (!can_recompute) { + return false; } - if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_INT16) { - return RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, *dq2); + graph.RemoveEdge(q1_index, dq1_index, 0, 0); // Disconnect Q1 -> DQ1 + graph.RemoveEdge(dq1_index, q2_index, 0, 0); // Disconnect DQ1 -> Q2 + + // Disconnect Q2 --> DQ2(s) + // Connect Q1 -> DQ2(s) + for (gsl::not_null dq2 : dq2_nodes) { + graph.RemoveEdge(q2_index, dq2->Index(), 0, 0); + graph.AddEdge(q1_index, dq2->Index(), 0, 0); } - return false; // Unsupported zero-point type + graph.RemoveNode(q2_index); + graph.RemoveNode(dq1_index); + + return true; } Status DoubleQDQPairsRemover::ApplyImpl( @@ -200,18 +264,8 @@ Status DoubleQDQPairsRemover::ApplyImpl( const GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); - for (const auto& dq1_index : node_topology_list) { - NodeIndex q1_index = 0; - NodeIndex q2_index = 0; - NodeIndex dq2_index = 0; - if (IsReducibleDoubleQDQSequence(graph, q1_index, dq1_index, q2_index, dq2_index)) { - graph.RemoveEdge(q1_index, dq1_index, 0, 0); - graph.RemoveEdge(dq1_index, q2_index, 0, 0); - graph.RemoveEdge(q2_index, dq2_index, 0, 0); - graph_utils::ReplaceNodeInput(*graph.GetNode(dq2_index), 0, *graph.GetNode(dq1_index)->MutableInputDefs()[0]); - graph.AddEdge(q1_index, dq2_index, 0, 0); - graph.RemoveNode(q2_index); - graph.RemoveNode(dq1_index); + for (NodeIndex node_index : node_topology_list) { + if (TryReduceDoubleQDQSequence(graph, node_index)) { modified = true; } } diff --git a/onnxruntime/core/optimizer/double_qdq_pairs_remover.h b/onnxruntime/core/optimizer/double_qdq_pairs_remover.h index 1833b007674fd..854b3f52de72c 100644 --- a/onnxruntime/core/optimizer/double_qdq_pairs_remover.h +++ b/onnxruntime/core/optimizer/double_qdq_pairs_remover.h @@ -13,6 +13,16 @@ namespace onnxruntime { * Specifically, this transformer converts the sequence Q1 -> DQ1 -> Q2 -> DQ2, where the first pair has (zp1, scale1) * and the second pair has (zp2, scale2), into the sequence Q1 -> DQ2 by removing the middle two nodes. The zero-point * and scale of the final QDQ pair is recomputed to preserve equality to the original sequence. + * + * Also supports multiple identical DQ2 nodes, which may have been inserted by the EnsureUniqueDQNodeUnit optimizer. + * Q1 --> DQ1 --> Q2 --+--> DQ2 + * | + * +--> DQ2' + * + * The above becomes: + * Q1 ---+--> DQ2 + * | + * +--> DQ2' */ class DoubleQDQPairsRemover : public GraphTransformer { public: diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.cc b/onnxruntime/test/optimizer/graph_transform_test_builder.cc index a5024f510b3cd..73c8b3f119103 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.cc @@ -7,6 +7,8 @@ #include #include +#include "core/common/inlined_containers_fwd.h" +#include "core/common/span_utils.h" #include "core/graph/model.h" #include "core/session/inference_session.h" #include "test/compare_ortvalue.h" @@ -20,6 +22,90 @@ namespace onnxruntime { namespace test { +static InlinedVector GetZeroPointBytes(int64_t zero_point, ONNX_NAMESPACE::TensorProto_DataType type) { + switch (type) { + case ONNX_NAMESPACE::TensorProto_DataType_INT8: { + int8_t val = static_cast(zero_point); + auto span = gsl::as_bytes(gsl::make_span(&val, 1)); + return InlinedVector(span.begin(), span.end()); + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { + uint8_t val = static_cast(zero_point); + auto span = gsl::as_bytes(gsl::make_span(&val, 1)); + return InlinedVector(span.begin(), span.end()); + } + case ONNX_NAMESPACE::TensorProto_DataType_INT16: { + int16_t val = static_cast(zero_point); + auto span = gsl::as_bytes(gsl::make_span(&val, 1)); + return InlinedVector(span.begin(), span.end()); + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + uint16_t val = static_cast(zero_point); + auto span = gsl::as_bytes(gsl::make_span(&val, 1)); + return InlinedVector(span.begin(), span.end()); + } + case ONNX_NAMESPACE::TensorProto_DataType_INT32: { + int32_t val = static_cast(zero_point); + auto span = gsl::as_bytes(gsl::make_span(&val, 1)); + return InlinedVector(span.begin(), span.end()); + } + default: + ORT_THROW("Unhandled zero-point type ", type, "."); + } +} + +NodeArg* ModelTestBuilder::MakeInitializer(gsl::span shape, + ONNX_NAMESPACE::TensorProto_DataType elem_type, + gsl::span raw_data) { + std::string name = graph_.GenerateNodeArgName("constant"); + ONNX_NAMESPACE::TensorProto tensor_proto; + tensor_proto.set_name(name); + tensor_proto.set_data_type(elem_type); + tensor_proto.set_raw_data(raw_data.data(), raw_data.size()); + + for (auto& dim : shape) { + tensor_proto.add_dims(dim); + } + + graph_.AddInitializedTensor(tensor_proto); + + return &graph_.GetOrCreateNodeArg(name, nullptr); +} + +Node& ModelTestBuilder::AddQuantizeLinearNode(NodeArg* input_arg, + float input_scale, + int64_t input_zero_point, + ONNX_NAMESPACE::TensorProto_DataType zero_point_type, + NodeArg* output_arg, + bool use_ms_domain) { + std::vector input_args; + input_args.push_back(input_arg); + input_args.push_back(MakeScalarInitializer(input_scale)); + + InlinedVector zp_bytes = GetZeroPointBytes(input_zero_point, zero_point_type); + input_args.push_back(MakeInitializer({}, zero_point_type, zp_bytes)); + + std::string domain = use_ms_domain ? kMSDomain : ""; + return AddNode("QuantizeLinear", input_args, {output_arg}, domain); +} + +Node& ModelTestBuilder::AddDequantizeLinearNode(NodeArg* input_arg, + float input_scale, + int64_t input_zero_point, + ONNX_NAMESPACE::TensorProto_DataType zero_point_type, + NodeArg* output_arg, + bool use_ms_domain) { + std::vector input_args; + input_args.push_back(input_arg); + input_args.push_back(MakeScalarInitializer(input_scale)); + + InlinedVector zp_bytes = GetZeroPointBytes(input_zero_point, zero_point_type); + input_args.push_back(MakeInitializer({}, zero_point_type, zp_bytes)); + + std::string domain = use_ms_domain ? kMSDomain : ""; + return AddNode("DequantizeLinear", input_args, {output_arg}, domain); +} + void TransformerTester(const std::function& build_test_case, const std::function& check_transformed_graph, TransformerLevel baseline_level, diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index 57f10d9a4eb69..fd5770cb70022 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -6,6 +6,7 @@ #include #include +#include "core/common/span_utils.h" #include "core/common/type_utils.h" #include "core/graph/graph.h" #include "core/framework/framework_common.h" @@ -195,21 +196,20 @@ class ModelTestBuilder { return &graph_.GetOrCreateNodeArg(name, &type_proto); } + /// + /// Makes an initializer from the provided shape, element type, and raw data bytes. + /// + /// Initializer shape + /// ONNX tensor element data type + /// Raw data bytes + /// NodeArg pointer + NodeArg* MakeInitializer(gsl::span shape, ONNX_NAMESPACE::TensorProto_DataType elem_type, + gsl::span raw_data); + template NodeArg* MakeInitializer(const std::vector& shape, const std::vector& data) { - std::string name = graph_.GenerateNodeArgName("constant"); - ONNX_NAMESPACE::TensorProto tensor_proto; - tensor_proto.set_name(name); - tensor_proto.set_data_type(utils::ToTensorProtoElementType()); - tensor_proto.set_raw_data(data.data(), data.size() * sizeof(T)); - - for (auto& dim : shape) { - tensor_proto.add_dims(dim); - } - - graph_.AddInitializedTensor(tensor_proto); - - return &graph_.GetOrCreateNodeArg(name, nullptr); + gsl::span raw_data = ReinterpretAsSpan(data); + return MakeInitializer(shape, utils::ToTensorProtoElementType(), raw_data); } // Special handle for std::vector. @@ -342,6 +342,24 @@ class ModelTestBuilder { return AddNode("QuantizeLinear", input_args, {output_arg}, domain, attributes); } + /// + /// Adds a Q node with a configurable zero-point type. + /// Takes in an int64_t zero_point value, which is large enough to represent all ONNX zero-point types. + /// + /// First input to the Q node + /// Input scale value + /// Input zero point value + /// Input zero point's type + /// Q node's output node arg + /// True to use the 'com.microsoft' domain + /// Reference to the new Q node + Node& AddQuantizeLinearNode(NodeArg* input_arg, + float input_scale, + int64_t input_zero_point, + ONNX_NAMESPACE::TensorProto_DataType zero_point_type, + NodeArg* output_arg, + bool use_ms_domain = false); + template typename std::enable_if::value, Node&>::type AddDequantizeLinearNode(NodeArg* input_arg, @@ -400,6 +418,24 @@ class ModelTestBuilder { return AddNode("DequantizeLinear", input_args, {output_arg}, domain, attributes); } + /// + /// Adds a DQ node with a configurable zero-point type. + /// Takes in an int64_t zero_point value, which is large enough to represent all ONNX zero-point types. + /// + /// First input to the DQ node + /// Input scale value + /// Input zero point value + /// Input zero point's type + /// DQ node's output node arg + /// True to use the 'com.microsoft' domain + /// Reference to the new DQ node + Node& AddDequantizeLinearNode(NodeArg* input_arg, + float input_scale, + int64_t input_zero_point, + ONNX_NAMESPACE::TensorProto_DataType zero_point_type, + NodeArg* output_arg, + bool use_ms_domain = false); + template Node& AddQLinearConvNode(NodeArg* input_arg, float input_scale, diff --git a/onnxruntime/test/optimizer/qdq_test_utils.cc b/onnxruntime/test/optimizer/qdq_test_utils.cc index 24cace43c6967..a2f32562658c3 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.cc +++ b/onnxruntime/test/optimizer/qdq_test_utils.cc @@ -5,6 +5,8 @@ #include #include #include "core/common/common.h" +#include "core/common/inlined_containers_fwd.h" +#include "core/common/span_utils.h" namespace onnxruntime { namespace test { @@ -164,5 +166,52 @@ std::vector GetNodeOpTypesInTopologicalOrder(const Graph& graph, bo return op_types; } +GetQDQTestCaseFn BuildDoubleQDQTestCaseWithDuplicateLastDQs( + gsl::span input_shape, + gsl::span input_data, + gsl::span zero_points, + gsl::span zero_point_types, + gsl::span scales, + size_t graph_output_index, + bool use_contrib_qdq) { + const size_t num_nodes = zero_points.size(); + bool valid_inputs = (num_nodes >= 4) && + (zero_point_types.size() == num_nodes) && + (scales.size() == num_nodes) && + (graph_output_index < 4); + if (!valid_inputs) { + ORT_THROW("Invalid inputs for call to BuildDoubleQDQTestCaseWithDuplicateLastDQs()"); + } + + return [=](ModelTestBuilder& builder) { + // TODO(adrianlizarraga): Clean up ModelTestBuilder functions (like MakeInput) to work with gsl::span inputs. + // For now, we have to copy data into a std::vector if we want this outer function to take in span inputs. + std::vector input_shape_copy(input_shape.begin(), input_shape.end()); + std::vector input_data_copy(input_data.begin(), input_data.end()); + auto* input_arg = builder.MakeInput(input_shape_copy, input_data_copy); + InlinedVector node_outputs(num_nodes); + + for (size_t i = 0; i < num_nodes; i++) { + if (i == graph_output_index || i >= 3) { + node_outputs[i] = builder.MakeOutput(); + } else { + node_outputs[i] = builder.MakeIntermediate(); + } + } + + builder.AddQuantizeLinearNode(input_arg, scales[0], zero_points[0], zero_point_types[0], node_outputs[0], + use_contrib_qdq); + builder.AddDequantizeLinearNode(node_outputs[0], scales[1], zero_points[1], zero_point_types[1], node_outputs[1], + use_contrib_qdq); + builder.AddQuantizeLinearNode(node_outputs[1], scales[2], zero_points[2], zero_point_types[2], node_outputs[2], + use_contrib_qdq); + + for (size_t i = 3; i < num_nodes; i++) { + builder.AddDequantizeLinearNode(node_outputs[2], scales[i], zero_points[i], zero_point_types[i], + node_outputs[i], use_contrib_qdq); + } + }; +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/optimizer/qdq_test_utils.h index 414a0fbeb78f5..3dab0ec248f95 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.h +++ b/onnxruntime/test/optimizer/qdq_test_utils.h @@ -8,6 +8,7 @@ #include "graph_transform_test_builder.h" +#include "core/common/span_utils.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/session/inference_session.h" @@ -460,6 +461,27 @@ GetQDQTestCaseFn BuildDoubleQDQTestCases(Type1 zp_1, Type2 zp_2, Type3 zp_3, Typ }; } +/// +/// Returns a function that builds a model with a double QDQ sequence (Q1 -> DQ1 -> Q2 -> DQ2*), +/// where DQ2 can be repeated. Must provide at least 4 zero-point and scale values. +/// +/// Shape of input float data. +/// Input float data. +/// Ordered list of zero-point values for each node in the sequence. +/// Ordered list of zero-point types for each node in the sequence. +/// Ordered list of scale values for each node in the sequence. +/// Index of the node that provides a graph output. +/// Set to true to use the 'com.microsoft' domain for Q and DQ ops. +/// A function for building the model +GetQDQTestCaseFn BuildDoubleQDQTestCaseWithDuplicateLastDQs( + gsl::span input_shape, + gsl::span input_data, + gsl::span zero_points, + gsl::span zero_point_types, + gsl::span scales, + size_t graph_output_index, + bool use_contrib_qdq = false); + template GetQDQTestCaseFn BuildDoubleQDQWithoutLastOutput(int output_index, bool use_contrib_qdq = false) { return [=](ModelTestBuilder& builder) { diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 8c138b22bd52b..31e2280187f76 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -2,11 +2,14 @@ // Licensed under the MIT License. #include +#include "core/common/inlined_containers_fwd.h" +#include "core/common/span_utils.h" #include "core/framework/compute_capability.h" #include "core/framework/node_unit.h" #include "core/graph/model.h" #include "core/graph/onnx_protobuf.h" #include "core/mlas/inc/mlas.h" +#include "core/optimizer/double_qdq_pairs_remover.h" #include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" @@ -1233,6 +1236,129 @@ TEST(QDQTransformerTests, DoubleQDQ_Without_Last_Node_Being_Output) { RunDoubleQDQWithoutLastNodeBeingOutput(3, 1, 1, !use_ms_qdq, 21); } +// Utility function that runs a model with a double QDQ sequence (with duplicate end DQs) through +// the DoubleQDQPairsRemover transformer and checks that the resulting graph contains the expected nodes. +// Also checks that the output from the unmodified model matches the output from the modified model. +static void RunDoubleQDQWithDuplicateLastDQs(int expected_Q_count, int expected_DQ_count, + gsl::span input_shape, + gsl::span input_data, + gsl::span zero_points, + gsl::span zero_point_types, + gsl::span scales, + size_t graph_output_index, + bool use_contrib_qdq = false, + int opset = 19) { + auto graph_checker = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], expected_Q_count); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], expected_DQ_count); + }; + + auto model_build_fn = BuildDoubleQDQTestCaseWithDuplicateLastDQs(input_shape, input_data, zero_points, + zero_point_types, scales, graph_output_index, + use_contrib_qdq); + TransformerTester(model_build_fn, + graph_checker, + TransformerLevel::Default, + TransformerLevel::Level1, + opset, + /*per_sample_tolerance*/ 0.0, + /*relative_per_sample_tolerance*/ 0.0, + std::make_unique()); +} + +// Test QDQDoublePairsRemover when the sequence ends with duplicate DQs. +TEST(QDQTransformerTests, DoubleQDQPairsRemover_DuplicateLastDQs) { + InlinedVector shape = {1, 2, 2, 2}; + InlinedVector input_data = {-3.0f, -2.0f, -1.0f, 0.0f, 0.5f, 1.0f, 2.0f, 3.0f}; + + constexpr auto int8_type = ONNX_NAMESPACE::TensorProto_DataType_INT8; + constexpr auto uint8_type = ONNX_NAMESPACE::TensorProto_DataType_UINT8; + constexpr auto int16_type = ONNX_NAMESPACE::TensorProto_DataType_INT16; + constexpr auto uint16_type = ONNX_NAMESPACE::TensorProto_DataType_UINT16; + InlinedVector quant_types = {int8_type, uint8_type, int16_type, uint16_type}; + + // Input graph: + // input -> Q1 -> DQ1 -> Q2 --+--> DQ2 -> output0 + // | + // ... + // | + // +--> DQ2'' -> outputN + // Expected graph after DoubleQDQPairsRemover: + // input -> Q1 --+--> DQ2 -> output0 + // | + // ... + // | + // +--> DQ2'' -> outputN + for (auto quant_type : quant_types) { + for (size_t num_dq2s = 1; num_dq2s <= 3; num_dq2s++) { + const size_t num_nodes = 3 + num_dq2s; + InlinedVector zp_vals(num_nodes, 1); + InlinedVector zp_types(num_nodes, quant_type); + InlinedVector scale_vals(num_nodes, 0.1f); + + const int expected_q_nodes = 1; + const int expected_dq_nodes = static_cast(num_dq2s); + RunDoubleQDQWithDuplicateLastDQs(expected_q_nodes, expected_dq_nodes, shape, input_data, zp_vals, zp_types, + scale_vals, 3, false, 21); + RunDoubleQDQWithDuplicateLastDQs(expected_q_nodes, expected_dq_nodes, shape, input_data, zp_vals, zp_types, + scale_vals, 3, quant_type == int16_type || quant_type == uint16_type, 19); + } + } + + // Should not remove QDQ pair because the middle nodes produce a graph output. + for (auto quant_type : quant_types) { + for (size_t output_index = 0; output_index < 3; output_index++) { + for (size_t num_dq2s = 1; num_dq2s <= 3; num_dq2s++) { + const size_t num_nodes = 3 + num_dq2s; + InlinedVector zp_vals(num_nodes, 1); + InlinedVector zp_types(num_nodes, quant_type); + InlinedVector scale_vals(num_nodes, 0.1f); + + const int expected_q_nodes = 2; + int expected_dq_nodes = 1 + static_cast(num_dq2s); + if (output_index == 1) { + // EnsureUniqueDQ pass will create a duplicate DQ if it produces a graph output. + expected_dq_nodes += 1; + } + RunDoubleQDQWithDuplicateLastDQs(expected_q_nodes, expected_dq_nodes, shape, input_data, zp_vals, zp_types, + scale_vals, output_index, false, 21); + } + } + } + + // Should not remove any nodes because the Q -> DQ pairs are of different quant types. + for (size_t num_dq2s = 1; num_dq2s <= 2; num_dq2s++) { + const size_t num_nodes = 3 + num_dq2s; + InlinedVector zp_vals(num_nodes, 1); + InlinedVector zp_types(num_nodes, int8_type); + for (size_t i = 2; i < num_nodes; i++) { + // Q2 -> DQ2* have a different type + zp_types[i] = int16_type; + } + InlinedVector scale_vals(num_nodes, 0.1f); + + const int expected_q_nodes = 2; + const int expected_dq_nodes = 1 + static_cast(num_dq2s); + RunDoubleQDQWithDuplicateLastDQs(expected_q_nodes, expected_dq_nodes, shape, input_data, zp_vals, zp_types, + scale_vals, 3, false, 21); + } + + // Should not remove nodes because 1 of the ending DQ2s has a different zero_point. + const size_t num_dq2s = 2; + const size_t num_nodes = 3 + num_dq2s; + InlinedVector zp_vals(num_nodes, 1); + InlinedVector zp_types(num_nodes, int8_type); + zp_vals[num_nodes - 1] = 2; // Last DQ2 has a different zero-point. + InlinedVector scale_vals(num_nodes, 0.1f); + + const int expected_q_nodes = 2; + const int expected_dq_nodes = 1 + static_cast(num_dq2s); + RunDoubleQDQWithDuplicateLastDQs(expected_q_nodes, expected_dq_nodes, shape, input_data, zp_vals, zp_types, + scale_vals, 3, false, 21); +} + // Runs a test that checks if DQ -> Split -> Q (many) is replaced with just Split. template static void RunDropSplitQDQTestCase(const std::vector& input_shape, int64_t axis,