From 81d6eca0f7a07e2b1fff9d2946d86e57e9a3be1c Mon Sep 17 00:00:00 2001
From: Vaibhav Thakkar <vaibhav.thakkar.22.12.99@gmail.com>
Date: Thu, 22 Feb 2024 15:42:11 +0100
Subject: [PATCH] Add support for C-style memory allocations in reverse mode AD

---
 include/clad/Differentiator/CladUtils.h   |  2 ++
 lib/Differentiator/CladUtils.cpp          | 10 +++++++
 lib/Differentiator/ReverseModeVisitor.cpp | 19 ++++++++++++
 test/Gradient/Pointers.C                  | 35 +++++++++++++++++++++++
 4 files changed, 66 insertions(+)

diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h
index 5690c3913..1b3a149e1 100644
--- a/include/clad/Differentiator/CladUtils.h
+++ b/include/clad/Differentiator/CladUtils.h
@@ -328,6 +328,8 @@ namespace clad {
     void SetSwitchCaseSubStmt(clang::SwitchCase* SC, clang::Stmt* subStmt);
 
     bool IsLiteral(const clang::Expr* E);
+
+    bool IsMemoryAllocationFunction(const clang::FunctionDecl* FD);
     } // namespace utils
     } // namespace clad
 
diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp
index fbddd535b..b350aa6d6 100644
--- a/lib/Differentiator/CladUtils.cpp
+++ b/lib/Differentiator/CladUtils.cpp
@@ -641,5 +641,15 @@ namespace clad {
              isa<ObjCBoolLiteralExpr>(E) || isa<CXXBoolLiteralExpr>(E) ||
              isa<GNUNullExpr>(E);
     }
+
+    bool IsMemoryAllocationFunction(const clang::FunctionDecl* FD) {
+      if (FD->getNameAsString() == "malloc")
+        return true;
+      if (FD->getNameAsString() == "calloc")
+        return true;
+      if (FD->getNameAsString() == "realloc")
+        return true;
+      return false;
+    }
   } // namespace utils
 } // namespace clad
diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp
index b792c29a8..49c8c2c80 100644
--- a/lib/Differentiator/ReverseModeVisitor.cpp
+++ b/lib/Differentiator/ReverseModeVisitor.cpp
@@ -1441,6 +1441,25 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
     // Stores tape decl and pushes for multiarg numerically differentiated
     // calls.
     llvm::SmallVector<Stmt*, 16> NumericalDiffMultiArg{};
+
+    // For calls to C-style memory allocation functions, we do not need to
+    // differentiate the call. We just need to visit the arguments to the
+    // function.
+    if (utils::IsMemoryAllocationFunction(FD)) {
+      for (const Expr* Arg : CE->arguments()) {
+        StmtDiff ArgDiff = Visit(Arg, dfdx());
+        CallArgs.push_back(ArgDiff.getExpr());
+      }
+      Expr* call = m_Sema
+                       .ActOnCallExpr(getCurrentScope(),
+                                      Clone(CE->getCallee()),
+                                      noLoc,
+                                      llvm::MutableArrayRef<Expr*>(CallArgs),
+                                      noLoc)
+                       .get();
+      return StmtDiff(call, call);
+    }
+
     // If the result does not depend on the result of the call, just clone
     // the call and visit arguments (since they may contain side-effects like
     // f(x = y))
diff --git a/test/Gradient/Pointers.C b/test/Gradient/Pointers.C
index cb2b66ee9..dd488b0fe 100644
--- a/test/Gradient/Pointers.C
+++ b/test/Gradient/Pointers.C
@@ -430,6 +430,35 @@ double structPointer (double x) {
 // CHECK-NEXT:     delete _d_t;
 // CHECK-NEXT: }
 
+double cStyleMemoryAlloc(double x, size_t n) {
+  T* t = (T*)malloc(n * sizeof(T));
+  t->x = x;
+  double res = t->x;
+  return res;
+}
+
+// CHECK: void cStyleMemoryAlloc_grad_0(double x, size_t n, clad::array_ref<double> _d_x) {
+// CHECK-NEXT:     size_t _d_n = 0;
+// CHECK-NEXT:     T *_d_t = 0;
+// CHECK-NEXT:     double _t0;
+// CHECK-NEXT:     double _d_res = 0;
+// CHECK-NEXT:     _d_t = (T *)malloc(n * sizeof(T));
+// CHECK-NEXT:     T *t = (T *)malloc(n * sizeof(T));
+// CHECK-NEXT:     _t0 = t->x;
+// CHECK-NEXT:     t->x = x;
+// CHECK-NEXT:     double res = t->x;
+// CHECK-NEXT:     goto _label0;
+// CHECK-NEXT:   _label0:
+// CHECK-NEXT:     _d_res += 1;
+// CHECK-NEXT:     _d_t->x += _d_res;
+// CHECK-NEXT:     {
+// CHECK-NEXT:         t->x = _t0;
+// CHECK-NEXT:         double _r_d0 = _d_t->x;
+// CHECK-NEXT:         _d_t->x -= _r_d0;
+// CHECK-NEXT:         * _d_x += _r_d0;
+// CHECK-NEXT:     }
+// CHECK-NEXT: }
+
 #define NON_MEM_FN_TEST(var)\
 res[0]=0;\
 var.execute(5,res);\
@@ -533,4 +562,10 @@ int main() {
   auto d_structPointer = clad::gradient(structPointer);
   double d_x = 0;
   d_structPointer.execute(5, &d_x);
+  printf("%.2f\n", d_x); // CHECK-EXEC: 1.00
+
+  auto d_cStyleMemoryAlloc = clad::gradient(cStyleMemoryAlloc, "x");
+  d_x = 0;
+  d_cStyleMemoryAlloc.execute(5, 7, &d_x);
+  printf("%.2f\n", d_x); // CHECK-EXEC: 1.00
 }
\ No newline at end of file