From 7c573054b61bb44e5ee690fbee80aab359b28282 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Tue, 21 Nov 2023 21:31:31 -0800 Subject: [PATCH] [QDQ Optimizer] Fix logic that drops Q/DQ ops from QDQ split node groups (#18394) ### Description - Fix QDQ optimizer logic that drops Q/DQ ops from Split node groups so that it only occurs when all input/output quantization parameters are equal. - Currently, the selector used for this optimization does not ensure that all quantization parameters are equal. - Support dropping Q/DQ ops from Split node groups with optional split inputs (introduced opset 13). This was not working previously. ### Motivation and Context Fix bugs in handling of QDQ Split node groups. --------- Signed-off-by: adrianlizarraga --- .../selectors_actions/qdq_actions.cc | 22 +++++++--- .../qdq_selector_action_transformer.cc | 2 +- .../selectors_actions/qdq_selectors.cc | 34 +++++++++++++- .../selectors_actions/qdq_selectors.h | 25 +++++++++-- .../selectors_actions/shared/utils.cc | 15 ++++++- onnxruntime/test/optimizer/qdq_test_utils.h | 37 +++++++++++----- .../test/optimizer/qdq_transformer_test.cc | 44 ++++++++++++++----- 7 files changed, 147 insertions(+), 32 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index f42766267b0f9..3d2a81ce7f8cd 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -87,12 +87,19 @@ std::vector WhereMoves() { MoveAll(q, ArgType::kOutput)}; return moves; } -QDQReplaceWithNew SplitReplacer() { +QDQReplaceWithNew SplitReplacer(bool has_split_as_input) { NTO::NodeLocation dq{NTO::NodeType::kInput, 0}; + NTO::NodeLocation target{NTO::NodeType::kTarget, 0}; NTO::NodeLocation q{NTO::NodeType::kOutput, 0}; - std::vector moves{ - MoveAndAppend(dq, ArgType::kInput, 0, ArgType::kInput), - MoveAll(q, ArgType::kOutput)}; + std::vector moves{MoveAndAppend(dq, ArgType::kInput, 0, ArgType::kInput)}; + + if (has_split_as_input) { + // Move the optional split input to the new node. + moves.push_back(MoveAndAppend(target, ArgType::kInput, 1, ArgType::kInput, true)); + } + + moves.push_back(MoveAll(q, ArgType::kOutput)); + return QDQReplaceWithNew(kOnnxDomain, "Split", std::move(moves)); } @@ -247,7 +254,12 @@ MatMulReplaceWithQLinear::MatMulReplaceWithQLinear() } Status SplitReplaceWithQuant::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { - return SplitReplacer().Run(graph, selected_nodes); + const auto& target_node = selected_nodes.Target(); + const auto& input_defs = target_node.InputDefs(); + + // The 'split' attribute became an optional input at opset 13. + bool has_split_as_input = target_node.SinceVersion() >= 13 && input_defs.size() == 2; + return SplitReplacer(has_split_as_input).Run(graph, selected_nodes); } Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 0e383c3031ca6..29178fe87f75c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -20,7 +20,7 @@ void SplitQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { const std::string action_name{"dropSplitQDQ"}; std::unique_ptr action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) - std::unique_ptr selector = std::make_unique(); + std::unique_ptr selector = std::make_unique(true /*req_equal_quant_params*/); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"Split", {}}}, std::move(selector), diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 3880288bdba2e..15b501c667046 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -253,7 +253,39 @@ void InputVariadicSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder builder.num_input_defs = 1; // set to 1 as the first input is variadic } -void OutputVariadicSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { +bool SplitNodeGroupSelector::Check(const GraphViewer& graph_viewer, + const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, 1)) { + return false; + } + + auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { + return graph_viewer.GetConstantInitializer(initializer_name, true); + }; + + const Node& dq_node = *dq_nodes.front(); + int32_t dt_input = dq_node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + + // All Q outputs should have same data type and (optionally) equal quantization parameters as the input. + for (size_t q_idx = 0; q_idx < q_nodes.size(); q_idx++) { + const Node& q_node = *q_nodes[q_idx]; + + if (dt_input != q_node.OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()) { + return false; + } + + if (req_equal_quant_params_ && + !IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath())) { + return false; + } + } + + return true; +} + +void SplitSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { builder.num_output_defs = 1; // set to 1 as the first output is variadic } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index be7f7e0288eda..d0d7fb2c2af17 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -115,6 +115,24 @@ class VariadicNodeGroupSelector : public NodeGroupSelector { bool allow_16bit_; }; +// DQ node -> Split -> multiple Q nodes with equal quantization types. +// Optionally, the selector can require all input and output quantization parameters to be +// equal and constant. +class SplitNodeGroupSelector : public NodeGroupSelector { + public: + explicit SplitNodeGroupSelector(bool req_equal_quant_params = false) + : req_equal_quant_params_(req_equal_quant_params) {} + + private: + bool Check(const GraphViewer& graph_viewer, const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; + + bool req_equal_quant_params_; // If true, only selects a node group if the input and output + // quantization parameters are all equal/constant, which enables the + // optimizer to drop the Q/DQ ops if the group is assigned to the CPU EP. +}; + // DQ nodes for X, W and optionally B -> node -> Q class ConvNodeGroupSelector : public NodeGroupSelector { public: @@ -288,10 +306,11 @@ class InputVariadicSelector : public BaseSelector { void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; -// DQ -> node -> Variadic Q nodes -class OutputVariadicSelector : public BaseSelector { +// DQ -> Split -> variadic Q nodes +class SplitSelector : public BaseSelector { public: - OutputVariadicSelector() : BaseSelector(std::make_unique()) {} + SplitSelector(bool req_equal_quant_params = false) + : BaseSelector(std::make_unique(req_equal_quant_params)) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index 1a4d3a0c18151..e2aa25897ee06 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -27,6 +27,9 @@ void Selectors::RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops } /* static methods to return different operator's OpVersionMap */ + +// These are operators that do not change the data and therefore the input DQ and +// output Q have the same scale and zero_point. static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { return {{"Gather", {}}, {"Reshape", {}}, @@ -35,7 +38,6 @@ static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { {"Transpose", {}}, {"MaxPool", {12}}, {"Resize", {}}, - {"Split", {}}, {"Squeeze", {}}, {"Unsqueeze", {}}, {"Tile", {}}}; @@ -97,6 +99,9 @@ static const OpVersionsAndSelector::OpVersionsMap GetVariadicOpVersionsMap() { {"Max", {}}, {"Min", {}}}; } +static const OpVersionsAndSelector::OpVersionsMap GetSplitOpVersionsMap() { + return {{"Split", {}}}; +} static const OpVersionsAndSelector::OpVersionsMap GetConvOpVersionsMap() { return {{"Conv", {}}}; } @@ -170,6 +175,13 @@ void RegisterVariadicSelectors(Selectors& qdq_selectors) { std::move(selector)); } +void RegisterSplitSelector(Selectors& qdq_selectors) { + /* register selectors for Split op */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetSplitOpVersionsMap(), + std::move(selector)); +} + void RegisterConvSelector(Selectors& qdq_selectors) { /* register selector for conv op */ std::unique_ptr selector = std::make_unique(); @@ -247,6 +259,7 @@ void SelectorManager::CreateSelectors() { RegisterUnarySelectors(qdq_selectors_); RegisterBinarySelectors(qdq_selectors_); RegisterVariadicSelectors(qdq_selectors_); + RegisterSplitSelector(qdq_selectors_); RegisterConvSelector(qdq_selectors_); RegisterConvTransposeSelector(qdq_selectors_); RegisterMatMulSelector(qdq_selectors_); diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/optimizer/qdq_test_utils.h index 2008d96539dca..e64117925eb57 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.h +++ b/onnxruntime/test/optimizer/qdq_test_utils.h @@ -466,11 +466,11 @@ GetQDQTestCaseFn BuildDoubleQDQWithoutLastOutput(int output_index, bool use_cont } template -GetQDQTestCaseFn BuildQDQSplitTestCase( - const std::vector& input_shape, - const int64_t& axis, - bool use_contrib_qdq = false) { - return [input_shape, axis, use_contrib_qdq](ModelTestBuilder& builder) { +GetQDQTestCaseFn BuildQDQSplitTestCase(const std::vector& input_shape, + const int64_t& axis, + bool use_diff_output_scale, + bool use_contrib_qdq = false) { + return [input_shape, axis, use_diff_output_scale, use_contrib_qdq](ModelTestBuilder& builder) { auto* input_arg = builder.MakeInput(input_shape, std::numeric_limits::min(), std::numeric_limits::max()); @@ -478,16 +478,30 @@ GetQDQTestCaseFn BuildQDQSplitTestCase( InputType dq_zp = std::numeric_limits::max() / 2; OutputType q_zp = std::numeric_limits::max() / 2; auto* dq_output = builder.MakeIntermediate(); - builder.AddDequantizeLinearNode(input_arg, .003f, dq_zp, dq_output, use_contrib_qdq); + constexpr float input_scale = 0.003f; + builder.AddDequantizeLinearNode(input_arg, input_scale, dq_zp, dq_output, use_contrib_qdq); // add Split + std::vector split_inputs; + split_inputs.push_back(dq_output); + + // Use the optional 'split' input when testing Split 13 + int opset = builder.DomainToVersionMap().find(kOnnxDomain)->second; + if (opset >= 13 && opset < 18) { + int64_t dim = input_shape[axis]; + int64_t split_size = dim / 3; + split_inputs.push_back(builder.Make1DInitializer(std::vector{split_size, + split_size, dim - (2 * split_size)})); + } auto* split_output_1 = builder.MakeIntermediate(); auto* split_output_2 = builder.MakeIntermediate(); auto* split_output_3 = builder.MakeIntermediate(); - Node& split_node = builder.AddNode("Split", {dq_output}, {split_output_1, split_output_2, split_output_3}); + Node& split_node = builder.AddNode("Split", split_inputs, {split_output_1, split_output_2, split_output_3}); split_node.AddAttribute("axis", axis); - if (builder.DomainToVersionMap().find(kOnnxDomain)->second >= 18) { + + // Use the 'num_outputs' attribute when testing Split >= 18 + if (opset >= 18) { split_node.AddAttribute("num_outputs", static_cast(3)); } @@ -495,11 +509,12 @@ GetQDQTestCaseFn BuildQDQSplitTestCase( auto* q_split_output_1 = builder.MakeOutput(); auto* q_split_output_2 = builder.MakeOutput(); auto* q_split_output_3 = builder.MakeOutput(); - builder.AddQuantizeLinearNode(split_output_1, .003f, q_zp, q_split_output_1, + float output_scale = use_diff_output_scale ? input_scale + 0.001f : input_scale; + builder.AddQuantizeLinearNode(split_output_1, output_scale, q_zp, q_split_output_1, use_contrib_qdq); // Model input (node_token_1) - builder.AddQuantizeLinearNode(split_output_2, .003f, q_zp, q_split_output_2, + builder.AddQuantizeLinearNode(split_output_2, output_scale, q_zp, q_split_output_2, use_contrib_qdq); // Model input (node_token_2) - builder.AddQuantizeLinearNode(split_output_3, .003f, q_zp, q_split_output_3, + builder.AddQuantizeLinearNode(split_output_3, output_scale, q_zp, q_split_output_3, use_contrib_qdq); }; } diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 1bf1cbacf479e..17dd2e80f9f88 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -1210,27 +1210,51 @@ TEST(QDQTransformerTests, DoubleQDQ_Without_Last_Node_Being_Output) { // Runs a test that checks if DQ -> Split -> Q (many) is replaced with just Split. template static void RunDropSplitQDQTestCase(const std::vector& input_shape, int64_t axis, - bool use_contrib_qdq = false) { - auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) { + bool all_same_quant_params, bool use_contrib_qdq = false) { + auto check_graph = [all_same_quant_params, use_contrib_qdq](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + int expected_q_ops = all_same_quant_params ? 0 : 3; + int expected_dq_ops = all_same_quant_params ? 0 : 1; EXPECT_EQ(op_to_count["Split"], 1); - EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); - EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], expected_q_ops); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], expected_dq_ops); }; - TransformerTester(BuildQDQSplitTestCase(input_shape, axis, use_contrib_qdq), + TransformerTester(BuildQDQSplitTestCase(input_shape, axis, !all_same_quant_params, + use_contrib_qdq), check_graph, TransformerLevel::Level1, TransformerLevel::Level2, - {12, 18, 19}); + {12, 13, 18, 19}); // Test different ways to specify the split in each opset: + // 12 - split into equal parts without explicit 'split' attribute + // 13 - use optional 'split' input to split into 3 parts + // 18 - use 'num_outputs' attribute to split into 3 parts + // 19 - use 'num_outputs' attribute to split into 3 parts } // Test that DQ -> Split -> Q (many) is replaced with just Split for various quantization types. TEST(QDQTransformerTests, Split) { - RunDropSplitQDQTestCase({6, 18, 54}, 0); - RunDropSplitQDQTestCase({6, 18, 54}, 0, true); // Use com.microsoft int8 QDQ ops - RunDropSplitQDQTestCase({6, 18, 54}, 0, true); // Use com.microsoft int16 QDQ ops - RunDropSplitQDQTestCase({6, 18, 54}, 0, true); // Use com.microsoft uint16 QDQ ops + // Test cases that drop Q/DQ ops from DQ -> Split -> Q (many). + // This happens when all the Q/DQ ops have equal and constant quantization parameters. + { + constexpr bool ALL_SAME_QUANT_PARAMS = true; + constexpr bool USE_CONTRIB_QDQ_OPS = true; + RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS); + RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS); + RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS); + RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS); + } + + // Test cases that DO NOT drop Q/DQ ops from DQ -> Split -> Q (many) + // This happens when the Q/DQ ops do not have equal and constant quantization parameters. + { + constexpr bool DIFF_QUANT_PARAMS = false; + constexpr bool USE_CONTRIB_QDQ_OPS = true; + RunDropSplitQDQTestCase({6, 18, 54}, 0, DIFF_QUANT_PARAMS); + RunDropSplitQDQTestCase({6, 18, 54}, 0, DIFF_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS); + RunDropSplitQDQTestCase({6, 18, 54}, 0, DIFF_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS); + RunDropSplitQDQTestCase({6, 18, 54}, 0, DIFF_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS); + } } // Because split isn't one the supported ops, this will stay the same