Skip to content

Commit

Permalink
Reorder filechecks in tests for pullbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Apr 10, 2024
1 parent 076babf commit 785ba1f
Show file tree
Hide file tree
Showing 14 changed files with 757 additions and 653 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ namespace clad {
class CladPlugin;
clang::FunctionDecl* ProcessDiffRequest(CladPlugin& P,
DiffRequest& request);
// FIXME: This function should be removed and the entire plans array
// should be somehow made accessible to all the visitors.
void AddRequestToSchedule(CladPlugin& P, const DiffRequest& request);
} // namespace plugin

Expand Down
6 changes: 4 additions & 2 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,13 @@ namespace clad {
// Diff info for pullbacks is generated automatically,
// its parameters are not provided by the user.
if (Mode == DiffMode::experimental_pullback) {
// Might need to update DVI args, as they may be pointing to the
// declaration parameters, not the definition parameters.
if (!Function->getPreviousDecl())
// If the function was never declared before, we can safely assume
// that the parameters are correctly referring to the definition ones.
return;
const FunctionDecl* FD = Function->getPreviousDecl();
// Might need to update DVI args, as they may be pointing to the
// declaration parameters, not the definition parameters.
for (size_t i = 0, e = DVI.size(), paramIdx = 0;
i < e && paramIdx < FD->getNumParams(); ++i) {
const auto* param = DVI[i].param;
Expand Down
12 changes: 8 additions & 4 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1804,17 +1804,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (!m_ExternalSource) {
// Derive the declaration of the pullback function.
pullbackRequest.DeclarationOnly = true;
pullbackFD = plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest);
pullbackFD =
plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest);

// Add the request to derive the definition of the pullback function.
pullbackRequest.DeclarationOnly = false;
pullbackRequest.DerivedFDPrototype = pullbackFD;
plugin::AddRequestToSchedule(m_CladPlugin, pullbackRequest);
} else {
// FIXME: Error estimation currently uses singleton objects - m_ErrorEstHandler and m_EstModel, which is cleared after each error_estimate request. This requires the pullback to be derived at the same time to access the singleton objects.
pullbackFD = plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest);
// FIXME: Error estimation currently uses singleton objects -
// m_ErrorEstHandler and m_EstModel, which is cleared after each
// error_estimate request. This requires the pullback to be derived at
// the same time to access the singleton objects.
pullbackFD =
plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest);
}


// Clad failed to derive it.
// FIXME: Add support for reference arguments to the numerical diff. If
Expand Down
162 changes: 87 additions & 75 deletions test/Arrays/ArrayInputsReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,7 @@ double addArr(const double *arr, int n) {
return ret;
}

//CHECK: void addArr_pullback(const double *arr, int n, double _d_y, double *_d_arr, int *_d_n) {
//CHECK-NEXT: double _d_ret = 0;
//CHECK-NEXT: unsigned {{int|long}} _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double ret = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (i = 0; i < n; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, ret);
//CHECK-NEXT: ret += arr[i];
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_ret += _d_y;
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: i--;
//CHECK-NEXT: {
//CHECK-NEXT: ret = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_ret;
//CHECK-NEXT: _d_arr[i] += _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK: void addArr_pullback(const double *arr, int n, double _d_y, double *_d_arr, int *_d_n);

double f(double *arr) {
return addArr(arr, 3);
Expand Down Expand Up @@ -104,11 +80,7 @@ float helper(float x) {
return 2 * x;
}

// CHECK: void helper_pullback(float x, float _d_y, float *_d_x) {
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: *_d_x += 2 * _d_y;
// CHECK-NEXT: }
// CHECK: void helper_pullback(float x, float _d_y, float *_d_x);

float func2(float* a) {
float sum = 0;
Expand Down Expand Up @@ -345,17 +317,7 @@ double inv_square(double *params) {
return 1 / (params[0] * params[0]);
}

//CHECK: void inv_square_pullback(double *params, double _d_y, double *_d_params) {
//CHECK-NEXT: double _t0;
//CHECK-NEXT: _t0 = (params[0] * params[0]);
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = _d_y * -1 / (_t0 * _t0);
//CHECK-NEXT: _d_params[0] += _r0 * params[0];
//CHECK-NEXT: _d_params[0] += params[0] * _r0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK: void inv_square_pullback(double *params, double _d_y, double *_d_params);

double func7(double *params) {
double out = 0.0;
Expand Down Expand Up @@ -407,14 +369,7 @@ double helper2(double i, double *arr, int n) {
return arr[0]*i;
}

//CHECK: void helper2_pullback(double i, double *arr, int n, double _d_y, double *_d_i, double *_d_arr, int *_d_n) {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: _d_arr[0] += _d_y * i;
//CHECK-NEXT: *_d_i += arr[0] * _d_y;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK: void helper2_pullback(double i, double *arr, int n, double _d_y, double *_d_i, double *_d_arr, int *_d_n);

double func8(double i, double *arr, int n) {
double res = 0;
Expand Down Expand Up @@ -465,17 +420,7 @@ void modify(double& elem, double val) {
elem = val;
}

//CHECK: void modify_pullback(double &elem, double val, double *_d_elem, double *_d_val) {
//CHECK-NEXT: double _t0;
//CHECK-NEXT: _t0 = elem;
//CHECK-NEXT: elem = val;
//CHECK-NEXT: {
//CHECK-NEXT: elem = _t0;
//CHECK-NEXT: double _r_d0 = *_d_elem;
//CHECK-NEXT: *_d_elem -= _r_d0;
//CHECK-NEXT: *_d_val += _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK: void modify_pullback(double &elem, double val, double *_d_elem, double *_d_val);

double func9(double i, double j) {
double arr[5] = {};
Expand Down Expand Up @@ -525,21 +470,7 @@ double sq(double& elem) {
return elem;
}

//CHECK: void sq_pullback(double &elem, double _d_y, double *_d_elem) {
//CHECK-NEXT: double _t0;
//CHECK-NEXT: _t0 = elem;
//CHECK-NEXT: elem = elem * elem;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: *_d_elem += _d_y;
//CHECK-NEXT: {
//CHECK-NEXT: elem = _t0;
//CHECK-NEXT: double _r_d0 = *_d_elem;
//CHECK-NEXT: *_d_elem -= _r_d0;
//CHECK-NEXT: *_d_elem += _r_d0 * elem;
//CHECK-NEXT: *_d_elem += elem * _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK: void sq_pullback(double &elem, double _d_y, double *_d_elem);

double func10(double *arr, int n) {
double res = 0;
Expand Down Expand Up @@ -645,3 +576,84 @@ int main() {
func10grad.execute(arr3, 5, _d_arr3);
printf("Result (arr) = {%.2f, %.2f, %.2f, %.2f, %.2f}\n", _d_arr3[0], _d_arr3[1], _d_arr3[2], _d_arr3[3], _d_arr3[4]); // CHECK-EXEC: Result (arr) = {2.00, 4.00, 6.00, 8.00, 10.00}
}

//CHECK: void addArr_pullback(const double *arr, int n, double _d_y, double *_d_arr, int *_d_n) {
//CHECK-NEXT: double _d_ret = 0;
//CHECK-NEXT: unsigned {{int|long}} _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double ret = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (i = 0; i < n; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, ret);
//CHECK-NEXT: ret += arr[i];
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_ret += _d_y;
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: i--;
//CHECK-NEXT: {
//CHECK-NEXT: ret = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_ret;
//CHECK-NEXT: _d_arr[i] += _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }

// CHECK: void helper_pullback(float x, float _d_y, float *_d_x) {
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: *_d_x += 2 * _d_y;
// CHECK-NEXT: }

//CHECK: void inv_square_pullback(double *params, double _d_y, double *_d_params) {
//CHECK-NEXT: double _t0;
//CHECK-NEXT: _t0 = (params[0] * params[0]);
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = _d_y * -1 / (_t0 * _t0);
//CHECK-NEXT: _d_params[0] += _r0 * params[0];
//CHECK-NEXT: _d_params[0] += params[0] * _r0;
//CHECK-NEXT: }
//CHECK-NEXT: }

//CHECK: void helper2_pullback(double i, double *arr, int n, double _d_y, double *_d_i, double *_d_arr, int *_d_n) {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: _d_arr[0] += _d_y * i;
//CHECK-NEXT: *_d_i += arr[0] * _d_y;
//CHECK-NEXT: }
//CHECK-NEXT: }

//CHECK: void modify_pullback(double &elem, double val, double *_d_elem, double *_d_val) {
//CHECK-NEXT: double _t0;
//CHECK-NEXT: _t0 = elem;
//CHECK-NEXT: elem = val;
//CHECK-NEXT: {
//CHECK-NEXT: elem = _t0;
//CHECK-NEXT: double _r_d0 = *_d_elem;
//CHECK-NEXT: *_d_elem -= _r_d0;
//CHECK-NEXT: *_d_val += _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }

//CHECK: void sq_pullback(double &elem, double _d_y, double *_d_elem) {
//CHECK-NEXT: double _t0;
//CHECK-NEXT: _t0 = elem;
//CHECK-NEXT: elem = elem * elem;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: *_d_elem += _d_y;
//CHECK-NEXT: {
//CHECK-NEXT: elem = _t0;
//CHECK-NEXT: double _r_d0 = *_d_elem;
//CHECK-NEXT: *_d_elem -= _r_d0;
//CHECK-NEXT: *_d_elem += _r_d0 * elem;
//CHECK-NEXT: *_d_elem += elem * _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }
Loading

0 comments on commit 785ba1f

Please sign in to comment.