Skip to content

Commit

Permalink
Fix paranthesis for derivative of division operator
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed May 8, 2024
1 parent 3f0e5a1 commit b7cec61
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 18 deletions.
8 changes: 4 additions & 4 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2267,10 +2267,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* dr = nullptr;
if (dfdx()) {
Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored));
dr = BuildOp(
BO_Mul, dfdx(),
BuildOp(UO_Minus,
BuildOp(BO_Div, LStored.getRevSweepAsExpr(), RxR)));
dr = BuildOp(BO_Mul, dfdx(),
BuildOp(UO_Minus,
BuildParens(BuildOp(
BO_Div, LStored.getRevSweepAsExpr(), RxR))));
dr = StoreAndRef(dr, direction::reverse);
}
Rdiff = Visit(R, dr);
Expand Down
2 changes: 1 addition & 1 deletion test/Arrays/ArrayInputsReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ int main() {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = _d_y * -1 / (_t0 * _t0);
//CHECK-NEXT: double _r0 = _d_y * -(1 / (_t0 * _t0));
//CHECK-NEXT: _d_params[0] += _r0 * params[0];
//CHECK-NEXT: _d_params[0] += params[0] * _r0;
//CHECK-NEXT: }
Expand Down
2 changes: 1 addition & 1 deletion test/ErrorEstimation/BasicOps.C
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ float func2(float x, float y) {
//CHECK-NEXT: {
//CHECK-NEXT: _final_error += std::abs(_d_z * z * {{.+}});
//CHECK-NEXT: *_d_y += _d_z / x;
//CHECK-NEXT: float _r0 = _d_z * -y / (x * x);
//CHECK-NEXT: float _r0 = _d_z * -(y / (x * x));
//CHECK-NEXT: *_d_x += _r0;
//CHECK-NEXT: }
//CHECK-NEXT: {
Expand Down
2 changes: 1 addition & 1 deletion test/ErrorEstimation/ConditonalStatements.C
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ float func4(float x, float y) {
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: *_d_y += 1 / x;
//CHECK-NEXT: float _r0 = 1 * -y / (x * x);
//CHECK-NEXT: float _r0 = 1 * -(y / (x * x));
//CHECK-NEXT: *_d_x += _r0;
//CHECK-NEXT: }
//CHECK-NEXT: if (_cond0) {
Expand Down
2 changes: 1 addition & 1 deletion test/ErrorEstimation/LoopsAndArraysExec.C
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ double divSum(float* a, float* b, int n) {
//CHECK-NEXT: b_size = std::max(b_size, i);
//CHECK-NEXT: _d_a[i] += _r_d0 / b[i];
//CHECK-NEXT: a_size = std::max(a_size, i);
//CHECK-NEXT: double _r0 = _r_d0 * -a[i] / (b[i] * b[i]);
//CHECK-NEXT: double _r0 = _r_d0 * -(a[i] / (b[i] * b[i]));
//CHECK-NEXT: _d_b[i] += _r0;
//CHECK-NEXT: b_size = std::max(b_size, i);
//CHECK-NEXT: }
Expand Down
27 changes: 22 additions & 5 deletions test/Gradient/Gradients.C
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ double f_div1(double x, double y) {
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: *_d_x += 1 / y;
//CHECK-NEXT: double _r0 = 1 * -x / (y * y);
//CHECK-NEXT: double _r0 = 1 * -(x / (y * y));
//CHECK-NEXT: *_d_y += _r0;
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand All @@ -143,7 +143,7 @@ double f_div2(double x, double y) {
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: *_d_x += 3 * 1 / _t0;
//CHECK-NEXT: double _r0 = 1 * -3 * x / (_t0 * _t0);
//CHECK-NEXT: double _r0 = 1 * -(3 * x / (_t0 * _t0));
//CHECK-NEXT: *_d_y += 4 * _r0;
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand All @@ -169,7 +169,7 @@ double f_div3(double x, double y) {
//CHECK-NEXT: double _r_d0 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d0;
//CHECK-NEXT: *_d_y += _r_d0;
//CHECK-NEXT: double _r0 = 1 * -_t2 / (_t0 * _t0);
//CHECK-NEXT: double _r0 = 1 * -(_t2 / (_t0 * _t0));
//CHECK-NEXT: *_d_y += _r0 * y;
//CHECK-NEXT: *_d_y += y * _r0;
//CHECK-NEXT: }
Expand All @@ -190,7 +190,7 @@ double f_c(double x, double y) {
//CHECK-NEXT: *_d_x += 1 * (x / y);
//CHECK-NEXT: *_d_y += 1 * (x / y);
//CHECK-NEXT: *_d_x += (x + y) * 1 / y;
//CHECK-NEXT: double _r0 = (x + y) * 1 * -x / (y * y);
//CHECK-NEXT: double _r0 = (x + y) * 1 * -(x / (y * y));
//CHECK-NEXT: *_d_y += _r0;
//CHECK-NEXT: *_d_x += -1 * x;
//CHECK-NEXT: *_d_x += x * -1;
Expand Down Expand Up @@ -455,7 +455,7 @@ void f_norm_grad(double x,
//CHECK-NEXT: *_d_y += _r2;
//CHECK-NEXT: *_d_z += _r3;
//CHECK-NEXT: *_d_d += _r4;
//CHECK-NEXT: double _r6 = _r5 * -1 / (d * d);
//CHECK-NEXT: double _r6 = _r5 * -(1 / (d * d));
//CHECK-NEXT: *_d_d += _r6;
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand Down Expand Up @@ -777,6 +777,19 @@ double fn_template_non_type(double x) {
// CHECK-NEXT: _d_maxN += _d_m;
// CHECK-NEXT: }

double fn_div(double x) {
return -0.5 / x;
}

// CHECK: void fn_div_grad(double x, double *_d_x) {
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 1 * -(-0.5 / (x * x));
// CHECK-NEXT: *_d_x += _r0;
// CHECK-NEXT: }
// CHECK-NEXT: }

#define TEST(F, x, y) \
{ \
result[0] = 0; \
Expand Down Expand Up @@ -833,4 +846,8 @@ int main() {
double x = 5, dx = 0;
fn_template_non_type_dx.execute(x, &dx);
printf("Result is = %.2f\n", dx); // CHECK-EXEC: Result is = 15.00

INIT_GRADIENT(fn_div);
dx = 0;
TEST_GRADIENT(fn_div, /*numOfDerivativeArgs=*/1, 2, &dx); // CHECK-EXEC: 0.12
}
6 changes: 3 additions & 3 deletions test/Gradient/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ double f_log_gaus(double* x, double* p /*means*/, double n, double sigma) {
//CHECK-NEXT: _d_gaus += _r8;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: double _r3 = _d_gaus * _t5 * -1. / (_t6 * _t6);
//CHECK-NEXT: double _r3 = _d_gaus * _t5 * -(1. / (_t6 * _t6));
//CHECK-NEXT: double _r4 = 0;
//CHECK-NEXT: _r4 += _r3 * clad::custom_derivatives::sqrt_pushforward(_t7 * sigma, 1.).pushforward;
//CHECK-NEXT: double _r5 = 0;
Expand All @@ -377,7 +377,7 @@ double f_log_gaus(double* x, double* p /*means*/, double n, double sigma) {
//CHECK-NEXT: double _r_d1 = _d_power;
//CHECK-NEXT: _d_power -= _r_d1;
//CHECK-NEXT: _d_power += -_r_d1 / _t3;
//CHECK-NEXT: double _r1 = _r_d1 * --power / (_t3 * _t3);
//CHECK-NEXT: double _r1 = _r_d1 * -(-power / (_t3 * _t3));
//CHECK-NEXT: double _r2 = 0;
//CHECK-NEXT: sq_pullback(sigma, 2 * _r1, &_r2);
//CHECK-NEXT: _d_sigma += _r2;
Expand Down Expand Up @@ -1611,7 +1611,7 @@ double f_loop_init_var(double lower, double upper) {
// CHECK-NEXT: {
// CHECK-NEXT: *_d_upper += _d_interval / num_points;
// CHECK-NEXT: *_d_lower += -_d_interval / num_points;
// CHECK-NEXT: double _r0 = _d_interval * -(upper - lower) / (num_points * num_points);
// CHECK-NEXT: double _r0 = _d_interval * -((upper - lower) / (num_points * num_points));
// CHECK-NEXT: _d_num_points += _r0;
// CHECK-NEXT: }
// CHECK-NEXT: }
Expand Down
2 changes: 1 addition & 1 deletion test/Gradient/constexprTest.C
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ constexpr double fn( double a, double b, double c) {
//CHECK-NEXT: {
//CHECK-NEXT: *_d_a += _d_result * 100 * (a + b) / c * b;
//CHECK-NEXT: *_d_b += a * _d_result * 100 * (a + b) / c;
//CHECK-NEXT: double _r0 = _d_result * 100 * (a + b) * -a * b / (c * c);
//CHECK-NEXT: double _r0 = _d_result * 100 * (a + b) * -(a * b / (c * c));
//CHECK-NEXT: *_d_c += _r0;
//CHECK-NEXT: *_d_a += a * b / c * _d_result * 100;
//CHECK-NEXT: *_d_b += a * b / c * _d_result * 100;
Expand Down
2 changes: 1 addition & 1 deletion test/Hessian/BuiltinDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ int main() {
// CHECK-NEXT: float _r0 = 0;
// CHECK-NEXT: _r0 += _d_y.value * clad::custom_derivatives{{(::std)?}}::log_pushforward(x, 1.F).pushforward;
// CHECK-NEXT: *_d_x += _r0;
// CHECK-NEXT: double _r1 = _d_y.pushforward * d_x * -1. / (x * x);
// CHECK-NEXT: double _r1 = _d_y.pushforward * d_x * -(1. / (x * x));
// CHECK-NEXT: *_d_x += _r1;
// CHECK-NEXT: *_d_d_x += (1. / x) * _d_y.pushforward;
// CHECK-NEXT: }
Expand Down

0 comments on commit b7cec61

Please sign in to comment.