diff --git a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc index 624679e7b1b4b..22b9dca39dceb 100644 --- a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc +++ b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc @@ -2,13 +2,46 @@ // Licensed under the MIT License. #include "core/optimizer/double_qdq_pairs_remover.h" #include +#include +#include "core/common/span_utils.h" +#include "core/common/inlined_containers_fwd.h" #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/qdq_transformer/qdq_util.h" namespace onnxruntime { +/// +/// Returns the zero-point type from the given QuantizeLinear node. +/// +/// Graph +/// QuantizeLinear node +/// Output parameter to store the zero-point data type +/// True if successfully extracted the zero-point data type +static bool GetQNodeZeroPointType(const Graph& graph, const Node& q_node, + /*out*/ ONNX_NAMESPACE::TensorProto_DataType& zp_data_type) { + assert(q_node.OpType() == "QuantizeLinear"); + const auto input_defs = q_node.InputDefs(); + + if (QDQ::InputIndex::ZERO_POINT_ID >= input_defs.size() || !input_defs[QDQ::InputIndex::ZERO_POINT_ID]->Exists()) { + // If a zero_point input is absent, get the type from the "output_dtype" attribute or default to uint8. + // The "output_dtype" attribute was added in ONNX opset 21. + const auto* attr = graph_utils::GetNodeAttribute(q_node, "output_dtype"); + zp_data_type = attr != nullptr ? static_cast(attr->i()) + : ONNX_NAMESPACE::TensorProto_DataType_UINT8; + return true; + } + + const auto* zp_proto = graph.GetConstantInitializer(input_defs[QDQ::InputIndex::ZERO_POINT_ID]->Name(), true); + if (zp_proto == nullptr) { + return false; + } + + zp_data_type = static_cast(zp_proto->data_type()); + return true; +} + // Applies a new zero point or scale as the input for a Q/DQ node. template static void ApplyNewInputValue(Graph& graph, Node& node, QDQ::InputIndex index, T value) { @@ -81,38 +114,64 @@ static bool FindNewZeroPointAndScale(const Graph& graph, const Node& node1, cons return true; } -// Recomputes the zero point and scale of the outer Q/DQ nodes (i.e., Q1 and DQ2). This is necessary because -// the original two QDQ pairs may have different zero-points and scales. Ex: Q1 -> DQ1 -> Q2 -> DQ2, where +// Recomputes the zero point and scale of the outer Q/DQ nodes (i.e., Q1 and DQ2(s)). This is necessary because +// the original two QDQ pairs may have different zero-points and scales. Ex: Q1 -> DQ1 -> Q2 -> DQ2*, where // the first pair has (zp1, scale1) and the second pair has (zp2, scale2). // After removing the middle two nodes, the zero point and scale of the final (outer) ops must be recomputed // for correctness. template -static bool RecomputeOuterQDQZeroPointAndScale(Graph& graph, Node& q1, const Node& dq1, const Node& q2, Node& dq2) { - bool skip_reset = false; +static bool RecomputeOuterQDQZeroPointAndScale(Graph& graph, Node& q1, const Node& dq1, const Node& q2, + gsl::span> dq2s) { + if (dq2s.empty()) { + return false; + } + + bool no_change_needed = false; float new_scale = 0.0f; ZeroPointType new_zero_point = 0; - if (!FindNewZeroPointAndScale(graph, dq1, q2, new_scale, new_zero_point, skip_reset)) { + if (!FindNewZeroPointAndScale(graph, dq1, q2, new_scale, new_zero_point, no_change_needed)) { return false; } - if (skip_reset) { + if (no_change_needed) { return true; } - ApplyNewInputValue(graph, dq2, QDQ::InputIndex::SCALE_ID, new_scale); + ApplyNewInputValue(graph, q1, QDQ::InputIndex::SCALE_ID, new_scale); - ApplyNewInputValue(graph, dq2, QDQ::InputIndex::ZERO_POINT_ID, new_zero_point); ApplyNewInputValue(graph, q1, QDQ::InputIndex::ZERO_POINT_ID, new_zero_point); + for (gsl::not_null dq2 : dq2s) { + ApplyNewInputValue(graph, *dq2, QDQ::InputIndex::SCALE_ID, new_scale); + ApplyNewInputValue(graph, *dq2, QDQ::InputIndex::ZERO_POINT_ID, new_zero_point); + } + return true; } -// Checks if the provided node index (dq1_index) is a part of a valid double QDQ pair sequence -// (i.e., Q1 -> DQ1 -> Q2 -> DQ2) that can be reduced to the outer Q/DQ nodes (i.e., Q1 -> DQ2). -// If so, the zero point and scale of the outer Q/DQ nodes are recomputed and the node indices of the other nodes -// in the sequence (i.e., Q1, Q2, and DQ2) are returned via output parameters. -static bool IsReducibleDoubleQDQSequence(Graph& graph, NodeIndex& q1_index, NodeIndex dq1_index, - NodeIndex& q2_index, NodeIndex& dq2_index) { +/// +/// Tries to reduce a double QDQ sequence (Q1 -> DQ1 -> Q2 -> DQ2*) beginning with the provided Q1 node index. +/// The scale/zero-point values of the outer Q1 and DQ2* nodes may need to be recomputed. +/// Supports multiple identical DQ2 nodes. +/// +/// Graph to modify +/// Index of potential Q1 node +/// True if the double QDQ sequence was reduced +static bool TryReduceDoubleQDQSequence(Graph& graph, NodeIndex q1_index) { + const auto get_constant_initializer = [&graph](const std::string& initializer_name) { + return graph.GetConstantInitializer(initializer_name, true); + }; + + // Ensure that q1 is a Q operator, has only one output, and is not a graph output + Node* q1 = graph.GetNode(q1_index); + if (q1 == nullptr || + q1->OpType() != "QuantizeLinear" || + q1->GetOutputEdgesCount() != 1 || + graph.NodeProducesGraphOutput(*q1)) { + return false; + } + // Ensure that dq1 is a DQ operator, has one parent and one child, and is not a graph output - Node* dq1 = graph.GetNode(dq1_index); + NodeIndex dq1_index = q1->OutputEdgesBegin()->GetNode().Index(); + const Node* dq1 = graph.GetNode(dq1_index); if (dq1 == nullptr || dq1->OpType() != "DequantizeLinear" || dq1->GetInputEdgesCount() != 1 || @@ -121,75 +180,80 @@ static bool IsReducibleDoubleQDQSequence(Graph& graph, NodeIndex& q1_index, Node return false; } - // Ensure that q2 is a Q operator, has only one child, and is not a graph output - q2_index = dq1->OutputEdgesBegin()->GetNode().Index(); - const Node* q2 = graph.GetNode(q2_index); - if (q2 == nullptr || - q2->OpType() != "QuantizeLinear" || - q2->GetOutputEdgesCount() != 1 || - graph.NodeProducesGraphOutput(*q2)) { - return false; - } - - // Ensure that q1 is a Q operator, has only one output, and is not a graph output - q1_index = dq1->InputEdgesBegin()->GetNode().Index(); - Node* q1 = graph.GetNode(q1_index); - if (q1 == nullptr || - q1->GetOutputEdgesCount() != 1 || - q1->OpType() != "QuantizeLinear" || - graph.NodeProducesGraphOutput(*q1)) { + // The Q1 and DQ1 nodes must have equal zero-point and scale values (scalar/constant). + if (!QDQ::IsQDQPairSupported(*q1, *dq1, get_constant_initializer, graph.ModelPath())) { return false; } - // Ensure the dq2 is a DQ operator. - dq2_index = q2->OutputEdgesBegin()->GetNode().Index(); - Node* dq2 = graph.GetNode(dq2_index); - if (dq2 == nullptr || - dq2->OpType() != "DequantizeLinear") { + auto q1_quant_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + if (!GetQNodeZeroPointType(graph, *q1, q1_quant_type)) { return false; } - const auto get_constant_initializer = [&graph](const std::string& initializer_name) { - return graph.GetConstantInitializer(initializer_name, true); - }; + // Ensure that q2 is a Q operator, its output is not a graph output, and that its zero-point quantization type + // is equal to q1's. + NodeIndex q2_index = dq1->OutputEdgesBegin()->GetNode().Index(); + const Node* q2 = graph.GetNode(q2_index); + auto q2_quant_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; - // Each QDQ pair (i.e., q1 -> dq1, q2 -> dq2) has to meet the following additional requirements: - // - Scalar/constant zero-point and scale. - // - The DQ and Q ops within a pair must have the same scale and zero-point. - // However, each pair is allowed to have different scales and zero-points. - // - // TODO: IsQDQPairSupported() requires an explicit zero-point input, but technically a default - // value of 0 could be fine. - if (!QDQ::IsQDQPairSupported(*q1, *dq1, get_constant_initializer, graph.ModelPath()) || - !QDQ::IsQDQPairSupported(*q2, *dq2, get_constant_initializer, graph.ModelPath())) { + if (q2 == nullptr || + q2->OpType() != "QuantizeLinear" || + graph.NodeProducesGraphOutput(*q2) || + !GetQNodeZeroPointType(graph, *q2, q2_quant_type) || + q1_quant_type != q2_quant_type) { return false; } - const auto& dq1_input_defs = dq1->InputDefs(); - const ONNX_NAMESPACE::TensorProto* dq1_zp_tensor_proto = graph.GetConstantInitializer( - dq1_input_defs[QDQ::InputIndex::ZERO_POINT_ID]->Name(), true); + // All of q2's children should be DQ nodes with zero-point and scale values equal to those of q2. + InlinedVector> dq2_nodes; + dq2_nodes.reserve(q2->GetOutputEdgesCount()); - assert(dq1_zp_tensor_proto != nullptr); // IsQDQPairSupported should have checked that this exists. + for (auto it = q2->OutputEdgesBegin(); it != q2->OutputEdgesEnd(); it++) { + NodeIndex dq2_index = it->GetNode().Index(); + Node* dq2 = graph.GetNode(dq2_index); - auto dq1_zp_type = dq1_zp_tensor_proto->data_type(); + if (dq2 == nullptr || dq2->OpType() != "DequantizeLinear") { + // Child is not a DQ op. + return false; + } - if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { - return RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, *dq2); + // The Q2 and DQ2 nodes must have equal zero-point and scale values (scalar/constant). + if (!QDQ::IsQDQPairSupported(*q2, *dq2, get_constant_initializer, graph.ModelPath())) { + return false; + } + + dq2_nodes.push_back(dq2); } - if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) { - return RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, *dq2); + bool can_recompute = false; + if (q1_quant_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { + can_recompute = RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, dq2_nodes); + } else if (q1_quant_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) { + can_recompute = RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, dq2_nodes); + } else if (q1_quant_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) { + can_recompute = RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, dq2_nodes); + } else if (q1_quant_type == ONNX_NAMESPACE::TensorProto_DataType_INT16) { + can_recompute = RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, dq2_nodes); } - if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) { - return RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, *dq2); + if (!can_recompute) { + return false; } - if (dq1_zp_type == ONNX_NAMESPACE::TensorProto_DataType_INT16) { - return RecomputeOuterQDQZeroPointAndScale(graph, *q1, *dq1, *q2, *dq2); + graph.RemoveEdge(q1_index, dq1_index, 0, 0); // Disconnect Q1 -> DQ1 + graph.RemoveEdge(dq1_index, q2_index, 0, 0); // Disconnect DQ1 -> Q2 + + // Disconnect Q2 --> DQ2(s) + // Connect Q1 -> DQ2(s) + for (gsl::not_null dq2 : dq2_nodes) { + graph.RemoveEdge(q2_index, dq2->Index(), 0, 0); + graph.AddEdge(q1_index, dq2->Index(), 0, 0); } - return false; // Unsupported zero-point type + graph.RemoveNode(q2_index); + graph.RemoveNode(dq1_index); + + return true; } Status DoubleQDQPairsRemover::ApplyImpl( @@ -200,18 +264,8 @@ Status DoubleQDQPairsRemover::ApplyImpl( const GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); - for (const auto& dq1_index : node_topology_list) { - NodeIndex q1_index = 0; - NodeIndex q2_index = 0; - NodeIndex dq2_index = 0; - if (IsReducibleDoubleQDQSequence(graph, q1_index, dq1_index, q2_index, dq2_index)) { - graph.RemoveEdge(q1_index, dq1_index, 0, 0); - graph.RemoveEdge(dq1_index, q2_index, 0, 0); - graph.RemoveEdge(q2_index, dq2_index, 0, 0); - graph_utils::ReplaceNodeInput(*graph.GetNode(dq2_index), 0, *graph.GetNode(dq1_index)->MutableInputDefs()[0]); - graph.AddEdge(q1_index, dq2_index, 0, 0); - graph.RemoveNode(q2_index); - graph.RemoveNode(dq1_index); + for (NodeIndex node_index : node_topology_list) { + if (TryReduceDoubleQDQSequence(graph, node_index)) { modified = true; } } diff --git a/onnxruntime/core/optimizer/double_qdq_pairs_remover.h b/onnxruntime/core/optimizer/double_qdq_pairs_remover.h index 1833b007674fd..854b3f52de72c 100644 --- a/onnxruntime/core/optimizer/double_qdq_pairs_remover.h +++ b/onnxruntime/core/optimizer/double_qdq_pairs_remover.h @@ -13,6 +13,16 @@ namespace onnxruntime { * Specifically, this transformer converts the sequence Q1 -> DQ1 -> Q2 -> DQ2, where the first pair has (zp1, scale1) * and the second pair has (zp2, scale2), into the sequence Q1 -> DQ2 by removing the middle two nodes. The zero-point * and scale of the final QDQ pair is recomputed to preserve equality to the original sequence. + * + * Also supports multiple identical DQ2 nodes, which may have been inserted by the EnsureUniqueDQNodeUnit optimizer. + * Q1 --> DQ1 --> Q2 --+--> DQ2 + * | + * +--> DQ2' + * + * The above becomes: + * Q1 ---+--> DQ2 + * | + * +--> DQ2' */ class DoubleQDQPairsRemover : public GraphTransformer { public: diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.cc b/onnxruntime/test/optimizer/graph_transform_test_builder.cc index a5024f510b3cd..73c8b3f119103 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.cc @@ -7,6 +7,8 @@ #include #include +#include "core/common/inlined_containers_fwd.h" +#include "core/common/span_utils.h" #include "core/graph/model.h" #include "core/session/inference_session.h" #include "test/compare_ortvalue.h" @@ -20,6 +22,90 @@ namespace onnxruntime { namespace test { +static InlinedVector GetZeroPointBytes(int64_t zero_point, ONNX_NAMESPACE::TensorProto_DataType type) { + switch (type) { + case ONNX_NAMESPACE::TensorProto_DataType_INT8: { + int8_t val = static_cast(zero_point); + auto span = gsl::as_bytes(gsl::make_span(&val, 1)); + return InlinedVector(span.begin(), span.end()); + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { + uint8_t val = static_cast(zero_point); + auto span = gsl::as_bytes(gsl::make_span(&val, 1)); + return InlinedVector(span.begin(), span.end()); + } + case ONNX_NAMESPACE::TensorProto_DataType_INT16: { + int16_t val = static_cast(zero_point); + auto span = gsl::as_bytes(gsl::make_span(&val, 1)); + return InlinedVector(span.begin(), span.end()); + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + uint16_t val = static_cast(zero_point); + auto span = gsl::as_bytes(gsl::make_span(&val, 1)); + return InlinedVector(span.begin(), span.end()); + } + case ONNX_NAMESPACE::TensorProto_DataType_INT32: { + int32_t val = static_cast(zero_point); + auto span = gsl::as_bytes(gsl::make_span(&val, 1)); + return InlinedVector(span.begin(), span.end()); + } + default: + ORT_THROW("Unhandled zero-point type ", type, "."); + } +} + +NodeArg* ModelTestBuilder::MakeInitializer(gsl::span shape, + ONNX_NAMESPACE::TensorProto_DataType elem_type, + gsl::span raw_data) { + std::string name = graph_.GenerateNodeArgName("constant"); + ONNX_NAMESPACE::TensorProto tensor_proto; + tensor_proto.set_name(name); + tensor_proto.set_data_type(elem_type); + tensor_proto.set_raw_data(raw_data.data(), raw_data.size()); + + for (auto& dim : shape) { + tensor_proto.add_dims(dim); + } + + graph_.AddInitializedTensor(tensor_proto); + + return &graph_.GetOrCreateNodeArg(name, nullptr); +} + +Node& ModelTestBuilder::AddQuantizeLinearNode(NodeArg* input_arg, + float input_scale, + int64_t input_zero_point, + ONNX_NAMESPACE::TensorProto_DataType zero_point_type, + NodeArg* output_arg, + bool use_ms_domain) { + std::vector input_args; + input_args.push_back(input_arg); + input_args.push_back(MakeScalarInitializer(input_scale)); + + InlinedVector zp_bytes = GetZeroPointBytes(input_zero_point, zero_point_type); + input_args.push_back(MakeInitializer({}, zero_point_type, zp_bytes)); + + std::string domain = use_ms_domain ? kMSDomain : ""; + return AddNode("QuantizeLinear", input_args, {output_arg}, domain); +} + +Node& ModelTestBuilder::AddDequantizeLinearNode(NodeArg* input_arg, + float input_scale, + int64_t input_zero_point, + ONNX_NAMESPACE::TensorProto_DataType zero_point_type, + NodeArg* output_arg, + bool use_ms_domain) { + std::vector input_args; + input_args.push_back(input_arg); + input_args.push_back(MakeScalarInitializer(input_scale)); + + InlinedVector zp_bytes = GetZeroPointBytes(input_zero_point, zero_point_type); + input_args.push_back(MakeInitializer({}, zero_point_type, zp_bytes)); + + std::string domain = use_ms_domain ? kMSDomain : ""; + return AddNode("DequantizeLinear", input_args, {output_arg}, domain); +} + void TransformerTester(const std::function& build_test_case, const std::function& check_transformed_graph, TransformerLevel baseline_level, diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index 57f10d9a4eb69..fd5770cb70022 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -6,6 +6,7 @@ #include #include +#include "core/common/span_utils.h" #include "core/common/type_utils.h" #include "core/graph/graph.h" #include "core/framework/framework_common.h" @@ -195,21 +196,20 @@ class ModelTestBuilder { return &graph_.GetOrCreateNodeArg(name, &type_proto); } + /// + /// Makes an initializer from the provided shape, element type, and raw data bytes. + /// + /// Initializer shape + /// ONNX tensor element data type + /// Raw data bytes + /// NodeArg pointer + NodeArg* MakeInitializer(gsl::span shape, ONNX_NAMESPACE::TensorProto_DataType elem_type, + gsl::span raw_data); + template NodeArg* MakeInitializer(const std::vector& shape, const std::vector& data) { - std::string name = graph_.GenerateNodeArgName("constant"); - ONNX_NAMESPACE::TensorProto tensor_proto; - tensor_proto.set_name(name); - tensor_proto.set_data_type(utils::ToTensorProtoElementType()); - tensor_proto.set_raw_data(data.data(), data.size() * sizeof(T)); - - for (auto& dim : shape) { - tensor_proto.add_dims(dim); - } - - graph_.AddInitializedTensor(tensor_proto); - - return &graph_.GetOrCreateNodeArg(name, nullptr); + gsl::span raw_data = ReinterpretAsSpan(data); + return MakeInitializer(shape, utils::ToTensorProtoElementType(), raw_data); } // Special handle for std::vector. @@ -342,6 +342,24 @@ class ModelTestBuilder { return AddNode("QuantizeLinear", input_args, {output_arg}, domain, attributes); } + /// + /// Adds a Q node with a configurable zero-point type. + /// Takes in an int64_t zero_point value, which is large enough to represent all ONNX zero-point types. + /// + /// First input to the Q node + /// Input scale value + /// Input zero point value + /// Input zero point's type + /// Q node's output node arg + /// True to use the 'com.microsoft' domain + /// Reference to the new Q node + Node& AddQuantizeLinearNode(NodeArg* input_arg, + float input_scale, + int64_t input_zero_point, + ONNX_NAMESPACE::TensorProto_DataType zero_point_type, + NodeArg* output_arg, + bool use_ms_domain = false); + template typename std::enable_if::value, Node&>::type AddDequantizeLinearNode(NodeArg* input_arg, @@ -400,6 +418,24 @@ class ModelTestBuilder { return AddNode("DequantizeLinear", input_args, {output_arg}, domain, attributes); } + /// + /// Adds a DQ node with a configurable zero-point type. + /// Takes in an int64_t zero_point value, which is large enough to represent all ONNX zero-point types. + /// + /// First input to the DQ node + /// Input scale value + /// Input zero point value + /// Input zero point's type + /// DQ node's output node arg + /// True to use the 'com.microsoft' domain + /// Reference to the new DQ node + Node& AddDequantizeLinearNode(NodeArg* input_arg, + float input_scale, + int64_t input_zero_point, + ONNX_NAMESPACE::TensorProto_DataType zero_point_type, + NodeArg* output_arg, + bool use_ms_domain = false); + template Node& AddQLinearConvNode(NodeArg* input_arg, float input_scale, diff --git a/onnxruntime/test/optimizer/qdq_test_utils.cc b/onnxruntime/test/optimizer/qdq_test_utils.cc index 24cace43c6967..a2f32562658c3 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.cc +++ b/onnxruntime/test/optimizer/qdq_test_utils.cc @@ -5,6 +5,8 @@ #include #include #include "core/common/common.h" +#include "core/common/inlined_containers_fwd.h" +#include "core/common/span_utils.h" namespace onnxruntime { namespace test { @@ -164,5 +166,52 @@ std::vector GetNodeOpTypesInTopologicalOrder(const Graph& graph, bo return op_types; } +GetQDQTestCaseFn BuildDoubleQDQTestCaseWithDuplicateLastDQs( + gsl::span input_shape, + gsl::span input_data, + gsl::span zero_points, + gsl::span zero_point_types, + gsl::span scales, + size_t graph_output_index, + bool use_contrib_qdq) { + const size_t num_nodes = zero_points.size(); + bool valid_inputs = (num_nodes >= 4) && + (zero_point_types.size() == num_nodes) && + (scales.size() == num_nodes) && + (graph_output_index < 4); + if (!valid_inputs) { + ORT_THROW("Invalid inputs for call to BuildDoubleQDQTestCaseWithDuplicateLastDQs()"); + } + + return [=](ModelTestBuilder& builder) { + // TODO(adrianlizarraga): Clean up ModelTestBuilder functions (like MakeInput) to work with gsl::span inputs. + // For now, we have to copy data into a std::vector if we want this outer function to take in span inputs. + std::vector input_shape_copy(input_shape.begin(), input_shape.end()); + std::vector input_data_copy(input_data.begin(), input_data.end()); + auto* input_arg = builder.MakeInput(input_shape_copy, input_data_copy); + InlinedVector node_outputs(num_nodes); + + for (size_t i = 0; i < num_nodes; i++) { + if (i == graph_output_index || i >= 3) { + node_outputs[i] = builder.MakeOutput(); + } else { + node_outputs[i] = builder.MakeIntermediate(); + } + } + + builder.AddQuantizeLinearNode(input_arg, scales[0], zero_points[0], zero_point_types[0], node_outputs[0], + use_contrib_qdq); + builder.AddDequantizeLinearNode(node_outputs[0], scales[1], zero_points[1], zero_point_types[1], node_outputs[1], + use_contrib_qdq); + builder.AddQuantizeLinearNode(node_outputs[1], scales[2], zero_points[2], zero_point_types[2], node_outputs[2], + use_contrib_qdq); + + for (size_t i = 3; i < num_nodes; i++) { + builder.AddDequantizeLinearNode(node_outputs[2], scales[i], zero_points[i], zero_point_types[i], + node_outputs[i], use_contrib_qdq); + } + }; +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/optimizer/qdq_test_utils.h index 414a0fbeb78f5..3dab0ec248f95 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.h +++ b/onnxruntime/test/optimizer/qdq_test_utils.h @@ -8,6 +8,7 @@ #include "graph_transform_test_builder.h" +#include "core/common/span_utils.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/session/inference_session.h" @@ -460,6 +461,27 @@ GetQDQTestCaseFn BuildDoubleQDQTestCases(Type1 zp_1, Type2 zp_2, Type3 zp_3, Typ }; } +/// +/// Returns a function that builds a model with a double QDQ sequence (Q1 -> DQ1 -> Q2 -> DQ2*), +/// where DQ2 can be repeated. Must provide at least 4 zero-point and scale values. +/// +/// Shape of input float data. +/// Input float data. +/// Ordered list of zero-point values for each node in the sequence. +/// Ordered list of zero-point types for each node in the sequence. +/// Ordered list of scale values for each node in the sequence. +/// Index of the node that provides a graph output. +/// Set to true to use the 'com.microsoft' domain for Q and DQ ops. +/// A function for building the model +GetQDQTestCaseFn BuildDoubleQDQTestCaseWithDuplicateLastDQs( + gsl::span input_shape, + gsl::span input_data, + gsl::span zero_points, + gsl::span zero_point_types, + gsl::span scales, + size_t graph_output_index, + bool use_contrib_qdq = false); + template GetQDQTestCaseFn BuildDoubleQDQWithoutLastOutput(int output_index, bool use_contrib_qdq = false) { return [=](ModelTestBuilder& builder) { diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 8c138b22bd52b..31e2280187f76 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -2,11 +2,14 @@ // Licensed under the MIT License. #include +#include "core/common/inlined_containers_fwd.h" +#include "core/common/span_utils.h" #include "core/framework/compute_capability.h" #include "core/framework/node_unit.h" #include "core/graph/model.h" #include "core/graph/onnx_protobuf.h" #include "core/mlas/inc/mlas.h" +#include "core/optimizer/double_qdq_pairs_remover.h" #include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" @@ -1233,6 +1236,129 @@ TEST(QDQTransformerTests, DoubleQDQ_Without_Last_Node_Being_Output) { RunDoubleQDQWithoutLastNodeBeingOutput(3, 1, 1, !use_ms_qdq, 21); } +// Utility function that runs a model with a double QDQ sequence (with duplicate end DQs) through +// the DoubleQDQPairsRemover transformer and checks that the resulting graph contains the expected nodes. +// Also checks that the output from the unmodified model matches the output from the modified model. +static void RunDoubleQDQWithDuplicateLastDQs(int expected_Q_count, int expected_DQ_count, + gsl::span input_shape, + gsl::span input_data, + gsl::span zero_points, + gsl::span zero_point_types, + gsl::span scales, + size_t graph_output_index, + bool use_contrib_qdq = false, + int opset = 19) { + auto graph_checker = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], expected_Q_count); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], expected_DQ_count); + }; + + auto model_build_fn = BuildDoubleQDQTestCaseWithDuplicateLastDQs(input_shape, input_data, zero_points, + zero_point_types, scales, graph_output_index, + use_contrib_qdq); + TransformerTester(model_build_fn, + graph_checker, + TransformerLevel::Default, + TransformerLevel::Level1, + opset, + /*per_sample_tolerance*/ 0.0, + /*relative_per_sample_tolerance*/ 0.0, + std::make_unique()); +} + +// Test QDQDoublePairsRemover when the sequence ends with duplicate DQs. +TEST(QDQTransformerTests, DoubleQDQPairsRemover_DuplicateLastDQs) { + InlinedVector shape = {1, 2, 2, 2}; + InlinedVector input_data = {-3.0f, -2.0f, -1.0f, 0.0f, 0.5f, 1.0f, 2.0f, 3.0f}; + + constexpr auto int8_type = ONNX_NAMESPACE::TensorProto_DataType_INT8; + constexpr auto uint8_type = ONNX_NAMESPACE::TensorProto_DataType_UINT8; + constexpr auto int16_type = ONNX_NAMESPACE::TensorProto_DataType_INT16; + constexpr auto uint16_type = ONNX_NAMESPACE::TensorProto_DataType_UINT16; + InlinedVector quant_types = {int8_type, uint8_type, int16_type, uint16_type}; + + // Input graph: + // input -> Q1 -> DQ1 -> Q2 --+--> DQ2 -> output0 + // | + // ... + // | + // +--> DQ2'' -> outputN + // Expected graph after DoubleQDQPairsRemover: + // input -> Q1 --+--> DQ2 -> output0 + // | + // ... + // | + // +--> DQ2'' -> outputN + for (auto quant_type : quant_types) { + for (size_t num_dq2s = 1; num_dq2s <= 3; num_dq2s++) { + const size_t num_nodes = 3 + num_dq2s; + InlinedVector zp_vals(num_nodes, 1); + InlinedVector zp_types(num_nodes, quant_type); + InlinedVector scale_vals(num_nodes, 0.1f); + + const int expected_q_nodes = 1; + const int expected_dq_nodes = static_cast(num_dq2s); + RunDoubleQDQWithDuplicateLastDQs(expected_q_nodes, expected_dq_nodes, shape, input_data, zp_vals, zp_types, + scale_vals, 3, false, 21); + RunDoubleQDQWithDuplicateLastDQs(expected_q_nodes, expected_dq_nodes, shape, input_data, zp_vals, zp_types, + scale_vals, 3, quant_type == int16_type || quant_type == uint16_type, 19); + } + } + + // Should not remove QDQ pair because the middle nodes produce a graph output. + for (auto quant_type : quant_types) { + for (size_t output_index = 0; output_index < 3; output_index++) { + for (size_t num_dq2s = 1; num_dq2s <= 3; num_dq2s++) { + const size_t num_nodes = 3 + num_dq2s; + InlinedVector zp_vals(num_nodes, 1); + InlinedVector zp_types(num_nodes, quant_type); + InlinedVector scale_vals(num_nodes, 0.1f); + + const int expected_q_nodes = 2; + int expected_dq_nodes = 1 + static_cast(num_dq2s); + if (output_index == 1) { + // EnsureUniqueDQ pass will create a duplicate DQ if it produces a graph output. + expected_dq_nodes += 1; + } + RunDoubleQDQWithDuplicateLastDQs(expected_q_nodes, expected_dq_nodes, shape, input_data, zp_vals, zp_types, + scale_vals, output_index, false, 21); + } + } + } + + // Should not remove any nodes because the Q -> DQ pairs are of different quant types. + for (size_t num_dq2s = 1; num_dq2s <= 2; num_dq2s++) { + const size_t num_nodes = 3 + num_dq2s; + InlinedVector zp_vals(num_nodes, 1); + InlinedVector zp_types(num_nodes, int8_type); + for (size_t i = 2; i < num_nodes; i++) { + // Q2 -> DQ2* have a different type + zp_types[i] = int16_type; + } + InlinedVector scale_vals(num_nodes, 0.1f); + + const int expected_q_nodes = 2; + const int expected_dq_nodes = 1 + static_cast(num_dq2s); + RunDoubleQDQWithDuplicateLastDQs(expected_q_nodes, expected_dq_nodes, shape, input_data, zp_vals, zp_types, + scale_vals, 3, false, 21); + } + + // Should not remove nodes because 1 of the ending DQ2s has a different zero_point. + const size_t num_dq2s = 2; + const size_t num_nodes = 3 + num_dq2s; + InlinedVector zp_vals(num_nodes, 1); + InlinedVector zp_types(num_nodes, int8_type); + zp_vals[num_nodes - 1] = 2; // Last DQ2 has a different zero-point. + InlinedVector scale_vals(num_nodes, 0.1f); + + const int expected_q_nodes = 2; + const int expected_dq_nodes = 1 + static_cast(num_dq2s); + RunDoubleQDQWithDuplicateLastDQs(expected_q_nodes, expected_dq_nodes, shape, input_data, zp_vals, zp_types, + scale_vals, 3, false, 21); +} + // Runs a test that checks if DQ -> Split -> Q (many) is replaced with just Split. template static void RunDropSplitQDQTestCase(const std::vector& input_shape, int64_t axis,