Skip to content

Commit

Permalink
Addressing review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
amancini-N committed Dec 20, 2024
1 parent c350042 commit 39852df
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ namespace transformers {
Inputs:
input_ids: int32 (B, 1)
encoder_input_ids: int32 (B, encode_sequence_length) (optional)
encoder_attention_mask: int32 (B, encode_sequence_length)
encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) (optional)
past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size)
past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size)
Expand Down Expand Up @@ -53,9 +54,6 @@ Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_i
bool has_hidden_state = subgraph_inputs[2 + has_encoder_input_ids]->Name() == "encoder_hidden_states";
SetPastInputIndex(has_hidden_state, has_encoder_input_ids);

ORT_RETURN_IF(first_past_input_index_ != 2 && first_past_input_index_ != 3 && first_past_input_index_ != 4,
"kFirstPastInputIndex currently only supports 2, 3 or 4");

if (!past_present_share_buffer_) {
ORT_RETURN_IF(has_decoder_masked_attention_, "decoder_masked_attention shall use with past_present_share_buffer");
ORT_RETURN_IF(num_subgraph_inputs < 4 + first_past_input_index_ ||
Expand All @@ -78,11 +76,6 @@ Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_i
"decoder subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name());
const int enc_attn_mask_index = 1 + has_encoder_input_ids_;
const int enc_hidden_state_index = enc_attn_mask_index + 1;
if (has_encoder_input_ids_) {
ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_input_ids",
"decoder subgraph input 1 shall be named as encoder_input_ids, got: ",
subgraph_inputs[1]->Name());
}
ORT_RETURN_IF(subgraph_inputs[enc_attn_mask_index]->Name() != "encoder_attention_mask",
"decoder subgraph input ", std::to_string(enc_attn_mask_index),
" shall be named as encoder_attention_mask, got: ",
Expand Down Expand Up @@ -268,6 +261,8 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
// When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output
// of encoder.
// When first_past_input_index_ == 2, the past states are copied from the second output of encoder.
// TODO - probably more robust to introduce a encoder_out/decoder_in mapping instead of relying on positions.

Check warning on line 264 in onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc:264: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// What happens if encoder_hidden_states is present in the encoder_fetches but not in the decoder_feeds?
for (size_t j = static_cast<size_t>(2) - has_hidden_state_; j < encoder_fetches.size(); j++) {
if (j == 1) {
ORT_RETURN_IF(has_hidden_state_ == false, "Invalid hidden_states expension: has_hidden_state_ == false");
Expand Down

0 comments on commit 39852df

Please sign in to comment.