Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Dec 23, 2023
1 parent 7f455a7 commit 514e58c
Showing 1 changed file with 104 additions and 0 deletions.
104 changes: 104 additions & 0 deletions test/Gradient/Pointers.C
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,101 @@ double pointerParam(const double* arr, size_t n) {
return sum;
}

// CHECK: void pointerParam_grad_0(const double *arr, size_t n, clad::array_ref<double> _d_arr) {
// CHECK-NEXT: size_t _d_n = 0;
// CHECK-NEXT: double _d_sum = 0;
// CHECK-NEXT: unsigned long _t0;
// CHECK-NEXT: size_t _d_i = 0;
// CHECK-NEXT: clad::tape<size_t *> _t1 = {};
// CHECK-NEXT: size_t *_d_j = 0;
// CHECK-NEXT: clad::tape<double> _t3 = {};
// CHECK-NEXT: clad::tape<const double *> _t4 = {};
// CHECK-NEXT: clad::tape<clad::array_ref<double> > _t5 = {};
// CHECK-NEXT: double sum = 0;
// CHECK-NEXT: _t0 = 0;
// CHECK-NEXT: for (size_t i = 0; i < n; ++i) {
// CHECK-NEXT: _t0++;
// CHECK-NEXT: _d_j = &_d_i;
// CHECK-NEXT: clad::push(_t1, _d_j);
// CHECK-NEXT: size_t *j = &i;
// CHECK-NEXT: clad::push(_t3, sum);
// CHECK-NEXT: sum += arr[0] * (*j);
// CHECK-NEXT: clad::push(_t4, arr);
// CHECK-NEXT: clad::push(_t5, _d_arr);
// CHECK-NEXT: _d_arr.ptr_ref() = _d_arr.ptr_ref() + 1;
// CHECK-NEXT: arr = arr + 1;
// CHECK-NEXT: }
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: _d_sum += 1;
// CHECK-NEXT: for (; _t0; _t0--) {
// CHECK-NEXT: --i;
// CHECK-NEXT: size_t *_t2 = clad::pop(_t1);
// CHECK-NEXT: {
// CHECK-NEXT: arr = clad::pop(_t4);
// CHECK-NEXT: _d_arr = clad::pop(_t5);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: sum = clad::pop(_t3);
// CHECK-NEXT: double _r_d0 = _d_sum;
// CHECK-NEXT: _d_sum += _r_d0;
// CHECK-NEXT: double _r0 = _r_d0 * (*j);
// CHECK-NEXT: _d_arr[0] += _r0;
// CHECK-NEXT: double _r1 = arr[0] * _r_d0;
// CHECK-NEXT: *_t2 += _r1;
// CHECK-NEXT: _d_sum -= _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

double pointerMultipleParams(const double* a, const double* b) {
double sum = b[2];
a = 1+a;
b = a;
sum += a[0] + b[0];
return sum; // 2*a[1] + b[2]
}

// CHECK: void pointerMultipleParams_grad(const double *a, const double *b, clad::array_ref<double> _d_a, clad::array_ref<double> _d_b) {
// CHECK-NEXT: double _d_sum = 0;
// CHECK-NEXT: const double *_t0;
// CHECK-NEXT: clad::array_ref<double> _t1;
// CHECK-NEXT: const double *_t2;
// CHECK-NEXT: clad::array_ref<double> _t3;
// CHECK-NEXT: double _t4;
// CHECK-NEXT: double sum = b[2];
// CHECK-NEXT: _t0 = a;
// CHECK-NEXT: _t1 = _d_a;
// CHECK-NEXT: _d_a.ptr_ref() = 1 + _d_a.ptr_ref();
// CHECK-NEXT: a = 1 + a;
// CHECK-NEXT: _t2 = b;
// CHECK-NEXT: _t3 = _d_b;
// CHECK-NEXT: _d_b.ptr_ref() = _d_a.ptr_ref();
// CHECK-NEXT: b = a;
// CHECK-NEXT: _t4 = sum;
// CHECK-NEXT: sum += a[0] + b[0];
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: _d_sum += 1;
// CHECK-NEXT: {
// CHECK-NEXT: sum = _t4;
// CHECK-NEXT: double _r_d0 = _d_sum;
// CHECK-NEXT: _d_sum += _r_d0;
// CHECK-NEXT: _d_a[0] += _r_d0;
// CHECK-NEXT: _d_b[0] += _r_d0;
// CHECK-NEXT: _d_sum -= _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: b = _t2;
// CHECK-NEXT: _d_b = _t3;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: a = _t0;
// CHECK-NEXT: _d_a = _t1;
// CHECK-NEXT: }
// CHECK-NEXT: _d_b[2] += _d_sum;
// CHECK-NEXT: }

#define NON_MEM_FN_TEST(var)\
res[0]=0;\
var.execute(5,res);\
Expand Down Expand Up @@ -265,4 +360,13 @@ int main() {
d_arr[0] = d_arr[1] = d_arr[2] = d_arr[3] = d_arr[4] = 0;
d_pointerParam.execute(arr, 5, d_arr_ref);
printf("%.2f %.2f %.2f %.2f %.2f\n", d_arr[0], d_arr[1], d_arr[2], d_arr[3], d_arr[4]); // CHECK-EXEC: 0.00 1.00 2.00 3.00 4.00

auto d_pointerMultipleParams = clad::gradient(pointerMultipleParams);
double b_arr[5] = {1, 2, 3, 4, 5};
double d_b_arr[5] = {0, 0, 0, 0, 0};
clad::array_ref<double> d_b_arr_ref(d_b_arr, 5);
d_arr[0] = d_arr[1] = d_arr[2] = d_arr[3] = d_arr[4] = 0;
d_pointerMultipleParams.execute(arr, b_arr, d_arr_ref, d_b_arr_ref);
printf("%.2f %.2f %.2f %.2f %.2f\n", d_arr[0], d_arr[1], d_arr[2], d_arr[3], d_arr[4]); // CHECK-EXEC: 0.00 2.00 0.00 0.00 0.00
printf("%.2f %.2f %.2f %.2f %.2f\n", d_b_arr[0], d_b_arr[1], d_b_arr[2], d_b_arr[3], d_b_arr[4]); // CHECK-EXEC: 0.00 0.00 1.00 0.00 0.00
}

0 comments on commit 514e58c

Please sign in to comment.