diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index e610bb46a67bc..7576e8814b826 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1113,17 +1113,18 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) { ArgDef grad = GO(0); if (!keepdims) { size_t numInputs = GetSrcNodeInputSize(); + grad = IA("Unqueezed_Grad") if (attributes.find("axes") != attributes.end()) { std::vector axes_values = RetrieveValues(attributes.at("axes")); if (SrcNodeOpsetVersion() < 13) { // axes is attribute for unsqueeze - result.push_back(NodeDef("Unsqueeze", {GO(0)}, {IA("Unqueezed_Grad")}, {MakeAttribute("axes", axes_values)})); + result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)})); } else { NodeDef axes_values_node = ConstantVectorNode(axes_values, Name("axes_values")); result.push_back(axes_values_node); - result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), axes_values_node.output_args[0]}, {IA("Unqueezed_Grad")})); + result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), axes_values_node.output_args[0]}, {grad})); } } else if (numInputs == 2) { // optional input 'axes' is available as input I(1) - result.push_back(NodeDef("Unsqueeze", {GO(0), I(1)}, {IA("Unqueezed_Grad")})); + result.push_back(NodeDef("Unsqueeze", {GO(0), I(1)}, {grad})); } }