diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index b34de9c8e..9bcf4d44f 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -108,6 +108,7 @@ class BaseForwardModeVisitor const clang::SubstNonTypeTemplateParmExpr* NTTP); StmtDiff VisitImplicitValueInitExpr(const clang::ImplicitValueInitExpr* IVIE); StmtDiff VisitCStyleCastExpr(const clang::CStyleCastExpr* CSCE); + StmtDiff VisitStaticAssertDecl(const clang::StaticAssertDecl* SAD); virtual clang::QualType GetPushForwardDerivativeType(clang::QualType ParamType); diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 6c7d1fe71..3c8a5c717 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -411,6 +411,7 @@ namespace clad { const clang::SubstNonTypeTemplateParmExpr* NTTP); StmtDiff VisitCXXNullPtrLiteralExpr(const clang::CXXNullPtrLiteralExpr* NPE); + StmtDiff VisitStaticAssertDecl(const clang::StaticAssertDecl* SAD); /// A helper method to differentiate a single Stmt in the reverse mode. /// Internally, calls Visit(S, expr). Its result is wrapped into a diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 9ce8d9a2e..ea1d3c8c1 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -11,7 +11,6 @@ #include "clad/Differentiator/CladUtils.h" #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/ErrorEstimator.h" -#include "clad/Differentiator/StmtClone.h" #include "clang/AST/ASTContext.h" #include "clang/AST/Expr.h" @@ -1454,6 +1453,8 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) { m_DeclReplacements[VD] = VDDiff.getDecl(); decls.push_back(VDDiff.getDecl()); declsDiff.push_back(VDDiff.getDecl_dx()); + } else if (auto SAD = dyn_cast(D)) { + return VisitStaticAssertDecl(SAD); } else { diag(DiagnosticsEngine::Warning, D->getEndLoc(), "Unsupported declaration"); @@ -2022,4 +2023,9 @@ StmtDiff BaseForwardModeVisitor::VisitSubstNonTypeTemplateParmExpr( const clang::SubstNonTypeTemplateParmExpr* NTTP) { return Visit(NTTP->getReplacement()); } + +StmtDiff BaseForwardModeVisitor::VisitStaticAssertDecl( + const clang::StaticAssertDecl* SAD) { + return nullptr; +} } // end namespace clad diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index db2255cdf..c012f465f 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2878,6 +2878,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, localDeclsDiff.push_back(VDDiff.getDecl_dx()); else declsDiff.push_back(VDDiff.getDecl_dx()); + } else if (auto SAD = dyn_cast(D)) { + return VisitStaticAssertDecl(SAD); } else { diag(DiagnosticsEngine::Warning, D->getEndLoc(), @@ -3831,6 +3833,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return Visit(NTTP->getReplacement()); } + StmtDiff ReverseModeVisitor::VisitStaticAssertDecl( + const clang::StaticAssertDecl* SAD) { + return nullptr; + } + QualType ReverseModeVisitor::GetParameterDerivativeType(QualType yType, QualType xType) { diff --git a/test/FirstDerivative/FunctionCallsWithResults.C b/test/FirstDerivative/FunctionCallsWithResults.C index 1eef5b388..205c139aa 100644 --- a/test/FirstDerivative/FunctionCallsWithResults.C +++ b/test/FirstDerivative/FunctionCallsWithResults.C @@ -2,6 +2,7 @@ // RUN: ./FunctionCallsWithResults.out | FileCheck -check-prefix=CHECK-EXEC %s #include "clad/Differentiator/Differentiator.h" +#include int printf(const char* fmt, ...); @@ -289,6 +290,25 @@ double fn9 (double i, double j) { // CHECK-NEXT: return _t0.pushforward * _t3 + _t2 * _t1.pushforward; // CHECK-NEXT: } +double fn10(double x) { + std::mt19937 gen64; + std::uniform_real_distribution distribution(0.0,1.0); + double rand = distribution(gen64); + return x+rand; +} + +// CHECK: double f_darg0(double x) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: std::mt19937 _d_gen64; +// CHECK-NEXT: std::mt19937 gen64; +// CHECK-NEXT: std::uniform_real_distribution _d_distribution(0., 0.); +// CHECK-NEXT: std::uniform_real_distribution distribution(0., 1.); +// CHECK-NEXT: clad::ValueAndPushforward _t0 = distribution.operator_call_pushforward(gen64, &_d_distribution, _d_gen64); +// CHECK-NEXT: double _d_rand = _t0.pushforward; +// CHECK-NEXT: double rand0 = _t0.value; +// CHECK-NEXT: return _d_x * x + x * _d_x + _d_rand; +// CHECK-NEXT: } + float test_1_darg0(float x); float test_2_darg0(float x); float test_4_darg0(float x); @@ -318,6 +338,7 @@ int main () { INIT(fn7, "i"); INIT(fn8, "i"); INIT(fn9, "i"); + INIT(fn10, "x"); TEST(fn1, 3, 5); // CHECK-EXEC: {12.00} TEST(fn2, 3, 5); // CHECK-EXEC: {181.00} @@ -328,6 +349,7 @@ int main () { TEST(fn7, 3, 5); // CHECK-EXEC: {8.00} TEST(fn8, 3, 5); // CHECK-EXEC: {19.04} TEST(fn9, 3, 5); // CHECK-EXEC: {5.00} + TEST(fn10, 3); // CHECK-EXEC: {1.00} return 0; // CHECK: clad::ValueAndPushforward sum_of_squares_pushforward(double u, double v, double _d_u, double _d_v) {