Skip to content

Commit

Permalink
Fix char and string literals in forward mode AD
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Feb 7, 2024
1 parent 46adbfd commit 19cc205
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 21 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class BaseForwardModeVisitor
StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS);
StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO);
StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL);
StmtDiff VisitCharacterLiteral(const clang::CharacterLiteral* CL);
StmtDiff VisitStringLiteral(const clang::StringLiteral* SL);
StmtDiff VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE);
StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE);
StmtDiff VisitDeclStmt(const clang::DeclStmt* DS);
Expand Down
32 changes: 24 additions & 8 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -935,14 +935,14 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
StmtDiff BaseForwardModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) {
QualType T = IL->getType();
llvm::APInt zero(m_Context.getIntWidth(T), /*value*/ 0);
auto constant0 = IntegerLiteral::Create(m_Context, zero, T, noLoc);
auto* constant0 = IntegerLiteral::Create(m_Context, zero, T, noLoc);
return StmtDiff(Clone(IL), constant0);
}

StmtDiff
BaseForwardModeVisitor::VisitFloatingLiteral(const FloatingLiteral* FL) {
llvm::APFloat zero = llvm::APFloat::getZero(FL->getSemantics());
auto constant0 =
auto* constant0 =
FloatingLiteral::Create(m_Context, zero, true, FL->getType(), noLoc);
return StmtDiff(Clone(FL), constant0);
}
Expand Down Expand Up @@ -1108,7 +1108,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
// subexpression.
if (const auto* MTE = dyn_cast<MaterializeTemporaryExpr>(arg))
arg = clad_compat::GetSubExpr(MTE);
if (!isa<FloatingLiteral>(arg) && !isa<IntegerLiteral>(arg)) {
if (!arg->isEvaluatable(m_Context)) {
allArgsAreConstantLiterals = false;
break;
}
Expand Down Expand Up @@ -1455,11 +1455,26 @@ BaseForwardModeVisitor::VisitCXXDefaultArgExpr(const CXXDefaultArgExpr* DE) {
StmtDiff
BaseForwardModeVisitor::VisitCXXBoolLiteralExpr(const CXXBoolLiteralExpr* BL) {
llvm::APInt zero(m_Context.getIntWidth(m_Context.IntTy), /*value*/ 0);
auto constant0 =
auto* constant0 =
IntegerLiteral::Create(m_Context, zero, m_Context.IntTy, noLoc);
return StmtDiff(Clone(BL), constant0);
}

StmtDiff
BaseForwardModeVisitor::VisitCharacterLiteral(const CharacterLiteral* CL) {
llvm::APInt zero(m_Context.getIntWidth(m_Context.IntTy), /*value*/ 0);
auto* constant0 =
IntegerLiteral::Create(m_Context, zero, m_Context.IntTy, noLoc);
return StmtDiff(Clone(CL), constant0);
}

StmtDiff BaseForwardModeVisitor::VisitStringLiteral(const StringLiteral* SL) {
llvm::APInt zero(m_Context.getIntWidth(m_Context.IntTy), /*value*/ 0);
auto* constant0 =
IntegerLiteral::Create(m_Context, zero, m_Context.IntTy, noLoc);
return StmtDiff(Clone(SL), constant0);
}

StmtDiff BaseForwardModeVisitor::VisitWhileStmt(const WhileStmt* WS) {
// begin scope for while loop
beginScope(Scope::ContinueScope | Scope::BreakScope | Scope::DeclScope |
Expand Down Expand Up @@ -1926,15 +1941,16 @@ StmtDiff BaseForwardModeVisitor::VisitCXXStaticCastExpr(
StmtDiff BaseForwardModeVisitor::VisitCXXFunctionalCastExpr(
const clang::CXXFunctionalCastExpr* FCE) {
StmtDiff castExprDiff = Visit(FCE->getSubExpr());
SourceLocation fakeLoc = utils::GetValidSLoc(m_Sema);
Expr* clonedFCE = m_Sema
.BuildCXXFunctionalCastExpr(
FCE->getTypeInfoAsWritten(), FCE->getType(), noLoc,
castExprDiff.getExpr(), noLoc)
FCE->getTypeInfoAsWritten(), FCE->getType(),
fakeLoc, castExprDiff.getExpr(), fakeLoc)
.get();
Expr* derivedFCE = m_Sema
.BuildCXXFunctionalCastExpr(
FCE->getTypeInfoAsWritten(), FCE->getType(), noLoc,
castExprDiff.getExpr_dx(), noLoc)
FCE->getTypeInfoAsWritten(), FCE->getType(),
fakeLoc, castExprDiff.getExpr_dx(), fakeLoc)
.get();
return {clonedFCE, derivedFCE};
}
Expand Down
4 changes: 2 additions & 2 deletions test/FirstDerivative/CodeGenSimple.C
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %cladclang %s -I%S/../../include -oCodeGenSimple.out -Xclang -verify 2>&1 | FileCheck %s
// RUN: %cladclang %s -I%S/../../include -oCodeGenSimple.out 2>&1 | FileCheck %s
// RUN: ./CodeGenSimple.out | FileCheck -check-prefix=CHECK-EXEC %s

//CHECK-NOT: {{.*error|warning|note:.*}}
Expand All @@ -7,7 +7,7 @@
extern "C" int printf(const char* fmt, ...);

int f_1(int x) {
printf("I am being run!\n"); //expected-warning{{attempted to differentiate unsupported statement, no changes applied}} //expected-warning{{function 'printf' was not differentiated because clad failed to differentiate it and no suitable overload was found in namespace 'custom_derivatives', and function may not be eligible for numerical differentiation.}}
printf("I am being run!\n");
return x * x;
}
// CHECK: int f_1_darg0(int x) {
Expand Down
29 changes: 18 additions & 11 deletions test/FirstDerivative/FunctionCallsWithResults.C
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,6 @@ double nonRealParamFn(const char* a, const char* b = nullptr) {
return 1;
}

// CHECK: clad::ValueAndPushforward<double, double> nonRealParamFn_pushforward(const char *a, const char *b, const char *_d_a, const char *_d_b) {
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: }

double fn4(double i, double j) {
double res = nonRealParamFn(0, 0);
res += i;
Expand All @@ -183,9 +179,8 @@ double fn4(double i, double j) {
// CHECK: double fn4_darg0(double i, double j) {
// CHECK-NEXT: double _d_i = 1;
// CHECK-NEXT: double _d_j = 0;
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = nonRealParamFn_pushforward(0, 0, 0, 0);
// CHECK-NEXT: double _d_res = _t0.pushforward;
// CHECK-NEXT: double res = _t0.value;
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: double res = nonRealParamFn(0, 0);
// CHECK-NEXT: _d_res += _d_i;
// CHECK-NEXT: res += i;
// CHECK-NEXT: return _d_res;
Expand Down Expand Up @@ -294,10 +289,21 @@ double sum(double* arr, int n) {
// CHECK-NEXT: return {val, _d_val};
// CHECK-NEXT: }

double check_and_return(double x, char c) {
if (c == 'a')
return x;
return 1;
}
// CHECK: clad::ValueAndPushforward<double, double> check_and_return_pushforward(double x, char c, double _d_x, char _d_c) {
// CHECK-NEXT: if (c == 'a')
// CHECK-NEXT: return {x, _d_x};
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: }

double fn8(double i, double j) {
double arr[5] = {};
modifyArr(arr, 5, i*j);
return sum(arr, 5) * std::tanh(1.0);
return check_and_return(sum(arr, 5), 'a') * std::tanh(1.0);
}

// CHECK: double fn8_darg0(double i, double j) {
Expand All @@ -307,9 +313,10 @@ double fn8(double i, double j) {
// CHECK-NEXT: double arr[5] = {};
// CHECK-NEXT: modifyArr_pushforward(arr, 5, i * j, _d_arr, 0, _d_i * j + i * _d_j);
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = sum_pushforward(arr, 5, _d_arr, 0);
// CHECK-NEXT: double &_t1 = _t0.value;
// CHECK-NEXT: double _t2 = std::tanh(1.);
// CHECK-NEXT: return _t0.pushforward * _t2 + _t1 * 0;
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t1 = check_and_return_pushforward(_t0.value, 'a', _t0.pushforward, 0);
// CHECK-NEXT: double &_t2 = _t1.value;
// CHECK-NEXT: double _t3 = std::tanh(1.);
// CHECK-NEXT: return _t1.pushforward * _t3 + _t2 * 0;
// CHECK-NEXT: }

float test_1_darg0(float x);
Expand Down

0 comments on commit 19cc205

Please sign in to comment.