-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Varied analysis to the reverse mode (#1084)
Partially addresses #716
- Loading branch information
Showing
12 changed files
with
648 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
#include "ActivityAnalyzer.h" | ||
|
||
using namespace clang; | ||
|
||
namespace clad { | ||
|
||
void VariedAnalyzer::Analyze(const FunctionDecl* FD) { | ||
// Build the CFG (control-flow graph) of FD. | ||
clang::CFG::BuildOptions Options; | ||
m_CFG = clang::CFG::buildCFG(FD, FD->getBody(), &m_Context, Options); | ||
|
||
m_BlockData.resize(m_CFG->size()); | ||
// Set current block ID to the ID of entry the block. | ||
CFGBlock* entry = &m_CFG->getEntry(); | ||
m_CurBlockID = entry->getBlockID(); | ||
m_BlockData[m_CurBlockID] = createNewVarsData({}); | ||
for (const VarDecl* i : m_VariedDecls) | ||
m_BlockData[m_CurBlockID]->insert(i); | ||
// Add the entry block to the queue. | ||
m_CFGQueue.insert(m_CurBlockID); | ||
|
||
// Visit CFG blocks in the queue until it's empty. | ||
while (!m_CFGQueue.empty()) { | ||
auto IDIter = std::prev(m_CFGQueue.end()); | ||
m_CurBlockID = *IDIter; | ||
m_CFGQueue.erase(IDIter); | ||
CFGBlock& nextBlock = *getCFGBlockByID(m_CurBlockID); | ||
AnalyzeCFGBlock(nextBlock); | ||
} | ||
} | ||
|
||
void mergeVarsData(std::set<const clang::VarDecl*>* targetData, | ||
std::set<const clang::VarDecl*>* mergeData) { | ||
for (const clang::VarDecl* i : *mergeData) | ||
targetData->insert(i); | ||
*mergeData = *targetData; | ||
} | ||
|
||
CFGBlock* VariedAnalyzer::getCFGBlockByID(unsigned ID) { | ||
return *(m_CFG->begin() + ID); | ||
} | ||
|
||
void VariedAnalyzer::AnalyzeCFGBlock(const CFGBlock& block) { | ||
// Visit all the statements inside the block. | ||
for (const clang::CFGElement& Element : block) { | ||
if (Element.getKind() == clang::CFGElement::Statement) { | ||
const clang::Stmt* S = Element.castAs<clang::CFGStmt>().getStmt(); | ||
// The const_cast is inevitable, since there is no | ||
// ConstRecusiveASTVisitor. | ||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) | ||
TraverseStmt(const_cast<clang::Stmt*>(S)); | ||
} | ||
} | ||
|
||
for (const clang::CFGBlock::AdjacentBlock succ : block.succs()) { | ||
if (!succ) | ||
continue; | ||
auto& succData = m_BlockData[succ->getBlockID()]; | ||
|
||
if (!succData) | ||
succData = createNewVarsData(*m_BlockData[block.getBlockID()]); | ||
|
||
bool shouldPushSucc = true; | ||
if (succ->getBlockID() > block.getBlockID()) { | ||
if (m_LoopMem == *m_BlockData[block.getBlockID()]) | ||
shouldPushSucc = false; | ||
|
||
for (const VarDecl* i : *m_BlockData[block.getBlockID()]) | ||
m_LoopMem.insert(i); | ||
} | ||
|
||
if (shouldPushSucc) | ||
m_CFGQueue.insert(succ->getBlockID()); | ||
|
||
mergeVarsData(succData.get(), m_BlockData[block.getBlockID()].get()); | ||
} | ||
// FIXME: Information about the varied variables is stored in the last block, | ||
// so we should be able to get it form there | ||
for (const VarDecl* i : *m_BlockData[block.getBlockID()]) | ||
m_VariedDecls.insert(i); | ||
} | ||
|
||
bool VariedAnalyzer::isVaried(const VarDecl* VD) const { | ||
const VarsData& curBranch = getCurBlockVarsData(); | ||
return curBranch.find(VD) != curBranch.end(); | ||
} | ||
|
||
void VariedAnalyzer::copyVarToCurBlock(const clang::VarDecl* VD) { | ||
VarsData& curBranch = getCurBlockVarsData(); | ||
curBranch.insert(VD); | ||
} | ||
|
||
bool VariedAnalyzer::VisitBinaryOperator(BinaryOperator* BinOp) { | ||
Expr* L = BinOp->getLHS(); | ||
Expr* R = BinOp->getRHS(); | ||
const auto opCode = BinOp->getOpcode(); | ||
if (BinOp->isAssignmentOp()) { | ||
m_Varied = false; | ||
TraverseStmt(R); | ||
m_Marking = m_Varied; | ||
TraverseStmt(L); | ||
m_Marking = false; | ||
} else if (opCode == BO_Add || opCode == BO_Sub || opCode == BO_Mul || | ||
opCode == BO_Div) { | ||
for (auto* subexpr : BinOp->children()) | ||
if (!isa<BinaryOperator>(subexpr)) | ||
TraverseStmt(subexpr); | ||
} | ||
return true; | ||
} | ||
|
||
// add branching merging | ||
bool VariedAnalyzer::VisitConditionalOperator(ConditionalOperator* CO) { | ||
TraverseStmt(CO->getCond()); | ||
TraverseStmt(CO->getTrueExpr()); | ||
TraverseStmt(CO->getFalseExpr()); | ||
return true; | ||
} | ||
|
||
bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) { | ||
FunctionDecl* FD = CE->getDirectCallee(); | ||
bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams()); | ||
if (noHiddenParam) { | ||
MutableArrayRef<ParmVarDecl*> FDparam = FD->parameters(); | ||
for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { | ||
clang::Expr* par = CE->getArg(i); | ||
TraverseStmt(par); | ||
m_VariedDecls.insert(FDparam[i]); | ||
} | ||
} | ||
m_Varied = true; | ||
return true; | ||
} | ||
|
||
bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) { | ||
for (Decl* D : DS->decls()) { | ||
if (Expr* init = cast<VarDecl>(D)->getInit()) { | ||
m_Varied = false; | ||
TraverseStmt(init); | ||
m_Marking = true; | ||
if (m_Varied) | ||
copyVarToCurBlock(cast<VarDecl>(D)); | ||
m_Marking = false; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
bool VariedAnalyzer::VisitUnaryOperator(UnaryOperator* UnOp) { | ||
const auto opCode = UnOp->getOpcode(); | ||
Expr* E = UnOp->getSubExpr(); | ||
if (opCode == UO_AddrOf || opCode == UO_Deref) { | ||
m_Varied = true; | ||
m_Marking = true; | ||
} | ||
TraverseStmt(E); | ||
m_Marking = false; | ||
return true; | ||
} | ||
|
||
bool VariedAnalyzer::VisitDeclRefExpr(DeclRefExpr* DRE) { | ||
auto* VD = dyn_cast<VarDecl>(DRE->getDecl()); | ||
if (!VD) | ||
return true; | ||
|
||
if (isVaried(VD)) | ||
m_Varied = true; | ||
|
||
if (m_Varied && m_Marking) | ||
copyVarToCurBlock(VD); | ||
return true; | ||
} | ||
} // namespace clad |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
#ifndef CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H | ||
#define CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H | ||
|
||
#include "clang/AST/RecursiveASTVisitor.h" | ||
#include "clang/Analysis/CFG.h" | ||
|
||
#include "clad/Differentiator/CladUtils.h" | ||
#include "clad/Differentiator/Compatibility.h" | ||
|
||
#include <algorithm> | ||
#include <iterator> | ||
#include <memory> | ||
#include <set> | ||
#include <utility> | ||
|
||
namespace clad { | ||
|
||
/// Class that implemets Varied part of the Activity analysis. | ||
/// By performing static data-flow analysis, so called Varied variables | ||
/// are determined, meaning variables that depend on input parameters | ||
/// in a differentiable way. That result enables us to remove redundant | ||
/// statements in the reverse mode, improving generated codes efficiency. | ||
class VariedAnalyzer : public clang::RecursiveASTVisitor<VariedAnalyzer> { | ||
bool m_Varied = false; | ||
bool m_Marking = false; | ||
using VarsData = std::set<const clang::VarDecl*>; | ||
VarsData& m_VariedDecls; | ||
/// A helper method to allocate VarsData | ||
/// \param[in] toAssign - Parameter to initialize new VarsData with. | ||
/// \return Unique pointer to a new object of type Varsdata. | ||
static std::unique_ptr<VarsData> createNewVarsData(VarsData toAssign) { | ||
return std::unique_ptr<VarsData>(new VarsData(std::move(toAssign))); | ||
} | ||
VarsData m_LoopMem; | ||
|
||
clang::CFGBlock* getCFGBlockByID(unsigned ID); | ||
|
||
clang::ASTContext& m_Context; | ||
std::unique_ptr<clang::CFG> m_CFG; | ||
std::vector<std::unique_ptr<VarsData>> m_BlockData; | ||
unsigned m_CurBlockID{}; | ||
std::set<unsigned> m_CFGQueue; | ||
/// Checks if a variable is on the current branch. | ||
/// \param[in] VD - Variable declaration. | ||
/// @return Whether a variable is on the current branch. | ||
bool isVaried(const clang::VarDecl* VD) const; | ||
/// Adds varied variable to current branch. | ||
/// \param[in] VD - Variable declaration. | ||
void copyVarToCurBlock(const clang::VarDecl* VD); | ||
VarsData& getCurBlockVarsData() { return *m_BlockData[m_CurBlockID]; } | ||
[[nodiscard]] const VarsData& getCurBlockVarsData() const { | ||
return const_cast<VariedAnalyzer*>(this)->getCurBlockVarsData(); | ||
} | ||
void AnalyzeCFGBlock(const clang::CFGBlock& block); | ||
|
||
public: | ||
/// Constructor | ||
VariedAnalyzer(clang::ASTContext& Context, | ||
std::set<const clang::VarDecl*>& Decls) | ||
: m_VariedDecls(Decls), m_Context(Context) {} | ||
|
||
/// Destructor | ||
~VariedAnalyzer() = default; | ||
|
||
/// Delete copy/move operators and constructors. | ||
VariedAnalyzer(const VariedAnalyzer&) = delete; | ||
VariedAnalyzer& operator=(const VariedAnalyzer&) = delete; | ||
VariedAnalyzer(const VariedAnalyzer&&) = delete; | ||
VariedAnalyzer& operator=(const VariedAnalyzer&&) = delete; | ||
|
||
/// Runs Varied analysis. | ||
/// \param[in] FD Function to run the analysis on. | ||
void Analyze(const clang::FunctionDecl* FD); | ||
bool VisitBinaryOperator(clang::BinaryOperator* BinOp); | ||
bool VisitCallExpr(clang::CallExpr* CE); | ||
bool VisitConditionalOperator(clang::ConditionalOperator* CO); | ||
bool VisitDeclRefExpr(clang::DeclRefExpr* DRE); | ||
bool VisitDeclStmt(clang::DeclStmt* DS); | ||
bool VisitUnaryOperator(clang::UnaryOperator* UnOp); | ||
}; | ||
} // namespace clad | ||
#endif // CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.