Skip to content

Commit

Permalink
Fix tape push expr for clad array in reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Oct 14, 2023
1 parent 41699e8 commit 2fe70c4
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 5 deletions.
7 changes: 6 additions & 1 deletion lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* exprToPush = E;
if (auto AT = dyn_cast<ArrayType>(E->getType())) {
Expr* init = getArraySizeExpr(AT, m_Context, *this);
exprToPush = BuildOp(BO_Comma, E, init);
llvm::SmallVector<Expr*, 2> pushArgs{E, init};
SourceLocation loc = E->getExprLoc();
TypeSourceInfo* TSI = m_Context.getTrivialTypeSourceInfo(EQt, loc);
exprToPush =
m_Sema.BuildCXXTypeConstructExpr(TSI, loc, pushArgs, loc, false)
.get();
}
Expr* CallArgs[] = {TapeRef, exprToPush};
Expr* PushExpr =
Expand Down
80 changes: 76 additions & 4 deletions test/Arrays/ArrayInputsReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ double func4(double x) {
//CHECK-NEXT: _t3 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _t3++;
//CHECK-NEXT: clad::push(_t4, arr , 3UL);
//CHECK-NEXT: clad::push(_t4, clad::array<double>(arr, 3UL));
//CHECK-NEXT: sum += addArr(arr, 3);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
Expand Down Expand Up @@ -283,7 +283,7 @@ double func5(int k) {
//CHECK-NEXT: _t3 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _t3++;
//CHECK-NEXT: clad::push(_t4, arr , n);
//CHECK-NEXT: clad::push(_t4, clad::array<double>(arr, n));
//CHECK-NEXT: sum += addArr(arr, clad::push(_t5, n));
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
Expand Down Expand Up @@ -336,7 +336,7 @@ double func6(double seed) {
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: double arr[3] = {seed, clad::push(_t2, seed) * clad::push(_t1, i), seed + i};
//CHECK-NEXT: clad::push(_t3, arr , 3UL);
//CHECK-NEXT: clad::push(_t3, clad::array<double>(arr, 3UL));
//CHECK-NEXT: sum += addArr(arr, 3);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
Expand Down Expand Up @@ -366,6 +366,72 @@ double func6(double seed) {
//CHECK-NEXT: }
//CHECK-NEXT: }

double inv_square(double *params) {
return 1 / (params[0] * params[0]);
}

//CHECK: void inv_square_pullback(double *params, double _d_y, clad::array_ref<double> _d_params) {
//CHECK-NEXT: double _t0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: double _t2;
//CHECK-NEXT: _t2 = params[0];
//CHECK-NEXT: _t1 = params[0];
//CHECK-NEXT: _t0 = (_t2 * _t1);
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = _d_y / _t0;
//CHECK-NEXT: double _r1 = _d_y * -1 / (_t0 * _t0);
//CHECK-NEXT: double _r2 = _r1 * _t1;
//CHECK-NEXT: _d_params[0] += _r2;
//CHECK-NEXT: double _r3 = _t2 * _r1;
//CHECK-NEXT: _d_params[0] += _r3;
//CHECK-NEXT: }
//CHECK-NEXT: }

double func7(double *params) {
double out = 0.0;
for (std::size_t i = 0; i < 1; ++i) {
double paramsPrime[1] = {params[0]};
out = out + inv_square(paramsPrime);
}
return out;
}

//CHECK: void func7_grad(double *params, clad::array_ref<double> _d_params) {
//CHECK-NEXT: double _d_out = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: std::size_t _d_i = 0;
//CHECK-NEXT: clad::array<double> _d_paramsPrime(1UL);
//CHECK-NEXT: clad::tape<clad::array<double> > _t2 = {};
//CHECK-NEXT: double out = 0.;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (std::size_t i = 0; i < 1; ++i) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: double paramsPrime[1] = {params[0]};
//CHECK-NEXT: clad::push(_t2, clad::array<double>(paramsPrime, 1UL));
//CHECK-NEXT: out = out + inv_square(paramsPrime);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_out += 1;
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: {
//CHECK-NEXT: double _r_d0 = _d_out;
//CHECK-NEXT: _d_out += _r_d0;
//CHECK-NEXT: clad::array<double> _r1 = clad::pop(_t2);
//CHECK-NEXT: inv_square_pullback(_r1, _r_d0, _d_paramsPrime);
//CHECK-NEXT: clad::array<double> _r0(_d_paramsPrime);
//CHECK-NEXT: _d_out -= _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: int _t1 = 0;
//CHECK-NEXT: _d_params[_t1] += _d_paramsPrime[0];
//CHECK-NEXT: _d_paramsPrime = {};
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }

int main() {
double arr[] = {1, 2, 3};
auto f_dx = clad::gradient(f);
Expand Down Expand Up @@ -406,5 +472,11 @@ int main() {
auto localArray = clad::gradient(func6);
double dseed = 0;
localArray.execute(1, &dseed);
printf("Result = {%.2f}", dseed); // CHECK-EXEC: Result = {9.00}
printf("Result = {%.2f}\n", dseed); // CHECK-EXEC: Result = {9.00}

auto func7grad = clad::gradient(func7);
double params = 2.0;
double dparams = 0.0;
func7grad.execute(&params, &dparams);
printf("Result = {%.2f}\n", dparams); // CHECK-EXEC: Result = {-0.25}
}

0 comments on commit 2fe70c4

Please sign in to comment.