From fee1a93758d9e424942bdf92fe545b5ebc5635f0 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Thu, 22 Feb 2024 15:42:11 +0100 Subject: [PATCH] Add support for C-style memory alloc and free in reverse mode AD --- include/clad/Differentiator/CladUtils.h | 3 ++ lib/Differentiator/CladUtils.cpp | 16 ++++++++ lib/Differentiator/ReverseModeVisitor.cpp | 47 +++++++++++++++++++++++ test/Gradient/Pointers.C | 38 ++++++++++++++++++ 4 files changed, 104 insertions(+) diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index 5690c3913..9b086ac50 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -328,6 +328,9 @@ namespace clad { void SetSwitchCaseSubStmt(clang::SwitchCase* SC, clang::Stmt* subStmt); bool IsLiteral(const clang::Expr* E); + + bool IsMemoryAllocationFunction(const clang::FunctionDecl* FD); + bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD); } // namespace utils } // namespace clad diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index fbddd535b..cf68b10f4 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -641,5 +641,21 @@ namespace clad { isa(E) || isa(E) || isa(E); } + + bool IsMemoryAllocationFunction(const clang::FunctionDecl* FD) { + if (FD->getNameAsString() == "malloc") + return true; + if (FD->getNameAsString() == "calloc") + return true; + if (FD->getNameAsString() == "realloc") + return true; + return false; + } + + bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD) { + if (FD->getNameAsString() == "free") + return true; + return false; + } } // namespace utils } // namespace clad diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index b792c29a8..be3f8731b 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1441,6 +1441,53 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Stores tape decl and pushes for multiarg numerically differentiated // calls. llvm::SmallVector NumericalDiffMultiArg{}; + + // For calls to C-style memory allocation functions, we do not need to + // differentiate the call. We just need to visit the arguments to the + // function. + if (utils::IsMemoryAllocationFunction(FD)) { + for (const Expr* Arg : CE->arguments()) { + StmtDiff ArgDiff = Visit(Arg, dfdx()); + CallArgs.push_back(ArgDiff.getExpr()); + } + Expr* call = m_Sema + .ActOnCallExpr(getCurrentScope(), + Clone(CE->getCallee()), + noLoc, + llvm::MutableArrayRef(CallArgs), + noLoc) + .get(); + return StmtDiff(call, call); + } + // For calls to C-style memory deallocation functions, we do not need to + // differentiate the call. We just need to visit the arguments to the + // function. Also, don't add any statements either in forward or reverse + // pass. Instead, add it in m_DeallocExprs. + if (utils::IsMemoryDeallocationFunction(FD)) { + for (const Expr* Arg : CE->arguments()) { + StmtDiff ArgDiff = Visit(Arg, dfdx()); + CallArgs.push_back(ArgDiff.getExpr()); + DerivedCallArgs.push_back(ArgDiff.getExpr_dx()); + } + Expr* call = m_Sema + .ActOnCallExpr(getCurrentScope(), + Clone(CE->getCallee()), + noLoc, + llvm::MutableArrayRef(CallArgs), + noLoc) + .get(); + Expr* call_dx = m_Sema + .ActOnCallExpr(getCurrentScope(), + Clone(CE->getCallee()), + noLoc, + llvm::MutableArrayRef(DerivedCallArgs), + noLoc) + .get(); + m_DeallocExprs.push_back(call); + m_DeallocExprs.push_back(call_dx); + return StmtDiff(); + } + // If the result does not depend on the result of the call, just clone // the call and visit arguments (since they may contain side-effects like // f(x = y)) diff --git a/test/Gradient/Pointers.C b/test/Gradient/Pointers.C index cb2b66ee9..f6f1c78ba 100644 --- a/test/Gradient/Pointers.C +++ b/test/Gradient/Pointers.C @@ -430,6 +430,38 @@ double structPointer (double x) { // CHECK-NEXT: delete _d_t; // CHECK-NEXT: } +double cStyleMemoryAlloc(double x, size_t n) { + T* t = (T*)malloc(n * sizeof(T)); + t->x = x; + double res = t->x; + free(t); + return res; +} + +// CHECK: void cStyleMemoryAlloc_grad_0(double x, size_t n, clad::array_ref _d_x) { +// CHECK-NEXT: size_t _d_n = 0; +// CHECK-NEXT: T *_d_t = 0; +// CHECK-NEXT: double _t0; +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: _d_t = (T *)malloc(n * sizeof(T)); +// CHECK-NEXT: T *t = (T *)malloc(n * sizeof(T)); +// CHECK-NEXT: _t0 = t->x; +// CHECK-NEXT: t->x = x; +// CHECK-NEXT: double res = t->x; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: _d_t->x += _d_res; +// CHECK-NEXT: { +// CHECK-NEXT: t->x = _t0; +// CHECK-NEXT: double _r_d0 = _d_t->x; +// CHECK-NEXT: _d_t->x -= _r_d0; +// CHECK-NEXT: * _d_x += _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: free(t); +// CHECK-NEXT: free(_d_t); +// CHECK-NEXT: } + #define NON_MEM_FN_TEST(var)\ res[0]=0;\ var.execute(5,res);\ @@ -533,4 +565,10 @@ int main() { auto d_structPointer = clad::gradient(structPointer); double d_x = 0; d_structPointer.execute(5, &d_x); + printf("%.2f\n", d_x); // CHECK-EXEC: 1.00 + + auto d_cStyleMemoryAlloc = clad::gradient(cStyleMemoryAlloc, "x"); + d_x = 0; + d_cStyleMemoryAlloc.execute(5, 7, &d_x); + printf("%.2f\n", d_x); // CHECK-EXEC: 1.00 } \ No newline at end of file