From 1bfc466c100d2af1d45eea8c76646ca47335d7a7 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Sun, 15 Oct 2023 01:50:10 +0530 Subject: [PATCH] Fix tape push expr for clad array in reverse mode --- lib/Differentiator/ReverseModeVisitor.cpp | 7 +- test/Arrays/ArrayInputsReverseMode.C | 80 +++++++++++++++++++++-- 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 6c5d09cc2..ea176a8a7 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -95,7 +95,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Expr* exprToPush = E; if (auto AT = dyn_cast(E->getType())) { Expr* init = getArraySizeExpr(AT, m_Context, *this); - exprToPush = BuildOp(BO_Comma, E, init); + Expr* pushArgs[] = {E, init}; + SourceLocation loc = E->getExprLoc(); + TypeSourceInfo* TSI = m_Context.getTrivialTypeSourceInfo(EQt, loc); + exprToPush = + m_Sema.BuildCXXTypeConstructExpr(TSI, loc, pushArgs, loc, false) + .get(); } Expr* CallArgs[] = {TapeRef, exprToPush}; Expr* PushExpr = diff --git a/test/Arrays/ArrayInputsReverseMode.C b/test/Arrays/ArrayInputsReverseMode.C index c0e6859c2..aaa2d5a63 100644 --- a/test/Arrays/ArrayInputsReverseMode.C +++ b/test/Arrays/ArrayInputsReverseMode.C @@ -218,7 +218,7 @@ double func4(double x) { //CHECK-NEXT: _t3 = 0; //CHECK-NEXT: for (int i = 0; i < 3; i++) { //CHECK-NEXT: _t3++; -//CHECK-NEXT: clad::push(_t4, arr , 3UL); +//CHECK-NEXT: clad::push(_t4, clad::array(arr, 3UL)); //CHECK-NEXT: sum += addArr(arr, 3); //CHECK-NEXT: } //CHECK-NEXT: goto _label0; @@ -283,7 +283,7 @@ double func5(int k) { //CHECK-NEXT: _t3 = 0; //CHECK-NEXT: for (int i = 0; i < 3; i++) { //CHECK-NEXT: _t3++; -//CHECK-NEXT: clad::push(_t4, arr , n); +//CHECK-NEXT: clad::push(_t4, clad::array(arr, n)); //CHECK-NEXT: sum += addArr(arr, clad::push(_t5, n)); //CHECK-NEXT: } //CHECK-NEXT: goto _label0; @@ -336,7 +336,7 @@ double func6(double seed) { //CHECK-NEXT: for (int i = 0; i < 3; i++) { //CHECK-NEXT: _t0++; //CHECK-NEXT: double arr[3] = {seed, clad::push(_t2, seed) * clad::push(_t1, i), seed + i}; -//CHECK-NEXT: clad::push(_t3, arr , 3UL); +//CHECK-NEXT: clad::push(_t3, clad::array(arr, 3UL)); //CHECK-NEXT: sum += addArr(arr, 3); //CHECK-NEXT: } //CHECK-NEXT: goto _label0; @@ -366,6 +366,72 @@ double func6(double seed) { //CHECK-NEXT: } //CHECK-NEXT: } +double inv_square(double *params) { + return 1 / (params[0] * params[0]); +} + +//CHECK: void inv_square_pullback(double *params, double _d_y, clad::array_ref _d_params) { +//CHECK-NEXT: double _t0; +//CHECK-NEXT: double _t1; +//CHECK-NEXT: double _t2; +//CHECK-NEXT: _t2 = params[0]; +//CHECK-NEXT: _t1 = params[0]; +//CHECK-NEXT: _t0 = (_t2 * _t1); +//CHECK-NEXT: goto _label0; +//CHECK-NEXT: _label0: +//CHECK-NEXT: { +//CHECK-NEXT: double _r0 = _d_y / _t0; +//CHECK-NEXT: double _r1 = _d_y * -1 / (_t0 * _t0); +//CHECK-NEXT: double _r2 = _r1 * _t1; +//CHECK-NEXT: _d_params[0] += _r2; +//CHECK-NEXT: double _r3 = _t2 * _r1; +//CHECK-NEXT: _d_params[0] += _r3; +//CHECK-NEXT: } +//CHECK-NEXT: } + +double func7(double *params) { + double out = 0.0; + for (std::size_t i = 0; i < 1; ++i) { + double paramsPrime[1] = {params[0]}; + out = out + inv_square(paramsPrime); + } + return out; +} + +//CHECK: void func7_grad(double *params, clad::array_ref _d_params) { +//CHECK-NEXT: double _d_out = 0; +//CHECK-NEXT: unsigned long _t0; +//CHECK-NEXT: std::size_t _d_i = 0; +//CHECK-NEXT: clad::array _d_paramsPrime(1UL); +//CHECK-NEXT: clad::tape > _t2 = {}; +//CHECK-NEXT: double out = 0.; +//CHECK-NEXT: _t0 = 0; +//CHECK-NEXT: for (std::size_t i = 0; i < 1; ++i) { +//CHECK-NEXT: _t0++; +//CHECK-NEXT: double paramsPrime[1] = {params[0]}; +//CHECK-NEXT: clad::push(_t2, clad::array(paramsPrime, 1UL)); +//CHECK-NEXT: out = out + inv_square(paramsPrime); +//CHECK-NEXT: } +//CHECK-NEXT: goto _label0; +//CHECK-NEXT: _label0: +//CHECK-NEXT: _d_out += 1; +//CHECK-NEXT: for (; _t0; _t0--) { +//CHECK-NEXT: { +//CHECK-NEXT: double _r_d0 = _d_out; +//CHECK-NEXT: _d_out += _r_d0; +//CHECK-NEXT: clad::array _r1 = clad::pop(_t2); +//CHECK-NEXT: inv_square_pullback(_r1, _r_d0, _d_paramsPrime); +//CHECK-NEXT: clad::array _r0(_d_paramsPrime); +//CHECK-NEXT: _d_out -= _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: int _t1 = 0; +//CHECK-NEXT: _d_params[_t1] += _d_paramsPrime[0]; +//CHECK-NEXT: _d_paramsPrime = {}; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } + int main() { double arr[] = {1, 2, 3}; auto f_dx = clad::gradient(f); @@ -406,5 +472,11 @@ int main() { auto localArray = clad::gradient(func6); double dseed = 0; localArray.execute(1, &dseed); - printf("Result = {%.2f}", dseed); // CHECK-EXEC: Result = {9.00} + printf("Result = {%.2f}\n", dseed); // CHECK-EXEC: Result = {9.00} + + auto func7grad = clad::gradient(func7); + double params = 2.0; + double dparams = 0.0; + func7grad.execute(¶ms, &dparams); + printf("Result = {%.2f}\n", dparams); // CHECK-EXEC: Result = {-0.25} }