Skip to content

Commit

Permalink
Fix char and string literals in reverse mode AD
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Feb 7, 2024
1 parent 19cc205 commit 241f526
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 7 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ namespace clad {
virtual StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS);
StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO);
StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL);
StmtDiff VisitCharacterLiteral(const clang::CharacterLiteral* CL);
StmtDiff VisitStringLiteral(const clang::StringLiteral* SL);
StmtDiff VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE);
virtual StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE);
StmtDiff VisitDeclStmt(const clang::DeclStmt* DS);
Expand Down
13 changes: 12 additions & 1 deletion lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return Clone(BL);
}

StmtDiff
ReverseModeVisitor::VisitCharacterLiteral(const CharacterLiteral* CL) {
return Clone(CL);
}

StmtDiff ReverseModeVisitor::VisitStringLiteral(const StringLiteral* SL) {
return StmtDiff(Clone(SL), StringLiteral::Create(
m_Context, "", SL->getKind(), SL->isPascal(),
SL->getType(), utils::GetValidSLoc(m_Sema)));
}

StmtDiff ReverseModeVisitor::VisitReturnStmt(const ReturnStmt* RS) {
// Initially, df/df = 1.
const Expr* value = RS->getRetValue();
Expand Down Expand Up @@ -1386,7 +1397,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// subexpression.
if (const auto* MTE = dyn_cast<MaterializeTemporaryExpr>(arg))
arg = clad_compat::GetSubExpr(MTE)->IgnoreImpCasts();
if (!isa<FloatingLiteral>(arg) && !isa<IntegerLiteral>(arg)) {
if (!arg->isEvaluatable(m_Context)) {
allArgsAreConstantLiterals = false;
break;
}
Expand Down
39 changes: 33 additions & 6 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// RUN: %cladnumdiffclang -std=c++17 %s -I%S/../../include -oFunctionCalls.out 2>&1 | FileCheck %s
// RUN: %cladnumdiffclang -std=c++17 -Wno-writable-strings %s -I%S/../../include -oFunctionCalls.out 2>&1 | FileCheck %s
// RUN: ./FunctionCalls.out | FileCheck -check-prefix=CHECK-EXEC %s
// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -enable-tbr -std=c++17 %s -I%S/../../include -oFunctionCalls.out
// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -enable-tbr -std=c++17 -Wno-writable-strings %s -I%S/../../include -oFunctionCalls.out
// RUN: ./FunctionCalls.out | FileCheck -check-prefix=CHECK-EXEC %s

//CHECK-NOT: {{.*error|warning|note:.*}}
// CHECK-NOT: {{.*error|warning|note:.*}}

#include "clad/Differentiator/Differentiator.h"

Expand Down Expand Up @@ -456,20 +456,47 @@ double fn7(double i, double j) {
// CHECK-NEXT: }
// CHECK-NEXT: }

double check_and_return(double x, char c, const char* s) {
if (c == 'a' && s[0] == 'a')
return x;
return 1;
}
// CHECK: void check_and_return_pullback(double x, char c, const char *s, double _d_y, clad::array_ref<double> _d_x, clad::array_ref<char> _d_c, clad::array_ref<char> _d_s) {
// CHECK-NEXT: bool _cond0;
// CHECK-NEXT: _cond0 = c == 'a' && s[0] == 'a';
// CHECK-NEXT: if (_cond0)
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: goto _label1;
// CHECK-NEXT: _label1:
// CHECK-NEXT: ;
// CHECK-NEXT: if (_cond0)
// CHECK-NEXT: _label0:
// CHECK-NEXT: * _d_x += _d_y;
// CHECK-NEXT: }

double fn8(double x, double y) {
return x*y*std::tanh(1.0)*std::max(1.0, 2.0);
return check_and_return(x, 'a', "aa") * y * std::tanh(1.0) * std::max(1.0, 2.0); // expected-warning {{ISO C++11 does not allow conversion from string literal to 'char *' [-Wwritable-strings]}}
}

// CHECK: void fn8_grad(double x, double y, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y) {
// CHECK-NEXT: double _t0;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: double _t3;
// CHECK-NEXT: _t3 = check_and_return(x, 'a', "aa");
// CHECK-NEXT: _t1 = std::tanh(1.);
// CHECK-NEXT: _t0 = std::max(1., 2.);
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: * _d_x += 1 * _t0 * _t1 * y;
// CHECK-NEXT: * _d_y += x * 1 * _t0 * _t1;
// CHECK-NEXT: double _grad0 = 0.;
// CHECK-NEXT: char _grad1 = 0i8;
// CHECK-NEXT: clad::array_ref<char> _t2 = {"", 3UL};
// CHECK-NEXT: check_and_return_pullback(x, 'a', "aa", 1 * _t0 * _t1 * y, &_grad0, &_grad1, _t2);
// CHECK-NEXT: double _r0 = _grad0;
// CHECK-NEXT: * _d_x += _r0;
// CHECK-NEXT: char _r1 = _grad1;
// CHECK-NEXT: clad::array<char> _r2({"", 3UL});
// CHECK-NEXT: * _d_y += _t3 * 1 * _t0 * _t1;
// CHECK-NEXT: }
// CHECK-NEXT: }

Expand Down

0 comments on commit 241f526

Please sign in to comment.