Skip to content

Commit

Permalink
[CUDA EP] Fix BeamSearch on T5 with sequence_as_input_ids (#20667) (#…
Browse files Browse the repository at this point in the history
…20668)

### Description
Change the implementation of BeamSearch op when using CUDA EP: in case
of T5 model, and in case the decoder input_ids are sequences, copy the
sequences device-to-device instead of host-to-device

### Motivation and Context
- Fixes #20667
  • Loading branch information
amancini-N authored Dec 11, 2024
1 parent 02f0af0 commit d8de3c4
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 37 deletions.
26 changes: 17 additions & 9 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
current_length,
cpu_state.sequences,
parameters->max_length,
decoder_subgraph_.has_decoder_masked_attention_));
decoder_subgraph_.has_decoder_masked_attention_,
this->cuda_device_prop_ != nullptr));

if (decoder_subgraph_.past_present_share_buffer_) {
decoder_fetches.reserve(static_cast<size_t>(decoder_subgraph_.GetFirstPresentOutputIndex()) +
Expand Down Expand Up @@ -302,17 +303,24 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
auto cur_len = std::to_string(current_length);
dumper->Print("***CurrentLength", cur_len, true);

for (int i = 0; i <= decoder_subgraph_.GetFirstPastInputIndex(); i++) {
for (int i = 0; i < decoder_subgraph_.GetFirstPastInputIndex(); i++) {
dumper->Print("decoder_feeds", i, true);
dumper->Print("", decoder_feeds[i]);
}
auto offset = decoder_subgraph_.GetFirstPastInputIndex() + 4 * decoder_subgraph_.num_layers;
dumper->Print("past_sequence_length", offset, true);
dumper->Print("", decoder_feeds[offset]);
dumper->Print("beam_width", offset + 1, true);
dumper->Print("", decoder_feeds[offset + 1]);
dumper->Print("cache_redir", offset + 2, true);
dumper->Print("", decoder_feeds[offset + 2]);
for (int i = 0; i < decoder_subgraph_.num_layers; i++) {
int self_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * i;
int self_value_idx = self_key_idx + 1;
dumper->Print("past_key_self", i, true);
dumper->Print("", decoder_feeds[self_key_idx]);
dumper->Print("past_value_self", i + 1, true);
dumper->Print("", decoder_feeds[self_value_idx]);
int cross_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * decoder_subgraph_.num_layers + 2 * i;
int cross_value_idx = cross_key_idx + 1;
dumper->Print("past_key_cross", i, true);
dumper->Print("", decoder_feeds[cross_key_idx]);
dumper->Print("past_value_cross", i, true);
dumper->Print("", decoder_feeds[cross_value_idx]);
}
#endif

#ifdef DEBUG_NODE_INPUTS_OUTPUTS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ struct ISequences {
virtual gsl::span<const int32_t> GetCurrentDeviceSequences() const = 0; // Get all current beam_index sequences in one continuous block (to pass to CUDA)
virtual gsl::span<int32_t> GetNextDeviceSequences() = 0; // Get all next beam_index sequences in one continuous block (to pass to CUDA)
virtual int GetSequenceLength() const = 0;
virtual int GetMaxLength() const = 0;
};

struct ILogitsProcessorList {
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cpu/transformers/sequences.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ int Sequences::GetSequenceLength() const {
return current_length_;
}

int Sequences::GetMaxLength() const {
return max_length_;
}

#ifdef DEBUG_GENERATION
void Sequences::PrintSequences(const IConsoleDumper* dumper) const {
for (int i = 0; i < batch_beam_size_; i++) {
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cpu/transformers/sequences.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ class Sequences : public ISequences {
// Returns current sequence length.
int GetSequenceLength() const override;

// Returns max sequence length.
int GetMaxLength() const override;

#ifdef DEBUG_GENERATION
// Print the sequences to StdOut in debug mode
void PrintSequences(const IConsoleDumper* dumper) const;
Expand Down
52 changes: 35 additions & 17 deletions onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
int cur_len,
transformers::Sequences& sequences,
int past_present_share_buffer_max_seq_len,
bool need_cache_indir) {
bool need_cache_indir,
bool use_cuda) {
ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds");

// Allocate subgraph inputs from same device as inputs of encoder subgraph.
Expand All @@ -171,8 +172,9 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), input_ids_shape, allocator, input_ids);
int32_t* input_ids_data = input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
AllocatorPtr buffer_allocator = std::make_shared<onnxruntime::CPUAllocator>();
size_t total_size = static_cast<size_t>(static_cast<long long>(cur_len) * batch_beam_size * sizeof(int));
auto seq_copy = IAllocator::MakeUniquePtr<int>(buffer_allocator, total_size, false, stream);
size_t total_size = static_cast<size_t>(cur_len) * static_cast<size_t>(batch_beam_size);
size_t total_size_bytes = total_size * sizeof(int);
auto seq_copy = IAllocator::MakeUniquePtr<int>(buffer_allocator, total_size_bytes, false, stream);
int* seq_copy_ptr = seq_copy.get();

if (!use_sequence_as_input_ids_) {
Expand All @@ -182,19 +184,35 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
stream,
DeviceCopyDirection::hostToDevice));
} else {
for (int i = 0; i < batch_beam_size; i++) {
gsl::span<const int32_t> sequence = sequences.GetSequence(i);
const int32_t* sequence_data = sequence.data();
long long seq_index = (long long)i * cur_len;
memcpy(seq_copy_ptr + seq_index, sequence_data, total_size);
if (use_cuda) {
auto sequences_buffer = sequences.GetCurrentDeviceSequences();
for (int i = 0; i < batch_beam_size; i++) {
size_t batch_beam_stride = static_cast<size_t>(i) * static_cast<size_t>(sequences.GetMaxLength());
int seq_size = sequences.GetSequenceLength();
gsl::span<const int32_t> sequence = sequences_buffer.subspan(batch_beam_stride, seq_size);
gsl::span<int> temp_input(input_ids_data + static_cast<ptrdiff_t>(i) * seq_size, seq_size);
ORT_RETURN_IF_ERROR(device_copy_int32_func(
temp_input,
sequence,
stream,
DeviceCopyDirection::deviceToDevice));
}
} else {
const size_t cur_len_bytes = cur_len * sizeof(int);
for (int i = 0; i < batch_beam_size; i++) {
gsl::span<const int32_t> sequence = sequences.GetSequence(i);
const int32_t* sequence_data = sequence.data();
ptrdiff_t seq_index = static_cast<ptrdiff_t>(i) * cur_len;
memcpy(seq_copy_ptr + seq_index, sequence_data, cur_len_bytes);
}
gsl::span<int> temp_input(input_ids_data, total_size);
gsl::span<int> temp_sequence(seq_copy_ptr, total_size);
ORT_RETURN_IF_ERROR(device_copy_int32_func(
temp_input,
temp_sequence,
stream,
DeviceCopyDirection::hostToDevice));
}
gsl::span<int> temp_input(input_ids_data, total_size);
gsl::span<int> temp_sequence(seq_copy_ptr, total_size);
ORT_RETURN_IF_ERROR(device_copy_int32_func(
temp_input,
temp_sequence,
stream,
DeviceCopyDirection::hostToDevice));
}

// The ordering is the same as used in Setup.
Expand Down Expand Up @@ -230,15 +248,15 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
num_beam,
allocator,
expanded_hidden_states,
true,
false,
0 /*max_sequence_length*/));
} else {
ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream,
encoder_fetches[j],
num_beam,
allocator,
expanded_hidden_states,
true,
false,
0 /*max_sequence_length*/));
}
decoder_feeds.push_back(expanded_hidden_states);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class T5DecoderSubgraph : public Subgraph {
int cur_len,
transformers::Sequences& sequences,
int past_present_share_buffer_max_seq_len = -1,
bool need_cache_indir = false);
bool need_cache_indir = false,
bool use_cuda = false);

Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1264,16 +1264,14 @@ Status UpdateDecoderFeeds(
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(),
cudaMemcpyHostToDevice, cuda_stream));
} else {
for (int i = 0; i < batch_beam_size; i++) {
gsl::span<const int32_t> sequence = sequences.GetSequence(i);
const int32_t* sequence_data = sequence.data();
CUDA_RETURN_IF_ERROR(
cudaMemcpyAsync(input_ids_data + static_cast<ptrdiff_t>(i) * current_length,
sequence_data,
current_length * sizeof(int32_t),
cudaMemcpyHostToDevice,
cuda_stream));
}
// We expect sequences to point directly to device memory
int max_length = sequences.GetMaxLength();
auto sequences_buffer = sequences.GetCurrentDeviceSequences();
CUDA_RETURN_IF_ERROR(
cudaMemcpy2DAsync(input_ids_data, current_length * sizeof(int32_t),
sequences_buffer.data(), max_length * sizeof(int32_t),
current_length * sizeof(int32_t), batch_beam_size,
cudaMemcpyDeviceToDevice, cuda_stream));
}
next_inputs[0] = input_ids;

Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/test/contrib_ops/beam_search_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,5 +424,19 @@ TEST(BeamSearchTest, DummyT5WithOuterScopeInitializers) {
tester.RunWithConfig();
}

TEST(BeamSearchTest, DummyT5WithSequenceInputIds) {
#if defined(USE_CUDA) && defined(USE_DML)
SKIP_CUDA_TEST_WITH_DML;
#endif
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_sequence_input_ids.onnx"));
tester.ConfigEp(DefaultCpuExecutionProvider());
tester.AddInput("encoder_input_ids", {1, 5}, {16, 17, 1, 0, 8});
tester.AddOutput("sequences", {1, 3, 10}, {2, 19, 18, 3, 8, 8, 8, 8, 8, 8, 2, 19, 18, 3, 10, 19, 18, 3, 8, 8, 2, 19, 18, 15, 13, 13, 13, 13, 13, 13});
#ifdef USE_CUDA
tester.ConfigEp(DefaultCudaExecutionProvider());
#endif
tester.RunWithConfig();
}

} // namespace test
} // namespace onnxruntime
Binary file not shown.

0 comments on commit d8de3c4

Please sign in to comment.