From d8de3c4096d61c70e1c7efd94b39068f44152139 Mon Sep 17 00:00:00 2001 From: amancini-N <63410090+amancini-N@users.noreply.github.com> Date: Wed, 11 Dec 2024 01:20:47 +0100 Subject: [PATCH] [CUDA EP] Fix BeamSearch on T5 with sequence_as_input_ids (#20667) (#20668) ### Description Change the implementation of BeamSearch op when using CUDA EP: in case of T5 model, and in case the decoder input_ids are sequences, copy the sequences device-to-device instead of host-to-device ### Motivation and Context - Fixes #20667 --- .../cpu/transformers/beam_search_impl_t5.h | 26 ++++++--- .../cpu/transformers/generation_shared.h | 1 + .../contrib_ops/cpu/transformers/sequences.cc | 4 ++ .../contrib_ops/cpu/transformers/sequences.h | 3 + .../cpu/transformers/subgraph_t5_decoder.cc | 52 ++++++++++++------ .../cpu/transformers/subgraph_t5_decoder.h | 3 +- .../transformers/generation_device_helper.cc | 18 +++--- .../test/contrib_ops/beam_search_test.cc | 14 +++++ .../dummy_t5_with_sequence_input_ids.onnx | Bin 0 -> 6754 bytes 9 files changed, 84 insertions(+), 37 deletions(-) create mode 100644 onnxruntime/test/testdata/dummy_t5_with_sequence_input_ids.onnx diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 8f5cdc97f27e5..b67d003eaceeb 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -258,7 +258,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches current_length, cpu_state.sequences, parameters->max_length, - decoder_subgraph_.has_decoder_masked_attention_)); + decoder_subgraph_.has_decoder_masked_attention_, + this->cuda_device_prop_ != nullptr)); if (decoder_subgraph_.past_present_share_buffer_) { decoder_fetches.reserve(static_cast(decoder_subgraph_.GetFirstPresentOutputIndex()) + @@ -302,17 +303,24 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches auto cur_len = std::to_string(current_length); dumper->Print("***CurrentLength", cur_len, true); - for (int i = 0; i <= decoder_subgraph_.GetFirstPastInputIndex(); i++) { + for (int i = 0; i < decoder_subgraph_.GetFirstPastInputIndex(); i++) { dumper->Print("decoder_feeds", i, true); dumper->Print("", decoder_feeds[i]); } - auto offset = decoder_subgraph_.GetFirstPastInputIndex() + 4 * decoder_subgraph_.num_layers; - dumper->Print("past_sequence_length", offset, true); - dumper->Print("", decoder_feeds[offset]); - dumper->Print("beam_width", offset + 1, true); - dumper->Print("", decoder_feeds[offset + 1]); - dumper->Print("cache_redir", offset + 2, true); - dumper->Print("", decoder_feeds[offset + 2]); + for (int i = 0; i < decoder_subgraph_.num_layers; i++) { + int self_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * i; + int self_value_idx = self_key_idx + 1; + dumper->Print("past_key_self", i, true); + dumper->Print("", decoder_feeds[self_key_idx]); + dumper->Print("past_value_self", i + 1, true); + dumper->Print("", decoder_feeds[self_value_idx]); + int cross_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * decoder_subgraph_.num_layers + 2 * i; + int cross_value_idx = cross_key_idx + 1; + dumper->Print("past_key_cross", i, true); + dumper->Print("", decoder_feeds[cross_key_idx]); + dumper->Print("past_value_cross", i, true); + dumper->Print("", decoder_feeds[cross_value_idx]); + } #endif #ifdef DEBUG_NODE_INPUTS_OUTPUTS diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 30bf3aa0a1212..8145fbd4a4123 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -100,6 +100,7 @@ struct ISequences { virtual gsl::span GetCurrentDeviceSequences() const = 0; // Get all current beam_index sequences in one continuous block (to pass to CUDA) virtual gsl::span GetNextDeviceSequences() = 0; // Get all next beam_index sequences in one continuous block (to pass to CUDA) virtual int GetSequenceLength() const = 0; + virtual int GetMaxLength() const = 0; }; struct ILogitsProcessorList { diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc index 723c271897a78..ecad146da6777 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc @@ -36,6 +36,10 @@ int Sequences::GetSequenceLength() const { return current_length_; } +int Sequences::GetMaxLength() const { + return max_length_; +} + #ifdef DEBUG_GENERATION void Sequences::PrintSequences(const IConsoleDumper* dumper) const { for (int i = 0; i < batch_beam_size_; i++) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.h b/onnxruntime/contrib_ops/cpu/transformers/sequences.h index 440a07e14a6cc..7dd1f28d270c7 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.h @@ -25,6 +25,9 @@ class Sequences : public ISequences { // Returns current sequence length. int GetSequenceLength() const override; + // Returns max sequence length. + int GetMaxLength() const override; + #ifdef DEBUG_GENERATION // Print the sequences to StdOut in debug mode void PrintSequences(const IConsoleDumper* dumper) const; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 6c66bfc2816e4..f4e7173c917c1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -156,7 +156,8 @@ Status T5DecoderSubgraph::CreateInitialFeeds( int cur_len, transformers::Sequences& sequences, int past_present_share_buffer_max_seq_len, - bool need_cache_indir) { + bool need_cache_indir, + bool use_cuda) { ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds"); // Allocate subgraph inputs from same device as inputs of encoder subgraph. @@ -171,8 +172,9 @@ Status T5DecoderSubgraph::CreateInitialFeeds( Tensor::InitOrtValue(DataTypeImpl::GetType(), input_ids_shape, allocator, input_ids); int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); AllocatorPtr buffer_allocator = std::make_shared(); - size_t total_size = static_cast(static_cast(cur_len) * batch_beam_size * sizeof(int)); - auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size, false, stream); + size_t total_size = static_cast(cur_len) * static_cast(batch_beam_size); + size_t total_size_bytes = total_size * sizeof(int); + auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size_bytes, false, stream); int* seq_copy_ptr = seq_copy.get(); if (!use_sequence_as_input_ids_) { @@ -182,19 +184,35 @@ Status T5DecoderSubgraph::CreateInitialFeeds( stream, DeviceCopyDirection::hostToDevice)); } else { - for (int i = 0; i < batch_beam_size; i++) { - gsl::span sequence = sequences.GetSequence(i); - const int32_t* sequence_data = sequence.data(); - long long seq_index = (long long)i * cur_len; - memcpy(seq_copy_ptr + seq_index, sequence_data, total_size); + if (use_cuda) { + auto sequences_buffer = sequences.GetCurrentDeviceSequences(); + for (int i = 0; i < batch_beam_size; i++) { + size_t batch_beam_stride = static_cast(i) * static_cast(sequences.GetMaxLength()); + int seq_size = sequences.GetSequenceLength(); + gsl::span sequence = sequences_buffer.subspan(batch_beam_stride, seq_size); + gsl::span temp_input(input_ids_data + static_cast(i) * seq_size, seq_size); + ORT_RETURN_IF_ERROR(device_copy_int32_func( + temp_input, + sequence, + stream, + DeviceCopyDirection::deviceToDevice)); + } + } else { + const size_t cur_len_bytes = cur_len * sizeof(int); + for (int i = 0; i < batch_beam_size; i++) { + gsl::span sequence = sequences.GetSequence(i); + const int32_t* sequence_data = sequence.data(); + ptrdiff_t seq_index = static_cast(i) * cur_len; + memcpy(seq_copy_ptr + seq_index, sequence_data, cur_len_bytes); + } + gsl::span temp_input(input_ids_data, total_size); + gsl::span temp_sequence(seq_copy_ptr, total_size); + ORT_RETURN_IF_ERROR(device_copy_int32_func( + temp_input, + temp_sequence, + stream, + DeviceCopyDirection::hostToDevice)); } - gsl::span temp_input(input_ids_data, total_size); - gsl::span temp_sequence(seq_copy_ptr, total_size); - ORT_RETURN_IF_ERROR(device_copy_int32_func( - temp_input, - temp_sequence, - stream, - DeviceCopyDirection::hostToDevice)); } // The ordering is the same as used in Setup. @@ -230,7 +248,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( num_beam, allocator, expanded_hidden_states, - true, + false, 0 /*max_sequence_length*/)); } else { ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream, @@ -238,7 +256,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( num_beam, allocator, expanded_hidden_states, - true, + false, 0 /*max_sequence_length*/)); } decoder_feeds.push_back(expanded_hidden_states); diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h index 83dae49c7dcbd..a72ce37a93aba 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h @@ -48,7 +48,8 @@ class T5DecoderSubgraph : public Subgraph { int cur_len, transformers::Sequences& sequences, int past_present_share_buffer_max_seq_len = -1, - bool need_cache_indir = false); + bool need_cache_indir = false, + bool use_cuda = false); Status Validate(const std::vector& subgraph_inputs, const std::vector& subgraph_outputs) override; diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index e047bd948434d..4e65336665bf7 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -1264,16 +1264,14 @@ Status UpdateDecoderFeeds( CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(), cudaMemcpyHostToDevice, cuda_stream)); } else { - for (int i = 0; i < batch_beam_size; i++) { - gsl::span sequence = sequences.GetSequence(i); - const int32_t* sequence_data = sequence.data(); - CUDA_RETURN_IF_ERROR( - cudaMemcpyAsync(input_ids_data + static_cast(i) * current_length, - sequence_data, - current_length * sizeof(int32_t), - cudaMemcpyHostToDevice, - cuda_stream)); - } + // We expect sequences to point directly to device memory + int max_length = sequences.GetMaxLength(); + auto sequences_buffer = sequences.GetCurrentDeviceSequences(); + CUDA_RETURN_IF_ERROR( + cudaMemcpy2DAsync(input_ids_data, current_length * sizeof(int32_t), + sequences_buffer.data(), max_length * sizeof(int32_t), + current_length * sizeof(int32_t), batch_beam_size, + cudaMemcpyDeviceToDevice, cuda_stream)); } next_inputs[0] = input_ids; diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index ca600c0700682..8c69e2d9810b8 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -424,5 +424,19 @@ TEST(BeamSearchTest, DummyT5WithOuterScopeInitializers) { tester.RunWithConfig(); } +TEST(BeamSearchTest, DummyT5WithSequenceInputIds) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_sequence_input_ids.onnx")); + tester.ConfigEp(DefaultCpuExecutionProvider()); + tester.AddInput("encoder_input_ids", {1, 5}, {16, 17, 1, 0, 8}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 19, 18, 3, 8, 8, 8, 8, 8, 8, 2, 19, 18, 3, 10, 19, 18, 3, 8, 8, 2, 19, 18, 15, 13, 13, 13, 13, 13, 13}); +#ifdef USE_CUDA + tester.ConfigEp(DefaultCudaExecutionProvider()); +#endif + tester.RunWithConfig(); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/dummy_t5_with_sequence_input_ids.onnx b/onnxruntime/test/testdata/dummy_t5_with_sequence_input_ids.onnx new file mode 100644 index 0000000000000000000000000000000000000000..5a5c302914890e5f1c381eae3676c985e85cc0c6 GIT binary patch literal 6754 zcmds6d0bOR*M_i#D+pXxH$;4EAu1w*2$DGgcLWy%t5{3}2?9a_Nzfv#qT+%Z6$PuJ zEUs8;QPEZ*bE0Auwc=K_E|gZO`%)EAtM$EX6$SKt`}U9T_vN3Pd+(WN&YXMB%=66f z8SgT8CdZAjkjWzyQkha5EssxBi=(9~##AknizO;CcTp*njHyB{Yj$g$a@1N|unwbtqG@hN=jKQl@fb{A7~2aWaWAB1&-AmNAmb>fLdx zv}FjxHr|P4ap5wlG+I7I#TftPC~MvHIx1Q!<$PADC29`RaF|3LB~vm1jCIrC-Y{T0 zw)$8cCzHq>nIM@oF+vu|{TH-j^d+ClR3hD6T_YxtX}vq<9cZj&Q%_BvsuJTI^?apL z#;XlD*0jkqT%wMMax@qhC5e{_7)DPfm)dcj@R$VVO&i8pC5w%07`-@L7Rkj`tX4|o z@d}kptg34+mMg^daIC}oOdSmeDdZ6nb;IDJRdzacUG`;K8wtZiMsuH79IsTwu!gaU zDbZ>!`hgO4U}7xuaqCUq>j2xaZ5JJlf@Jjr;=aZ^%vcK*mnyN`;OpWtRhA@nbuj#NLQp+_i;zR6V9?Z+%!pSqU^nZfs9$aM5S)#0n^ecmW}61 zKqgn$`&jR+Ym4~v>f%3+v3L#F=mFE3TdY+RWP{Z^K-s2R6Rb1fTk-kkx_+(Hn}3SI zhRd_)9jzP-&X=By*?eUofg)3tic_c zY>ns4a9lXOOc>lb2%S% zVpJ;47JzBLfKuF*2YFsU!h@1b+!c@m36ppFXf87FT&a&vrkn8R?sMq=M_Sx@drst2ZaoeWF-Ea3Kl9GYEYA$)Q^ zS<9>K459APLUX7U9@m>olYR3vYeUxH@c2?;)>9pbbnXrEeFdbSPClu6_6ONM$_LhY z+#<(x4pR3{KTz`#M`?!LN&4N^&+yliF_>?*oWvB&Be7d|({1)!$)$7`*m`XfIeRMw zt;2Lh-X3X~z4AO{Z~u+f^fJb<9tPUat&V}v=yz$R7&eZw^YWDry4=(VYl9>;Np>c#t z_$f1z9+^=>Om9s>lk^<$U$C0&sa!(?LUo8>##AcEG~o03R$P8In{T+hjN2oiBy2Wz zw*8*Woo6AkC|(MMUHZ{3;73<&PJyhAQRKG%eDHMXfXbsE!ymd|g1h@tw0JZWj|?#r z1qVlw8=H(p)AQ`L#nDS(t7|OjyK5QIn+4GBkP9*CxC3sTs1}BpZGvCZlCZXX06Z(} zL;F|0#7;3zaF6dHGE`O5g2=hxRy2zIl;w^+M;*cqA#32i)<7$%@r02Xw=tmS1=8+p z2ySuc2Xk!#P`kPvZZ5AxW&2!2<2Ay48;Wt&JSU>)aTr&wmsm{C{ca)IYQuupEd@E&x2MqY^IYgzuv>GZ zj~*tP-4=G+{}KBBmW;aPuIN-P(9U#Rik%&b;AQS)=ox4Z6=QBv*Rx@8YSBtaxK~NP z@U+0Ym(##r-W_5NOd<=-R^kuawn5RuM|AS(VRW;1Uy)(JDm1@hO#S*5mad(!3@qoU zfaE{*dBkCFH|Ymgx}}3s=1+|a9iUhH{d8rv9Tt?XFP&y@sP!D=OXU$>S6_d-SB*cOPQy>X@ftd?Ij?bHdhrce zC&q+j>Q`;CDtd;Dv14r-r~$XAG_WXf1JMc{&e$~Z42ha!ZR%)+TpTA+O=a6eYOrsC zG_ozqQehR5f%Jvt6bKXha(>)+o7^JmQK z7t;u(LZxb;2k*vebBy|-0Swy=t8r!iyJ+D)Oe2e8+l+VXDZmBEDm<9;r3Ig03FE_s zd_XH9cVO+C?At`u4OcMSb1WcI7%jws>Bkxw7#5QOM?auT~BLUkiUBJ-3;~(9rK!CH-lj9 zo1*cCt(Z%UIwc@|O`*pz?=fX9o6@6>LN`>)`-pVNdqJ(6KpTqXKZ9-KBR5H<{~wa@ z|D4Dg^-fjP%4G8%l{My_c&$z|??z;sCW2-XrYZejZy9h@ajv=$Q*t`M&M+Uk{p3vW zyJ7^)sezE3by;{uEQS1{fw;qQA30mJkH$?3qKA{JsNP;ne7VsbD-)a$M;6n?jzdZJ zsip9H)D3zqZ5c_#A#@&e1X9ap(n+ThF>6f)tQgWx zl;!I~`VYK{{fvf_bDOr&;;I28_}m7nR`5kSv+j{&-Wt>_+(g3X42K1KKBS@;ClUQ> zsir2+km!~r;s^ZIWXSIWaKYId`0C3Pm~8tQ^*i-a_$2)Rsky!g2Re<%yRK6)?4A=G zd~8BBff@+3nhZq`Sg?9XVACEq?Bx@HCr9LIto$cH*iL;=_C;dgXN*q+yf+Detgbya(2Q~4M8d|98gnKS70=?-W z_$j$WFFpSRJ&Xb{CF>G7qbY*iL-Xm~OG$Y3lUKODA_rHy`crwu7u4ui1TXtyN*bVwpSv-`VKLX#lPND88a`Q;eg2~VRq=G|b!k?;U!}azHa7+4H zG}(FxUd&D*yQn3Y9!rCNSxuq0&u*p3HD99dw<(acek0y4oq)Po;ppGvTk_(jKK#ow z3Hz75Bxh%;i9o>Taa6IDb->^JZ7$fB+_}w;!A&6_YR|3fjbS%Z;eIh6U!Wl!?63`U zDnj6w+mYnR$J>Cpv%))2=OL|KZ9_7Hd%+Q{H_^Kl31POsK~UN@@XR_!syr5uSrg0f z!Q?{Nc0H@KytWX#I%g4wYcB9)#aem*Y@kDBjiyUUS5aKr2O_N90t+Vzi1O1+I1qb+ z+Rd9wA9=l?USr2&nr#Z|l?;VH0~f>3S3bd-UB;rk_)++!+dQ-rPr^B;?0w35bjE~- zduefS2`w(UOQP00(17#L>4TBWaZr{X=50Iyfgk09?c5^Lb^cDhev zaY`lJ%nJgSqC0pqwuE@P@5QGBcEgQzqcxi!wa10!3$SE@nMiwN35G_j#-bTJaGG~9 z*?aE<{d)a6e4aP~Z;CEJ;N3DP$WPK_pVrg#nU@b;%{Ie@Awe)<{sD4o!)Rim9F5

evRX(`*+tuU?2_NRR&j>25C zfv`@yiWWqvY4uDq*yj-jQoqHx?6Cy$stTZc5+m}?xCoBbyM59e7Lom#qj7jn3T%zv zr1A9A1Ho{w(5=%`;RD(r_3v*P$h5`^|F+^8&9u%3|1Myi ztgvKEWD4&1;terj!7y>$IY+EmofI#t-$$>riu9MWQq3CF zHxk&I8~VL@EaI=vaM|7P;*A3mfBm|&7j)oSynR&Sudi44%8}d2YV(Z5U;l^WOW+@= zkZ@p(o3;v6^@`H4g@Wht_1c!+;f_q>F%4T+4Z5vVgPtE32XhNQ3+Bzhnd|-=8-F7c literal 0 HcmV?d00001