Skip to content

Commit

Permalink
Support model with multiple SCE loss nodes (#20016)
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani authored Mar 22, 2024
1 parent 6238e9c commit 2bc2924
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ NodeSet GradientGraphBuilder::ReverseBFSWithStopGradient(const NodeSet& nodes) c
continue;
}
const NodeArg* node_arg = n->InputDefs()[edge_it->GetDstArgIndex()];
if (!node_arg) {
LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
<< " of node: " << n->Name() << " because it is not found in the graph.";
continue;
}
const auto [is_tensor_type, is_allowed_type_for_grad, type] = IsAllowedForGradient(graph_, node_arg);
if (is_tensor_type) {
if (!is_allowed_type_for_grad) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Status InsertSoftmaxCrossEntropyLossOutput::Apply(Graph& graph, Node& node, Rewr
t.mutable_tensor_type()->mutable_shape()->CopyFrom(*X->Shape()); // log probability should have the same shape as logits.
}

NodeArg& node_arg = graph.GetOrCreateNodeArg(X->Name() + "_log_prob", &t);
NodeArg& node_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(X->Name() + "_log_prob"), &t);

outputs.push_back(&node_arg);

Expand Down

0 comments on commit 2bc2924

Please sign in to comment.