From 3a7b269b977c122679e893082f5cf1321c77e7e8 Mon Sep 17 00:00:00 2001 From: kchristin22 Date: Wed, 20 Mar 2024 23:13:01 +0200 Subject: [PATCH 01/11] Add original loop variable's step prior to switch in reverse pass --- lib/Differentiator/ReverseModeVisitor.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 04f626286..118766f6e 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -4105,6 +4105,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, bodyDiff.updateStmtDx(MakeCompoundStmt(revLoopBlock)); m_LoopBlock.pop_back(); + activeBreakContHandler->EndCFSwitchStmtScope(); + activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); + PopBreakContStmtHandler(); + // Increment statement in the for-loop is only executed if the iteration // did not end with a break/continue statement. Therefore, forLoopIncDiff // should be inside the last switch case in the reverse pass. @@ -4117,10 +4121,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } } - activeBreakContHandler->EndCFSwitchStmtScope(); - activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); - PopBreakContStmtHandler(); - Expr* counterDecrement = loopCounter.getCounterDecrement(); // Create reverse pass loop body statements by arranging various From 6d6ee8540dbbfd66d7b1580da2679f40ab14be07 Mon Sep 17 00:00:00 2001 From: kchristin22 Date: Thu, 4 Apr 2024 11:24:05 +0300 Subject: [PATCH 02/11] Update comment on step increment --- lib/Differentiator/ReverseModeVisitor.cpp | 4 +-- test/Analyses/TBR.cpp | 4 +-- test/Gradient/Loops.C | 36 +++++++++++------------ 3 files changed, 21 insertions(+), 23 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 118766f6e..c8f457f50 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -4109,9 +4109,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); PopBreakContStmtHandler(); - // Increment statement in the for-loop is only executed if the iteration - // did not end with a break/continue statement. Therefore, forLoopIncDiff - // should be inside the last switch case in the reverse pass. + // Increment statement in the for-loop is executed for every case if (forLoopIncDiff) { if (bodyDiff.getStmt_dx()) { bodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt( diff --git a/test/Analyses/TBR.cpp b/test/Analyses/TBR.cpp index 21f4c5b9f..b16ee6f12 100644 --- a/test/Analyses/TBR.cpp +++ b/test/Analyses/TBR.cpp @@ -82,10 +82,10 @@ double f2(double val) { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } +//CHECK-NEXT: --i; //CHECK-NEXT: switch (clad::pop(_t1)) { //CHECK-NEXT: case {{2U|2UL}}: //CHECK-NEXT: ; -//CHECK-NEXT: --i; //CHECK-NEXT: { //CHECK-NEXT: double _r_d0 = _d_res; //CHECK-NEXT: _d_i += _r_d0 * val; @@ -167,6 +167,6 @@ double f3 (double x){ int main() { double result[3] = {}; TEST(f1, 3); // CHECK-EXEC: {27.00} - TEST(f2, 3); // CHECK-EXEC: {9.00} + TEST(f2, 3); // CHECK-EXEC: {7.00} TEST(f3, 3); // CHECK-EXEC: {2.00} } diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index 8586aac60..77eed0f2c 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -1320,10 +1320,10 @@ double fn16(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } +// CHECK-NEXT: --ii; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{3U|3UL|3ULL}}: // CHECK-NEXT: ; -// CHECK-NEXT: --ii; // CHECK-NEXT: { // CHECK-NEXT: res = clad::pop(_t4); // CHECK-NEXT: double _r_d2 = _d_res; @@ -1443,10 +1443,10 @@ double fn17(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } +// CHECK-NEXT: --ii; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: // CHECK-NEXT: ; -// CHECK-NEXT: --ii; // CHECK-NEXT: { // CHECK-NEXT: while (clad::back(_t3)) // CHECK-NEXT: { @@ -1559,10 +1559,10 @@ double fn18(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } +// CHECK-NEXT: --counter; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{3U|3UL|3ULL}}: // CHECK-NEXT: ; -// CHECK-NEXT: --counter; // CHECK-NEXT: if (clad::back(_cond0)) { // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; @@ -1898,10 +1898,10 @@ double fn23(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } +// CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: // CHECK-NEXT: ; -// CHECK-NEXT: --c; // CHECK-NEXT: { // CHECK-NEXT: if (clad::back(_cond0)) // CHECK-NEXT: case {{1U|1UL|1ULL}}: @@ -2009,10 +2009,10 @@ double fn25(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } +// CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t3)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: // CHECK-NEXT: ; -// CHECK-NEXT: --c; // CHECK-NEXT: { // CHECK-NEXT: if (clad::back(_cond0)) { // CHECK-NEXT: case {{1U|1UL|1ULL}}: @@ -2080,19 +2080,19 @@ double fn26(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t2); +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: _d_res = 0.; +// CHECK-NEXT: *_d_i += 7 * _r_d1 * j; +// CHECK-NEXT: *_d_j += 7 * i * _r_d1; +// CHECK-NEXT: _d_c += 0; +// CHECK-NEXT: --c; +// CHECK-NEXT: } // CHECK-NEXT: switch (clad::pop(_t3)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: // CHECK-NEXT: ; // CHECK-NEXT: { -// CHECK-NEXT: res = clad::pop(_t2); -// CHECK-NEXT: double _r_d1 = _d_res; -// CHECK-NEXT: _d_res = 0.; -// CHECK-NEXT: *_d_i += 7 * _r_d1 * j; -// CHECK-NEXT: *_d_j += 7 * i * _r_d1; -// CHECK-NEXT: _d_c += 0; -// CHECK-NEXT: --c; -// CHECK-NEXT: } -// CHECK-NEXT: { // CHECK-NEXT: if (clad::back(_cond0)) // CHECK-NEXT: case {{1U|1UL|1ULL}}: // CHECK-NEXT: ; @@ -2155,10 +2155,10 @@ double fn27(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } +// CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: // CHECK-NEXT: ; -// CHECK-NEXT: --c; // CHECK-NEXT: { // CHECK-NEXT: res = clad::pop(_t3); // CHECK-NEXT: double _r_d1 = _d_res; @@ -2480,10 +2480,10 @@ double fn32(double i, double j) { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } +//CHECK-NEXT: --c; //CHECK-NEXT: switch (clad::pop(_t8)) { //CHECK-NEXT: case {{2U|2UL|2ULL}}: //CHECK-NEXT: ; -//CHECK-NEXT: --c; //CHECK-NEXT: { //CHECK-NEXT: if (clad::back(_cond1)) { //CHECK-NEXT: case {{1U|1UL|1ULL}}: @@ -2509,10 +2509,10 @@ double fn32(double i, double j) { //CHECK-NEXT: if (!clad::back(_t2)) //CHECK-NEXT: break; //CHECK-NEXT: } +//CHECK-NEXT: --d; //CHECK-NEXT: switch (clad::pop(_t6)) { //CHECK-NEXT: case {{2U|2UL|2ULL}}: //CHECK-NEXT: ; -//CHECK-NEXT: --d; //CHECK-NEXT: { //CHECK-NEXT: if (clad::back(_cond0)) { //CHECK-NEXT: case {{1U|1UL|1ULL}}: @@ -2631,10 +2631,10 @@ double fn33(double i, double j) { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } +//CHECK-NEXT: --c; //CHECK-NEXT: switch (clad::pop(_t4)) { //CHECK-NEXT: case {{3U|3UL|3ULL}}: //CHECK-NEXT: ; -//CHECK-NEXT: --c; //CHECK-NEXT: { //CHECK-NEXT: if (clad::back(_cond5)) { //CHECK-NEXT: case {{2U|2UL|2ULL}}: From d713ef43b38fa0cc46c686f8a6988ba919d092be Mon Sep 17 00:00:00 2001 From: kchristin22 Date: Thu, 4 Apr 2024 11:52:04 +0300 Subject: [PATCH 03/11] Fix format conflicts --- include/clad/Differentiator/Compatibility.h | 10 +++++----- lib/Differentiator/ReverseModeVisitor.cpp | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/clad/Differentiator/Compatibility.h b/include/clad/Differentiator/Compatibility.h index efd3d629c..4d2b9b22a 100644 --- a/include/clad/Differentiator/Compatibility.h +++ b/include/clad/Differentiator/Compatibility.h @@ -284,7 +284,7 @@ static inline SwitchStmt* SwitchStmt_Create(const ASTContext &Ctx, { #if CLANG_VERSION_MAJOR < 12 - return SwitchStmt::Create(Ctx, Init, Var, Cond); + return SwitchStmt::Create(Ctx, Init, Var, Cond); #elif CLANG_VERSION_MAJOR >= 12 return SwitchStmt::Create(Ctx, Init, Var, Cond, LParenLoc, RParenLoc); #endif @@ -342,7 +342,7 @@ getConstantArrayType(const ASTContext& Ctx, QualType EltTy, #if CLANG_VERSION_MAJOR < 10 return Ctx.getConstantArrayType(EltTy, ArySize, ASM, IndexTypeQuals); #elif CLANG_VERSION_MAJOR >= 10 - return Ctx.getConstantArrayType(EltTy, ArySize, SizeExpr, ASM, + return Ctx.getConstantArrayType(EltTy, ArySize, SizeExpr, ASM, IndexTypeQuals); #endif } @@ -351,7 +351,7 @@ static inline QualType getConstantArrayType(const ASTContext& Ctx, QualType EltTy, const APInt& ArySize, const Expr* SizeExpr, clang::ArraySizeModifier ASM, unsigned IndexTypeQuals) { - return Ctx.getConstantArrayType(EltTy, ArySize, SizeExpr, ASM, + return Ctx.getConstantArrayType(EltTy, ArySize, SizeExpr, ASM, IndexTypeQuals); } #endif @@ -433,8 +433,8 @@ getConstantArrayType(const ASTContext& Ctx, QualType EltTy, /// Clang < 9, do not provide `Sema::BuildCXXThisExpr` function. static inline CXXThisExpr* Sema_BuildCXXThisExpr(Sema& SemaRef, const CXXMethodDecl* method) { - auto thisType = method->getThisType(); - SourceLocation noLoc; + auto thisType = method->getThisType(); + SourceLocation noLoc; #if CLANG_VERSION_MAJOR >= 9 return cast( SemaRef.BuildCXXThisExpr(noLoc, thisType, /*IsImplicit=*/true)); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index c8f457f50..9cc5a2791 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -4108,7 +4108,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, activeBreakContHandler->EndCFSwitchStmtScope(); activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); PopBreakContStmtHandler(); - + // Increment statement in the for-loop is executed for every case if (forLoopIncDiff) { if (bodyDiff.getStmt_dx()) { From 8cd23967455ee4d890eb8d2293d90a64f2be74d2 Mon Sep 17 00:00:00 2001 From: kchristin22 Date: Thu, 4 Apr 2024 12:10:09 +0300 Subject: [PATCH 04/11] Revert formatting changes by running git clang-format locally --- include/clad/Differentiator/Compatibility.h | 10 ++--- include/clad/Differentiator/Differentiator.h | 7 ++- .../clad/Differentiator/ReverseModeVisitor.h | 6 +++ include/clad/Differentiator/VisitorBase.h | 3 +- lib/Differentiator/ReverseModeVisitor.cpp | 42 +++++++++++++++--- lib/Differentiator/VisitorBase.cpp | 7 +++ test/Gradient/Loops.C | 43 +++++++++++-------- 7 files changed, 89 insertions(+), 29 deletions(-) diff --git a/include/clad/Differentiator/Compatibility.h b/include/clad/Differentiator/Compatibility.h index 4d2b9b22a..efd3d629c 100644 --- a/include/clad/Differentiator/Compatibility.h +++ b/include/clad/Differentiator/Compatibility.h @@ -284,7 +284,7 @@ static inline SwitchStmt* SwitchStmt_Create(const ASTContext &Ctx, { #if CLANG_VERSION_MAJOR < 12 - return SwitchStmt::Create(Ctx, Init, Var, Cond); + return SwitchStmt::Create(Ctx, Init, Var, Cond); #elif CLANG_VERSION_MAJOR >= 12 return SwitchStmt::Create(Ctx, Init, Var, Cond, LParenLoc, RParenLoc); #endif @@ -342,7 +342,7 @@ getConstantArrayType(const ASTContext& Ctx, QualType EltTy, #if CLANG_VERSION_MAJOR < 10 return Ctx.getConstantArrayType(EltTy, ArySize, ASM, IndexTypeQuals); #elif CLANG_VERSION_MAJOR >= 10 - return Ctx.getConstantArrayType(EltTy, ArySize, SizeExpr, ASM, + return Ctx.getConstantArrayType(EltTy, ArySize, SizeExpr, ASM, IndexTypeQuals); #endif } @@ -351,7 +351,7 @@ static inline QualType getConstantArrayType(const ASTContext& Ctx, QualType EltTy, const APInt& ArySize, const Expr* SizeExpr, clang::ArraySizeModifier ASM, unsigned IndexTypeQuals) { - return Ctx.getConstantArrayType(EltTy, ArySize, SizeExpr, ASM, + return Ctx.getConstantArrayType(EltTy, ArySize, SizeExpr, ASM, IndexTypeQuals); } #endif @@ -433,8 +433,8 @@ getConstantArrayType(const ASTContext& Ctx, QualType EltTy, /// Clang < 9, do not provide `Sema::BuildCXXThisExpr` function. static inline CXXThisExpr* Sema_BuildCXXThisExpr(Sema& SemaRef, const CXXMethodDecl* method) { - auto thisType = method->getThisType(); - SourceLocation noLoc; + auto thisType = method->getThisType(); + SourceLocation noLoc; #if CLANG_VERSION_MAJOR >= 9 return cast( SemaRef.BuildCXXThisExpr(noLoc, thisType, /*IsImplicit=*/true)); diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index cca1cd5cf..957e00597 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -78,6 +78,11 @@ CUDA_HOST_DEVICE T push(tape& to, ArgsT... val) { return of.back(); } + /// Return the size of the tape. + template CUDA_HOST_DEVICE std::size_t size(tape& of) { + return of.size(); + } + /// The purpose of this function is to initialize adjoints /// (or all of its differentiable fields) with 0. // FIXME: Add support for objects. @@ -88,7 +93,7 @@ CUDA_HOST_DEVICE T push(tape& to, ArgsT... val) { /// N. template CUDA_HOST_DEVICE void zero_init(T* x, std::size_t N) { for (std::size_t i = 0; i < N; ++i) - zero_init(x[i]); + zero_init(x[i]); } /// Initialize a const sized array. diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index c9fd1296f..4d0cb1409 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -318,6 +318,8 @@ namespace clad { /// (clad::back(Ref)). Since it is required only rarely, it is built on /// demand in the method. clang::Expr* Last(); + /// A request to get the size of the tape (clad::size(Ref)). + clang::Expr* Size(); }; /// Make a clad::tape to store variables. @@ -655,6 +657,10 @@ namespace clad { /// by their actual values respectively clang::Expr* CreateCFTapeBackExprForCurrentCase(); + /// Builds and returns `clad::size(TapeRef) != 0` expression, + /// where `TapeRef` is replaced by its actual value + clang::Expr* CreateCFTapeSizeExprForCurrentCase(); + /// Does final modifications on forward and reverse blocks /// so that `break` and `continue` statements are handled /// accurately. diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 210f82112..78cf720a7 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -460,9 +460,10 @@ namespace clad { clang::TemplateDecl* GetCladTapeDecl(); /// Perform a lookup into clad namespace for an entity with given name. clang::LookupResult LookupCladTapeMethod(llvm::StringRef name); - /// Perform lookup into clad namespace for push/pop/back. Returns + /// Perform lookup into clad namespace for push/pop/back/size. Returns /// LookupResult, which is will be resolved later (which is handy since they /// are templates). + clang::LookupResult& GetCladTapeSize(); clang::LookupResult& GetCladTapePush(); clang::LookupResult& GetCladTapePop(); clang::LookupResult& GetCladTapeBack(); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 9cc5a2791..d7d5c6e26 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -71,6 +71,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return Call; } + Expr* ReverseModeVisitor::CladTapeResult::Size() { + LookupResult& TapeSize = V.GetCladTapeSize(); + CXXScopeSpec CSS; + CSS.Extend(V.m_Context, V.GetCladNamespace(), noLoc, noLoc); + Expr* SizeDRE = V.m_Sema + .BuildDeclarationNameExpr(CSS, TapeSize, + /*AcceptInvalidDecl=*/false) + .get(); + Expr* Call = + V.m_Sema.ActOnCallExpr(V.getCurrentScope(), SizeDRE, noLoc, Ref, noLoc) + .get(); + return Call; + } + ReverseModeVisitor::CladTapeResult ReverseModeVisitor::MakeCladTapeFor(Expr* E, llvm::StringRef prefix) { assert(E && "must be provided"); @@ -4108,14 +4122,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, activeBreakContHandler->EndCFSwitchStmtScope(); activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); PopBreakContStmtHandler(); - + // Increment statement in the for-loop is executed for every case if (forLoopIncDiff) { + Stmt* forLoopIncDiffExpr = forLoopIncDiff; + if (m_CurrentBreakFlagExpr) { + forLoopIncDiffExpr = clad_compat::IfStmt_Create( + m_Context, noLoc, false, nullptr, nullptr, m_CurrentBreakFlagExpr, + noLoc, noLoc, forLoopIncDiff, noLoc, nullptr); + } if (bodyDiff.getStmt_dx()) { bodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt( - m_Context, bodyDiff.getStmt_dx(), forLoopIncDiff)); + m_Context, bodyDiff.getStmt_dx(), forLoopIncDiffExpr)); } else { - bodyDiff.updateStmtDx(forLoopIncDiff); + bodyDiff.updateStmtDx(forLoopIncDiffExpr); } } @@ -4163,13 +4183,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (isInsideLoop && !activeBreakContHandler->m_IsInvokedBySwitchStmt) { Expr* tapeBackExprForCurrentCase = activeBreakContHandler->CreateCFTapeBackExprForCurrentCase(); + Expr* tapeSizeExprForCurrentCase = + activeBreakContHandler->CreateCFTapeSizeExprForCurrentCase(); + Expr* currentBreakFlagExpr = + BuildOp(BinaryOperatorKind::BO_LAnd, tapeSizeExprForCurrentCase, + tapeBackExprForCurrentCase); if (m_CurrentBreakFlagExpr) { m_CurrentBreakFlagExpr = BuildOp(BinaryOperatorKind::BO_LAnd, m_CurrentBreakFlagExpr, - tapeBackExprForCurrentCase); + currentBreakFlagExpr); } else { - m_CurrentBreakFlagExpr = tapeBackExprForCurrentCase; + m_CurrentBreakFlagExpr = currentBreakFlagExpr; } } addToCurrentBlock(pushExprToCurrentCase); @@ -4248,6 +4273,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return CreateCFTapePushExpr(m_CaseCounter); } + Expr* ReverseModeVisitor::BreakContStmtHandler:: + CreateCFTapeSizeExprForCurrentCase() { + return m_RMV.BuildOp(BinaryOperatorKind::BO_NE, m_ControlFlowTape->Size(), + ConstantFolder::synthesizeLiteral( + m_RMV.m_Context.IntTy, m_RMV.m_Context, 0)); + } + void ReverseModeVisitor::BreakContStmtHandler::UpdateForwAndRevBlocks( StmtDiff& bodyDiff) { if (m_SwitchCases.empty() && !m_IsInvokedBySwitchStmt) diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index e8fce3628..37c4bd0ea 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -580,6 +580,13 @@ namespace clad { return clad_compat::llvm_Optional_GetValue(Result); } + LookupResult& VisitorBase::GetCladTapeSize() { + static clad_compat::llvm_Optional Result{}; + if (!Result) + Result = LookupCladTapeMethod("size"); + return clad_compat::llvm_Optional_GetValue(Result); + } + QualType VisitorBase::GetCladTapeOfType(QualType T) { return InstantiateTemplate(GetCladTapeDecl(), {T}); } diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index 77eed0f2c..f73ab4850 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -1320,7 +1320,8 @@ double fn16(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: --ii; +// CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 1) +// CHECK-NEXT: --ii; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{3U|3UL|3ULL}}: // CHECK-NEXT: ; @@ -1443,7 +1444,8 @@ double fn17(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: --ii; +// CHECK-NEXT: if (clad::size(_t5) != 0 && clad::back(_t5) != 1) +// CHECK-NEXT: --ii; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: // CHECK-NEXT: ; @@ -1559,7 +1561,8 @@ double fn18(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: --counter; +// CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 2) +// CHECK-NEXT: --counter; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{3U|3UL|3ULL}}: // CHECK-NEXT: ; @@ -1888,7 +1891,7 @@ double fn23(double i, double j) { // CHECK-NEXT: _d_res += 1; // CHECK-NEXT: for (;; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (clad::back(_t2) != 1)) { +// CHECK-NEXT: if (!_t0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) { // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; // CHECK-NEXT: _d_res = 0.; @@ -1898,7 +1901,8 @@ double fn23(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: --c; +// CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 1) +// CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: // CHECK-NEXT: ; @@ -1999,7 +2003,7 @@ double fn25(double i, double j) { // CHECK-NEXT: _d_res += 1; // CHECK-NEXT: for (;; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (clad::back(_t3) != 1)) { +// CHECK-NEXT: if (!_t0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1)) { // CHECK-NEXT: _d_res += 0; // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; @@ -2009,7 +2013,8 @@ double fn25(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: --c; +// CHECK-NEXT: if (clad::size(_t3) != 0 && clad::back(_t3) != 1) +// CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t3)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: // CHECK-NEXT: ; @@ -2071,7 +2076,7 @@ double fn26(double i, double j) { // CHECK-NEXT: _d_res += 1; // CHECK-NEXT: for (;; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (clad::back(_t3) != 1)) { +// CHECK-NEXT: if (!_t0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1)) { // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; // CHECK-NEXT: *_d_i += _r_d0 * j; @@ -2146,7 +2151,7 @@ double fn27(double i, double j) { // CHECK-NEXT: _d_res += 1; // CHECK-NEXT: for (;; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (clad::back(_t2) != 1)) { +// CHECK-NEXT: if (!_t0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) { // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; // CHECK-NEXT: *_d_i += _r_d0 * j; @@ -2155,7 +2160,8 @@ double fn27(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: --c; +// CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 1) +// CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: // CHECK-NEXT: ; @@ -2471,7 +2477,7 @@ double fn32(double i, double j) { //CHECK-NEXT: _d_res += 1; //CHECK-NEXT: for (;; _t0--) { //CHECK-NEXT: { -//CHECK-NEXT: if (!_t0 || (clad::back(_t8) != 1)) { +//CHECK-NEXT: if (!_t0 || (clad::size(_t8) != 0 && clad::back(_t8) != 1)) { //CHECK-NEXT: res = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_res; //CHECK-NEXT: *_d_i += _r_d0 * j; @@ -2480,7 +2486,8 @@ double fn32(double i, double j) { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } -//CHECK-NEXT: --c; +//CHECK-NEXT: if (clad::size(_t8) != 0 && clad::back(_t8) != 1) +//CHECK-NEXT: --c; //CHECK-NEXT: switch (clad::pop(_t8)) { //CHECK-NEXT: case {{2U|2UL|2ULL}}: //CHECK-NEXT: ; @@ -2500,7 +2507,7 @@ double fn32(double i, double j) { //CHECK-NEXT: { //CHECK-NEXT: for (;; clad::back(_t2)--) { //CHECK-NEXT: { -//CHECK-NEXT: if (!clad::back(_t2) || (clad::back(_t6) != 1)) { +//CHECK-NEXT: if (!clad::back(_t2) || (clad::size(_t6) != 0 && clad::back(_t6) != 1)) { //CHECK-NEXT: res = clad::pop(_t4); //CHECK-NEXT: double _r_d1 = _d_res; //CHECK-NEXT: *_d_i += _r_d1 * j; @@ -2509,7 +2516,8 @@ double fn32(double i, double j) { //CHECK-NEXT: if (!clad::back(_t2)) //CHECK-NEXT: break; //CHECK-NEXT: } -//CHECK-NEXT: --d; +//CHECK-NEXT: if (clad::size(_t6) != 0 && clad::back(_t6) != 1) +//CHECK-NEXT: --d; //CHECK-NEXT: switch (clad::pop(_t6)) { //CHECK-NEXT: case {{2U|2UL|2ULL}}: //CHECK-NEXT: ; @@ -2621,7 +2629,7 @@ double fn33(double i, double j) { //CHECK-NEXT: _d_res += 1; //CHECK-NEXT: for (;; _t0--) { //CHECK-NEXT: { -//CHECK-NEXT: if (!_t0 || (clad::back(_t4) != 1 && clad::back(_t4) != 2)) { +//CHECK-NEXT: if (!_t0 || (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::size(_t4) != 0 && clad::back(_t4) != 2)) { //CHECK-NEXT: res = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_res; //CHECK-NEXT: _d_res = 0.; @@ -2631,7 +2639,8 @@ double fn33(double i, double j) { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } -//CHECK-NEXT: --c; +//CHECK-NEXT: if (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::size(_t4) != 0 && clad::back(_t4) != 2) +//CHECK-NEXT: --c; //CHECK-NEXT: switch (clad::pop(_t4)) { //CHECK-NEXT: case {{3U|3UL|3ULL}}: //CHECK-NEXT: ; @@ -3189,7 +3198,7 @@ int main() { TEST_2(fn16, 3, 5); // CHECK-EXEC: {10.00, 6.00} TEST_2(fn17, 3, 5); // CHECK-EXEC: {15.00, 9.00} TEST_2(fn18, 3, 5); // CHECK-EXEC: {4.00, 4.00} - + INIT_GRADIENT(fn19, "arr"); double arr[5] = {}; From cd4f32b363dac30dbdc0ed0fd964d9f37bad3033 Mon Sep 17 00:00:00 2001 From: kchristin Date: Sat, 14 Sep 2024 14:10:45 +0300 Subject: [PATCH 05/11] Add test cases of issues 710 and 851 --- test/Gradient/Loops.C | 119 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index f73ab4850..e1f21c2ac 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -3142,6 +3142,123 @@ double fn39(double x) { //CHECK-NEXT: } //CHECK-NEXT: } +double fn40(double u, double v) { + double res = 11 * u; + for (int i = 0; i < 3; i++) { + res += u * i; + continue; + } + return res; +} + +// CHECK: void fn40_grad(double u, double v, double *_d_u, double *_d_v) { +//CHECK-NEXT: int _d_i = 0; +//CHECK-NEXT: int i = 0; +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: clad::tape _t2 = {}; +//CHECK-NEXT: double _d_res = 0.; +//CHECK-NEXT: double res = 11 * u; +//CHECK-NEXT: unsigned long _t0 = 0UL; +//CHECK-NEXT: for (i = 0; ; i++) { +//CHECK-NEXT: { +//CHECK-NEXT: if (!(i < 3)) +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, res); +//CHECK-NEXT: res += u * i; +//CHECK-NEXT: { +//CHECK-NEXT: clad::push(_t2, 1UL); +//CHECK-NEXT: continue; +//CHECK-NEXT: } +//CHECK-NEXT: clad::push(_t2, 2UL); +//CHECK-NEXT: } +//CHECK-NEXT: _d_res += 1; +//CHECK-NEXT: for (;; _t0--) { +//CHECK-NEXT: { +//CHECK-NEXT: if (!_t0) +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: i--; +//CHECK-NEXT: switch (clad::pop(_t2)) { +//CHECK-NEXT: case 2UL: +//CHECK-NEXT: ; +//CHECK-NEXT: case 1UL: +//CHECK-NEXT: ; +//CHECK-NEXT: { +//CHECK-NEXT: res = clad::pop(_t1); +//CHECK-NEXT: double _r_d0 = _d_res; +//CHECK-NEXT: *_d_u += _r_d0 * i; +//CHECK-NEXT: _d_i += u * _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: *_d_u += 11 * _d_res; +//CHECK-NEXT:} + +double fn41(double u, double v) { + double res = 0; + for (int i = 1; i < 3; i++) { + res += i * u; + if (i == 1) + break; + } + return res; +} + +//CHECK: void fn41_grad(double u, double v, double *_d_u, double *_d_v) { +//CHECK-NEXT: int _d_i = 0; +//CHECK-NEXT: int i = 0; +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: clad::tape _cond0 = {}; +//CHECK-NEXT: clad::tape _t2 = {}; +//CHECK-NEXT: double _d_res = 0.; +//CHECK-NEXT: double res = 0; +//CHECK-NEXT: unsigned long _t0 = 0UL; +//CHECK-NEXT: for (i = 1; ; i++) { +//CHECK-NEXT: { +//CHECK-NEXT: if (!(i < 3)) +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, res); +//CHECK-NEXT: res += i * u; +//CHECK-NEXT: { +//CHECK-NEXT: clad::push(_cond0, i == 1); +//CHECK-NEXT: if (clad::back(_cond0)) { +//CHECK-NEXT: clad::push(_t2, 1UL); +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: clad::push(_t2, 2UL); +//CHECK-NEXT: } +//CHECK-NEXT: _d_res += 1; +//CHECK-NEXT: for (;; _t0--) { +//CHECK-NEXT: { +//CHECK-NEXT: if (!_t0) +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 1) +//CHECK-NEXT: i--; +//CHECK-NEXT: switch (clad::pop(_t2)) { +//CHECK-NEXT: case 2UL: +//CHECK-NEXT: ; +//CHECK-NEXT: { +//CHECK-NEXT: if (clad::back(_cond0)) +//CHECK-NEXT: case 1UL: +//CHECK-NEXT: ; +//CHECK-NEXT: clad::pop(_cond0); +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: res = clad::pop(_t1); +//CHECK-NEXT: double _r_d0 = _d_res; +//CHECK-NEXT: _d_i += _r_d0 * u; +//CHECK-NEXT: *_d_u += i * _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT:} + #define TEST(F, x) { \ result[0] = 0; \ auto F##grad = clad::gradient(F);\ @@ -3232,6 +3349,8 @@ int main() { TEST_2(fn37, 1, 1); // CHECK-EXEC: {1.00, 1.00} TEST_2(fn38, 6, 3); // CHECK-EXEC: {1.00, 1.00} TEST(fn39, 9); // CHECK-EXEC: {6.00} + TEST_2(fn40, 2, 3); // CHECK-EXEC: {14.00, 0.00} + TEST_2(fn41, 2, 3); // CHECK-EXEC: {1.00, 0.00} } //CHECK: void sq_pullback(double x, double _d_y, double *_d_x) { From 71dfc9314c4f72f9602b86e7461b36d9bf2d03ce Mon Sep 17 00:00:00 2001 From: kchristin Date: Sat, 14 Sep 2024 14:28:49 +0300 Subject: [PATCH 06/11] Fix size_t values for Arch tests --- test/Gradient/Loops.C | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index e1f21c2ac..98ab93da6 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -3158,7 +3158,7 @@ double fn40(double u, double v) { //CHECK-NEXT: clad::tape _t2 = {}; //CHECK-NEXT: double _d_res = 0.; //CHECK-NEXT: double res = 11 * u; -//CHECK-NEXT: unsigned long _t0 = 0UL; +//CHECK-NEXT: unsigned long _t0 = {{0U|0UL}}; //CHECK-NEXT: for (i = 0; ; i++) { //CHECK-NEXT: { //CHECK-NEXT: if (!(i < 3)) @@ -3168,10 +3168,10 @@ double fn40(double u, double v) { //CHECK-NEXT: clad::push(_t1, res); //CHECK-NEXT: res += u * i; //CHECK-NEXT: { -//CHECK-NEXT: clad::push(_t2, 1UL); +//CHECK-NEXT: clad::push(_t2, {{1U|1UL}}); //CHECK-NEXT: continue; //CHECK-NEXT: } -//CHECK-NEXT: clad::push(_t2, 2UL); +//CHECK-NEXT: clad::push(_t2, {{2U|2UL}}); //CHECK-NEXT: } //CHECK-NEXT: _d_res += 1; //CHECK-NEXT: for (;; _t0--) { @@ -3181,9 +3181,9 @@ double fn40(double u, double v) { //CHECK-NEXT: } //CHECK-NEXT: i--; //CHECK-NEXT: switch (clad::pop(_t2)) { -//CHECK-NEXT: case 2UL: +//CHECK-NEXT: case {{2U|2UL}}: //CHECK-NEXT: ; -//CHECK-NEXT: case 1UL: +//CHECK-NEXT: case {{1U|1UL}}: //CHECK-NEXT: ; //CHECK-NEXT: { //CHECK-NEXT: res = clad::pop(_t1); @@ -3214,7 +3214,7 @@ double fn41(double u, double v) { //CHECK-NEXT: clad::tape _t2 = {}; //CHECK-NEXT: double _d_res = 0.; //CHECK-NEXT: double res = 0; -//CHECK-NEXT: unsigned long _t0 = 0UL; +//CHECK-NEXT: unsigned long _t0 = {{0U|0UL}}; //CHECK-NEXT: for (i = 1; ; i++) { //CHECK-NEXT: { //CHECK-NEXT: if (!(i < 3)) @@ -3226,11 +3226,11 @@ double fn41(double u, double v) { //CHECK-NEXT: { //CHECK-NEXT: clad::push(_cond0, i == 1); //CHECK-NEXT: if (clad::back(_cond0)) { -//CHECK-NEXT: clad::push(_t2, 1UL); +//CHECK-NEXT: clad::push(_t2, {{1U|1UL}}); //CHECK-NEXT: break; //CHECK-NEXT: } //CHECK-NEXT: } -//CHECK-NEXT: clad::push(_t2, 2UL); +//CHECK-NEXT: clad::push(_t2, {{2U|2UL}}); //CHECK-NEXT: } //CHECK-NEXT: _d_res += 1; //CHECK-NEXT: for (;; _t0--) { @@ -3241,11 +3241,11 @@ double fn41(double u, double v) { //CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 1) //CHECK-NEXT: i--; //CHECK-NEXT: switch (clad::pop(_t2)) { -//CHECK-NEXT: case 2UL: +//CHECK-NEXT: case {{2U|2UL}}: //CHECK-NEXT: ; //CHECK-NEXT: { //CHECK-NEXT: if (clad::back(_cond0)) -//CHECK-NEXT: case 1UL: +//CHECK-NEXT: case {{1U|1UL}}: //CHECK-NEXT: ; //CHECK-NEXT: clad::pop(_cond0); //CHECK-NEXT: } From 6582cb119ee5862c67a095e7b1d166c53d5a62da Mon Sep 17 00:00:00 2001 From: kchristin Date: Sat, 14 Sep 2024 14:36:24 +0300 Subject: [PATCH 07/11] Fix unsigned var declarations for Arch tests --- test/Gradient/Loops.C | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index 98ab93da6..58716687c 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -3155,10 +3155,10 @@ double fn40(double u, double v) { //CHECK-NEXT: int _d_i = 0; //CHECK-NEXT: int i = 0; //CHECK-NEXT: clad::tape _t1 = {}; -//CHECK-NEXT: clad::tape _t2 = {}; +//CHECK-NEXT: clad::tape _t2 = {}; //CHECK-NEXT: double _d_res = 0.; //CHECK-NEXT: double res = 11 * u; -//CHECK-NEXT: unsigned long _t0 = {{0U|0UL}}; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; //CHECK-NEXT: for (i = 0; ; i++) { //CHECK-NEXT: { //CHECK-NEXT: if (!(i < 3)) @@ -3211,10 +3211,10 @@ double fn41(double u, double v) { //CHECK-NEXT: int i = 0; //CHECK-NEXT: clad::tape _t1 = {}; //CHECK-NEXT: clad::tape _cond0 = {}; -//CHECK-NEXT: clad::tape _t2 = {}; +//CHECK-NEXT: clad::tape _t2 = {}; //CHECK-NEXT: double _d_res = 0.; //CHECK-NEXT: double res = 0; -//CHECK-NEXT: unsigned long _t0 = {{0U|0UL}}; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; //CHECK-NEXT: for (i = 1; ; i++) { //CHECK-NEXT: { //CHECK-NEXT: if (!(i < 3)) From 906814c8fde1e28acc80c5b3783171ec349fb4ad Mon Sep 17 00:00:00 2001 From: kchristin Date: Sat, 26 Oct 2024 22:08:16 +0300 Subject: [PATCH 08/11] Avoid duplicate of tape size checking when there's a break stmt --- lib/Differentiator/ReverseModeVisitor.cpp | 21 ++++++++++----------- test/Gradient/Loops.C | 4 ++-- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index d7d5c6e26..f76a90374 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -4183,18 +4183,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (isInsideLoop && !activeBreakContHandler->m_IsInvokedBySwitchStmt) { Expr* tapeBackExprForCurrentCase = activeBreakContHandler->CreateCFTapeBackExprForCurrentCase(); - Expr* tapeSizeExprForCurrentCase = - activeBreakContHandler->CreateCFTapeSizeExprForCurrentCase(); - Expr* currentBreakFlagExpr = - BuildOp(BinaryOperatorKind::BO_LAnd, tapeSizeExprForCurrentCase, - tapeBackExprForCurrentCase); if (m_CurrentBreakFlagExpr) { m_CurrentBreakFlagExpr = BuildOp(BinaryOperatorKind::BO_LAnd, m_CurrentBreakFlagExpr, - currentBreakFlagExpr); - + tapeBackExprForCurrentCase); } else { - m_CurrentBreakFlagExpr = currentBreakFlagExpr; + Expr* tapeSizeExprForCurrentCase = + activeBreakContHandler->CreateCFTapeSizeExprForCurrentCase(); + m_CurrentBreakFlagExpr = + BuildOp(BinaryOperatorKind::BO_LAnd, tapeSizeExprForCurrentCase, + tapeBackExprForCurrentCase); } } addToCurrentBlock(pushExprToCurrentCase); @@ -4275,9 +4273,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Expr* ReverseModeVisitor::BreakContStmtHandler:: CreateCFTapeSizeExprForCurrentCase() { - return m_RMV.BuildOp(BinaryOperatorKind::BO_NE, m_ControlFlowTape->Size(), - ConstantFolder::synthesizeLiteral( - m_RMV.m_Context.IntTy, m_RMV.m_Context, 0)); + return m_RMV.BuildOp( + BinaryOperatorKind::BO_NE, m_ControlFlowTape->Size(), + ConstantFolder::synthesizeLiteral(m_RMV.m_Context.IntTy, + m_RMV.m_Context, /*val=*/0)); } void ReverseModeVisitor::BreakContStmtHandler::UpdateForwAndRevBlocks( diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index 58716687c..9d5dff546 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -2629,7 +2629,7 @@ double fn33(double i, double j) { //CHECK-NEXT: _d_res += 1; //CHECK-NEXT: for (;; _t0--) { //CHECK-NEXT: { -//CHECK-NEXT: if (!_t0 || (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::size(_t4) != 0 && clad::back(_t4) != 2)) { +//CHECK-NEXT: if (!_t0 || (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::back(_t4) != 2)) { //CHECK-NEXT: res = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_res; //CHECK-NEXT: _d_res = 0.; @@ -2639,7 +2639,7 @@ double fn33(double i, double j) { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } -//CHECK-NEXT: if (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::size(_t4) != 0 && clad::back(_t4) != 2) +//CHECK-NEXT: if (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::back(_t4) != 2) //CHECK-NEXT: --c; //CHECK-NEXT: switch (clad::pop(_t4)) { //CHECK-NEXT: case {{3U|3UL|3ULL}}: From 64493d115b0956b1e83166b523b1444859f5a393 Mon Sep 17 00:00:00 2001 From: kchristin Date: Sun, 27 Oct 2024 09:59:03 +0200 Subject: [PATCH 09/11] Optimize check of break branch using first iteration check in reverse pass --- .../clad/Differentiator/ReverseModeVisitor.h | 9 +++ lib/Differentiator/ReverseModeVisitor.cpp | 17 +++++- test/Gradient/Loops.C | 56 +++++++++---------- 3 files changed, 53 insertions(+), 29 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 4d0cb1409..7a291e95a 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -519,6 +519,7 @@ namespace clad { clang::Expr *m_Pop = nullptr; clang::Expr *m_Push = nullptr; ReverseModeVisitor& m_RMV; + clang::VarDecl* numRevIterations = nullptr; public: LoopCounter(ReverseModeVisitor& RMV); @@ -551,6 +552,14 @@ namespace clad { m_Ref, clang::Sema::ConditionKind::Boolean); } + + /// Sets the number of reverse iterations to be executed. + clang::VarDecl* setNumRevIterations(clang::VarDecl* numRevIterations) { + return this->numRevIterations = numRevIterations; + } + + /// Returns the number of reverse iterations to be executed. + clang::VarDecl* getNumRevIterations() const { return numRevIterations; } }; /// Helper function to differentiate a loop body. diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index f76a90374..0f4f696a2 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1386,8 +1386,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, BodyDiff.updateStmtDx(utils::unwrapIfSingleStmt(revPassCondStmts)); } + Stmt* revInit = loopCounter.getNumRevIterations() + ? BuildDeclStmt(loopCounter.getNumRevIterations()) + : nullptr; Stmt* Reverse = new (m_Context) - ForStmt(m_Context, nullptr, nullptr, nullptr, CounterDecrement, + ForStmt(m_Context, revInit, nullptr, nullptr, CounterDecrement, BodyDiff.getStmt_dx(), noLoc, noLoc, noLoc); addToCurrentBlock(initResult.getStmt_dx(), direction::reverse); @@ -4123,10 +4126,22 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); PopBreakContStmtHandler(); + Expr* revCounter = loopCounter.getCounterConditionResult().get().second; + if (m_CurrentBreakFlagExpr) { + VarDecl* numRevIterations = BuildVarDecl(m_Context.getSizeType(), + "_numRevIterations", revCounter); + loopCounter.setNumRevIterations(numRevIterations); + } + // Increment statement in the for-loop is executed for every case if (forLoopIncDiff) { Stmt* forLoopIncDiffExpr = forLoopIncDiff; if (m_CurrentBreakFlagExpr) { + m_CurrentBreakFlagExpr = + BuildOp(BinaryOperatorKind::BO_LOr, + BuildOp(BinaryOperatorKind::BO_NE, revCounter, + BuildDeclRef(loopCounter.getNumRevIterations())), + BuildParens(m_CurrentBreakFlagExpr)); forLoopIncDiffExpr = clad_compat::IfStmt_Create( m_Context, noLoc, false, nullptr, nullptr, m_CurrentBreakFlagExpr, noLoc, noLoc, forLoopIncDiff, noLoc, nullptr); diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index 9d5dff546..af957e51c 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -1315,12 +1315,12 @@ double fn16(double i, double j) { // CHECK-NEXT: clad::push(_t2, {{3U|3UL|3ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 1) +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) // CHECK-NEXT: --ii; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{3U|3UL|3ULL}}: @@ -1439,12 +1439,12 @@ double fn17(double i, double j) { // CHECK-NEXT: clad::push(_t2, {{2U|2UL|2ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations1 = _t0; ; _t0--) { // CHECK-NEXT: { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (clad::size(_t5) != 0 && clad::back(_t5) != 1) +// CHECK-NEXT: if (_t0 != _numRevIterations1 || (clad::size(_t5) != 0 && clad::back(_t5) != 1)) // CHECK-NEXT: --ii; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -1556,12 +1556,12 @@ double fn18(double i, double j) { // CHECK-NEXT: clad::push(_t2, {{3U|3UL|3ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 2) +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 2)) // CHECK-NEXT: --counter; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{3U|3UL|3ULL}}: @@ -1889,9 +1889,9 @@ double fn23(double i, double j) { // CHECK-NEXT: clad::push(_t2, {{2U|2UL|2ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) { +// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1))) { // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; // CHECK-NEXT: _d_res = 0.; @@ -1901,7 +1901,7 @@ double fn23(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 1) +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) // CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -2001,9 +2001,9 @@ double fn25(double i, double j) { // CHECK-NEXT: clad::push(_t3, {{2U|2UL|2ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1)) { +// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1))) { // CHECK-NEXT: _d_res += 0; // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; @@ -2013,7 +2013,7 @@ double fn25(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (clad::size(_t3) != 0 && clad::back(_t3) != 1) +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1)) // CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t3)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -2074,9 +2074,9 @@ double fn26(double i, double j) { // CHECK-NEXT: clad::push(_t3, {{2U|2UL|2ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1)) { +// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1))) { // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; // CHECK-NEXT: *_d_i += _r_d0 * j; @@ -2149,9 +2149,9 @@ double fn27(double i, double j) { // CHECK-NEXT: clad::push(_t2, {{2U|2UL|2ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) { +// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1))) { // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; // CHECK-NEXT: *_d_i += _r_d0 * j; @@ -2160,7 +2160,7 @@ double fn27(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 1) +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) // CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -2475,9 +2475,9 @@ double fn32(double i, double j) { //CHECK-NEXT: clad::push(_t8, {{2U|2UL|2ULL}}); //CHECK-NEXT: } //CHECK-NEXT: _d_res += 1; -//CHECK-NEXT: for (;; _t0--) { +//CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations1 = _t0; ; _t0--) { //CHECK-NEXT: { -//CHECK-NEXT: if (!_t0 || (clad::size(_t8) != 0 && clad::back(_t8) != 1)) { +//CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations1 || (clad::size(_t8) != 0 && clad::back(_t8) != 1))) { //CHECK-NEXT: res = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_res; //CHECK-NEXT: *_d_i += _r_d0 * j; @@ -2486,7 +2486,7 @@ double fn32(double i, double j) { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } -//CHECK-NEXT: if (clad::size(_t8) != 0 && clad::back(_t8) != 1) +//CHECK-NEXT: if (_t0 != _numRevIterations1 || (clad::size(_t8) != 0 && clad::back(_t8) != 1)) //CHECK-NEXT: --c; //CHECK-NEXT: switch (clad::pop(_t8)) { //CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -2505,9 +2505,9 @@ double fn32(double i, double j) { //CHECK-NEXT: clad::pop(_cond1); //CHECK-NEXT: } //CHECK-NEXT: { -//CHECK-NEXT: for (;; clad::back(_t2)--) { +//CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = clad::back(_t2); ; clad::back(_t2)--) { //CHECK-NEXT: { -//CHECK-NEXT: if (!clad::back(_t2) || (clad::size(_t6) != 0 && clad::back(_t6) != 1)) { +//CHECK-NEXT: if (!clad::back(_t2) || (clad::back(_t2) != _numRevIterations0 || (clad::size(_t6) != 0 && clad::back(_t6) != 1))) { //CHECK-NEXT: res = clad::pop(_t4); //CHECK-NEXT: double _r_d1 = _d_res; //CHECK-NEXT: *_d_i += _r_d1 * j; @@ -2516,7 +2516,7 @@ double fn32(double i, double j) { //CHECK-NEXT: if (!clad::back(_t2)) //CHECK-NEXT: break; //CHECK-NEXT: } -//CHECK-NEXT: if (clad::size(_t6) != 0 && clad::back(_t6) != 1) +//CHECK-NEXT: if (clad::back(_t2) != _numRevIterations0 || (clad::size(_t6) != 0 && clad::back(_t6) != 1)) //CHECK-NEXT: --d; //CHECK-NEXT: switch (clad::pop(_t6)) { //CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -2627,9 +2627,9 @@ double fn33(double i, double j) { //CHECK-NEXT: clad::push(_t4, {{3U|3UL|3ULL}}); //CHECK-NEXT: } //CHECK-NEXT: _d_res += 1; -//CHECK-NEXT: for (;; _t0--) { +//CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { //CHECK-NEXT: { -//CHECK-NEXT: if (!_t0 || (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::back(_t4) != 2)) { +//CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::back(_t4) != 2))) { //CHECK-NEXT: res = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_res; //CHECK-NEXT: _d_res = 0.; @@ -2639,7 +2639,7 @@ double fn33(double i, double j) { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } -//CHECK-NEXT: if (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::back(_t4) != 2) +//CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::back(_t4) != 2)) //CHECK-NEXT: --c; //CHECK-NEXT: switch (clad::pop(_t4)) { //CHECK-NEXT: case {{3U|3UL|3ULL}}: @@ -3233,12 +3233,12 @@ double fn41(double u, double v) { //CHECK-NEXT: clad::push(_t2, {{2U|2UL}}); //CHECK-NEXT: } //CHECK-NEXT: _d_res += 1; -//CHECK-NEXT: for (;; _t0--) { +//CHECK-NEXT: for (unsigned {{int|long}} _numRevIterations0 = _t0; ; _t0--) { //CHECK-NEXT: { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } -//CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 1) +//CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) //CHECK-NEXT: i--; //CHECK-NEXT: switch (clad::pop(_t2)) { //CHECK-NEXT: case {{2U|2UL}}: From 9ffefc1f9cb1a19de18fe2acaa9bbe6be3aa8ebe Mon Sep 17 00:00:00 2001 From: kchristin Date: Sun, 27 Oct 2024 19:24:04 +0200 Subject: [PATCH 10/11] Improve set function of numRevIterations --- include/clad/Differentiator/ReverseModeVisitor.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 7a291e95a..79fd1c65f 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -519,7 +519,7 @@ namespace clad { clang::Expr *m_Pop = nullptr; clang::Expr *m_Push = nullptr; ReverseModeVisitor& m_RMV; - clang::VarDecl* numRevIterations = nullptr; + clang::VarDecl* m_numRevIterations = nullptr; public: LoopCounter(ReverseModeVisitor& RMV); @@ -554,12 +554,12 @@ namespace clad { } /// Sets the number of reverse iterations to be executed. - clang::VarDecl* setNumRevIterations(clang::VarDecl* numRevIterations) { - return this->numRevIterations = numRevIterations; + void setNumRevIterations(clang::VarDecl* numRevIterations) { + m_numRevIterations = numRevIterations; } /// Returns the number of reverse iterations to be executed. - clang::VarDecl* getNumRevIterations() const { return numRevIterations; } + clang::VarDecl* getNumRevIterations() const { return m_numRevIterations; } }; /// Helper function to differentiate a loop body. From 1bae2cd3246854507940152443d588922ad005e1 Mon Sep 17 00:00:00 2001 From: kchristin Date: Mon, 28 Oct 2024 11:42:40 +0200 Subject: [PATCH 11/11] Use different break cond for each loop and remove the now unnecessary tape size check --- include/clad/Differentiator/Differentiator.h | 5 --- .../clad/Differentiator/ReverseModeVisitor.h | 6 --- include/clad/Differentiator/VisitorBase.h | 3 +- lib/Differentiator/ReverseModeVisitor.cpp | 34 ++++------------- lib/Differentiator/VisitorBase.cpp | 7 ---- test/Gradient/Loops.C | 37 +++++++++---------- 6 files changed, 26 insertions(+), 66 deletions(-) diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index 957e00597..dfb900e1e 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -78,11 +78,6 @@ CUDA_HOST_DEVICE T push(tape& to, ArgsT... val) { return of.back(); } - /// Return the size of the tape. - template CUDA_HOST_DEVICE std::size_t size(tape& of) { - return of.size(); - } - /// The purpose of this function is to initialize adjoints /// (or all of its differentiable fields) with 0. // FIXME: Add support for objects. diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 79fd1c65f..9104ef79d 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -318,8 +318,6 @@ namespace clad { /// (clad::back(Ref)). Since it is required only rarely, it is built on /// demand in the method. clang::Expr* Last(); - /// A request to get the size of the tape (clad::size(Ref)). - clang::Expr* Size(); }; /// Make a clad::tape to store variables. @@ -666,10 +664,6 @@ namespace clad { /// by their actual values respectively clang::Expr* CreateCFTapeBackExprForCurrentCase(); - /// Builds and returns `clad::size(TapeRef) != 0` expression, - /// where `TapeRef` is replaced by its actual value - clang::Expr* CreateCFTapeSizeExprForCurrentCase(); - /// Does final modifications on forward and reverse blocks /// so that `break` and `continue` statements are handled /// accurately. diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 78cf720a7..210f82112 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -460,10 +460,9 @@ namespace clad { clang::TemplateDecl* GetCladTapeDecl(); /// Perform a lookup into clad namespace for an entity with given name. clang::LookupResult LookupCladTapeMethod(llvm::StringRef name); - /// Perform lookup into clad namespace for push/pop/back/size. Returns + /// Perform lookup into clad namespace for push/pop/back. Returns /// LookupResult, which is will be resolved later (which is handy since they /// are templates). - clang::LookupResult& GetCladTapeSize(); clang::LookupResult& GetCladTapePush(); clang::LookupResult& GetCladTapePop(); clang::LookupResult& GetCladTapeBack(); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 0f4f696a2..11c0c1981 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -71,20 +71,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return Call; } - Expr* ReverseModeVisitor::CladTapeResult::Size() { - LookupResult& TapeSize = V.GetCladTapeSize(); - CXXScopeSpec CSS; - CSS.Extend(V.m_Context, V.GetCladNamespace(), noLoc, noLoc); - Expr* SizeDRE = V.m_Sema - .BuildDeclarationNameExpr(CSS, TapeSize, - /*AcceptInvalidDecl=*/false) - .get(); - Expr* Call = - V.m_Sema.ActOnCallExpr(V.getCurrentScope(), SizeDRE, noLoc, Ref, noLoc) - .get(); - return Call; - } - ReverseModeVisitor::CladTapeResult ReverseModeVisitor::MakeCladTapeFor(Expr* E, llvm::StringRef prefix) { assert(E && "must be provided"); @@ -3808,6 +3794,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SaveAndRestore SaveIsInsideLoop(isInsideLoop); isInsideLoop = true; + llvm::SaveAndRestore SaveCurrentBreakFlagExpr( + m_CurrentBreakFlagExpr); + m_CurrentBreakFlagExpr = nullptr; Expr* condClone = (WS->getCond() ? Clone(WS->getCond()) : nullptr); const VarDecl* condVarDecl = WS->getConditionVariable(); @@ -3866,6 +3855,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SaveAndRestore SaveIsInsideLoop(isInsideLoop); isInsideLoop = true; + llvm::SaveAndRestore SaveCurrentBreakFlagExpr( + m_CurrentBreakFlagExpr); + m_CurrentBreakFlagExpr = nullptr; Expr* clonedCond = (DS->getCond() ? Clone(DS->getCond()) : nullptr); @@ -4203,11 +4195,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, BuildOp(BinaryOperatorKind::BO_LAnd, m_CurrentBreakFlagExpr, tapeBackExprForCurrentCase); } else { - Expr* tapeSizeExprForCurrentCase = - activeBreakContHandler->CreateCFTapeSizeExprForCurrentCase(); - m_CurrentBreakFlagExpr = - BuildOp(BinaryOperatorKind::BO_LAnd, tapeSizeExprForCurrentCase, - tapeBackExprForCurrentCase); + m_CurrentBreakFlagExpr = tapeBackExprForCurrentCase; } } addToCurrentBlock(pushExprToCurrentCase); @@ -4286,14 +4274,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return CreateCFTapePushExpr(m_CaseCounter); } - Expr* ReverseModeVisitor::BreakContStmtHandler:: - CreateCFTapeSizeExprForCurrentCase() { - return m_RMV.BuildOp( - BinaryOperatorKind::BO_NE, m_ControlFlowTape->Size(), - ConstantFolder::synthesizeLiteral(m_RMV.m_Context.IntTy, - m_RMV.m_Context, /*val=*/0)); - } - void ReverseModeVisitor::BreakContStmtHandler::UpdateForwAndRevBlocks( StmtDiff& bodyDiff) { if (m_SwitchCases.empty() && !m_IsInvokedBySwitchStmt) diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 37c4bd0ea..e8fce3628 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -580,13 +580,6 @@ namespace clad { return clad_compat::llvm_Optional_GetValue(Result); } - LookupResult& VisitorBase::GetCladTapeSize() { - static clad_compat::llvm_Optional Result{}; - if (!Result) - Result = LookupCladTapeMethod("size"); - return clad_compat::llvm_Optional_GetValue(Result); - } - QualType VisitorBase::GetCladTapeOfType(QualType T) { return InstantiateTemplate(GetCladTapeDecl(), {T}); } diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index af957e51c..5a44d7c50 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -1320,7 +1320,7 @@ double fn16(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t2) != 1)) // CHECK-NEXT: --ii; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{3U|3UL|3ULL}}: @@ -1439,13 +1439,12 @@ double fn17(double i, double j) { // CHECK-NEXT: clad::push(_t2, {{2U|2UL|2ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations1 = _t0; ; _t0--) { +// CHECK-NEXT: for (;; _t0--) { // CHECK-NEXT: { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (_t0 != _numRevIterations1 || (clad::size(_t5) != 0 && clad::back(_t5) != 1)) -// CHECK-NEXT: --ii; +// CHECK-NEXT: --ii; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: // CHECK-NEXT: ; @@ -1561,7 +1560,7 @@ double fn18(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 2)) +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t2) != 2)) // CHECK-NEXT: --counter; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{3U|3UL|3ULL}}: @@ -1891,7 +1890,7 @@ double fn23(double i, double j) { // CHECK-NEXT: _d_res += 1; // CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1))) { +// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::back(_t2) != 1))) { // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; // CHECK-NEXT: _d_res = 0.; @@ -1901,7 +1900,7 @@ double fn23(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t2) != 1)) // CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -2003,7 +2002,7 @@ double fn25(double i, double j) { // CHECK-NEXT: _d_res += 1; // CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1))) { +// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::back(_t3) != 1))) { // CHECK-NEXT: _d_res += 0; // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; @@ -2013,7 +2012,7 @@ double fn25(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1)) +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t3) != 1)) // CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t3)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -2076,7 +2075,7 @@ double fn26(double i, double j) { // CHECK-NEXT: _d_res += 1; // CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1))) { +// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::back(_t3) != 1))) { // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; // CHECK-NEXT: *_d_i += _r_d0 * j; @@ -2151,7 +2150,7 @@ double fn27(double i, double j) { // CHECK-NEXT: _d_res += 1; // CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1))) { +// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::back(_t2) != 1))) { // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; // CHECK-NEXT: *_d_i += _r_d0 * j; @@ -2160,7 +2159,7 @@ double fn27(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t2) != 1)) // CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -2477,7 +2476,7 @@ double fn32(double i, double j) { //CHECK-NEXT: _d_res += 1; //CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations1 = _t0; ; _t0--) { //CHECK-NEXT: { -//CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations1 || (clad::size(_t8) != 0 && clad::back(_t8) != 1))) { +//CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations1 || (clad::back(_t8) != 1))) { //CHECK-NEXT: res = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_res; //CHECK-NEXT: *_d_i += _r_d0 * j; @@ -2486,7 +2485,7 @@ double fn32(double i, double j) { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } -//CHECK-NEXT: if (_t0 != _numRevIterations1 || (clad::size(_t8) != 0 && clad::back(_t8) != 1)) +//CHECK-NEXT: if (_t0 != _numRevIterations1 || (clad::back(_t8) != 1)) //CHECK-NEXT: --c; //CHECK-NEXT: switch (clad::pop(_t8)) { //CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -2507,7 +2506,7 @@ double fn32(double i, double j) { //CHECK-NEXT: { //CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = clad::back(_t2); ; clad::back(_t2)--) { //CHECK-NEXT: { -//CHECK-NEXT: if (!clad::back(_t2) || (clad::back(_t2) != _numRevIterations0 || (clad::size(_t6) != 0 && clad::back(_t6) != 1))) { +//CHECK-NEXT: if (!clad::back(_t2) || (clad::back(_t2) != _numRevIterations0 || (clad::back(_t6) != 1))) { //CHECK-NEXT: res = clad::pop(_t4); //CHECK-NEXT: double _r_d1 = _d_res; //CHECK-NEXT: *_d_i += _r_d1 * j; @@ -2516,7 +2515,7 @@ double fn32(double i, double j) { //CHECK-NEXT: if (!clad::back(_t2)) //CHECK-NEXT: break; //CHECK-NEXT: } -//CHECK-NEXT: if (clad::back(_t2) != _numRevIterations0 || (clad::size(_t6) != 0 && clad::back(_t6) != 1)) +//CHECK-NEXT: if (clad::back(_t2) != _numRevIterations0 || (clad::back(_t6) != 1)) //CHECK-NEXT: --d; //CHECK-NEXT: switch (clad::pop(_t6)) { //CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -2629,7 +2628,7 @@ double fn33(double i, double j) { //CHECK-NEXT: _d_res += 1; //CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { //CHECK-NEXT: { -//CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::back(_t4) != 2))) { +//CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::back(_t4) != 1 && clad::back(_t4) != 2))) { //CHECK-NEXT: res = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_res; //CHECK-NEXT: _d_res = 0.; @@ -2639,7 +2638,7 @@ double fn33(double i, double j) { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } -//CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::back(_t4) != 2)) +//CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t4) != 1 && clad::back(_t4) != 2)) //CHECK-NEXT: --c; //CHECK-NEXT: switch (clad::pop(_t4)) { //CHECK-NEXT: case {{3U|3UL|3ULL}}: @@ -3238,7 +3237,7 @@ double fn41(double u, double v) { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } -//CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) +//CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t2) != 1)) //CHECK-NEXT: i--; //CHECK-NEXT: switch (clad::pop(_t2)) { //CHECK-NEXT: case {{2U|2UL}}: