Skip to content

Commit

Permalink
Fix custom reverse_forws for operators
Browse files Browse the repository at this point in the history
Previously, nullptr used to be set as the derivative of a
call in the reverse mode, if there was a custom reverse_forw
function available. This issue was overlooked at first, since
it doesn't cause any trouble, unless someone decides to use nested
operators (such as expressions of the form `a[i] = x*x`).

Fixes: #1070
  • Loading branch information
gojakuch committed Sep 3, 2024
1 parent c294ea5 commit 11443bf
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 2 deletions.
2 changes: 1 addition & 1 deletion include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ template <typename T, typename U>
void push_back_reverse_forw(::std::vector<T>* v, U val, ::std::vector<T>* d_v,
U d_val) {
v->push_back(val);
d_v->push_back(0);
d_v->push_back(d_val);
}

template <typename T, typename U>
Expand Down
2 changes: 1 addition & 1 deletion lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2056,7 +2056,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value");
auto* resAdjoint =
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint");
return StmtDiff(resValue, nullptr, resAdjoint);
return StmtDiff(resValue, resAdjoint, resAdjoint);
}
if (utils::isNonConstReferenceType(returnType) ||
returnType->isPointerType()) {
Expand Down
48 changes: 48 additions & 0 deletions test/Gradient/STLCustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,27 @@ double fn13(double u, double v) {
return vec[0] + vec[1] + vec[2];
}

double fn14(double x, double y) {
std::vector<double> a;
a.push_back(x);
a.push_back(x);
a[1] = x*x;
return a[1];
}

int main() {
double d_i, d_j;
INIT_GRADIENT(fn10);
INIT_GRADIENT(fn11);
INIT_GRADIENT(fn12);
INIT_GRADIENT(fn13);
INIT_GRADIENT(fn14);

TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {1.00, 1.00}
TEST_GRADIENT(fn11, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {2.00, 1.00}
TEST_GRADIENT(fn12, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {4.00, 2.00}
TEST_GRADIENT(fn13, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {3.00, 0.00}
TEST_GRADIENT(fn14, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {6.00, 0.00}
}

// CHECK: void fn10_grad(double u, double v, double *_d_u, double *_d_v) {
Expand Down Expand Up @@ -381,3 +391,41 @@ int main() {
// CHECK-NEXT: {{.*}}constructor_pullback(&vec, count, u, allocator, &_d_vec, &_d_count, &*_d_u, &_d_allocator);
// CHECK-NEXT: *_d_u += _d_res;
// CHECK-NEXT: }

// CHECK: void fn14_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: std::vector<double> _d_a({});
// CHECK-NEXT: std::vector<double> a;
// CHECK-NEXT: double _t0 = x;
// CHECK-NEXT: std::vector<double> _t1 = a;
// CHECK-NEXT: {{.*}}push_back_reverse_forw(&a, x, &_d_a, *_d_x);
// CHECK-NEXT: double _t2 = x;
// CHECK-NEXT: std::vector<double> _t3 = a;
// CHECK-NEXT: {{.*}}push_back_reverse_forw(&a, x, &_d_a, *_d_x);
// CHECK-NEXT: std::vector<double> _t4 = a;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t5 = {{.*}}operator_subscript_reverse_forw(&a, 1, &_d_a, _r0);
// CHECK-NEXT: double _t6 = _t5.value;
// CHECK-NEXT: _t5.value = x * x;
// CHECK-NEXT: std::vector<double> _t7 = a;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t8 = {{.*}}operator_subscript_reverse_forw(&a, 1, &_d_a, _r1);
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}} _r1 = 0;
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t7, 1, 1, &_d_a, &_r1);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: _t5.value = _t6;
// CHECK-NEXT: double _r_d0 = _t5.adjoint;
// CHECK-NEXT: _t5.adjoint = 0;
// CHECK-NEXT: *_d_x += _r_d0 * x;
// CHECK-NEXT: *_d_x += x * _r_d0;
// CHECK-NEXT: {{.*}} _r0 = 0;
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t4, 1, 0, &_d_a, &_r0);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: x = _t2;
// CHECK-NEXT: {{.*}}push_back_pullback(&_t3, _t2, &_d_a, &*_d_x);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: x = _t0;
// CHECK-NEXT: {{.*}}push_back_pullback(&_t1, _t0, &_d_a, &*_d_x);
// CHECK-NEXT: }
// CHECK-NEXT: }

0 comments on commit 11443bf

Please sign in to comment.