Skip to content

Commit

Permalink
Avoid creation of local derivative of const parameter (#1131)
Browse files Browse the repository at this point in the history
Closes #1130.
  • Loading branch information
kchristin22 authored Nov 3, 2024
1 parent effbb7b commit a50432f
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 0 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ namespace clad {
StmtDiff
VisitMaterializeTemporaryExpr(const clang::MaterializeTemporaryExpr* MTE);
StmtDiff VisitCXXStaticCastExpr(const clang::CXXStaticCastExpr* SCE);
StmtDiff VisitCXXConstCastExpr(const clang::CXXConstCastExpr* CCE);
StmtDiff VisitSwitchStmt(const clang::SwitchStmt* SS);
StmtDiff VisitCaseStmt(const clang::CaseStmt* CS);
StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS);
Expand Down
8 changes: 8 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
i == m_DiffReq->getNumParams() - 1)
continue;
auto VDDerivedType = param->getType();
if (VDDerivedType.isConstQualified())
continue;
// We cannot initialize derived variable for pointer types because
// we do not know the correct size.
if (utils::isArrayOrPointerType(VDDerivedType))
Expand Down Expand Up @@ -4576,6 +4578,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return subExprDiff;
}

StmtDiff ReverseModeVisitor::VisitCXXConstCastExpr(
const clang::CXXConstCastExpr* CCE) {
StmtDiff subExprDiff = Visit(CCE->getSubExpr(), dfdx());
return {Clone(CCE), subExprDiff.getExpr_dx()};
}

clang::QualType ReverseModeVisitor::ComputeAdjointType(clang::QualType T) {
if (T->isReferenceType()) {
QualType TValueType = utils::GetValueType(T);
Expand Down
71 changes: 71 additions & 0 deletions test/Gradient/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,62 @@ double f23(double x, double y) {
//CHECK-NEXT: }
//CHECK-NEXT: }

double constVal(double y, const double x) {
const double z = y;
y *= z;
return y * x;
}

//CHECK: void constVal_grad_0(double y, const double x, double *_d_y) {
//CHECK-NEXT: double _d_z = 0.;
//CHECK-NEXT: const double z = y;
//CHECK-NEXT: double _t0 = y;
//CHECK-NEXT: y *= z;
//CHECK-NEXT: *_d_y += 1 * x;
//CHECK-NEXT: {
//CHECK-NEXT: y = _t0;
//CHECK-NEXT: double _r_d0 = *_d_y;
//CHECK-NEXT: *_d_y = 0.;
//CHECK-NEXT: *_d_y += _r_d0 * z;
//CHECK-NEXT: _d_z += y * _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: *_d_y += _d_z;
//CHECK-NEXT:}

double constValConstCast(double y, const double x) {
double z = const_cast<double&>(x);
y *= z;
return y * x;
}

// CHECK: void constValConstCast_grad(double y, const double x, double *_d_y, double *_d_x) {
//CHECK-NEXT: double _d_z = 0.;
//CHECK-NEXT: double z = const_cast<double &>(x);
//CHECK-NEXT: double _t0 = y;
//CHECK-NEXT: y *= z;
//CHECK-NEXT: {
//CHECK-NEXT: *_d_y += 1 * x;
//CHECK-NEXT: *_d_x += y * 1;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: y = _t0;
//CHECK-NEXT: double _r_d0 = *_d_y;
//CHECK-NEXT: *_d_y = 0.;
//CHECK-NEXT: *_d_y += _r_d0 * z;
//CHECK-NEXT: _d_z += y * _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: *_d_x += _d_z;
//CHECK-NEXT:}

double constValInput(const double x) {
return x;
}

//CHECK: void constValInput_grad(const double x, double *_d_x) {
//CHECK-NEXT: *_d_x += 1;
//CHECK-NEXT:}


#define TEST(F, x, y) \
{ \
result[0] = 0; \
Expand Down Expand Up @@ -884,4 +940,19 @@ int main() {
TEST(f21, 6, 4); // CHECK-EXEC: {1.00, 0.00}
TEST(f22, 6, 4); // CHECK-EXEC: {0.00, 0.00}
TEST(f23, 7, 5); // CHECK-EXEC: {1.00, 1.00}

auto const_test = clad::gradient(constVal, "y");
double const_test_result = 0;
const_test.execute(3, 4, &const_test_result);
printf("%.2f\n", const_test_result); // CHECK-EXEC: 24.00

auto const_test_const_cast = clad::gradient(constValConstCast);
double const_test_const_cast_result[2] = {0};
const_test_const_cast.execute(3, 4, &const_test_const_cast_result[0], &const_test_const_cast_result[1]);
printf("{%.2f, %.2f}\n", const_test_const_cast_result[0], const_test_const_cast_result[1]); // CHECK-EXEC: {16.00, 24.00}

auto const_test_input = clad::gradient(constValInput);
double const_test_input_result = 0;
const_test_input.execute(3, &const_test_input_result);
printf("%.2f\n", const_test_input_result); // CHECK-EXEC: 1.00
}

0 comments on commit a50432f

Please sign in to comment.