diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index 7b09cc282..375ae88d5 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -56,6 +56,8 @@ class BaseForwardModeVisitor StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS); StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO); StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL); + StmtDiff VisitCharacterLiteral(const clang::CharacterLiteral* CL); + StmtDiff VisitStringLiteral(const clang::StringLiteral* SL); StmtDiff VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE); StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE); StmtDiff VisitDeclStmt(const clang::DeclStmt* DS); diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 2564683ca..4dc24a7a3 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1108,7 +1108,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { // subexpression. if (const auto* MTE = dyn_cast(arg)) arg = clad_compat::GetSubExpr(MTE); - if (!isa(arg) && !isa(arg)) { + if (!arg->isEvaluatable(m_Context)) { allArgsAreConstantLiterals = false; break; } @@ -1460,6 +1460,21 @@ BaseForwardModeVisitor::VisitCXXBoolLiteralExpr(const CXXBoolLiteralExpr* BL) { return StmtDiff(Clone(BL), constant0); } +StmtDiff +BaseForwardModeVisitor::VisitCharacterLiteral(const CharacterLiteral* CL) { + llvm::APInt zero(m_Context.getIntWidth(m_Context.IntTy), /*value*/ 0); + auto constant0 = + IntegerLiteral::Create(m_Context, zero, m_Context.IntTy, noLoc); + return StmtDiff(Clone(CL), constant0); +} + +StmtDiff BaseForwardModeVisitor::VisitStringLiteral(const StringLiteral* SL) { + llvm::APInt zero(m_Context.getIntWidth(m_Context.IntTy), /*value*/ 0); + auto constant0 = + IntegerLiteral::Create(m_Context, zero, m_Context.IntTy, noLoc); + return StmtDiff(Clone(SL), constant0); +} + StmtDiff BaseForwardModeVisitor::VisitWhileStmt(const WhileStmt* WS) { // begin scope for while loop beginScope(Scope::ContinueScope | Scope::BreakScope | Scope::DeclScope | diff --git a/test/FirstDerivative/CodeGenSimple.C b/test/FirstDerivative/CodeGenSimple.C index 0759b15a2..02a815c92 100644 --- a/test/FirstDerivative/CodeGenSimple.C +++ b/test/FirstDerivative/CodeGenSimple.C @@ -1,4 +1,4 @@ -// RUN: %cladclang %s -I%S/../../include -oCodeGenSimple.out -Xclang -verify 2>&1 | FileCheck %s +// RUN: %cladclang %s -I%S/../../include -oCodeGenSimple.out 2>&1 | FileCheck %s // RUN: ./CodeGenSimple.out | FileCheck -check-prefix=CHECK-EXEC %s //CHECK-NOT: {{.*error|warning|note:.*}} @@ -7,7 +7,7 @@ extern "C" int printf(const char* fmt, ...); int f_1(int x) { - printf("I am being run!\n"); //expected-warning{{attempted to differentiate unsupported statement, no changes applied}} //expected-warning{{function 'printf' was not differentiated because clad failed to differentiate it and no suitable overload was found in namespace 'custom_derivatives', and function may not be eligible for numerical differentiation.}} + printf("I am being run!\n"); return x * x; } // CHECK: int f_1_darg0(int x) { diff --git a/test/FirstDerivative/FunctionCallsWithResults.C b/test/FirstDerivative/FunctionCallsWithResults.C index aa99017fc..f694aaaa3 100644 --- a/test/FirstDerivative/FunctionCallsWithResults.C +++ b/test/FirstDerivative/FunctionCallsWithResults.C @@ -170,10 +170,6 @@ double nonRealParamFn(const char* a, const char* b = nullptr) { return 1; } -// CHECK: clad::ValueAndPushforward nonRealParamFn_pushforward(const char *a, const char *b, const char *_d_a, const char *_d_b) { -// CHECK-NEXT: return {1, 0}; -// CHECK-NEXT: } - double fn4(double i, double j) { double res = nonRealParamFn(0, 0); res += i; @@ -183,9 +179,8 @@ double fn4(double i, double j) { // CHECK: double fn4_darg0(double i, double j) { // CHECK-NEXT: double _d_i = 1; // CHECK-NEXT: double _d_j = 0; -// CHECK-NEXT: clad::ValueAndPushforward _t0 = nonRealParamFn_pushforward(0, 0, 0, 0); -// CHECK-NEXT: double _d_res = _t0.pushforward; -// CHECK-NEXT: double res = _t0.value; +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: double res = nonRealParamFn(0, 0); // CHECK-NEXT: _d_res += _d_i; // CHECK-NEXT: res += i; // CHECK-NEXT: return _d_res; @@ -294,10 +289,21 @@ double sum(double* arr, int n) { // CHECK-NEXT: return {val, _d_val}; // CHECK-NEXT: } +double check_and_return(double x, char c) { + if (c == 'a') + return x; + return 1; +} +// CHECK: clad::ValueAndPushforward check_and_return_pushforward(double x, char c, double _d_x, char _d_c) { +// CHECK-NEXT: if (c == 'a') +// CHECK-NEXT: return {x, _d_x}; +// CHECK-NEXT: return {1, 0}; +// CHECK-NEXT: } + double fn8(double i, double j) { double arr[5] = {}; modifyArr(arr, 5, i*j); - return sum(arr, 5) * std::tanh(1.0); + return check_and_return(sum(arr, 5), 'a') * std::tanh(1.0); } // CHECK: double fn8_darg0(double i, double j) { @@ -307,9 +313,10 @@ double fn8(double i, double j) { // CHECK-NEXT: double arr[5] = {}; // CHECK-NEXT: modifyArr_pushforward(arr, 5, i * j, _d_arr, 0, _d_i * j + i * _d_j); // CHECK-NEXT: clad::ValueAndPushforward _t0 = sum_pushforward(arr, 5, _d_arr, 0); -// CHECK-NEXT: double &_t1 = _t0.value; -// CHECK-NEXT: double _t2 = std::tanh(1.); -// CHECK-NEXT: return _t0.pushforward * _t2 + _t1 * 0; +// CHECK-NEXT: clad::ValueAndPushforward _t1 = check_and_return_pushforward(_t0.value, 'a', _t0.pushforward, 0); +// CHECK-NEXT: double &_t2 = _t1.value; +// CHECK-NEXT: double _t3 = std::tanh(1.); +// CHECK-NEXT: return _t1.pushforward * _t3 + _t2 * 0; // CHECK-NEXT: } float test_1_darg0(float x);