From 05da570cca240800fbf865b389da54233628738d Mon Sep 17 00:00:00 2001 From: kchristin Date: Mon, 28 Oct 2024 11:42:40 +0200 Subject: [PATCH] Use different break cond for each loop and remove the now unnecessary tape size check --- include/clad/Differentiator/Differentiator.h | 5 --- .../clad/Differentiator/ReverseModeVisitor.h | 6 --- include/clad/Differentiator/VisitorBase.h | 3 +- lib/Differentiator/ReverseModeVisitor.cpp | 34 ++++------------- lib/Differentiator/VisitorBase.cpp | 7 ---- test/Gradient/Loops.C | 37 +++++++++---------- 6 files changed, 26 insertions(+), 66 deletions(-) diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index 957e00597..dfb900e1e 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -78,11 +78,6 @@ CUDA_HOST_DEVICE T push(tape& to, ArgsT... val) { return of.back(); } - /// Return the size of the tape. - template CUDA_HOST_DEVICE std::size_t size(tape& of) { - return of.size(); - } - /// The purpose of this function is to initialize adjoints /// (or all of its differentiable fields) with 0. // FIXME: Add support for objects. diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index dd8b9e014..29945fe7e 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -318,8 +318,6 @@ namespace clad { /// (clad::back(Ref)). Since it is required only rarely, it is built on /// demand in the method. clang::Expr* Last(); - /// A request to get the size of the tape (clad::size(Ref)). - clang::Expr* Size(); }; /// Make a clad::tape to store variables. @@ -665,10 +663,6 @@ namespace clad { /// by their actual values respectively clang::Expr* CreateCFTapeBackExprForCurrentCase(); - /// Builds and returns `clad::size(TapeRef) != 0` expression, - /// where `TapeRef` is replaced by its actual value - clang::Expr* CreateCFTapeSizeExprForCurrentCase(); - /// Does final modifications on forward and reverse blocks /// so that `break` and `continue` statements are handled /// accurately. diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 890bf51bb..dba1540a2 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -460,10 +460,9 @@ namespace clad { clang::TemplateDecl* GetCladTapeDecl(); /// Perform a lookup into clad namespace for an entity with given name. clang::LookupResult LookupCladTapeMethod(llvm::StringRef name); - /// Perform lookup into clad namespace for push/pop/back/size. Returns + /// Perform lookup into clad namespace for push/pop/back. Returns /// LookupResult, which is will be resolved later (which is handy since they /// are templates). - clang::LookupResult& GetCladTapeSize(); clang::LookupResult& GetCladTapePush(); clang::LookupResult& GetCladTapePop(); clang::LookupResult& GetCladTapeBack(); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 139711e17..4edb74b1a 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -71,20 +71,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return Call; } - Expr* ReverseModeVisitor::CladTapeResult::Size() { - LookupResult& TapeSize = V.GetCladTapeSize(); - CXXScopeSpec CSS; - CSS.Extend(V.m_Context, V.GetCladNamespace(), noLoc, noLoc); - Expr* SizeDRE = V.m_Sema - .BuildDeclarationNameExpr(CSS, TapeSize, - /*AcceptInvalidDecl=*/false) - .get(); - Expr* Call = - V.m_Sema.ActOnCallExpr(V.getCurrentScope(), SizeDRE, noLoc, Ref, noLoc) - .get(); - return Call; - } - ReverseModeVisitor::CladTapeResult ReverseModeVisitor::MakeCladTapeFor(Expr* E, llvm::StringRef prefix) { assert(E && "must be provided"); @@ -3790,6 +3776,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SaveAndRestore SaveIsInsideLoop(isInsideLoop); isInsideLoop = true; + llvm::SaveAndRestore SaveCurrentBreakFlagExpr( + m_CurrentBreakFlagExpr); + m_CurrentBreakFlagExpr = nullptr; Expr* condClone = (WS->getCond() ? Clone(WS->getCond()) : nullptr); const VarDecl* condVarDecl = WS->getConditionVariable(); @@ -3848,6 +3837,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SaveAndRestore SaveIsInsideLoop(isInsideLoop); isInsideLoop = true; + llvm::SaveAndRestore SaveCurrentBreakFlagExpr( + m_CurrentBreakFlagExpr); + m_CurrentBreakFlagExpr = nullptr; Expr* clonedCond = (DS->getCond() ? Clone(DS->getCond()) : nullptr); @@ -4185,11 +4177,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, BuildOp(BinaryOperatorKind::BO_LAnd, m_CurrentBreakFlagExpr, tapeBackExprForCurrentCase); } else { - Expr* tapeSizeExprForCurrentCase = - activeBreakContHandler->CreateCFTapeSizeExprForCurrentCase(); - m_CurrentBreakFlagExpr = - BuildOp(BinaryOperatorKind::BO_LAnd, tapeSizeExprForCurrentCase, - tapeBackExprForCurrentCase); + m_CurrentBreakFlagExpr = tapeBackExprForCurrentCase; } } addToCurrentBlock(pushExprToCurrentCase); @@ -4268,14 +4256,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return CreateCFTapePushExpr(m_CaseCounter); } - Expr* ReverseModeVisitor::BreakContStmtHandler:: - CreateCFTapeSizeExprForCurrentCase() { - return m_RMV.BuildOp( - BinaryOperatorKind::BO_NE, m_ControlFlowTape->Size(), - ConstantFolder::synthesizeLiteral(m_RMV.m_Context.IntTy, - m_RMV.m_Context, /*val=*/0)); - } - void ReverseModeVisitor::BreakContStmtHandler::UpdateForwAndRevBlocks( StmtDiff& bodyDiff) { if (m_SwitchCases.empty() && !m_IsInvokedBySwitchStmt) diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 2bb7ec973..6cd582270 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -580,13 +580,6 @@ namespace clad { return clad_compat::llvm_Optional_GetValue(Result); } - LookupResult& VisitorBase::GetCladTapeSize() { - static clad_compat::llvm_Optional Result{}; - if (!Result) - Result = LookupCladTapeMethod("size"); - return clad_compat::llvm_Optional_GetValue(Result); - } - QualType VisitorBase::GetCladTapeOfType(QualType T) { return InstantiateTemplate(GetCladTapeDecl(), {T}); } diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index af957e51c..5a44d7c50 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -1320,7 +1320,7 @@ double fn16(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t2) != 1)) // CHECK-NEXT: --ii; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{3U|3UL|3ULL}}: @@ -1439,13 +1439,12 @@ double fn17(double i, double j) { // CHECK-NEXT: clad::push(_t2, {{2U|2UL|2ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations1 = _t0; ; _t0--) { +// CHECK-NEXT: for (;; _t0--) { // CHECK-NEXT: { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (_t0 != _numRevIterations1 || (clad::size(_t5) != 0 && clad::back(_t5) != 1)) -// CHECK-NEXT: --ii; +// CHECK-NEXT: --ii; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: // CHECK-NEXT: ; @@ -1561,7 +1560,7 @@ double fn18(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 2)) +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t2) != 2)) // CHECK-NEXT: --counter; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{3U|3UL|3ULL}}: @@ -1891,7 +1890,7 @@ double fn23(double i, double j) { // CHECK-NEXT: _d_res += 1; // CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1))) { +// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::back(_t2) != 1))) { // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; // CHECK-NEXT: _d_res = 0.; @@ -1901,7 +1900,7 @@ double fn23(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t2) != 1)) // CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -2003,7 +2002,7 @@ double fn25(double i, double j) { // CHECK-NEXT: _d_res += 1; // CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1))) { +// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::back(_t3) != 1))) { // CHECK-NEXT: _d_res += 0; // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; @@ -2013,7 +2012,7 @@ double fn25(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1)) +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t3) != 1)) // CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t3)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -2076,7 +2075,7 @@ double fn26(double i, double j) { // CHECK-NEXT: _d_res += 1; // CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1))) { +// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::back(_t3) != 1))) { // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; // CHECK-NEXT: *_d_i += _r_d0 * j; @@ -2151,7 +2150,7 @@ double fn27(double i, double j) { // CHECK-NEXT: _d_res += 1; // CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1))) { +// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::back(_t2) != 1))) { // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; // CHECK-NEXT: *_d_i += _r_d0 * j; @@ -2160,7 +2159,7 @@ double fn27(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t2) != 1)) // CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -2477,7 +2476,7 @@ double fn32(double i, double j) { //CHECK-NEXT: _d_res += 1; //CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations1 = _t0; ; _t0--) { //CHECK-NEXT: { -//CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations1 || (clad::size(_t8) != 0 && clad::back(_t8) != 1))) { +//CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations1 || (clad::back(_t8) != 1))) { //CHECK-NEXT: res = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_res; //CHECK-NEXT: *_d_i += _r_d0 * j; @@ -2486,7 +2485,7 @@ double fn32(double i, double j) { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } -//CHECK-NEXT: if (_t0 != _numRevIterations1 || (clad::size(_t8) != 0 && clad::back(_t8) != 1)) +//CHECK-NEXT: if (_t0 != _numRevIterations1 || (clad::back(_t8) != 1)) //CHECK-NEXT: --c; //CHECK-NEXT: switch (clad::pop(_t8)) { //CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -2507,7 +2506,7 @@ double fn32(double i, double j) { //CHECK-NEXT: { //CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = clad::back(_t2); ; clad::back(_t2)--) { //CHECK-NEXT: { -//CHECK-NEXT: if (!clad::back(_t2) || (clad::back(_t2) != _numRevIterations0 || (clad::size(_t6) != 0 && clad::back(_t6) != 1))) { +//CHECK-NEXT: if (!clad::back(_t2) || (clad::back(_t2) != _numRevIterations0 || (clad::back(_t6) != 1))) { //CHECK-NEXT: res = clad::pop(_t4); //CHECK-NEXT: double _r_d1 = _d_res; //CHECK-NEXT: *_d_i += _r_d1 * j; @@ -2516,7 +2515,7 @@ double fn32(double i, double j) { //CHECK-NEXT: if (!clad::back(_t2)) //CHECK-NEXT: break; //CHECK-NEXT: } -//CHECK-NEXT: if (clad::back(_t2) != _numRevIterations0 || (clad::size(_t6) != 0 && clad::back(_t6) != 1)) +//CHECK-NEXT: if (clad::back(_t2) != _numRevIterations0 || (clad::back(_t6) != 1)) //CHECK-NEXT: --d; //CHECK-NEXT: switch (clad::pop(_t6)) { //CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -2629,7 +2628,7 @@ double fn33(double i, double j) { //CHECK-NEXT: _d_res += 1; //CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { //CHECK-NEXT: { -//CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::back(_t4) != 2))) { +//CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::back(_t4) != 1 && clad::back(_t4) != 2))) { //CHECK-NEXT: res = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_res; //CHECK-NEXT: _d_res = 0.; @@ -2639,7 +2638,7 @@ double fn33(double i, double j) { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } -//CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::back(_t4) != 2)) +//CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t4) != 1 && clad::back(_t4) != 2)) //CHECK-NEXT: --c; //CHECK-NEXT: switch (clad::pop(_t4)) { //CHECK-NEXT: case {{3U|3UL|3ULL}}: @@ -3238,7 +3237,7 @@ double fn41(double u, double v) { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } -//CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) +//CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t2) != 1)) //CHECK-NEXT: i--; //CHECK-NEXT: switch (clad::pop(_t2)) { //CHECK-NEXT: case {{2U|2UL}}: