From 811a88310b431a59e556f18948a1172b5a0ed280 Mon Sep 17 00:00:00 2001 From: Max Andriychuk Date: Tue, 10 Sep 2024 21:40:16 +0200 Subject: [PATCH] minor changes --- include/clad/Differentiator/DiffPlanner.h | 2 +- lib/Differentiator/ActivityAnalyzer.cpp | 35 ++++++++++------------ lib/Differentiator/ActivityAnalyzer.h | 36 +++++++++++++---------- lib/Differentiator/CMakeLists.txt | 2 +- lib/Differentiator/DiffPlanner.cpp | 4 +-- 5 files changed, 40 insertions(+), 39 deletions(-) diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index aabff5210..b01e9603d 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -35,7 +35,7 @@ struct DiffRequest { mutable struct ActivityRunInfo { std::set ToBeRecorded; - bool HasAnalysisRun = false; + bool HasNoAnalysisRun = true; } m_ActivityRunInfo; public: diff --git a/lib/Differentiator/ActivityAnalyzer.cpp b/lib/Differentiator/ActivityAnalyzer.cpp index f1688fa80..0b7e34699 100644 --- a/lib/Differentiator/ActivityAnalyzer.cpp +++ b/lib/Differentiator/ActivityAnalyzer.cpp @@ -10,8 +10,6 @@ void VariedAnalyzer::Analyze(const FunctionDecl* FD) { m_CFG = clang::CFG::buildCFG(FD, FD->getBody(), &m_Context, Options); m_BlockData.resize(m_CFG->size()); - m_BlockPassCounter.resize(m_CFG->size(), 0); - // Set current block ID to the ID of entry the block. CFGBlock* entry = &m_CFG->getEntry(); m_CurBlockID = entry->getBlockID(); @@ -27,7 +25,7 @@ void VariedAnalyzer::Analyze(const FunctionDecl* FD) { m_CurBlockID = *IDIter; m_CFGQueue.erase(IDIter); CFGBlock& nextBlock = *getCFGBlockByID(m_CurBlockID); - VisitCFGBlock(nextBlock); + AnalyzeCFGBlock(nextBlock); } } @@ -35,7 +33,7 @@ CFGBlock* VariedAnalyzer::getCFGBlockByID(unsigned ID) { return *(m_CFG->begin() + ID); } -void VariedAnalyzer::VisitCFGBlock(const CFGBlock& block) { +void VariedAnalyzer::AnalyzeCFGBlock(const CFGBlock& block) { // Visit all the statements inside the block. for (const clang::CFGElement& Element : block) { if (Element.getKind() == clang::CFGElement::Statement) { @@ -49,8 +47,9 @@ void VariedAnalyzer::VisitCFGBlock(const CFGBlock& block) { continue; auto& succData = m_BlockData[succ->getBlockID()]; - if (!succData) + if (!succData) { succData = createNewVarsData(*m_BlockData[block.getBlockID()]); + } bool shouldPushSucc = true; if (succ->getBlockID() > block.getBlockID()) { @@ -66,14 +65,13 @@ void VariedAnalyzer::VisitCFGBlock(const CFGBlock& block) { merge(succData.get(), m_BlockData[block.getBlockID()].get()); } - // FIXME: Information about the varied variables is stored in the last block, - // so we should be able to get it form there + // FIXME: Information about the varied variables is stored in the last block, so we should be able to get it form there for (const VarDecl* i : *m_BlockData[block.getBlockID()]) m_VariedDecls.insert(i); } -bool VariedAnalyzer::isVaried(const VarDecl* VD) { - VarsData& curBranch = getCurBlockVarsData(); +bool VariedAnalyzer::isVaried(const VarDecl* VD) const{ + const VarsData& curBranch = getCurBlockVarsData(); return curBranch.find(VD) != curBranch.end(); } @@ -115,21 +113,21 @@ bool VariedAnalyzer::VisitConditionalOperator(ConditionalOperator* CO) { } bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) { - FunctionDecl* FD = CE->getDirectCallee(); + FunctionDecl* FD = CE->getDirectCallee(); bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams()); std::set variedParam; - if (noHiddenParam) { - MutableArrayRef FDparam = FD->parameters(); - for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { + if(noHiddenParam){ + MutableArrayRef FDparam = FD->parameters(); + for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i){ clang::Expr* par = CE->getArg(i); TraverseStmt(par); - if (m_Varied) { + if(m_Varied){ m_VariedDecls.insert(FDparam[i]); m_Varied = false; } } } - return true; + return true; } bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) { @@ -139,8 +137,7 @@ bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) { m_Varied = false; TraverseStmt(init); m_Marking = true; - VarsData& curBranch = getCurBlockVarsData(); - if (m_Varied && curBranch.find(VD) == curBranch.end()) + if (m_Varied) copyVarToCurBlock(VD); m_Marking = false; } @@ -160,10 +157,10 @@ bool VariedAnalyzer::VisitDeclRefExpr(DeclRefExpr* DRE) { m_Varied = true; if (const auto* VD = dyn_cast(DRE->getDecl())) { - VarsData& curBranch = getCurBlockVarsData(); - if (m_Varied && m_Marking && curBranch.find(VD) == curBranch.end()) + if (m_Varied && m_Marking) copyVarToCurBlock(VD); } return true; } } // namespace clad + diff --git a/lib/Differentiator/ActivityAnalyzer.h b/lib/Differentiator/ActivityAnalyzer.h index d3e3727d6..d42d61c6f 100644 --- a/lib/Differentiator/ActivityAnalyzer.h +++ b/lib/Differentiator/ActivityAnalyzer.h @@ -8,13 +8,17 @@ #include "clad/Differentiator/Compatibility.h" #include +#include #include +#include #include #include #include - -using namespace clang; - +/// @brief Class that implemets Varied part of the Activity analysis. +/// By performing static data-flow analysis, so called Varied variables +/// are determined, meaning variables that depend on input parameters +/// in a differentiable way. That result enables us to remove redundant +/// statements in the reverse mode, improving generated codes efficiency. namespace clad { class VariedAnalyzer : public clang::RecursiveASTVisitor { @@ -23,30 +27,30 @@ class VariedAnalyzer : public clang::RecursiveASTVisitor { std::set& m_VariedDecls; using VarsData = std::set; + /// @brief A helper method to allocate VarsData + /// @param toAssign Parameter to initialize new VarsData with. + /// @return Unique pointer to a new object of type Varsdata. static std::unique_ptr createNewVarsData(VarsData toAssign) { - return std::unique_ptr(new VarsData(std::move(toAssign))); + return std::make_unique(std::move(toAssign)); } - VarsData m_LoopMem; - clang::CFGBlock* getCFGBlockByID(unsigned ID); + clang::CFGBlock* getCFGBlockByID(unsigned ID); static void merge(VarsData* targetData, VarsData* mergeData); - ASTContext& m_Context; + clang::ASTContext& m_Context; std::unique_ptr m_CFG; std::vector> m_BlockData; - std::vector m_BlockPassCounter; unsigned m_CurBlockID{}; std::set m_CFGQueue; - - void addToVaried(const clang::VarDecl* VD); - bool isVaried(const clang::VarDecl* VD); - + bool isVaried(const clang::VarDecl* VD) const; void copyVarToCurBlock(const clang::VarDecl* VD); VarsData& getCurBlockVarsData() { return *m_BlockData[m_CurBlockID]; } + const VarsData& getCurBlockVarsData() const { return const_cast(this)->getCurBlockVarsData();} + void AnalyzeCFGBlock(const clang::CFGBlock& block); public: /// Constructor - VariedAnalyzer(ASTContext& Context, std::set& Decls) + VariedAnalyzer(clang::ASTContext& Context, std::set& Decls) : m_VariedDecls(Decls), m_Context(Context) {} /// Destructor @@ -58,9 +62,9 @@ class VariedAnalyzer : public clang::RecursiveASTVisitor { VariedAnalyzer(const VariedAnalyzer&&) = delete; VariedAnalyzer& operator=(const VariedAnalyzer&&) = delete; - /// Visitors + /// @brief Runs Varied analysis. + /// @param FD Function to run the analysis on. void Analyze(const clang::FunctionDecl* FD); - void VisitCFGBlock(const clang::CFGBlock& block); bool VisitBinaryOperator(clang::BinaryOperator* BinOp); bool VisitCallExpr(clang::CallExpr* CE); bool VisitConditionalOperator(clang::ConditionalOperator* CO); @@ -69,4 +73,4 @@ class VariedAnalyzer : public clang::RecursiveASTVisitor { bool VisitUnaryOperator(clang::UnaryOperator* UnOp); }; } // namespace clad -#endif // CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H \ No newline at end of file +#endif // CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H diff --git a/lib/Differentiator/CMakeLists.txt b/lib/Differentiator/CMakeLists.txt index aabbb72f7..7f928b2ac 100644 --- a/lib/Differentiator/CMakeLists.txt +++ b/lib/Differentiator/CMakeLists.txt @@ -21,6 +21,7 @@ set_property(SOURCE Version.cpp APPEND PROPERTY # (Ab)use llvm facilities for adding libraries. llvm_add_library(cladDifferentiator STATIC + ActivityAnalyzer.cpp BaseForwardModeVisitor.cpp CladUtils.cpp ConstantFolder.cpp @@ -36,7 +37,6 @@ llvm_add_library(cladDifferentiator ReverseModeForwPassVisitor.cpp ReverseModeVisitor.cpp TBRAnalyzer.cpp - ActivityAnalyzer.cpp StmtClone.cpp VectorForwardModeVisitor.cpp VectorPushForwardModeVisitor.cpp diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index e7b6300e0..4d707cd4a 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -623,7 +623,7 @@ namespace clad { if (VD->getType()->isPointerType()) return true; - if (!m_ActivityRunInfo.HasAnalysisRun) { + if (m_ActivityRunInfo.HasNoAnalysisRun) { if (!DVI.empty()) { for (const auto& dParam : DVI) m_ActivityRunInfo.ToBeRecorded.insert(cast(dParam.param)); @@ -636,7 +636,7 @@ namespace clad { VariedAnalyzer analyzer(Function->getASTContext(), m_ActivityRunInfo.ToBeRecorded); analyzer.Analyze(Function); - m_ActivityRunInfo.HasAnalysisRun = true; + m_ActivityRunInfo.HasNoAnalysisRun = false; } auto found = m_ActivityRunInfo.ToBeRecorded.find(VD); return found != m_ActivityRunInfo.ToBeRecorded.end();