Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for null (empty) statements
Browse files Browse the repository at this point in the history
This commit removes the "unsupported" warning when differentiating
 code with ";" in both forward and reverse modes. It also prevents
additional semicolons from getting pulled into the derivative code.

Fixes: vgvassilev#899
gojakuch committed May 20, 2024
1 parent 19c6573 commit 6bae634
Showing 4 changed files with 17 additions and 1 deletion.
1 change: 1 addition & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
@@ -109,6 +109,7 @@ class BaseForwardModeVisitor
const clang::SubstNonTypeTemplateParmExpr* NTTP);
StmtDiff VisitImplicitValueInitExpr(const clang::ImplicitValueInitExpr* IVIE);
StmtDiff VisitCStyleCastExpr(const clang::CStyleCastExpr* CSCE);
StmtDiff VisitNullStmt(const clang::NullStmt* NS) { return StmtDiff{}; };
static DeclDiff<clang::StaticAssertDecl>
DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD);

1 change: 1 addition & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
@@ -411,6 +411,7 @@ namespace clad {
const clang::SubstNonTypeTemplateParmExpr* NTTP);
StmtDiff
VisitCXXNullPtrLiteralExpr(const clang::CXXNullPtrLiteralExpr* NPE);
StmtDiff VisitNullStmt(const clang::NullStmt* NS) { return StmtDiff{}; };
static DeclDiff<clang::StaticAssertDecl>
DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD);

2 changes: 1 addition & 1 deletion test/FirstDerivative/Simple.C
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@ extern "C" int printf(const char* fmt, ...);

int f(int x) {
printf("This is f(x).\n");
return x*x + x - x*x*x*x;
return x*x + x - x*x*x*x;;
}

int main () {
14 changes: 14 additions & 0 deletions test/Gradient/Gradients.C
Original file line number Diff line number Diff line change
@@ -901,6 +901,18 @@ double fn_cond_init(double x) {
//CHECK-NEXT: }
//CHECK-NEXT: }

double fn_null_stmts(double x) {
;;;;;;;;;;;;;;;;;
;;;;;return x;;;;
;;;;;;;;;;;;;;;;;
} // = x

//CHECK: void fn_null_stmts_grad(double x, double *_d_x) {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: *_d_x += 1;
//CHECK-NEXT: }

#define TEST(F, x, y) \
{ \
result[0] = 0; \
@@ -971,4 +983,6 @@ int main() {
INIT_GRADIENT(fn_cond_init);
TEST_GRADIENT(fn_cond_init, /*numOfDerivativeArgs=*/1, 0, &dx); // CHECK-EXEC: 1.00
TEST_GRADIENT(fn_cond_init, /*numOfDerivativeArgs=*/1, -1, &dx); // CHECK-EXEC: 1.00

INIT_GRADIENT(fn_null_stmts);
}

0 comments on commit 6bae634

Please sign in to comment.