diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 8db69150919d5..997beb198f450 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -20,8 +20,9 @@ namespace transformers { Inputs: input_ids: int32 (B, 1) + encoder_input_ids: int32 (B, encode_sequence_length) (optional) encoder_attention_mask: int32 (B, encode_sequence_length) - encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) + encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) (optional) past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size) past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size) @@ -53,9 +54,6 @@ Status T5DecoderSubgraph::Validate(const std::vector& subgraph_i bool has_hidden_state = subgraph_inputs[2 + has_encoder_input_ids]->Name() == "encoder_hidden_states"; SetPastInputIndex(has_hidden_state, has_encoder_input_ids); - ORT_RETURN_IF(first_past_input_index_ != 2 && first_past_input_index_ != 3 && first_past_input_index_ != 4, - "kFirstPastInputIndex currently only supports 2, 3 or 4"); - if (!past_present_share_buffer_) { ORT_RETURN_IF(has_decoder_masked_attention_, "decoder_masked_attention shall use with past_present_share_buffer"); ORT_RETURN_IF(num_subgraph_inputs < 4 + first_past_input_index_ || @@ -78,11 +76,6 @@ Status T5DecoderSubgraph::Validate(const std::vector& subgraph_i "decoder subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name()); const int enc_attn_mask_index = 1 + has_encoder_input_ids_; const int enc_hidden_state_index = enc_attn_mask_index + 1; - if (has_encoder_input_ids_) { - ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_input_ids", - "decoder subgraph input 1 shall be named as encoder_input_ids, got: ", - subgraph_inputs[1]->Name()); - } ORT_RETURN_IF(subgraph_inputs[enc_attn_mask_index]->Name() != "encoder_attention_mask", "decoder subgraph input ", std::to_string(enc_attn_mask_index), " shall be named as encoder_attention_mask, got: ", @@ -268,6 +261,8 @@ Status T5DecoderSubgraph::CreateInitialFeeds( // When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output // of encoder. // When first_past_input_index_ == 2, the past states are copied from the second output of encoder. + // TODO - probably more robust to introduce a encoder_out/decoder_in mapping instead of relying on positions. + // What happens if encoder_hidden_states is present in the encoder_fetches but not in the decoder_feeds? for (size_t j = static_cast(2) - has_hidden_state_; j < encoder_fetches.size(); j++) { if (j == 1) { ORT_RETURN_IF(has_hidden_state_ == false, "Invalid hidden_states expension: has_hidden_state_ == false");