Skip to content

Commit

Permalink
Improve fwd mode for calling fxn with zero/null derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Jul 18, 2024
1 parent e04d04e commit 2a5ac8b
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 10 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ namespace clad {
void SetSwitchCaseSubStmt(clang::SwitchCase* SC, clang::Stmt* subStmt);

bool IsLiteral(const clang::Expr* E);
bool IsZeroOrNullValue(const clang::Expr* E);

bool IsMemoryFunction(const clang::FunctionDecl* FD);
bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD);
Expand Down
17 changes: 7 additions & 10 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1071,7 +1071,6 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
// Returning the function call and zero derivative
return StmtDiff(Call, zero);
}

// Find the built-in derivatives namespace.
std::string s = std::to_string(m_DerivativeOrder);
if (m_DerivativeOrder == 1)
Expand Down Expand Up @@ -1200,19 +1199,17 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
// FIXME: revert this when this is integrated in the activity analysis pass.
if (!callDiff) {
if (!isa<CXXOperatorCallExpr>(CE) && !isa<CXXMemberCallExpr>(CE)) {
bool allArgsAreConstantLiterals = true;
bool allArgsHaveZeroDerivatives = true;
for (unsigned i = 0, e = CE->getNumArgs(); i < e; ++i) {
const Expr* arg = CE->getArg(i);
// if it's of type MaterializeTemporaryExpr, then check its
// subexpression.
if (const auto* MTE = dyn_cast<MaterializeTemporaryExpr>(arg))
arg = clad_compat::GetSubExpr(MTE);
if (!arg->isEvaluatable(m_Context)) {
allArgsAreConstantLiterals = false;
Expr* dArg = diffArgs[i];
// If argDiff.expr_dx is nullptr or is a constant 0, then the derivative
// of the function call is 0.
if (!clad::utils::IsZeroOrNullValue(dArg->IgnoreParenImpCasts())) {
allArgsHaveZeroDerivatives = false;
break;
}
}
if (allArgsAreConstantLiterals) {
if (allArgsHaveZeroDerivatives) {
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
Expand Down
16 changes: 16 additions & 0 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,22 @@ namespace clad {
isa<GNUNullExpr>(E);
}

bool IsZeroOrNullValue(const clang::Expr* E) {
if (!E)
return true;
if (isa<CXXNullPtrLiteralExpr>(E))
return true;
if (auto* FL = dyn_cast<FloatingLiteral>(E))
return FL->getValue().isZero();
if (auto* IL = dyn_cast<IntegerLiteral>(E))
return IL->getValue() == 0;
if (auto* CL = dyn_cast<CharacterLiteral>(E))
return CL->getValue() == 0;
if (auto* SL = dyn_cast<StringLiteral>(E))
return SL->getLength() == 0;
return false;
}

bool IsMemoryFunction(const clang::FunctionDecl* FD) {

#if CLANG_VERSION_MAJOR > 12
Expand Down
22 changes: 22 additions & 0 deletions test/FirstDerivative/CallArguments.C
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,25 @@ float f_const_args_func_8(const float x, float y) {
// CHECK-NEXT: return _t0.pushforward + _t1.pushforward - _d_y;
// CHECK-NEXT: }

float f_literal_helper(float x, char ch, float* p, float* q) {
if (ch == 'a')
return x * x;
return -x * x;
}

float f_literal_args_func(float x, float y) {
printf("hello world ");
return x * f_literal_helper(0.5, 'a', nullptr, nullptr);
}

// CHECK: float f_literal_args_func_darg0(float x, float y) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: printf("hello world ");
// CHECK-NEXT: float _t0 = f_literal_helper(0.5, 'a', nullptr, nullptr);
// CHECK-NEXT: return _d_x * _t0 + x * 0.F;
// CHECK-NEXT: }

extern "C" int printf(const char* fmt, ...);
int main () { // expected-no-diagnostics
auto f = clad::differentiate(g, 0);
Expand Down Expand Up @@ -165,6 +184,9 @@ int main () { // expected-no-diagnostics
const float f8x = 1.F;
printf("f8_darg0=%f\n", f8.execute(f8x,2.F));
//CHECK-EXEC: f8_darg0=2.000000
auto f9 = clad::differentiate(f_literal_args_func, 0);
printf("f9_darg0=%.2f\n", f9.execute(1.F,2.F));
//CHECK-EXEC: hello world f9_darg0=0.25

// CHECK: clad::ValueAndPushforward<float, float> f_const_helper_pushforward(const float x, const float _d_x) {
// CHECK-NEXT: return {x * x, _d_x * x + x * _d_x};
Expand Down

0 comments on commit 2a5ac8b

Please sign in to comment.