diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc index d675ba742e03b..7757435990a65 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc @@ -31,6 +31,7 @@ Subgraph::Subgraph( allocator_(nullptr), is_output_float16_(false) { num_implicit_inputs = static_cast(node.ImplicitInputDefs().size()); + used_implicit_inputs = std::vector(num_implicit_inputs, true); auto& subgraph_inputs = subgraph.GetInputs(); auto& subgraph_outputs = subgraph.GetOutputs(); @@ -73,8 +74,18 @@ Status Subgraph::Setup(const SessionState& session_state, // The position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter. feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end()); - for (auto& entry : node.ImplicitInputDefs()) { - feed_names.push_back(entry->Name()); + const auto& subgraph_map = subgraph_session_state.GetOrtValueNameIdxMap(); + + const auto& implicit_input_defs = node.ImplicitInputDefs(); + for (size_t i = 0, end = num_implicit_inputs; i < end; ++i) { + const auto* entry = implicit_input_defs[i]; + int idx; + if (subgraph_map.GetIdx(entry->Name(), idx).IsOK()) { + feed_names.push_back(entry->Name()); + } else { + --num_implicit_inputs; + used_implicit_inputs[i] = false; + } } InlinedVector feed_locations; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h index bde591626bb83..8ec9c9cbdc20f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h @@ -31,6 +31,7 @@ class Subgraph { const GraphViewer& subgraph; // The subgraph int num_implicit_inputs; + std::vector used_implicit_inputs; int num_subgraph_inputs; // Same as subgraph_input_names.size(), keep it for convenience. int num_subgraph_outputs; // Same as subgraph_output_names.size() diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 9037e58aaf31f..6c66bfc2816e4 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -281,8 +281,11 @@ Status T5DecoderSubgraph::CreateInitialFeeds( } // Pass through implicit inputs. - for (const auto* entry : implicit_inputs) { - decoder_feeds.push_back(*entry); + for (size_t i = 0; i < implicit_inputs.size(); ++i) { + const auto* entry = implicit_inputs[i]; + if (used_implicit_inputs[i]) { + decoder_feeds.push_back(*entry); + } } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc index 51473c0c931b9..d59db4afac2c2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc @@ -145,8 +145,11 @@ Status T5EncoderSubgraph::CreateInitialFeeds( pinned_allocator, location)); - for (const auto* entry : implicit_inputs) { - feeds.push_back(*entry); + for (size_t i = 0; i < implicit_inputs.size(); ++i) { + const auto* entry = implicit_inputs[i]; + if (used_implicit_inputs[i]) { + feeds.push_back(*entry); + } } return Status::OK(); diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index f6fc9ea7662cb..ca600c0700682 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -7,6 +7,8 @@ #include #include "core/session/onnxruntime_cxx_api.h" #include "test/common/cuda_op_test_utils.h" +#include "test/providers/model_tester.h" +#include "test/util/include/current_test_name.h" #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_options.h" @@ -394,5 +396,33 @@ TEST(BeamSearchTest, GptBeamSearchFp16_VocabPadded) { } } +TEST(BeamSearchTest, DummyT5) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5.onnx")); + tester.ConfigEp(DefaultCpuExecutionProvider()); + tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 16, 6, 14, 1, 15, 6, 14, 1, 15, 2, 3, 4, 15, 6, 14, 1, 15, 6, 14, 2, 16, 6, 14, 1, 15, 6, 14, 1, 14}); +#ifdef USE_CUDA + tester.ConfigEp(DefaultCudaExecutionProvider()); +#endif + tester.RunWithConfig(); +} + +TEST(BeamSearchTest, DummyT5WithOuterScopeInitializers) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_outer_scope_initializers.onnx")); + tester.ConfigEp(DefaultCpuExecutionProvider()); + tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 16, 6, 14, 1, 15, 6, 14, 1, 15, 2, 3, 4, 15, 6, 14, 1, 15, 6, 14, 2, 16, 6, 14, 1, 15, 6, 14, 1, 14}); +#ifdef USE_CUDA + tester.ConfigEp(DefaultCudaExecutionProvider()); +#endif + tester.RunWithConfig(); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/dummy_t5.onnx b/onnxruntime/test/testdata/dummy_t5.onnx new file mode 100644 index 0000000000000..3a3bbf4767523 Binary files /dev/null and b/onnxruntime/test/testdata/dummy_t5.onnx differ diff --git a/onnxruntime/test/testdata/dummy_t5_with_outer_scope_initializers.onnx b/onnxruntime/test/testdata/dummy_t5_with_outer_scope_initializers.onnx new file mode 100644 index 0000000000000..4b36cc9b6eca0 Binary files /dev/null and b/onnxruntime/test/testdata/dummy_t5_with_outer_scope_initializers.onnx differ