From 5678317bafd219e2b71c72431905b776460e11a4 Mon Sep 17 00:00:00 2001 From: Yvonne Chen Date: Thu, 11 Jan 2024 10:36:33 +0800 Subject: [PATCH] Fix the duplicated QDQ attributes setup issue (#18039) ### Description The copied QDQ node should have exactly the same attributes as the original QDQ node. Otherwise, it might cause errors when the original node has attributes that use non default values (such as axis != 1 case). An example user case is like: A DequantizeLinear node has more than 1 consumer in the graph, and its attributes axis is 0. ### Motivation and Context I see the errors like https://github.com/microsoft/onnxruntime/issues/16188 and this fix could solve the issue. --- .../ensure_unique_dq_for_node_unit.cc | 2 +- .../ensure_unique_dq_for_node_unit_test.cc | 40 ++++++++++++++++++ .../qdq_with_multi_consumer_q_dq_axis.onnx | Bin 0 -> 9361 bytes 3 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/test/testdata/qdq_with_multi_consumer_q_dq_axis.onnx diff --git a/onnxruntime/core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.cc b/onnxruntime/core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.cc index cc0f7854791d4..9d53e28921784 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.cc @@ -53,7 +53,7 @@ Status DuplicateDQForOutputEdge(const graph_utils::GraphEdge& original_dq_output MakeString("Added by ", kTransformerName), dq_inputs, {&new_dq_output_nodearg}, - nullptr, // attributes + &original_dq_node.GetAttributes(), original_dq_node.Domain()); // set up edges diff --git a/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc b/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc index 7a67747f7cf4c..89ffb8ec87dcb 100644 --- a/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc +++ b/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc @@ -234,4 +234,44 @@ TEST(EnsureUniqueDQForNodeUnitTests, QDQWithMultiConsumerDQNodes) { EXPECT_EQ(OpCount(op_count_before, "DequantizeLinear") + 4, OpCount(op_count_after, "DequantizeLinear")); } +TEST(EnsureUniqueDQForNodeUnitTests, QDQWithMultiConsumerDQNodesPreservingAttributes) { + constexpr auto model_uri = ORT_TSTR("testdata/qdq_with_multi_consumer_q_dq_axis.onnx"); + + SessionOptions session_options{}; + // test interaction with level 1 transformers + session_options.graph_optimization_level = TransformerLevel::Level1; + + InferenceSessionWrapper session{session_options, GetEnvironment()}; + + ASSERT_STATUS_OK(session.Load(model_uri)); + + const auto op_count_before = CountOpsInGraph(session.GetGraph()); + + ASSERT_STATUS_OK(session.Initialize()); + + const auto op_count_after = CountOpsInGraph(session.GetGraph()); + + EXPECT_EQ(OpCount(op_count_before, "DequantizeLinear") + 8, OpCount(op_count_after, "DequantizeLinear")); + + int64_t given_axis = 0; // all the following 4 DQ nodes and their duplicated one should have axis = 0 + std::string axis_dq_name0 = "Convolution28_Output_0/fusedmuladd_B/DequantizeLinear"; + std::string axis_dq_name1 = "Parameter5/DequantizeLinear"; + std::string axis_dq_name2 = "Convolution110_Output_0/fusedmuladd_B/DequantizeLinear"; + std::string axis_dq_name3 = "Parameter87/DequantizeLinear"; + for (const auto& node : session.GetGraph().Nodes()) { + if (node.OpType() == "DequantizeLinear") { + if (node.Name().find(axis_dq_name0) == 0 || + node.Name().find(axis_dq_name1) == 0 || + node.Name().find(axis_dq_name2) == 0 || + node.Name().find(axis_dq_name3) == 0) { + const auto& attrs = node.GetAttributes(); + ASSERT_TRUE(attrs.find("axis") != attrs.end()); + const auto& axis_attr = attrs.at("axis"); + int64_t axis = axis_attr.i(); + EXPECT_EQ(axis, given_axis); + } + } + } +} + } // namespace onnxruntime::test diff --git a/onnxruntime/test/testdata/qdq_with_multi_consumer_q_dq_axis.onnx b/onnxruntime/test/testdata/qdq_with_multi_consumer_q_dq_axis.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4f575ebb2841a02802754e2a449a05ad58220afc GIT binary patch literal 9361 zcmcgy2~<;O+Rh5$W)CDJA?&LG*&)W0gi3^9sdLUR=N`E4eYfv@p6|Q& z`(9D*B+sm^+p_|;?Ap3DVB68W1MeUBDEAm0vTA%2a0?2Mw`U*Aem^%a_gD%m=7SU2 zhx5S4`v9V|`1m{7hjIamvj}u?Wbe@<2M*_v7|!x_xj*}O<-p|^w2J`Ex))Mugr zMD~XVj#DA|P*^-6b0(pkzdxG0DgX?X!d?Xd3ZQ`3lZZ)9DHEKevl-8iEh)w+VIL>O zh4JandiXr)7Q%p-wMP#B?#Q7Nc?XUhj*E-kyX{2Y(Gz)lV`JVuaXfe5`zH=%@7uRG z-DxC|%QKiAWbE>koyMQMJc~u8c}}}usv7ZWkQ3r(Q2q%9L_4J!pYqF9V=ha7wrI|f zEbZH=W=YF(zNlo+XZBnb%Ld#w9|n1q0AOZ+$IK3uICtkX7}T67lm6tdAlzdVz_~v= zZ(UBk_kbt*zZ}G4F37q}`uv}ZN9HonC={5t0iHD_f#<6*3BM!vl~)tupRkIKeG1?p z#LO*|2B5M}Yt5GRdtcqYef^GnD4a5VD6rV_c~T-RObMie3)wPx@lPn=!k9$be&ooZ1Bc&B zPM*ip$2Qo1ZC&&|9q7u6>w^Z7F*yKo1>>RP-^WNkBF zB8#(JmkB)c0|I!?cM^#(!^^n1#3z8!aZU>W_?G^EUID&>0*g`tC<~LEW021xxf6dL zPY~yk#mvpPr)A6H&dIs?_O$8q<>k33X^=Cngv(?MxJ0?KCV9@$C@y*aON$Q73xD-; z08fMe@^Ub%1TF_Zf@#eMdqOXO=0Z3Jr=P1wPKf?IF`P#Pg6dklEL<GMo}T*n%0RIG$YnFBp!MR< zuo(+~(8BeLelBnzX!?v-{)~@c(8tp3nkz&o=;s%BHskAc6an>Fpcus33p**z_9vIlk?VgKcC3GA=p|C{}n;JxfU0jt;|>Tb4Q9)``kafp3b z7Rk0%Y-2N{LD%r~B{q>6geff9HjH$KTBZQ4n?j3Xo~NM*C=?2bLV{-VJ2~fFY6h1! zDxd81@oFqE;@vGnE#3Gj&sy>Hx`dHVZoyBUWcvd(bTlf>*gmd~ixnbT8FT={ZzuV- zi^^Ju`m2je`f)dVPqpjet!RDFV1mc5}o<>aJ;r*+_Uh)kBgHHtuI@ZlBgt z#+wP!;-7}FhJ+Z$3y}wQ?DXj#RX5&JVdF?d8N%M4#4Q@3q2a9xgEc}x@XJzs@LG$a zW821XRBgm>12)`ifJm4=5F}O_OX|beQv8RsoWtu_n_k7a;S&uBFvA!m$|rGU`lDV* z@b>nKv*XaW;}N60cY?w*HpMXPQSQ?)?KI6_*?dK|I>Dm)VI?bgN1_)+8aw4a2)MgN z_-G}yal7L@S~P(>R~&}IiMvwZb_%Ri(>_>J+3D!1sh95ZAyXsJUE(f$2*qs=EOg&T znWvjgDBqa2R&T<7tYoCoT11lyM;pBfBa4N|qLG>c1hAU(y60DSqi?WPW=%hEhJIadW z;z#0P#zRZt!KbtSnWsbA=_3Fin>?oG$RR#gH4QnkItIjM{R?b zR-}5O$|6|y4c+{{<6hb;1BB6b_5IfK;&Bl+>J|Sg^|-O;-lu23jiKK;%@v5?VWHhO zZt+yD8P${FFlN&ll>Wjar`nzyq=NeCp8g7&} zNzCmYLW`$b)G8VH?$dusu~8xE);J`)eW0#hQ;H3#|Mcv*rT&(nN@{LwDlHtf=}9$a z*u;3{w^OK-HC1(e(N@L7f=J@P!|%!(OWNyy653HEZf^(olxFRVc#ESCQfUD>?E<&f zQr}Q{!Q3SnIUn+7fU-XVQCIc@=C8R zVfuuqh?Jaz^6tW#)dBJewM`8d=xuie-vZSTQG4Ir$;2>U-s$KKI9LUQl|U7F;B9so z(LJ>7hfy88{a@89==~1{4g>}wJwtVLB4kp;Tnh)HXg!{Zto|MyQ#GYE)pSN65(`ZuR^iDM5-{){DVa>hK zyL{uQLJwC2cZ9E0U;&J>E2PCu2VW}`8wVn#!^S|4s;ygLZn-N*Cs3MXRvLt_NPGRX zdqQkiqY|&J;j8uB@<3D9=q7x>ikga1kGW0bjYB@}D|KW8k%mFK#f-t+NRh21#`H@T zMVB8j(px%-3-oWl-t}_m;3LaEL_bs#DJC9t8yElH582gO(b|N+)7K5{xq{j9k*GVC z?0Ey`4@{wh?UQ0spuFd*`%!FlL=sz)<%_fLA;yVFZNj&c-OAdx$6ZuLa9$|$)#qL7 zos30{>QTOiqB9h@^r4bGaW08Iar?m8P{dh)_I4=(-K-6;H+Yu3fnqSn1q-=d{@mW^ER`d(3frc!hMPk$3O zYRjy_s6r1TUtJL!0^#?ve1}aBs11p@Q+`-f`{3&IlyXLr?aufH=;bD3#OIALR++jI zj&C*#|uJFrbvXvZ?DRvJEQS65b$4&J5uJO_tR0)3d6P|?wo}>2~y;|TTf>APL zlb!Mo2(PY*^lU3^R7wkqgcVkVskLduVdBaA3~F<|6&bOM7QcFh(m*ol){Mu6wRrS4 zC`Yju2_ks=uxivOF3O8aZ|DkRCjUg}G9+%;ZliKL)QQS6XuNuWLbq+C!8 zk$%Ay8~tM3YHe!2ydGL}Xcmlb4df6y+NV?OG*hp--#^h{s_WNRivA()m4)*k2wJF+ zG#{Nt6&zXVC->lspghE-n^3Xlkw$gHy|z!P20U8@ba7vSK3XmonD$Es&9D%%HNU*m z7cp{mR4y4TZLznvw)*O8dvWn)MPJpem6|o`OfAz+c%8OAjQL(Xf!Yu89-$ynqimB9 z;*YXgtBJ*((Y3noN*kyB!)(Zb*q2(>y2)NoFGsoO3kj3n`=ir1ME!`^64ZerSYsZl zwf>vojL`8Udi**M4ZpD(qHG&86Rgk)VqkCRI`4<6dp@e?0+#;1WX$wbu$w=0g;(UL zr&(vbE54;V3hHIp$AQUi2%QAWN27B27?f{>;U9%?Ot(f z>=HexsJBGkD3~-=Q&(=#A5* z=0VhCQ^ORTBD}1JMHyuVbH%mdPLZYX%luo7Llsp*{B%&?0138#%0fNF8MBzx4SFoB z)6#;Pm||-?WxW`Kt!z{l>JVxi6;NEg#pFgFARF6lQv%JTwvJ2hX2y=8r;HQRQcY`} z3{uN$Xe45!VnJ*y@?4qGMl3t8ZSNM|gidJdZ&(Iy@+{OB`S-=SzcDa03VBY{n{COuo;QK zkMKzN@+%#X5_EGzH*bwQPO6lvsvQI0D0m*jl=|YL|Jg75K~^S)k!YK5|NCy+7_x0J zj*y6$4yU|oGwfO07K}?<$LrEuyKuGE92V2xAnwx-6zBi#8V{o7^9gF60a8?Oxz*3o z_7~okuf-a zs(5(+YQ&Wr`r@9+o{4h#8Zx1)Uc1*zYQb?3H}O8q3|2^W`I*}%FY#m&Wn=RNQ8nKF zRbhW2gw*YIK4}ZIv2XQD5MxT3i8m0zz)u=F-VL+V8906#P2|BTp(ZE|c6Taq;!&Gn z3`1(cQV02#=wiWrS!5$$E^oR#VnX-ldo!rgvFM1>A^ZNo5&>*C>gHH3cCZwF>d!Vu z$4_D(pvFw1YP3BRavw}XQeyb!7tRl1Mjt{t(O3$ckSIhdqx269N_9_vd5J~}Ttbkb z;dSJs(B|Y9|*c9g84KRHCE)`+pH*oz7l3}%*d+l&7Re~8{rl3&Zj%iWvDI-8s zEM(ug`M+#@eFvKsx`+M2rhN9k!Q1TUl8@Pk4t>IAhJu#PI~(U7RR&6iAma^g1`Bfp z3hsiSbNe8dov(~~XkYO^J57K7SJ!TjuZB3R({-#DZ*vFXkG{D*#rp6)yrvq?$&o+k z+x-Ujs^oEvxmbv&b?H1PhjrvbD-~CD=;^41@^SGlc+?- z&PYxsX8i{rQn#L{>o(N1B7NIxxT=lXFS$L0sA-)s^f)J2-Jl=@g#XArReoT77x)dw z+3cq$z8Lpr42Dz6eYExc&;EM%jn&*sy)O{zPd8@-A;~I(7#FkIe@FOQjwGB4r#eJE z4cu|wmxVm$@y8p`U(*1~NNb4qNvT+aNO?<1OSxBsO<6N(Wvxu4rc{_hQr_j7Sy)DV zN)?}$^4-gjl(-466mEV6>tRT6${M|fbvY5A!u${n>X!bi8)+G7kpJ)>^S>TPcJ4+k zr$LyJmT7i4{-k#}rnL@7rrO~+170oQbq|zL1|5zpmBYaqa5y%Y9FE^XA&{3G4#(@A z;5{h+ddlHo?s!bwU(qtl2!Mam%G$Z=Ggj=$P8P?yhxMW|pY`rHcUYvPJkbs5Qb=u*uV<8Zu)#0cFi(>){KtnmI5TBanK|A|9Ombc7N2te9PsHg<{Yxi)7=@d zi(^KofKQ)i=1{@^wlLdPcPFgrQK^+X5r^w@U!41t^sx#?6CjP--C+awj>8aUqaK$1@El`~) z!K}YX7gSpid|o-q6$O`c4QekYp$lysm<;YMp47LVbo>h=cP3eQKs;x=3f%Aa<##wN!|71OOqB{tq3W Bly?9C literal 0 HcmV?d00001