Skip to content

Commit

Permalink
Differentiate global variables in the reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Nov 23, 2024
1 parent 693b1bd commit 3acb66a
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 25 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace clad {
class CladPlugin;
clang::FunctionDecl* ProcessDiffRequest(CladPlugin& P,
DiffRequest& request);
void ProcessTopLevelDecl(CladPlugin& P, clang::Decl* D);
} // namespace plugin

} // namespace clad
Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ namespace clad {
StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS);
DeclDiff<clang::VarDecl> DifferentiateVarDecl(const clang::VarDecl* VD,
bool keepLocal = false);
clang::Expr* DifferentiateGlobalVarDecl(clang::VarDecl* VD);
StmtDiff VisitSubstNonTypeTemplateParmExpr(
const clang::SubstNonTypeTemplateParmExpr* NTTP);
StmtDiff
Expand Down
8 changes: 5 additions & 3 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,9 @@ namespace clad {
/// \returns The newly built variable declaration.
clang::VarDecl*
BuildVarDecl(clang::QualType Type, clang::IdentifierInfo* Identifier,
clang::Scope* scope, clang::Expr* Init = nullptr,
bool DirectInit = false, clang::TypeSourceInfo* TSI = nullptr,
clang::Scope* scope, clang::DeclContext* DeclCtx,
clang::Expr* Init = nullptr, bool DirectInit = false,
clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit);
/// Builds variable declaration to be used inside the derivative
Expand Down Expand Up @@ -334,7 +335,8 @@ namespace clad {
clang::Expr* Init = nullptr, bool DirectInit = false,
clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit);
clang::VarDecl::InitializationStyle::CInit,
clang::DeclContext* DeclCtx = nullptr);
/// Creates a namespace declaration and enters its context. All subsequent
/// Stmts are built inside that namespace, until
/// m_Sema.PopDeclContextIsUsed.
Expand Down
44 changes: 42 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1413,8 +1413,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Check DeclRefExpr is a reference to an independent variable.
auto it = m_Variables.find(VD);
if (it == std::end(m_Variables)) {
// Is not an independent variable, ignored.
return StmtDiff(clonedDRE);
if (VD->isFileVarDecl()) {
Expr* DREDiff = DifferentiateGlobalVarDecl(VD);
it = m_Variables.emplace(VD, DREDiff).first;
} else
// Is not an independent variable, ignored.
return StmtDiff(clonedDRE);
}
// Create the (_d_param[idx] += dfdx) statement.
if (dfdx()) {
Expand All @@ -1440,6 +1444,42 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(clonedDRE);
}

Expr* ReverseModeVisitor::DifferentiateGlobalVarDecl(VarDecl* VD) {
assert(VD->isFileVarDecl() && "Must be a global variable");
std::string nameDiff_str = "_d_" + VD->getNameAsString();
DeclarationName nameDiff = &m_Context.Idents.get(nameDiff_str);
DeclContext* DC = VD->getDeclContext();

// Attempt to find the adjoint of VD in case it has already been created.
LookupResult result(m_Sema, nameDiff, noLoc, Sema::LookupOrdinaryName);
m_Sema.LookupQualifiedName(result, DC);
if (!result.empty()) {
// Found, return a reference
Expr* foundExpr = m_Sema
.BuildDeclarationNameExpr(CXXScopeSpec{}, result,
/*ADL=*/false)
.get();
return foundExpr;
}
// Not found, construct the adjoint and register it.
VarDecl* VDDiff =
BuildVarDecl(VD->getType(), CreateUniqueIdentifier(nameDiff_str),
m_DerivativeFnScope->getParent()->getParent(), DC,
getZeroInit(VD->getType()));

DC->addDecl(VDDiff);
DC->makeDeclVisibleInContext(VDDiff);
plugin::ProcessTopLevelDecl(m_CladPlugin, VDDiff);
// diag(DiagnosticsEngine::Warning,
// VD->getLocation(),
// "The gradient utilizes a global variable '%0' and its adjoint
// '%1'"
// ". Please make sure to properly reset '%0' and '%1' before
// re-running the gradient.",
// {VD->getNameAsString(), nameDiff_str});
return BuildDeclRef(VDDiff);
}

StmtDiff ReverseModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) {
auto* Constant0 =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Expand Down
20 changes: 12 additions & 8 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,19 @@ namespace clad {
Expr* Init, bool DirectInit,
TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS) {
return BuildVarDecl(Type, Identifier, getCurrentScope(), Init, DirectInit,
TSI, IS);
return BuildVarDecl(Type, Identifier, getCurrentScope(), m_Sema.CurContext,
Init, DirectInit, TSI, IS);
}
VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier,
Scope* Scope, Expr* Init, bool DirectInit,
Scope* Scope, DeclContext* DeclCtx,
Expr* Init, bool DirectInit,
TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS) {
// add namespace specifier in variable declaration if needed.
Type = utils::AddNamespaceSpecifier(m_Sema, m_Context, Type);
auto* VD = VarDecl::Create(
m_Context, m_Sema.CurContext, m_DiffReq->getLocation(),
m_DiffReq->getLocation(), Identifier, Type, TSI, SC_None);
auto* VD = VarDecl::Create(m_Context, DeclCtx, m_DiffReq->getLocation(),
m_DiffReq->getLocation(), Identifier, Type, TSI,
SC_None);

if (Init) {
m_Sema.AddInitializerToDecl(VD, Init, DirectInit);
Expand Down Expand Up @@ -149,9 +150,12 @@ namespace clad {
VarDecl* VisitorBase::BuildGlobalVarDecl(QualType Type,
llvm::StringRef prefix, Expr* Init,
bool DirectInit, TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS) {
VarDecl::InitializationStyle IS,
DeclContext* DeclCtx) {
DeclCtx = DeclCtx ? DeclCtx : m_Sema.CurContext;
return BuildVarDecl(Type, CreateUniqueIdentifier(prefix),
m_DerivativeFnScope, Init, DirectInit, TSI, IS);
m_DerivativeFnScope, DeclCtx, Init, DirectInit, TSI,
IS);
}

NamespaceDecl* VisitorBase::BuildNamespaceDecl(IdentifierInfo* II,
Expand Down
1 change: 1 addition & 0 deletions test/Gradient/Functors.C
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ int main() {

// CHECK: inline void operator_call_grad(double ii, double j, double *_d_ii, double *_d_j) const {
// CHECK-NEXT: {
// CHECK-NEXT: _d_x += 1 * j * ii;
// CHECK-NEXT: *_d_ii += x * 1 * j;
// CHECK-NEXT: *_d_j += x * ii * 1;
// CHECK-NEXT: }
Expand Down
62 changes: 61 additions & 1 deletion test/Gradient/Gradients.C
Original file line number Diff line number Diff line change
Expand Up @@ -666,14 +666,16 @@ float running_sum(float* p, int n) {
// CHECK-NEXT: }

double global = 7;
// CHECK: double _d_global = 0.;
// expected-warning {{The gradient utilizes a global variable 'global' and its adjoint '_d_global'. Please make sure to properly reset 'global' and '_d_global' before re-running the gradient.}}

double fn_global_var_use(double i, double j) {
double& ref = global;
return ref * i;
}

// CHECK: void fn_global_var_use_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: double _d_ref = 0.;
// CHECK-NEXT: double &_d_ref = _d_global;
// CHECK-NEXT: double &ref = global;
// CHECK-NEXT: {
// CHECK-NEXT: _d_ref += 1 * i;
Expand Down Expand Up @@ -1143,6 +1145,49 @@ double f_ref_in_rhs(double x, double y) {
//CHECK-NEXT: }
//CHECK-NEXT: }

double glob1 = 5;

double g(double a, double b) {
glob1 = b;
return a;
}

//CHECK: void g_pullback(double a, double b, double _d_y, double *_d_a, double *_d_b);

//CHECK: double _d_glob1 = 0.;
// expected-warning {{The gradient utilizes a global variable 'glob1' and its adjoint '_d_glob1'. Please make sure to properly reset 'glob1' and '_d_glob1' before re-running the gradient.}}

double f_reuse_global(double x, double t) {
t = g(t, x);
glob1 *= t;
return -glob1;
} // -x * t

//CHECK: void f_reuse_global_grad(double x, double t, double *_d_x, double *_d_t) {
//CHECK-NEXT: double _t0 = t;
//CHECK-NEXT: t = g(t, x);
//CHECK-NEXT: double _t1 = glob1;
//CHECK-NEXT: glob1 *= t;
//CHECK-NEXT: _d_glob1 += -1;
//CHECK-NEXT: {
//CHECK-NEXT: glob1 = _t1;
//CHECK-NEXT: double _r_d1 = _d_glob1;
//CHECK-NEXT: _d_glob1 = 0.;
//CHECK-NEXT: _d_glob1 += _r_d1 * t;
//CHECK-NEXT: *_d_t += glob1 * _r_d1;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: t = _t0;
//CHECK-NEXT: double _r_d0 = *_d_t;
//CHECK-NEXT: *_d_t = 0.;
//CHECK-NEXT: double _r0 = 0.;
//CHECK-NEXT: double _r1 = 0.;
//CHECK-NEXT: g_pullback(t, x, _r_d0, &_r0, &_r1);
//CHECK-NEXT: *_d_t += _r0;
//CHECK-NEXT: *_d_x += _r1;
//CHECK-NEXT: }
//CHECK-NEXT: }

#define TEST(F, x, y) \
{ \
result[0] = 0; \
Expand Down Expand Up @@ -1239,4 +1284,19 @@ int main() {

INIT_GRADIENT(f_ref_in_rhs);
TEST_GRADIENT(f_ref_in_rhs, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {5.00, 13.00}

INIT_GRADIENT(f_reuse_global);
TEST_GRADIENT(f_reuse_global, /*numOfDerivativeArgs=*/2, -3, 4, &d_i, &d_j); // CHECK-EXEC: {-4.00, 3.00}
}

//CHECK-NEXT: void g_pullback(double a, double b, double _d_y, double *_d_a, double *_d_b) {
//CHECK-NEXT: double _t0 = glob1;
//CHECK-NEXT: glob1 = b;
//CHECK-NEXT: *_d_a += _d_y;
//CHECK-NEXT: {
//CHECK-NEXT: glob1 = _t0;
//CHECK-NEXT: double _r_d0 = _d_glob1;
//CHECK-NEXT: _d_glob1 = 0.;
//CHECK-NEXT: *_d_b += _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }
5 changes: 4 additions & 1 deletion test/ValidCodeGen/ValidCodeGen.C
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ int main() {
//CHECK-NEXT: }

//CHECK: void fn_grad(double x, double *_d_x) {
//CHECK-NEXT: *_d_x += 1 * TN::coefficient;
//CHECK-NEXT: {
//CHECK-NEXT: *_d_x += 1 * TN::coefficient;
//CHECK-NEXT: _d_coefficient += x * 1;
//CHECK-NEXT: }
//CHECK-NEXT: }

//CHECK: void fn2_grad(double x, double y, double *_d_x, double *_d_y) {
Expand Down
32 changes: 22 additions & 10 deletions tools/ClangPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,24 @@ class CladTimerGroup {
// handling of the differentiation plans.
clang::FunctionDecl* ProcessDiffRequest(DiffRequest& request);

void ProcessTopLevelDecl(clang::Decl* D) {
if (llvm::isa<clang::VarDecl>(D) && m_DO.DumpDerivedFn) {
clang::LangOptions LangOpts;
LangOpts.CPlusPlus = true;
clang::PrintingPolicy Policy(LangOpts);
Policy.Bool = true;
D->print(llvm::outs(), Policy);
llvm::outs() << ";\n";
}
DelayedCallInfo DCI{CallKind::HandleTopLevelDecl, D};
assert(!llvm::is_contained(m_DelayedCalls, DCI) && "Already exists!");
AppendDelayed(DCI);
// We could not delay the process due to some strange way of
// initialization, inform the consumers now.
if (!m_Multiplexer)
m_CI.getASTConsumer().HandleTopLevelDecl(DCI.m_DGR);
}

private:
void AppendDelayed(DelayedCallInfo DCI) {
// Incremental processing handles the translation unit in chunks and it is
Expand All @@ -268,16 +286,6 @@ class CladTimerGroup {
void SendToMultiplexer();
bool CheckBuiltins();
void SetRequestOptions(RequestOptions& opts) const;

void ProcessTopLevelDecl(clang::Decl* D) {
DelayedCallInfo DCI{CallKind::HandleTopLevelDecl, D};
assert(!llvm::is_contained(m_DelayedCalls, DCI) && "Already exists!");
AppendDelayed(DCI);
// We could not delay the process due to some strange way of
// initialization, inform the consumers now.
if (!m_Multiplexer)
m_CI.getASTConsumer().HandleTopLevelDecl(DCI.m_DGR);
}
void HandleTopLevelDeclForClad(clang::DeclGroupRef DGR);
};

Expand All @@ -286,6 +294,10 @@ class CladTimerGroup {
return P.ProcessDiffRequest(request);
}

void ProcessTopLevelDecl(CladPlugin& P, clang::Decl* D) {
P.ProcessTopLevelDecl(D);
}

template <typename ConsumerType>
class Action : public clang::PluginASTAction {
private:
Expand Down

0 comments on commit 3acb66a

Please sign in to comment.