From 8f3384b4c130fcf98b3db7689a3cef4d7497e726 Mon Sep 17 00:00:00 2001 From: amancini-N <63410090+amancini-N@users.noreply.github.com> Date: Tue, 10 Dec 2024 00:15:20 +0100 Subject: [PATCH] Fix BeamSearch T5 if initializers are on outer scope (#23044) ### Description This PR adds the logic needed to consider only the needed implicit inputs on BeamSearch op in case of T5 model (encoder/decoder, 2 graphs). The logic added is similar to what happens in the _If_ kernel setup. ### Motivation and Context Fixes #23043 --- .../cpu/transformers/subgraph_base.cc | 15 +++++++-- .../cpu/transformers/subgraph_base.h | 1 + .../cpu/transformers/subgraph_t5_decoder.cc | 7 ++-- .../cpu/transformers/subgraph_t5_encoder.cc | 7 ++-- .../test/contrib_ops/beam_search_test.cc | 30 ++++++++++++++++++ onnxruntime/test/testdata/dummy_t5.onnx | Bin 0 -> 6815 bytes ...ummy_t5_with_outer_scope_initializers.onnx | Bin 0 -> 7135 bytes 7 files changed, 54 insertions(+), 6 deletions(-) create mode 100644 onnxruntime/test/testdata/dummy_t5.onnx create mode 100644 onnxruntime/test/testdata/dummy_t5_with_outer_scope_initializers.onnx diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc index d675ba742e03b..7757435990a65 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc @@ -31,6 +31,7 @@ Subgraph::Subgraph( allocator_(nullptr), is_output_float16_(false) { num_implicit_inputs = static_cast(node.ImplicitInputDefs().size()); + used_implicit_inputs = std::vector(num_implicit_inputs, true); auto& subgraph_inputs = subgraph.GetInputs(); auto& subgraph_outputs = subgraph.GetOutputs(); @@ -73,8 +74,18 @@ Status Subgraph::Setup(const SessionState& session_state, // The position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter. feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end()); - for (auto& entry : node.ImplicitInputDefs()) { - feed_names.push_back(entry->Name()); + const auto& subgraph_map = subgraph_session_state.GetOrtValueNameIdxMap(); + + const auto& implicit_input_defs = node.ImplicitInputDefs(); + for (size_t i = 0, end = num_implicit_inputs; i < end; ++i) { + const auto* entry = implicit_input_defs[i]; + int idx; + if (subgraph_map.GetIdx(entry->Name(), idx).IsOK()) { + feed_names.push_back(entry->Name()); + } else { + --num_implicit_inputs; + used_implicit_inputs[i] = false; + } } InlinedVector feed_locations; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h index bde591626bb83..8ec9c9cbdc20f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h @@ -31,6 +31,7 @@ class Subgraph { const GraphViewer& subgraph; // The subgraph int num_implicit_inputs; + std::vector used_implicit_inputs; int num_subgraph_inputs; // Same as subgraph_input_names.size(), keep it for convenience. int num_subgraph_outputs; // Same as subgraph_output_names.size() diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 9037e58aaf31f..6c66bfc2816e4 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -281,8 +281,11 @@ Status T5DecoderSubgraph::CreateInitialFeeds( } // Pass through implicit inputs. - for (const auto* entry : implicit_inputs) { - decoder_feeds.push_back(*entry); + for (size_t i = 0; i < implicit_inputs.size(); ++i) { + const auto* entry = implicit_inputs[i]; + if (used_implicit_inputs[i]) { + decoder_feeds.push_back(*entry); + } } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc index 51473c0c931b9..d59db4afac2c2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc @@ -145,8 +145,11 @@ Status T5EncoderSubgraph::CreateInitialFeeds( pinned_allocator, location)); - for (const auto* entry : implicit_inputs) { - feeds.push_back(*entry); + for (size_t i = 0; i < implicit_inputs.size(); ++i) { + const auto* entry = implicit_inputs[i]; + if (used_implicit_inputs[i]) { + feeds.push_back(*entry); + } } return Status::OK(); diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index f6fc9ea7662cb..ca600c0700682 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -7,6 +7,8 @@ #include #include "core/session/onnxruntime_cxx_api.h" #include "test/common/cuda_op_test_utils.h" +#include "test/providers/model_tester.h" +#include "test/util/include/current_test_name.h" #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_options.h" @@ -394,5 +396,33 @@ TEST(BeamSearchTest, GptBeamSearchFp16_VocabPadded) { } } +TEST(BeamSearchTest, DummyT5) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5.onnx")); + tester.ConfigEp(DefaultCpuExecutionProvider()); + tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 16, 6, 14, 1, 15, 6, 14, 1, 15, 2, 3, 4, 15, 6, 14, 1, 15, 6, 14, 2, 16, 6, 14, 1, 15, 6, 14, 1, 14}); +#ifdef USE_CUDA + tester.ConfigEp(DefaultCudaExecutionProvider()); +#endif + tester.RunWithConfig(); +} + +TEST(BeamSearchTest, DummyT5WithOuterScopeInitializers) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_outer_scope_initializers.onnx")); + tester.ConfigEp(DefaultCpuExecutionProvider()); + tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 16, 6, 14, 1, 15, 6, 14, 1, 15, 2, 3, 4, 15, 6, 14, 1, 15, 6, 14, 2, 16, 6, 14, 1, 15, 6, 14, 1, 14}); +#ifdef USE_CUDA + tester.ConfigEp(DefaultCudaExecutionProvider()); +#endif + tester.RunWithConfig(); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/dummy_t5.onnx b/onnxruntime/test/testdata/dummy_t5.onnx new file mode 100644 index 0000000000000000000000000000000000000000..3a3bbf47675231bca6d80570ec6cedb01ed4440a GIT binary patch literal 6815 zcmb_hc|cRg_J**9D=1tP(I+k_sIjh9+{m1O2#Rb1q6H)-r-(nF4)%hE=!6^Kwp0^|J)?^oNvyYIWzNpGeTxq z>LBLI0LGqE#i$irf;3j8Nz_SW6As|1N65n9F8yT8SXn)QV9s&A2g6jiIDkrA))I zU7K!9%J;p-M&V)=QumDsXN>oU`4|}1!O)f=C)&h#7c)DsS$6m)me#_t-9Bi_ zH}7IOLamCC>FV+wt93G|X|_LOtTzlZIhHRdsU|@^m9>mhPl?s>V;vyV1ti8XlZ?B3 zR1j?YrbBHQo&cs9GuB?q7iS!wgqr-+mYR!o7@rw!O=?8;jJbwOh?kg11QJt;Sxrp^ zH4EqPHrxzZTp}m!_W{5Frd4gVH$=cR4vJ+peBp5_U2TlDru`a?0blb1CX}&nfUA#y zF%}lv+5lN+)eIQNhn(-Znaj;H2E21|hY)S(6RgUynXejQ^+(L_|l@rfk zJGA~H1iLLrB54)L#OeoAeBA3PP05KT=OPYbY4^Y5iwA@8*0Kod+jk>Lswjf2C6Q2B z+Ld1aVO`OUkQBVKb#;;bi4`<@cpgz64aM8Pn#1vaC9p+*Ro}DV6!9(S2S?*h(Ue&n zi3|?(P%>rWS2suE(qm;Fo2%at2RsO+%M*c~vX_l_xmt1Is7OEft!$nhj@(X_+y7o#7xrGz}XZ=tUaHZNY5wqJkw)wj4!-3v;-D~C!fFR2wCsP@C60hje>ze>Xwg|=uHmW|Gz{|0FqJF$D0PjQiR zB$3$M(O(`Zr&hz9!SbOeJUnfMpmIm+9)XcQKbhTZH~Cp7g1RF{4fY!Si|g+H0*ZIi_YsY3A6sa2dmq~L&=mX z7-7B-i#yuE+@)@qqO7J}^4nrcpNY7E`iR{QhT+(x43bV)`Sgap{#KVyEo7O;O)Pdvo3P^B1wlPvFIaOfF4SmiDL>~$I$lR5=Y zCU=EBJI`T7Sh{|tU>z1-b;Ay>USeDCD4cn#gbqBN0{!@Ud+>N1c&uHDu*nKS=6w%w zeLEJ1nfjwip#yYz5JZ2rq@*;bJ^nFnulQ1ICwe0(3$!D?B3_qvl8ozT>HPd^>OJe9 zbo29L#CPEZDxJ8JM34GQtBmYJ|?a|H&L9f9(bU1*P`m*_W{sW{z|?_wNOoxmj+MMdJkd!Gs5+44(7x-3N2Zdv0<6oeP61L zP39PZ&@^Hc<7nU{G9Aw)b+KxdG+w5i#`a-8;eC3%=i|iI+w+@f9ZQtf5y^qEqK+y_D17(nZ$#mB)mF~|nZ^U*5OzhQ*lN#%bnC3!_<#RN1V#+p-@4dZSI{!eY^ z-Y}qVy<=L%s$zB0X5>}oDWEw;AZiRt~JdgvWC^* zEg@o<0}U$zPx!`;eM5R=+0qB6U17_!!*KVgnr0tT)4Z<7VeJfWcozMH4!+Gmn~{~Y zi|!G9UNVzJ->Je5$FHEnvsS3xl8RQhbBVW3iP5EhN7>p!xZ8gU9p;}Qc8lJRJEB^k zyw^?f;m{R0C$g9%9%+X+CS=hB>-MN?TM09}PAp<$(lKy(2(0s7kK?b6BeTxGMiaqy ze3LbrxY~%J*X+JTaPgsjglH`tB+=u3-yv90@qm^ErGx&_ZW6HLI)*>WBdI-)(j$Jo z@uyWC;Ka+hFiEL`7w5y#acmBaTDOga>XRTe{5g$rCRDzD5&Ao4<6hHX@X0QK7Rf90 zCE65->~WARb92QC;Xxw5S^}Yz)f1mGxVx=;@jzt=6pj2&Jg03ob*L;9Z&kj*9Zx6G z#0&e$=3aV;EHFn+bs)NFP4L9{sbC+nl}yaB!Y?^9SaYUS+_8El6d(K)#~sP1lMk2C z#IQnayMUn&&cM_*Wew89v~741w7$1vKjcCYFkJ}n6+%^3|>{8xh5DFlDoI0Gyvl#>MsA>e%c8Cd$xrkAQDIOISKKCKAG zl)D@BNjr9tY1Cq)T=WTs3OtSVFBh^z9g!01=W|i z!$ALMB(!^$KIw-En0Dt7uHSHrR(V8%Vt6Ep)@{V`;r&U-sH0@Iw;Y|PIuq|WmH3kX za=h2~09mo56~t}IfZuP-hgswIU_cuUMqD(5 z+~C#JP4vexUU2QgNL-ncK$oY!qMHix#nD|`OELwI5qI^4r1;a=a??{~7w z*9GX>aWHm#-5Y|chha$mM3S*>n0VBWZYXoNmZXhxfhX%si?3G_a(-%0WQRt;mK%Qf zu}_=A4#5K6Pi>{1?5jxlC#~@;6UyI%fQxRdgymbidvy1njpx=pB!TD0kTd!1kgHlOF84i3%VysbbE!Qc zTW|vOnjFlTE0maBt;R+1f1}UWIYO_iONd8oCp76ej^<4Lovy#;i1WYv42q8LC-dd~ zvA6$Zm>9l+6t)e9$i7{`d}tV2DqG>^K!9PM9pLQH9Yy0uKM~L8PYH%jZ;LBh+Q8}- zLGBqC|Pu^qM&B>S6@-a=3 zktV_FwlyP?LnwDQO8z@D9r>`@zZX`Gvf@QujfRpBX@cH=*fW4J#t;9p$_ZkOGsC|M zSR*hT7#mK_-+jCzE9@C2p1&lClj@Q*Tt4KbyW(eFQ}&{EyvO_ZEPm!@HLqOwL#ZY&Sp3XSoy>&) z$qbPLV{OpkTDGRVEF0CTKxo#q+?KmA^>;jM&1^6=azoD5pvJm&!HwJ#5J8bmK(uTzAx{EA5=l_O)}MyvDxGWh53Z@4 zI!UjRX*sn#RiD|kmbKGyGcx!aayr&Vm#9hSbRLW!Cr^vySD!)k!+W?srGoE~d+>WTle#S*Mrlc}SZ;xn9MkGl7hwR<6^_64Nz0 zooqUnDI3amH*L4Y5?LB2S9>rKoH8Sk3*ol0NT z){#bs#0+D_sg-n<|k?TV;Qkkjd zm;50o*zRq++H^c2Ogm<*la5c$R6YpC_%x)Nhg}pO8J#gSVkbtZ<HpFiSN4THe-4dABWvWMnDMQm!+JAsiY(coFF^=I0)~faERC08 zc&CIh6U;>BFM|b{;iz^UYmA3k$ATHQ8Q;A)v-Uk}8%^QGyr?m@2h*JoluN5XS@PlN ztcc?8+A615!uarOMzmTZV0y9cCi}L7tcYgJrq%z7o6^^urm!6ud)`}-aXVT8rJ77b%!-WZQJTOs|>FD_pYD)mWY4 zE=)&WGE+18iyEz#r=AAl6fflk^hLqEp3{*DU|3W6G!xL)iqCd2nYECz--XndF>lbc z?nT{lGmQ22l9Dvg$)qNT7Pbtl&PbD~IJr_MldF{trQf7wjfCGIl&x5*K`ENSws5^k z4Vpo)?q;`XMv5buSD3QSrW`hMePe_1DuSN$O3*GQ(8k*RW3X);vPs+jaJLlPbV4mc z`KMbZzH_i;LfU9L+Uum6&fVnv?>&rn(}6a-#&k+nC+lVHv;}eI*Y#=;sfKpZd>iw} zT7uAaHQWp>zCyyo+r=|Wi0vb?CN_u}W2Vvn(<1u+ljA{*lj$6r612@eK;iSd7QN04 z<(c}mmvp5)!S1Y>&{k%*3D(FHW<*w`@HQ)Vy%=}it?nc;6s{Sm`2XV z9KysA=VyuqKo67NXC3-1>9@i5Ky-S&&OBb?fj3x(Q_v` zy%Er>pc1NY-j#U&v>JK`=1V4Avyuw-UWSYZbLr*JXG7d@Ply<@7e-poLU!vA&}Zer zRQD?=%vgg}?+eO5y}y8dzi&ABgp`9Nu90NUJPqRFeW(?k$6x0s8b)oeAop%t8|uS_ z<;!vp7%n{h7@4ADYKsT!gRt!N z3x+fA=i-x62Xu^Dhi>ou0=Zg8iFfZeaFJUqk=oxhTo|jMwxis@=B_W?J!Om4Dra0# zT|i4BoXe{o9>%iAdmv=-83?s{k4n}#!{E5hL}5QxIy!U-h*w#QM54}oem0xX`Mf;) ztAGX<>QS`iBp#g-PTt(&i^=u}@Z@|?6l?R*ukIXp?w7-!&4RM8)dmRA6-s)(8;7Nj zkAUEb9n7i7#lAoK(fR$S;;R4b!}@M%P?1yzV}$#$+}#o8E%m}|RXy!p+!eD2$Kxg% zAn`gBg%dOLNmgZu;lv(2b&Xcx_+n>_ko!Up9}DQ@n}ucSa|7K^)efCF0x z;9-`9I^`IgYI6(2BTwU@I)BOA&vMCxoFqJv)d%+NK8rO`d4`pO4On`~3wwC_NgVv+ zaMtw-`sS%@7|Q3{o%@r(XZ=!yt+o&`|8qzk;$9wQ8H^UC&d~c#82#3UlFGvF_}ipU zCFhfS(yN)PKsV-n;&*;G$-i=jeo$Ob{b&D+ZhQO{30!!N%Hmg&gleJGRv?4LcR!}` z!|TBJ{%N{RD=4oQ`9uF1Bz5vw z2=_DKhtrlrQdkKo+;Wya|E)KOBAujOqeqeaL32sWusbwq$=i5sX%vL@kHv^(%O$x> z>S&G6V{+u%outR9T+l2Pp#7kH>bEotRaw{Qhl36i!7*R3bxy^(g$b~5ehC;#s!7eL z2jI{(0BgcGv~e0QS3e##PJkgLtHIE!FQEEFZ`yz9 zdHPX74o*2!jbWt`P}*;`MC(?BeUBEQc)ZBU(bjLktD3#vpIHUD{^!L@z*DpFC2ah# z5RYhpYV%QF(@G=X8(*Tv4~u5x-C8C2BAGbxIqAR3f!;ag3ELkXfm=s4blqVME$VX& z*3a~ZM+pz;@EZ(t8Cy$x>+jLW6|+df%{uIH>>@fp>V&%OIcR%h3-Q;hFro7AC|_R+ zw}vIrQNj5VuY?1*E3N}7240gKiClqmW6Md#7v1pclvOm{t~=_x*21hl@nvjc9)>QD zfDQf|aq{I!WcGK@&_b{ipRXEEJnbbgaLy1S`2Mb8jCehLOKQLafg`Y{<_@h2%LBu` zy(DDU6^y=DL~;fkrC$UM!f#ghfa5>UgQ+StJozpfT_zUNxD7iaqh6-Ia#P&Z#mU($%jpz_jT&jRb${I*O72Mj!~z%zfEqFhhz4eDqh zYgVE4xaF{SE>AK!K~Q*7g%ykU!v)QL^!IUrw4wIWV=LapYfl|f=yF0*bq zQXpAYpNId*dPEXs(^0CdgrIrj5N>RzJN?4pNx>%cyQ3od@M5q!auRnqN=RI%`DCHz zI7o^r#6!Xz_+*ncOzpUqtdnMw9r8~}cbCCmQMA@2tg} z&iz3w776$wF*Ak-Iku{T^(LZ!dGO*B*^rxwQ|FF{i@_^F;u?Y9Y@P`=Q>w{=^ayY} z_6TeO=g{+YQXFwG5g*osWA?30hRj`i$n>3y$haesaLsy>kILbn&^;%QO1lo0mJRF$ zCF)qTbY6g4Y4BtgF!Cr1w> z5#x@MIsOWCOK~IqscOmj;N|$^kb`8!l1`AiH6MPx`T@+IybnXVXfft{E6B>T0`Yk( z=(l2~^ol17!(*1g%#nIndDsh{rfj8OPw<1w=f>j7>~y+3=PBJe`9Cc{Pxs;I{%jD0)sMo6;&_t3W0Yjv*Ip?1wv*GylT5VXlN6&e< zDLjY99o&o;9-W3CIv&J%Ay??d%hsf5&gU@c*g;%$btNqSq@Pbe|2cSe?OhW3-2`&F zxEpeHizU^8M`_iZA0=GQ09YqD4hC%@7S0n%tuEE$qO`x!#~WN=;H4$RC%GqDxKE;m zDZkQm=@3TG+)ALIZ!ds%LL*@yUj+BM6@%1qx6kaAgJJxILh4ap zL`M2sQ}H;FkWb`S4hTWLOeAlj>X*~}|1vBqfUz?@gx0afNzZ2MWW!h?D7 cKQh)h7+acoKyKy7&mE3VeojowLB-MXzj%F;%m4rY literal 0 HcmV?d00001