From 76fa8d17a19ce1e1ea6e190091cb5e8c7a5d99a6 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Wed, 7 Feb 2024 10:58:39 +0100 Subject: [PATCH] Fix char and string literals in reverse mode --- .../clad/Differentiator/ReverseModeVisitor.h | 2 ++ include/clad/Differentiator/StmtClone.h | 5 ++- lib/Differentiator/BaseForwardModeVisitor.cpp | 4 +-- lib/Differentiator/DiffPlanner.cpp | 19 +++++------- lib/Differentiator/ReverseModeVisitor.cpp | 12 ++++++- lib/Differentiator/StmtClone.cpp | 3 ++ test/Gradient/FunctionCalls.C | 31 +++++++++++++++++-- 7 files changed, 57 insertions(+), 19 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index c0010bce3..129d02b98 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -347,6 +347,8 @@ namespace clad { virtual StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS); StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO); StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL); + StmtDiff VisitCharacterLiteral(const clang::CharacterLiteral* CL); + StmtDiff VisitStringLiteral(const clang::StringLiteral* SL); StmtDiff VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE); virtual StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE); StmtDiff VisitDeclStmt(const clang::DeclStmt* DS); diff --git a/include/clad/Differentiator/StmtClone.h b/include/clad/Differentiator/StmtClone.h index 31e54ae04..8d5a2f713 100644 --- a/include/clad/Differentiator/StmtClone.h +++ b/include/clad/Differentiator/StmtClone.h @@ -122,11 +122,14 @@ namespace utils { DECLARE_CLONE_FN(PseudoObjectExpr) DECLARE_CLONE_FN(SubstNonTypeTemplateParmExpr) DECLARE_CLONE_FN(CXXScalarValueInitExpr) - DECLARE_CLONE_FN(ValueStmt) // `ConstantExpr` node is only available after clang 7. #if CLANG_VERSION_MAJOR > 7 DECLARE_CLONE_FN(ConstantExpr) #endif + // `ValueStmt` node is only available after clang 8. + #if CLANG_VERSION_MAJOR > 8 + DECLARE_CLONE_FN(ValueStmt) + #endif clang::Stmt* VisitStmt(clang::Stmt*); }; diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 4dc24a7a3..8367f3168 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -935,14 +935,14 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { StmtDiff BaseForwardModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) { QualType T = IL->getType(); llvm::APInt zero(m_Context.getIntWidth(T), /*value*/ 0); - auto constant0 = IntegerLiteral::Create(m_Context, zero, T, noLoc); + auto* constant0 = IntegerLiteral::Create(m_Context, zero, T, noLoc); return StmtDiff(Clone(IL), constant0); } StmtDiff BaseForwardModeVisitor::VisitFloatingLiteral(const FloatingLiteral* FL) { llvm::APFloat zero = llvm::APFloat::getZero(FL->getSemantics()); - auto constant0 = + auto* constant0 = FloatingLiteral::Create(m_Context, zero, true, FL->getType(), noLoc); return StmtDiff(Clone(FL), constant0); } diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index 5e7fb331f..87ec8df05 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -563,14 +563,12 @@ namespace clad { // do bitwise or of all the values to get the final value. unsigned bitmasked_opts_value = 0; const auto template_arg = FD->getTemplateSpecializationArgs()->get(0); - if (template_arg.getKind() == TemplateArgument::Pack) { + if (template_arg.getKind() == TemplateArgument::Pack) for (const auto& arg : - FD->getTemplateSpecializationArgs()->get(0).pack_elements()) { + FD->getTemplateSpecializationArgs()->get(0).pack_elements()) bitmasked_opts_value |= arg.getAsIntegral().getExtValue(); - } - } else { + else bitmasked_opts_value = template_arg.getAsIntegral().getExtValue(); - } unsigned derivative_order = clad::GetDerivativeOrder(bitmasked_opts_value); if (derivative_order == 0) { @@ -608,17 +606,14 @@ namespace clad { // do bitwise or of all the values to get the final value. unsigned bitmasked_opts_value = 0; const auto template_arg = FD->getTemplateSpecializationArgs()->get(0); - if (template_arg.getKind() == TemplateArgument::Pack) { + if (template_arg.getKind() == TemplateArgument::Pack) for (const auto& arg : - FD->getTemplateSpecializationArgs()->get(0).pack_elements()) { + FD->getTemplateSpecializationArgs()->get(0).pack_elements()) bitmasked_opts_value |= arg.getAsIntegral().getExtValue(); - } - } else { + else bitmasked_opts_value = template_arg.getAsIntegral().getExtValue(); - } - if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme)) { + if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme)) request.use_enzyme = true; - } // reverse vector mode is not yet supported. if (clad::HasOption(bitmasked_opts_value, clad::opts::vector_mode)) { utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index d0ec52a6c..ed0c41db9 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1147,6 +1147,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return Clone(BL); } + StmtDiff + ReverseModeVisitor::VisitCharacterLiteral(const CharacterLiteral* CL) { + return Clone(CL); + } + + StmtDiff ReverseModeVisitor::VisitStringLiteral(const StringLiteral* SL) { + // FIXME: Returning an emptry string literal returns an error. + return StmtDiff(Clone(SL), Clone(SL)); + } + StmtDiff ReverseModeVisitor::VisitReturnStmt(const ReturnStmt* RS) { // Initially, df/df = 1. const Expr* value = RS->getRetValue(); @@ -1386,7 +1396,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // subexpression. if (const auto* MTE = dyn_cast(arg)) arg = clad_compat::GetSubExpr(MTE)->IgnoreImpCasts(); - if (!isa(arg) && !isa(arg)) { + if (!arg->isEvaluatable(m_Context)) { allArgsAreConstantLiterals = false; break; } diff --git a/lib/Differentiator/StmtClone.cpp b/lib/Differentiator/StmtClone.cpp index a366ba000..724f48b13 100644 --- a/lib/Differentiator/StmtClone.cpp +++ b/lib/Differentiator/StmtClone.cpp @@ -432,7 +432,10 @@ DEFINE_CLONE_STMT(BreakStmt, (Node->getBreakLoc())) DEFINE_CLONE_STMT(CXXCatchStmt, (Node->getCatchLoc(), CloneDeclOrNull(Node->getExceptionDecl()), Clone(Node->getHandlerBlock()))) + +#if CLANG_VERSION_MAJOR > 8 DEFINE_CLONE_STMT(ValueStmt, (Node->getStmtClass())) +#endif Stmt* StmtClone::VisitCXXTryStmt(CXXTryStmt* Node) { llvm::SmallVector CatchStmts(std::max(1u, Node->getNumHandlers())); diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index b15f2ab74..7a11c8ef2 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -456,20 +456,45 @@ double fn7(double i, double j) { // CHECK-NEXT: } // CHECK-NEXT: } +double check_and_return(double x, char c) { + if (c == 'a') + return x; + return 1; +} +// CHECK: void check_and_return_pullback(double x, char c, double _d_y, clad::array_ref _d_x, clad::array_ref _d_c) { +// CHECK-NEXT: bool _cond0; +// CHECK-NEXT: _cond0 = c == 'a'; +// CHECK-NEXT: if (_cond0) +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: goto _label1; +// CHECK-NEXT: _label1: +// CHECK-NEXT: ; +// CHECK-NEXT: if (_cond0) +// CHECK-NEXT: _label0: +// CHECK-NEXT: * _d_x += _d_y; +// CHECK-NEXT: } + double fn8(double x, double y) { - return x*y*std::tanh(1.0)*std::max(1.0, 2.0); + return check_and_return(x, 'a') * y * std::tanh(1.0) * std::max(1.0, 2.0); } // CHECK: void fn8_grad(double x, double y, clad::array_ref _d_x, clad::array_ref _d_y) { // CHECK-NEXT: double _t0; // CHECK-NEXT: double _t1; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: _t2 = check_and_return(x, 'a'); // CHECK-NEXT: _t1 = std::tanh(1.); // CHECK-NEXT: _t0 = std::max(1., 2.); // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { -// CHECK-NEXT: * _d_x += 1 * _t0 * _t1 * y; -// CHECK-NEXT: * _d_y += x * 1 * _t0 * _t1; +// CHECK-NEXT: double _grad0 = 0.; +// CHECK-NEXT: char _grad1 = 0i8; +// CHECK-NEXT: check_and_return_pullback(x, 'a', 1 * _t0 * _t1 * y, &_grad0, &_grad1); +// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: * _d_x += _r0; +// CHECK-NEXT: char _r1 = _grad1; +// CHECK-NEXT: * _d_y += _t2 * 1 * _t0 * _t1; // CHECK-NEXT: } // CHECK-NEXT: }