Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Max Andriychuk authored and Max Andriychuk committed Sep 13, 2024
1 parent db2e2be commit 792cd6b
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 52 deletions.
44 changes: 18 additions & 26 deletions lib/Differentiator/ActivityAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ void VariedAnalyzer::Analyze(const FunctionDecl* FD) {
m_CFG = clang::CFG::buildCFG(FD, FD->getBody(), &m_Context, Options);

m_BlockData.resize(m_CFG->size());
m_BlockPassCounter.resize(m_CFG->size(), 0);

// Set current block ID to the ID of entry the block.
CFGBlock* entry = &m_CFG->getEntry();
m_CurBlockID = entry->getBlockID();
Expand All @@ -27,19 +25,22 @@ void VariedAnalyzer::Analyze(const FunctionDecl* FD) {
m_CurBlockID = *IDIter;
m_CFGQueue.erase(IDIter);
CFGBlock& nextBlock = *getCFGBlockByID(m_CurBlockID);
VisitCFGBlock(nextBlock);
AnalyzeCFGBlock(nextBlock);
}
}

CFGBlock* VariedAnalyzer::getCFGBlockByID(unsigned ID) {
return *(m_CFG->begin() + ID);
}

void VariedAnalyzer::VisitCFGBlock(const CFGBlock& block) {
void VariedAnalyzer::AnalyzeCFGBlock(const CFGBlock& block) {
// Visit all the statements inside the block.
for (const clang::CFGElement& Element : block) {
if (Element.getKind() == clang::CFGElement::Statement) {
const clang::Stmt* S = Element.castAs<clang::CFGStmt>().getStmt();
// The const_cast is inevitable, since there is no
// ConstRecusiveASTVisitor.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
TraverseStmt(const_cast<clang::Stmt*>(S));
}
}
Expand All @@ -64,26 +65,19 @@ void VariedAnalyzer::VisitCFGBlock(const CFGBlock& block) {
if (shouldPushSucc)
m_CFGQueue.insert(succ->getBlockID());

merge(succData.get(), m_BlockData[block.getBlockID()].get());
mergeVarsData(succData.get(), m_BlockData[block.getBlockID()].get());
}
// FIXME: Information about the varied variables is stored in the last block,
// so we should be able to get it form there
for (const VarDecl* i : *m_BlockData[block.getBlockID()])
m_VariedDecls.insert(i);
}

bool VariedAnalyzer::isVaried(const VarDecl* VD) {
VarsData& curBranch = getCurBlockVarsData();
bool VariedAnalyzer::isVaried(const VarDecl* VD) const {
const VarsData& curBranch = getCurBlockVarsData();
return curBranch.find(VD) != curBranch.end();
}

void VariedAnalyzer::merge(VarsData* targetData, VarsData* mergeData) {
for (const VarDecl* i : *mergeData)
targetData->insert(i);
for (const VarDecl* i : *targetData)
mergeData->insert(i);
}

void VariedAnalyzer::copyVarToCurBlock(const clang::VarDecl* VD) {
VarsData& curBranch = getCurBlockVarsData();
curBranch.insert(VD);
Expand Down Expand Up @@ -134,16 +128,15 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) {

bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) {
for (Decl* D : DS->decls()) {
if (auto* VD = dyn_cast<VarDecl>(D)) {
if (Expr* init = VD->getInit()) {
m_Varied = false;
TraverseStmt(init);
m_Marking = true;
VarsData& curBranch = getCurBlockVarsData();
if (m_Varied && curBranch.find(VD) == curBranch.end())
copyVarToCurBlock(VD);
m_Marking = false;
}
if (!isa<VarDecl>(D))
continue;

Check warning on line 132 in lib/Differentiator/ActivityAnalyzer.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/ActivityAnalyzer.cpp#L132

Added line #L132 was not covered by tests
if (Expr* init = cast<VarDecl>(D)->getInit()) {
m_Varied = false;
TraverseStmt(init);
m_Marking = true;
if (m_Varied)
copyVarToCurBlock(cast<VarDecl>(D));
m_Marking = false;
}
}
return true;
Expand All @@ -160,8 +153,7 @@ bool VariedAnalyzer::VisitDeclRefExpr(DeclRefExpr* DRE) {
m_Varied = true;

if (const auto* VD = dyn_cast<VarDecl>(DRE->getDecl())) {
VarsData& curBranch = getCurBlockVarsData();
if (m_Varied && m_Marking && curBranch.find(VD) == curBranch.end())
if (m_Varied && m_Marking)
copyVarToCurBlock(VD);
}
return true;
Expand Down
46 changes: 29 additions & 17 deletions lib/Differentiator/ActivityAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,44 +9,56 @@

#include <algorithm>
#include <iterator>
#include <memory>
#include <set>
#include <unordered_map>
#include <utility>

using namespace clang;

/// Class that implemets Varied part of the Activity analysis.
/// By performing static data-flow analysis, so called Varied variables
/// are determined, meaning variables that depend on input parameters
/// in a differentiable way. That result enables us to remove redundant
/// statements in the reverse mode, improving generated codes efficiency.
namespace clad {
using VarsData = std::set<const clang::VarDecl*>;
static inline void mergeVarsData(VarsData* targetData, VarsData* mergeData) {
for (const clang::VarDecl* i : *mergeData)
targetData->insert(i);
for (const clang::VarDecl* i : *targetData)
mergeData->insert(i);
}
class VariedAnalyzer : public clang::RecursiveASTVisitor<VariedAnalyzer> {

bool m_Varied = false;
bool m_Marking = false;

std::set<const clang::VarDecl*>& m_VariedDecls;
using VarsData = std::set<const clang::VarDecl*>;
// using VarsData = std::set<const clang::VarDecl*>;
/// A helper method to allocate VarsData
/// \param toAssign - Parameter to initialize new VarsData with.
/// \return Unique pointer to a new object of type Varsdata.
static std::unique_ptr<VarsData> createNewVarsData(VarsData toAssign) {
return std::unique_ptr<VarsData>(new VarsData(std::move(toAssign)));
}

VarsData m_LoopMem;

clang::CFGBlock* getCFGBlockByID(unsigned ID);

static void merge(VarsData* targetData, VarsData* mergeData);
ASTContext& m_Context;
clang::ASTContext& m_Context;
std::unique_ptr<clang::CFG> m_CFG;
std::vector<std::unique_ptr<VarsData>> m_BlockData;
std::vector<short> m_BlockPassCounter;
unsigned m_CurBlockID{};
std::set<unsigned> m_CFGQueue;

void addToVaried(const clang::VarDecl* VD);
bool isVaried(const clang::VarDecl* VD);

bool isVaried(const clang::VarDecl* VD) const;
void copyVarToCurBlock(const clang::VarDecl* VD);
VarsData& getCurBlockVarsData() { return *m_BlockData[m_CurBlockID]; }
[[nodiscard]] const VarsData& getCurBlockVarsData() const {
return const_cast<VariedAnalyzer*>(this)->getCurBlockVarsData();
}
void AnalyzeCFGBlock(const clang::CFGBlock& block);

public:
/// Constructor
VariedAnalyzer(ASTContext& Context, std::set<const clang::VarDecl*>& Decls)
VariedAnalyzer(clang::ASTContext& Context,
std::set<const clang::VarDecl*>& Decls)
: m_VariedDecls(Decls), m_Context(Context) {}

/// Destructor
Expand All @@ -58,9 +70,9 @@ class VariedAnalyzer : public clang::RecursiveASTVisitor<VariedAnalyzer> {
VariedAnalyzer(const VariedAnalyzer&&) = delete;
VariedAnalyzer& operator=(const VariedAnalyzer&&) = delete;

/// Visitors
/// Runs Varied analysis.
/// \param FD Function to run the analysis on.
void Analyze(const clang::FunctionDecl* FD);
void VisitCFGBlock(const clang::CFGBlock& block);
bool VisitBinaryOperator(clang::BinaryOperator* BinOp);
bool VisitCallExpr(clang::CallExpr* CE);
bool VisitConditionalOperator(clang::ConditionalOperator* CO);
Expand All @@ -69,4 +81,4 @@ class VariedAnalyzer : public clang::RecursiveASTVisitor<VariedAnalyzer> {
bool VisitUnaryOperator(clang::UnaryOperator* UnOp);
};
} // namespace clad
#endif // CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H
#endif // CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H
2 changes: 1 addition & 1 deletion lib/Differentiator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ set_property(SOURCE Version.cpp APPEND PROPERTY
# (Ab)use llvm facilities for adding libraries.
llvm_add_library(cladDifferentiator
STATIC
ActivityAnalyzer.cpp
BaseForwardModeVisitor.cpp
CladUtils.cpp
ConstantFolder.cpp
Expand All @@ -36,7 +37,6 @@ llvm_add_library(cladDifferentiator
ReverseModeForwPassVisitor.cpp
ReverseModeVisitor.cpp
TBRAnalyzer.cpp
ActivityAnalyzer.cpp
StmtClone.cpp
VectorForwardModeVisitor.cpp
VectorPushForwardModeVisitor.cpp
Expand Down
10 changes: 2 additions & 8 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,14 +624,8 @@ namespace clad {
return true;

Check warning on line 624 in lib/Differentiator/DiffPlanner.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/DiffPlanner.cpp#L624

Added line #L624 was not covered by tests

if (!m_ActivityRunInfo.HasAnalysisRun) {
if (!DVI.empty()) {
for (const auto& dParam : DVI)
m_ActivityRunInfo.ToBeRecorded.insert(cast<VarDecl>(dParam.param));
} else {
std::copy(Function->param_begin(), Function->param_end(),
std::inserter(m_ActivityRunInfo.ToBeRecorded,
m_ActivityRunInfo.ToBeRecorded.end()));
}
for (const auto& dParam : DVI)
m_ActivityRunInfo.ToBeRecorded.insert(cast<VarDecl>(dParam.param));

VariedAnalyzer analyzer(Function->getASTContext(),
m_ActivityRunInfo.ToBeRecorded);
Expand Down
21 changes: 21 additions & 0 deletions test/Analyses/ActivityReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,26 @@ double f5(double x){
// CHECK-NEXT: _d_g += 1;
// CHECK-NEXT: }

double f6(double x){
double a = 0;
if(0){
a=x;
}
return a;
}

// CHECK: void f6_grad(double x, double *_d_x) {
// CHECK-NEXT: double _t0;
// CHECK-NEXT: double a = 0;
// CHECK-NEXT: if (0) {
// CHECK-NEXT: _t0 = a;
// CHECK-NEXT: a = x;
// CHECK-NEXT: }
// CHECK-NEXT: if (0) {
// CHECK-NEXT: a = _t0;
// CHECK-NEXT: }
// CHECK-NEXT: }

#define TEST(F, x) { \
result[0] = 0; \
auto F##grad = clad::gradient<clad::opts::enable_aa>(F);\
Expand All @@ -207,6 +227,7 @@ int main(){
TEST(f3, 3);// CHECK-EXEC: {0.00}
TEST(f4, 3);// CHECK-EXEC: {4.00}
TEST(f5, 3);// CHECK-EXEC: {0.00}
TEST(f6, 3);// CHECK-EXEC: {0.00}
}

// CHECK: void f4_1_pullback(double v, double u, double _d_y, double *_d_v) {
Expand Down

0 comments on commit 792cd6b

Please sign in to comment.