From e2a570301a4c534e46110cc426409504aa842c44 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Tue, 2 Jul 2024 05:44:21 +0000 Subject: [PATCH 1/2] check require grad before gradient output naming --- orttraining/orttraining/core/graph/gradient_builder.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 22dcf4eb92411..c07cbb80e8e22 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1794,7 +1794,13 @@ IMPLEMENT_GRADIENT_BUILDER(GetExternalGradient) { } std::vector output_args; - for (const auto& output : node_def.outputs) { + for (size_t output_index = 0; output_index < node_def.outputs.size(); ++output_index) { + const auto& output = node_def.outputs[output_index]; + if (!IsGradientRequiredForSrcNodeInput(output_index)) { + output_args.emplace_back(ArgDef()); + continue; + } + if (output.find("GI(") == 0) { size_t index = static_cast(std::stoi(output.substr(3, output.length() - 4))); output_args.emplace_back(GI(index)); From 8282c0fdeac0a92fc53583330c13055f9ed0f15d Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Thu, 4 Jul 2024 10:52:17 +0000 Subject: [PATCH 2/2] fix --- orttraining/orttraining/core/graph/gradient_builder.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index c07cbb80e8e22..2e880a5682760 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1795,12 +1795,17 @@ IMPLEMENT_GRADIENT_BUILDER(GetExternalGradient) { std::vector output_args; for (size_t output_index = 0; output_index < node_def.outputs.size(); ++output_index) { - const auto& output = node_def.outputs[output_index]; + if (output_index >= GetSrcNodeInputSize()) { + continue; + } + if (!IsGradientRequiredForSrcNodeInput(output_index)) { output_args.emplace_back(ArgDef()); continue; } + const auto& output = node_def.outputs[output_index]; + if (output.find("GI(") == 0) { size_t index = static_cast(std::stoi(output.substr(3, output.length() - 4))); output_args.emplace_back(GI(index));