Skip to content

Commit

Permalink
Fix pointer arithmetic for array types
Browse files Browse the repository at this point in the history
required for vgvassilev#762
  • Loading branch information
vaithak committed Feb 15, 2024
1 parent d305002 commit 8faca4f
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
3 changes: 2 additions & 1 deletion include/clad/Differentiator/Array.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ template <typename T> class array {
/// Returns the size of the underlying array
CUDA_HOST_DEVICE std::size_t size() const { return m_size; }
/// Returns the ptr of the underlying array
CUDA_HOST_DEVICE T* ptr() { return m_arr; }
CUDA_HOST_DEVICE T* ptr() const { return m_arr; }
CUDA_HOST_DEVICE T*& ptr_ref() { return m_arr; }
/// Returns the reference to the location at the index of the underlying
/// array
CUDA_HOST_DEVICE T& operator[](std::ptrdiff_t i) { return m_arr[i]; }
Expand Down
65 changes: 64 additions & 1 deletion test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,61 @@ double fn12(double x, double y) {
// CHECK-NEXT: }
// CHECK-NEXT: }

double multiply(double* a, double* b) {
return a[0] * b[0];
}

// CHECK: void multiply_pullback(double *a, double *b, double _d_y, clad::array_ref<double> _d_a, clad::array_ref<double> _d_b) {
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: _d_a[0] += _d_y * b[0];
// CHECK-NEXT: _d_b[0] += a[0] * _d_y;
// CHECK-NEXT: }
// CHECK-NEXT: }

double fn13(double* x, const double* w) {
double wCopy[2];
for(std::size_t i = 0; i < 2; ++i) { wCopy[i] = w[i]; }
return multiply(x, wCopy + 1);
}

// CHECK: void fn13_grad_0(double *x, const double *w, clad::array_ref<double> _d_x) {
// CHECK-NEXT: clad::array<double> _d_wCopy(2UL);
// CHECK-NEXT: unsigned long _t0;
// CHECK-NEXT: std::size_t _d_i = 0;
// CHECK-NEXT: std::size_t i = 0;
// CHECK-NEXT: clad::tape<double> _t1 = {};
// CHECK-NEXT: double *_t2;
// CHECK-NEXT: double *_t3;
// CHECK-NEXT: double wCopy[2];
// CHECK-NEXT: _t0 = 0;
// CHECK-NEXT: for (i = 0; i < 2; ++i) {
// CHECK-NEXT: _t0++;
// CHECK-NEXT: clad::push(_t1, wCopy[i]);
// CHECK-NEXT: wCopy[i] = w[i];
// CHECK-NEXT: }
// CHECK-NEXT: _t2 = x;
// CHECK-NEXT: _t3 = wCopy + 1;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: x = _t2;
// CHECK-NEXT: multiply_pullback(_t2, _t3, 1, _d_x, _d_wCopy.ptr_ref() + 1);
// CHECK-NEXT: double *_r0 = _d_x;
// CHECK-NEXT: double *_r1 = _d_wCopy.ptr_ref() + 1;
// CHECK-NEXT: }
// CHECK-NEXT: for (; _t0; _t0--) {
// CHECK-NEXT: --i;
// CHECK-NEXT: {
// CHECK-NEXT: wCopy[i] = clad::pop(_t1);
// CHECK-NEXT: double _r_d0 = _d_wCopy[i];
// CHECK-NEXT: _d_wCopy[i] -= _r_d0;
// CHECK-NEXT: _d_wCopy[i];
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

template<typename T>
void reset(T* arr, int n) {
for (int i=0; i<n; ++i)
Expand Down Expand Up @@ -743,4 +798,12 @@ int main() {
TEST2(fn10, 8, 5); // CHECK-EXEC: {0.00, 7.00}
TEST2(fn11, 3, 5); // CHECK-EXEC: {1.00, 1.00}
TEST2(fn12, 3, 5); // CHECK-EXEC: {1.00, 0.00}
}

// Testing the partial gradient of a function with multiple pointer arguments
auto fn13_grad_0 = clad::gradient(fn13, "x");
double x = 2.0;
double w[] = {2.0, 3.0};
double fn13_result = 0.0;
fn13_grad_0.execute(&x, w, &fn13_result);
printf("{%.2f}\n", fn13_result); // CHECK-EXEC: {3.00}
}

0 comments on commit 8faca4f

Please sign in to comment.