diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 8f5cdc97f27e5..b67d003eaceeb 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -258,7 +258,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches current_length, cpu_state.sequences, parameters->max_length, - decoder_subgraph_.has_decoder_masked_attention_)); + decoder_subgraph_.has_decoder_masked_attention_, + this->cuda_device_prop_ != nullptr)); if (decoder_subgraph_.past_present_share_buffer_) { decoder_fetches.reserve(static_cast(decoder_subgraph_.GetFirstPresentOutputIndex()) + @@ -302,17 +303,24 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches auto cur_len = std::to_string(current_length); dumper->Print("***CurrentLength", cur_len, true); - for (int i = 0; i <= decoder_subgraph_.GetFirstPastInputIndex(); i++) { + for (int i = 0; i < decoder_subgraph_.GetFirstPastInputIndex(); i++) { dumper->Print("decoder_feeds", i, true); dumper->Print("", decoder_feeds[i]); } - auto offset = decoder_subgraph_.GetFirstPastInputIndex() + 4 * decoder_subgraph_.num_layers; - dumper->Print("past_sequence_length", offset, true); - dumper->Print("", decoder_feeds[offset]); - dumper->Print("beam_width", offset + 1, true); - dumper->Print("", decoder_feeds[offset + 1]); - dumper->Print("cache_redir", offset + 2, true); - dumper->Print("", decoder_feeds[offset + 2]); + for (int i = 0; i < decoder_subgraph_.num_layers; i++) { + int self_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * i; + int self_value_idx = self_key_idx + 1; + dumper->Print("past_key_self", i, true); + dumper->Print("", decoder_feeds[self_key_idx]); + dumper->Print("past_value_self", i + 1, true); + dumper->Print("", decoder_feeds[self_value_idx]); + int cross_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * decoder_subgraph_.num_layers + 2 * i; + int cross_value_idx = cross_key_idx + 1; + dumper->Print("past_key_cross", i, true); + dumper->Print("", decoder_feeds[cross_key_idx]); + dumper->Print("past_value_cross", i, true); + dumper->Print("", decoder_feeds[cross_value_idx]); + } #endif #ifdef DEBUG_NODE_INPUTS_OUTPUTS diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 30bf3aa0a1212..8145fbd4a4123 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -100,6 +100,7 @@ struct ISequences { virtual gsl::span GetCurrentDeviceSequences() const = 0; // Get all current beam_index sequences in one continuous block (to pass to CUDA) virtual gsl::span GetNextDeviceSequences() = 0; // Get all next beam_index sequences in one continuous block (to pass to CUDA) virtual int GetSequenceLength() const = 0; + virtual int GetMaxLength() const = 0; }; struct ILogitsProcessorList { diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc index 723c271897a78..ecad146da6777 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc @@ -36,6 +36,10 @@ int Sequences::GetSequenceLength() const { return current_length_; } +int Sequences::GetMaxLength() const { + return max_length_; +} + #ifdef DEBUG_GENERATION void Sequences::PrintSequences(const IConsoleDumper* dumper) const { for (int i = 0; i < batch_beam_size_; i++) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.h b/onnxruntime/contrib_ops/cpu/transformers/sequences.h index 440a07e14a6cc..7dd1f28d270c7 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.h @@ -25,6 +25,9 @@ class Sequences : public ISequences { // Returns current sequence length. int GetSequenceLength() const override; + // Returns max sequence length. + int GetMaxLength() const override; + #ifdef DEBUG_GENERATION // Print the sequences to StdOut in debug mode void PrintSequences(const IConsoleDumper* dumper) const; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 6c66bfc2816e4..f4e7173c917c1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -156,7 +156,8 @@ Status T5DecoderSubgraph::CreateInitialFeeds( int cur_len, transformers::Sequences& sequences, int past_present_share_buffer_max_seq_len, - bool need_cache_indir) { + bool need_cache_indir, + bool use_cuda) { ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds"); // Allocate subgraph inputs from same device as inputs of encoder subgraph. @@ -171,8 +172,9 @@ Status T5DecoderSubgraph::CreateInitialFeeds( Tensor::InitOrtValue(DataTypeImpl::GetType(), input_ids_shape, allocator, input_ids); int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); AllocatorPtr buffer_allocator = std::make_shared(); - size_t total_size = static_cast(static_cast(cur_len) * batch_beam_size * sizeof(int)); - auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size, false, stream); + size_t total_size = static_cast(cur_len) * static_cast(batch_beam_size); + size_t total_size_bytes = total_size * sizeof(int); + auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size_bytes, false, stream); int* seq_copy_ptr = seq_copy.get(); if (!use_sequence_as_input_ids_) { @@ -182,19 +184,35 @@ Status T5DecoderSubgraph::CreateInitialFeeds( stream, DeviceCopyDirection::hostToDevice)); } else { - for (int i = 0; i < batch_beam_size; i++) { - gsl::span sequence = sequences.GetSequence(i); - const int32_t* sequence_data = sequence.data(); - long long seq_index = (long long)i * cur_len; - memcpy(seq_copy_ptr + seq_index, sequence_data, total_size); + if (use_cuda) { + auto sequences_buffer = sequences.GetCurrentDeviceSequences(); + for (int i = 0; i < batch_beam_size; i++) { + size_t batch_beam_stride = static_cast(i) * static_cast(sequences.GetMaxLength()); + int seq_size = sequences.GetSequenceLength(); + gsl::span sequence = sequences_buffer.subspan(batch_beam_stride, seq_size); + gsl::span temp_input(input_ids_data + static_cast(i) * seq_size, seq_size); + ORT_RETURN_IF_ERROR(device_copy_int32_func( + temp_input, + sequence, + stream, + DeviceCopyDirection::deviceToDevice)); + } + } else { + const size_t cur_len_bytes = cur_len * sizeof(int); + for (int i = 0; i < batch_beam_size; i++) { + gsl::span sequence = sequences.GetSequence(i); + const int32_t* sequence_data = sequence.data(); + ptrdiff_t seq_index = static_cast(i) * cur_len; + memcpy(seq_copy_ptr + seq_index, sequence_data, cur_len_bytes); + } + gsl::span temp_input(input_ids_data, total_size); + gsl::span temp_sequence(seq_copy_ptr, total_size); + ORT_RETURN_IF_ERROR(device_copy_int32_func( + temp_input, + temp_sequence, + stream, + DeviceCopyDirection::hostToDevice)); } - gsl::span temp_input(input_ids_data, total_size); - gsl::span temp_sequence(seq_copy_ptr, total_size); - ORT_RETURN_IF_ERROR(device_copy_int32_func( - temp_input, - temp_sequence, - stream, - DeviceCopyDirection::hostToDevice)); } // The ordering is the same as used in Setup. @@ -230,7 +248,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( num_beam, allocator, expanded_hidden_states, - true, + false, 0 /*max_sequence_length*/)); } else { ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream, @@ -238,7 +256,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( num_beam, allocator, expanded_hidden_states, - true, + false, 0 /*max_sequence_length*/)); } decoder_feeds.push_back(expanded_hidden_states); diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h index 83dae49c7dcbd..a72ce37a93aba 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h @@ -48,7 +48,8 @@ class T5DecoderSubgraph : public Subgraph { int cur_len, transformers::Sequences& sequences, int past_present_share_buffer_max_seq_len = -1, - bool need_cache_indir = false); + bool need_cache_indir = false, + bool use_cuda = false); Status Validate(const std::vector& subgraph_inputs, const std::vector& subgraph_outputs) override; diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index e047bd948434d..4e65336665bf7 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -1264,16 +1264,14 @@ Status UpdateDecoderFeeds( CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(), cudaMemcpyHostToDevice, cuda_stream)); } else { - for (int i = 0; i < batch_beam_size; i++) { - gsl::span sequence = sequences.GetSequence(i); - const int32_t* sequence_data = sequence.data(); - CUDA_RETURN_IF_ERROR( - cudaMemcpyAsync(input_ids_data + static_cast(i) * current_length, - sequence_data, - current_length * sizeof(int32_t), - cudaMemcpyHostToDevice, - cuda_stream)); - } + // We expect sequences to point directly to device memory + int max_length = sequences.GetMaxLength(); + auto sequences_buffer = sequences.GetCurrentDeviceSequences(); + CUDA_RETURN_IF_ERROR( + cudaMemcpy2DAsync(input_ids_data, current_length * sizeof(int32_t), + sequences_buffer.data(), max_length * sizeof(int32_t), + current_length * sizeof(int32_t), batch_beam_size, + cudaMemcpyDeviceToDevice, cuda_stream)); } next_inputs[0] = input_ids; diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index ca600c0700682..8c69e2d9810b8 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -424,5 +424,19 @@ TEST(BeamSearchTest, DummyT5WithOuterScopeInitializers) { tester.RunWithConfig(); } +TEST(BeamSearchTest, DummyT5WithSequenceInputIds) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_sequence_input_ids.onnx")); + tester.ConfigEp(DefaultCpuExecutionProvider()); + tester.AddInput("encoder_input_ids", {1, 5}, {16, 17, 1, 0, 8}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 19, 18, 3, 8, 8, 8, 8, 8, 8, 2, 19, 18, 3, 10, 19, 18, 3, 8, 8, 2, 19, 18, 15, 13, 13, 13, 13, 13, 13}); +#ifdef USE_CUDA + tester.ConfigEp(DefaultCudaExecutionProvider()); +#endif + tester.RunWithConfig(); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/dummy_t5_with_sequence_input_ids.onnx b/onnxruntime/test/testdata/dummy_t5_with_sequence_input_ids.onnx new file mode 100644 index 0000000000000..5a5c302914890 Binary files /dev/null and b/onnxruntime/test/testdata/dummy_t5_with_sequence_input_ids.onnx differ