Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
format
  • Loading branch information
Max Andriychuk committed Oct 19, 2024
1 parent 2d08ce1 commit 9a31722
Show file tree
Hide file tree
Showing 11 changed files with 322 additions and 283 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/CladConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ enum opts : unsigned {
disable_tbr = 1 << (ORDER_BITS + 3),
enable_va = 1 << (ORDER_BITS + 5),
disable_va = 1 << (ORDER_BITS + 6),
enable_ua = 1 << (ORDER_BITS + 7),
disable_ua = 1 << (ORDER_BITS + 8),

// Specifying whether we only want the diagonal of the hessian.
diagonal_only = 1 << (ORDER_BITS + 4),
Expand Down
11 changes: 11 additions & 0 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,16 @@ struct DiffRequest {
bool HasAnalysisRun = false;
} m_ActivityRunInfo;

mutable struct UsefulRunInfo {
std::set<const clang::VarDecl*> UsefulDecls;
std::set<const clang::FunctionDecl*> UsefulFuncs;
bool HasAnalysisRun = false;
} m_UsefulRunInfo;

public:
/// Function to be differentiated.
const clang::FunctionDecl* Function = nullptr;
bool ReqAdj = true;
/// Name of the base function to be differentiated. Can be different from
/// function->getNameAsString() when higher-order derivatives are computed.
std::string BaseFunctionName = {};
Expand All @@ -65,6 +72,7 @@ struct DiffRequest {
/// A flag to enable TBR analysis during reverse-mode differentiation.
bool EnableTBRAnalysis = false;
bool EnableVariedAnalysis = false;
bool EnableUsefulAnalysis = false;
/// Puts the derived function and its code in the diff call
void updateCall(clang::FunctionDecl* FD, clang::FunctionDecl* OverloadedFD,
clang::Sema& SemaRef);
Expand Down Expand Up @@ -123,6 +131,7 @@ struct DiffRequest {
CallContext == other.CallContext && Args == other.Args &&
Mode == other.Mode && EnableTBRAnalysis == other.EnableTBRAnalysis &&
EnableVariedAnalysis == other.EnableVariedAnalysis &&
EnableUsefulAnalysis == other.EnableUsefulAnalysis &&
DVI == other.DVI && use_enzyme == other.use_enzyme &&
DeclarationOnly == other.DeclarationOnly;
}
Expand All @@ -141,6 +150,7 @@ struct DiffRequest {

bool shouldBeRecorded(clang::Expr* E) const;
bool shouldHaveAdjoint(const clang::VarDecl* VD) const;
bool shouldHaveAdjointForw(const clang::VarDecl* VD) const;
};

using DiffInterval = std::vector<clang::SourceRange>;
Expand All @@ -150,6 +160,7 @@ struct DiffRequest {
/// TBR analysis during reverse-mode differentiation.
bool EnableTBRAnalysis = false;
bool EnableVariedAnalysis = false;
bool EnableUsefulAnalysis = false;
};

class DiffCollector: public clang::RecursiveASTVisitor<DiffCollector> {
Expand Down
24 changes: 21 additions & 3 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,14 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
// If DRE is of type pointer, then the derivative is a null pointer.
if (clonedDRE->getType()->isPointerType())
return StmtDiff(clonedDRE, nullptr);

if (auto* i = cast<VarDecl>(DRE->getDecl())) {
if (!m_DiffReq.shouldHaveAdjointForw(i))
return StmtDiff(clonedDRE, nullptr);
}

QualType literalTy = utils::GetValueType(clonedDRE->getType());

return StmtDiff(clonedDRE, ConstantFolder::synthesizeLiteral(
literalTy, m_Context, /*val=*/0));
}
Expand Down Expand Up @@ -1208,6 +1215,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
}

llvm::SmallVector<Expr*, 16> pushforwardFnArgs;
//
pushforwardFnArgs.insert(pushforwardFnArgs.end(), CallArgs.begin(),
CallArgs.end());
pushforwardFnArgs.insert(pushforwardFnArgs.end(), diffArgs.begin(),
Expand Down Expand Up @@ -1284,6 +1292,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
pushforwardFnRequest.BaseFunctionName = utils::ComputeEffectiveFnName(FD);
// Silence diag outputs in nested derivation process.
pushforwardFnRequest.VerboseDiags = false;
pushforwardFnRequest.EnableUsefulAnalysis = m_DiffReq.EnableUsefulAnalysis;

// Check if request already derived in DerivedFunctions.
FunctionDecl* pushforwardFD =
Expand Down Expand Up @@ -1446,7 +1455,8 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) {
derivedR = BuildParens(derivedR);
opDiff = BuildOp(opCode, derivedL, derivedR);
} else if (BinOp->isAssignmentOp()) {
if (Ldiff.getExpr_dx()->isModifiableLvalue(m_Context) != Expr::MLV_Valid) {
if (Ldiff.getExpr_dx() &&
Ldiff.getExpr_dx()->isModifiableLvalue(m_Context) != Expr::MLV_Valid) {
diag(DiagnosticsEngine::Warning, BinOp->getEndLoc(),
"derivative of an assignment attempts to assign to unassignable "
"expr, assignment ignored");
Expand Down Expand Up @@ -1575,11 +1585,16 @@ BaseForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD,
VarDecl* VDDerived =
BuildVarDecl(VD->getType(), "_d_" + VD->getNameAsString(), initDx,
VD->isDirectInit(), /*TSI=*/nullptr, VD->getInitStyle());
m_Variables.emplace(VDClone, BuildDeclRef(VDDerived));

if (!m_DiffReq.shouldHaveAdjointForw(VD))
VDDerived = nullptr;
else
m_Variables.emplace(VDClone, BuildDeclRef(VDDerived));
return DeclDiff<VarDecl>(VDClone, VDDerived);
}

StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) {
// llvm::errs() << "\nVisitDeclStmt";
llvm::SmallVector<Decl*, 4> decls;
llvm::SmallVector<Decl*, 4> declsDiff;
// If the type is marked as non_differentiable, skip generating its derivative
Expand Down Expand Up @@ -1642,7 +1657,8 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) {
if (VDDiff.getDecl()->getDeclName() != VD->getDeclName())
m_DeclReplacements[VD] = VDDiff.getDecl();
decls.push_back(VDDiff.getDecl());
declsDiff.push_back(VDDiff.getDecl_dx());
if (m_DiffReq.shouldHaveAdjointForw(VD))
declsDiff.push_back(VDDiff.getDecl_dx());
} else if (auto* SAD = dyn_cast<StaticAssertDecl>(D)) {
DeclDiff<StaticAssertDecl> SADDiff = DifferentiateStaticAssertDecl(SAD);
if (SADDiff.getDecl())
Expand All @@ -1661,6 +1677,8 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) {
DSClone = BuildDeclStmt(decls);
if (!declsDiff.empty())
DSDiff = BuildDeclStmt(declsDiff);
// llvm::errs() << "\n=====";
// DSDiff->dump();
return StmtDiff(DSClone, DSDiff);
}

Expand Down
1 change: 1 addition & 0 deletions lib/Differentiator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ llvm_add_library(cladDifferentiator
ReverseModeVisitor.cpp
TBRAnalyzer.cpp
StmtClone.cpp
UsefulAnalyzer.cpp
VectorForwardModeVisitor.cpp
VectorPushForwardModeVisitor.cpp
Version.cpp
Expand Down
31 changes: 31 additions & 0 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "ActivityAnalyzer.h"
#include "TBRAnalyzer.h"
#include "UsefulAnalyzer.h"

#include "clang/AST/ASTContext.h"
#include "clang/AST/RecursiveASTVisitor.h"
Expand Down Expand Up @@ -636,6 +637,26 @@ namespace clad {
return found != m_ActivityRunInfo.ToBeRecorded.end();
}

bool DiffRequest::shouldHaveAdjointForw(const VarDecl* VD) const {
if (!EnableUsefulAnalysis)
return true;

if (!m_UsefulRunInfo.HasAnalysisRun) {

UsefulAnalyzer analyzer(Function->getASTContext(),
m_UsefulRunInfo.UsefulDecls,
m_UsefulRunInfo.UsefulFuncs);
analyzer.Analyze(Function);
m_UsefulRunInfo.HasAnalysisRun = true;
// llvm::errs() << "ToBeRecorded: ";
// for (auto* i : m_UsefulRunInfo.UsefulDecls)
// llvm::errs() << i->getNameAsString() << " ";
// llvm::errs() << "\n";
}
auto found = m_UsefulRunInfo.UsefulDecls.find(VD);
return found != m_UsefulRunInfo.UsefulDecls.end();
}

bool DiffCollector::VisitCallExpr(CallExpr* E) {
// Check if we should look into this.
// FIXME: Generated code does not usually have valid source locations.
Expand Down Expand Up @@ -669,6 +690,8 @@ namespace clad {
bool disable_tbr_in_req = false;
bool enable_va_in_req = false;
bool disable_va_in_req = false;
bool enable_ua_in_req = false;
bool disable_ua_in_req = false;
if (!A->getAnnotation().equals("E") &&
FD->getTemplateSpecializationArgs()) {
const auto template_arg = FD->getTemplateSpecializationArgs()->get(0);
Expand All @@ -689,6 +712,10 @@ namespace clad {
clad::HasOption(bitmasked_opts_value, clad::opts::enable_va);
disable_va_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::disable_va);
enable_ua_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::enable_ua);
disable_ua_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::disable_ua);
if (enable_tbr_in_req && disable_tbr_in_req) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Both enable and disable TBR options are specified.");
Expand All @@ -711,6 +738,10 @@ namespace clad {
} else {
request.EnableVariedAnalysis = m_Options.EnableVariedAnalysis;
}
if (enable_ua_in_req || disable_ua_in_req)
request.EnableUsefulAnalysis = enable_ua_in_req && !disable_ua_in_req;
else
request.EnableUsefulAnalysis = m_Options.EnableUsefulAnalysis;
if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only)) {
if (!A->getAnnotation().equals("H")) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
Expand Down
1 change: 0 additions & 1 deletion lib/Differentiator/PushForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ StmtDiff PushForwardModeVisitor::VisitReturnStmt(const ReturnStmt* RS) {
// If there is no return value, we must not attempt to differentiate
if (!RS->getRetValue())
return nullptr;

StmtDiff retValDiff = Visit(RS->getRetValue());
Expr* retVal = retValDiff.getExpr();
Expr* retVal_dx = retValDiff.getExpr_dx();
Expand Down
156 changes: 156 additions & 0 deletions lib/Differentiator/UsefulAnalyzer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
#include "UsefulAnalyzer.h"

using namespace clang;

namespace clad {

void UsefulAnalyzer::Analyze(const FunctionDecl* FD) {
// Build the CFG (control-flow graph) of FD.
clang::CFG::BuildOptions Options;
m_CFG = clang::CFG::buildCFG(FD, FD->getBody(), &m_Context, Options);

m_BlockData.resize(m_CFG->size());
// Set current block ID to the ID of entry the block.
CFGBlock* exit = &m_CFG->getExit();
m_CurBlockID = exit->getBlockID();
m_BlockData[m_CurBlockID] = createNewVarsData({});
for (const VarDecl* i : m_UsefulDecls)
m_BlockData[m_CurBlockID]->insert(i);
// Add the entry block to the queue.
m_CFGQueue.insert(m_CurBlockID);

// Visit CFG blocks in the queue until it's empty.
while (!m_CFGQueue.empty()) {
auto IDIter = m_CFGQueue.begin();
m_CurBlockID = *IDIter;
m_CFGQueue.erase(IDIter);
CFGBlock& nextBlock = *getCFGBlockByID(m_CurBlockID);
AnalyzeCFGBlock(nextBlock);
}
}

CFGBlock* UsefulAnalyzer::getCFGBlockByID(unsigned ID) {
return *(m_CFG->begin() + ID);
}

bool UsefulAnalyzer::isUseful(const VarDecl* VD) const {
const VarsData& curBranch = getCurBlockVarsData();
return curBranch.find(VD) != curBranch.end();
}

void UsefulAnalyzer::copyVarToCurBlock(const clang::VarDecl* VD) {
VarsData& curBranch = getCurBlockVarsData();
curBranch.insert(VD);
}

void mergeVarsData(std::set<const clang::VarDecl*>* targetData,
std::set<const clang::VarDecl*>* mergeData) {
for (const clang::VarDecl* i : *mergeData)
targetData->insert(i);
*mergeData = *targetData;
}

void UsefulAnalyzer::AnalyzeCFGBlock(const CFGBlock& block) {

for (auto ib = block.end(); ib != block.begin() - 1; ib--) {
if (ib->getKind() == clang::CFGElement::Statement) {

const clang::Stmt* S = ib->castAs<clang::CFGStmt>().getStmt();
// The const_cast is inevitable, since there is no
// ConstRecusiveASTVisitor.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
TraverseStmt(const_cast<clang::Stmt*>(S));
}
}

for (const clang::CFGBlock::AdjacentBlock pred : block.preds()) {
if (!pred)
continue;
auto& predData = m_BlockData[pred->getBlockID()];
if (!predData)
predData = createNewVarsData(*m_BlockData[block.getBlockID()]);

bool shouldPushPred = true;
if (pred->getBlockID() < block.getBlockID()) {
if (m_LoopMem == *m_BlockData[block.getBlockID()])
shouldPushPred = false;

for (const VarDecl* i : *m_BlockData[block.getBlockID()])
m_LoopMem.insert(i);
}

if (shouldPushPred)
m_CFGQueue.insert(pred->getBlockID());

mergeVarsData(predData.get(), m_BlockData[block.getBlockID()].get());
}

for (const VarDecl* i : *m_BlockData[block.getBlockID()])
m_UsefulDecls.insert(i);
}

bool UsefulAnalyzer::VisitBinaryOperator(BinaryOperator* BinOp) {
Expr* L = BinOp->getLHS();
Expr* R = BinOp->getRHS();
const auto opCode = BinOp->getOpcode();
if (BinOp->isAssignmentOp()) {
m_Useful = false;
TraverseStmt(L);
m_Marking = m_Useful;
TraverseStmt(R);
m_Marking = false;
} else if (opCode == BO_Add || opCode == BO_Sub || opCode == BO_Mul ||
opCode == BO_Div) {
for (auto* subexpr : BinOp->children())
if (!isa<BinaryOperator>(subexpr))
TraverseStmt(subexpr);
}
return true;
}

bool UsefulAnalyzer::VisitDeclStmt(DeclStmt* DS) {
for (Decl* D : DS->decls()) {
if (auto* VD = dyn_cast<VarDecl>(D)) {
if (isUseful(VD)) {
m_Useful = true;
m_Marking = true;
}
if (Expr* init = cast<VarDecl>(D)->getInit())
TraverseStmt(init);
m_Marking = false;
}
}
return true;
}

bool UsefulAnalyzer::VisitReturnStmt(ReturnStmt* RS) {
m_Useful = true;
m_Marking = true;
auto* rv = RS->getRetValue();
TraverseStmt(rv);
return true;
}

bool UsefulAnalyzer::VisitCallExpr(CallExpr* CE) {
if (m_Useful)
return true;
FunctionDecl* FD = CE->getDirectCallee();
m_UsefulFuncs.insert(FD);
return true;
}

bool UsefulAnalyzer::VisitDeclRefExpr(DeclRefExpr* DRE) {
auto* VD = dyn_cast<VarDecl>(DRE->getDecl());
if (!VD)
return true;

if (isUseful(VD))
m_Useful = true;

if (m_Useful && m_Marking)
copyVarToCurBlock(VD);

return true;
}

} // namespace clad
Loading

0 comments on commit 9a31722

Please sign in to comment.