Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix custom reverse_forws for operators #1076

Merged
merged 1 commit into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2056,7 +2056,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value");
auto* resAdjoint =
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint");
return StmtDiff(resValue, nullptr, resAdjoint);
return StmtDiff(resValue, resAdjoint, resAdjoint);
}
if (utils::isNonConstReferenceType(returnType) ||
returnType->isPointerType()) {
Expand Down
48 changes: 48 additions & 0 deletions test/Gradient/STLCustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,27 @@ double fn13(double u, double v) {
return vec[0] + vec[1] + vec[2];
}

double fn14(double x, double y) {
std::vector<double> a;
a.push_back(x);
a.push_back(x);
a[1] = x*x;
return a[1];
}

int main() {
double d_i, d_j;
INIT_GRADIENT(fn10);
INIT_GRADIENT(fn11);
INIT_GRADIENT(fn12);
INIT_GRADIENT(fn13);
INIT_GRADIENT(fn14);

TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {1.00, 1.00}
TEST_GRADIENT(fn11, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {2.00, 1.00}
TEST_GRADIENT(fn12, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {4.00, 2.00}
TEST_GRADIENT(fn13, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {3.00, 0.00}
TEST_GRADIENT(fn14, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {6.00, 0.00}
}

// CHECK: void fn10_grad(double u, double v, double *_d_u, double *_d_v) {
Expand Down Expand Up @@ -381,3 +391,41 @@ int main() {
// CHECK-NEXT: {{.*}}constructor_pullback(&vec, count, u, allocator, &_d_vec, &_d_count, &*_d_u, &_d_allocator);
// CHECK-NEXT: *_d_u += _d_res;
// CHECK-NEXT: }

// CHECK: void fn14_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: std::vector<double> _d_a({});
// CHECK-NEXT: std::vector<double> a;
// CHECK-NEXT: double _t0 = x;
// CHECK-NEXT: std::vector<double> _t1 = a;
// CHECK-NEXT: {{.*}}push_back_reverse_forw(&a, x, &_d_a, *_d_x);
// CHECK-NEXT: double _t2 = x;
// CHECK-NEXT: std::vector<double> _t3 = a;
// CHECK-NEXT: {{.*}}push_back_reverse_forw(&a, x, &_d_a, *_d_x);
// CHECK-NEXT: std::vector<double> _t4 = a;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t5 = {{.*}}operator_subscript_reverse_forw(&a, 1, &_d_a, _r0);
// CHECK-NEXT: double _t6 = _t5.value;
// CHECK-NEXT: _t5.value = x * x;
// CHECK-NEXT: std::vector<double> _t7 = a;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t8 = {{.*}}operator_subscript_reverse_forw(&a, 1, &_d_a, _r1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nullptr was being passed to reverse_forws intentionally. We cannot pass adjoints uniformally for non-reference/non-pointer cases. Now, the generated code cannot be compiled because _r1 is being passed to the reverse_forw function call even though it is defined later in the reverse-pass.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, I've opened an issue about this (#1066) this week and I think it was partially introduced in the original PR about custom reverse_forws for constructors, so I didn't really register this as an issue. it's still better this way because it at least has the correct logic now. but such misuses of yet-to-be-declared variables are happening all over in the reverse mode (at least that's what I thought), so I didn't pay that much attention to this. I'll add this as a note to #1066 or open a separate issue about this. thanks for your comment, cos I'd have forgotten to track this otherwise.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's still better this way because it at least has the correct logic now.

I don't think it is the correct logic now. The reverse_forw call now refers to a variable that is defined much later, that does not seem like a correct logic to me.

but such misuses of yet-to-be-declared variables are happening all over in the reverse mode (at least that's what I thought)

Please correct me if I am wrong, I think @PetroZarytskyi made significant efforts to reduce such cases. We certainly should not knowingly add more such cases.

I'll add this as a note to #1066 or open a separate issue about this. thanks for your comment, cos I'd have forgotten to track this otherwise.

I think passing nullptr where adjoint is unavailable in the forward-pass is the right way to go. @vgvassilev What do you recommend here?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I see both points. @gojakuch is unblocked by this change on the one hand and on the other it is problematic. Do we have an alternative way to fix his example cases where the nullptr approach breaks?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have an alternative way to fix his example cases where the nullptr approach breaks?
I opened #1079 yesterday to track this. I think there should be a way, I'll get back to this.

// CHECK-NEXT: {
// CHECK-NEXT: {{.*}} _r1 = 0;
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t7, 1, 1, &_d_a, &_r1);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: _t5.value = _t6;
// CHECK-NEXT: double _r_d0 = _t5.adjoint;
// CHECK-NEXT: _t5.adjoint = 0;
// CHECK-NEXT: *_d_x += _r_d0 * x;
// CHECK-NEXT: *_d_x += x * _r_d0;
// CHECK-NEXT: {{.*}} _r0 = 0;
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t4, 1, 0, &_d_a, &_r0);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: x = _t2;
// CHECK-NEXT: {{.*}}push_back_pullback(&_t3, _t2, &_d_a, &*_d_x);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: x = _t0;
// CHECK-NEXT: {{.*}}push_back_pullback(&_t1, _t0, &_d_a, &*_d_x);
// CHECK-NEXT: }
// CHECK-NEXT: }
Loading