Skip to content

Commit

Permalink
undo IA changes
Browse files Browse the repository at this point in the history
  • Loading branch information
prathikr committed Mar 18, 2024
1 parent 7840801 commit 771cb5b
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1113,17 +1113,18 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) {
ArgDef grad = GO(0);
if (!keepdims) {
size_t numInputs = GetSrcNodeInputSize();
grad = IA("Unqueezed_Grad")
if (attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
if (SrcNodeOpsetVersion() < 13) { // axes is attribute for unsqueeze
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {IA("Unqueezed_Grad")}, {MakeAttribute("axes", axes_values)}));
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
} else {
NodeDef axes_values_node = ConstantVectorNode(axes_values, Name("axes_values"));
result.push_back(axes_values_node);
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), axes_values_node.output_args[0]}, {IA("Unqueezed_Grad")}));
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), axes_values_node.output_args[0]}, {grad}));
}
} else if (numInputs == 2) { // optional input 'axes' is available as input I(1)
result.push_back(NodeDef("Unsqueeze", {GO(0), I(1)}, {IA("Unqueezed_Grad")}));
result.push_back(NodeDef("Unsqueeze", {GO(0), I(1)}, {grad}));
}
}

Expand Down

0 comments on commit 771cb5b

Please sign in to comment.