Skip to content

Commit

Permalink
check require grad before gradient output naming
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwa committed Jul 2, 2024
1 parent beb2496 commit e2a5703
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1794,7 +1794,13 @@ IMPLEMENT_GRADIENT_BUILDER(GetExternalGradient) {
}

std::vector<ArgDef> output_args;
for (const auto& output : node_def.outputs) {
for (size_t output_index = 0; output_index < node_def.outputs.size(); ++output_index) {
const auto& output = node_def.outputs[output_index];
if (!IsGradientRequiredForSrcNodeInput(output_index)) {
output_args.emplace_back(ArgDef());
continue;
}

if (output.find("GI(") == 0) {
size_t index = static_cast<size_t>(std::stoi(output.substr(3, output.length() - 4)));
output_args.emplace_back(GI(index));
Expand Down

0 comments on commit e2a5703

Please sign in to comment.