diff --git a/include/clad/Differentiator/CladConfig.h b/include/clad/Differentiator/CladConfig.h index 8c0eb3b5f..39d47efd8 100644 --- a/include/clad/Differentiator/CladConfig.h +++ b/include/clad/Differentiator/CladConfig.h @@ -29,6 +29,8 @@ enum opts : unsigned { // 00 - default, 01 - enable, 10 - disable, 11 - not used / invalid enable_tbr = 1 << (ORDER_BITS + 2), disable_tbr = 1 << (ORDER_BITS + 3), + enable_va = 1 << (ORDER_BITS + 5), + disable_va = 1 << (ORDER_BITS + 6), // Specifying whether we only want the diagonal of the hessian. diagonal_only = 1 << (ORDER_BITS + 4), diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index a4b06a148..2116d5ea0 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -7,6 +7,8 @@ #include "clad/Differentiator/DynamicGraph.h" #include "clad/Differentiator/ParseDiffArgsTypes.h" +#include +#include namespace clang { class CallExpr; class CompilerInstance; @@ -31,6 +33,11 @@ struct DiffRequest { bool HasAnalysisRun = false; } m_TbrRunInfo; + mutable struct ActivityRunInfo { + std::set ToBeRecorded; + bool HasAnalysisRun = false; + } m_ActivityRunInfo; + public: /// Function to be differentiated. const clang::FunctionDecl* Function = nullptr; @@ -57,6 +64,7 @@ struct DiffRequest { bool VerboseDiags = false; /// A flag to enable TBR analysis during reverse-mode differentiation. bool EnableTBRAnalysis = false; + bool EnableVariedAnalysis = false; /// Puts the derived function and its code in the diff call void updateCall(clang::FunctionDecl* FD, clang::FunctionDecl* OverloadedFD, clang::Sema& SemaRef); @@ -114,6 +122,7 @@ struct DiffRequest { RequestedDerivativeOrder == other.RequestedDerivativeOrder && CallContext == other.CallContext && Args == other.Args && Mode == other.Mode && EnableTBRAnalysis == other.EnableTBRAnalysis && + EnableVariedAnalysis == other.EnableVariedAnalysis && DVI == other.DVI && use_enzyme == other.use_enzyme && DeclarationOnly == other.DeclarationOnly; } @@ -131,6 +140,7 @@ struct DiffRequest { } bool shouldBeRecorded(clang::Expr* E) const; + bool shouldHaveAdjoint(const clang::VarDecl* VD) const; }; using DiffInterval = std::vector; @@ -139,6 +149,7 @@ struct DiffRequest { /// This is a flag to indicate the default behaviour to enable/disable /// TBR analysis during reverse-mode differentiation. bool EnableTBRAnalysis = false; + bool EnableVariedAnalysis = false; }; class DiffCollector: public clang::RecursiveASTVisitor { diff --git a/lib/Differentiator/ActivityAnalyzer.cpp b/lib/Differentiator/ActivityAnalyzer.cpp new file mode 100644 index 000000000..ac233700d --- /dev/null +++ b/lib/Differentiator/ActivityAnalyzer.cpp @@ -0,0 +1,173 @@ +#include "ActivityAnalyzer.h" + +using namespace clang; + +namespace clad { + +void VariedAnalyzer::Analyze(const FunctionDecl* FD) { + // Build the CFG (control-flow graph) of FD. + clang::CFG::BuildOptions Options; + m_CFG = clang::CFG::buildCFG(FD, FD->getBody(), &m_Context, Options); + + m_BlockData.resize(m_CFG->size()); + // Set current block ID to the ID of entry the block. + CFGBlock* entry = &m_CFG->getEntry(); + m_CurBlockID = entry->getBlockID(); + m_BlockData[m_CurBlockID] = createNewVarsData({}); + for (const VarDecl* i : m_VariedDecls) + m_BlockData[m_CurBlockID]->insert(i); + // Add the entry block to the queue. + m_CFGQueue.insert(m_CurBlockID); + + // Visit CFG blocks in the queue until it's empty. + while (!m_CFGQueue.empty()) { + auto IDIter = std::prev(m_CFGQueue.end()); + m_CurBlockID = *IDIter; + m_CFGQueue.erase(IDIter); + CFGBlock& nextBlock = *getCFGBlockByID(m_CurBlockID); + AnalyzeCFGBlock(nextBlock); + } +} + +void mergeVarsData(std::set* targetData, + std::set* mergeData) { + for (const clang::VarDecl* i : *mergeData) + targetData->insert(i); + *mergeData = *targetData; +} + +CFGBlock* VariedAnalyzer::getCFGBlockByID(unsigned ID) { + return *(m_CFG->begin() + ID); +} + +void VariedAnalyzer::AnalyzeCFGBlock(const CFGBlock& block) { + // Visit all the statements inside the block. + for (const clang::CFGElement& Element : block) { + if (Element.getKind() == clang::CFGElement::Statement) { + const clang::Stmt* S = Element.castAs().getStmt(); + // The const_cast is inevitable, since there is no + // ConstRecusiveASTVisitor. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + TraverseStmt(const_cast(S)); + } + } + + for (const clang::CFGBlock::AdjacentBlock succ : block.succs()) { + if (!succ) + continue; + auto& succData = m_BlockData[succ->getBlockID()]; + + if (!succData) + succData = createNewVarsData(*m_BlockData[block.getBlockID()]); + + bool shouldPushSucc = true; + if (succ->getBlockID() > block.getBlockID()) { + if (m_LoopMem == *m_BlockData[block.getBlockID()]) + shouldPushSucc = false; + + for (const VarDecl* i : *m_BlockData[block.getBlockID()]) + m_LoopMem.insert(i); + } + + if (shouldPushSucc) + m_CFGQueue.insert(succ->getBlockID()); + + mergeVarsData(succData.get(), m_BlockData[block.getBlockID()].get()); + } + // FIXME: Information about the varied variables is stored in the last block, + // so we should be able to get it form there + for (const VarDecl* i : *m_BlockData[block.getBlockID()]) + m_VariedDecls.insert(i); +} + +bool VariedAnalyzer::isVaried(const VarDecl* VD) const { + const VarsData& curBranch = getCurBlockVarsData(); + return curBranch.find(VD) != curBranch.end(); +} + +void VariedAnalyzer::copyVarToCurBlock(const clang::VarDecl* VD) { + VarsData& curBranch = getCurBlockVarsData(); + curBranch.insert(VD); +} + +bool VariedAnalyzer::VisitBinaryOperator(BinaryOperator* BinOp) { + Expr* L = BinOp->getLHS(); + Expr* R = BinOp->getRHS(); + const auto opCode = BinOp->getOpcode(); + if (BinOp->isAssignmentOp()) { + m_Varied = false; + TraverseStmt(R); + m_Marking = m_Varied; + TraverseStmt(L); + m_Marking = false; + } else if (opCode == BO_Add || opCode == BO_Sub || opCode == BO_Mul || + opCode == BO_Div) { + for (auto* subexpr : BinOp->children()) + if (!isa(subexpr)) + TraverseStmt(subexpr); + } + return true; +} + +// add branching merging +bool VariedAnalyzer::VisitConditionalOperator(ConditionalOperator* CO) { + TraverseStmt(CO->getCond()); + TraverseStmt(CO->getTrueExpr()); + TraverseStmt(CO->getFalseExpr()); + return true; +} + +bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) { + FunctionDecl* FD = CE->getDirectCallee(); + bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams()); + if (noHiddenParam) { + MutableArrayRef FDparam = FD->parameters(); + for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { + clang::Expr* par = CE->getArg(i); + TraverseStmt(par); + m_VariedDecls.insert(FDparam[i]); + } + } + m_Varied = true; + return true; +} + +bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) { + for (Decl* D : DS->decls()) { + if (Expr* init = cast(D)->getInit()) { + m_Varied = false; + TraverseStmt(init); + m_Marking = true; + if (m_Varied) + copyVarToCurBlock(cast(D)); + m_Marking = false; + } + } + return true; +} + +bool VariedAnalyzer::VisitUnaryOperator(UnaryOperator* UnOp) { + const auto opCode = UnOp->getOpcode(); + Expr* E = UnOp->getSubExpr(); + if (opCode == UO_AddrOf || opCode == UO_Deref) { + m_Varied = true; + m_Marking = true; + } + TraverseStmt(E); + m_Marking = false; + return true; +} + +bool VariedAnalyzer::VisitDeclRefExpr(DeclRefExpr* DRE) { + auto* VD = dyn_cast(DRE->getDecl()); + if (!VD) + return true; + + if (isVaried(VD)) + m_Varied = true; + + if (m_Varied && m_Marking) + copyVarToCurBlock(VD); + return true; +} +} // namespace clad diff --git a/lib/Differentiator/ActivityAnalyzer.h b/lib/Differentiator/ActivityAnalyzer.h new file mode 100644 index 000000000..3e0c62693 --- /dev/null +++ b/lib/Differentiator/ActivityAnalyzer.h @@ -0,0 +1,82 @@ +#ifndef CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H +#define CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H + +#include "clang/AST/RecursiveASTVisitor.h" +#include "clang/Analysis/CFG.h" + +#include "clad/Differentiator/CladUtils.h" +#include "clad/Differentiator/Compatibility.h" + +#include +#include +#include +#include +#include + +namespace clad { + +/// Class that implemets Varied part of the Activity analysis. +/// By performing static data-flow analysis, so called Varied variables +/// are determined, meaning variables that depend on input parameters +/// in a differentiable way. That result enables us to remove redundant +/// statements in the reverse mode, improving generated codes efficiency. +class VariedAnalyzer : public clang::RecursiveASTVisitor { + bool m_Varied = false; + bool m_Marking = false; + using VarsData = std::set; + VarsData& m_VariedDecls; + /// A helper method to allocate VarsData + /// \param[in] toAssign - Parameter to initialize new VarsData with. + /// \return Unique pointer to a new object of type Varsdata. + static std::unique_ptr createNewVarsData(VarsData toAssign) { + return std::unique_ptr(new VarsData(std::move(toAssign))); + } + VarsData m_LoopMem; + + clang::CFGBlock* getCFGBlockByID(unsigned ID); + + clang::ASTContext& m_Context; + std::unique_ptr m_CFG; + std::vector> m_BlockData; + unsigned m_CurBlockID{}; + std::set m_CFGQueue; + /// Checks if a variable is on the current branch. + /// \param[in] VD - Variable declaration. + /// @return Whether a variable is on the current branch. + bool isVaried(const clang::VarDecl* VD) const; + /// Adds varied variable to current branch. + /// \param[in] VD - Variable declaration. + void copyVarToCurBlock(const clang::VarDecl* VD); + VarsData& getCurBlockVarsData() { return *m_BlockData[m_CurBlockID]; } + [[nodiscard]] const VarsData& getCurBlockVarsData() const { + return const_cast(this)->getCurBlockVarsData(); + } + void AnalyzeCFGBlock(const clang::CFGBlock& block); + +public: + /// Constructor + VariedAnalyzer(clang::ASTContext& Context, + std::set& Decls) + : m_VariedDecls(Decls), m_Context(Context) {} + + /// Destructor + ~VariedAnalyzer() = default; + + /// Delete copy/move operators and constructors. + VariedAnalyzer(const VariedAnalyzer&) = delete; + VariedAnalyzer& operator=(const VariedAnalyzer&) = delete; + VariedAnalyzer(const VariedAnalyzer&&) = delete; + VariedAnalyzer& operator=(const VariedAnalyzer&&) = delete; + + /// Runs Varied analysis. + /// \param[in] FD Function to run the analysis on. + void Analyze(const clang::FunctionDecl* FD); + bool VisitBinaryOperator(clang::BinaryOperator* BinOp); + bool VisitCallExpr(clang::CallExpr* CE); + bool VisitConditionalOperator(clang::ConditionalOperator* CO); + bool VisitDeclRefExpr(clang::DeclRefExpr* DRE); + bool VisitDeclStmt(clang::DeclStmt* DS); + bool VisitUnaryOperator(clang::UnaryOperator* UnOp); +}; +} // namespace clad +#endif // CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H diff --git a/lib/Differentiator/CMakeLists.txt b/lib/Differentiator/CMakeLists.txt index 4d9731d85..7f928b2ac 100644 --- a/lib/Differentiator/CMakeLists.txt +++ b/lib/Differentiator/CMakeLists.txt @@ -21,6 +21,7 @@ set_property(SOURCE Version.cpp APPEND PROPERTY # (Ab)use llvm facilities for adding libraries. llvm_add_library(cladDifferentiator STATIC + ActivityAnalyzer.cpp BaseForwardModeVisitor.cpp CladUtils.cpp ConstantFolder.cpp diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index e25cabf40..a526a6d58 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -1,5 +1,6 @@ #include "clad/Differentiator/DiffPlanner.h" +#include "ActivityAnalyzer.h" #include "TBRAnalyzer.h" #include "clang/AST/ASTContext.h" @@ -615,6 +616,26 @@ namespace clad { return found != m_TbrRunInfo.ToBeRecorded.end(); } + bool DiffRequest::shouldHaveAdjoint(const VarDecl* VD) const { + if (!EnableVariedAnalysis) + return true; + + if (VD->getType()->isPointerType() || isa(VD->getType())) + return true; + + if (!m_ActivityRunInfo.HasAnalysisRun) { + for (const auto& dParam : DVI) + m_ActivityRunInfo.ToBeRecorded.insert(cast(dParam.param)); + + VariedAnalyzer analyzer(Function->getASTContext(), + m_ActivityRunInfo.ToBeRecorded); + analyzer.Analyze(Function); + m_ActivityRunInfo.HasAnalysisRun = true; + } + auto found = m_ActivityRunInfo.ToBeRecorded.find(VD); + return found != m_ActivityRunInfo.ToBeRecorded.end(); + } + bool DiffCollector::VisitCallExpr(CallExpr* E) { // Check if we should look into this. // FIXME: Generated code does not usually have valid source locations. @@ -646,6 +667,8 @@ namespace clad { unsigned bitmasked_opts_value = 0; bool enable_tbr_in_req = false; bool disable_tbr_in_req = false; + bool enable_va_in_req = false; + bool disable_va_in_req = false; if (!A->getAnnotation().equals("E") && FD->getTemplateSpecializationArgs()) { const auto template_arg = FD->getTemplateSpecializationArgs()->get(0); @@ -661,17 +684,33 @@ namespace clad { clad::HasOption(bitmasked_opts_value, clad::opts::enable_tbr); disable_tbr_in_req = clad::HasOption(bitmasked_opts_value, clad::opts::disable_tbr); + // Set option for Activity analysis. + enable_va_in_req = + clad::HasOption(bitmasked_opts_value, clad::opts::enable_va); + disable_va_in_req = + clad::HasOption(bitmasked_opts_value, clad::opts::disable_va); if (enable_tbr_in_req && disable_tbr_in_req) { utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, "Both enable and disable TBR options are specified."); return true; } + if (enable_va_in_req && disable_va_in_req) { + utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, + "Both enable and disable VA options are specified."); + return true; + } if (enable_tbr_in_req || disable_tbr_in_req) { // override the default value of TBR analysis. request.EnableTBRAnalysis = enable_tbr_in_req && !disable_tbr_in_req; } else { request.EnableTBRAnalysis = m_Options.EnableTBRAnalysis; } + if (enable_va_in_req || disable_va_in_req) { + // override the default value of TBR analysis. + request.EnableVariedAnalysis = enable_va_in_req && !disable_va_in_req; + } else { + request.EnableVariedAnalysis = m_Options.EnableVariedAnalysis; + } if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only)) { if (!A->getAnnotation().equals("H")) { utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 578f8bb0f..26746f5b0 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1826,7 +1826,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // We do not need to create result arg for arguments passed by reference // because the derivatives of arguments passed by reference are directly // modified by the derived callee function. - if (utils::IsReferenceOrPointerArg(arg)) { + if (utils::IsReferenceOrPointerArg(arg) || + !m_DiffReq.shouldHaveAdjoint(PVD)) { argDiff = Visit(arg); CallArgDx.push_back(argDiff.getExpr_dx()); } else { @@ -1966,7 +1967,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, for (auto* argDerivative : CallArgDx) { Expr* gradArgExpr = nullptr; QualType paramTy = FD->getParamDecl(idx)->getType(); - if (utils::isArrayOrPointerType(paramTy) || + if (!argDerivative || utils::isArrayOrPointerType(paramTy) || isCladArrayType(argDerivative->getType())) gradArgExpr = argDerivative; else @@ -2057,6 +2058,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Silence diag outputs in nested derivation process. pullbackRequest.VerboseDiags = false; pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis; + pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis; bool isaMethod = isa(FD); for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) if (MD && isLambdaCallOperator(MD)) { @@ -2212,7 +2214,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, i != e; ++i) { const Expr* arg = CE->getArg(i); StmtDiff argDiff = Visit(arg); - CallArgs.push_back(argDiff.getExpr_dx()); + // Has to be removed once nondifferentiable arguments are handeled + if (argDiff.getStmt_dx()) + CallArgs.push_back(argDiff.getExpr_dx()); + else + CallArgs.push_back(getZeroInit(arg->getType())); } if (Expr* baseE = baseDiff.getExpr()) { call = BuildCallExprToMemFn(baseE, calleeFnForwPassFD->getName(), @@ -3025,6 +3031,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, nullptr, VD->getInitStyle()); } + if (!m_DiffReq.shouldHaveAdjoint((VD))) + VDDerived = nullptr; + // If `VD` is a reference to a local variable, then it is already // differentiated and should not be differentiated again. // If `VD` is a reference to a non-local variable then also there's no @@ -3032,7 +3041,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (!isRefType && (!isPointerType || isInitializedByNewExpr)) { Expr* derivedE = nullptr; - if (!clad::utils::hasNonDifferentiableAttribute(VD)) { + if (VDDerived && !clad::utils::hasNonDifferentiableAttribute(VD)) { derivedE = BuildDeclRef(VDDerived); if (isInitializedByNewExpr) derivedE = BuildOp(UnaryOperatorKind::UO_Deref, derivedE); @@ -3059,7 +3068,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // *_d_i += _d_localVar; // _d_localVar = 0; // } - if (isInsideLoop) { + if (VDDerived && isInsideLoop) { Stmt* assignToZero = nullptr; Expr* declRef = BuildDeclRef(VDDerived); if (!isa(VDDerivedType)) @@ -3074,9 +3083,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, VarDecl* VDClone = nullptr; Expr* derivedVDE = nullptr; - if (VDDerived) + if (VDDerived && m_DiffReq.shouldHaveAdjoint(const_cast(VD))) derivedVDE = BuildDeclRef(VDDerived); - // FIXME: Add extra parantheses if derived variable pointer is pointing to a // class type object. if (isRefType && promoteToFnScope) { @@ -3130,8 +3138,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, VDDerived->setInitStyle(VarDecl::InitializationStyle::CInit); } } + if (derivedVDE) m_Variables.emplace(VDClone, derivedVDE); + // Check if decl's name is the same as before. The name may be changed // if decl name collides with something in the derivative body. // This can happen in rare cases, e.g. when the original function @@ -3167,8 +3177,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // TODO: 'shouldEmit' parameter should be removed after converting // Error estimation framework to callback style. Some more research // need to be done to - StmtDiff - ReverseModeVisitor::DifferentiateSingleStmt(const Stmt* S, Expr* dfdS) { + StmtDiff ReverseModeVisitor::DifferentiateSingleStmt(const Stmt* S, + Expr* dfdS) { if (m_ExternalSource) m_ExternalSource->ActOnStartOfDifferentiateSingleStmt(); beginBlock(direction::reverse); @@ -3195,6 +3205,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CompoundStmt* RCS = endBlock(direction::reverse); std::reverse(RCS->body_begin(), RCS->body_end()); Stmt* ReverseResult = utils::unwrapIfSingleStmt(RCS); + return StmtDiff(SDiff.getStmt(), ReverseResult); } diff --git a/test/Analyses/ActivityReverse.cpp b/test/Analyses/ActivityReverse.cpp new file mode 100644 index 000000000..e1b5cb35c --- /dev/null +++ b/test/Analyses/ActivityReverse.cpp @@ -0,0 +1,273 @@ +// RUN: %cladclang %s -I%S/../../include -oActivity.out 2>&1 | %filecheck %s +// RUN: ./Activity.out | %filecheck_exec %s +// RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-va %s -I%S/../../include -oActivity.out +// RUN: ./Activity.out | %filecheck_exec %s +//CHECK-NOT: {{.*error|warning|note:.*}} + +#include "clad/Differentiator/Differentiator.h" + +double f1(double x){ + double a = x*x; + double b = 1; + b = b*b; + return a; +} + +//CHECK: void f1_grad(double x, double *_d_x) { +//CHECK-NEXT: double _d_a = 0.; +//CHECK-NEXT: double a = x * x; +//CHECK-NEXT: double b = 1; +//CHECK-NEXT: double _t0 = b; +//CHECK-NEXT: b = b * b; +//CHECK-NEXT: _d_a += 1; +//CHECK-NEXT: b = _t0; +//CHECK-NEXT: { +//CHECK-NEXT: *_d_x += _d_a * x; +//CHECK-NEXT: *_d_x += x * _d_a; +//CHECK-NEXT: } +//CHECK-NEXT: } + +double f2(double x){ + double a = x*x; + double b = 1; + double g; + if(a) + b=x; + else if(b) + double d = b; + else + g = a; + return a; +} + +//CHECK: void f2_grad(double x, double *_d_x) { +//CHECK-NEXT: bool _cond0; +//CHECK-NEXT: double _t0; +//CHECK-NEXT: bool _cond1; +//CHECK-NEXT: double d = 0.; +//CHECK-NEXT: double _t1; +//CHECK-NEXT: double _d_a = 0.; +//CHECK-NEXT: double a = x * x; +//CHECK-NEXT: double _d_b = 0.; +//CHECK-NEXT: double b = 1; +//CHECK-NEXT: double _d_g = 0.; +//CHECK-NEXT: double g; +//CHECK-NEXT: { +//CHECK-NEXT: _cond0 = a; +//CHECK-NEXT: if (_cond0) { +//CHECK-NEXT: _t0 = b; +//CHECK-NEXT: b = x; +//CHECK-NEXT: } else { +//CHECK-NEXT: _cond1 = b; +//CHECK-NEXT: if (_cond1) +//CHECK-NEXT: d = b; +//CHECK-NEXT: else { +//CHECK-NEXT: _t1 = g; +//CHECK-NEXT: g = a; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: _d_a += 1; +//CHECK-NEXT: if (_cond0) { +//CHECK-NEXT: b = _t0; +//CHECK-NEXT: double _r_d0 = _d_b; +//CHECK-NEXT: _d_b = 0.; +//CHECK-NEXT: *_d_x += _r_d0; +//CHECK-NEXT: } else if (!_cond1) { +//CHECK-NEXT: g = _t1; +//CHECK-NEXT: double _r_d1 = _d_g; +//CHECK-NEXT: _d_g = 0.; +//CHECK-NEXT: _d_a += _r_d1; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: *_d_x += _d_a * x; +//CHECK-NEXT: *_d_x += x * _d_a; +//CHECK-NEXT: } +//CHECK-NEXT: } + +double f3(double x){ + double x1, x2, x3, x4, x5 = 0; + while(!x3){ + x5 = x4; + x4 = x3; + x3 = x2; + x2 = x1; + x1 = x; + } + return x5; +} + +//CHECK: void f3_grad(double x, double *_d_x) { +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: clad::tape _t2 = {}; +//CHECK-NEXT: clad::tape _t3 = {}; +//CHECK-NEXT: clad::tape _t4 = {}; +//CHECK-NEXT: clad::tape _t5 = {}; +//CHECK-NEXT: double _d_x1 = 0., _d_x2 = 0., _d_x3 = 0., _d_x4 = 0., _d_x5 = 0.; +//CHECK-NEXT: double x1, x2, x3, x4, x5 = 0; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; +//CHECK-NEXT: while (!x3) +//CHECK-NEXT: { +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, x5); +//CHECK-NEXT: x5 = x4; +//CHECK-NEXT: clad::push(_t2, x4); +//CHECK-NEXT: x4 = x3; +//CHECK-NEXT: clad::push(_t3, x3); +//CHECK-NEXT: x3 = x2; +//CHECK-NEXT: clad::push(_t4, x2); +//CHECK-NEXT: x2 = x1; +//CHECK-NEXT: clad::push(_t5, x1); +//CHECK-NEXT: x1 = x; +//CHECK-NEXT: } +//CHECK-NEXT: _d_x5 += 1; +//CHECK-NEXT: while (_t0) +//CHECK-NEXT: { +//CHECK-NEXT: { +//CHECK-NEXT: { +//CHECK-NEXT: x1 = clad::pop(_t5); +//CHECK-NEXT: double _r_d4 = _d_x1; +//CHECK-NEXT: _d_x1 = 0.; +//CHECK-NEXT: *_d_x += _r_d4; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: x2 = clad::pop(_t4); +//CHECK-NEXT: double _r_d3 = _d_x2; +//CHECK-NEXT: _d_x2 = 0.; +//CHECK-NEXT: _d_x1 += _r_d3; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: x3 = clad::pop(_t3); +//CHECK-NEXT: double _r_d2 = _d_x3; +//CHECK-NEXT: _d_x3 = 0.; +//CHECK-NEXT: _d_x2 += _r_d2; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: x4 = clad::pop(_t2); +//CHECK-NEXT: double _r_d1 = _d_x4; +//CHECK-NEXT: _d_x4 = 0.; +//CHECK-NEXT: _d_x3 += _r_d1; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: x5 = clad::pop(_t1); +//CHECK-NEXT: double _r_d0 = _d_x5; +//CHECK-NEXT: _d_x5 = 0.; +//CHECK-NEXT: _d_x4 += _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: _t0--; +//CHECK-NEXT: } +//CHECK-NEXT: } + +double f4_1(double v, double u){ + double k = 2*u; + double n = 2*v; + return n*k; +} +double f4(double x){ + double c = f4_1(x, 1); + return c; +} +// CHECK-NEXT: void f4_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u); + +// CHECK: void f4_grad(double x, double *_d_x) { +// CHECK-NEXT: double _d_c = 0.; +// CHECK-NEXT: double c = f4_1(x, 1); +// CHECK-NEXT: _d_c += 1; +// CHECK-NEXT: { +// CHECK-NEXT: double _r0 = 0.; +// CHECK-NEXT: double _r1 = 0.; +// CHECK-NEXT: f4_1_pullback(x, 1, _d_c, &_r0, &_r1); +// CHECK-NEXT: *_d_x += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + +double f5(double x){ + double g = x ? 1 : 2; + return g; +} +// CHECK: void f5_grad(double x, double *_d_x) { +// CHECK-NEXT: double _cond0 = x; +// CHECK-NEXT: double _d_g = 0.; +// CHECK-NEXT: double g = _cond0 ? 1 : 2; +// CHECK-NEXT: _d_g += 1; +// CHECK-NEXT: } + +double f6(double x){ + double a = 0; + if(0){ + a = x; + } + return a; +} + +// CHECK: void f6_grad(double x, double *_d_x) { +// CHECK-NEXT: double _t0; +// CHECK-NEXT: double a = 0; +// CHECK-NEXT: if (0) { +// CHECK-NEXT: _t0 = a; +// CHECK-NEXT: a = x; +// CHECK-NEXT: } +// CHECK-NEXT: if (0) { +// CHECK-NEXT: a = _t0; +// CHECK-NEXT: } +// CHECK-NEXT: } + +double f7(double x){ + double &a = x; + double* b = &a; + double arr[3] = {1,2,3}; + double c = arr[0]*(*b)+arr[1]*a+arr[2]*x; + return a; +} + +// CHECK: void f7_grad(double x, double *_d_x) { +// CHECK-NEXT: double &_d_a = *_d_x; +// CHECK-NEXT: double &a = x; +// CHECK-NEXT: double *_d_b = &_d_a; +// CHECK-NEXT: double *b = &a; +// CHECK-NEXT: double _d_arr[3] = {0}; +// CHECK-NEXT: double arr[3] = {1, 2, 3}; +// CHECK-NEXT: double _d_c = 0.; +// CHECK-NEXT: double c = arr[0] * *b + arr[1] * a + arr[2] * x; +// CHECK-NEXT: _d_a += 1; +// CHECK-NEXT: { +// CHECK-NEXT: _d_arr[0] += _d_c * *b; +// CHECK-NEXT: *_d_b += arr[0] * _d_c; +// CHECK-NEXT: _d_arr[1] += _d_c * a; +// CHECK-NEXT: _d_a += arr[1] * _d_c; +// CHECK-NEXT: _d_arr[2] += _d_c * x; +// CHECK-NEXT: *_d_x += arr[2] * _d_c; +// CHECK-NEXT: } +// CHECK-NEXT: } + +#define TEST(F, x) { \ + result[0] = 0; \ + auto F##grad = clad::gradient(F);\ + F##grad.execute(x, result);\ + printf("{%.2f}\n", result[0]); \ +} + +int main(){ + double result[3] = {}; + TEST(f1, 3);// CHECK-EXEC: {6.00} + TEST(f2, 3);// CHECK-EXEC: {6.00} + TEST(f3, 3);// CHECK-EXEC: {0.00} + TEST(f4, 3);// CHECK-EXEC: {4.00} + TEST(f5, 3);// CHECK-EXEC: {0.00} + TEST(f6, 3);// CHECK-EXEC: {0.00} + TEST(f7, 3);// CHECK-EXEC: {1.00} +} + +// CHECK: void f4_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u) { +// CHECK-NEXT: double _d_k = 0.; +// CHECK-NEXT: double k = 2 * u; +// CHECK-NEXT: double _d_n = 0.; +// CHECK-NEXT: double n = 2 * v; +// CHECK-NEXT: { +// CHECK-NEXT: _d_n += _d_y * k; +// CHECK-NEXT: _d_k += n * _d_y; +// CHECK-NEXT: } +// CHECK-NEXT: *_d_v += 2 * _d_n; +// CHECK-NEXT: *_d_u += 2 * _d_k; +// CHECK-NEXT: } diff --git a/test/FirstDerivative/FunctionCalls.C b/test/FirstDerivative/FunctionCalls.C index c460ba94f..02c5ccfd9 100644 --- a/test/FirstDerivative/FunctionCalls.C +++ b/test/FirstDerivative/FunctionCalls.C @@ -207,6 +207,7 @@ int main () { clad::differentiate(test_8, "x"); clad::differentiate(test_8); // expected-error {{TBR analysis is not meant for forward mode AD.}} clad::differentiate(test_8); // expected-error {{Both enable and disable TBR options are specified.}} + clad::gradient(test_8); // expected-error {{Both enable and disable VA options are specified.}} clad::differentiate(test_8); // expected-error {{Diagonal only option is only valid for Hessian mode.}} clad::differentiate(test_9); clad::differentiate(test_10); diff --git a/test/Misc/Args.C b/test/Misc/Args.C index 35b7c3e5f..d263d782f 100644 --- a/test/Misc/Args.C +++ b/test/Misc/Args.C @@ -29,4 +29,8 @@ // RUN: clang -fsyntax-only -fplugin=%cladlib -Xclang -plugin-arg-clad -Xclang -enable-tbr \ // RUN: -Xclang -plugin-arg-clad -Xclang -disable-tbr %s 2>&1 | FileCheck --check-prefix=CHECK_TBR %s -// CHECK_TBR: -enable-tbr and -disable-tbr cannot be used together \ No newline at end of file +// CHECK_TBR: -enable-tbr and -disable-tbr cannot be used together + +// RUN: clang -fsyntax-only -fplugin=%cladlib -Xclang -plugin-arg-clad -Xclang -enable-va \ +// RUN: -Xclang -plugin-arg-clad -Xclang -disable-va %s 2>&1 | FileCheck --check-prefix=CHECK_VA %s +// CHECK_VA: -enable-va and -disable-va cannot be used together \ No newline at end of file diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index 31383229e..d228a2dc3 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -404,8 +404,19 @@ namespace clad { opts.EnableTBRAnalysis = false; // Default mode. } + static void SetActivityAnalysisOptions(const DifferentiationOptions& DO, + RequestOptions& opts) { + // If user has explicitly specified the mode for AA, use it. + if (DO.EnableVariedAnalysis || DO.DisableActivityAnalysis) + opts.EnableVariedAnalysis = + DO.EnableVariedAnalysis && !DO.DisableActivityAnalysis; + else + opts.EnableVariedAnalysis = false; // Default mode. + } + void CladPlugin::SetRequestOptions(RequestOptions& opts) const { SetTBRAnalysisOptions(m_DO, opts); + SetActivityAnalysisOptions(m_DO, opts); } void CladPlugin::FinalizeTranslationUnit() { diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 2f4e23694..89b62ce8f 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -51,24 +51,27 @@ class CladTimerGroup { namespace plugin { struct DifferentiationOptions { - DifferentiationOptions() - : DumpSourceFn(false), DumpSourceFnAST(false), DumpDerivedFn(false), - DumpDerivedAST(false), GenerateSourceFile(false), - ValidateClangVersion(true), EnableTBRAnalysis(false), - DisableTBRAnalysis(false), CustomEstimationModel(false), - PrintNumDiffErrorInfo(false) {} - - bool DumpSourceFn : 1; - bool DumpSourceFnAST : 1; - bool DumpDerivedFn : 1; - bool DumpDerivedAST : 1; - bool GenerateSourceFile : 1; - bool ValidateClangVersion : 1; - bool EnableTBRAnalysis : 1; - bool DisableTBRAnalysis : 1; - bool CustomEstimationModel : 1; - bool PrintNumDiffErrorInfo : 1; - std::string CustomModelName; + DifferentiationOptions() + : DumpSourceFn(false), DumpSourceFnAST(false), DumpDerivedFn(false), + DumpDerivedAST(false), GenerateSourceFile(false), + ValidateClangVersion(true), EnableTBRAnalysis(false), + DisableTBRAnalysis(false), EnableVariedAnalysis(false), + DisableActivityAnalysis(false), CustomEstimationModel(false), + PrintNumDiffErrorInfo(false) {} + + bool DumpSourceFn : 1; + bool DumpSourceFnAST : 1; + bool DumpDerivedFn : 1; + bool DumpDerivedAST : 1; + bool GenerateSourceFile : 1; + bool ValidateClangVersion : 1; + bool EnableTBRAnalysis : 1; + bool DisableTBRAnalysis : 1; + bool EnableVariedAnalysis : 1; + bool DisableActivityAnalysis : 1; + bool CustomEstimationModel : 1; + bool PrintNumDiffErrorInfo : 1; + std::string CustomModelName; }; class CladExternalSource : public clang::ExternalSemaSource { @@ -314,6 +317,10 @@ class CladTimerGroup { m_DO.EnableTBRAnalysis = true; } else if (args[i] == "-disable-tbr") { m_DO.DisableTBRAnalysis = true; + } else if (args[i] == "-enable-va") { + m_DO.EnableVariedAnalysis = true; + } else if (args[i] == "-disable-va") { + m_DO.DisableActivityAnalysis = true; } else if (args[i] == "-fcustom-estimation-model") { m_DO.CustomEstimationModel = true; if (++i == e) { @@ -367,6 +374,11 @@ class CladTimerGroup { "be used together.\n"; return false; } + if (m_DO.EnableVariedAnalysis && m_DO.DisableActivityAnalysis) { + llvm::errs() << "clad: Error: -enable-va and -disable-va cannot " + "be used together.\n"; + return false; + } return true; }