Skip to content

Commit

Permalink
Enable pointer-generator T5 models in BeamSearch (#23134)
Browse files Browse the repository at this point in the history
### Description
Introduces a new optional input (encoder_ibnput_ids) in the decoder
graph of the T5 implementation for BeamSearch. This allows usage of
pointer generator networks in decoder graph.

### Motivation and Context
- Fixes #23123
  • Loading branch information
amancini-N authored Dec 23, 2024
1 parent ebdbbb7 commit c6ba7ed
Show file tree
Hide file tree
Showing 5 changed files with 448 additions and 26 deletions.
65 changes: 45 additions & 20 deletions onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ namespace transformers {
Inputs:
input_ids: int32 (B, 1)
encoder_input_ids: int32 (B, encode_sequence_length) (optional)
encoder_attention_mask: int32 (B, encode_sequence_length)
encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) (optional)
past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size)
past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size)
Expand Down Expand Up @@ -49,11 +50,9 @@ namespace transformers {

Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) {
bool has_hidden_state = subgraph_inputs[2]->Name() == "encoder_hidden_states" ? true : false;
SetPastInputIndex(has_hidden_state);

ORT_RETURN_IF(first_past_input_index_ != 2 && first_past_input_index_ != 3,
"kFirstPastInputIndex currently only supports 2 or 3");
bool has_encoder_input_ids = subgraph_inputs[1]->Name() == "encoder_input_ids";
bool has_hidden_state = subgraph_inputs[2 + has_encoder_input_ids]->Name() == "encoder_hidden_states";
SetPastInputIndex(has_hidden_state, has_encoder_input_ids);

if (!past_present_share_buffer_) {
ORT_RETURN_IF(has_decoder_masked_attention_, "decoder_masked_attention shall use with past_present_share_buffer");
Expand All @@ -75,13 +74,17 @@ Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_i

ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids",
"decoder subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name());
ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_attention_mask",
"decoder subgraph input 1 shall be named as encoder_attention_mask, got: ",
subgraph_inputs[1]->Name());
if (first_past_input_index_ == 3) {
ORT_RETURN_IF(subgraph_inputs[2]->Name() != "encoder_hidden_states",
"decoder subgraph input 2 shall be named as encoder_hidden_states, got: ",
subgraph_inputs[2]->Name());
const int enc_attn_mask_index = 1 + has_encoder_input_ids_;
const int enc_hidden_state_index = enc_attn_mask_index + 1;
ORT_RETURN_IF(subgraph_inputs[enc_attn_mask_index]->Name() != "encoder_attention_mask",
"decoder subgraph input ", std::to_string(enc_attn_mask_index),
" shall be named as encoder_attention_mask, got: ",
subgraph_inputs[enc_attn_mask_index]->Name());
if (has_hidden_state_) {
ORT_RETURN_IF(subgraph_inputs[enc_hidden_state_index]->Name() != "encoder_hidden_states",
"decoder subgraph input ", std::to_string(enc_hidden_state_index),
" shall be named as encoder_hidden_states, got: ",
subgraph_inputs[enc_hidden_state_index]->Name());
}

// check subgraph outputs
Expand All @@ -108,12 +111,19 @@ Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_i

ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"decoder subgraph input 0 (input_ids) shall have int32 type");
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"decoder subgraph input 1 (encoder_attention_mask) shall have int32 type");

auto float_type = subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type();
ORT_RETURN_IF(float_type != float32_type && float_type != float16_type,
"decoder subgraph input 2 (encoder_hidden_states) shall have float or float16 type");
if (has_encoder_input_ids_) {
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"decoder subgraph input 1 (encoder_input_ids) shall have int32 type");
}
ORT_RETURN_IF(subgraph_inputs[enc_attn_mask_index]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"decoder subgraph input ", std::to_string(enc_attn_mask_index),
" (encoder_attention_mask) shall have int32 type");

auto float_type = subgraph_inputs[enc_hidden_state_index]->TypeAsProto()->tensor_type().elem_type();
if (has_hidden_state_) {
ORT_RETURN_IF(float_type != float32_type && float_type != float16_type,
"decoder subgraph input ", std::to_string(enc_hidden_state_index), " (encoder_hidden_states) shall have float or float16 type");
}

for (int i = first_past_input_index_; i < first_past_input_index_ + 4 * num_layers; i++) {
ORT_RETURN_IF(subgraph_inputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type,
Expand Down Expand Up @@ -219,6 +229,19 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
decoder_feeds.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
decoder_feeds.push_back(input_ids);

if (has_encoder_input_ids_) {
// The encoder_input_ids is copied from the first input of encoder.
OrtValue expanded_encoder_input_ids;
ORT_RETURN_IF_ERROR(expand_buffer_int32_func(stream,
encoder_feeds[0],
num_beam,
allocator,
expanded_encoder_input_ids,
false,
0 /*max_sequence_length*/));
decoder_feeds.push_back(expanded_encoder_input_ids);
}

// The encoder_attention_mask is copied from the second input of encoder.
OrtValue expanded_decoder_attention_masks;
ORT_RETURN_IF_ERROR(expand_buffer_int32_func(stream,
Expand All @@ -238,7 +261,9 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
// When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output
// of encoder.
// When first_past_input_index_ == 2, the past states are copied from the second output of encoder.
for (size_t j = static_cast<size_t>(4) - first_past_input_index_; j < encoder_fetches.size(); j++) {
// TODO - probably more robust to introduce a encoder_out/decoder_in mapping instead of relying on positions.
// What happens if encoder_hidden_states is present in the encoder_fetches but not in the decoder_feeds?
for (size_t j = static_cast<size_t>(2) - has_hidden_state_; j < encoder_fetches.size(); j++) {
if (j == 1) {
ORT_RETURN_IF(has_hidden_state_ == false, "Invalid hidden_states expension: has_hidden_state_ == false");
OrtValue expanded_hidden_states;
Expand Down
10 changes: 4 additions & 6 deletions onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,10 @@ class T5DecoderSubgraph : public Subgraph {
Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) override;

void SetPastInputIndex(bool has_hidden_state) {
void SetPastInputIndex(bool has_hidden_state, bool has_encoder_input_ids) {
has_hidden_state_ = has_hidden_state;
if (!has_hidden_state_) {
first_past_input_index_ = 2;
} else {
first_past_input_index_ = 3;
}
has_encoder_input_ids_ = has_encoder_input_ids;
first_past_input_index_ = 2 + has_hidden_state_ + has_encoder_input_ids_;
}

int GetFirstPastInputIndex() const {
Expand All @@ -79,6 +76,7 @@ class T5DecoderSubgraph : public Subgraph {
int first_past_input_index_;
int first_present_output_index_;
bool has_hidden_state_;
bool has_encoder_input_ids_;
bool use_sequence_as_input_ids_;
};

Expand Down
22 changes: 22 additions & 0 deletions onnxruntime/test/contrib_ops/beam_search_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,8 @@ TEST(BeamSearchTest, DummyT5) {
#if defined(USE_CUDA) && defined(USE_DML)
SKIP_CUDA_TEST_WITH_DML;
#endif
// dummy_t5.onnx model generated using following command:
// python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5.onnx
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5.onnx"));
tester.ConfigEp(DefaultCpuExecutionProvider());
tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7});
Expand All @@ -408,6 +410,8 @@ TEST(BeamSearchTest, DummyT5WithOuterScopeInitializers) {
#if defined(USE_CUDA) && defined(USE_DML)
SKIP_CUDA_TEST_WITH_DML;
#endif
// dummy_t5_with_outer_scope_initializers.onnx model generated using following command:
// python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_with_outer_scope_initializers.onnx --move-initializers
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_outer_scope_initializers.onnx"));
tester.ConfigEp(DefaultCpuExecutionProvider());
tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7});
Expand All @@ -422,6 +426,8 @@ TEST(BeamSearchTest, DummyT5WithSequenceInputIds) {
#if defined(USE_CUDA) && defined(USE_DML)
SKIP_CUDA_TEST_WITH_DML;
#endif
// dummy_t5_with_sequence_input_ids.onnx model generated using following command:
// python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_with_sequence_input_ids.onnx --sequence-as-input
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_sequence_input_ids.onnx"));
tester.ConfigEp(DefaultCpuExecutionProvider());
tester.AddInput("encoder_input_ids", {1, 5}, {16, 17, 1, 0, 8});
Expand All @@ -432,5 +438,21 @@ TEST(BeamSearchTest, DummyT5WithSequenceInputIds) {
tester.RunWithConfig();
}

TEST(BeamSearchTest, DummyT5PointerGenerator) {
#if defined(USE_CUDA) && defined(USE_DML)
SKIP_CUDA_TEST_WITH_DML;
#endif
// dummy_t5_pointer_generator.onnx model generated using following command:
// python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_pointer_generator.onnx --decoder-needs-input-ids
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_pointer_generator.onnx"));
tester.ConfigEp(DefaultCpuExecutionProvider());
tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7});
tester.AddOutput("sequences", {1, 3, 10}, {2, 3, 6, 7, 3, 6, 7, 18, 3, 6, 2, 3, 6, 7, 18, 3, 6, 7, 18, 3, 2, 3, 6, 7, 3, 6, 7, 3, 6, 7});
#ifdef USE_CUDA
tester.ConfigEp(DefaultCudaExecutionProvider());
#endif
tester.RunWithConfig();
}

} // namespace test
} // namespace onnxruntime
Loading

0 comments on commit c6ba7ed

Please sign in to comment.