From b721b161afae241a4dd531c12fd4e924209acc81 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Mon, 20 May 2024 19:01:01 +0200 Subject: [PATCH] Add support for null (empty) statements This commit removes the "unsupported" warning when differentiating code with ";" in both forward and reverse modes. It also prevents additional semicolons from getting pulled into the derivative code. Fixes: #899 --- .../clad/Differentiator/BaseForwardModeVisitor.h | 1 + include/clad/Differentiator/ReverseModeVisitor.h | 1 + test/Gradient/Gradients.C | 14 ++++++++++++++ 3 files changed, 16 insertions(+) diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index 9e0394c8a..0aa593033 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -109,6 +109,7 @@ class BaseForwardModeVisitor const clang::SubstNonTypeTemplateParmExpr* NTTP); StmtDiff VisitImplicitValueInitExpr(const clang::ImplicitValueInitExpr* IVIE); StmtDiff VisitCStyleCastExpr(const clang::CStyleCastExpr* CSCE); + StmtDiff VisitNullStmt(const clang::NullStmt* NS) { return StmtDiff{}; }; static DeclDiff DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD); diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index eb234b610..976ed698b 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -411,6 +411,7 @@ namespace clad { const clang::SubstNonTypeTemplateParmExpr* NTTP); StmtDiff VisitCXXNullPtrLiteralExpr(const clang::CXXNullPtrLiteralExpr* NPE); + StmtDiff VisitNullStmt(const clang::NullStmt* NS) { return StmtDiff{}; }; static DeclDiff DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD); diff --git a/test/Gradient/Gradients.C b/test/Gradient/Gradients.C index fc374912b..62c948747 100644 --- a/test/Gradient/Gradients.C +++ b/test/Gradient/Gradients.C @@ -901,6 +901,18 @@ double fn_cond_init(double x) { //CHECK-NEXT: } //CHECK-NEXT: } +double fn_null_stmts(double x) { + ;;;;;;;;;;;;;;;;; + ;;;;;return x;;;; + ;;;;;;;;;;;;;;;;; +} // = x + +//CHECK: void fn_empty_lines_grad(double x, double *_d_x) { +//CHECK-NEXT: goto _label0; +//CHECK-NEXT: _label0: +//CHECK-NEXT: *_d_x += 1; +//CHECK-NEXT: } + #define TEST(F, x, y) \ { \ result[0] = 0; \ @@ -971,4 +983,6 @@ int main() { INIT_GRADIENT(fn_cond_init); TEST_GRADIENT(fn_cond_init, /*numOfDerivativeArgs=*/1, 0, &dx); // CHECK-EXEC: 1.00 TEST_GRADIENT(fn_cond_init, /*numOfDerivativeArgs=*/1, -1, &dx); // CHECK-EXEC: 1.00 + + INIT_GRADIENT(fn_empty_lines); }