Skip to content

Commit

Permalink
Fix BeamSearch T5 if initializers are on outer scope (#23044)
Browse files Browse the repository at this point in the history
### Description
This PR adds the logic needed to consider only the needed implicit
inputs on BeamSearch op in case of T5 model (encoder/decoder, 2 graphs).
The logic added is similar to what happens in the _If_ kernel setup.


### Motivation and Context
Fixes #23043
  • Loading branch information
amancini-N authored Dec 9, 2024
1 parent 2f2c73b commit 8f3384b
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 6 deletions.
15 changes: 13 additions & 2 deletions onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Subgraph::Subgraph(
allocator_(nullptr),
is_output_float16_(false) {
num_implicit_inputs = static_cast<int>(node.ImplicitInputDefs().size());
used_implicit_inputs = std::vector<bool>(num_implicit_inputs, true);

auto& subgraph_inputs = subgraph.GetInputs();
auto& subgraph_outputs = subgraph.GetOutputs();
Expand Down Expand Up @@ -73,8 +74,18 @@ Status Subgraph::Setup(const SessionState& session_state,
// The position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter.
feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end());

for (auto& entry : node.ImplicitInputDefs()) {
feed_names.push_back(entry->Name());
const auto& subgraph_map = subgraph_session_state.GetOrtValueNameIdxMap();

const auto& implicit_input_defs = node.ImplicitInputDefs();
for (size_t i = 0, end = num_implicit_inputs; i < end; ++i) {
const auto* entry = implicit_input_defs[i];
int idx;
if (subgraph_map.GetIdx(entry->Name(), idx).IsOK()) {
feed_names.push_back(entry->Name());
} else {
--num_implicit_inputs;
used_implicit_inputs[i] = false;
}
}

InlinedVector<OrtDevice> feed_locations;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Subgraph {
const GraphViewer& subgraph; // The subgraph

int num_implicit_inputs;
std::vector<bool> used_implicit_inputs;

int num_subgraph_inputs; // Same as subgraph_input_names.size(), keep it for convenience.
int num_subgraph_outputs; // Same as subgraph_output_names.size()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,11 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
}

// Pass through implicit inputs.
for (const auto* entry : implicit_inputs) {
decoder_feeds.push_back(*entry);
for (size_t i = 0; i < implicit_inputs.size(); ++i) {
const auto* entry = implicit_inputs[i];
if (used_implicit_inputs[i]) {
decoder_feeds.push_back(*entry);
}
}

return Status::OK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,11 @@ Status T5EncoderSubgraph::CreateInitialFeeds(
pinned_allocator,
location));

for (const auto* entry : implicit_inputs) {
feeds.push_back(*entry);
for (size_t i = 0; i < implicit_inputs.size(); ++i) {
const auto* entry = implicit_inputs[i];
if (used_implicit_inputs[i]) {
feeds.push_back(*entry);
}
}

return Status::OK();
Expand Down
30 changes: 30 additions & 0 deletions onnxruntime/test/contrib_ops/beam_search_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <gsl/gsl>
#include "core/session/onnxruntime_cxx_api.h"
#include "test/common/cuda_op_test_utils.h"
#include "test/providers/model_tester.h"
#include "test/util/include/current_test_name.h"

#ifdef USE_CUDA
#include "core/providers/cuda/cuda_provider_options.h"
Expand Down Expand Up @@ -394,5 +396,33 @@ TEST(BeamSearchTest, GptBeamSearchFp16_VocabPadded) {
}
}

TEST(BeamSearchTest, DummyT5) {
#if defined(USE_CUDA) && defined(USE_DML)
SKIP_CUDA_TEST_WITH_DML;
#endif
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5.onnx"));
tester.ConfigEp(DefaultCpuExecutionProvider());
tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7});
tester.AddOutput("sequences", {1, 3, 10}, {2, 16, 6, 14, 1, 15, 6, 14, 1, 15, 2, 3, 4, 15, 6, 14, 1, 15, 6, 14, 2, 16, 6, 14, 1, 15, 6, 14, 1, 14});
#ifdef USE_CUDA
tester.ConfigEp(DefaultCudaExecutionProvider());
#endif
tester.RunWithConfig();
}

TEST(BeamSearchTest, DummyT5WithOuterScopeInitializers) {
#if defined(USE_CUDA) && defined(USE_DML)
SKIP_CUDA_TEST_WITH_DML;
#endif
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_outer_scope_initializers.onnx"));
tester.ConfigEp(DefaultCpuExecutionProvider());
tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7});
tester.AddOutput("sequences", {1, 3, 10}, {2, 16, 6, 14, 1, 15, 6, 14, 1, 15, 2, 3, 4, 15, 6, 14, 1, 15, 6, 14, 2, 16, 6, 14, 1, 15, 6, 14, 1, 14});
#ifdef USE_CUDA
tester.ConfigEp(DefaultCudaExecutionProvider());
#endif
tester.RunWithConfig();
}

} // namespace test
} // namespace onnxruntime
Binary file added onnxruntime/test/testdata/dummy_t5.onnx
Binary file not shown.
Binary file not shown.

0 comments on commit 8f3384b

Please sign in to comment.