From d629553ecfb9eabb9d19a12d4643b5d7fb760974 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Mon, 20 May 2024 19:01:01 +0200 Subject: [PATCH] Add support for null (empty) statements This commit removes the "unsupported" warning when differentiating code with ";" in both forward and reverse modes. It also prevents additional semicolons from getting pulled into the derivative code. Fixes: #899 --- .../clad/Differentiator/BaseForwardModeVisitor.h | 1 + include/clad/Differentiator/ReverseModeVisitor.h | 1 + test/FirstDerivative/Simple.C | 2 +- test/Gradient/Gradients.C | 14 ++++++++++++++ 4 files changed, 17 insertions(+), 1 deletion(-) diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index 9e0394c8a..0aa593033 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -109,6 +109,7 @@ class BaseForwardModeVisitor const clang::SubstNonTypeTemplateParmExpr* NTTP); StmtDiff VisitImplicitValueInitExpr(const clang::ImplicitValueInitExpr* IVIE); StmtDiff VisitCStyleCastExpr(const clang::CStyleCastExpr* CSCE); + StmtDiff VisitNullStmt(const clang::NullStmt* NS) { return StmtDiff{}; }; static DeclDiff DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD); diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index eb234b610..976ed698b 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -411,6 +411,7 @@ namespace clad { const clang::SubstNonTypeTemplateParmExpr* NTTP); StmtDiff VisitCXXNullPtrLiteralExpr(const clang::CXXNullPtrLiteralExpr* NPE); + StmtDiff VisitNullStmt(const clang::NullStmt* NS) { return StmtDiff{}; }; static DeclDiff DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD); diff --git a/test/FirstDerivative/Simple.C b/test/FirstDerivative/Simple.C index 6e90ad0a3..1074316e1 100644 --- a/test/FirstDerivative/Simple.C +++ b/test/FirstDerivative/Simple.C @@ -7,7 +7,7 @@ extern "C" int printf(const char* fmt, ...); int f(int x) { printf("This is f(x).\n"); - return x*x + x - x*x*x*x; + return x*x + x - x*x*x*x;; } int main () { diff --git a/test/Gradient/Gradients.C b/test/Gradient/Gradients.C index fc374912b..60049ae55 100644 --- a/test/Gradient/Gradients.C +++ b/test/Gradient/Gradients.C @@ -901,6 +901,18 @@ double fn_cond_init(double x) { //CHECK-NEXT: } //CHECK-NEXT: } +double fn_null_stmts(double x) { + ;;;;;;;;;;;;;;;;; + ;;;;;return x;;;; + ;;;;;;;;;;;;;;;;; +} // = x + +//CHECK: void fn_null_stmts_grad(double x, double *_d_x) { +//CHECK-NEXT: goto _label0; +//CHECK-NEXT: _label0: +//CHECK-NEXT: *_d_x += 1; +//CHECK-NEXT: } + #define TEST(F, x, y) \ { \ result[0] = 0; \ @@ -971,4 +983,6 @@ int main() { INIT_GRADIENT(fn_cond_init); TEST_GRADIENT(fn_cond_init, /*numOfDerivativeArgs=*/1, 0, &dx); // CHECK-EXEC: 1.00 TEST_GRADIENT(fn_cond_init, /*numOfDerivativeArgs=*/1, -1, &dx); // CHECK-EXEC: 1.00 + + INIT_GRADIENT(fn_null_stmts); }