From 7272ee7157db4f79edec9a07601d89205d90be3b Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Fri, 20 Oct 2023 14:12:43 +0530 Subject: [PATCH] Set derivative to 0 for fxn calls with literal arguments --- lib/Differentiator/BaseForwardModeVisitor.cpp | 29 +++++++++++++++ lib/Differentiator/ReverseModeVisitor.cpp | 37 ++++++++++++------- test/FirstDerivative/FunctionCalls.C | 7 +--- .../FunctionCallsWithResults.C | 8 ++-- test/Gradient/FunctionCalls.C | 32 ++++++++++++++++ 5 files changed, 91 insertions(+), 22 deletions(-) diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index f092adcd2..6e89364eb 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1055,6 +1055,35 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { .get(); } + // If all arguments are constant literals, then this does not contribute to + // the gradient. + if (!callDiff) { + if (!isa(CE) && !isa(CE)) { + bool allArgsAreConstantLiterals = true; + for (unsigned i = 0, e = CE->getNumArgs(); i < e; ++i) { + const Expr* arg = CE->getArg(i); + // if it's of type MaterializeTemporaryExpr, then check its + // subexpression. + if (const auto* MTE = dyn_cast(arg)) + arg = MTE->getSubExpr(); + if (!isa(arg) && !isa(arg)) { + allArgsAreConstantLiterals = false; + break; + } + } + if (allArgsAreConstantLiterals) { + Expr* call = + m_Sema + .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc, + llvm::MutableArrayRef(CallArgs), noLoc) + .get(); + auto* zero = + ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); + return StmtDiff(call, zero); + } + } + } + if (!callDiff) { // Overloaded derivative was not found, request the CladPlugin to // derive the called function. diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index df30acb91..30056411c 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1381,6 +1381,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, !isa(CE)) return StmtDiff(Clone(CE)); + // If all arguments are constant literals, then this does not contribute to + // the gradient. + if (!isa(CE) && !isa(CE)) { + bool allArgsAreConstantLiterals = true; + for (const Expr* arg : CE->arguments()) { + // if it's of type MaterializeTemporaryExpr, then check its + // subexpression. + if (const auto* MTE = dyn_cast(arg)) + arg = MTE->getSubExpr(); + if (!isa(arg) && !isa(arg)) { + allArgsAreConstantLiterals = false; + break; + } + } + if (allArgsAreConstantLiterals) + return StmtDiff(Clone(CE)); + } + // Stores the call arguments for the function to be derived llvm::SmallVector CallArgs{}; // Stores the dx of the call arguments for the function to be derived @@ -1419,14 +1437,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // statements there later. std::size_t insertionPoint = getCurrentBlock(direction::reverse).size(); - // `CXXOperatorCallExpr` have the `base` expression as the first argument. - size_t skipFirstArg = 0; - - // Here we do not need to check if FD is an instance method or a static - // method because C++ forbids creating operator overloads as static methods. - if (isa(CE) && isa(FD)) - skipFirstArg = 1; - // FIXME: We should add instructions for handling non-differentiable // arguments. Currently we are implicitly assuming function call only // contains differentiable arguments. @@ -1665,7 +1675,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // is required because the pullback function expects `clad::array_ref` // type for representing array derivatives. Currently, only constant // array data members have derivatives of constant array types. - if (isa(argDerivative->getType())) { + if ((argDerivative != nullptr) && + isa(argDerivative->getType())) { Expr* init = utils::BuildCladArrayInitByConstArray(m_Sema, argDerivative); auto derivativeArrayRefVD = BuildVarDecl( @@ -1676,11 +1687,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, ArgDeclStmts.push_back(BuildDeclStmt(derivativeArrayRefVD)); argDerivative = BuildDeclRef(derivativeArrayRefVD); } - if (isCladArrayType(argDerivative->getType())) { + if ((argDerivative != nullptr) && + isCladArrayType(argDerivative->getType())) gradArgExpr = argDerivative; - } else { + else gradArgExpr = BuildOp(UnaryOperatorKind::UO_AddrOf, argDerivative); - } } else { // Declare: diffArgType _grad = 0; gradVarDecl = BuildVarDecl( @@ -1721,7 +1732,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (pullback) pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs() - - static_cast(skipFirstArg), + static_cast(isCXXOperatorCall), pullback); // Try to find it in builtin derivatives diff --git a/test/FirstDerivative/FunctionCalls.C b/test/FirstDerivative/FunctionCalls.C index c070b281c..276d44fca 100644 --- a/test/FirstDerivative/FunctionCalls.C +++ b/test/FirstDerivative/FunctionCalls.C @@ -86,14 +86,9 @@ float test_4(int x) { return overloaded(); } -// CHECK: {{(clad::)?}}ValueAndPushforward overloaded_pushforward() { -// CHECK-NEXT: return {3, 0}; -// CHECK-NEXT: } - // CHECK: float test_4_darg0(int x) { // CHECK-NEXT: int _d_x = 1; -// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward _t0 = overloaded_pushforward(); -// CHECK-NEXT: return _t0.pushforward; +// CHECK-NEXT: return 0; // CHECK-NEXT: } float test_5(int x) { diff --git a/test/FirstDerivative/FunctionCallsWithResults.C b/test/FirstDerivative/FunctionCallsWithResults.C index e7e36615b..aa99017fc 100644 --- a/test/FirstDerivative/FunctionCallsWithResults.C +++ b/test/FirstDerivative/FunctionCallsWithResults.C @@ -297,7 +297,7 @@ double sum(double* arr, int n) { double fn8(double i, double j) { double arr[5] = {}; modifyArr(arr, 5, i*j); - return sum(arr, 5); + return sum(arr, 5) * std::tanh(1.0); } // CHECK: double fn8_darg0(double i, double j) { @@ -307,7 +307,9 @@ double fn8(double i, double j) { // CHECK-NEXT: double arr[5] = {}; // CHECK-NEXT: modifyArr_pushforward(arr, 5, i * j, _d_arr, 0, _d_i * j + i * _d_j); // CHECK-NEXT: clad::ValueAndPushforward _t0 = sum_pushforward(arr, 5, _d_arr, 0); -// CHECK-NEXT: return _t0.pushforward; +// CHECK-NEXT: double &_t1 = _t0.value; +// CHECK-NEXT: double _t2 = std::tanh(1.); +// CHECK-NEXT: return _t0.pushforward * _t2 + _t1 * 0; // CHECK-NEXT: } float test_1_darg0(float x); @@ -346,6 +348,6 @@ int main () { TEST(fn5, 3, 5); // CHECK-EXEC: {1.00} TEST(fn6, 3, 5, 7); // CHECK-EXEC: {3.00} TEST(fn7, 3, 5); // CHECK-EXEC: {8.00} - TEST(fn8, 3, 5); // CHECK-EXEC: {25.00} + TEST(fn8, 3, 5); // CHECK-EXEC: {19.04} return 0; } diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index 93d135756..3f440c36d 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -460,6 +460,36 @@ double fn7(double i, double j) { // CHECK-NEXT: } // CHECK-NEXT: } +double fn8(double x, double y) { + return x*y*std::tanh(1.0)*std::max(1.0, 2.0); +} + +// CHECK: void fn8_grad(double x, double y, clad::array_ref _d_x, clad::array_ref _d_y) { +// CHECK-NEXT: double _t0; +// CHECK-NEXT: double _t1; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: double _t3; +// CHECK-NEXT: double _t4; +// CHECK-NEXT: double _t5; +// CHECK-NEXT: _t3 = x; +// CHECK-NEXT: _t2 = y; +// CHECK-NEXT: _t4 = _t3 * _t2; +// CHECK-NEXT: _t1 = std::tanh(1.); +// CHECK-NEXT: _t5 = _t4 * _t1; +// CHECK-NEXT: _t0 = std::max(1., 2.); +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: double _r0 = 1 * _t0; +// CHECK-NEXT: double _r1 = _r0 * _t1; +// CHECK-NEXT: double _r2 = _r1 * _t2; +// CHECK-NEXT: * _d_x += _r2; +// CHECK-NEXT: double _r3 = _t3 * _r1; +// CHECK-NEXT: * _d_y += _r3; +// CHECK-NEXT: double _r4 = _t4 * _r0; +// CHECK-NEXT: double _r5 = _t5 * 1; +// CHECK-NEXT: } +// CHECK-NEXT: } template void reset(T* arr, int n) { @@ -513,6 +543,7 @@ int main() { INIT(fn5); INIT(fn6); INIT(fn7); + INIT(fn8); TEST1_float(fn1, 11); // CHECK-EXEC: {3.00} TEST2(fn2, 3, 5); // CHECK-EXEC: {1.00, 3.00} @@ -522,4 +553,5 @@ int main() { TEST_ARR5(fn5, arr, 5); // CHECK-EXEC: {5.00, 1.00, 0.00, 0.00, 0.00} TEST2(fn6, 3, 5); // CHECK-EXEC: {5.00, 3.00} TEST2(fn7, 3, 5); // CHECK-EXEC: {10.00, 71.00} + TEST2(fn8, 3, 5); // CHECK-EXEC: {7.62, 4.57} }