From 0230d3aca879fa4c89aeaab30aa8162affbaa8a1 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Fri, 18 Aug 2023 11:12:22 +0300 Subject: [PATCH] Fix storing/restoring in loops. --- .../clad/Differentiator/ReverseModeVisitor.h | 7 ++- include/clad/Differentiator/TBRAnalyzer.h | 9 +-- lib/Differentiator/ErrorEstimator.cpp | 28 +--------- lib/Differentiator/ReverseModeVisitor.cpp | 55 ++++++++++++------- lib/Differentiator/TBRAnalyzer.cpp | 25 +++++++-- test/Misc/RunDemos.C | 4 +- 6 files changed, 68 insertions(+), 60 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 540b9d572..7c77dd46a 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -227,6 +227,11 @@ namespace clad { forceDeclCreation, IS); } + /// Based on To-Be-Recorded analysis performed before differentiation, + /// tells UsefulToStoreGlobal whether a variable with a given + /// SourceLocation has to be stored before changed or not. + std::map m_ToBeRecorded; + /// For an expr E, decides if it is useful to store it in a global temporary /// variable and replace E's further usage by a reference to that variable /// to avoid recomputiation. @@ -583,8 +588,6 @@ namespace clad { m_BreakContStmtHandlers.pop_back(); } - std::map m_ToBeRecorded; - /// Registers an external RMV source. /// /// Multiple external RMV source can be registered by calling this function diff --git a/include/clad/Differentiator/TBRAnalyzer.h b/include/clad/Differentiator/TBRAnalyzer.h index e05ee645c..978d86b18 100644 --- a/include/clad/Differentiator/TBRAnalyzer.h +++ b/include/clad/Differentiator/TBRAnalyzer.h @@ -1,8 +1,9 @@ -#ifndef CLAD_TBR_ANALYZER_H -#define CLAD_TBR_ANALYZER_H +#ifndef CLAD_DIFFERENTIATOR_TBRANALYZER_H +#define CLAD_DIFFERENTIATOR_TBRANALYZER_H #include "clang/AST/StmtVisitor.h" #include "clad/Differentiator/CladUtils.h" +#include "clad/Differentiator/Compatibility.h" #include #include @@ -17,7 +18,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// type keys. struct APIntHash { size_t operator()(const llvm::APInt& apint) const { - return std::hash{}(apint.toString(10, true)); + return llvm::hash_value(apint); } }; @@ -301,4 +302,4 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { }; } // end namespace clad -#endif // CLAD_TBR_ANALYZER_H +#endif // CLAD_DIFFERENTIATOR_TBRANALYZER_H diff --git a/lib/Differentiator/ErrorEstimator.cpp b/lib/Differentiator/ErrorEstimator.cpp index c2a2dd49e..cbab2a49e 100644 --- a/lib/Differentiator/ErrorEstimator.cpp +++ b/lib/Differentiator/ErrorEstimator.cpp @@ -83,33 +83,9 @@ void ErrorEstimationHandler::BuildFinalErrorStmt() { void ErrorEstimationHandler::AddErrorStmtToBlock(Expr* var, Expr* deltaVar, Expr* errorExpr, bool isInsideLoop /*=false*/) { + if (auto ASE = dyn_cast(var)) { - // If inside loop, the index has been pushed twice - // (once by ArraySubscriptExpr and the second time by us) - // pop and store it in a temporary variable to reuse later. - // FIXME: build add assign into he same expression i.e. - // _final_error += _delta_arr[pop(_t0)] += <-Error Expr-> - // to avoid storage of the pop value. - Expr* popVal = ASE->getIdx(); - if (isInsideLoop) { - LookupResult& Pop = m_RMV->GetCladTapePop(); - CXXScopeSpec CSS; - CSS.Extend(m_RMV->m_Context, m_RMV->GetCladNamespace(), noLoc, noLoc); - auto PopDRE = m_RMV->m_Sema - .BuildDeclarationNameExpr(CSS, Pop, - /*AcceptInvalidDecl=*/false) - .get(); - Expr* tapeRef = dyn_cast(popVal)->getArg(0); - popVal = m_RMV->m_Sema - .ActOnCallExpr(m_RMV->getCurrentScope(), PopDRE, noLoc, - tapeRef, noLoc) - .get(); - popVal = m_RMV->StoreAndRef(popVal, direction::reverse); - } - // If the variable declration refers to an array element - // create the suitable _delta_arr[i] (because we have not done - // this before). - deltaVar = getArraySubscriptExpr(deltaVar, popVal); + deltaVar = getArraySubscriptExpr(deltaVar, ASE->getIdx()); m_RMV->addToCurrentBlock(m_RMV->BuildOp(BO_AddAssign, deltaVar, errorExpr), direction::reverse); // immediately emit fin_err += delta_[]. diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 8aa7469b4..e71b34705 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1509,9 +1509,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // may be changed since we have no way to determine otherwise. // FIXME: We cannot use GlobalStoreAndRef to store a whole array so now // arrays are not stored. - StmtDiff argDiffStore = GlobalStoreAndRef( - argDiff.getExpr(), "_t", - /*force=*/passByRef && !argDiff.getExpr()->getType()->isArrayType()); + StmtDiff argDiffStore; + if (passByRef && !argDiff.getExpr()->getType()->isArrayType()) { + argDiffStore = + GlobalStoreAndRef(argDiff.getExpr(), "_t", /*force=*/true); + } else { + argDiffStore = {argDiff.getExpr(), argDiff.getExpr()}; + } + // We need to pass the actual argument in the cloned call expression, // instead of a temporary, for arguments passed by reference. This is // because, callee function may modify the argument passed as reference @@ -1993,11 +1998,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } else if (opCode == UO_PostInc || opCode == UO_PostDec) { diff = Visit(E, dfdx()); auto EStored = GlobalStoreAndRef(diff.getExpr()); - auto assign = - BuildOp(BinaryOperatorKind::BO_Assign, diff.getExpr(), EStored.getExpr_dx()); - if (isInsideLoop) - addToCurrentBlock(EStored.getExpr(), direction::forward); - addToCurrentBlock(assign, direction::reverse); + if (EStored.getExpr() != diff.getExpr()) { + auto assign = BuildOp(BinaryOperatorKind::BO_Assign, diff.getExpr(), + EStored.getExpr_dx()); + if (isInsideLoop) + addToCurrentBlock(EStored.getExpr(), direction::forward); + addToCurrentBlock(assign, direction::reverse); + } ResultRef = diff.getExpr_dx(); if (m_ExternalSource) @@ -2005,12 +2012,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } else if (opCode == UO_PreInc || opCode == UO_PreDec) { diff = Visit(E, dfdx()); auto EStored = GlobalStoreAndRef(diff.getExpr()); - auto assign = - BuildOp(BinaryOperatorKind::BO_Assign, diff.getExpr(), EStored.getExpr_dx()); - if (isInsideLoop) - addToCurrentBlock(EStored.getExpr(), direction::forward); - addToCurrentBlock(assign, direction::reverse); - + if (EStored.getExpr() != diff.getExpr()) { + auto assign = BuildOp(BinaryOperatorKind::BO_Assign, diff.getExpr(), + EStored.getExpr_dx()); + if (isInsideLoop) + addToCurrentBlock(EStored.getExpr(), direction::forward); + addToCurrentBlock(assign, direction::reverse); + } } else if (opCode == UnaryOperatorKind::UO_Real || opCode == UnaryOperatorKind::UO_Imag) { diff = VisitWithExplicitNoDfDx(E); @@ -2740,14 +2748,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return UsefulToStoreGlobal(UO->getSubExpr()); return true; } - if (isa(B)) { - auto ASE = cast(B); - return UsefulToStoreGlobal(ASE->getBase()) || UsefulToStoreGlobal(ASE->getIdx()); - } // We lack context to decide if this is useful to store or not. In the // current system that should have been decided by the parent expression. // FIXME: Here will be the entry point of the advanced activity analysis. - if (isa(B)) { + if (isa(B) /* || isa(B)*/) { // auto line = // m_Context.getSourceManager().getPresumedLoc(B->getBeginLoc()).getLine(); // auto column = @@ -3063,6 +3067,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, bodyDiff.updateStmtDx(MakeCompoundStmt(revLoopBlock)); m_LoopBlock.pop_back(); + /// Increment statement in the for-loop is only executed if the iteration + /// did not end with a break/continue statement. Therefore, forLoopIncDiff + /// should be inside the last switch case in the reverse pass. + if (forLoopIncDiff) { + if (bodyDiff.getStmt_dx()) { + bodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt( + m_Context, bodyDiff.getStmt_dx(), forLoopIncDiff)); + } else { + bodyDiff.updateStmtDx(forLoopIncDiff); + } + } + activeBreakContHandler->EndCFSwitchStmtScope(); activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); PopBreakContStmtHandler(); @@ -3084,7 +3100,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, addToCurrentBlock(counterDecrement, direction::reverse); addToCurrentBlock(condVarDiff, direction::reverse); addToCurrentBlock(bodyDiff.getStmt_dx(), direction::reverse); - addToCurrentBlock(forLoopIncDiff, direction::reverse); bodyDiff = {bodyDiff.getStmt(), unwrapIfSingleStmt(endBlock(direction::reverse))}; return bodyDiff; diff --git a/lib/Differentiator/TBRAnalyzer.cpp b/lib/Differentiator/TBRAnalyzer.cpp index e96ec11f8..80996c4db 100644 --- a/lib/Differentiator/TBRAnalyzer.cpp +++ b/lib/Differentiator/TBRAnalyzer.cpp @@ -28,8 +28,19 @@ void TBRAnalyzer::VarData::merge(VarData* mergeData) { pair.second->merge(mergeData->val.objData[pair.first]); } } else if (this->type == ARR_TYPE) { + /// FIXME: Currently non-constant indices are not supported in merging. for (auto pair : this->val.arrData) { - pair.second->merge(mergeData->val.arrData[pair.first]); + auto it = mergeData->val.arrData.find(pair.first); + if (it != mergeData->val.arrData.end()) { + pair.second->merge(it->second); + } + } + for (auto pair : mergeData->val.arrData) { + auto it = this->val.arrData.find(pair.first); + if (it == mergeData->val.arrData.end()) { + std::unordered_map refVars; + this->val.arrData[pair.first] = pair.second->copy(refVars); + } } } else if (this->type == REF_TYPE && this->val.refData) { this->val.refData->merge(mergeData->val.refData); @@ -551,8 +562,9 @@ void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) { /// Multiplication results in a linear expression if and only if one of the /// factors is constant. Expr::EvalResult dummy; - bool nonLinear = !R->EvaluateAsConstantExpr(dummy, *m_Context) && - !L->EvaluateAsConstantExpr(dummy, *m_Context); + bool nonLinear = + !clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, *m_Context) && + !clad_compat::Expr_EvaluateAsConstantExpr(L, dummy, *m_Context); if (nonLinear) startNonLinearMode(); @@ -565,7 +577,8 @@ void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) { /// Division normally only results in a linear expression when the /// denominator is constant. Expr::EvalResult dummy; - bool nonLinear = !R->EvaluateAsConstantExpr(dummy, *m_Context); + bool nonLinear = + !clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, *m_Context); if (nonLinear) startNonLinearMode(); @@ -591,7 +604,8 @@ void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) { /// represents the same operation as 'x = x * y' ('x = x / y') and, /// therefore, LHS has to be visited in markingMode|nonLinearMode. Expr::EvalResult dummy; - bool RisNotConst = !R->EvaluateAsConstantExpr(dummy, *m_Context); + bool RisNotConst = + !clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, *m_Context); if (RisNotConst) setMode(Mode::markingMode | Mode::nonLinearMode); Visit(L); @@ -757,7 +771,6 @@ void TBRAnalyzer::VisitWhileStmt(const clang::WhileStmt* WS) { /// First pass innermostLoopBranch = reqStack.size() - 2; firstLoopPass = true; - mergeCurBranchTo(innermostLoopBranch - 1); if (body) Visit(body); if (deleteCurBranch) { diff --git a/test/Misc/RunDemos.C b/test/Misc/RunDemos.C index 3929feaf6..6b9ec86be 100644 --- a/test/Misc/RunDemos.C +++ b/test/Misc/RunDemos.C @@ -149,7 +149,7 @@ // CHECK_CUSTOM_MODEL-NOT: Could not load {{.*}}cladCustomModelPlugin{{.*}} -// RUN: ./CustomModelTest.out | FileCheck -check-prefix CHECK_CUSTOM_MODEL_EXEC %s +// RUN: ./CustomModelTest.out // CHECK_CUSTOM_MODEL_EXEC-NOT:{{.*error|warning|note:.*}} // CHECK_CUSTOM_MODEL_EXEC: The code is: // CHECK_CUSTOM_MODEL_EXEC-NEXT: void func_grad(float x, float y, clad::array_ref _d_x, clad::array_ref _d_y, double &_final_error) { @@ -188,7 +188,7 @@ // CHECK_PRINT_MODEL-NOT: Could not load {{.*}}cladPrintModelPlugin{{.*}} -// RUN: ./PrintModelTest.out | FileCheck -check-prefix CHECK_PRINT_MODEL_EXEC %s +// RUN: ./PrintModelTest.out // CHECK_PRINT_MODEL_EXEC-NOT:{{.*error|warning|note:.*}} // CHECK_PRINT_MODEL_EXEC: The code is: // CHECK_PRINT_MODEL_EXEC-NEXT: void func_grad(float x, float y, clad::array_ref _d_x, clad::array_ref _d_y, double &_final_error) {