From 2a5ac8bd2df394366df88e128574f50c9cf6d521 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Thu, 18 Jul 2024 09:45:23 +0200 Subject: [PATCH] Improve fwd mode for calling fxn with zero/null derivatives --- include/clad/Differentiator/CladUtils.h | 1 + lib/Differentiator/BaseForwardModeVisitor.cpp | 17 ++++++-------- lib/Differentiator/CladUtils.cpp | 16 ++++++++++++++ test/FirstDerivative/CallArguments.C | 22 +++++++++++++++++++ 4 files changed, 46 insertions(+), 10 deletions(-) diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index ae8b55813..05899cad7 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -317,6 +317,7 @@ namespace clad { void SetSwitchCaseSubStmt(clang::SwitchCase* SC, clang::Stmt* subStmt); bool IsLiteral(const clang::Expr* E); + bool IsZeroOrNullValue(const clang::Expr* E); bool IsMemoryFunction(const clang::FunctionDecl* FD); bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD); diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 7853205a7..1d39622f6 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1071,7 +1071,6 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { // Returning the function call and zero derivative return StmtDiff(Call, zero); } - // Find the built-in derivatives namespace. std::string s = std::to_string(m_DerivativeOrder); if (m_DerivativeOrder == 1) @@ -1200,19 +1199,17 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { // FIXME: revert this when this is integrated in the activity analysis pass. if (!callDiff) { if (!isa(CE) && !isa(CE)) { - bool allArgsAreConstantLiterals = true; + bool allArgsHaveZeroDerivatives = true; for (unsigned i = 0, e = CE->getNumArgs(); i < e; ++i) { - const Expr* arg = CE->getArg(i); - // if it's of type MaterializeTemporaryExpr, then check its - // subexpression. - if (const auto* MTE = dyn_cast(arg)) - arg = clad_compat::GetSubExpr(MTE); - if (!arg->isEvaluatable(m_Context)) { - allArgsAreConstantLiterals = false; + Expr* dArg = diffArgs[i]; + // If argDiff.expr_dx is nullptr or is a constant 0, then the derivative + // of the function call is 0. + if (!clad::utils::IsZeroOrNullValue(dArg->IgnoreParenImpCasts())) { + allArgsHaveZeroDerivatives = false; break; } } - if (allArgsAreConstantLiterals) { + if (allArgsHaveZeroDerivatives) { Expr* call = m_Sema .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index 96b35d6aa..d2784a2df 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -658,6 +658,22 @@ namespace clad { isa(E); } + bool IsZeroOrNullValue(const clang::Expr* E) { + if (!E) + return true; + if (isa(E)) + return true; + if (auto* FL = dyn_cast(E)) + return FL->getValue().isZero(); + if (auto* IL = dyn_cast(E)) + return IL->getValue() == 0; + if (auto* CL = dyn_cast(E)) + return CL->getValue() == 0; + if (auto* SL = dyn_cast(E)) + return SL->getLength() == 0; + return false; + } + bool IsMemoryFunction(const clang::FunctionDecl* FD) { #if CLANG_VERSION_MAJOR > 12 diff --git a/test/FirstDerivative/CallArguments.C b/test/FirstDerivative/CallArguments.C index 9c543a655..0f2a7b49a 100644 --- a/test/FirstDerivative/CallArguments.C +++ b/test/FirstDerivative/CallArguments.C @@ -132,6 +132,25 @@ float f_const_args_func_8(const float x, float y) { // CHECK-NEXT: return _t0.pushforward + _t1.pushforward - _d_y; // CHECK-NEXT: } +float f_literal_helper(float x, char ch, float* p, float* q) { + if (ch == 'a') + return x * x; + return -x * x; +} + +float f_literal_args_func(float x, float y) { + printf("hello world "); + return x * f_literal_helper(0.5, 'a', nullptr, nullptr); +} + +// CHECK: float f_literal_args_func_darg0(float x, float y) { +// CHECK-NEXT: float _d_x = 1; +// CHECK-NEXT: float _d_y = 0; +// CHECK-NEXT: printf("hello world "); +// CHECK-NEXT: float _t0 = f_literal_helper(0.5, 'a', nullptr, nullptr); +// CHECK-NEXT: return _d_x * _t0 + x * 0.F; +// CHECK-NEXT: } + extern "C" int printf(const char* fmt, ...); int main () { // expected-no-diagnostics auto f = clad::differentiate(g, 0); @@ -165,6 +184,9 @@ int main () { // expected-no-diagnostics const float f8x = 1.F; printf("f8_darg0=%f\n", f8.execute(f8x,2.F)); //CHECK-EXEC: f8_darg0=2.000000 + auto f9 = clad::differentiate(f_literal_args_func, 0); + printf("f9_darg0=%.2f\n", f9.execute(1.F,2.F)); + //CHECK-EXEC: hello world f9_darg0=0.25 // CHECK: clad::ValueAndPushforward f_const_helper_pushforward(const float x, const float _d_x) { // CHECK-NEXT: return {x * x, _d_x * x + x * _d_x};