Skip to content

Commit

Permalink
Initial support for new and delete operations in reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Feb 18, 2024
1 parent d305002 commit 31da20c
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 2 deletions.
4 changes: 4 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ namespace clad {
/// the reverse mode we also accumulate Stmts for the reverse pass which
/// will be executed on return.
std::vector<Stmts> m_Reverse;
/// A map for storing variables to be deleted/free'd in the reverse pass.
std::unordered_map<const clang::Expr*, clang::Stmt*> m_DeleteExprs;
/// Stack is used to pass the arguments (dfdx) to further nodes
/// in the Visit method.
std::stack<clang::Expr*> m_Stack;
Expand Down Expand Up @@ -370,6 +372,8 @@ namespace clad {
StmtDiff VisitContinueStmt(const clang::ContinueStmt* CS);
StmtDiff VisitBreakStmt(const clang::BreakStmt* BS);
StmtDiff VisitCXXThisExpr(const clang::CXXThisExpr* CTE);
StmtDiff VisitCXXNewExpr(const clang::CXXNewExpr* CNE);
StmtDiff VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE);
StmtDiff VisitCXXConstructExpr(const clang::CXXConstructExpr* CE);
StmtDiff
VisitMaterializeTemporaryExpr(const clang::MaterializeTemporaryExpr* MTE);
Expand Down
67 changes: 66 additions & 1 deletion lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(S, direction::forward);
else
addToCurrentBlock(Reverse, direction::forward);
// Add delete statements present in m_DeleteExprs map.
// std::unordered_map<const clang::Expr*, clang::Stmt*> m_DeleteExprs;
for (auto& it : m_DeleteExprs)
if (auto* CS = dyn_cast<CompoundStmt>(it.second))
for (Stmt* S : CS->body())
addToCurrentBlock(S, direction::forward);
else
addToCurrentBlock(it.second, direction::forward);

if (m_ExternalSource)
m_ExternalSource->ActOnEndOfDerivedFnBody();
Expand Down Expand Up @@ -2573,6 +2581,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
bool isDerivativeOfRefType = VD->getType()->isReferenceType();
VarDecl* VDDerived = nullptr;
bool isPointerType = VD->getType()->isPointerType();
bool isInitializedByNewExpr = false;
// Check if the variable is pointer type and initialized by new expression
// FIXME: We should have a more general way to check if a variable is
// initialized by new memory allocation. For ex: malloc can also be used to
// initialize a pointer.
if (isPointerType && VD->getInit()) {
if (const auto* newExpr = dyn_cast<CXXNewExpr>(VD->getInit())) {
isInitializedByNewExpr = true;
}
}

// VDDerivedInit now serves two purposes -- as the initial derivative value
// or the size of the derivative array -- depending on the primal type.
Expand Down Expand Up @@ -2663,8 +2681,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// differentiated and should not be differentiated again.
// If `VD` is a reference to a non-local variable then also there's no
// need to call `Visit` since non-local variables are not differentiated.
if (!isDerivativeOfRefType && !isPointerType) {
if (!isDerivativeOfRefType && !(isPointerType && !isInitializedByNewExpr)) {
Expr* derivedE = BuildDeclRef(VDDerived);
if (isInitializedByNewExpr) {
// derivedE should be dereferenced.
derivedE = BuildOp(UnaryOperatorKind::UO_Deref, derivedE);
}
if (VD->getInit()) {
if (isa<CXXConstructExpr>(VD->getInit()))
initDiff = Visit(VD->getInit());
Expand Down Expand Up @@ -3703,6 +3725,49 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return {clonedCTE, m_ThisExprDerivative};
}

StmtDiff ReverseModeVisitor::VisitCXXNewExpr(const clang::CXXNewExpr* CNE) {
StmtDiff initializerDiff;
if (CNE->hasInitializer())
initializerDiff = Visit(CNE->getInitializer(), dfdx());

Expr* clonedArraySizeE = nullptr;
Expr* derivedArraySizeE = nullptr;
if (CNE->getArraySize()) {
clonedArraySizeE =
Visit(clad_compat::ArraySize_GetValue(CNE->getArraySize())).getExpr();
// Array size is a non-differentiable expression, thus the original value
// should be used in both the cloned and the derived statements.
derivedArraySizeE = Clone(clonedArraySizeE);
}
Expr* clonedNewE = utils::BuildCXXNewExpr(
m_Sema, CNE->getAllocatedType(), clonedArraySizeE,
initializerDiff.getExpr(), CNE->getAllocatedTypeSourceInfo());
Expr* derivedNewE = utils::BuildCXXNewExpr(
m_Sema, CNE->getAllocatedType(), derivedArraySizeE,
initializerDiff.getExpr_dx(), CNE->getAllocatedTypeSourceInfo());
return {clonedNewE, derivedNewE};
}

StmtDiff
ReverseModeVisitor::VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE) {
StmtDiff argDiff = Visit(CDE->getArgument());
Expr* clonedDeleteE =
m_Sema
.ActOnCXXDelete(noLoc, CDE->isGlobalDelete(), CDE->isArrayForm(),
argDiff.getExpr())
.get();
Expr* derivedDeleteE =
m_Sema
.ActOnCXXDelete(noLoc, CDE->isGlobalDelete(), CDE->isArrayForm(),
argDiff.getExpr_dx())
.get();
// create a compound statement containing both the cloned and the derived
// delete expressions.
CompoundStmt* CS = MakeCompoundStmt({clonedDeleteE, derivedDeleteE});
m_DeleteExprs[CDE->getArgument()] = CS;
return {nullptr, nullptr};
}

// FIXME: Add support for differentiating calls to constructors.
// We currently assume that constructor arguments are non-differentiable.
StmtDiff
Expand Down
72 changes: 71 additions & 1 deletion test/Gradient/Pointers.C
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,71 @@ double pointerMultipleParams(const double* a, const double* b) {
// CHECK-NEXT: _d_b[2] += _d_sum;
// CHECK-NEXT: }

double newAndDeletePointer(double i, double j) {
double *p = new double(i);
double *q = new double(j);
double *r = new double[2];
r[0] = i + j;
r[1] = i*j;
double sum = *p + *q + r[0] + r[1];
delete p;
delete q;
delete [] r;
return sum;
}

// CHECK: void newAndDeletePointer_grad(double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) {
// CHECK-NEXT: double *_d_p = 0;
// CHECK-NEXT: double *_d_q = 0;
// CHECK-NEXT: double *_d_r = 0;
// CHECK-NEXT: double _t0;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: double _d_sum = 0;
// CHECK-NEXT: _d_p = new double(* _d_i);
// CHECK-NEXT: double *p = new double(i);
// CHECK-NEXT: _d_q = new double(* _d_j);
// CHECK-NEXT: double *q = new double(j);
// CHECK-NEXT: _d_r = new double [2];
// CHECK-NEXT: double *r = new double [2];
// CHECK-NEXT: _t0 = r[0];
// CHECK-NEXT: r[0] = i + j;
// CHECK-NEXT: _t1 = r[1];
// CHECK-NEXT: r[1] = i * j;
// CHECK-NEXT: double sum = *p + *q + r[0] + r[1];
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: _d_sum += 1;
// CHECK-NEXT: {
// CHECK-NEXT: *_d_p += _d_sum;
// CHECK-NEXT: *_d_q += _d_sum;
// CHECK-NEXT: _d_r[0] += _d_sum;
// CHECK-NEXT: _d_r[1] += _d_sum;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: r[1] = _t1;
// CHECK-NEXT: double _r_d1 = _d_r[1];
// CHECK-NEXT: _d_r[1] -= _r_d1;
// CHECK-NEXT: * _d_i += _r_d1 * j;
// CHECK-NEXT: * _d_j += i * _r_d1;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: r[0] = _t0;
// CHECK-NEXT: double _r_d0 = _d_r[0];
// CHECK-NEXT: _d_r[0] -= _r_d0;
// CHECK-NEXT: * _d_i += _r_d0;
// CHECK-NEXT: * _d_j += _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: * _d_j += *_d_q;
// CHECK-NEXT: * _d_i += *_d_p;
// CHECK-NEXT: delete [] r;
// CHECK-NEXT: delete [] _d_r;
// CHECK-NEXT: delete q;
// CHECK-NEXT: delete _d_q;
// CHECK-NEXT: delete p;
// CHECK-NEXT: delete _d_p;
// CHECK-NEXT: }


#define NON_MEM_FN_TEST(var)\
res[0]=0;\
var.execute(5,res);\
Expand Down Expand Up @@ -433,4 +498,9 @@ int main() {
d_pointerMultipleParams.execute(arr, b_arr, d_arr_ref, d_b_arr_ref);
printf("%.2f %.2f %.2f %.2f %.2f\n", d_arr[0], d_arr[1], d_arr[2], d_arr[3], d_arr[4]); // CHECK-EXEC: 2.00 4.00 2.00 0.00 0.00
printf("%.2f %.2f %.2f %.2f %.2f\n", d_b_arr[0], d_b_arr[1], d_b_arr[2], d_b_arr[3], d_b_arr[4]); // CHECK-EXEC: 0.00 0.00 1.00 0.00 0.00
}

auto d_newAndDeletePointer = clad::gradient(newAndDeletePointer);
double d_i = 0, d_j = 0;
d_newAndDeletePointer.execute(5, 7, &d_i, &d_j);
printf("%.2f %.2f\n", d_i, d_j); // CHECK-EXEC: 9.00 7.00
}

0 comments on commit 31da20c

Please sign in to comment.