Skip to content

Commit

Permalink
Remove static set from DiffReq
Browse files Browse the repository at this point in the history
  • Loading branch information
Max Andriychuk committed Dec 12, 2024
1 parent fac1aee commit e092f26
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 21 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ namespace clad {
class DerivativeBuilder {
private:
friend class VisitorBase;
friend class DiffRequest;
friend class BaseForwardModeVisitor;
friend class PushForwardModeVisitor;
friend class VectorForwardModeVisitor;
Expand Down
25 changes: 19 additions & 6 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

#include "clang/AST/RecursiveASTVisitor.h"
#include "llvm/ADT/SmallSet.h"
#include <iterator>
#include <set>
#include "clad/Differentiator/DerivativeBuilder.h"
#include "clad/Differentiator/DiffMode.h"
#include "clad/Differentiator/DynamicGraph.h"
#include "clad/Differentiator/ParseDiffArgsTypes.h"

#include <iterator>
#include <set>
namespace clang {
class CallExpr;
class CompilerInstance;
Expand All @@ -21,6 +21,7 @@ class Type;
} // namespace clang

namespace clad {
class DerivativeBuilder;

/// A struct containing information about request to differentiate a function.
struct DiffRequest {
Expand All @@ -34,12 +35,13 @@ struct DiffRequest {
} m_TbrRunInfo;

mutable struct ActivityRunInfo {
std::set<const clang::VarDecl*> VariedDecls;
bool HasAnalysisRun = false;
} m_ActivityRunInfo;

public:
/// All varied declarations.
static std::set<const clang::VarDecl*> AllVariedDecls;
const DerivativeBuilder* Builder = nullptr;
// static std::set<const clang::VarDecl*> AllVariedDecls;
/// Function to be differentiated.
const clang::FunctionDecl* Function = nullptr;
/// Name of the base function to be differentiated. Can be different from
Expand Down Expand Up @@ -128,7 +130,8 @@ struct DiffRequest {
Mode == other.Mode && EnableTBRAnalysis == other.EnableTBRAnalysis &&
EnableVariedAnalysis == other.EnableVariedAnalysis &&
DVI == other.DVI && use_enzyme == other.use_enzyme &&
DeclarationOnly == other.DeclarationOnly;
DeclarationOnly == other.DeclarationOnly &&
getVariedDecls() == other.getVariedDecls();
}

const clang::FunctionDecl* operator->() const { return Function; }
Expand All @@ -145,6 +148,16 @@ struct DiffRequest {

bool shouldBeRecorded(clang::Expr* E) const;
bool shouldHaveAdjoint(const clang::VarDecl* VD) const;

void setVariedDecls(std::set<const clang::VarDecl*> init) {
for (auto* vd : init)
this->m_ActivityRunInfo.VariedDecls.insert(vd);
}
std::set<const clang::VarDecl*> getVariedDecls() const {
return this->m_ActivityRunInfo.VariedDecls;
}
DiffRequest() : Builder(nullptr) {}
DiffRequest(DerivativeBuilder& builder) : Builder(&builder) {}
};

using DiffInterval = std::vector<clang::SourceRange>;
Expand Down
32 changes: 21 additions & 11 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
using namespace clang;

namespace clad {
std::set<const clang::VarDecl*> DiffRequest::AllVariedDecls;
static SourceLocation noLoc;
// std::set<const clang::VarDecl*> DiffRequest::AllVariedDecls;
static SourceLocation noloc;

/// Returns `DeclRefExpr` node corresponding to the function, method or
/// functor argument which is to be differentiated.
Expand Down Expand Up @@ -62,7 +62,7 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) {
auto callOperatorDeclName =
m_SemaRef.getASTContext().DeclarationNames.getCXXOperatorName(
OverloadedOperatorKind::OO_Call);
LookupResult R(m_SemaRef, callOperatorDeclName, noLoc,
LookupResult R(m_SemaRef, callOperatorDeclName, noloc,
Sema::LookupNameKind::LookupMemberName);
// We do not want diagnostics that would fire because of this lookup.
R.suppressDiagnostics();
Expand Down Expand Up @@ -149,7 +149,7 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) {
auto* newFnDRE =
clad_compat::GetResult<Expr*>(m_SemaRef.BuildDeclRefExpr(
callOperator, callOperator->getType(),
CLAD_COMPAT_ExprValueKind_R_or_PR_Value, noLoc, &CSS));
CLAD_COMPAT_ExprValueKind_R_or_PR_Value, noloc, &CSS));
m_FnDRE = cast<DeclRefExpr>(newFnDRE);
}
return false;
Expand Down Expand Up @@ -198,7 +198,7 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) {
auto kernelArgIdx = numArgs - 1;
auto* cudaKernelFlag =
SemaRef
.ActOnCXXBoolLiteral(noLoc,
.ActOnCXXBoolLiteral(noloc,
replacementFD->hasAttr<CUDAGlobalAttr>()
? tok::kw_true
: tok::kw_false)
Expand All @@ -209,7 +209,7 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) {

// Create ref to generated FD.
DeclRefExpr* DRE =
DeclRefExpr::Create(C, oldDRE->getQualifierLoc(), noLoc, replacementFD,
DeclRefExpr::Create(C, oldDRE->getQualifierLoc(), noloc, replacementFD,
/*RefersToEnclosingVariableOrCapture=*/false,
replacementFD->getNameInfo(),
replacementFD->getType(), oldDRE->getValueKind());
Expand All @@ -225,7 +225,7 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) {
// Add the "&" operator
auto* newUnOp =
SemaRef
.BuildUnaryOp(nullptr, noLoc, UnaryOperatorKind::UO_AddrOf, DRE)
.BuildUnaryOp(nullptr, noloc, UnaryOperatorKind::UO_AddrOf, DRE)
.get();
call->setArg(derivedFnArgIdx, newUnOp);
}
Expand Down Expand Up @@ -618,15 +618,25 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) {
return true;

if (!m_ActivityRunInfo.HasAnalysisRun) {
if (Builder)
for (auto diffreq : this->Builder->m_DiffRequestGraph.getNodes())
for (auto vd : diffreq.getVariedDecls())
m_ActivityRunInfo.VariedDecls.insert(vd);

if (Args)
for (const auto& dParam : DVI)
AllVariedDecls.insert(cast<VarDecl>(dParam.param));
VariedAnalyzer analyzer(Function->getASTContext(), AllVariedDecls);
m_ActivityRunInfo.VariedDecls.insert(cast<VarDecl>(dParam.param));
VariedAnalyzer analyzer(Function->getASTContext(),
m_ActivityRunInfo.VariedDecls);
analyzer.Analyze(Function);
m_ActivityRunInfo.HasAnalysisRun = true;
if (Builder)
this->Builder->m_DiffRequestGraph.addNode(*this);
}
auto found = AllVariedDecls.find(VD);
return found != AllVariedDecls.end();
auto found = m_ActivityRunInfo.VariedDecls.find(VD);
return found != m_ActivityRunInfo.VariedDecls.end();

return false;
}

bool DiffCollector::VisitCallExpr(CallExpr* E) {
Expand Down
10 changes: 6 additions & 4 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_ExternalSource->ActAfterParsingDiffArgs(m_DiffReq, args);

auto derivativeBaseName = m_DiffReq.BaseFunctionName;
// llvm::errs() << "\nBaseFunctionName: " << derivativeBaseName << "\n";
std::string gradientName = derivativeBaseName + funcPostfix();
// To be consistent with older tests, nothing is appended to 'f_grad' if
// we differentiate w.r.t. all the parameters at once.
Expand Down Expand Up @@ -1944,10 +1945,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_ExternalSource->ActBeforeDifferentiatingCallExpr(
pullbackCallArgs, PreCallStmts, dfdx());

// Overloaded derivative was not found, request the CladPlugin to
// derive the called function.
DiffRequest pullbackRequest{};
pullbackRequest.Function = FD;
// Overloaded derivative was not found, request the CladPlugin to
// derive the called function.
DiffRequest pullbackRequest(m_Builder);
pullbackRequest.Function = FD;

// Mark the indexes of the global args. Necessary if the argument of the
// call has a different name than the function's signature parameter.
Expand All @@ -1960,6 +1961,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
pullbackRequest.VerboseDiags = false;
pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis;
pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis;
pullbackRequest.setVariedDecls(m_DiffReq.getVariedDecls());
bool isaMethod = isa<CXXMethodDecl>(FD);
for (size_t i = 0, e = FD->getNumParams(); i < e; ++i)
if (MD && isLambdaCallOperator(MD)) {
Expand Down

0 comments on commit e092f26

Please sign in to comment.