From 7de91af1e4943b0ada0c4084a8a707d5dd8c2712 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Thu, 15 Feb 2024 18:17:08 +0100 Subject: [PATCH] Fix pointer arithmetic for array types required for #762 --- include/clad/Differentiator/Array.h | 3 +- test/Gradient/FunctionCalls.C | 65 ++++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/include/clad/Differentiator/Array.h b/include/clad/Differentiator/Array.h index 48c02f776..b8e291eb3 100644 --- a/include/clad/Differentiator/Array.h +++ b/include/clad/Differentiator/Array.h @@ -102,7 +102,8 @@ template class array { /// Returns the size of the underlying array CUDA_HOST_DEVICE std::size_t size() const { return m_size; } /// Returns the ptr of the underlying array - CUDA_HOST_DEVICE T* ptr() { return m_arr; } + CUDA_HOST_DEVICE T* ptr() const { return m_arr; } + CUDA_HOST_DEVICE T*& ptr_ref() { return m_arr; } /// Returns the reference to the location at the index of the underlying /// array CUDA_HOST_DEVICE T& operator[](std::ptrdiff_t i) { return m_arr[i]; } diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index 62062f813..828bca861 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -672,6 +672,61 @@ double fn12(double x, double y) { // CHECK-NEXT: } // CHECK-NEXT: } +double multiply(double* a, double* b) { + return a[0] * b[0]; +} + +// CHECK: void multiply_pullback(double *a, double *b, double _d_y, clad::array_ref _d_a, clad::array_ref _d_b) { +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: _d_a[0] += _d_y * b[0]; +// CHECK-NEXT: _d_b[0] += a[0] * _d_y; +// CHECK-NEXT: } +// CHECK-NEXT: } + +double fn13(double* x, const double* w) { + double wCopy[2]; + for(std::size_t i = 0; i < 2; ++i) { wCopy[i] = w[i]; } + return multiply(x, wCopy + 1); +} + +// CHECK: void fn13_grad_0(double *x, const double *w, clad::array_ref _d_x) { +// CHECK-NEXT: clad::array _d_wCopy(2UL); +// CHECK-NEXT: unsigned long _t0; +// CHECK-NEXT: std::size_t _d_i = 0; +// CHECK-NEXT: std::size_t i = 0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: double *_t2; +// CHECK-NEXT: double *_t3; +// CHECK-NEXT: double wCopy[2]; +// CHECK-NEXT: _t0 = 0; +// CHECK-NEXT: for (i = 0; i < 2; ++i) { +// CHECK-NEXT: _t0++; +// CHECK-NEXT: clad::push(_t1, wCopy[i]); +// CHECK-NEXT: wCopy[i] = w[i]; +// CHECK-NEXT: } +// CHECK-NEXT: _t2 = x; +// CHECK-NEXT: _t3 = wCopy + 1; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: x = _t2; +// CHECK-NEXT: multiply_pullback(_t2, _t3, 1, _d_x, _d_wCopy.ptr_ref() + 1); +// CHECK-NEXT: double *_r0 = _d_x; +// CHECK-NEXT: double *_r1 = _d_wCopy.ptr_ref() + 1; +// CHECK-NEXT: } +// CHECK-NEXT: for (; _t0; _t0--) { +// CHECK-NEXT: --i; +// CHECK-NEXT: { +// CHECK-NEXT: wCopy[i] = clad::pop(_t1); +// CHECK-NEXT: double _r_d0 = _d_wCopy[i]; +// CHECK-NEXT: _d_wCopy[i] -= _r_d0; +// CHECK-NEXT: _d_wCopy[i]; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + template void reset(T* arr, int n) { for (int i=0; i