From 241f526a13685ce2d4e5af0b3e947db72643db7a Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Thu, 8 Feb 2024 00:25:50 +0100 Subject: [PATCH] Fix char and string literals in reverse mode AD --- .../clad/Differentiator/ReverseModeVisitor.h | 2 + lib/Differentiator/ReverseModeVisitor.cpp | 13 ++++++- test/Gradient/FunctionCalls.C | 39 ++++++++++++++++--- 3 files changed, 47 insertions(+), 7 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index c0010bce3..129d02b98 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -347,6 +347,8 @@ namespace clad { virtual StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS); StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO); StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL); + StmtDiff VisitCharacterLiteral(const clang::CharacterLiteral* CL); + StmtDiff VisitStringLiteral(const clang::StringLiteral* SL); StmtDiff VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE); virtual StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE); StmtDiff VisitDeclStmt(const clang::DeclStmt* DS); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index d0ec52a6c..1b89faa51 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1147,6 +1147,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return Clone(BL); } + StmtDiff + ReverseModeVisitor::VisitCharacterLiteral(const CharacterLiteral* CL) { + return Clone(CL); + } + + StmtDiff ReverseModeVisitor::VisitStringLiteral(const StringLiteral* SL) { + return StmtDiff(Clone(SL), StringLiteral::Create( + m_Context, "", SL->getKind(), SL->isPascal(), + SL->getType(), utils::GetValidSLoc(m_Sema))); + } + StmtDiff ReverseModeVisitor::VisitReturnStmt(const ReturnStmt* RS) { // Initially, df/df = 1. const Expr* value = RS->getRetValue(); @@ -1386,7 +1397,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // subexpression. if (const auto* MTE = dyn_cast(arg)) arg = clad_compat::GetSubExpr(MTE)->IgnoreImpCasts(); - if (!isa(arg) && !isa(arg)) { + if (!arg->isEvaluatable(m_Context)) { allArgsAreConstantLiterals = false; break; } diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index b15f2ab74..c2fb5aa90 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -1,9 +1,9 @@ -// RUN: %cladnumdiffclang -std=c++17 %s -I%S/../../include -oFunctionCalls.out 2>&1 | FileCheck %s +// RUN: %cladnumdiffclang -std=c++17 -Wno-writable-strings %s -I%S/../../include -oFunctionCalls.out 2>&1 | FileCheck %s // RUN: ./FunctionCalls.out | FileCheck -check-prefix=CHECK-EXEC %s -// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -enable-tbr -std=c++17 %s -I%S/../../include -oFunctionCalls.out +// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -enable-tbr -std=c++17 -Wno-writable-strings %s -I%S/../../include -oFunctionCalls.out // RUN: ./FunctionCalls.out | FileCheck -check-prefix=CHECK-EXEC %s -//CHECK-NOT: {{.*error|warning|note:.*}} +// CHECK-NOT: {{.*error|warning|note:.*}} #include "clad/Differentiator/Differentiator.h" @@ -456,20 +456,47 @@ double fn7(double i, double j) { // CHECK-NEXT: } // CHECK-NEXT: } +double check_and_return(double x, char c, const char* s) { + if (c == 'a' && s[0] == 'a') + return x; + return 1; +} +// CHECK: void check_and_return_pullback(double x, char c, const char *s, double _d_y, clad::array_ref _d_x, clad::array_ref _d_c, clad::array_ref _d_s) { +// CHECK-NEXT: bool _cond0; +// CHECK-NEXT: _cond0 = c == 'a' && s[0] == 'a'; +// CHECK-NEXT: if (_cond0) +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: goto _label1; +// CHECK-NEXT: _label1: +// CHECK-NEXT: ; +// CHECK-NEXT: if (_cond0) +// CHECK-NEXT: _label0: +// CHECK-NEXT: * _d_x += _d_y; +// CHECK-NEXT: } + double fn8(double x, double y) { - return x*y*std::tanh(1.0)*std::max(1.0, 2.0); + return check_and_return(x, 'a', "aa") * y * std::tanh(1.0) * std::max(1.0, 2.0); // expected-warning {{ISO C++11 does not allow conversion from string literal to 'char *' [-Wwritable-strings]}} } // CHECK: void fn8_grad(double x, double y, clad::array_ref _d_x, clad::array_ref _d_y) { // CHECK-NEXT: double _t0; // CHECK-NEXT: double _t1; +// CHECK-NEXT: double _t3; +// CHECK-NEXT: _t3 = check_and_return(x, 'a', "aa"); // CHECK-NEXT: _t1 = std::tanh(1.); // CHECK-NEXT: _t0 = std::max(1., 2.); // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { -// CHECK-NEXT: * _d_x += 1 * _t0 * _t1 * y; -// CHECK-NEXT: * _d_y += x * 1 * _t0 * _t1; +// CHECK-NEXT: double _grad0 = 0.; +// CHECK-NEXT: char _grad1 = 0i8; +// CHECK-NEXT: clad::array_ref _t2 = {"", 3UL}; +// CHECK-NEXT: check_and_return_pullback(x, 'a', "aa", 1 * _t0 * _t1 * y, &_grad0, &_grad1, _t2); +// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: * _d_x += _r0; +// CHECK-NEXT: char _r1 = _grad1; +// CHECK-NEXT: clad::array _r2({"", 3UL}); +// CHECK-NEXT: * _d_y += _t3 * 1 * _t0 * _t1; // CHECK-NEXT: } // CHECK-NEXT: }