Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T5 BeamSearch model loading failing if initializers are moved on the outer graph #23043

Closed
amancini-N opened this issue Dec 6, 2024 · 0 comments · Fixed by #23044
Closed
Labels
model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.

Comments

@amancini-N
Copy link
Contributor

Describe the issue

Seems there's a bug in subgraph setup for BeamSearch op in case of T5 models in which the initializers are moved into outer graph. Essentially, the setup of the encoder/decoder subgraphs is not taking into consideration which outer scope initializer is used inside the graph or not (similarly to what happens in the If kernel)

To reproduce

Following python script creates a dummy T5 like model, runs it, performs another pass to move the initializers in the outer graphs, then runs the new model.

import onnx
import onnxruntime as ort
import numpy as np

np.random.seed(42)

VOCAB_SIZE = 20
EMBED_DIM = 8
NUM_HEADS = 2
HEAD_SIZE = 4
BEAM_SIZE = 3
LENGTH_PENALTY = 1.1


def create_model(model_path):
    encoder_graph = create_encoder()
    decoder_graph = create_decoder()

    # Inputs: encoder_input_ids
    encoder_input_ids = onnx.helper.make_tensor_value_info('encoder_input_ids', onnx.TensorProto.INT32, ['batch_size', 'encode_sequence_length'])

    # Outputs: sequences, scores
    sequences = onnx.helper.make_tensor_value_info('sequences', onnx.TensorProto.INT32, ['batch_size', BEAM_SIZE, 'decode_sequence_length'])
    scores = onnx.helper.make_tensor_value_info('scores', onnx.TensorProto.FLOAT, ['batch_size', BEAM_SIZE])

    # Tensors
    ten_as_tensor = onnx.numpy_helper.from_array(np.array(10, dtype=np.int32), name='ten_as_tensor')
    one_as_tensor = onnx.numpy_helper.from_array(np.array(1, dtype=np.int32), name='one_as_tensor')
    three_as_tensor = onnx.numpy_helper.from_array(np.array(3, dtype=np.int32), name='three_as_tensor')
    length_penalty_as_tensor = onnx.numpy_helper.from_array(np.array(LENGTH_PENALTY, dtype=np.float32), name='length_penalty_as_tensor')

    # Nodes
    beam_search = onnx.helper.make_node(
        'BeamSearch',
        ['encoder_input_ids', 'ten_as_tensor', 'one_as_tensor', 'three_as_tensor', 'three_as_tensor', 'length_penalty_as_tensor'],
        ['sequences', 'scores'],
        decoder_start_token_id=2,
        eos_token_id=2,
        early_stopping=0,
        model_type=1,
        pad_token_id=1,
        decoder=decoder_graph,
        encoder=encoder_graph,
        domain='com.microsoft',
    )

    # Graph
    graph = onnx.helper.make_graph(
        [beam_search],
        'model',
        [encoder_input_ids],
        [sequences, scores],
        [ten_as_tensor, one_as_tensor, three_as_tensor, length_penalty_as_tensor]
    )

    # Model
    model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid('', 17), onnx.helper.make_opsetid('com.microsoft', 1)])

    # Save
    onnx.save(model, model_path)


def create_encoder():
    # Inputs: encoder_input_ids, encoder_attention_mask, decoder_input_ids
    encoder_input_ids = onnx.helper.make_tensor_value_info('encoder_input_ids', onnx.TensorProto.INT32, ['batch_size', 'encode_sequence_length'])
    encoder_attention_mask = onnx.helper.make_tensor_value_info('encoder_attention_mask', onnx.TensorProto.INT32, ['batch_size', 'encode_sequence_length'])
    decoder_input_ids = onnx.helper.make_tensor_value_info('decoder_input_ids', onnx.TensorProto.INT32, ['batch_size', 1])

    # Outputs: logits, present_key_self_0, present_value_self_0, present_key_cross_0, present_value_cross_0, encoder_hidden_states
    logits = onnx.helper.make_tensor_value_info('logits', onnx.TensorProto.FLOAT, ['batch_size', 'decode_sequence_length', VOCAB_SIZE])
    present_key_self_0 = onnx.helper.make_tensor_value_info('present_key_self_0', onnx.TensorProto.FLOAT, ['batch_size', NUM_HEADS, 1, HEAD_SIZE])
    present_value_self_0 = onnx.helper.make_tensor_value_info('present_value_self_0', onnx.TensorProto.FLOAT, ['batch_size', NUM_HEADS, 1, HEAD_SIZE])
    present_key_cross_0 = onnx.helper.make_tensor_value_info('present_key_cross_0', onnx.TensorProto.FLOAT, ['batch_size', NUM_HEADS, 'encode_sequence_length', HEAD_SIZE])
    present_value_cross_0 = onnx.helper.make_tensor_value_info('present_value_cross_0', onnx.TensorProto.FLOAT, ['batch_size', NUM_HEADS, 'encode_sequence_length', HEAD_SIZE])
    encoder_hidden_states = onnx.helper.make_tensor_value_info('encoder_hidden_states', onnx.TensorProto.FLOAT, ['batch_size', 'encode_sequence_length', EMBED_DIM])

    # Tensors
    encoder_embeddings_tensor = onnx.numpy_helper.from_array(np.random.randn(VOCAB_SIZE, EMBED_DIM).astype(np.float32), name='encoder_embeddings')
    num_heads_and_size_tensor = onnx.numpy_helper.from_array(np.array([NUM_HEADS, HEAD_SIZE], dtype=np.int64), name='num_heads_and_size')
    final_proj_tensor = onnx.numpy_helper.from_array(np.random.randn(EMBED_DIM, VOCAB_SIZE).astype(np.float32), name='init_final_proj')
    self_state_before_tranpose_shape_tensor = onnx.numpy_helper.from_array(np.array([-1, 1, NUM_HEADS, HEAD_SIZE], dtype=np.int64), name='self_state_before_tranpose_shape')

    # Nodes
    nodes = [
        onnx.helper.make_node('Gather', ['encoder_embeddings', 'encoder_input_ids'], ['encoder_hidden_states']),
        onnx.helper.make_node('Shape', ['encoder_hidden_states'], ['encoder_batch_seq_len'], end=2),
        onnx.helper.make_node('Concat', ['encoder_batch_seq_len', 'num_heads_and_size'], ['encoder_final_shape'], axis=0),
        onnx.helper.make_node('Reshape', ['encoder_hidden_states', 'encoder_final_shape'], ['encoder_hidden_states_reshaped']),

        onnx.helper.make_node('Transpose', ['encoder_hidden_states_reshaped'], ['present_key_cross_0'], perm=[0, 2, 1, 3]),
        onnx.helper.make_node('Transpose', ['encoder_hidden_states_reshaped'], ['present_value_cross_0'], perm=[0, 2, 1, 3]),

        onnx.helper.make_node('Gather', ['encoder_embeddings', 'decoder_input_ids'], ['decoder_hidden_states']),
        onnx.helper.make_node('ReduceMean', ['encoder_hidden_states'], ['encoder_hidden_states_mean'], axes=[1]),
        onnx.helper.make_node('Add', ['decoder_hidden_states', 'encoder_hidden_states_mean'], ['encoder_decoder_sum']),

        onnx.helper.make_node('MatMul', ['encoder_decoder_sum', 'init_final_proj'], ['logits']),

        onnx.helper.make_node('Reshape', ['encoder_decoder_sum', 'self_state_before_tranpose_shape'], ['self_state_before_tranpose']),
        onnx.helper.make_node('Transpose', ['self_state_before_tranpose'], ['present_key_self_0'], perm=[0, 2, 1, 3]),
        onnx.helper.make_node('Transpose', ['self_state_before_tranpose'], ['present_value_self_0'], perm=[0, 2, 1, 3]),
    ]

    # Graph
    graph = onnx.helper.make_graph(
        nodes,
        'encoder',
        [encoder_input_ids, encoder_attention_mask, decoder_input_ids],
        [logits, encoder_hidden_states, present_key_self_0, present_value_self_0, present_key_cross_0, present_value_cross_0],
        [encoder_embeddings_tensor, num_heads_and_size_tensor, final_proj_tensor, self_state_before_tranpose_shape_tensor]
    )
    return graph


def create_decoder():
    # Inputs: input_ids, encoder_attention_mask, past_self_key_0, past_self_value_0, past_cross_key_0, past_cross_value_0
    input_ids = onnx.helper.make_tensor_value_info('input_ids', onnx.TensorProto.INT32, ['batch_size', 1])
    encoder_attention_mask = onnx.helper.make_tensor_value_info('encoder_attention_mask', onnx.TensorProto.INT32, ['batch_size', 'encode_sequence_length'])
    past_self_key_0 = onnx.helper.make_tensor_value_info('past_self_key_0', onnx.TensorProto.FLOAT, ['batch_size', NUM_HEADS, 'decode_sequence_length', HEAD_SIZE])
    past_self_value_0 = onnx.helper.make_tensor_value_info('past_self_value_0', onnx.TensorProto.FLOAT, ['batch_size', NUM_HEADS, 'decode_sequence_length', HEAD_SIZE])
    past_cross_key_0 = onnx.helper.make_tensor_value_info('past_cross_key_0', onnx.TensorProto.FLOAT, ['batch_size', NUM_HEADS, 'encode_sequence_length', HEAD_SIZE])
    past_cross_value_0 = onnx.helper.make_tensor_value_info('past_cross_value_0', onnx.TensorProto.FLOAT, ['batch_size', NUM_HEADS, 'encode_sequence_length', HEAD_SIZE])

    # Outputs: logits, present_key_self_0, present_value_self_0
    logits = onnx.helper.make_tensor_value_info('logits', onnx.TensorProto.FLOAT, ['batch_size', 1, VOCAB_SIZE])
    present_key_self_0 = onnx.helper.make_tensor_value_info('present_key_self_0', onnx.TensorProto.FLOAT, ['batch_size', NUM_HEADS, 'present_decode_sequence_length', HEAD_SIZE])
    present_value_self_0 = onnx.helper.make_tensor_value_info('present_value_self_0', onnx.TensorProto.FLOAT, ['batch_size', NUM_HEADS, 'present_decode_sequence_length', HEAD_SIZE])

    # Tensors
    decoder_embeddings_tensor = onnx.numpy_helper.from_array(np.random.randn(VOCAB_SIZE, EMBED_DIM).astype(np.float32), name='decoder_embeddings')
    final_proj_tensor = onnx.numpy_helper.from_array(np.random.randn(EMBED_DIM, VOCAB_SIZE).astype(np.float32), name='final_proj')
    self_state_before_tranpose_shape_no_batch_tensor = onnx.numpy_helper.from_array(np.array([-1, NUM_HEADS, HEAD_SIZE], dtype=np.int64), name='self_state_before_tranpose_shape_no_batch')
    hidden_states_mean_shape_tensor = onnx.numpy_helper.from_array(np.array([-1, 1, EMBED_DIM], dtype=np.int64), name='hidden_states_mean_shape')

    # Nodes
    nodes = [
        onnx.helper.make_node('Gather', ['decoder_embeddings', 'input_ids'], ['decoder_hidden_states']),
        onnx.helper.make_node('ReduceMean', ['past_cross_key_0'], ['encoder_hidden_states_mean'], axes=[2]),
        onnx.helper.make_node('Reshape', ['encoder_hidden_states_mean', 'hidden_states_mean_shape'], ['encoder_hidden_states_mean_reshaped']),
        onnx.helper.make_node('Add', ['decoder_hidden_states', 'encoder_hidden_states_mean_reshaped'], ['encoder_decoder_sum']),
        onnx.helper.make_node('Shape', ['decoder_hidden_states'], ['decoder_batch'], end=1),
        onnx.helper.make_node('Concat', ['decoder_batch', 'self_state_before_tranpose_shape_no_batch'], ['self_state_before_tranpose_shape_dec'], axis=0),
        onnx.helper.make_node('MatMul', ['encoder_decoder_sum', 'final_proj'], ['logits']),
        onnx.helper.make_node('Reshape', ['encoder_decoder_sum', 'self_state_before_tranpose_shape_dec'], ['self_state_before_tranpose']),
        onnx.helper.make_node('Transpose', ['self_state_before_tranpose'], ['single_self_key_0'], perm=[0, 2, 1, 3]),
        onnx.helper.make_node('Transpose', ['self_state_before_tranpose'], ['single_self_value_0'], perm=[0, 2, 1, 3]),
        onnx.helper.make_node('Concat', ['past_self_key_0', 'single_self_key_0'], ['present_key_self_0'], axis=2),
        onnx.helper.make_node('Concat', ['past_self_value_0', 'single_self_value_0'], ['present_value_self_0'], axis=2),
    ]

    # Graph
    graph = onnx.helper.make_graph(
        nodes,
        'decoder',
        [input_ids, encoder_attention_mask, past_self_key_0, past_self_value_0, past_cross_key_0, past_cross_value_0],
        [logits, present_key_self_0, present_value_self_0],
        [decoder_embeddings_tensor, final_proj_tensor, self_state_before_tranpose_shape_no_batch_tensor, hidden_states_mean_shape_tensor]
    )
    return graph


def run_model(model_path):
    ort_session = ort.InferenceSession(model_path)
    encoder_input_ids = np.array([[14, 6, 13, 9, 7]]).astype(np.int32)
    print("encoder_input_ids: ", encoder_input_ids)
    sequence, scores = ort_session.run(None, {'encoder_input_ids': encoder_input_ids})
    print("sequence: ", sequence)
    print("scores: ", scores)

from onnxruntime.transformers.convert_generation import move_initializers

def move_initializers_on_outer_scope(input_path, output_path):
    model = onnx.load(input_path)
    main_graph = model.graph
    beam_search_node = model.graph.node[0]
    decoder_graph = [attr for attr in beam_search_node.attribute if attr.name == 'decoder'][0].g
    encoder_graph = [attr for attr in beam_search_node.attribute if attr.name == 'encoder'][0].g
    main_graph.initializer.extend(move_initializers(decoder_graph, min_elements=10))
    main_graph.initializer.extend(move_initializers(encoder_graph, min_elements=10))
    onnx.save(model, output_path)


if __name__ == '__main__':
    create_model('model_with_inner_initializers.onnx')

    # Following runs successfully as expected
    run_model('model_with_inner_initializers.onnx')
    move_initializers_on_outer_scope('model_with_inner_initializers.onnx', 'model_with_outer_initializers.onnx')

    # Following fails with error: "Error mapping feeds: Could not find OrtValue with name '<name>'"
    run_model('model_with_outer_initializers.onnx')

Urgency

No response

Platform

Linux

OS Version

Ubuntu 20.04

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

09c9843

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Other / Unknown

Execution Provider Library Version

No response

@github-actions github-actions bot added the model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. label Dec 6, 2024
tianleiwu pushed a commit that referenced this issue Dec 9, 2024
### Description
This PR adds the logic needed to consider only the needed implicit
inputs on BeamSearch op in case of T5 model (encoder/decoder, 2 graphs).
The logic added is similar to what happens in the _If_ kernel setup.


### Motivation and Context
Fixes #23043
guschmue pushed a commit that referenced this issue Dec 9, 2024
### Description
This PR adds the logic needed to consider only the needed implicit
inputs on BeamSearch op in case of T5 model (encoder/decoder, 2 graphs).
The logic added is similar to what happens in the _If_ kernel setup.


### Motivation and Context
Fixes #23043
ankitm3k pushed a commit to intel/onnxruntime that referenced this issue Dec 11, 2024
### Description
This PR adds the logic needed to consider only the needed implicit
inputs on BeamSearch op in case of T5 model (encoder/decoder, 2 graphs).
The logic added is similar to what happens in the _If_ kernel setup.


### Motivation and Context
Fixes microsoft#23043
ankitm3k pushed a commit to intel/onnxruntime that referenced this issue Dec 11, 2024
### Description
This PR adds the logic needed to consider only the needed implicit
inputs on BeamSearch op in case of T5 model (encoder/decoder, 2 graphs).
The logic added is similar to what happens in the _If_ kernel setup.


### Motivation and Context
Fixes microsoft#23043
ankitm3k pushed a commit to intel/onnxruntime that referenced this issue Dec 11, 2024
### Description
This PR adds the logic needed to consider only the needed implicit
inputs on BeamSearch op in case of T5 model (encoder/decoder, 2 graphs).
The logic added is similar to what happens in the _If_ kernel setup.


### Motivation and Context
Fixes microsoft#23043
tarekziade pushed a commit to tarekziade/onnxruntime that referenced this issue Jan 10, 2025
### Description
This PR adds the logic needed to consider only the needed implicit
inputs on BeamSearch op in case of T5 model (encoder/decoder, 2 graphs).
The logic added is similar to what happens in the _If_ kernel setup.


### Motivation and Context
Fixes microsoft#23043
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant