From 7b6e3b8d72a4ff1728e2810231b5621bc349424b Mon Sep 17 00:00:00 2001 From: Max Andriychuk Date: Thu, 21 Nov 2024 16:51:19 +0100 Subject: [PATCH] Don't mark nonvaried constant params --- include/clad/Differentiator/DiffPlanner.h | 7 ++++ lib/Differentiator/ActivityAnalyzer.cpp | 41 +++++++++++++++-------- lib/Differentiator/ActivityAnalyzer.h | 1 + lib/Differentiator/DiffPlanner.cpp | 12 ++----- lib/Differentiator/ReverseModeVisitor.cpp | 1 + test/Analyses/ActivityReverse.cpp | 26 ++++++++++++-- 6 files changed, 62 insertions(+), 26 deletions(-) diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index d2b74592b..ccce6288c 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -144,6 +144,13 @@ struct DiffRequest { bool shouldBeRecorded(clang::Expr* E) const; bool shouldHaveAdjoint(const clang::VarDecl* VD) const; + + void setToBeRecorded(std::set init) { + this->m_ActivityRunInfo.ToBeRecorded = init; + } + std::set getToBeRecorded() const { + return this->m_ActivityRunInfo.ToBeRecorded; + } }; using DiffInterval = std::vector; diff --git a/lib/Differentiator/ActivityAnalyzer.cpp b/lib/Differentiator/ActivityAnalyzer.cpp index f7810baf2..09191ba05 100644 --- a/lib/Differentiator/ActivityAnalyzer.cpp +++ b/lib/Differentiator/ActivityAnalyzer.cpp @@ -126,20 +126,29 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) { clang::Expr* par = CE->getArg(i); QualType parType = FDparam[i]->getType(); - while (parType->isPointerType()) - parType = parType->getPointeeType(); - if((parType->isReferenceType() || utils::isArrayOrPointerType(parType)) && !parType.isConstQualified()){ + QualType innermostType = parType; + while (innermostType->isPointerType()) + innermostType = innermostType->getPointeeType(); + + if ((parType->isReferenceType() || + utils::isArrayOrPointerType(parType)) && + !innermostType.isConstQualified()) { m_Marking = true; m_Varied = true; } TraverseStmt(par); + if ((parType->isReferenceType() || + utils::isArrayOrPointerType(parType)) && + !innermostType.isConstQualified()) { + m_Marking = false; //? + m_Varied = false; + } - m_Marking = false; - m_Varied = false; - - if(!parType.isConstQualified()) + if (!(!m_Varied && innermostType.isConstQualified() && + parType->isPointerType())) { m_VariedDecls.insert(FDparam[i]); + } } } return true; @@ -147,16 +156,24 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) { bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) { for (Decl* D : DS->decls()) { + QualType VDTy = cast(D)->getType(); - if(utils::isArrayOrPointerType(VDTy)){ + QualType innermost = VDTy; + while (innermost->isPointerType()) + innermost = innermost->getPointeeType(); + if (VDTy->isPointerType() && !innermost.isConstQualified()) { + copyVarToCurBlock(cast(D)); + continue; + } else if (VDTy->isArrayType()) { copyVarToCurBlock(cast(D)); continue; } + if (Expr* init = cast(D)->getInit()) { m_Varied = false; TraverseStmt(init); m_Marking = true; - if (m_Varied ) + if (m_Varied) copyVarToCurBlock(cast(D)); m_Marking = false; } @@ -167,10 +184,6 @@ bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) { bool VariedAnalyzer::VisitUnaryOperator(UnaryOperator* UnOp) { const auto opCode = UnOp->getOpcode(); Expr* E = UnOp->getSubExpr(); - if (opCode == UO_AddrOf || opCode == UO_Deref) { - m_Varied = true; - m_Marking = true; - } TraverseStmt(E); m_Marking = false; return true; @@ -181,7 +194,7 @@ bool VariedAnalyzer::VisitDeclRefExpr(DeclRefExpr* DRE) { if (!VD) return true; - if (isVaried(VD)) + if (isVaried(VD) || VD->getType()->isArrayType()) m_Varied = true; if (m_Varied && m_Marking) diff --git a/lib/Differentiator/ActivityAnalyzer.h b/lib/Differentiator/ActivityAnalyzer.h index 3e0c62693..2666ac0e4 100644 --- a/lib/Differentiator/ActivityAnalyzer.h +++ b/lib/Differentiator/ActivityAnalyzer.h @@ -34,6 +34,7 @@ class VariedAnalyzer : public clang::RecursiveASTVisitor { VarsData m_LoopMem; clang::CFGBlock* getCFGBlockByID(unsigned ID); + std::unordered_map m_Dep; clang::ASTContext& m_Context; std::unique_ptr m_CFG; diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index 8f4cb12f6..18b2325a6 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -639,19 +639,13 @@ namespace clad { if (!m_ActivityRunInfo.HasAnalysisRun) { ArrayRef FDparam = Function->parameters(); std::vector derivedParam; - - for (auto* parameter : FDparam) { - QualType parType = parameter->getType(); - while (parType->isPointerType()) - parType = parType->getPointeeType(); - if (!parType.isConstQualified()) - derivedParam.push_back(parameter); - } + if (Args) + for (const auto& dParam : DVI) + m_ActivityRunInfo.ToBeRecorded.insert(cast(dParam.param)); std::copy(derivedParam.begin(), derivedParam.end(), std::inserter(m_ActivityRunInfo.ToBeRecorded, m_ActivityRunInfo.ToBeRecorded.end())); - VariedAnalyzer analyzer(Function->getASTContext(), m_ActivityRunInfo.ToBeRecorded); analyzer.Analyze(Function); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 19b902ce1..aff42d3cb 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2201,6 +2201,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, pullbackRequest.VerboseDiags = false; pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis; pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis; + pullbackRequest.setToBeRecorded(m_DiffReq.getToBeRecorded()); bool isaMethod = isa(FD); for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) if (MD && isLambdaCallOperator(MD)) { diff --git a/test/Analyses/ActivityReverse.cpp b/test/Analyses/ActivityReverse.cpp index 4f5d3a2d6..57722941b 100644 --- a/test/Analyses/ActivityReverse.cpp +++ b/test/Analyses/ActivityReverse.cpp @@ -6,6 +6,26 @@ #include "clad/Differentiator/Differentiator.h" +inline double interpolate1d(double low, double high, double val, unsigned int numBins, double const* vals) +{ + double binWidth = (high - low) / numBins; + int idx = val >= high ? numBins - 1 : std::abs((val - low) / binWidth); + + // interpolation + double central = low + (idx + 0.5) * binWidth; + if (val > low + 0.5 * binWidth && val < high - 0.5 * binWidth) { + double slope; + if (val < central) { + slope = vals[idx] - vals[idx - 1]; + } else { + slope = vals[idx + 1] - vals[idx]; + } + return vals[idx] + slope * (val - central) / binWidth; + } + + return vals[idx]; +} + double f1(double x){ double a = x*x; double b = 1; @@ -273,7 +293,7 @@ double f9(double x, double const *obs) return res; } -// CHECK: void f9_grad(double x, const double *obs, double *_d_x, double *_d_obs) { +// CHECK: void f9_grad_0(double x, const double *obs, double *_d_x) { // CHECK-NEXT: int loopIdx0 = 0; // CHECK-NEXT: clad::tape _t1 = {}; // CHECK-NEXT: double _d_res = 0.; @@ -373,8 +393,8 @@ int main(){ TEST(f6, 3);// CHECK-EXEC: {0.00} TEST(f7, 3);// CHECK-EXEC: {1.00} TEST(f8, 3);// CHECK-EXEC: {1.00} - auto grad = clad::gradient(f9); - grad.execute(3, arr, &dx, darr); + auto grad9 = clad::gradient(f9, "x"); + grad9.execute(3, arr, &dx, darr); printf("%.2f\n", dx);// CHECK-EXEC: 2.00 TEST(f10, 3);// CHECK-EXEC: {1.00} TEST(f11, 3);// CHECK-EXEC: {1.00}