From ca1c47466a8585dad512ebb2886bd1cb906f36e3 Mon Sep 17 00:00:00 2001 From: amancini-N Date: Tue, 17 Dec 2024 11:25:21 +0000 Subject: [PATCH 1/3] Enable pointer-generator T5 models in BeamSearch --- .../cpu/transformers/subgraph_t5_decoder.cc | 66 +++-- .../cpu/transformers/subgraph_t5_decoder.h | 10 +- .../test/contrib_ops/beam_search_test.cc | 22 ++ .../test/testdata/dummy_t5_model_generator.py | 235 ++++++++++++++++++ .../testdata/dummy_t5_pointer_generator.onnx | Bin 0 -> 7100 bytes 5 files changed, 309 insertions(+), 24 deletions(-) create mode 100644 onnxruntime/test/testdata/dummy_t5_model_generator.py create mode 100644 onnxruntime/test/testdata/dummy_t5_pointer_generator.onnx diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index f4e7173c917c1..8db69150919d5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -49,11 +49,12 @@ namespace transformers { Status T5DecoderSubgraph::Validate(const std::vector& subgraph_inputs, const std::vector& subgraph_outputs) { - bool has_hidden_state = subgraph_inputs[2]->Name() == "encoder_hidden_states" ? true : false; - SetPastInputIndex(has_hidden_state); + bool has_encoder_input_ids = subgraph_inputs[1]->Name() == "encoder_input_ids"; + bool has_hidden_state = subgraph_inputs[2 + has_encoder_input_ids]->Name() == "encoder_hidden_states"; + SetPastInputIndex(has_hidden_state, has_encoder_input_ids); - ORT_RETURN_IF(first_past_input_index_ != 2 && first_past_input_index_ != 3, - "kFirstPastInputIndex currently only supports 2 or 3"); + ORT_RETURN_IF(first_past_input_index_ != 2 && first_past_input_index_ != 3 && first_past_input_index_ != 4, + "kFirstPastInputIndex currently only supports 2, 3 or 4"); if (!past_present_share_buffer_) { ORT_RETURN_IF(has_decoder_masked_attention_, "decoder_masked_attention shall use with past_present_share_buffer"); @@ -75,13 +76,22 @@ Status T5DecoderSubgraph::Validate(const std::vector& subgraph_i ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids", "decoder subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name()); - ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_attention_mask", - "decoder subgraph input 1 shall be named as encoder_attention_mask, got: ", - subgraph_inputs[1]->Name()); - if (first_past_input_index_ == 3) { - ORT_RETURN_IF(subgraph_inputs[2]->Name() != "encoder_hidden_states", - "decoder subgraph input 2 shall be named as encoder_hidden_states, got: ", - subgraph_inputs[2]->Name()); + const int enc_attn_mask_index = 1 + has_encoder_input_ids_; + const int enc_hidden_state_index = enc_attn_mask_index + 1; + if (has_encoder_input_ids_) { + ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_input_ids", + "decoder subgraph input 1 shall be named as encoder_input_ids, got: ", + subgraph_inputs[1]->Name()); + } + ORT_RETURN_IF(subgraph_inputs[enc_attn_mask_index]->Name() != "encoder_attention_mask", + "decoder subgraph input ", std::to_string(enc_attn_mask_index), + " shall be named as encoder_attention_mask, got: ", + subgraph_inputs[enc_attn_mask_index]->Name()); + if (has_hidden_state_) { + ORT_RETURN_IF(subgraph_inputs[enc_hidden_state_index]->Name() != "encoder_hidden_states", + "decoder subgraph input ", std::to_string(enc_hidden_state_index), + " shall be named as encoder_hidden_states, got: ", + subgraph_inputs[enc_hidden_state_index]->Name()); } // check subgraph outputs @@ -108,12 +118,19 @@ Status T5DecoderSubgraph::Validate(const std::vector& subgraph_i ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type, "decoder subgraph input 0 (input_ids) shall have int32 type"); - ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type, - "decoder subgraph input 1 (encoder_attention_mask) shall have int32 type"); - - auto float_type = subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type(); - ORT_RETURN_IF(float_type != float32_type && float_type != float16_type, - "decoder subgraph input 2 (encoder_hidden_states) shall have float or float16 type"); + if (has_encoder_input_ids_) { + ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type, + "decoder subgraph input 1 (encoder_input_ids) shall have int32 type"); + } + ORT_RETURN_IF(subgraph_inputs[enc_attn_mask_index]->TypeAsProto()->tensor_type().elem_type() != int32_type, + "decoder subgraph input ", std::to_string(enc_attn_mask_index), + " (encoder_attention_mask) shall have int32 type"); + + auto float_type = subgraph_inputs[enc_hidden_state_index]->TypeAsProto()->tensor_type().elem_type(); + if (has_hidden_state_) { + ORT_RETURN_IF(float_type != float32_type && float_type != float16_type, + "decoder subgraph input ", std::to_string(enc_hidden_state_index), " (encoder_hidden_states) shall have float or float16 type"); + } for (int i = first_past_input_index_; i < first_past_input_index_ + 4 * num_layers; i++) { ORT_RETURN_IF(subgraph_inputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type, @@ -219,6 +236,19 @@ Status T5DecoderSubgraph::CreateInitialFeeds( decoder_feeds.reserve(static_cast(num_subgraph_inputs) + static_cast(num_implicit_inputs)); decoder_feeds.push_back(input_ids); + if (has_encoder_input_ids_) { + // The encoder_input_ids is copied from the first input of encoder. + OrtValue expanded_encoder_input_ids; + ORT_RETURN_IF_ERROR(expand_buffer_int32_func(stream, + encoder_feeds[0], + num_beam, + allocator, + expanded_encoder_input_ids, + false, + 0 /*max_sequence_length*/)); + decoder_feeds.push_back(expanded_encoder_input_ids); + } + // The encoder_attention_mask is copied from the second input of encoder. OrtValue expanded_decoder_attention_masks; ORT_RETURN_IF_ERROR(expand_buffer_int32_func(stream, @@ -238,7 +268,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( // When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output // of encoder. // When first_past_input_index_ == 2, the past states are copied from the second output of encoder. - for (size_t j = static_cast(4) - first_past_input_index_; j < encoder_fetches.size(); j++) { + for (size_t j = static_cast(2) - has_hidden_state_; j < encoder_fetches.size(); j++) { if (j == 1) { ORT_RETURN_IF(has_hidden_state_ == false, "Invalid hidden_states expension: has_hidden_state_ == false"); OrtValue expanded_hidden_states; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h index a72ce37a93aba..b5d727b67924c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h @@ -54,13 +54,10 @@ class T5DecoderSubgraph : public Subgraph { Status Validate(const std::vector& subgraph_inputs, const std::vector& subgraph_outputs) override; - void SetPastInputIndex(bool has_hidden_state) { + void SetPastInputIndex(bool has_hidden_state, bool has_encoder_input_ids) { has_hidden_state_ = has_hidden_state; - if (!has_hidden_state_) { - first_past_input_index_ = 2; - } else { - first_past_input_index_ = 3; - } + has_encoder_input_ids_ = has_encoder_input_ids; + first_past_input_index_ = 2 + has_hidden_state_ + has_encoder_input_ids_; } int GetFirstPastInputIndex() const { @@ -79,6 +76,7 @@ class T5DecoderSubgraph : public Subgraph { int first_past_input_index_; int first_present_output_index_; bool has_hidden_state_; + bool has_encoder_input_ids_; bool use_sequence_as_input_ids_; }; diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index 8c69e2d9810b8..468ca9083f703 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -400,6 +400,8 @@ TEST(BeamSearchTest, DummyT5) { #if defined(USE_CUDA) && defined(USE_DML) SKIP_CUDA_TEST_WITH_DML; #endif + // dummy_t5.onnx model generated using following command: + // python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5.onnx ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5.onnx")); tester.ConfigEp(DefaultCpuExecutionProvider()); tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7}); @@ -414,6 +416,8 @@ TEST(BeamSearchTest, DummyT5WithOuterScopeInitializers) { #if defined(USE_CUDA) && defined(USE_DML) SKIP_CUDA_TEST_WITH_DML; #endif + // dummy_t5_with_outer_scope_initializers.onnx model generated using following command: + // python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_with_outer_scope_initializers.onnx --move-initializers ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_outer_scope_initializers.onnx")); tester.ConfigEp(DefaultCpuExecutionProvider()); tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7}); @@ -428,6 +432,8 @@ TEST(BeamSearchTest, DummyT5WithSequenceInputIds) { #if defined(USE_CUDA) && defined(USE_DML) SKIP_CUDA_TEST_WITH_DML; #endif + // dummy_t5_with_sequence_input_ids.onnx model generated using following command: + // python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_with_sequence_input_ids.onnx --sequence-as-input ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_sequence_input_ids.onnx")); tester.ConfigEp(DefaultCpuExecutionProvider()); tester.AddInput("encoder_input_ids", {1, 5}, {16, 17, 1, 0, 8}); @@ -438,5 +444,21 @@ TEST(BeamSearchTest, DummyT5WithSequenceInputIds) { tester.RunWithConfig(); } +TEST(BeamSearchTest, DummyT5PointerGenerator) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + // dummy_t5_pointer_generator.onnx model generated using following command: + // python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_pointer_generator.onnx --decoder-needs-input-ids + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_pointer_generator.onnx")); + tester.ConfigEp(DefaultCpuExecutionProvider()); + tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 3, 6, 7, 3, 6, 7, 18, 3, 6, 2, 3, 6, 7, 18, 3, 6, 7, 18, 3, 2, 3, 6, 7, 3, 6, 7, 3, 6, 7}); +#ifdef USE_CUDA + tester.ConfigEp(DefaultCudaExecutionProvider()); +#endif + tester.RunWithConfig(); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/dummy_t5_model_generator.py b/onnxruntime/test/testdata/dummy_t5_model_generator.py new file mode 100644 index 0000000000000..e4ac59ef71bd5 --- /dev/null +++ b/onnxruntime/test/testdata/dummy_t5_model_generator.py @@ -0,0 +1,235 @@ +""" Script to generate a dummy ONNX model emulating T5 model with BeamSearch op. """ +import argparse + +import onnx +import onnxruntime as ort +from onnxruntime.transformers.convert_generation import move_initializers +import numpy as np + + +def create_model( + vocab_size: int, + embed_dim: int, + num_heads: int, + head_size: int, + beam_size: int, + min_length: int, + max_length: int, + length_penalty: float, + sequence_as_input: bool, + decoder_needs_input_ids: bool +) -> onnx.ModelProto: + encoder_graph = create_encoder(vocab_size, embed_dim, num_heads, head_size) + decoder_graph = create_decoder(vocab_size, embed_dim, num_heads, head_size, sequence_as_input, decoder_needs_input_ids) + + # Inputs: encoder_input_ids + encoder_input_ids = onnx.helper.make_tensor_value_info('encoder_input_ids', onnx.TensorProto.INT32, ['batch_size', 'encode_sequence_length']) + + # Outputs: sequences, scores + sequences = onnx.helper.make_tensor_value_info('sequences', onnx.TensorProto.INT32, ['batch_size', beam_size, 'decode_sequence_length']) + scores = onnx.helper.make_tensor_value_info('scores', onnx.TensorProto.FLOAT, ['batch_size', beam_size]) + + # Tensors + max_length_t = onnx.numpy_helper.from_array(np.array(max_length, dtype=np.int32), name='max_length') + min_length_t = onnx.numpy_helper.from_array(np.array(min_length, dtype=np.int32), name='min_length') + num_beams_t = onnx.numpy_helper.from_array(np.array(beam_size, dtype=np.int32), name='num_beams') + length_penalty_t = onnx.numpy_helper.from_array(np.array(length_penalty, dtype=np.float32), name='length_penalty_as_tensor') + + # Nodes + beam_search = onnx.helper.make_node( + 'BeamSearch', + ['encoder_input_ids', 'max_length', 'min_length', 'num_beams', 'num_beams', 'length_penalty_as_tensor'], + ['sequences', 'scores'], + decoder_start_token_id=2, + eos_token_id=2, + early_stopping=0, + model_type=1, + pad_token_id=1, + decoder=decoder_graph, + encoder=encoder_graph, + domain='com.microsoft', + ) + + # Graph + graph = onnx.helper.make_graph( + [beam_search], + 'model', + [encoder_input_ids], + [sequences, scores], + [max_length_t, min_length_t, num_beams_t, length_penalty_t] + ) + + # Model + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid('', 17), onnx.helper.make_opsetid('com.microsoft', 1)]) + + return model + + +def create_encoder(vocab_size, embed_dim, num_heads, head_size) -> onnx.GraphProto: + # Inputs: encoder_input_ids, encoder_attention_mask, decoder_input_ids + encoder_input_ids = onnx.helper.make_tensor_value_info('encoder_input_ids', onnx.TensorProto.INT32, ['batch_size', 'encode_sequence_length']) + encoder_attention_mask = onnx.helper.make_tensor_value_info('encoder_attention_mask', onnx.TensorProto.INT32, ['batch_size', 'encode_sequence_length']) + decoder_input_ids = onnx.helper.make_tensor_value_info('decoder_input_ids', onnx.TensorProto.INT32, ['batch_size', 1]) + + # Outputs: logits, present_key_self_0, present_value_self_0, present_key_cross_0, present_value_cross_0, encoder_hidden_states + logits = onnx.helper.make_tensor_value_info('logits', onnx.TensorProto.FLOAT, ['batch_size', 'decode_sequence_length', vocab_size]) + present_key_self_0 = onnx.helper.make_tensor_value_info('present_key_self_0', onnx.TensorProto.FLOAT, ['batch_size', num_heads, 1, head_size]) + present_value_self_0 = onnx.helper.make_tensor_value_info('present_value_self_0', onnx.TensorProto.FLOAT, ['batch_size', num_heads, 1, head_size]) + present_key_cross_0 = onnx.helper.make_tensor_value_info('present_key_cross_0', onnx.TensorProto.FLOAT, ['batch_size', num_heads, 'encode_sequence_length', head_size]) + present_value_cross_0 = onnx.helper.make_tensor_value_info('present_value_cross_0', onnx.TensorProto.FLOAT, ['batch_size', num_heads, 'encode_sequence_length', head_size]) + encoder_hidden_states = onnx.helper.make_tensor_value_info('encoder_hidden_states', onnx.TensorProto.FLOAT, ['batch_size', 'encode_sequence_length', embed_dim]) + + # Tensors + encoder_embeddings_tensor = onnx.numpy_helper.from_array(np.random.randn(vocab_size, embed_dim).astype(np.float32), name='encoder_embeddings') + num_heads_and_size_tensor = onnx.numpy_helper.from_array(np.array([num_heads, head_size], dtype=np.int64), name='num_heads_and_size') + final_proj_tensor = onnx.numpy_helper.from_array(np.random.randn(embed_dim, vocab_size).astype(np.float32), name='init_final_proj') + self_state_before_tranpose_shape_tensor = onnx.numpy_helper.from_array(np.array([-1, 1, num_heads, head_size], dtype=np.int64), name='self_state_before_tranpose_shape') + + # Nodes + nodes = [ + onnx.helper.make_node('Gather', ['encoder_embeddings', 'encoder_input_ids'], ['encoder_hidden_states']), + onnx.helper.make_node('Shape', ['encoder_hidden_states'], ['encoder_batch_seq_len'], end=2), + onnx.helper.make_node('Concat', ['encoder_batch_seq_len', 'num_heads_and_size'], ['encoder_final_shape'], axis=0), + onnx.helper.make_node('Reshape', ['encoder_hidden_states', 'encoder_final_shape'], ['encoder_hidden_states_reshaped']), + + onnx.helper.make_node('Transpose', ['encoder_hidden_states_reshaped'], ['present_key_cross_0'], perm=[0, 2, 1, 3]), + onnx.helper.make_node('Transpose', ['encoder_hidden_states_reshaped'], ['present_value_cross_0'], perm=[0, 2, 1, 3]), + + onnx.helper.make_node('Gather', ['encoder_embeddings', 'decoder_input_ids'], ['decoder_hidden_states']), + onnx.helper.make_node('ReduceMean', ['encoder_hidden_states'], ['encoder_hidden_states_mean'], axes=[1]), + onnx.helper.make_node('Add', ['decoder_hidden_states', 'encoder_hidden_states_mean'], ['encoder_decoder_sum']), + + onnx.helper.make_node('MatMul', ['encoder_decoder_sum', 'init_final_proj'], ['logits']), + + onnx.helper.make_node('Reshape', ['encoder_decoder_sum', 'self_state_before_tranpose_shape'], ['self_state_before_tranpose']), + onnx.helper.make_node('Transpose', ['self_state_before_tranpose'], ['present_key_self_0'], perm=[0, 2, 1, 3]), + onnx.helper.make_node('Transpose', ['self_state_before_tranpose'], ['present_value_self_0'], perm=[0, 2, 1, 3]), + ] + + # Graph + graph = onnx.helper.make_graph( + nodes, + 'encoder', + [encoder_input_ids, encoder_attention_mask, decoder_input_ids], + [logits, encoder_hidden_states, present_key_self_0, present_value_self_0, present_key_cross_0, present_value_cross_0], + [encoder_embeddings_tensor, num_heads_and_size_tensor, final_proj_tensor, self_state_before_tranpose_shape_tensor] + ) + return graph + + +def create_decoder(vocab_size, embed_dim, num_heads, head_size, sequence_as_input, decoder_needs_input_ids) -> onnx.GraphProto: + # Inputs: input_ids, encoder_input_ids (optional), encoder_attention_mask, past_self_key_0, past_self_value_0, past_cross_key_0, past_cross_value_0 + inputs = [] + inputs.append(onnx.helper.make_tensor_value_info('input_ids', onnx.TensorProto.INT32, ['batch_size', 'decode_sequence_length' if sequence_as_input else 1])) + if decoder_needs_input_ids: + inputs.append(onnx.helper.make_tensor_value_info('encoder_input_ids', onnx.TensorProto.INT32, ['batch_size', 'encode_sequence_length'])) + inputs.append(onnx.helper.make_tensor_value_info('encoder_attention_mask', onnx.TensorProto.INT32, ['batch_size', 'encode_sequence_length'])) + inputs.append(onnx.helper.make_tensor_value_info('past_self_key_0', onnx.TensorProto.FLOAT, ['batch_size', num_heads, 'decode_sequence_length', head_size])) + inputs.append(onnx.helper.make_tensor_value_info('past_self_value_0', onnx.TensorProto.FLOAT, ['batch_size', num_heads, 'decode_sequence_length', head_size])) + inputs.append(onnx.helper.make_tensor_value_info('past_cross_key_0', onnx.TensorProto.FLOAT, ['batch_size', num_heads, 'encode_sequence_length', head_size])) + inputs.append(onnx.helper.make_tensor_value_info('past_cross_value_0', onnx.TensorProto.FLOAT, ['batch_size', num_heads, 'encode_sequence_length', head_size])) + + # Outputs: logits, present_key_self_0, present_value_self_0 + outputs = [ + onnx.helper.make_tensor_value_info('logits', onnx.TensorProto.FLOAT, ['batch_size', 1, vocab_size]), + onnx.helper.make_tensor_value_info('present_key_self_0', onnx.TensorProto.FLOAT, ['batch_size', num_heads, 'present_decode_sequence_length', head_size]), + onnx.helper.make_tensor_value_info('present_value_self_0', onnx.TensorProto.FLOAT, ['batch_size', num_heads, 'present_decode_sequence_length', head_size]), + ] + + # Tensors: decoder_embeddings, final_proj, self_state_before_tranpose_shape_no_batch, hidden_states_mean + initializers = [ + onnx.numpy_helper.from_array(np.random.randn(vocab_size, embed_dim).astype(np.float32), name='decoder_embeddings'), + onnx.numpy_helper.from_array(np.random.randn(embed_dim, vocab_size).astype(np.float32), name='final_proj'), + onnx.numpy_helper.from_array(np.array([-1, num_heads, head_size], dtype=np.int64), name='self_state_before_tranpose_shape_no_batch'), + onnx.numpy_helper.from_array(np.array([-1, 1, embed_dim], dtype=np.int64), name='hidden_states_mean_shape'), + ] + + # Nodes + nodes = [] + nodes.append(onnx.helper.make_node('Gather', ['decoder_embeddings', 'input_ids'], ['decoder_hidden_states'])) + if decoder_needs_input_ids: + nodes.append(onnx.helper.make_node('Gather', ['decoder_embeddings', 'encoder_input_ids'], ['encoder_input_embeddings'])) + nodes.append(onnx.helper.make_node('ReduceMean', ['encoder_input_embeddings'], ['encoder_input_embeddings_mean'], axes=[1])) + nodes.append(onnx.helper.make_node('Mul', ['decoder_hidden_states', 'encoder_input_embeddings_mean'], ['combined_hidden_states'])) + else: + nodes.append(onnx.helper.make_node('Identity', ['decoder_hidden_states'], ['combined_hidden_states'])) + nodes.append(onnx.helper.make_node('ReduceMean', ['past_cross_key_0'], ['encoder_hidden_states_mean'], axes=[2])) + nodes.append(onnx.helper.make_node('Reshape', ['encoder_hidden_states_mean', 'hidden_states_mean_shape'], ['encoder_hidden_states_mean_reshaped'])) + if sequence_as_input: + nodes.append(onnx.helper.make_node('ReduceMean', ['combined_hidden_states'], ['decoder_hidden_states_mean'], axes=[1])) + nodes.append(onnx.helper.make_node('Add', ['decoder_hidden_states_mean', 'encoder_hidden_states_mean_reshaped'], ['encoder_decoder_sum'])) + else: + nodes.append(onnx.helper.make_node('Add', ['combined_hidden_states', 'encoder_hidden_states_mean_reshaped'], ['encoder_decoder_sum'])) + nodes.append(onnx.helper.make_node('Shape', ['combined_hidden_states'], ['decoder_batch'], end=1)) + nodes.append(onnx.helper.make_node('Concat', ['decoder_batch', 'self_state_before_tranpose_shape_no_batch'], ['self_state_before_tranpose_shape_dec'], axis=0)) + nodes.append(onnx.helper.make_node('MatMul', ['encoder_decoder_sum', 'final_proj'], ['logits'])) + nodes.append(onnx.helper.make_node('Reshape', ['encoder_decoder_sum', 'self_state_before_tranpose_shape_dec'], ['self_state_before_tranpose'])) + nodes.append(onnx.helper.make_node('Transpose', ['self_state_before_tranpose'], ['single_self_key_0'], perm=[0, 2, 1, 3])) + nodes.append(onnx.helper.make_node('Transpose', ['self_state_before_tranpose'], ['single_self_value_0'], perm=[0, 2, 1, 3])) + nodes.append(onnx.helper.make_node('Concat', ['past_self_key_0', 'single_self_key_0'], ['present_key_self_0'], axis=2)) + nodes.append(onnx.helper.make_node('Concat', ['past_self_value_0', 'single_self_value_0'], ['present_value_self_0'], axis=2)) + + + # Graph + graph = onnx.helper.make_graph(nodes, 'decoder', inputs, outputs, initializers) + return graph + + +def run_model(model_path): + ort_session = ort.InferenceSession(model_path) + encoder_input_ids = np.array([[14, 6, 13, 9, 7]]).astype(np.int32) + print("encoder_input_ids: ", encoder_input_ids) + sequence, scores = ort_session.run(None, {'encoder_input_ids': encoder_input_ids}) + print("sequence: ", sequence) + print("scores: ", scores) + + +def move_initializers_on_outer_scope(model) -> None: + main_graph = model.graph + beam_search_node = model.graph.node[0] + decoder_graph = [attr for attr in beam_search_node.attribute if attr.name == 'decoder'][0].g + encoder_graph = [attr for attr in beam_search_node.attribute if attr.name == 'encoder'][0].g + main_graph.initializer.extend(move_initializers(decoder_graph, min_elements=10)) + main_graph.initializer.extend(move_initializers(encoder_graph, min_elements=10)) + + +def arg_parser(): + parser = argparse.ArgumentParser(description='Generate a dummy ONNX model emulating T5 model with BeamSearch op.') + parser.add_argument('--output-path', type=str, default='model.onnx', help='Model output path') + parser.add_argument('--seed', type=int, default=42, help='Random seed') + parser.add_argument('--vocab-size', type=int, default=20, help='Vocab size') + parser.add_argument('--embed-dim', type=int, default=8, help='Embedding dimension') + parser.add_argument('--num-heads', type=int, default=2, help='Number of heads') + parser.add_argument('--head-size', type=int, default=4, help='Head size') + parser.add_argument('--beam-size', type=int, default=3, help='Beam size') + parser.add_argument('--min-length', type=int, default=1, help='Min length') + parser.add_argument('--max-length', type=int, default=10, help='Max length') + parser.add_argument('--length-penalty', type=float, default=1.1, help='Length penalty') + parser.add_argument('--move-initializers', action='store_true', help='Move initializers to outer scope') + parser.add_argument('--sequence-as-input', action='store_true', help='Use sequence as input') + parser.add_argument('--decoder-needs-input-ids', action='store_true', help='Decoder needs model/encoder input ids') + + return parser.parse_args() + + +if __name__ == '__main__': + args = arg_parser() + np.random.seed(args.seed) + + model = create_model( + args.vocab_size, + args.embed_dim, + args.num_heads, + args.head_size, + args.beam_size, + args.min_length, + args.max_length, + args.length_penalty, + args.sequence_as_input, + args.decoder_needs_input_ids, + ) + if args.move_initializers: + move_initializers_on_outer_scope(model) + onnx.save(model, args.output_path) + + run_model(args.output_path) diff --git a/onnxruntime/test/testdata/dummy_t5_pointer_generator.onnx b/onnxruntime/test/testdata/dummy_t5_pointer_generator.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f7fee773cffe17cafa21af369ad06887d7cb38b4 GIT binary patch literal 7100 zcmd5>c|cRg)`zf#D=J(RktZ%FsIjh9R3vi(A}F#6xD=3>kgGu4>AWWQtDPKR&jaOU7w0ualy8}cLPK~0{Z%We|-5TH@WBh=FH4FGr!-A zknzYJ$}AYd*m0_u1O=BUja8|Ww9;6GhGF7m$qXKPhbyj0G~;x$Zz zN5@)AshU&C;^~OOyc1=4Qr{1Nl4^0E{r$t9>&QMW0a!HPK=d; zt7&}DNhgLdY~3#@7cb`&idfY&zM;2avF#17m9Yv1r;=*4GA$2j=_k`FxkP4+QU9jh zV;vj!xW2KsP@zn-Hmq}V-qLtZrgC9IIYm+o7sx-0S~2FbWKJV7-C}CRgfYhBXIchi z9byvV<*_PGA#L1kv%sV{#*eX8%QRYPOk#pYBb~veNC&V^h5#GcHy@oqADx0}!Hscj z3MJJjWonM?-h5+He!lfK3Ky%;+9zfr(@bxe55ZyW4UHKRL6a2kV&0%M44@tx@W}Hgl$WM%bNn#q?FugTg+|=6a;(qRx^%E4xv|z^CY50jlR4nCUclF+ zT|Hb~1dOq;*mefUTB{a7S;Lwp+GHVYCluP6dK+2S{L?JX=9P-wqa5L1v&P|`e_2TS z9!tY7rf(M7Yqjv$4uHj=L% z2BF!j33zg5N4!^QReWQuB`%oYjpt09$oAKJNwnJ&cpapH=lkq2{SAlbuKZ4>{ka24 z{C&9nat)jq`h>pmyNutAEr6>l@9Rp`xiG}xByLxC#EIp*M;um3#J{WVP)AJ{_%Z(r ztUO^ymrQ;k9{H^qa38sYZ`D$~G`R%tXMF{>uDi&Yjeu_1rBHt7q1gSmHPFp3OFZ_b znMAPf8YDfLN3VUC3{gW}A+-NK7;Z5e*{%ISo0DKclp)VxHF*-lzdxAUTjcf zS+*Glugb(blhbrVGPJOG(p(&9Z3-``6&|Yg$D%=(b!R`%z!!zB(Kb8>o#*}z8ERXx zd$*7AOXo-;vALtWJW5WjJe~5eUY`3yKm&5MDBN-ik4*|8gSU8Ltj$3@ zwZIic>MZoGx&W>RWUzO$py(%+4tzDa;x3;?Vd3+mAb4R7b4oI>$31Vlpw|>!{qOx) z-7y|Yrd7d6^8;Ai*%sz6bHg-cHSLz)0n_?T#!b{$>~=UD$E9SE)Y3rR$-P?YI6;YH z^6fEH<^`QRO`xq;Di*0ui0?&6fZg{C=67rb2eQ)IhVBu9a?Bwb#ZtW9=vu~Bq!Kc$;0H1FU9*+ml^~(^p zT0!W7?;x&!=i+eF05mDIhi(so=`WU)l;%3&pX0v~UyAKQZ=|dS&B)J*_oY1~>-t%` zFu$7mB>$6cd;TNwTYP~^C$A#W<>nGAffSZJ{F=&+@D+{953;;q9krmlc#!389D%Q^by&u$-o+hzE1W+%zgP;(+c7xs7jcil%|Qu{Hl`xXtWfohh%)O(uNc?j#iBU~FYPe;Zm74P)}fZa@cU;iEw1C$5u=&9H;Gc2mgmme z*aVd{UZ$DxUn7lavOlu!%wVA@JK&vQ1VRDB2UJUr-VKQFnR&!mBYWdt2J<(7*F13SQ8%s z)0ORPuy1~lctHeI7bnoz(R8r}f5QRwwly8#5-*=RS7 zD4j6oJ*KR^AwBd&x>movkNu8+FKBxMXl+@17;H0>Y!FT#EWiWmXQz`?bg!S#2Je`*v8q_Dv<1x-W!$f8jRL4SfiSKSKTK~0wJ4B`tT^w{ zMh>1mqg7&b>EBVdz7XyX zoJKtYvczuD2XS{)E0p)SDLxvu66Zx0lceuE;*CkGX`;0gYCBZI?Cz6`*qBTVS`i8x zd^Y05YvW1s`PXP7*oALak0Gu$V(2rcKM`Dfs2eF-Plrf!c+hVcR#ZHoWx<)Cd$f-P z?!Jx_9_5kr-pA8r7NrT8B4F?R1xS2)xMNgx(L7L>rP1W=BuqPFpdFYP*0eb{!4V z!gKMk`A&SX$pWUdSx0gtX=JDD8{*{95AT=9lXb1vkfLAR@xafIV8r@oWZT?zc*njM zh(tmG&l-(v2L859XyJN;@_D!X`9#)bq~Mf!W8g}_DiAw{;?J9Bf#sxfvM4bWoR2>P zOTRhvQk4XU9g4xH6(N{*catt<_g*q%*Oz4U(J;7aG2T>f^>dM?C>r+}JUkoH~zBuc*Md;djD0Y6`7lNxjFf@NM$=c~59{rOW%G|9b8KYg`$p+Kn z>y?C@pWYkU;SsR?hCe=eeP4Ic*A!f~O@@%Kau7%EroxYYqiZBRQTh5heE-Q$Bu}vw zO{%AZ@@Ff_*`672DD^02R=%XJqHR##dk6-4-Xt@Ed|}+!c)W5(qpMU07rS(skDEf$ zY1E<3c=_2G__fU;oF90dUb$vL^5%R8AfLGZ~}DdT+E#>l$c$u#xLXlMxSqRfIe5363^H!XwrE+&7J-S z-FVFb7k>5$6dgZE7Rm==-+-wwdBP@A*dYWW`*#EL;o)ehY>V5106e@p!P()vizbeF zB3{TJH4LBG0av!MfwirI>Djx{^tO3;s z$AXGR3(fgNe(wh3-_@Xc4BF-0H2?o5PWm!-hOM`rQ#32B Date: Tue, 17 Dec 2024 11:48:22 +0000 Subject: [PATCH 2/3] Linting changes --- .../test/testdata/dummy_t5_model_generator.py | 348 ++++++++++++------ 1 file changed, 245 insertions(+), 103 deletions(-) diff --git a/onnxruntime/test/testdata/dummy_t5_model_generator.py b/onnxruntime/test/testdata/dummy_t5_model_generator.py index e4ac59ef71bd5..1ecd8b9ee9c92 100644 --- a/onnxruntime/test/testdata/dummy_t5_model_generator.py +++ b/onnxruntime/test/testdata/dummy_t5_model_generator.py @@ -1,10 +1,12 @@ """ Script to generate a dummy ONNX model emulating T5 model with BeamSearch op. """ + import argparse +import numpy as np import onnx + import onnxruntime as ort from onnxruntime.transformers.convert_generation import move_initializers -import numpy as np def create_model( @@ -17,29 +19,37 @@ def create_model( max_length: int, length_penalty: float, sequence_as_input: bool, - decoder_needs_input_ids: bool + decoder_needs_input_ids: bool, ) -> onnx.ModelProto: encoder_graph = create_encoder(vocab_size, embed_dim, num_heads, head_size) - decoder_graph = create_decoder(vocab_size, embed_dim, num_heads, head_size, sequence_as_input, decoder_needs_input_ids) + decoder_graph = create_decoder( + vocab_size, embed_dim, num_heads, head_size, sequence_as_input, decoder_needs_input_ids + ) # Inputs: encoder_input_ids - encoder_input_ids = onnx.helper.make_tensor_value_info('encoder_input_ids', onnx.TensorProto.INT32, ['batch_size', 'encode_sequence_length']) + encoder_input_ids = onnx.helper.make_tensor_value_info( + "encoder_input_ids", onnx.TensorProto.INT32, ["batch_size", "encode_sequence_length"] + ) # Outputs: sequences, scores - sequences = onnx.helper.make_tensor_value_info('sequences', onnx.TensorProto.INT32, ['batch_size', beam_size, 'decode_sequence_length']) - scores = onnx.helper.make_tensor_value_info('scores', onnx.TensorProto.FLOAT, ['batch_size', beam_size]) + sequences = onnx.helper.make_tensor_value_info( + "sequences", onnx.TensorProto.INT32, ["batch_size", beam_size, "decode_sequence_length"] + ) + scores = onnx.helper.make_tensor_value_info("scores", onnx.TensorProto.FLOAT, ["batch_size", beam_size]) # Tensors - max_length_t = onnx.numpy_helper.from_array(np.array(max_length, dtype=np.int32), name='max_length') - min_length_t = onnx.numpy_helper.from_array(np.array(min_length, dtype=np.int32), name='min_length') - num_beams_t = onnx.numpy_helper.from_array(np.array(beam_size, dtype=np.int32), name='num_beams') - length_penalty_t = onnx.numpy_helper.from_array(np.array(length_penalty, dtype=np.float32), name='length_penalty_as_tensor') + max_length_t = onnx.numpy_helper.from_array(np.array(max_length, dtype=np.int32), name="max_length") + min_length_t = onnx.numpy_helper.from_array(np.array(min_length, dtype=np.int32), name="min_length") + num_beams_t = onnx.numpy_helper.from_array(np.array(beam_size, dtype=np.int32), name="num_beams") + length_penalty_t = onnx.numpy_helper.from_array( + np.array(length_penalty, dtype=np.float32), name="length_penalty_as_tensor" + ) # Nodes beam_search = onnx.helper.make_node( - 'BeamSearch', - ['encoder_input_ids', 'max_length', 'min_length', 'num_beams', 'num_beams', 'length_penalty_as_tensor'], - ['sequences', 'scores'], + "BeamSearch", + ["encoder_input_ids", "max_length", "min_length", "num_beams", "num_beams", "length_penalty_as_tensor"], + ["sequences", "scores"], decoder_start_token_id=2, eos_token_id=2, early_stopping=0, @@ -47,131 +57,263 @@ def create_model( pad_token_id=1, decoder=decoder_graph, encoder=encoder_graph, - domain='com.microsoft', + domain="com.microsoft", ) # Graph graph = onnx.helper.make_graph( [beam_search], - 'model', + "model", [encoder_input_ids], [sequences, scores], - [max_length_t, min_length_t, num_beams_t, length_penalty_t] + [max_length_t, min_length_t, num_beams_t, length_penalty_t], ) # Model - model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid('', 17), onnx.helper.make_opsetid('com.microsoft', 1)]) + model = onnx.helper.make_model( + graph, opset_imports=[onnx.helper.make_opsetid("", 17), onnx.helper.make_opsetid("com.microsoft", 1)] + ) return model def create_encoder(vocab_size, embed_dim, num_heads, head_size) -> onnx.GraphProto: # Inputs: encoder_input_ids, encoder_attention_mask, decoder_input_ids - encoder_input_ids = onnx.helper.make_tensor_value_info('encoder_input_ids', onnx.TensorProto.INT32, ['batch_size', 'encode_sequence_length']) - encoder_attention_mask = onnx.helper.make_tensor_value_info('encoder_attention_mask', onnx.TensorProto.INT32, ['batch_size', 'encode_sequence_length']) - decoder_input_ids = onnx.helper.make_tensor_value_info('decoder_input_ids', onnx.TensorProto.INT32, ['batch_size', 1]) + encoder_input_ids = onnx.helper.make_tensor_value_info( + "encoder_input_ids", onnx.TensorProto.INT32, ["batch_size", "encode_sequence_length"] + ) + encoder_attention_mask = onnx.helper.make_tensor_value_info( + "encoder_attention_mask", onnx.TensorProto.INT32, ["batch_size", "encode_sequence_length"] + ) + decoder_input_ids = onnx.helper.make_tensor_value_info( + "decoder_input_ids", onnx.TensorProto.INT32, ["batch_size", 1] + ) # Outputs: logits, present_key_self_0, present_value_self_0, present_key_cross_0, present_value_cross_0, encoder_hidden_states - logits = onnx.helper.make_tensor_value_info('logits', onnx.TensorProto.FLOAT, ['batch_size', 'decode_sequence_length', vocab_size]) - present_key_self_0 = onnx.helper.make_tensor_value_info('present_key_self_0', onnx.TensorProto.FLOAT, ['batch_size', num_heads, 1, head_size]) - present_value_self_0 = onnx.helper.make_tensor_value_info('present_value_self_0', onnx.TensorProto.FLOAT, ['batch_size', num_heads, 1, head_size]) - present_key_cross_0 = onnx.helper.make_tensor_value_info('present_key_cross_0', onnx.TensorProto.FLOAT, ['batch_size', num_heads, 'encode_sequence_length', head_size]) - present_value_cross_0 = onnx.helper.make_tensor_value_info('present_value_cross_0', onnx.TensorProto.FLOAT, ['batch_size', num_heads, 'encode_sequence_length', head_size]) - encoder_hidden_states = onnx.helper.make_tensor_value_info('encoder_hidden_states', onnx.TensorProto.FLOAT, ['batch_size', 'encode_sequence_length', embed_dim]) + logits = onnx.helper.make_tensor_value_info( + "logits", onnx.TensorProto.FLOAT, ["batch_size", "decode_sequence_length", vocab_size] + ) + present_key_self_0 = onnx.helper.make_tensor_value_info( + "present_key_self_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, 1, head_size] + ) + present_value_self_0 = onnx.helper.make_tensor_value_info( + "present_value_self_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, 1, head_size] + ) + present_key_cross_0 = onnx.helper.make_tensor_value_info( + "present_key_cross_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, "encode_sequence_length", head_size] + ) + present_value_cross_0 = onnx.helper.make_tensor_value_info( + "present_value_cross_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, "encode_sequence_length", head_size] + ) + encoder_hidden_states = onnx.helper.make_tensor_value_info( + "encoder_hidden_states", onnx.TensorProto.FLOAT, ["batch_size", "encode_sequence_length", embed_dim] + ) # Tensors - encoder_embeddings_tensor = onnx.numpy_helper.from_array(np.random.randn(vocab_size, embed_dim).astype(np.float32), name='encoder_embeddings') - num_heads_and_size_tensor = onnx.numpy_helper.from_array(np.array([num_heads, head_size], dtype=np.int64), name='num_heads_and_size') - final_proj_tensor = onnx.numpy_helper.from_array(np.random.randn(embed_dim, vocab_size).astype(np.float32), name='init_final_proj') - self_state_before_tranpose_shape_tensor = onnx.numpy_helper.from_array(np.array([-1, 1, num_heads, head_size], dtype=np.int64), name='self_state_before_tranpose_shape') + encoder_embeddings_tensor = onnx.numpy_helper.from_array( + np.random.randn(vocab_size, embed_dim).astype(np.float32), name="encoder_embeddings" + ) + num_heads_and_size_tensor = onnx.numpy_helper.from_array( + np.array([num_heads, head_size], dtype=np.int64), name="num_heads_and_size" + ) + final_proj_tensor = onnx.numpy_helper.from_array( + np.random.randn(embed_dim, vocab_size).astype(np.float32), name="init_final_proj" + ) + self_state_before_tranpose_shape_tensor = onnx.numpy_helper.from_array( + np.array([-1, 1, num_heads, head_size], dtype=np.int64), name="self_state_before_tranpose_shape" + ) # Nodes nodes = [ - onnx.helper.make_node('Gather', ['encoder_embeddings', 'encoder_input_ids'], ['encoder_hidden_states']), - onnx.helper.make_node('Shape', ['encoder_hidden_states'], ['encoder_batch_seq_len'], end=2), - onnx.helper.make_node('Concat', ['encoder_batch_seq_len', 'num_heads_and_size'], ['encoder_final_shape'], axis=0), - onnx.helper.make_node('Reshape', ['encoder_hidden_states', 'encoder_final_shape'], ['encoder_hidden_states_reshaped']), - - onnx.helper.make_node('Transpose', ['encoder_hidden_states_reshaped'], ['present_key_cross_0'], perm=[0, 2, 1, 3]), - onnx.helper.make_node('Transpose', ['encoder_hidden_states_reshaped'], ['present_value_cross_0'], perm=[0, 2, 1, 3]), - - onnx.helper.make_node('Gather', ['encoder_embeddings', 'decoder_input_ids'], ['decoder_hidden_states']), - onnx.helper.make_node('ReduceMean', ['encoder_hidden_states'], ['encoder_hidden_states_mean'], axes=[1]), - onnx.helper.make_node('Add', ['decoder_hidden_states', 'encoder_hidden_states_mean'], ['encoder_decoder_sum']), - - onnx.helper.make_node('MatMul', ['encoder_decoder_sum', 'init_final_proj'], ['logits']), - - onnx.helper.make_node('Reshape', ['encoder_decoder_sum', 'self_state_before_tranpose_shape'], ['self_state_before_tranpose']), - onnx.helper.make_node('Transpose', ['self_state_before_tranpose'], ['present_key_self_0'], perm=[0, 2, 1, 3]), - onnx.helper.make_node('Transpose', ['self_state_before_tranpose'], ['present_value_self_0'], perm=[0, 2, 1, 3]), + onnx.helper.make_node("Gather", ["encoder_embeddings", "encoder_input_ids"], ["encoder_hidden_states"]), + onnx.helper.make_node("Shape", ["encoder_hidden_states"], ["encoder_batch_seq_len"], end=2), + onnx.helper.make_node( + "Concat", ["encoder_batch_seq_len", "num_heads_and_size"], ["encoder_final_shape"], axis=0 + ), + onnx.helper.make_node( + "Reshape", ["encoder_hidden_states", "encoder_final_shape"], ["encoder_hidden_states_reshaped"] + ), + onnx.helper.make_node( + "Transpose", ["encoder_hidden_states_reshaped"], ["present_key_cross_0"], perm=[0, 2, 1, 3] + ), + onnx.helper.make_node( + "Transpose", ["encoder_hidden_states_reshaped"], ["present_value_cross_0"], perm=[0, 2, 1, 3] + ), + onnx.helper.make_node("Gather", ["encoder_embeddings", "decoder_input_ids"], ["decoder_hidden_states"]), + onnx.helper.make_node("ReduceMean", ["encoder_hidden_states"], ["encoder_hidden_states_mean"], axes=[1]), + onnx.helper.make_node("Add", ["decoder_hidden_states", "encoder_hidden_states_mean"], ["encoder_decoder_sum"]), + onnx.helper.make_node("MatMul", ["encoder_decoder_sum", "init_final_proj"], ["logits"]), + onnx.helper.make_node( + "Reshape", ["encoder_decoder_sum", "self_state_before_tranpose_shape"], ["self_state_before_tranpose"] + ), + onnx.helper.make_node("Transpose", ["self_state_before_tranpose"], ["present_key_self_0"], perm=[0, 2, 1, 3]), + onnx.helper.make_node("Transpose", ["self_state_before_tranpose"], ["present_value_self_0"], perm=[0, 2, 1, 3]), ] # Graph graph = onnx.helper.make_graph( nodes, - 'encoder', + "encoder", [encoder_input_ids, encoder_attention_mask, decoder_input_ids], - [logits, encoder_hidden_states, present_key_self_0, present_value_self_0, present_key_cross_0, present_value_cross_0], - [encoder_embeddings_tensor, num_heads_and_size_tensor, final_proj_tensor, self_state_before_tranpose_shape_tensor] + [ + logits, + encoder_hidden_states, + present_key_self_0, + present_value_self_0, + present_key_cross_0, + present_value_cross_0, + ], + [ + encoder_embeddings_tensor, + num_heads_and_size_tensor, + final_proj_tensor, + self_state_before_tranpose_shape_tensor, + ], ) return graph -def create_decoder(vocab_size, embed_dim, num_heads, head_size, sequence_as_input, decoder_needs_input_ids) -> onnx.GraphProto: +def create_decoder( + vocab_size, embed_dim, num_heads, head_size, sequence_as_input, decoder_needs_input_ids +) -> onnx.GraphProto: # Inputs: input_ids, encoder_input_ids (optional), encoder_attention_mask, past_self_key_0, past_self_value_0, past_cross_key_0, past_cross_value_0 inputs = [] - inputs.append(onnx.helper.make_tensor_value_info('input_ids', onnx.TensorProto.INT32, ['batch_size', 'decode_sequence_length' if sequence_as_input else 1])) + inputs.append( + onnx.helper.make_tensor_value_info( + "input_ids", onnx.TensorProto.INT32, ["batch_size", "decode_sequence_length" if sequence_as_input else 1] + ) + ) if decoder_needs_input_ids: - inputs.append(onnx.helper.make_tensor_value_info('encoder_input_ids', onnx.TensorProto.INT32, ['batch_size', 'encode_sequence_length'])) - inputs.append(onnx.helper.make_tensor_value_info('encoder_attention_mask', onnx.TensorProto.INT32, ['batch_size', 'encode_sequence_length'])) - inputs.append(onnx.helper.make_tensor_value_info('past_self_key_0', onnx.TensorProto.FLOAT, ['batch_size', num_heads, 'decode_sequence_length', head_size])) - inputs.append(onnx.helper.make_tensor_value_info('past_self_value_0', onnx.TensorProto.FLOAT, ['batch_size', num_heads, 'decode_sequence_length', head_size])) - inputs.append(onnx.helper.make_tensor_value_info('past_cross_key_0', onnx.TensorProto.FLOAT, ['batch_size', num_heads, 'encode_sequence_length', head_size])) - inputs.append(onnx.helper.make_tensor_value_info('past_cross_value_0', onnx.TensorProto.FLOAT, ['batch_size', num_heads, 'encode_sequence_length', head_size])) + inputs.append( + onnx.helper.make_tensor_value_info( + "encoder_input_ids", onnx.TensorProto.INT32, ["batch_size", "encode_sequence_length"] + ) + ) + inputs.append( + onnx.helper.make_tensor_value_info( + "encoder_attention_mask", onnx.TensorProto.INT32, ["batch_size", "encode_sequence_length"] + ) + ) + inputs.append( + onnx.helper.make_tensor_value_info( + "past_self_key_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, "decode_sequence_length", head_size] + ) + ) + inputs.append( + onnx.helper.make_tensor_value_info( + "past_self_value_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, "decode_sequence_length", head_size] + ) + ) + inputs.append( + onnx.helper.make_tensor_value_info( + "past_cross_key_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, "encode_sequence_length", head_size] + ) + ) + inputs.append( + onnx.helper.make_tensor_value_info( + "past_cross_value_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, "encode_sequence_length", head_size] + ) + ) # Outputs: logits, present_key_self_0, present_value_self_0 outputs = [ - onnx.helper.make_tensor_value_info('logits', onnx.TensorProto.FLOAT, ['batch_size', 1, vocab_size]), - onnx.helper.make_tensor_value_info('present_key_self_0', onnx.TensorProto.FLOAT, ['batch_size', num_heads, 'present_decode_sequence_length', head_size]), - onnx.helper.make_tensor_value_info('present_value_self_0', onnx.TensorProto.FLOAT, ['batch_size', num_heads, 'present_decode_sequence_length', head_size]), + onnx.helper.make_tensor_value_info("logits", onnx.TensorProto.FLOAT, ["batch_size", 1, vocab_size]), + onnx.helper.make_tensor_value_info( + "present_key_self_0", + onnx.TensorProto.FLOAT, + ["batch_size", num_heads, "present_decode_sequence_length", head_size], + ), + onnx.helper.make_tensor_value_info( + "present_value_self_0", + onnx.TensorProto.FLOAT, + ["batch_size", num_heads, "present_decode_sequence_length", head_size], + ), ] # Tensors: decoder_embeddings, final_proj, self_state_before_tranpose_shape_no_batch, hidden_states_mean initializers = [ - onnx.numpy_helper.from_array(np.random.randn(vocab_size, embed_dim).astype(np.float32), name='decoder_embeddings'), - onnx.numpy_helper.from_array(np.random.randn(embed_dim, vocab_size).astype(np.float32), name='final_proj'), - onnx.numpy_helper.from_array(np.array([-1, num_heads, head_size], dtype=np.int64), name='self_state_before_tranpose_shape_no_batch'), - onnx.numpy_helper.from_array(np.array([-1, 1, embed_dim], dtype=np.int64), name='hidden_states_mean_shape'), + onnx.numpy_helper.from_array( + np.random.randn(vocab_size, embed_dim).astype(np.float32), name="decoder_embeddings" + ), + onnx.numpy_helper.from_array(np.random.randn(embed_dim, vocab_size).astype(np.float32), name="final_proj"), + onnx.numpy_helper.from_array( + np.array([-1, num_heads, head_size], dtype=np.int64), name="self_state_before_tranpose_shape_no_batch" + ), + onnx.numpy_helper.from_array(np.array([-1, 1, embed_dim], dtype=np.int64), name="hidden_states_mean_shape"), ] # Nodes nodes = [] - nodes.append(onnx.helper.make_node('Gather', ['decoder_embeddings', 'input_ids'], ['decoder_hidden_states'])) + nodes.append(onnx.helper.make_node("Gather", ["decoder_embeddings", "input_ids"], ["decoder_hidden_states"])) if decoder_needs_input_ids: - nodes.append(onnx.helper.make_node('Gather', ['decoder_embeddings', 'encoder_input_ids'], ['encoder_input_embeddings'])) - nodes.append(onnx.helper.make_node('ReduceMean', ['encoder_input_embeddings'], ['encoder_input_embeddings_mean'], axes=[1])) - nodes.append(onnx.helper.make_node('Mul', ['decoder_hidden_states', 'encoder_input_embeddings_mean'], ['combined_hidden_states'])) + nodes.append( + onnx.helper.make_node("Gather", ["decoder_embeddings", "encoder_input_ids"], ["encoder_input_embeddings"]) + ) + nodes.append( + onnx.helper.make_node( + "ReduceMean", ["encoder_input_embeddings"], ["encoder_input_embeddings_mean"], axes=[1] + ) + ) + nodes.append( + onnx.helper.make_node( + "Mul", ["decoder_hidden_states", "encoder_input_embeddings_mean"], ["combined_hidden_states"] + ) + ) else: - nodes.append(onnx.helper.make_node('Identity', ['decoder_hidden_states'], ['combined_hidden_states'])) - nodes.append(onnx.helper.make_node('ReduceMean', ['past_cross_key_0'], ['encoder_hidden_states_mean'], axes=[2])) - nodes.append(onnx.helper.make_node('Reshape', ['encoder_hidden_states_mean', 'hidden_states_mean_shape'], ['encoder_hidden_states_mean_reshaped'])) + nodes.append(onnx.helper.make_node("Identity", ["decoder_hidden_states"], ["combined_hidden_states"])) + nodes.append(onnx.helper.make_node("ReduceMean", ["past_cross_key_0"], ["encoder_hidden_states_mean"], axes=[2])) + nodes.append( + onnx.helper.make_node( + "Reshape", + ["encoder_hidden_states_mean", "hidden_states_mean_shape"], + ["encoder_hidden_states_mean_reshaped"], + ) + ) if sequence_as_input: - nodes.append(onnx.helper.make_node('ReduceMean', ['combined_hidden_states'], ['decoder_hidden_states_mean'], axes=[1])) - nodes.append(onnx.helper.make_node('Add', ['decoder_hidden_states_mean', 'encoder_hidden_states_mean_reshaped'], ['encoder_decoder_sum'])) + nodes.append( + onnx.helper.make_node("ReduceMean", ["combined_hidden_states"], ["decoder_hidden_states_mean"], axes=[1]) + ) + nodes.append( + onnx.helper.make_node( + "Add", ["decoder_hidden_states_mean", "encoder_hidden_states_mean_reshaped"], ["encoder_decoder_sum"] + ) + ) else: - nodes.append(onnx.helper.make_node('Add', ['combined_hidden_states', 'encoder_hidden_states_mean_reshaped'], ['encoder_decoder_sum'])) - nodes.append(onnx.helper.make_node('Shape', ['combined_hidden_states'], ['decoder_batch'], end=1)) - nodes.append(onnx.helper.make_node('Concat', ['decoder_batch', 'self_state_before_tranpose_shape_no_batch'], ['self_state_before_tranpose_shape_dec'], axis=0)) - nodes.append(onnx.helper.make_node('MatMul', ['encoder_decoder_sum', 'final_proj'], ['logits'])) - nodes.append(onnx.helper.make_node('Reshape', ['encoder_decoder_sum', 'self_state_before_tranpose_shape_dec'], ['self_state_before_tranpose'])) - nodes.append(onnx.helper.make_node('Transpose', ['self_state_before_tranpose'], ['single_self_key_0'], perm=[0, 2, 1, 3])) - nodes.append(onnx.helper.make_node('Transpose', ['self_state_before_tranpose'], ['single_self_value_0'], perm=[0, 2, 1, 3])) - nodes.append(onnx.helper.make_node('Concat', ['past_self_key_0', 'single_self_key_0'], ['present_key_self_0'], axis=2)) - nodes.append(onnx.helper.make_node('Concat', ['past_self_value_0', 'single_self_value_0'], ['present_value_self_0'], axis=2)) - + nodes.append( + onnx.helper.make_node( + "Add", ["combined_hidden_states", "encoder_hidden_states_mean_reshaped"], ["encoder_decoder_sum"] + ) + ) + nodes.append(onnx.helper.make_node("Shape", ["combined_hidden_states"], ["decoder_batch"], end=1)) + nodes.append( + onnx.helper.make_node( + "Concat", + ["decoder_batch", "self_state_before_tranpose_shape_no_batch"], + ["self_state_before_tranpose_shape_dec"], + axis=0, + ) + ) + nodes.append(onnx.helper.make_node("MatMul", ["encoder_decoder_sum", "final_proj"], ["logits"])) + nodes.append( + onnx.helper.make_node( + "Reshape", ["encoder_decoder_sum", "self_state_before_tranpose_shape_dec"], ["self_state_before_tranpose"] + ) + ) + nodes.append( + onnx.helper.make_node("Transpose", ["self_state_before_tranpose"], ["single_self_key_0"], perm=[0, 2, 1, 3]) + ) + nodes.append( + onnx.helper.make_node("Transpose", ["self_state_before_tranpose"], ["single_self_value_0"], perm=[0, 2, 1, 3]) + ) + nodes.append( + onnx.helper.make_node("Concat", ["past_self_key_0", "single_self_key_0"], ["present_key_self_0"], axis=2) + ) + nodes.append( + onnx.helper.make_node("Concat", ["past_self_value_0", "single_self_value_0"], ["present_value_self_0"], axis=2) + ) # Graph - graph = onnx.helper.make_graph(nodes, 'decoder', inputs, outputs, initializers) + graph = onnx.helper.make_graph(nodes, "decoder", inputs, outputs, initializers) return graph @@ -179,7 +321,7 @@ def run_model(model_path): ort_session = ort.InferenceSession(model_path) encoder_input_ids = np.array([[14, 6, 13, 9, 7]]).astype(np.int32) print("encoder_input_ids: ", encoder_input_ids) - sequence, scores = ort_session.run(None, {'encoder_input_ids': encoder_input_ids}) + sequence, scores = ort_session.run(None, {"encoder_input_ids": encoder_input_ids}) print("sequence: ", sequence) print("scores: ", scores) @@ -187,32 +329,32 @@ def run_model(model_path): def move_initializers_on_outer_scope(model) -> None: main_graph = model.graph beam_search_node = model.graph.node[0] - decoder_graph = [attr for attr in beam_search_node.attribute if attr.name == 'decoder'][0].g - encoder_graph = [attr for attr in beam_search_node.attribute if attr.name == 'encoder'][0].g + decoder_graph = next(attr for attr in beam_search_node.attribute if attr.name == "decoder").g + encoder_graph = next(attr for attr in beam_search_node.attribute if attr.name == "encoder").g main_graph.initializer.extend(move_initializers(decoder_graph, min_elements=10)) main_graph.initializer.extend(move_initializers(encoder_graph, min_elements=10)) def arg_parser(): - parser = argparse.ArgumentParser(description='Generate a dummy ONNX model emulating T5 model with BeamSearch op.') - parser.add_argument('--output-path', type=str, default='model.onnx', help='Model output path') - parser.add_argument('--seed', type=int, default=42, help='Random seed') - parser.add_argument('--vocab-size', type=int, default=20, help='Vocab size') - parser.add_argument('--embed-dim', type=int, default=8, help='Embedding dimension') - parser.add_argument('--num-heads', type=int, default=2, help='Number of heads') - parser.add_argument('--head-size', type=int, default=4, help='Head size') - parser.add_argument('--beam-size', type=int, default=3, help='Beam size') - parser.add_argument('--min-length', type=int, default=1, help='Min length') - parser.add_argument('--max-length', type=int, default=10, help='Max length') - parser.add_argument('--length-penalty', type=float, default=1.1, help='Length penalty') - parser.add_argument('--move-initializers', action='store_true', help='Move initializers to outer scope') - parser.add_argument('--sequence-as-input', action='store_true', help='Use sequence as input') - parser.add_argument('--decoder-needs-input-ids', action='store_true', help='Decoder needs model/encoder input ids') + parser = argparse.ArgumentParser(description="Generate a dummy ONNX model emulating T5 model with BeamSearch op.") + parser.add_argument("--output-path", type=str, default="model.onnx", help="Model output path") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument("--vocab-size", type=int, default=20, help="Vocab size") + parser.add_argument("--embed-dim", type=int, default=8, help="Embedding dimension") + parser.add_argument("--num-heads", type=int, default=2, help="Number of heads") + parser.add_argument("--head-size", type=int, default=4, help="Head size") + parser.add_argument("--beam-size", type=int, default=3, help="Beam size") + parser.add_argument("--min-length", type=int, default=1, help="Min length") + parser.add_argument("--max-length", type=int, default=10, help="Max length") + parser.add_argument("--length-penalty", type=float, default=1.1, help="Length penalty") + parser.add_argument("--move-initializers", action="store_true", help="Move initializers to outer scope") + parser.add_argument("--sequence-as-input", action="store_true", help="Use sequence as input") + parser.add_argument("--decoder-needs-input-ids", action="store_true", help="Decoder needs model/encoder input ids") return parser.parse_args() -if __name__ == '__main__': +if __name__ == "__main__": args = arg_parser() np.random.seed(args.seed) From 39852dfe73e035f3e8ee937ec58c883b710241eb Mon Sep 17 00:00:00 2001 From: amancini-N Date: Fri, 20 Dec 2024 17:19:38 +0000 Subject: [PATCH 3/3] Addressing review comments --- .../cpu/transformers/subgraph_t5_decoder.cc | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 8db69150919d5..997beb198f450 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -20,8 +20,9 @@ namespace transformers { Inputs: input_ids: int32 (B, 1) + encoder_input_ids: int32 (B, encode_sequence_length) (optional) encoder_attention_mask: int32 (B, encode_sequence_length) - encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) + encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) (optional) past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size) past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size) @@ -53,9 +54,6 @@ Status T5DecoderSubgraph::Validate(const std::vector& subgraph_i bool has_hidden_state = subgraph_inputs[2 + has_encoder_input_ids]->Name() == "encoder_hidden_states"; SetPastInputIndex(has_hidden_state, has_encoder_input_ids); - ORT_RETURN_IF(first_past_input_index_ != 2 && first_past_input_index_ != 3 && first_past_input_index_ != 4, - "kFirstPastInputIndex currently only supports 2, 3 or 4"); - if (!past_present_share_buffer_) { ORT_RETURN_IF(has_decoder_masked_attention_, "decoder_masked_attention shall use with past_present_share_buffer"); ORT_RETURN_IF(num_subgraph_inputs < 4 + first_past_input_index_ || @@ -78,11 +76,6 @@ Status T5DecoderSubgraph::Validate(const std::vector& subgraph_i "decoder subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name()); const int enc_attn_mask_index = 1 + has_encoder_input_ids_; const int enc_hidden_state_index = enc_attn_mask_index + 1; - if (has_encoder_input_ids_) { - ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_input_ids", - "decoder subgraph input 1 shall be named as encoder_input_ids, got: ", - subgraph_inputs[1]->Name()); - } ORT_RETURN_IF(subgraph_inputs[enc_attn_mask_index]->Name() != "encoder_attention_mask", "decoder subgraph input ", std::to_string(enc_attn_mask_index), " shall be named as encoder_attention_mask, got: ", @@ -268,6 +261,8 @@ Status T5DecoderSubgraph::CreateInitialFeeds( // When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output // of encoder. // When first_past_input_index_ == 2, the past states are copied from the second output of encoder. + // TODO - probably more robust to introduce a encoder_out/decoder_in mapping instead of relying on positions. + // What happens if encoder_hidden_states is present in the encoder_fetches but not in the decoder_feeds? for (size_t j = static_cast(2) - has_hidden_state_; j < encoder_fetches.size(); j++) { if (j == 1) { ORT_RETURN_IF(has_hidden_state_ == false, "Invalid hidden_states expension: has_hidden_state_ == false");