Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Max Andriychuk authored and Max Andriychuk committed Sep 10, 2024
1 parent db2e2be commit 811a883
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 39 deletions.
2 changes: 1 addition & 1 deletion include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct DiffRequest {

mutable struct ActivityRunInfo {
std::set<const clang::VarDecl*> ToBeRecorded;
bool HasAnalysisRun = false;
bool HasNoAnalysisRun = true;
} m_ActivityRunInfo;

public:
Expand Down
35 changes: 16 additions & 19 deletions lib/Differentiator/ActivityAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ void VariedAnalyzer::Analyze(const FunctionDecl* FD) {
m_CFG = clang::CFG::buildCFG(FD, FD->getBody(), &m_Context, Options);

m_BlockData.resize(m_CFG->size());
m_BlockPassCounter.resize(m_CFG->size(), 0);

// Set current block ID to the ID of entry the block.
CFGBlock* entry = &m_CFG->getEntry();
m_CurBlockID = entry->getBlockID();
Expand All @@ -27,15 +25,15 @@ void VariedAnalyzer::Analyze(const FunctionDecl* FD) {
m_CurBlockID = *IDIter;
m_CFGQueue.erase(IDIter);
CFGBlock& nextBlock = *getCFGBlockByID(m_CurBlockID);
VisitCFGBlock(nextBlock);
AnalyzeCFGBlock(nextBlock);
}
}

CFGBlock* VariedAnalyzer::getCFGBlockByID(unsigned ID) {
return *(m_CFG->begin() + ID);
}

void VariedAnalyzer::VisitCFGBlock(const CFGBlock& block) {
void VariedAnalyzer::AnalyzeCFGBlock(const CFGBlock& block) {
// Visit all the statements inside the block.
for (const clang::CFGElement& Element : block) {
if (Element.getKind() == clang::CFGElement::Statement) {
Expand All @@ -49,8 +47,9 @@ void VariedAnalyzer::VisitCFGBlock(const CFGBlock& block) {
continue;
auto& succData = m_BlockData[succ->getBlockID()];

if (!succData)
if (!succData) {
succData = createNewVarsData(*m_BlockData[block.getBlockID()]);
}

bool shouldPushSucc = true;
if (succ->getBlockID() > block.getBlockID()) {
Expand All @@ -66,14 +65,13 @@ void VariedAnalyzer::VisitCFGBlock(const CFGBlock& block) {

merge(succData.get(), m_BlockData[block.getBlockID()].get());
}
// FIXME: Information about the varied variables is stored in the last block,
// so we should be able to get it form there
// FIXME: Information about the varied variables is stored in the last block, so we should be able to get it form there
for (const VarDecl* i : *m_BlockData[block.getBlockID()])
m_VariedDecls.insert(i);
}

bool VariedAnalyzer::isVaried(const VarDecl* VD) {
VarsData& curBranch = getCurBlockVarsData();
bool VariedAnalyzer::isVaried(const VarDecl* VD) const{
const VarsData& curBranch = getCurBlockVarsData();
return curBranch.find(VD) != curBranch.end();
}

Expand Down Expand Up @@ -115,21 +113,21 @@ bool VariedAnalyzer::VisitConditionalOperator(ConditionalOperator* CO) {
}

bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) {
FunctionDecl* FD = CE->getDirectCallee();
FunctionDecl* FD = CE->getDirectCallee();
bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams());
std::set<const clang::VarDecl*> variedParam;
if (noHiddenParam) {
MutableArrayRef<ParmVarDecl*> FDparam = FD->parameters();
for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) {
if(noHiddenParam){
MutableArrayRef<ParmVarDecl *> FDparam = FD->parameters();
for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i){
clang::Expr* par = CE->getArg(i);
TraverseStmt(par);
if (m_Varied) {
if(m_Varied){
m_VariedDecls.insert(FDparam[i]);
m_Varied = false;
}
}
}
return true;
return true;
}

bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) {
Expand All @@ -139,8 +137,7 @@ bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) {
m_Varied = false;
TraverseStmt(init);
m_Marking = true;
VarsData& curBranch = getCurBlockVarsData();
if (m_Varied && curBranch.find(VD) == curBranch.end())
if (m_Varied)
copyVarToCurBlock(VD);
m_Marking = false;
}
Expand All @@ -160,10 +157,10 @@ bool VariedAnalyzer::VisitDeclRefExpr(DeclRefExpr* DRE) {
m_Varied = true;

if (const auto* VD = dyn_cast<VarDecl>(DRE->getDecl())) {
VarsData& curBranch = getCurBlockVarsData();
if (m_Varied && m_Marking && curBranch.find(VD) == curBranch.end())
if (m_Varied && m_Marking)
copyVarToCurBlock(VD);
}
return true;
}
} // namespace clad

36 changes: 20 additions & 16 deletions lib/Differentiator/ActivityAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
#include "clad/Differentiator/Compatibility.h"

#include <algorithm>
#include <utility>
#include <iterator>
#include <memory>
#include <set>
#include <unordered_map>
#include <utility>

using namespace clang;

/// @brief Class that implemets Varied part of the Activity analysis.
/// By performing static data-flow analysis, so called Varied variables
/// are determined, meaning variables that depend on input parameters
/// in a differentiable way. That result enables us to remove redundant
/// statements in the reverse mode, improving generated codes efficiency.
namespace clad {
class VariedAnalyzer : public clang::RecursiveASTVisitor<VariedAnalyzer> {

Expand All @@ -23,30 +27,30 @@ class VariedAnalyzer : public clang::RecursiveASTVisitor<VariedAnalyzer> {

std::set<const clang::VarDecl*>& m_VariedDecls;
using VarsData = std::set<const clang::VarDecl*>;
/// @brief A helper method to allocate VarsData
/// @param toAssign Parameter to initialize new VarsData with.
/// @return Unique pointer to a new object of type Varsdata.
static std::unique_ptr<VarsData> createNewVarsData(VarsData toAssign) {
return std::unique_ptr<VarsData>(new VarsData(std::move(toAssign)));
return std::make_unique<VarsData>(std::move(toAssign));
}

VarsData m_LoopMem;
clang::CFGBlock* getCFGBlockByID(unsigned ID);

clang::CFGBlock* getCFGBlockByID(unsigned ID);
static void merge(VarsData* targetData, VarsData* mergeData);
ASTContext& m_Context;
clang::ASTContext& m_Context;
std::unique_ptr<clang::CFG> m_CFG;
std::vector<std::unique_ptr<VarsData>> m_BlockData;
std::vector<short> m_BlockPassCounter;
unsigned m_CurBlockID{};
std::set<unsigned> m_CFGQueue;

void addToVaried(const clang::VarDecl* VD);
bool isVaried(const clang::VarDecl* VD);

bool isVaried(const clang::VarDecl* VD) const;
void copyVarToCurBlock(const clang::VarDecl* VD);
VarsData& getCurBlockVarsData() { return *m_BlockData[m_CurBlockID]; }
const VarsData& getCurBlockVarsData() const { return const_cast<VariedAnalyzer*>(this)->getCurBlockVarsData();}
void AnalyzeCFGBlock(const clang::CFGBlock& block);

public:
/// Constructor
VariedAnalyzer(ASTContext& Context, std::set<const clang::VarDecl*>& Decls)
VariedAnalyzer(clang::ASTContext& Context, std::set<const clang::VarDecl*>& Decls)
: m_VariedDecls(Decls), m_Context(Context) {}

/// Destructor
Expand All @@ -58,9 +62,9 @@ class VariedAnalyzer : public clang::RecursiveASTVisitor<VariedAnalyzer> {
VariedAnalyzer(const VariedAnalyzer&&) = delete;
VariedAnalyzer& operator=(const VariedAnalyzer&&) = delete;

/// Visitors
/// @brief Runs Varied analysis.
/// @param FD Function to run the analysis on.
void Analyze(const clang::FunctionDecl* FD);
void VisitCFGBlock(const clang::CFGBlock& block);
bool VisitBinaryOperator(clang::BinaryOperator* BinOp);
bool VisitCallExpr(clang::CallExpr* CE);
bool VisitConditionalOperator(clang::ConditionalOperator* CO);
Expand All @@ -69,4 +73,4 @@ class VariedAnalyzer : public clang::RecursiveASTVisitor<VariedAnalyzer> {
bool VisitUnaryOperator(clang::UnaryOperator* UnOp);
};
} // namespace clad
#endif // CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H
#endif // CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H
2 changes: 1 addition & 1 deletion lib/Differentiator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ set_property(SOURCE Version.cpp APPEND PROPERTY
# (Ab)use llvm facilities for adding libraries.
llvm_add_library(cladDifferentiator
STATIC
ActivityAnalyzer.cpp
BaseForwardModeVisitor.cpp
CladUtils.cpp
ConstantFolder.cpp
Expand All @@ -36,7 +37,6 @@ llvm_add_library(cladDifferentiator
ReverseModeForwPassVisitor.cpp
ReverseModeVisitor.cpp
TBRAnalyzer.cpp
ActivityAnalyzer.cpp
StmtClone.cpp
VectorForwardModeVisitor.cpp
VectorPushForwardModeVisitor.cpp
Expand Down
4 changes: 2 additions & 2 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ namespace clad {
if (VD->getType()->isPointerType())
return true;

if (!m_ActivityRunInfo.HasAnalysisRun) {
if (m_ActivityRunInfo.HasNoAnalysisRun) {
if (!DVI.empty()) {
for (const auto& dParam : DVI)
m_ActivityRunInfo.ToBeRecorded.insert(cast<VarDecl>(dParam.param));
Expand All @@ -636,7 +636,7 @@ namespace clad {
VariedAnalyzer analyzer(Function->getASTContext(),
m_ActivityRunInfo.ToBeRecorded);
analyzer.Analyze(Function);
m_ActivityRunInfo.HasAnalysisRun = true;
m_ActivityRunInfo.HasNoAnalysisRun = false;
}
auto found = m_ActivityRunInfo.ToBeRecorded.find(VD);
return found != m_ActivityRunInfo.ToBeRecorded.end();
Expand Down

0 comments on commit 811a883

Please sign in to comment.