From fe044beb5187d1eb887c6ed54f94a3d3abc7966d Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Fri, 22 Mar 2024 10:28:44 -0700 Subject: [PATCH] Support model with multiple SCE loss nodes (#20016) --- .../orttraining/core/framework/gradient_graph_builder.cc | 5 +++++ .../orttraining/core/optimizer/insert_output_rewriter.cc | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.cc b/orttraining/orttraining/core/framework/gradient_graph_builder.cc index d66591318d5c7..2ee4b5e1a173d 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.cc @@ -210,6 +210,11 @@ NodeSet GradientGraphBuilder::ReverseBFSWithStopGradient(const NodeSet& nodes) c continue; } const NodeArg* node_arg = n->InputDefs()[edge_it->GetDstArgIndex()]; + if (!node_arg) { + LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() + << " of node: " << n->Name() << " because it is not found in the graph."; + continue; + } const auto [is_tensor_type, is_allowed_type_for_grad, type] = IsAllowedForGradient(graph_, node_arg); if (is_tensor_type) { if (!is_allowed_type_for_grad) { diff --git a/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc b/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc index 2aade8c9bc1f9..61fc8d5492c2b 100644 --- a/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc +++ b/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc @@ -44,7 +44,7 @@ Status InsertSoftmaxCrossEntropyLossOutput::Apply(Graph& graph, Node& node, Rewr t.mutable_tensor_type()->mutable_shape()->CopyFrom(*X->Shape()); // log probability should have the same shape as logits. } - NodeArg& node_arg = graph.GetOrCreateNodeArg(X->Name() + "_log_prob", &t); + NodeArg& node_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(X->Name() + "_log_prob"), &t); outputs.push_back(&node_arg);