From e092f267f66df616b6ab40ee968d48378a13debc Mon Sep 17 00:00:00 2001 From: Max Andriychuk Date: Thu, 12 Dec 2024 16:21:19 +0100 Subject: [PATCH] Remove static set from DiffReq --- .../clad/Differentiator/DerivativeBuilder.h | 1 + include/clad/Differentiator/DiffPlanner.h | 25 +++++++++++---- lib/Differentiator/DiffPlanner.cpp | 32 ++++++++++++------- lib/Differentiator/ReverseModeVisitor.cpp | 10 +++--- 4 files changed, 47 insertions(+), 21 deletions(-) diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index 5e9d54ac2..eb517666c 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -72,6 +72,7 @@ namespace clad { class DerivativeBuilder { private: friend class VisitorBase; + friend class DiffRequest; friend class BaseForwardModeVisitor; friend class PushForwardModeVisitor; friend class VectorForwardModeVisitor; diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index a33e10f79..dc12d9918 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -3,12 +3,12 @@ #include "clang/AST/RecursiveASTVisitor.h" #include "llvm/ADT/SmallSet.h" +#include +#include +#include "clad/Differentiator/DerivativeBuilder.h" #include "clad/Differentiator/DiffMode.h" #include "clad/Differentiator/DynamicGraph.h" #include "clad/Differentiator/ParseDiffArgsTypes.h" - -#include -#include namespace clang { class CallExpr; class CompilerInstance; @@ -21,6 +21,7 @@ class Type; } // namespace clang namespace clad { +class DerivativeBuilder; /// A struct containing information about request to differentiate a function. struct DiffRequest { @@ -34,12 +35,13 @@ struct DiffRequest { } m_TbrRunInfo; mutable struct ActivityRunInfo { + std::set VariedDecls; bool HasAnalysisRun = false; } m_ActivityRunInfo; public: - /// All varied declarations. - static std::set AllVariedDecls; + const DerivativeBuilder* Builder = nullptr; + // static std::set AllVariedDecls; /// Function to be differentiated. const clang::FunctionDecl* Function = nullptr; /// Name of the base function to be differentiated. Can be different from @@ -128,7 +130,8 @@ struct DiffRequest { Mode == other.Mode && EnableTBRAnalysis == other.EnableTBRAnalysis && EnableVariedAnalysis == other.EnableVariedAnalysis && DVI == other.DVI && use_enzyme == other.use_enzyme && - DeclarationOnly == other.DeclarationOnly; + DeclarationOnly == other.DeclarationOnly && + getVariedDecls() == other.getVariedDecls(); } const clang::FunctionDecl* operator->() const { return Function; } @@ -145,6 +148,16 @@ struct DiffRequest { bool shouldBeRecorded(clang::Expr* E) const; bool shouldHaveAdjoint(const clang::VarDecl* VD) const; + + void setVariedDecls(std::set init) { + for (auto* vd : init) + this->m_ActivityRunInfo.VariedDecls.insert(vd); + } + std::set getVariedDecls() const { + return this->m_ActivityRunInfo.VariedDecls; + } + DiffRequest() : Builder(nullptr) {} + DiffRequest(DerivativeBuilder& builder) : Builder(&builder) {} }; using DiffInterval = std::vector; diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index f404391a3..ec5aaf6ed 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -22,8 +22,8 @@ using namespace clang; namespace clad { -std::set DiffRequest::AllVariedDecls; -static SourceLocation noLoc; +// std::set DiffRequest::AllVariedDecls; +static SourceLocation noloc; /// Returns `DeclRefExpr` node corresponding to the function, method or /// functor argument which is to be differentiated. @@ -62,7 +62,7 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { auto callOperatorDeclName = m_SemaRef.getASTContext().DeclarationNames.getCXXOperatorName( OverloadedOperatorKind::OO_Call); - LookupResult R(m_SemaRef, callOperatorDeclName, noLoc, + LookupResult R(m_SemaRef, callOperatorDeclName, noloc, Sema::LookupNameKind::LookupMemberName); // We do not want diagnostics that would fire because of this lookup. R.suppressDiagnostics(); @@ -149,7 +149,7 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { auto* newFnDRE = clad_compat::GetResult(m_SemaRef.BuildDeclRefExpr( callOperator, callOperator->getType(), - CLAD_COMPAT_ExprValueKind_R_or_PR_Value, noLoc, &CSS)); + CLAD_COMPAT_ExprValueKind_R_or_PR_Value, noloc, &CSS)); m_FnDRE = cast(newFnDRE); } return false; @@ -198,7 +198,7 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { auto kernelArgIdx = numArgs - 1; auto* cudaKernelFlag = SemaRef - .ActOnCXXBoolLiteral(noLoc, + .ActOnCXXBoolLiteral(noloc, replacementFD->hasAttr() ? tok::kw_true : tok::kw_false) @@ -209,7 +209,7 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { // Create ref to generated FD. DeclRefExpr* DRE = - DeclRefExpr::Create(C, oldDRE->getQualifierLoc(), noLoc, replacementFD, + DeclRefExpr::Create(C, oldDRE->getQualifierLoc(), noloc, replacementFD, /*RefersToEnclosingVariableOrCapture=*/false, replacementFD->getNameInfo(), replacementFD->getType(), oldDRE->getValueKind()); @@ -225,7 +225,7 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { // Add the "&" operator auto* newUnOp = SemaRef - .BuildUnaryOp(nullptr, noLoc, UnaryOperatorKind::UO_AddrOf, DRE) + .BuildUnaryOp(nullptr, noloc, UnaryOperatorKind::UO_AddrOf, DRE) .get(); call->setArg(derivedFnArgIdx, newUnOp); } @@ -618,15 +618,25 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { return true; if (!m_ActivityRunInfo.HasAnalysisRun) { + if (Builder) + for (auto diffreq : this->Builder->m_DiffRequestGraph.getNodes()) + for (auto vd : diffreq.getVariedDecls()) + m_ActivityRunInfo.VariedDecls.insert(vd); + if (Args) for (const auto& dParam : DVI) - AllVariedDecls.insert(cast(dParam.param)); - VariedAnalyzer analyzer(Function->getASTContext(), AllVariedDecls); + m_ActivityRunInfo.VariedDecls.insert(cast(dParam.param)); + VariedAnalyzer analyzer(Function->getASTContext(), + m_ActivityRunInfo.VariedDecls); analyzer.Analyze(Function); m_ActivityRunInfo.HasAnalysisRun = true; + if (Builder) + this->Builder->m_DiffRequestGraph.addNode(*this); } - auto found = AllVariedDecls.find(VD); - return found != AllVariedDecls.end(); + auto found = m_ActivityRunInfo.VariedDecls.find(VD); + return found != m_ActivityRunInfo.VariedDecls.end(); + + return false; } bool DiffCollector::VisitCallExpr(CallExpr* E) { diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 61532f656..5b9ee3a54 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -216,6 +216,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_ExternalSource->ActAfterParsingDiffArgs(m_DiffReq, args); auto derivativeBaseName = m_DiffReq.BaseFunctionName; + // llvm::errs() << "\nBaseFunctionName: " << derivativeBaseName << "\n"; std::string gradientName = derivativeBaseName + funcPostfix(); // To be consistent with older tests, nothing is appended to 'f_grad' if // we differentiate w.r.t. all the parameters at once. @@ -1944,10 +1945,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_ExternalSource->ActBeforeDifferentiatingCallExpr( pullbackCallArgs, PreCallStmts, dfdx()); - // Overloaded derivative was not found, request the CladPlugin to - // derive the called function. - DiffRequest pullbackRequest{}; - pullbackRequest.Function = FD; + // Overloaded derivative was not found, request the CladPlugin to + // derive the called function. + DiffRequest pullbackRequest(m_Builder); + pullbackRequest.Function = FD; // Mark the indexes of the global args. Necessary if the argument of the // call has a different name than the function's signature parameter. @@ -1960,6 +1961,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, pullbackRequest.VerboseDiags = false; pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis; pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis; + pullbackRequest.setVariedDecls(m_DiffReq.getVariedDecls()); bool isaMethod = isa(FD); for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) if (MD && isLambdaCallOperator(MD)) {