From 72a3890a32060a09d0d07e2dfee54f89689155be Mon Sep 17 00:00:00 2001
From: Vaibhav Thakkar <vaibhav.thakkar.22.12.99@gmail.com>
Date: Wed, 25 Oct 2023 02:22:26 +0530
Subject: [PATCH] Fix gradient computation of higher order functions

---
 lib/Differentiator/ReverseModeVisitor.cpp | 23 ++++--
 test/Gradient/Functors.C                  | 89 +++++++++++++++++++++++
 2 files changed, 105 insertions(+), 7 deletions(-)

diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp
index 67ed800f2..b6eb0691d 100644
--- a/lib/Differentiator/ReverseModeVisitor.cpp
+++ b/lib/Differentiator/ReverseModeVisitor.cpp
@@ -1345,9 +1345,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
         }
         // Create the (_d_param[idx] += dfdx) statement.
         if (dfdx()) {
-          Expr* add_assign = BuildOp(BO_AddAssign, it->second, dfdx());
-          // Add it to the body statements.
-          addToCurrentBlock(add_assign, direction::reverse);
+          // FIXME: not sure if this is generic.
+          // Don't update derivatives of non-record types.
+          if (!decl->getType()->isRecordType()) {
+            auto* add_assign = BuildOp(BO_AddAssign, it->second, dfdx());
+            // Add it to the body statements.
+            addToCurrentBlock(add_assign, direction::reverse);
+          }
         }
         return StmtDiff(clonedDRE, it->second, it->second);
       }
@@ -1694,10 +1698,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
           else
             gradArgExpr = BuildOp(UnaryOperatorKind::UO_AddrOf, argDerivative);
         } else {
-          // Declare: diffArgType _grad = 0;
-          gradVarDecl = BuildVarDecl(
-              PVD->getType(), gradVarII,
-              ConstantFolder::synthesizeLiteral(PVD->getType(), m_Context, 0));
+          // Declare: diffArgType _grad;
+          Expr* initVal = nullptr;
+          if (!PVD->getType()->isRecordType()) {
+            // If the argument is not a class type, then initialize the grad
+            // variable with 0.
+            initVal =
+                ConstantFolder::synthesizeLiteral(PVD->getType(), m_Context, 0);
+          }
+          gradVarDecl = BuildVarDecl(PVD->getType(), gradVarII, initVal);
           // Pass the address of the declared variable
           gradVarExpr = BuildDeclRef(gradVarDecl);
           gradArgExpr =
diff --git a/test/Gradient/Functors.C b/test/Gradient/Functors.C
index 0fa4832e7..e663fe409 100644
--- a/test/Gradient/Functors.C
+++ b/test/Gradient/Functors.C
@@ -179,6 +179,18 @@ double CallFunctor(double i, double j) {
   return E(i, j);
 }
 
+// A function taking functor as an argument.
+template<typename Func>
+double FunctorAsArg(Func fn, double i, double j) {
+  return fn(i, j);
+}
+
+// A wrapper for function taking functor as an argument.
+double FunctorAsArgWrapper(double i, double j) {
+  Experiment E(3, 5);
+  return FunctorAsArg(E, i, j);
+}
+
 #define INIT(E)                                                                \
   auto E##_grad = clad::gradient(&E);                                          \
   auto E##Ref_grad = clad::gradient(E);
@@ -332,4 +344,81 @@ int main() {
   double di = 0, dj = 0;
   CallFunctor_grad.execute(7, 9, &di, &dj);
   printf("%.2f %.2f\n", di, dj);              // CHECK-EXEC: 27.00 21.00
+
+  // CHECK: void FunctorAsArg_grad(Experiment fn, double i, double j, clad::array_ref<Experiment> _d_fn, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) {
+  // CHECK-NEXT:     double _t0;
+  // CHECK-NEXT:     double _t1;
+  // CHECK-NEXT:     Experiment _t2;
+  // CHECK-NEXT:     _t0 = i;
+  // CHECK-NEXT:     _t1 = j;
+  // CHECK-NEXT:     _t2 = fn;
+  // CHECK-NEXT:     goto _label0;
+  // CHECK-NEXT:   _label0:
+  // CHECK-NEXT:     {
+  // CHECK-NEXT:         double _grad0 = 0.;
+  // CHECK-NEXT:         double _grad1 = 0.;
+  // CHECK-NEXT:         _t2.operator_call_pullback(_t0, _t1, 1, &(* _d_fn), &_grad0, &_grad1);
+  // CHECK-NEXT:         double _r0 = _grad0;
+  // CHECK-NEXT:         * _d_i += _r0;
+  // CHECK-NEXT:         double _r1 = _grad1;
+  // CHECK-NEXT:         * _d_j += _r1;
+  // CHECK-NEXT:     }
+  // CHECK-NEXT: }
+
+  // testing differentiating a function taking functor as an argument
+  auto FunctorAsArg_grad = clad::gradient(FunctorAsArg<Experiment>);
+  di = 0, dj = 0;
+  Experiment E_temp(3, 5), dE_temp;
+  FunctorAsArg_grad.execute(E_temp, 7, 9, &dE_temp, &di, &dj);
+  printf("%.2f %.2f\n", di, dj);              // CHECK-EXEC: 27.00 21.00
+
+  // CHECK: void FunctorAsArg_pullback(Experiment fn, double i, double j, double _d_y, clad::array_ref<Experiment> _d_fn, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) {
+  // CHECK-NEXT:     double _t0;
+  // CHECK-NEXT:     double _t1;
+  // CHECK-NEXT:     Experiment _t2;
+  // CHECK-NEXT:     _t0 = i;
+  // CHECK-NEXT:     _t1 = j;
+  // CHECK-NEXT:     _t2 = fn;
+  // CHECK-NEXT:     goto _label0;
+  // CHECK-NEXT:   _label0:
+  // CHECK-NEXT:     {
+  // CHECK-NEXT:         double _grad0 = 0.;
+  // CHECK-NEXT:         double _grad1 = 0.;
+  // CHECK-NEXT:         _t2.operator_call_pullback(_t0, _t1, _d_y, &(* _d_fn), &_grad0, &_grad1);
+  // CHECK-NEXT:         double _r0 = _grad0;
+  // CHECK-NEXT:         * _d_i += _r0;
+  // CHECK-NEXT:         double _r1 = _grad1;
+  // CHECK-NEXT:         * _d_j += _r1;
+  // CHECK-NEXT:     }
+  // CHECK-NEXT: }
+
+  // CHECK: void FunctorAsArgWrapper_grad(double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) {
+  // CHECK-NEXT:     Experiment _d_E({});
+  // CHECK-NEXT:     Experiment _t0;
+  // CHECK-NEXT:     double _t1;
+  // CHECK-NEXT:     double _t2;
+  // CHECK-NEXT:     Experiment E(3, 5);
+  // CHECK-NEXT:     _t0 = E
+  // CHECK-NEXT:     _t1 = i;
+  // CHECK-NEXT:     _t2 = j;
+  // CHECK-NEXT:     goto _label0;
+  // CHECK-NEXT:   _label0:
+  // CHECK-NEXT:     {
+  // CHECK-NEXT:         Experiment _grad0;
+  // CHECK-NEXT:         double _grad1 = 0.;
+  // CHECK-NEXT:         double _grad2 = 0.;
+  // CHECK-NEXT:         FunctorAsArg_pullback(_t0, _t1, _t2, 1, &_grad0, &_grad1, &_grad2);
+  // CHECK-NEXT:         Experiment _r0(_grad0);
+  // CHECK-NEXT:         double _r1 = _grad1;
+  // CHECK-NEXT:         * _d_i += _r1;
+  // CHECK-NEXT:         double _r2 = _grad2;
+  // CHECK-NEXT:         * _d_j += _r2;
+  // CHECK-NEXT:     }
+  // CHECK-NEXT: }
+
+  // testing differentiating a wrapper for function taking functor as an argument
+  auto FunctorAsArgWrapper_grad = clad::gradient(FunctorAsArgWrapper);
+  di = 0, dj = 0;
+  FunctorAsArgWrapper_grad.execute(7, 9, &di, &dj);
+  printf("%.2f %.2f\n", di, dj);              // CHECK-EXEC: 27.00 21.00
 }