diff --git a/include/clad/Differentiator/CladConfig.h b/include/clad/Differentiator/CladConfig.h index 7ab8670ba..eb249d7e7 100644 --- a/include/clad/Differentiator/CladConfig.h +++ b/include/clad/Differentiator/CladConfig.h @@ -29,7 +29,7 @@ enum opts : unsigned { // 00 - default, 01 - enable, 10 - disable, 11 - not used / invalid enable_tbr = 1 << (ORDER_BITS + 2), disable_tbr = 1 << (ORDER_BITS + 3), - enable_aa = 1 << (ORDER_BITS + 5), + enable_va = 1 << (ORDER_BITS + 5), disable_aa = 1 << (ORDER_BITS + 6), // Specifying whether we only want the diagonal of the hessian. diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index aabff5210..5be53c439 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -62,7 +62,7 @@ struct DiffRequest { bool VerboseDiags = false; /// A flag to enable TBR analysis during reverse-mode differentiation. bool EnableTBRAnalysis = false; - bool EnableActivityAnalysis = false; + bool EnableVariedAnalysis = false; /// Puts the derived function and its code in the diff call void updateCall(clang::FunctionDecl* FD, clang::FunctionDecl* OverloadedFD, clang::Sema& SemaRef); @@ -120,7 +120,7 @@ struct DiffRequest { RequestedDerivativeOrder == other.RequestedDerivativeOrder && CallContext == other.CallContext && Args == other.Args && Mode == other.Mode && EnableTBRAnalysis == other.EnableTBRAnalysis && - EnableActivityAnalysis == other.EnableActivityAnalysis && + EnableVariedAnalysis == other.EnableVariedAnalysis && DVI == other.DVI && use_enzyme == other.use_enzyme && DeclarationOnly == other.DeclarationOnly; } @@ -147,7 +147,7 @@ struct DiffRequest { /// This is a flag to indicate the default behaviour to enable/disable /// TBR analysis during reverse-mode differentiation. bool EnableTBRAnalysis = false; - bool EnableActivityAnalysis = false; + bool EnableVariedAnalysis = false; }; class DiffCollector: public clang::RecursiveASTVisitor { diff --git a/lib/Differentiator/ActivityAnalyzer.cpp b/lib/Differentiator/ActivityAnalyzer.cpp index 881d1b930..2dc89acdf 100644 --- a/lib/Differentiator/ActivityAnalyzer.cpp +++ b/lib/Differentiator/ActivityAnalyzer.cpp @@ -29,6 +29,13 @@ void VariedAnalyzer::Analyze(const FunctionDecl* FD) { } } +void mergeVarsData(VarsData* targetData, VarsData* mergeData) { + for (const clang::VarDecl* i : *mergeData) + targetData->insert(i); + for (const clang::VarDecl* i : *targetData) + mergeData->insert(i); +} + CFGBlock* VariedAnalyzer::getCFGBlockByID(unsigned ID) { return *(m_CFG->begin() + ID); } @@ -86,16 +93,18 @@ void VariedAnalyzer::copyVarToCurBlock(const clang::VarDecl* VD) { bool VariedAnalyzer::VisitBinaryOperator(BinaryOperator* BinOp) { Expr* L = BinOp->getLHS(); Expr* R = BinOp->getRHS(); - + const auto opCode = BinOp->getOpcode(); if (BinOp->isAssignmentOp()) { m_Varied = false; TraverseStmt(R); m_Marking = m_Varied; TraverseStmt(L); m_Marking = false; - } else { - TraverseStmt(L); - TraverseStmt(R); + } else if (opCode == BO_Add || opCode == BO_Sub || opCode == BO_Mul || + opCode == BO_Div) { + for (auto* subexpr : BinOp->children()) + if (!isa(subexpr)) + TraverseStmt(subexpr); } return true; } @@ -111,18 +120,15 @@ bool VariedAnalyzer::VisitConditionalOperator(ConditionalOperator* CO) { bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) { FunctionDecl* FD = CE->getDirectCallee(); bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams()); - std::set variedParam; if (noHiddenParam) { MutableArrayRef FDparam = FD->parameters(); for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { clang::Expr* par = CE->getArg(i); TraverseStmt(par); - if (m_Varied || 1) { - m_VariedDecls.insert(FDparam[i]); - m_Varied = false; - } + m_VariedDecls.insert(FDparam[i]); } } + m_Varied = true; return true; } @@ -150,8 +156,6 @@ bool VariedAnalyzer::VisitUnaryOperator(UnaryOperator* UnOp) { m_Marking = true; } TraverseStmt(E); - m_Varied = false; - m_Marking = false; return true; } @@ -161,10 +165,15 @@ bool VariedAnalyzer::VisitDeclRefExpr(DeclRefExpr* DRE) { if (isVaried(dyn_cast(DRE->getDecl()))) m_Varied = true; - if (const auto* VD = dyn_cast(DRE->getDecl())) { - if (m_Varied && m_Marking) - copyVarToCurBlock(VD); - } + auto* VD = dyn_cast(DRE->getDecl()); + if (!VD) + return true; + + if (isVaried(VD)) + m_Varied = true; + + if (m_Varied && m_Marking) + copyVarToCurBlock(VD); return true; } } // namespace clad diff --git a/lib/Differentiator/ActivityAnalyzer.h b/lib/Differentiator/ActivityAnalyzer.h index e6abe249c..ce1de04bc 100644 --- a/lib/Differentiator/ActivityAnalyzer.h +++ b/lib/Differentiator/ActivityAnalyzer.h @@ -19,21 +19,14 @@ /// statements in the reverse mode, improving generated codes efficiency. namespace clad { using VarsData = std::set; -static inline void mergeVarsData(VarsData* targetData, VarsData* mergeData) { - for (const clang::VarDecl* i : *mergeData) - targetData->insert(i); - for (const clang::VarDecl* i : *targetData) - mergeData->insert(i); -} class VariedAnalyzer : public clang::RecursiveASTVisitor { bool m_Varied = false; bool m_Marking = false; std::set& m_VariedDecls; - // using VarsData = std::set; /// A helper method to allocate VarsData - /// \param toAssign - Parameter to initialize new VarsData with. + /// \param[in] toAssign - Parameter to initialize new VarsData with. /// \return Unique pointer to a new object of type Varsdata. static std::unique_ptr createNewVarsData(VarsData toAssign) { return std::unique_ptr(new VarsData(std::move(toAssign))); @@ -47,7 +40,12 @@ class VariedAnalyzer : public clang::RecursiveASTVisitor { std::vector> m_BlockData; unsigned m_CurBlockID{}; std::set m_CFGQueue; + /// Checks if a variable is on the current branch. + /// \param[in] VD - Variable declaration. + /// @return Whether a variable is on the current branch. bool isVaried(const clang::VarDecl* VD) const; + /// Adds varied variable to current branch. + /// \param[in] VD - Variable declaration. void copyVarToCurBlock(const clang::VarDecl* VD); VarsData& getCurBlockVarsData() { return *m_BlockData[m_CurBlockID]; } [[nodiscard]] const VarsData& getCurBlockVarsData() const { @@ -71,7 +69,7 @@ class VariedAnalyzer : public clang::RecursiveASTVisitor { VariedAnalyzer& operator=(const VariedAnalyzer&&) = delete; /// Runs Varied analysis. - /// \param FD Function to run the analysis on. + /// \param[in] FD Function to run the analysis on. void Analyze(const clang::FunctionDecl* FD); bool VisitBinaryOperator(clang::BinaryOperator* BinOp); bool VisitCallExpr(clang::CallExpr* CE); diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index 8cbb8ac35..44a52ab93 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -617,7 +617,7 @@ namespace clad { } bool DiffRequest::shouldHaveAdjoint(const VarDecl* VD) const { - if (!EnableActivityAnalysis) + if (!EnableVariedAnalysis) return true; if (VD->getType()->isPointerType() || isa(VD->getType())) @@ -667,7 +667,7 @@ namespace clad { unsigned bitmasked_opts_value = 0; bool enable_tbr_in_req = false; bool disable_tbr_in_req = false; - bool enable_aa_in_req = false; + bool enable_va_in_req = false; bool disable_aa_in_req = false; if (!A->getAnnotation().equals("E") && FD->getTemplateSpecializationArgs()) { @@ -685,8 +685,8 @@ namespace clad { disable_tbr_in_req = clad::HasOption(bitmasked_opts_value, clad::opts::disable_tbr); // Set option for Activity analysis. - enable_aa_in_req = - clad::HasOption(bitmasked_opts_value, clad::opts::enable_aa); + enable_va_in_req = + clad::HasOption(bitmasked_opts_value, clad::opts::enable_va); disable_aa_in_req = clad::HasOption(bitmasked_opts_value, clad::opts::disable_aa); if (enable_tbr_in_req && disable_tbr_in_req) { @@ -694,7 +694,7 @@ namespace clad { "Both enable and disable TBR options are specified."); return true; } - if (enable_aa_in_req && disable_aa_in_req) { + if (enable_va_in_req && disable_aa_in_req) { utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, "Both enable and disable AA options are specified."); return true; @@ -705,12 +705,11 @@ namespace clad { } else { request.EnableTBRAnalysis = m_Options.EnableTBRAnalysis; } - if (enable_aa_in_req || disable_aa_in_req) { + if (enable_va_in_req || disable_aa_in_req) { // override the default value of TBR analysis. - request.EnableActivityAnalysis = - enable_aa_in_req && !disable_aa_in_req; + request.EnableVariedAnalysis = enable_va_in_req && !disable_aa_in_req; } else { - request.EnableActivityAnalysis = m_Options.EnableActivityAnalysis; + request.EnableVariedAnalysis = m_Options.EnableVariedAnalysis; } if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only)) { if (!A->getAnnotation().equals("H")) { diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index cf3b23ecd..46b29c479 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1951,8 +1951,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Silence diag outputs in nested derivation process. pullbackRequest.VerboseDiags = false; pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis; - pullbackRequest.EnableActivityAnalysis = - m_DiffReq.EnableActivityAnalysis; + pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis; bool isaMethod = isa(FD); for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) if (MD && isLambdaCallOperator(MD)) { diff --git a/test/Analyses/ActivityReverse.cpp b/test/Analyses/ActivityReverse.cpp index ace87f23b..b6593f98a 100644 --- a/test/Analyses/ActivityReverse.cpp +++ b/test/Analyses/ActivityReverse.cpp @@ -1,6 +1,6 @@ // RUN: %cladclang %s -I%S/../../include -oActivity.out 2>&1 | %filecheck %s // RUN: ./Activity.out | %filecheck_exec %s -// RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-aa %s -I%S/../../include -oActivity.out +// RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-va %s -I%S/../../include -oActivity.out // RUN: ./Activity.out | %filecheck_exec %s //CHECK-NOT: {{.*error|warning|note:.*}} @@ -244,7 +244,7 @@ double f7(double x){ #define TEST(F, x) { \ result[0] = 0; \ - auto F##grad = clad::gradient(F);\ + auto F##grad = clad::gradient(F);\ F##grad.execute(x, result);\ printf("{%.2f}\n", result[0]); \ } diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index 19a5e7d1b..2a9749b0b 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -401,11 +401,11 @@ namespace clad { static void SetActivityAnalysisOptions(const DifferentiationOptions& DO, RequestOptions& opts) { // If user has explicitly specified the mode for AA, use it. - if (DO.EnableActivityAnalysis || DO.DisableActivityAnalysis) - opts.EnableActivityAnalysis = - DO.EnableActivityAnalysis && !DO.DisableActivityAnalysis; + if (DO.EnableVariedAnalysis || DO.DisableActivityAnalysis) + opts.EnableVariedAnalysis = + DO.EnableVariedAnalysis && !DO.DisableActivityAnalysis; else - opts.EnableActivityAnalysis = false; // Default mode. + opts.EnableVariedAnalysis = false; // Default mode. } void CladPlugin::SetRequestOptions(RequestOptions& opts) const { diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 43750828e..f17f7e41b 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -55,7 +55,7 @@ class CladTimerGroup { : DumpSourceFn(false), DumpSourceFnAST(false), DumpDerivedFn(false), DumpDerivedAST(false), GenerateSourceFile(false), ValidateClangVersion(true), EnableTBRAnalysis(false), - DisableTBRAnalysis(false), EnableActivityAnalysis(false), + DisableTBRAnalysis(false), EnableVariedAnalysis(false), DisableActivityAnalysis(false), CustomEstimationModel(false), PrintNumDiffErrorInfo(false) {} @@ -67,7 +67,7 @@ class CladTimerGroup { bool ValidateClangVersion : 1; bool EnableTBRAnalysis : 1; bool DisableTBRAnalysis : 1; - bool EnableActivityAnalysis : 1; + bool EnableVariedAnalysis : 1; bool DisableActivityAnalysis : 1; bool CustomEstimationModel : 1; bool PrintNumDiffErrorInfo : 1; @@ -317,8 +317,8 @@ class CladTimerGroup { m_DO.EnableTBRAnalysis = true; } else if (args[i] == "-disable-tbr") { m_DO.DisableTBRAnalysis = true; - } else if (args[i] == "-enable-aa") { - m_DO.EnableActivityAnalysis = true; + } else if (args[i] == "-enable-va") { + m_DO.EnableVariedAnalysis = true; } else if (args[i] == "-disable-aa") { m_DO.DisableActivityAnalysis = true; } else if (args[i] == "-fcustom-estimation-model") { @@ -374,8 +374,8 @@ class CladTimerGroup { "be used together.\n"; return false; } - if (m_DO.EnableActivityAnalysis && m_DO.DisableActivityAnalysis) { - llvm::errs() << "clad: Error: -enable-aa and -disable-aa cannot " + if (m_DO.EnableVariedAnalysis && m_DO.DisableActivityAnalysis) { + llvm::errs() << "clad: Error: -enable-va and -disable-aa cannot " "be used together.\n"; return false; }