Skip to content

Commit

Permalink
Add Varied analysis to the reverse mode (#1084)
Browse files Browse the repository at this point in the history
Partially addresses #716
  • Loading branch information
ovdiiuv authored Oct 15, 2024
1 parent 332358e commit 2d08ce1
Show file tree
Hide file tree
Showing 12 changed files with 648 additions and 28 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/CladConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ enum opts : unsigned {
// 00 - default, 01 - enable, 10 - disable, 11 - not used / invalid
enable_tbr = 1 << (ORDER_BITS + 2),
disable_tbr = 1 << (ORDER_BITS + 3),
enable_va = 1 << (ORDER_BITS + 5),
disable_va = 1 << (ORDER_BITS + 6),

// Specifying whether we only want the diagonal of the hessian.
diagonal_only = 1 << (ORDER_BITS + 4),
Expand Down
11 changes: 11 additions & 0 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "clad/Differentiator/DynamicGraph.h"
#include "clad/Differentiator/ParseDiffArgsTypes.h"

#include <iterator>
#include <set>
namespace clang {
class CallExpr;
class CompilerInstance;
Expand All @@ -31,6 +33,11 @@ struct DiffRequest {
bool HasAnalysisRun = false;
} m_TbrRunInfo;

mutable struct ActivityRunInfo {
std::set<const clang::VarDecl*> ToBeRecorded;
bool HasAnalysisRun = false;
} m_ActivityRunInfo;

public:
/// Function to be differentiated.
const clang::FunctionDecl* Function = nullptr;
Expand All @@ -57,6 +64,7 @@ struct DiffRequest {
bool VerboseDiags = false;
/// A flag to enable TBR analysis during reverse-mode differentiation.
bool EnableTBRAnalysis = false;
bool EnableVariedAnalysis = false;
/// Puts the derived function and its code in the diff call
void updateCall(clang::FunctionDecl* FD, clang::FunctionDecl* OverloadedFD,
clang::Sema& SemaRef);
Expand Down Expand Up @@ -114,6 +122,7 @@ struct DiffRequest {
RequestedDerivativeOrder == other.RequestedDerivativeOrder &&
CallContext == other.CallContext && Args == other.Args &&
Mode == other.Mode && EnableTBRAnalysis == other.EnableTBRAnalysis &&
EnableVariedAnalysis == other.EnableVariedAnalysis &&
DVI == other.DVI && use_enzyme == other.use_enzyme &&
DeclarationOnly == other.DeclarationOnly;
}
Expand All @@ -131,6 +140,7 @@ struct DiffRequest {
}

bool shouldBeRecorded(clang::Expr* E) const;
bool shouldHaveAdjoint(const clang::VarDecl* VD) const;
};

using DiffInterval = std::vector<clang::SourceRange>;
Expand All @@ -139,6 +149,7 @@ struct DiffRequest {
/// This is a flag to indicate the default behaviour to enable/disable
/// TBR analysis during reverse-mode differentiation.
bool EnableTBRAnalysis = false;
bool EnableVariedAnalysis = false;
};

class DiffCollector: public clang::RecursiveASTVisitor<DiffCollector> {
Expand Down
173 changes: 173 additions & 0 deletions lib/Differentiator/ActivityAnalyzer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
#include "ActivityAnalyzer.h"

using namespace clang;

namespace clad {

void VariedAnalyzer::Analyze(const FunctionDecl* FD) {
// Build the CFG (control-flow graph) of FD.
clang::CFG::BuildOptions Options;
m_CFG = clang::CFG::buildCFG(FD, FD->getBody(), &m_Context, Options);

m_BlockData.resize(m_CFG->size());
// Set current block ID to the ID of entry the block.
CFGBlock* entry = &m_CFG->getEntry();
m_CurBlockID = entry->getBlockID();
m_BlockData[m_CurBlockID] = createNewVarsData({});
for (const VarDecl* i : m_VariedDecls)
m_BlockData[m_CurBlockID]->insert(i);
// Add the entry block to the queue.
m_CFGQueue.insert(m_CurBlockID);

// Visit CFG blocks in the queue until it's empty.
while (!m_CFGQueue.empty()) {
auto IDIter = std::prev(m_CFGQueue.end());
m_CurBlockID = *IDIter;
m_CFGQueue.erase(IDIter);
CFGBlock& nextBlock = *getCFGBlockByID(m_CurBlockID);
AnalyzeCFGBlock(nextBlock);
}
}

void mergeVarsData(std::set<const clang::VarDecl*>* targetData,
std::set<const clang::VarDecl*>* mergeData) {
for (const clang::VarDecl* i : *mergeData)
targetData->insert(i);
*mergeData = *targetData;
}

CFGBlock* VariedAnalyzer::getCFGBlockByID(unsigned ID) {
return *(m_CFG->begin() + ID);
}

void VariedAnalyzer::AnalyzeCFGBlock(const CFGBlock& block) {
// Visit all the statements inside the block.
for (const clang::CFGElement& Element : block) {
if (Element.getKind() == clang::CFGElement::Statement) {
const clang::Stmt* S = Element.castAs<clang::CFGStmt>().getStmt();
// The const_cast is inevitable, since there is no
// ConstRecusiveASTVisitor.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
TraverseStmt(const_cast<clang::Stmt*>(S));
}
}

for (const clang::CFGBlock::AdjacentBlock succ : block.succs()) {
if (!succ)
continue;
auto& succData = m_BlockData[succ->getBlockID()];

if (!succData)
succData = createNewVarsData(*m_BlockData[block.getBlockID()]);

bool shouldPushSucc = true;
if (succ->getBlockID() > block.getBlockID()) {
if (m_LoopMem == *m_BlockData[block.getBlockID()])
shouldPushSucc = false;

for (const VarDecl* i : *m_BlockData[block.getBlockID()])
m_LoopMem.insert(i);
}

if (shouldPushSucc)
m_CFGQueue.insert(succ->getBlockID());

mergeVarsData(succData.get(), m_BlockData[block.getBlockID()].get());
}
// FIXME: Information about the varied variables is stored in the last block,
// so we should be able to get it form there
for (const VarDecl* i : *m_BlockData[block.getBlockID()])
m_VariedDecls.insert(i);
}

bool VariedAnalyzer::isVaried(const VarDecl* VD) const {
const VarsData& curBranch = getCurBlockVarsData();
return curBranch.find(VD) != curBranch.end();
}

void VariedAnalyzer::copyVarToCurBlock(const clang::VarDecl* VD) {
VarsData& curBranch = getCurBlockVarsData();
curBranch.insert(VD);
}

bool VariedAnalyzer::VisitBinaryOperator(BinaryOperator* BinOp) {
Expr* L = BinOp->getLHS();
Expr* R = BinOp->getRHS();
const auto opCode = BinOp->getOpcode();
if (BinOp->isAssignmentOp()) {
m_Varied = false;
TraverseStmt(R);
m_Marking = m_Varied;
TraverseStmt(L);
m_Marking = false;
} else if (opCode == BO_Add || opCode == BO_Sub || opCode == BO_Mul ||
opCode == BO_Div) {
for (auto* subexpr : BinOp->children())
if (!isa<BinaryOperator>(subexpr))
TraverseStmt(subexpr);
}
return true;
}

// add branching merging
bool VariedAnalyzer::VisitConditionalOperator(ConditionalOperator* CO) {
TraverseStmt(CO->getCond());
TraverseStmt(CO->getTrueExpr());
TraverseStmt(CO->getFalseExpr());
return true;
}

bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) {
FunctionDecl* FD = CE->getDirectCallee();
bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams());
if (noHiddenParam) {
MutableArrayRef<ParmVarDecl*> FDparam = FD->parameters();
for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) {
clang::Expr* par = CE->getArg(i);
TraverseStmt(par);
m_VariedDecls.insert(FDparam[i]);
}
}
m_Varied = true;
return true;
}

bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) {
for (Decl* D : DS->decls()) {
if (Expr* init = cast<VarDecl>(D)->getInit()) {
m_Varied = false;
TraverseStmt(init);
m_Marking = true;
if (m_Varied)
copyVarToCurBlock(cast<VarDecl>(D));
m_Marking = false;
}
}
return true;
}

bool VariedAnalyzer::VisitUnaryOperator(UnaryOperator* UnOp) {
const auto opCode = UnOp->getOpcode();
Expr* E = UnOp->getSubExpr();
if (opCode == UO_AddrOf || opCode == UO_Deref) {
m_Varied = true;
m_Marking = true;
}
TraverseStmt(E);
m_Marking = false;
return true;
}

bool VariedAnalyzer::VisitDeclRefExpr(DeclRefExpr* DRE) {
auto* VD = dyn_cast<VarDecl>(DRE->getDecl());
if (!VD)
return true;

if (isVaried(VD))
m_Varied = true;

if (m_Varied && m_Marking)
copyVarToCurBlock(VD);
return true;
}
} // namespace clad
82 changes: 82 additions & 0 deletions lib/Differentiator/ActivityAnalyzer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#ifndef CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H
#define CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H

#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Analysis/CFG.h"

#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/Compatibility.h"

#include <algorithm>
#include <iterator>
#include <memory>
#include <set>
#include <utility>

namespace clad {

/// Class that implemets Varied part of the Activity analysis.
/// By performing static data-flow analysis, so called Varied variables
/// are determined, meaning variables that depend on input parameters
/// in a differentiable way. That result enables us to remove redundant
/// statements in the reverse mode, improving generated codes efficiency.
class VariedAnalyzer : public clang::RecursiveASTVisitor<VariedAnalyzer> {
bool m_Varied = false;
bool m_Marking = false;
using VarsData = std::set<const clang::VarDecl*>;
VarsData& m_VariedDecls;
/// A helper method to allocate VarsData
/// \param[in] toAssign - Parameter to initialize new VarsData with.
/// \return Unique pointer to a new object of type Varsdata.
static std::unique_ptr<VarsData> createNewVarsData(VarsData toAssign) {
return std::unique_ptr<VarsData>(new VarsData(std::move(toAssign)));
}
VarsData m_LoopMem;

clang::CFGBlock* getCFGBlockByID(unsigned ID);

clang::ASTContext& m_Context;
std::unique_ptr<clang::CFG> m_CFG;
std::vector<std::unique_ptr<VarsData>> m_BlockData;
unsigned m_CurBlockID{};
std::set<unsigned> m_CFGQueue;
/// Checks if a variable is on the current branch.
/// \param[in] VD - Variable declaration.
/// @return Whether a variable is on the current branch.
bool isVaried(const clang::VarDecl* VD) const;
/// Adds varied variable to current branch.
/// \param[in] VD - Variable declaration.
void copyVarToCurBlock(const clang::VarDecl* VD);
VarsData& getCurBlockVarsData() { return *m_BlockData[m_CurBlockID]; }
[[nodiscard]] const VarsData& getCurBlockVarsData() const {
return const_cast<VariedAnalyzer*>(this)->getCurBlockVarsData();
}
void AnalyzeCFGBlock(const clang::CFGBlock& block);

public:
/// Constructor
VariedAnalyzer(clang::ASTContext& Context,
std::set<const clang::VarDecl*>& Decls)
: m_VariedDecls(Decls), m_Context(Context) {}

/// Destructor
~VariedAnalyzer() = default;

/// Delete copy/move operators and constructors.
VariedAnalyzer(const VariedAnalyzer&) = delete;
VariedAnalyzer& operator=(const VariedAnalyzer&) = delete;
VariedAnalyzer(const VariedAnalyzer&&) = delete;
VariedAnalyzer& operator=(const VariedAnalyzer&&) = delete;

/// Runs Varied analysis.
/// \param[in] FD Function to run the analysis on.
void Analyze(const clang::FunctionDecl* FD);
bool VisitBinaryOperator(clang::BinaryOperator* BinOp);
bool VisitCallExpr(clang::CallExpr* CE);
bool VisitConditionalOperator(clang::ConditionalOperator* CO);
bool VisitDeclRefExpr(clang::DeclRefExpr* DRE);
bool VisitDeclStmt(clang::DeclStmt* DS);
bool VisitUnaryOperator(clang::UnaryOperator* UnOp);
};
} // namespace clad
#endif // CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H
1 change: 1 addition & 0 deletions lib/Differentiator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ set_property(SOURCE Version.cpp APPEND PROPERTY
# (Ab)use llvm facilities for adding libraries.
llvm_add_library(cladDifferentiator
STATIC
ActivityAnalyzer.cpp
BaseForwardModeVisitor.cpp
CladUtils.cpp
ConstantFolder.cpp
Expand Down
Loading

0 comments on commit 2d08ce1

Please sign in to comment.