Skip to content

Commit

Permalink
Add support for null (empty) statements
Browse files Browse the repository at this point in the history
This commit removes the "unsupported" warning when differentiating
 code with ";" in both forward and reverse modes. It also prevents
additional semicolons from getting pulled into the derivative code.

Fixes: vgvassilev#899
  • Loading branch information
gojakuch authored and vgvassilev committed May 21, 2024
1 parent 1ba8929 commit d629553
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 1 deletion.
1 change: 1 addition & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class BaseForwardModeVisitor
const clang::SubstNonTypeTemplateParmExpr* NTTP);
StmtDiff VisitImplicitValueInitExpr(const clang::ImplicitValueInitExpr* IVIE);
StmtDiff VisitCStyleCastExpr(const clang::CStyleCastExpr* CSCE);
StmtDiff VisitNullStmt(const clang::NullStmt* NS) { return StmtDiff{}; };
static DeclDiff<clang::StaticAssertDecl>
DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD);

Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ namespace clad {
const clang::SubstNonTypeTemplateParmExpr* NTTP);
StmtDiff
VisitCXXNullPtrLiteralExpr(const clang::CXXNullPtrLiteralExpr* NPE);
StmtDiff VisitNullStmt(const clang::NullStmt* NS) { return StmtDiff{}; };
static DeclDiff<clang::StaticAssertDecl>
DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD);

Expand Down
2 changes: 1 addition & 1 deletion test/FirstDerivative/Simple.C
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ extern "C" int printf(const char* fmt, ...);

int f(int x) {
printf("This is f(x).\n");
return x*x + x - x*x*x*x;
return x*x + x - x*x*x*x;;
}

int main () {
Expand Down
14 changes: 14 additions & 0 deletions test/Gradient/Gradients.C
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,18 @@ double fn_cond_init(double x) {
//CHECK-NEXT: }
//CHECK-NEXT: }

double fn_null_stmts(double x) {
;;;;;;;;;;;;;;;;;
;;;;;return x;;;;
;;;;;;;;;;;;;;;;;
} // = x

//CHECK: void fn_null_stmts_grad(double x, double *_d_x) {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: *_d_x += 1;
//CHECK-NEXT: }

#define TEST(F, x, y) \
{ \
result[0] = 0; \
Expand Down Expand Up @@ -971,4 +983,6 @@ int main() {
INIT_GRADIENT(fn_cond_init);
TEST_GRADIENT(fn_cond_init, /*numOfDerivativeArgs=*/1, 0, &dx); // CHECK-EXEC: 1.00
TEST_GRADIENT(fn_cond_init, /*numOfDerivativeArgs=*/1, -1, &dx); // CHECK-EXEC: 1.00

INIT_GRADIENT(fn_null_stmts);
}

0 comments on commit d629553

Please sign in to comment.