Skip to content

Commit

Permalink
CallExpr fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Max Andriychuk authored and Max Andriychuk committed Oct 4, 2024
1 parent 431c490 commit 59ea27e
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 51 deletions.
2 changes: 1 addition & 1 deletion include/clad/Differentiator/CladConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ enum opts : unsigned {
// 00 - default, 01 - enable, 10 - disable, 11 - not used / invalid
enable_tbr = 1 << (ORDER_BITS + 2),
disable_tbr = 1 << (ORDER_BITS + 3),
enable_aa = 1 << (ORDER_BITS + 5),
enable_va = 1 << (ORDER_BITS + 5),
disable_aa = 1 << (ORDER_BITS + 6),

// Specifying whether we only want the diagonal of the hessian.
Expand Down
6 changes: 3 additions & 3 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ struct DiffRequest {
bool VerboseDiags = false;
/// A flag to enable TBR analysis during reverse-mode differentiation.
bool EnableTBRAnalysis = false;
bool EnableActivityAnalysis = false;
bool EnableVariedAnalysis = false;
/// Puts the derived function and its code in the diff call
void updateCall(clang::FunctionDecl* FD, clang::FunctionDecl* OverloadedFD,
clang::Sema& SemaRef);
Expand Down Expand Up @@ -120,7 +120,7 @@ struct DiffRequest {
RequestedDerivativeOrder == other.RequestedDerivativeOrder &&
CallContext == other.CallContext && Args == other.Args &&
Mode == other.Mode && EnableTBRAnalysis == other.EnableTBRAnalysis &&
EnableActivityAnalysis == other.EnableActivityAnalysis &&
EnableVariedAnalysis == other.EnableVariedAnalysis &&
DVI == other.DVI && use_enzyme == other.use_enzyme &&
DeclarationOnly == other.DeclarationOnly;
}
Expand All @@ -147,7 +147,7 @@ struct DiffRequest {
/// This is a flag to indicate the default behaviour to enable/disable
/// TBR analysis during reverse-mode differentiation.
bool EnableTBRAnalysis = false;
bool EnableActivityAnalysis = false;
bool EnableVariedAnalysis = false;
};

class DiffCollector: public clang::RecursiveASTVisitor<DiffCollector> {
Expand Down
39 changes: 24 additions & 15 deletions lib/Differentiator/ActivityAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ void VariedAnalyzer::Analyze(const FunctionDecl* FD) {
}
}

void mergeVarsData(VarsData* targetData, VarsData* mergeData) {
for (const clang::VarDecl* i : *mergeData)
targetData->insert(i);
for (const clang::VarDecl* i : *targetData)
mergeData->insert(i);
}

CFGBlock* VariedAnalyzer::getCFGBlockByID(unsigned ID) {
return *(m_CFG->begin() + ID);
}
Expand Down Expand Up @@ -86,16 +93,18 @@ void VariedAnalyzer::copyVarToCurBlock(const clang::VarDecl* VD) {
bool VariedAnalyzer::VisitBinaryOperator(BinaryOperator* BinOp) {
Expr* L = BinOp->getLHS();
Expr* R = BinOp->getRHS();

const auto opCode = BinOp->getOpcode();
if (BinOp->isAssignmentOp()) {
m_Varied = false;
TraverseStmt(R);
m_Marking = m_Varied;
TraverseStmt(L);
m_Marking = false;
} else {
TraverseStmt(L);
TraverseStmt(R);
} else if (opCode == BO_Add || opCode == BO_Sub || opCode == BO_Mul ||
opCode == BO_Div) {
for (auto* subexpr : BinOp->children())
if (!isa<BinaryOperator>(subexpr))
TraverseStmt(subexpr);
}
return true;
}
Expand All @@ -111,18 +120,15 @@ bool VariedAnalyzer::VisitConditionalOperator(ConditionalOperator* CO) {
bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) {
FunctionDecl* FD = CE->getDirectCallee();
bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams());
std::set<const clang::VarDecl*> variedParam;
if (noHiddenParam) {
MutableArrayRef<ParmVarDecl*> FDparam = FD->parameters();
for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) {
clang::Expr* par = CE->getArg(i);
TraverseStmt(par);
if (m_Varied || 1) {
m_VariedDecls.insert(FDparam[i]);
m_Varied = false;
}
m_VariedDecls.insert(FDparam[i]);
}
}
m_Varied = true;
return true;
}

Expand Down Expand Up @@ -150,8 +156,6 @@ bool VariedAnalyzer::VisitUnaryOperator(UnaryOperator* UnOp) {
m_Marking = true;
}
TraverseStmt(E);
m_Varied = false;
m_Marking = false;
return true;
}

Expand All @@ -161,10 +165,15 @@ bool VariedAnalyzer::VisitDeclRefExpr(DeclRefExpr* DRE) {
if (isVaried(dyn_cast<VarDecl>(DRE->getDecl())))
m_Varied = true;

if (const auto* VD = dyn_cast<VarDecl>(DRE->getDecl())) {
if (m_Varied && m_Marking)
copyVarToCurBlock(VD);
}
auto* VD = dyn_cast<VarDecl>(DRE->getDecl());
if (!VD)
return true;

if (isVaried(VD))
m_Varied = true;

if (m_Varied && m_Marking)
copyVarToCurBlock(VD);
return true;
}
} // namespace clad
16 changes: 7 additions & 9 deletions lib/Differentiator/ActivityAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,14 @@
/// statements in the reverse mode, improving generated codes efficiency.
namespace clad {
using VarsData = std::set<const clang::VarDecl*>;
static inline void mergeVarsData(VarsData* targetData, VarsData* mergeData) {
for (const clang::VarDecl* i : *mergeData)
targetData->insert(i);
for (const clang::VarDecl* i : *targetData)
mergeData->insert(i);
}
class VariedAnalyzer : public clang::RecursiveASTVisitor<VariedAnalyzer> {

bool m_Varied = false;
bool m_Marking = false;

std::set<const clang::VarDecl*>& m_VariedDecls;
// using VarsData = std::set<const clang::VarDecl*>;
/// A helper method to allocate VarsData
/// \param toAssign - Parameter to initialize new VarsData with.
/// \param[in] toAssign - Parameter to initialize new VarsData with.
/// \return Unique pointer to a new object of type Varsdata.
static std::unique_ptr<VarsData> createNewVarsData(VarsData toAssign) {
return std::unique_ptr<VarsData>(new VarsData(std::move(toAssign)));
Expand All @@ -47,7 +40,12 @@ class VariedAnalyzer : public clang::RecursiveASTVisitor<VariedAnalyzer> {
std::vector<std::unique_ptr<VarsData>> m_BlockData;
unsigned m_CurBlockID{};
std::set<unsigned> m_CFGQueue;
/// Checks if a variable is on the current branch.
/// \param[in] VD - Variable declaration.
/// @return Whether a variable is on the current branch.
bool isVaried(const clang::VarDecl* VD) const;
/// Adds varied variable to current branch.
/// \param[in] VD - Variable declaration.
void copyVarToCurBlock(const clang::VarDecl* VD);
VarsData& getCurBlockVarsData() { return *m_BlockData[m_CurBlockID]; }
[[nodiscard]] const VarsData& getCurBlockVarsData() const {
Expand All @@ -71,7 +69,7 @@ class VariedAnalyzer : public clang::RecursiveASTVisitor<VariedAnalyzer> {
VariedAnalyzer& operator=(const VariedAnalyzer&&) = delete;

/// Runs Varied analysis.
/// \param FD Function to run the analysis on.
/// \param[in] FD Function to run the analysis on.
void Analyze(const clang::FunctionDecl* FD);
bool VisitBinaryOperator(clang::BinaryOperator* BinOp);
bool VisitCallExpr(clang::CallExpr* CE);
Expand Down
17 changes: 8 additions & 9 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ namespace clad {
}

bool DiffRequest::shouldHaveAdjoint(const VarDecl* VD) const {
if (!EnableActivityAnalysis)
if (!EnableVariedAnalysis)
return true;

if (VD->getType()->isPointerType() || isa<ArrayType>(VD->getType()))
Expand Down Expand Up @@ -667,7 +667,7 @@ namespace clad {
unsigned bitmasked_opts_value = 0;
bool enable_tbr_in_req = false;
bool disable_tbr_in_req = false;
bool enable_aa_in_req = false;
bool enable_va_in_req = false;
bool disable_aa_in_req = false;
if (!A->getAnnotation().equals("E") &&
FD->getTemplateSpecializationArgs()) {
Expand All @@ -685,16 +685,16 @@ namespace clad {
disable_tbr_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::disable_tbr);
// Set option for Activity analysis.
enable_aa_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::enable_aa);
enable_va_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::enable_va);
disable_aa_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::disable_aa);
if (enable_tbr_in_req && disable_tbr_in_req) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Both enable and disable TBR options are specified.");
return true;
}
if (enable_aa_in_req && disable_aa_in_req) {
if (enable_va_in_req && disable_aa_in_req) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Both enable and disable AA options are specified.");
return true;
Expand All @@ -705,12 +705,11 @@ namespace clad {
} else {
request.EnableTBRAnalysis = m_Options.EnableTBRAnalysis;
}
if (enable_aa_in_req || disable_aa_in_req) {
if (enable_va_in_req || disable_aa_in_req) {
// override the default value of TBR analysis.
request.EnableActivityAnalysis =
enable_aa_in_req && !disable_aa_in_req;
request.EnableVariedAnalysis = enable_va_in_req && !disable_aa_in_req;
} else {
request.EnableActivityAnalysis = m_Options.EnableActivityAnalysis;
request.EnableVariedAnalysis = m_Options.EnableVariedAnalysis;
}
if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only)) {
if (!A->getAnnotation().equals("H")) {
Expand Down
3 changes: 1 addition & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1951,8 +1951,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Silence diag outputs in nested derivation process.
pullbackRequest.VerboseDiags = false;
pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis;
pullbackRequest.EnableActivityAnalysis =
m_DiffReq.EnableActivityAnalysis;
pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis;
bool isaMethod = isa<CXXMethodDecl>(FD);
for (size_t i = 0, e = FD->getNumParams(); i < e; ++i)
if (MD && isLambdaCallOperator(MD)) {
Expand Down
4 changes: 2 additions & 2 deletions test/Analyses/ActivityReverse.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: %cladclang %s -I%S/../../include -oActivity.out 2>&1 | %filecheck %s
// RUN: ./Activity.out | %filecheck_exec %s
// RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-aa %s -I%S/../../include -oActivity.out
// RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-va %s -I%S/../../include -oActivity.out
// RUN: ./Activity.out | %filecheck_exec %s
//CHECK-NOT: {{.*error|warning|note:.*}}

Expand Down Expand Up @@ -244,7 +244,7 @@ double f7(double x){

#define TEST(F, x) { \
result[0] = 0; \
auto F##grad = clad::gradient<clad::opts::enable_aa>(F);\
auto F##grad = clad::gradient<clad::opts::enable_va>(F);\
F##grad.execute(x, result);\
printf("{%.2f}\n", result[0]); \
}
Expand Down
8 changes: 4 additions & 4 deletions tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,11 +401,11 @@ namespace clad {
static void SetActivityAnalysisOptions(const DifferentiationOptions& DO,
RequestOptions& opts) {
// If user has explicitly specified the mode for AA, use it.
if (DO.EnableActivityAnalysis || DO.DisableActivityAnalysis)
opts.EnableActivityAnalysis =
DO.EnableActivityAnalysis && !DO.DisableActivityAnalysis;
if (DO.EnableVariedAnalysis || DO.DisableActivityAnalysis)
opts.EnableVariedAnalysis =
DO.EnableVariedAnalysis && !DO.DisableActivityAnalysis;
else
opts.EnableActivityAnalysis = false; // Default mode.
opts.EnableVariedAnalysis = false; // Default mode.
}

void CladPlugin::SetRequestOptions(RequestOptions& opts) const {
Expand Down
12 changes: 6 additions & 6 deletions tools/ClangPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class CladTimerGroup {
: DumpSourceFn(false), DumpSourceFnAST(false), DumpDerivedFn(false),
DumpDerivedAST(false), GenerateSourceFile(false),
ValidateClangVersion(true), EnableTBRAnalysis(false),
DisableTBRAnalysis(false), EnableActivityAnalysis(false),
DisableTBRAnalysis(false), EnableVariedAnalysis(false),
DisableActivityAnalysis(false), CustomEstimationModel(false),
PrintNumDiffErrorInfo(false) {}

Expand All @@ -67,7 +67,7 @@ class CladTimerGroup {
bool ValidateClangVersion : 1;
bool EnableTBRAnalysis : 1;
bool DisableTBRAnalysis : 1;
bool EnableActivityAnalysis : 1;
bool EnableVariedAnalysis : 1;
bool DisableActivityAnalysis : 1;
bool CustomEstimationModel : 1;
bool PrintNumDiffErrorInfo : 1;
Expand Down Expand Up @@ -317,8 +317,8 @@ class CladTimerGroup {
m_DO.EnableTBRAnalysis = true;
} else if (args[i] == "-disable-tbr") {
m_DO.DisableTBRAnalysis = true;
} else if (args[i] == "-enable-aa") {
m_DO.EnableActivityAnalysis = true;
} else if (args[i] == "-enable-va") {
m_DO.EnableVariedAnalysis = true;
} else if (args[i] == "-disable-aa") {
m_DO.DisableActivityAnalysis = true;
} else if (args[i] == "-fcustom-estimation-model") {
Expand Down Expand Up @@ -374,8 +374,8 @@ class CladTimerGroup {
"be used together.\n";
return false;
}
if (m_DO.EnableActivityAnalysis && m_DO.DisableActivityAnalysis) {
llvm::errs() << "clad: Error: -enable-aa and -disable-aa cannot "
if (m_DO.EnableVariedAnalysis && m_DO.DisableActivityAnalysis) {
llvm::errs() << "clad: Error: -enable-va and -disable-aa cannot "
"be used together.\n";
return false;
}
Expand Down

0 comments on commit 59ea27e

Please sign in to comment.