diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index c07cbb80e8e22..2e880a5682760 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1795,12 +1795,17 @@ IMPLEMENT_GRADIENT_BUILDER(GetExternalGradient) { std::vector output_args; for (size_t output_index = 0; output_index < node_def.outputs.size(); ++output_index) { - const auto& output = node_def.outputs[output_index]; + if (output_index >= GetSrcNodeInputSize()) { + continue; + } + if (!IsGradientRequiredForSrcNodeInput(output_index)) { output_args.emplace_back(ArgDef()); continue; } + const auto& output = node_def.outputs[output_index]; + if (output.find("GI(") == 0) { size_t index = static_cast(std::stoi(output.substr(3, output.length() - 4))); output_args.emplace_back(GI(index));