From 35ee97d6c3da9883397c6776483f615aecb6bb89 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Sun, 18 Feb 2024 20:13:26 +0100 Subject: [PATCH] Initial support for new and delete operations in reverse mode --- .../clad/Differentiator/ReverseModeVisitor.h | 4 ++ lib/Differentiator/ReverseModeVisitor.cpp | 63 +++++++++++++++- test/Gradient/Pointers.C | 72 ++++++++++++++++++- 3 files changed, 137 insertions(+), 2 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 1cd1c0bfa..b38749bf1 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -41,6 +41,8 @@ namespace clad { /// the reverse mode we also accumulate Stmts for the reverse pass which /// will be executed on return. std::vector m_Reverse; + /// Storing expressions to delete/free memory in the reverse pass. + Stmts m_DeallocExprs; /// Stack is used to pass the arguments (dfdx) to further nodes /// in the Visit method. std::stack m_Stack; @@ -370,6 +372,8 @@ namespace clad { StmtDiff VisitContinueStmt(const clang::ContinueStmt* CS); StmtDiff VisitBreakStmt(const clang::BreakStmt* BS); StmtDiff VisitCXXThisExpr(const clang::CXXThisExpr* CTE); + StmtDiff VisitCXXNewExpr(const clang::CXXNewExpr* CNE); + StmtDiff VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE); StmtDiff VisitCXXConstructExpr(const clang::CXXConstructExpr* CE); StmtDiff VisitMaterializeTemporaryExpr(const clang::MaterializeTemporaryExpr* MTE); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 6e3f1af48..a9529270c 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -610,6 +610,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, addToCurrentBlock(S, direction::forward); else addToCurrentBlock(Reverse, direction::forward); + // Add delete statements present in m_DeallocExprs to the current block. + for (auto* S : m_DeallocExprs) + if (auto* CS = dyn_cast(S)) + for (Stmt* S : CS->body()) + addToCurrentBlock(S, direction::forward); + else + addToCurrentBlock(S, direction::forward); if (m_ExternalSource) m_ExternalSource->ActOnEndOfDerivedFnBody(); @@ -2573,6 +2580,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, bool isDerivativeOfRefType = VD->getType()->isReferenceType(); VarDecl* VDDerived = nullptr; bool isPointerType = VD->getType()->isPointerType(); + bool isInitializedByNewExpr = false; + // Check if the variable is pointer type and initialized by new expression + if (isPointerType && VD->getInit()) { + if (isa(VD->getInit())) { + isInitializedByNewExpr = true; + } + } // VDDerivedInit now serves two purposes -- as the initial derivative value // or the size of the derivative array -- depending on the primal type. @@ -2663,8 +2677,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // differentiated and should not be differentiated again. // If `VD` is a reference to a non-local variable then also there's no // need to call `Visit` since non-local variables are not differentiated. - if (!isDerivativeOfRefType && !isPointerType) { + if (!isDerivativeOfRefType && !(isPointerType && !isInitializedByNewExpr)) { Expr* derivedE = BuildDeclRef(VDDerived); + if (isInitializedByNewExpr) { + // derivedE should be dereferenced. + derivedE = BuildOp(UnaryOperatorKind::UO_Deref, derivedE); + } if (VD->getInit()) { if (isa(VD->getInit())) initDiff = Visit(VD->getInit()); @@ -3703,6 +3721,49 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return {clonedCTE, m_ThisExprDerivative}; } + StmtDiff ReverseModeVisitor::VisitCXXNewExpr(const clang::CXXNewExpr* CNE) { + StmtDiff initializerDiff; + if (CNE->hasInitializer()) + initializerDiff = Visit(CNE->getInitializer(), dfdx()); + + Expr* clonedArraySizeE = nullptr; + Expr* derivedArraySizeE = nullptr; + if (CNE->getArraySize()) { + clonedArraySizeE = + Visit(clad_compat::ArraySize_GetValue(CNE->getArraySize())).getExpr(); + // Array size is a non-differentiable expression, thus the original value + // should be used in both the cloned and the derived statements. + derivedArraySizeE = Clone(clonedArraySizeE); + } + Expr* clonedNewE = utils::BuildCXXNewExpr( + m_Sema, CNE->getAllocatedType(), clonedArraySizeE, + initializerDiff.getExpr(), CNE->getAllocatedTypeSourceInfo()); + Expr* derivedNewE = utils::BuildCXXNewExpr( + m_Sema, CNE->getAllocatedType(), derivedArraySizeE, + initializerDiff.getExpr_dx(), CNE->getAllocatedTypeSourceInfo()); + return {clonedNewE, derivedNewE}; +} + +StmtDiff +ReverseModeVisitor::VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE) { + StmtDiff argDiff = Visit(CDE->getArgument()); + Expr* clonedDeleteE = + m_Sema + .ActOnCXXDelete(noLoc, CDE->isGlobalDelete(), CDE->isArrayForm(), + argDiff.getExpr()) + .get(); + Expr* derivedDeleteE = + m_Sema + .ActOnCXXDelete(noLoc, CDE->isGlobalDelete(), CDE->isArrayForm(), + argDiff.getExpr_dx()) + .get(); + // create a compound statement containing both the cloned and the derived + // delete expressions. + CompoundStmt* CS = MakeCompoundStmt({clonedDeleteE, derivedDeleteE}); + m_DeallocExprs.push_back(CS); + return {nullptr, nullptr}; +} + // FIXME: Add support for differentiating calls to constructors. // We currently assume that constructor arguments are non-differentiable. StmtDiff diff --git a/test/Gradient/Pointers.C b/test/Gradient/Pointers.C index 107254ae0..11a66705c 100644 --- a/test/Gradient/Pointers.C +++ b/test/Gradient/Pointers.C @@ -339,6 +339,71 @@ double pointerMultipleParams(const double* a, const double* b) { // CHECK-NEXT: _d_b[2] += _d_sum; // CHECK-NEXT: } +double newAndDeletePointer(double i, double j) { + double *p = new double(i); + double *q = new double(j); + double *r = new double[2]; + r[0] = i + j; + r[1] = i*j; + double sum = *p + *q + r[0] + r[1]; + delete p; + delete q; + delete [] r; + return sum; +} + +// CHECK: void newAndDeletePointer_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double *_d_p = 0; +// CHECK-NEXT: double *_d_q = 0; +// CHECK-NEXT: double *_d_r = 0; +// CHECK-NEXT: double _t0; +// CHECK-NEXT: double _t1; +// CHECK-NEXT: double _d_sum = 0; +// CHECK-NEXT: _d_p = new double(* _d_i); +// CHECK-NEXT: double *p = new double(i); +// CHECK-NEXT: _d_q = new double(* _d_j); +// CHECK-NEXT: double *q = new double(j); +// CHECK-NEXT: _d_r = new double [2]; +// CHECK-NEXT: double *r = new double [2]; +// CHECK-NEXT: _t0 = r[0]; +// CHECK-NEXT: r[0] = i + j; +// CHECK-NEXT: _t1 = r[1]; +// CHECK-NEXT: r[1] = i * j; +// CHECK-NEXT: double sum = *p + *q + r[0] + r[1]; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_sum += 1; +// CHECK-NEXT: { +// CHECK-NEXT: *_d_p += _d_sum; +// CHECK-NEXT: *_d_q += _d_sum; +// CHECK-NEXT: _d_r[0] += _d_sum; +// CHECK-NEXT: _d_r[1] += _d_sum; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: r[1] = _t1; +// CHECK-NEXT: double _r_d1 = _d_r[1]; +// CHECK-NEXT: _d_r[1] -= _r_d1; +// CHECK-NEXT: * _d_i += _r_d1 * j; +// CHECK-NEXT: * _d_j += i * _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: r[0] = _t0; +// CHECK-NEXT: double _r_d0 = _d_r[0]; +// CHECK-NEXT: _d_r[0] -= _r_d0; +// CHECK-NEXT: * _d_i += _r_d0; +// CHECK-NEXT: * _d_j += _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: * _d_j += *_d_q; +// CHECK-NEXT: * _d_i += *_d_p; +// CHECK-NEXT: delete [] r; +// CHECK-NEXT: delete [] _d_r; +// CHECK-NEXT: delete q; +// CHECK-NEXT: delete _d_q; +// CHECK-NEXT: delete p; +// CHECK-NEXT: delete _d_p; +// CHECK-NEXT: } + + #define NON_MEM_FN_TEST(var)\ res[0]=0;\ var.execute(5,res);\ @@ -433,4 +498,9 @@ int main() { d_pointerMultipleParams.execute(arr, b_arr, d_arr_ref, d_b_arr_ref); printf("%.2f %.2f %.2f %.2f %.2f\n", d_arr[0], d_arr[1], d_arr[2], d_arr[3], d_arr[4]); // CHECK-EXEC: 2.00 4.00 2.00 0.00 0.00 printf("%.2f %.2f %.2f %.2f %.2f\n", d_b_arr[0], d_b_arr[1], d_b_arr[2], d_b_arr[3], d_b_arr[4]); // CHECK-EXEC: 0.00 0.00 1.00 0.00 0.00 -} + + auto d_newAndDeletePointer = clad::gradient(newAndDeletePointer); + double d_i = 0, d_j = 0; + d_newAndDeletePointer.execute(5, 7, &d_i, &d_j); + printf("%.2f %.2f\n", d_i, d_j); // CHECK-EXEC: 9.00 7.00 +} \ No newline at end of file