Skip to content

Commit

Permalink
Add CallExpr and basic ConditionalOperator support
Browse files Browse the repository at this point in the history
fix

fix
  • Loading branch information
Max Andriychuk authored and Max Andriychuk committed Sep 10, 2024
1 parent db2e2be commit f00bbd0
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 18 deletions.
4 changes: 1 addition & 3 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
#include "clad/Differentiator/DynamicGraph.h"
#include "clad/Differentiator/ParseDiffArgsTypes.h"

#include <iterator>
#include <set>
namespace clang {
class CallExpr;
class CompilerInstance;
Expand All @@ -35,7 +33,7 @@ struct DiffRequest {

mutable struct ActivityRunInfo {
std::set<const clang::VarDecl*> ToBeRecorded;
bool HasAnalysisRun = false;
bool HasNoAnalysisRun = true;
} m_ActivityRunInfo;

public:
Expand Down
8 changes: 3 additions & 5 deletions lib/Differentiator/ActivityAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ void VariedAnalyzer::Analyze(const FunctionDecl* FD) {
m_CFG = clang::CFG::buildCFG(FD, FD->getBody(), &m_Context, Options);

m_BlockData.resize(m_CFG->size());
m_BlockPassCounter.resize(m_CFG->size(), 0);

// Set current block ID to the ID of entry the block.
CFGBlock* entry = &m_CFG->getEntry();
m_CurBlockID = entry->getBlockID();
Expand All @@ -27,15 +25,15 @@ void VariedAnalyzer::Analyze(const FunctionDecl* FD) {
m_CurBlockID = *IDIter;
m_CFGQueue.erase(IDIter);
CFGBlock& nextBlock = *getCFGBlockByID(m_CurBlockID);
VisitCFGBlock(nextBlock);
AnalyzeCFGBlock(nextBlock);
}
}

CFGBlock* VariedAnalyzer::getCFGBlockByID(unsigned ID) {
return *(m_CFG->begin() + ID);
}

void VariedAnalyzer::VisitCFGBlock(const CFGBlock& block) {
void VariedAnalyzer::AnalyzeCFGBlock(const CFGBlock& block) {
// Visit all the statements inside the block.
for (const clang::CFGElement& Element : block) {
if (Element.getKind() == clang::CFGElement::Statement) {
Expand Down Expand Up @@ -166,4 +164,4 @@ bool VariedAnalyzer::VisitDeclRefExpr(DeclRefExpr* DRE) {
}
return true;
}
} // namespace clad
} // namespace clad
32 changes: 25 additions & 7 deletions lib/Differentiator/ActivityAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,20 @@
#include "clad/Differentiator/Compatibility.h"

#include <algorithm>
#include <utility>
#include <iterator>
#include <memory>
#include <set>
#include <unordered_map>
#include <utility>

using namespace clang;

/// @brief Class that implemets Varied part of the Activity analysis.
/// By performing static data-flow analysis, so called Varied variables
/// are determined, meaning variables that depend on input parameters
/// in a differentiable way. That result enables us to remove redundant
/// statements in the reverse mode, improving generated codes efficiency.
namespace clad {
class VariedAnalyzer : public clang::RecursiveASTVisitor<VariedAnalyzer> {

Expand All @@ -27,26 +34,36 @@ class VariedAnalyzer : public clang::RecursiveASTVisitor<VariedAnalyzer> {
return std::unique_ptr<VarsData>(new VarsData(std::move(toAssign)));
}

using VarsData = std::set<const clang::VarDecl*>;
/// @brief A helper method to allocate VarsData
/// @param toAssign Parameter to initialize new VarsData with.
/// @return Unique pointer to a new object of type Varsdata.
static std::unique_ptr<VarsData> createNewVarsData(VarsData toAssign) {
return std::make_unique<VarsData>(std::move(toAssign));
}
VarsData m_LoopMem;

clang::CFGBlock* getCFGBlockByID(unsigned ID);

static void merge(VarsData* targetData, VarsData* mergeData);
ASTContext& m_Context;
static void merge(VarsData* targetData, VarsData* mergeData);
clang::ASTContext& m_Context;
std::unique_ptr<clang::CFG> m_CFG;
std::vector<std::unique_ptr<VarsData>> m_BlockData;
std::vector<short> m_BlockPassCounter;
std::vector<std::unique_ptr<VarsData>> m_BlockData;
unsigned m_CurBlockID{};
std::set<unsigned> m_CFGQueue;

void addToVaried(const clang::VarDecl* VD);
bool isVaried(const clang::VarDecl* VD);

bool isVaried(const clang::VarDecl* VD) const;
void copyVarToCurBlock(const clang::VarDecl* VD);
VarsData& getCurBlockVarsData() { return *m_BlockData[m_CurBlockID]; }
const VarsData& getCurBlockVarsData() const { return const_cast<VariedAnalyzer*>(this)->getCurBlockVarsData();}
void AnalyzeCFGBlock(const clang::CFGBlock& block);

public:
/// Constructor
VariedAnalyzer(ASTContext& Context, std::set<const clang::VarDecl*>& Decls)
VariedAnalyzer(clang::ASTContext& Context, std::set<const clang::VarDecl*>& Decls)
: m_VariedDecls(Decls), m_Context(Context) {}

/// Destructor
Expand All @@ -58,7 +75,8 @@ class VariedAnalyzer : public clang::RecursiveASTVisitor<VariedAnalyzer> {
VariedAnalyzer(const VariedAnalyzer&&) = delete;
VariedAnalyzer& operator=(const VariedAnalyzer&&) = delete;

/// Visitors
/// @brief Runs Varied analysis.
/// @param FD Function to run the analysis on.
void Analyze(const clang::FunctionDecl* FD);
void VisitCFGBlock(const clang::CFGBlock& block);
bool VisitBinaryOperator(clang::BinaryOperator* BinOp);
Expand All @@ -69,4 +87,4 @@ class VariedAnalyzer : public clang::RecursiveASTVisitor<VariedAnalyzer> {
bool VisitUnaryOperator(clang::UnaryOperator* UnOp);
};
} // namespace clad
#endif // CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H
#endif // CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H
2 changes: 1 addition & 1 deletion lib/Differentiator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ set_property(SOURCE Version.cpp APPEND PROPERTY
# (Ab)use llvm facilities for adding libraries.
llvm_add_library(cladDifferentiator
STATIC
ActivityAnalyzer.cpp
BaseForwardModeVisitor.cpp
CladUtils.cpp
ConstantFolder.cpp
Expand All @@ -36,7 +37,6 @@ llvm_add_library(cladDifferentiator
ReverseModeForwPassVisitor.cpp
ReverseModeVisitor.cpp
TBRAnalyzer.cpp
ActivityAnalyzer.cpp
StmtClone.cpp
VectorForwardModeVisitor.cpp
VectorPushForwardModeVisitor.cpp
Expand Down
5 changes: 3 additions & 2 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,14 +616,15 @@ namespace clad {
return found != m_TbrRunInfo.ToBeRecorded.end();
}

bool DiffRequest::shouldHaveAdjoint(const VarDecl* VD) const {
bool DiffRequest::shouldHaveAdjoint(const VarDecl* VD) const {
if (!EnableActivityAnalysis)
return true;

if (VD->getType()->isPointerType())
return true;

if (!m_ActivityRunInfo.HasAnalysisRun) {
if (m_ActivityRunInfo.HasNoAnalysisRun) {
if (!DVI.empty()) {
for (const auto& dParam : DVI)
m_ActivityRunInfo.ToBeRecorded.insert(cast<VarDecl>(dParam.param));
Expand All @@ -636,7 +637,7 @@ namespace clad {
VariedAnalyzer analyzer(Function->getASTContext(),
m_ActivityRunInfo.ToBeRecorded);
analyzer.Analyze(Function);
m_ActivityRunInfo.HasAnalysisRun = true;
m_ActivityRunInfo.HasNoAnalysisRun = true;
}
auto found = m_ActivityRunInfo.ToBeRecorded.find(VD);
return found != m_ActivityRunInfo.ToBeRecorded.end();
Expand Down

0 comments on commit f00bbd0

Please sign in to comment.