From 01ae553eb773c30da8df20090ad8c90e9321be1a Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Sun, 14 Apr 2024 15:19:34 +0200 Subject: [PATCH] Fix static asserts in generated code --- .../Differentiator/BaseForwardModeVisitor.h | 3 +- include/clad/Differentiator/ErrorEstimator.h | 2 +- .../clad/Differentiator/ExternalRMVSource.h | 4 ++- .../clad/Differentiator/ReverseModeVisitor.h | 3 +- .../Differentiator/VectorForwardModeVisitor.h | 2 +- include/clad/Differentiator/VisitorBase.h | 15 ++++---- lib/Differentiator/BaseForwardModeVisitor.cpp | 35 +++++++++++++------ lib/Differentiator/ErrorEstimator.cpp | 6 ++-- lib/Differentiator/ReverseModeVisitor.cpp | 24 ++++++++++--- .../VectorForwardModeVisitor.cpp | 5 +-- .../FunctionCallsWithResults.C | 22 ++++++++++++ test/Gradient/FunctionCalls.C | 33 +++++++++++++++++ 12 files changed, 122 insertions(+), 32 deletions(-) diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index b34de9c8e..3cc8542b0 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -73,7 +73,7 @@ class BaseForwardModeVisitor StmtDiff VisitStmt(const clang::Stmt* S); StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp); // Decl is not Stmt, so it cannot be visited directly. - virtual VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD); + virtual DeclDiff DifferentiateVarDecl(const clang::VarDecl* VD); /// Shorthand for warning on differentiation of unsupported operators void unsupportedOpWarn(clang::SourceLocation loc, llvm::ArrayRef args = {}) { @@ -108,6 +108,7 @@ class BaseForwardModeVisitor const clang::SubstNonTypeTemplateParmExpr* NTTP); StmtDiff VisitImplicitValueInitExpr(const clang::ImplicitValueInitExpr* IVIE); StmtDiff VisitCStyleCastExpr(const clang::CStyleCastExpr* CSCE); + static DeclDiff DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD); virtual clang::QualType GetPushForwardDerivativeType(clang::QualType ParamType); diff --git a/include/clad/Differentiator/ErrorEstimator.h b/include/clad/Differentiator/ErrorEstimator.h index f0ce288d3..381fe42af 100644 --- a/include/clad/Differentiator/ErrorEstimator.h +++ b/include/clad/Differentiator/ErrorEstimator.h @@ -161,7 +161,7 @@ class ErrorEstimationHandler : public ExternalRMVSource { /// \param[in] VDDiff The variable declaration to calculate the error in. /// \param[in] isInsideLoop A flag to keep track of if we are inside a /// loop. - void EmitDeclErrorStmts(VarDeclDiff VDDiff, bool isInsideLoop); + void EmitDeclErrorStmts(DeclDiff VDDiff, bool isInsideLoop); /// This function returns the size expression for a given variable /// (`var.size()` for clad::array/clad::array_ref diff --git a/include/clad/Differentiator/ExternalRMVSource.h b/include/clad/Differentiator/ExternalRMVSource.h index ff3d7cda9..b6c967a9a 100644 --- a/include/clad/Differentiator/ExternalRMVSource.h +++ b/include/clad/Differentiator/ExternalRMVSource.h @@ -16,7 +16,9 @@ namespace clad { struct DiffRequest; class StmtDiff; -class VarDeclDiff; + +template +class DeclDiff; using direction = rmv::direction; diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 6c7d1fe71..8a3d76e31 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -406,11 +406,12 @@ namespace clad { StmtDiff VisitSwitchStmt(const clang::SwitchStmt* SS); StmtDiff VisitCaseStmt(const clang::CaseStmt* CS); StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS); - VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD); + DeclDiff DifferentiateVarDecl(const clang::VarDecl* VD); StmtDiff VisitSubstNonTypeTemplateParmExpr( const clang::SubstNonTypeTemplateParmExpr* NTTP); StmtDiff VisitCXXNullPtrLiteralExpr(const clang::CXXNullPtrLiteralExpr* NPE); + static DeclDiff DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD); /// A helper method to differentiate a single Stmt in the reverse mode. /// Internally, calls Visit(S, expr). Its result is wrapped into a diff --git a/include/clad/Differentiator/VectorForwardModeVisitor.h b/include/clad/Differentiator/VectorForwardModeVisitor.h index 2a34e8fc1..5a7d145cb 100644 --- a/include/clad/Differentiator/VectorForwardModeVisitor.h +++ b/include/clad/Differentiator/VectorForwardModeVisitor.h @@ -77,7 +77,7 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor { VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE) override; StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override; // Decl is not Stmt, so it cannot be visited directly. - VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD) override; + DeclDiff DifferentiateVarDecl(const clang::VarDecl* VD) override; clang::QualType GetPushForwardDerivativeType(clang::QualType ParamType) override; diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 60aaa356f..1a3caecc4 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -80,21 +80,22 @@ namespace clad { void setForwSweepStmt_dx(clang::Stmt* S) { m_DerivativeForForwSweep = S; } }; - class VarDeclDiff { + template + class DeclDiff { private: - std::array data; + std::array data; public: - VarDeclDiff(clang::VarDecl* orig = nullptr, - clang::VarDecl* diff = nullptr) { + DeclDiff(T* orig = nullptr, + T* diff = nullptr) { data[1] = orig; data[0] = diff; } - clang::VarDecl* getDecl() { return data[1]; } - clang::VarDecl* getDecl_dx() { return data[0]; } + T* getDecl() { return data[1]; } + T* getDecl_dx() { return data[0]; } // Decl_dx goes first! - std::array& getBothDecls() { return data; } + std::array& getBothDecls() { return data; } }; /// A base class for all common functionality for visitors diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 9ce8d9a2e..fadd85599 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -11,7 +11,6 @@ #include "clad/Differentiator/CladUtils.h" #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/ErrorEstimator.h" -#include "clad/Differentiator/StmtClone.h" #include "clang/AST/ASTContext.h" #include "clang/AST/Expr.h" @@ -574,7 +573,7 @@ StmtDiff BaseForwardModeVisitor::VisitIfStmt(const IfStmt* If) { VarDecl* condVarClone = nullptr; if (const VarDecl* condVarDecl = If->getConditionVariable()) { - VarDeclDiff condVarDeclDiff = DifferentiateVarDecl(condVarDecl); + DeclDiff condVarDeclDiff = DifferentiateVarDecl(condVarDecl); condVarClone = condVarDeclDiff.getDecl(); if (condVarDeclDiff.getDecl_dx()) addToCurrentBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx())); @@ -672,7 +671,7 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { VarDecl* condVarDecl = FS->getConditionVariable(); VarDecl* condVarClone = nullptr; if (condVarDecl) { - VarDeclDiff condVarResult = DifferentiateVarDecl(condVarDecl); + DeclDiff condVarResult = DifferentiateVarDecl(condVarDecl); condVarClone = condVarResult.getDecl(); if (condVarResult.getDecl_dx()) addToCurrentBlock(BuildDeclStmt(condVarResult.getDecl_dx())); @@ -1380,7 +1379,8 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { return StmtDiff(op, opDiff); } -VarDeclDiff BaseForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) { +DeclDiff +BaseForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) { StmtDiff initDiff = VD->getInit() ? Visit(VD->getInit()) : StmtDiff{}; // Here we are assuming that derived type and the original type are same. // This may not necessarily be true in the future. @@ -1392,7 +1392,7 @@ VarDeclDiff BaseForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) { VD->getType(), "_d_" + VD->getNameAsString(), initDiff.getExpr_dx(), VD->isDirectInit(), nullptr, VD->getInitStyle()); m_Variables.emplace(VDClone, BuildDeclRef(VDDerived)); - return VarDeclDiff(VDClone, VDDerived); + return DeclDiff(VDClone, VDDerived); } StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) { @@ -1431,7 +1431,7 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) { // double _d_y = _d_x; double y = x; for (auto D : DS->decls()) { if (auto VD = dyn_cast(D)) { - VarDeclDiff VDDiff = DifferentiateVarDecl(VD); + DeclDiff VDDiff = DifferentiateVarDecl(VD); // Check if decl's name is the same as before. The name may be changed // if decl name collides with something in the derivative body. // This can happen in rare cases, e.g. when the original function @@ -1454,14 +1454,23 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) { m_DeclReplacements[VD] = VDDiff.getDecl(); decls.push_back(VDDiff.getDecl()); declsDiff.push_back(VDDiff.getDecl_dx()); + } else if (auto* SAD = dyn_cast(D)) { + DeclDiff SADDiff = DifferentiateStaticAssertDecl(SAD); + if (SADDiff.getDecl()) + decls.push_back(SADDiff.getDecl()); + if (SADDiff.getDecl_dx()) + declsDiff.push_back(SADDiff.getDecl_dx()); } else { diag(DiagnosticsEngine::Warning, D->getEndLoc(), "Unsupported declaration"); } } - Stmt* DSClone = BuildDeclStmt(decls); - Stmt* DSDiff = BuildDeclStmt(declsDiff); + Stmt *DSClone = nullptr, *DSDiff = nullptr; + if (!decls.empty()) + DSClone = BuildDeclStmt(decls); + if (!declsDiff.empty()) + DSDiff = BuildDeclStmt(declsDiff); return StmtDiff(DSClone, DSDiff); } @@ -1534,7 +1543,7 @@ StmtDiff BaseForwardModeVisitor::VisitWhileStmt(const WhileStmt* WS) { const VarDecl* condVar = WS->getConditionVariable(); VarDecl* condVarClone = nullptr; - VarDeclDiff condVarRes; + DeclDiff condVarRes; if (condVar) { condVarRes = DifferentiateVarDecl(condVar); condVarClone = condVarRes.getDecl(); @@ -1659,7 +1668,7 @@ StmtDiff BaseForwardModeVisitor::VisitSwitchStmt(const SwitchStmt* SS) { const VarDecl* condVarDecl = SS->getConditionVariable(); VarDecl* condVarClone = nullptr; if (condVarDecl) { - VarDeclDiff condVarDeclDiff = DifferentiateVarDecl(condVarDecl); + DeclDiff condVarDeclDiff = DifferentiateVarDecl(condVarDecl); condVarClone = condVarDeclDiff.getDecl(); addToCurrentBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx())); } @@ -2022,4 +2031,10 @@ StmtDiff BaseForwardModeVisitor::VisitSubstNonTypeTemplateParmExpr( const clang::SubstNonTypeTemplateParmExpr* NTTP) { return Visit(NTTP->getReplacement()); } + +DeclDiff +BaseForwardModeVisitor::DifferentiateStaticAssertDecl( + const clang::StaticAssertDecl* SAD) { + return DeclDiff(); +} } // end namespace clad diff --git a/lib/Differentiator/ErrorEstimator.cpp b/lib/Differentiator/ErrorEstimator.cpp index 5b6438e97..25c00f71e 100644 --- a/lib/Differentiator/ErrorEstimator.cpp +++ b/lib/Differentiator/ErrorEstimator.cpp @@ -248,7 +248,7 @@ void ErrorEstimationHandler::EmitBinaryOpErrorStmts(Expr* LExpr, EmitErrorEstimationStmts(direction::reverse); } -void ErrorEstimationHandler::EmitDeclErrorStmts(VarDeclDiff VDDiff, +void ErrorEstimationHandler::EmitDeclErrorStmts(DeclDiff VDDiff, bool isInsideLoop) { auto VD = VDDiff.getDecl(); if (!ShouldEstimateErrorFor(VD)) @@ -481,8 +481,8 @@ void ErrorEstimationHandler::ActBeforeFinalizingVisitDeclStmt( // For all dependent variables, we register them for estimation // here. for (size_t i = 0; i < decls.size(); i++) { - VarDeclDiff VDDiff(static_cast(decls[0]), - static_cast(declsDiff[0])); + DeclDiff VDDiff(static_cast(decls[0]), + static_cast(declsDiff[0])); EmitDeclErrorStmts(VDDiff, m_RMV->isInsideLoop); } } diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index db2255cdf..e5b08e58f 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -860,7 +860,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, VarDecl* condVarClone = nullptr; if (const VarDecl* condVarDecl = If->getConditionVariable()) { - VarDeclDiff condVarDeclDiff = DifferentiateVarDecl(condVarDecl); + DeclDiff condVarDeclDiff = DifferentiateVarDecl(condVarDecl); condVarClone = condVarDeclDiff.getDecl(); if (condVarDeclDiff.getDecl_dx()) addToBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx()), m_Globals); @@ -2549,7 +2549,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(op, ResultRef, nullptr, valueForRevPass); } - VarDeclDiff ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD) { + DeclDiff + ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD) { StmtDiff initDiff; Expr* VDDerivedInit = nullptr; // Local declarations are promoted to the function global scope. This @@ -2745,7 +2746,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } m_Variables.emplace(VDClone, derivedVDE); - return VarDeclDiff(VDClone, VDDerived); + return DeclDiff(VDClone, VDDerived); } // TODO: 'shouldEmit' parameter should be removed after converting @@ -2812,7 +2813,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // double _d_y = _d_x; double y = x; for (auto* D : DS->decls()) { if (auto* VD = dyn_cast(D)) { - VarDeclDiff VDDiff = DifferentiateVarDecl(VD); + DeclDiff VDDiff = DifferentiateVarDecl(VD); // Check if decl's name is the same as before. The name may be changed // if decl name collides with something in the derivative body. @@ -2878,6 +2879,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, localDeclsDiff.push_back(VDDiff.getDecl_dx()); else declsDiff.push_back(VDDiff.getDecl_dx()); + } else if (auto* SAD = dyn_cast(D)) { + DeclDiff SADDiff = DifferentiateStaticAssertDecl(SAD); + if (SADDiff.getDecl()) + decls.push_back(SADDiff.getDecl()); + if (SADDiff.getDecl_dx()) + declsDiff.push_back(SADDiff.getDecl_dx()); } else { diag(DiagnosticsEngine::Warning, D->getEndLoc(), @@ -2885,7 +2892,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } } - Stmt* DSClone = BuildDeclStmt(decls); + Stmt* DSClone = nullptr; + if (!decls.empty()) + DSClone = BuildDeclStmt(decls); if (!localDeclsDiff.empty()) { Stmt* localDSDIff = BuildDeclStmt(localDeclsDiff); @@ -3831,6 +3840,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return Visit(NTTP->getReplacement()); } + DeclDiff ReverseModeVisitor::DifferentiateStaticAssertDecl( + const clang::StaticAssertDecl* SAD) { + return DeclDiff(nullptr, nullptr); + } + QualType ReverseModeVisitor::GetParameterDerivativeType(QualType yType, QualType xType) { diff --git a/lib/Differentiator/VectorForwardModeVisitor.cpp b/lib/Differentiator/VectorForwardModeVisitor.cpp index f07458814..2c8cabe21 100644 --- a/lib/Differentiator/VectorForwardModeVisitor.cpp +++ b/lib/Differentiator/VectorForwardModeVisitor.cpp @@ -576,7 +576,8 @@ StmtDiff VectorForwardModeVisitor::VisitReturnStmt(const ReturnStmt* RS) { return StmtDiff(returnStmt); } -VarDeclDiff VectorForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) { +DeclDiff +VectorForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) { StmtDiff initDiff = VD->getInit() ? Visit(VD->getInit()) : StmtDiff{}; // Here we are assuming that derived type and the original type are same. // This may not necessarily be true in the future. @@ -610,7 +611,7 @@ VarDeclDiff VectorForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) { false, nullptr, VarDecl::InitializationStyle::CallInit); m_Variables.emplace(VDClone, BuildDeclRef(VDDerived)); - return VarDeclDiff(VDClone, VDDerived); + return DeclDiff(VDClone, VDDerived); } } // namespace clad diff --git a/test/FirstDerivative/FunctionCallsWithResults.C b/test/FirstDerivative/FunctionCallsWithResults.C index 1eef5b388..1e1bb45f4 100644 --- a/test/FirstDerivative/FunctionCallsWithResults.C +++ b/test/FirstDerivative/FunctionCallsWithResults.C @@ -2,6 +2,7 @@ // RUN: ./FunctionCallsWithResults.out | FileCheck -check-prefix=CHECK-EXEC %s #include "clad/Differentiator/Differentiator.h" +#include int printf(const char* fmt, ...); @@ -289,6 +290,25 @@ double fn9 (double i, double j) { // CHECK-NEXT: return _t0.pushforward * _t3 + _t2 * _t1.pushforward; // CHECK-NEXT: } +double fn10(double x) { + std::mt19937 gen64; + std::uniform_real_distribution distribution(0.0,1.0); + double rand = distribution(gen64); + return x+rand; +} + +// CHECK: double fn10_darg0(double x) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: std::mt19937 _d_gen64; +// CHECK-NEXT: std::mt19937 gen64; +// CHECK-NEXT: std::uniform_real_distribution _d_distribution(0., 0.); +// CHECK-NEXT: std::uniform_real_distribution distribution(0., 1.); +// CHECK-NEXT: clad::ValueAndPushforward _t0 = distribution.operator_call_pushforward(gen64, &_d_distribution, _d_gen64); +// CHECK-NEXT: double _d_rand = _t0.pushforward; +// CHECK-NEXT: double rand0 = _t0.value; +// CHECK-NEXT: return _d_x + _d_rand; +// CHECK-NEXT: } + float test_1_darg0(float x); float test_2_darg0(float x); float test_4_darg0(float x); @@ -318,6 +338,7 @@ int main () { INIT(fn7, "i"); INIT(fn8, "i"); INIT(fn9, "i"); + INIT(fn10, "x"); TEST(fn1, 3, 5); // CHECK-EXEC: {12.00} TEST(fn2, 3, 5); // CHECK-EXEC: {181.00} @@ -328,6 +349,7 @@ int main () { TEST(fn7, 3, 5); // CHECK-EXEC: {8.00} TEST(fn8, 3, 5); // CHECK-EXEC: {19.04} TEST(fn9, 3, 5); // CHECK-EXEC: {5.00} + TEST(fn10, 3); // CHECK-EXEC: {1.00} return 0; // CHECK: clad::ValueAndPushforward sum_of_squares_pushforward(double u, double v, double _d_u, double _d_v) { diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index 6ffe45adf..471925081 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -6,6 +6,7 @@ // CHECK-NOT: {{.*error|warning|note:.*}} #include "clad/Differentiator/Differentiator.h" +#include namespace A { template T constantFn(T i) { return 3; } @@ -647,6 +648,29 @@ double fn18(double x, double y) { // CHECK-NEXT: } // CHECK-NEXT: } +template +T templated_fn(double x) { + static_assert(std::is_floating_point::value, + "template argument must be a floating point type"); + return x; +} + +// CHECK: void templated_fn_pullback(double x, double _d_y, double *_d_x); + +double fn19(double x) { + return templated_fn(x); +} + +// CHECK: void fn19_grad(double x, double *_d_x) { +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: templated_fn_pullback(x, 1, &_r0); +// CHECK-NEXT: *_d_x += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + template void reset(T* arr, int n) { for (int i=0; i