From 08cf4fbcad784bd69a9e9d8102c1800ea9dbc2f1 Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield <38966965+jeffbloo@users.noreply.github.com> Date: Thu, 11 Jan 2024 15:16:44 -0800 Subject: [PATCH] Handle all float types in IsQDQPairSupported (#19085) ### Description This makes detection of identical QDQ scales work with float16 and bfloat16 rather than failing. ### Motivation and Context This addresses failures in customer models --- .../optimizer/qdq_transformer/qdq_util.cc | 23 ++++++- .../test/optimizer/graph_transform_test.cc | 61 ++++++++++-------- .../qdq_optimization/dup_qdq_bfloat16.onnx | Bin 0 -> 1446 bytes .../qdq_optimization/dup_qdq_float16.onnx | Bin 0 -> 1446 bytes 4 files changed, 53 insertions(+), 31 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/qdq_optimization/dup_qdq_bfloat16.onnx create mode 100644 onnxruntime/test/testdata/transform/qdq_optimization/dup_qdq_float16.onnx diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc index 221c06d7c8dcf..b1ab641a23256 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc @@ -54,9 +54,26 @@ bool IsQDQPairSupported( Initializer dq_zp(*dq_zp_tensor_proto, model_path); Initializer dq_scale(*dq_scale_tensor_proto, model_path); - return q_zp.data_type() == dq_zp.data_type() && - SpanEq(q_zp.DataAsByteSpan(), dq_zp.DataAsByteSpan()) && - *q_scale.data() == *dq_scale.data(); + if (q_zp.data_type() != dq_zp.data_type() || + q_scale.data_type() != q_scale.data_type() || + !SpanEq(q_zp.DataAsByteSpan(), dq_zp.DataAsByteSpan())) { + return false; + } + + switch (q_scale.data_type()) { + case ONNX_NAMESPACE::TensorProto::FLOAT: + return *q_scale.data() == *dq_scale.data(); + + case ONNX_NAMESPACE::TensorProto::FLOAT16: + return *q_scale.data() == *dq_scale.data(); + + case ONNX_NAMESPACE::TensorProto::BFLOAT16: + return *q_scale.data() == *dq_scale.data(); + + default: + assert(false); + return false; + } } bool IsDQSupported(const Node& dq_node, const GetConstantInitializerFn& get_const_initializer) { diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index ef6e2d531bc1a..5adcb3c150b8d 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -4602,38 +4602,43 @@ TEST_F(GraphTransformationTests, GeluApproximation_SessionOptionConfig) { } // Test DoubleQDQPairsRemover to remove unnecessary DQ->Q nodes in the middle -TEST_F(GraphTransformationTests, DoublQDQRemover_RemoveDupQDQ) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "qdq_optimization/dup_qdq.onnx"; - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - Graph& graph = p_model->MainGraph(); +TEST_F(GraphTransformationTests, DoublQDQRemover_RemoveDupQDQ_Float16) { + auto RunTest = [this](const ORTCHAR_T* model_uri) { + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); - std::map op_to_count = CountOpsInGraph(graph); - EXPECT_EQ(op_to_count["QuantizeLinear"], 3); - EXPECT_EQ(op_to_count["DequantizeLinear"], 4); + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["QuantizeLinear"], 3); + EXPECT_EQ(op_to_count["DequantizeLinear"], 4); - std::string dq_scale_name_before_reshape_node; - std::string zp_name_before_reshape_node; - std::string dq_scale_name_after_reshape_node; - std::string zp_name_after_reshape_node; - for (auto& node : graph.Nodes()) { - if (node.Name() == "dq_2") { - dq_scale_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name(); - zp_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name(); - } - if (node.Name() == "q_3") { - dq_scale_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name(); - zp_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name(); + std::string dq_scale_name_before_reshape_node; + std::string zp_name_before_reshape_node; + std::string dq_scale_name_after_reshape_node; + std::string zp_name_after_reshape_node; + for (auto& node : graph.Nodes()) { + if (node.Name() == "dq_2") { + dq_scale_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name(); + zp_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name(); + } + if (node.Name() == "q_3") { + dq_scale_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name(); + zp_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name(); + } } - } - EXPECT_EQ(dq_scale_name_before_reshape_node.empty(), false); - EXPECT_EQ(zp_name_before_reshape_node.empty(), false); - EXPECT_EQ(dq_scale_name_before_reshape_node, dq_scale_name_after_reshape_node); - EXPECT_EQ(zp_name_before_reshape_node, zp_name_after_reshape_node); + EXPECT_EQ(dq_scale_name_before_reshape_node.empty(), false); + EXPECT_EQ(zp_name_before_reshape_node.empty(), false); + EXPECT_EQ(dq_scale_name_before_reshape_node, dq_scale_name_after_reshape_node); + EXPECT_EQ(zp_name_before_reshape_node, zp_name_after_reshape_node); + }; + + RunTest(MODEL_FOLDER "qdq_optimization/dup_qdq.onnx"); + RunTest(MODEL_FOLDER "qdq_optimization/dup_qdq_float16.onnx"); + RunTest(MODEL_FOLDER "qdq_optimization/dup_qdq_bfloat16.onnx"); } // Test Gelu -> FastGelu diff --git a/onnxruntime/test/testdata/transform/qdq_optimization/dup_qdq_bfloat16.onnx b/onnxruntime/test/testdata/transform/qdq_optimization/dup_qdq_bfloat16.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a04b7781fca402f7e2f185e030b4648ecc1c96ad GIT binary patch literal 1446 zcmcgs&1>5*6p!OJabJ=l)T|AmY!K+h5UAyN4ZY3u*h^c+XfML61Pe{%k2)h{^t@C1 zH&=3lqmAC@4&lb(McPtvn2+zo|Tjrl5AM00-j)PYCn2aM5~h?Xg#@iGwUY|R;b z36u2UcaU&K9eN^%Gxca1-g7}+j}Z8*eYNIENTYBjQry#)Sv)L;BXrI;3>P!Og9Pqi zH={1|x!*Z03-gQ3;+NC literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/qdq_optimization/dup_qdq_float16.onnx b/onnxruntime/test/testdata/transform/qdq_optimization/dup_qdq_float16.onnx new file mode 100644 index 0000000000000000000000000000000000000000..691da77969b1c1f1575d459e017efd25e4180e37 GIT binary patch literal 1446 zcmcgs%}c{D6kk`J^=*?9=LbU+23`t-HC=Uhvv~BPBBGa2O~B!ne%OgjJbTi=*`#)w z(u&Z78wtGR_4o0TJi|b(fb-`*d-kW{gdM#%;T1YQrDV**S%OJ4^La9!GYY#woIKw9 zF{8vL12OE8OOoJ$apJfb!-n$JoW%hNgE3D~TUn-2Hy`%Q6SiPD9dqW#a0kmNb&=0G z7%lh?Z5PZmjVM&jZz>XX)u6X_mf6Tirl4+Dplbb42~)YK#M2OR8IUgmT9#mTcl9pd zPlN}Q_4H#(;R3!kit?~MTvWn!F_<)xXQEU>j|85a3eGPEsH z$iY%7E|-ccKuRe%m5NKGj#8-@+A$N?Ib?VDg} zFWw)`#o7^awVGGnV}1q6cMNM=fYjeKuh2Pu1;{@%YfGviUg5o2+oGN;N4nNPUjAC$ m>}Bl?Z9>NYWJ2B0#AswgzukfC2oVBPWVDnFk_Tl%ehb literal 0 HcmV?d00001