From dee9a903e397f9dbde345ec91da7e17c4b26bc42 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Wed, 25 Oct 2023 02:22:26 +0530 Subject: [PATCH] Fix gradient computation of higher order functions --- lib/Differentiator/ReverseModeVisitor.cpp | 23 ++++-- test/Gradient/Functors.C | 89 +++++++++++++++++++++++ 2 files changed, 105 insertions(+), 7 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 67ed800f2..542f052e6 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1345,9 +1345,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } // Create the (_d_param[idx] += dfdx) statement. if (dfdx()) { - Expr* add_assign = BuildOp(BO_AddAssign, it->second, dfdx()); - // Add it to the body statements. - addToCurrentBlock(add_assign, direction::reverse); + // FIXME: not sure if this is generic. + // Don't update derivatives of non-record types. + if (!isa(decl->getType())) { + auto* add_assign = BuildOp(BO_AddAssign, it->second, dfdx()); + // Add it to the body statements. + addToCurrentBlock(add_assign, direction::reverse); + } } return StmtDiff(clonedDRE, it->second, it->second); } @@ -1694,10 +1698,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, else gradArgExpr = BuildOp(UnaryOperatorKind::UO_AddrOf, argDerivative); } else { - // Declare: diffArgType _grad = 0; - gradVarDecl = BuildVarDecl( - PVD->getType(), gradVarII, - ConstantFolder::synthesizeLiteral(PVD->getType(), m_Context, 0)); + // Declare: diffArgType _grad; + Expr* initVal = nullptr; + if (!PVD->getType()->isRecordType()) { + // If the argument is not a class type, then initialize the grad + // variable with 0. + initVal = + ConstantFolder::synthesizeLiteral(PVD->getType(), m_Context, 0); + } + gradVarDecl = BuildVarDecl(PVD->getType(), gradVarII, initVal); // Pass the address of the declared variable gradVarExpr = BuildDeclRef(gradVarDecl); gradArgExpr = diff --git a/test/Gradient/Functors.C b/test/Gradient/Functors.C index 0fa4832e7..e663fe409 100644 --- a/test/Gradient/Functors.C +++ b/test/Gradient/Functors.C @@ -179,6 +179,18 @@ double CallFunctor(double i, double j) { return E(i, j); } +// A function taking functor as an argument. +template +double FunctorAsArg(Func fn, double i, double j) { + return fn(i, j); +} + +// A wrapper for function taking functor as an argument. +double FunctorAsArgWrapper(double i, double j) { + Experiment E(3, 5); + return FunctorAsArg(E, i, j); +} + #define INIT(E) \ auto E##_grad = clad::gradient(&E); \ auto E##Ref_grad = clad::gradient(E); @@ -332,4 +344,81 @@ int main() { double di = 0, dj = 0; CallFunctor_grad.execute(7, 9, &di, &dj); printf("%.2f %.2f\n", di, dj); // CHECK-EXEC: 27.00 21.00 + + // CHECK: void FunctorAsArg_grad(Experiment fn, double i, double j, clad::array_ref _d_fn, clad::array_ref _d_i, clad::array_ref _d_j) { + // CHECK-NEXT: double _t0; + // CHECK-NEXT: double _t1; + // CHECK-NEXT: Experiment _t2; + // CHECK-NEXT: _t0 = i; + // CHECK-NEXT: _t1 = j; + // CHECK-NEXT: _t2 = fn; + // CHECK-NEXT: goto _label0; + // CHECK-NEXT: _label0: + // CHECK-NEXT: { + // CHECK-NEXT: double _grad0 = 0.; + // CHECK-NEXT: double _grad1 = 0.; + // CHECK-NEXT: _t2.operator_call_pullback(_t0, _t1, 1, &(* _d_fn), &_grad0, &_grad1); + // CHECK-NEXT: double _r0 = _grad0; + // CHECK-NEXT: * _d_i += _r0; + // CHECK-NEXT: double _r1 = _grad1; + // CHECK-NEXT: * _d_j += _r1; + // CHECK-NEXT: } + // CHECK-NEXT: } + + // testing differentiating a function taking functor as an argument + auto FunctorAsArg_grad = clad::gradient(FunctorAsArg); + di = 0, dj = 0; + Experiment E_temp(3, 5), dE_temp; + FunctorAsArg_grad.execute(E_temp, 7, 9, &dE_temp, &di, &dj); + printf("%.2f %.2f\n", di, dj); // CHECK-EXEC: 27.00 21.00 + + // CHECK: void FunctorAsArg_pullback(Experiment fn, double i, double j, double _d_y, clad::array_ref _d_fn, clad::array_ref _d_i, clad::array_ref _d_j) { + // CHECK-NEXT: double _t0; + // CHECK-NEXT: double _t1; + // CHECK-NEXT: Experiment _t2; + // CHECK-NEXT: _t0 = i; + // CHECK-NEXT: _t1 = j; + // CHECK-NEXT: _t2 = fn; + // CHECK-NEXT: goto _label0; + // CHECK-NEXT: _label0: + // CHECK-NEXT: { + // CHECK-NEXT: double _grad0 = 0.; + // CHECK-NEXT: double _grad1 = 0.; + // CHECK-NEXT: _t2.operator_call_pullback(_t0, _t1, _d_y, &(* _d_fn), &_grad0, &_grad1); + // CHECK-NEXT: double _r0 = _grad0; + // CHECK-NEXT: * _d_i += _r0; + // CHECK-NEXT: double _r1 = _grad1; + // CHECK-NEXT: * _d_j += _r1; + // CHECK-NEXT: } + // CHECK-NEXT: } + + // CHECK: void FunctorAsArgWrapper_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { + // CHECK-NEXT: Experiment _d_E({}); + // CHECK-NEXT: Experiment _t0; + // CHECK-NEXT: double _t1; + // CHECK-NEXT: double _t2; + // CHECK-NEXT: Experiment E(3, 5); + // CHECK-NEXT: _t0 = E + // CHECK-NEXT: _t1 = i; + // CHECK-NEXT: _t2 = j; + // CHECK-NEXT: goto _label0; + // CHECK-NEXT: _label0: + // CHECK-NEXT: { + // CHECK-NEXT: Experiment _grad0; + // CHECK-NEXT: double _grad1 = 0.; + // CHECK-NEXT: double _grad2 = 0.; + // CHECK-NEXT: FunctorAsArg_pullback(_t0, _t1, _t2, 1, &_grad0, &_grad1, &_grad2); + // CHECK-NEXT: Experiment _r0(_grad0); + // CHECK-NEXT: double _r1 = _grad1; + // CHECK-NEXT: * _d_i += _r1; + // CHECK-NEXT: double _r2 = _grad2; + // CHECK-NEXT: * _d_j += _r2; + // CHECK-NEXT: } + // CHECK-NEXT: } + + // testing differentiating a wrapper for function taking functor as an argument + auto FunctorAsArgWrapper_grad = clad::gradient(FunctorAsArgWrapper); + di = 0, dj = 0; + FunctorAsArgWrapper_grad.execute(7, 9, &di, &dj); + printf("%.2f %.2f\n", di, dj); // CHECK-EXEC: 27.00 21.00 }