Skip to content

Commit

Permalink
Add Gradient for Atan
Browse files Browse the repository at this point in the history
  • Loading branch information
cocotdf committed Dec 20, 2024
1 parent 4aca8f3 commit 153ec61
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 0 deletions.
13 changes: 13 additions & 0 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2227,5 +2227,18 @@ IMPLEMENT_GRADIENT_BUILDER(GetResizeGradient) {
SrcNodeAttributes())};
}

IMPLEMENT_GRADIENT_BUILDER(GetAtanGradient) {
// dl/dx = dl/dy * (1/(1+x^2))
NodeDef one_const_node = OneConstantNode(IElemType(0));
ArgDef one = one_const_node.output_args[0];
std::vector<NodeDef> result;
result.push_back(one_const_node);
result.push_back(NodeDef("Mul", {I(0), I(0)}, {IA("Square_I0")}));
result.push_back(NodeDef("Add", {IA("Square_I0"), one}, {IA("One_Plus_Square_I0")}));
result.push_back(NodeDef("Div", {GO(0), IA("One_Plus_Square_I0")}, {GI(0)}));
return result;
}


} // namespace training
} // namespace onnxruntime
2 changes: 2 additions & 0 deletions orttraining/orttraining/core/graph/gradient_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ DECLARE_GRADIENT_BUILDER(GetReciprocalGradient)
DECLARE_GRADIENT_BUILDER(GetLeakyReluGradient)
DECLARE_GRADIENT_BUILDER(GetConvTransposeGradient)
DECLARE_GRADIENT_BUILDER(GetResizeGradient)
DECLARE_GRADIENT_BUILDER(GetAtanGradient)


DECLARE_GRADIENT_BUILDER(GetExternalGradient)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("LeakyRelu", GetLeakyReluGradient);
REGISTER_GRADIENT_BUILDER("ConvTranspose", GetConvTransposeGradient);
REGISTER_GRADIENT_BUILDER("Resize", GetResizeGradient);
REGISTER_GRADIENT_BUILDER("Atan", GetAtanGradient);

REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient);
};
Expand Down
2 changes: 2 additions & 0 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3352,6 +3352,8 @@ TEST(GradientCheckerTest, ResizeGrad) {

#endif // USE_CUDA

TEST(GradientCheckerTest, AtanGrad) { UnaryOpGradientTest("Atan"); }

} // namespace test
} // namespace onnxruntime

Expand Down

0 comments on commit 153ec61

Please sign in to comment.