Skip to content

Commit

Permalink
Fix char and string literals in reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Feb 7, 2024
1 parent 899525f commit 76fa8d1
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 19 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ namespace clad {
virtual StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS);
StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO);
StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL);
StmtDiff VisitCharacterLiteral(const clang::CharacterLiteral* CL);
StmtDiff VisitStringLiteral(const clang::StringLiteral* SL);
StmtDiff VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE);
virtual StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE);
StmtDiff VisitDeclStmt(const clang::DeclStmt* DS);
Expand Down
5 changes: 4 additions & 1 deletion include/clad/Differentiator/StmtClone.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,14 @@ namespace utils {
DECLARE_CLONE_FN(PseudoObjectExpr)
DECLARE_CLONE_FN(SubstNonTypeTemplateParmExpr)
DECLARE_CLONE_FN(CXXScalarValueInitExpr)
DECLARE_CLONE_FN(ValueStmt)
// `ConstantExpr` node is only available after clang 7.
#if CLANG_VERSION_MAJOR > 7
DECLARE_CLONE_FN(ConstantExpr)
#endif
// `ValueStmt` node is only available after clang 8.
#if CLANG_VERSION_MAJOR > 8
DECLARE_CLONE_FN(ValueStmt)
#endif

clang::Stmt* VisitStmt(clang::Stmt*);
};
Expand Down
4 changes: 2 additions & 2 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -935,14 +935,14 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
StmtDiff BaseForwardModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) {
QualType T = IL->getType();
llvm::APInt zero(m_Context.getIntWidth(T), /*value*/ 0);
auto constant0 = IntegerLiteral::Create(m_Context, zero, T, noLoc);
auto* constant0 = IntegerLiteral::Create(m_Context, zero, T, noLoc);
return StmtDiff(Clone(IL), constant0);
}

StmtDiff
BaseForwardModeVisitor::VisitFloatingLiteral(const FloatingLiteral* FL) {
llvm::APFloat zero = llvm::APFloat::getZero(FL->getSemantics());
auto constant0 =
auto* constant0 =
FloatingLiteral::Create(m_Context, zero, true, FL->getType(), noLoc);
return StmtDiff(Clone(FL), constant0);
}
Expand Down
19 changes: 7 additions & 12 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,14 +563,12 @@ namespace clad {
// do bitwise or of all the values to get the final value.
unsigned bitmasked_opts_value = 0;
const auto template_arg = FD->getTemplateSpecializationArgs()->get(0);
if (template_arg.getKind() == TemplateArgument::Pack) {
if (template_arg.getKind() == TemplateArgument::Pack)
for (const auto& arg :
FD->getTemplateSpecializationArgs()->get(0).pack_elements()) {
FD->getTemplateSpecializationArgs()->get(0).pack_elements())
bitmasked_opts_value |= arg.getAsIntegral().getExtValue();
}
} else {
else
bitmasked_opts_value = template_arg.getAsIntegral().getExtValue();
}
unsigned derivative_order =
clad::GetDerivativeOrder(bitmasked_opts_value);
if (derivative_order == 0) {
Expand Down Expand Up @@ -608,17 +606,14 @@ namespace clad {
// do bitwise or of all the values to get the final value.
unsigned bitmasked_opts_value = 0;
const auto template_arg = FD->getTemplateSpecializationArgs()->get(0);
if (template_arg.getKind() == TemplateArgument::Pack) {
if (template_arg.getKind() == TemplateArgument::Pack)
for (const auto& arg :
FD->getTemplateSpecializationArgs()->get(0).pack_elements()) {
FD->getTemplateSpecializationArgs()->get(0).pack_elements())
bitmasked_opts_value |= arg.getAsIntegral().getExtValue();
}
} else {
else
bitmasked_opts_value = template_arg.getAsIntegral().getExtValue();
}
if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme)) {
if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme))
request.use_enzyme = true;
}
// reverse vector mode is not yet supported.
if (clad::HasOption(bitmasked_opts_value, clad::opts::vector_mode)) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
Expand Down
12 changes: 11 additions & 1 deletion lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return Clone(BL);
}

StmtDiff
ReverseModeVisitor::VisitCharacterLiteral(const CharacterLiteral* CL) {
return Clone(CL);
}

StmtDiff ReverseModeVisitor::VisitStringLiteral(const StringLiteral* SL) {
// FIXME: Returning an emptry string literal returns an error.
return StmtDiff(Clone(SL), Clone(SL));
}

StmtDiff ReverseModeVisitor::VisitReturnStmt(const ReturnStmt* RS) {
// Initially, df/df = 1.
const Expr* value = RS->getRetValue();
Expand Down Expand Up @@ -1386,7 +1396,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// subexpression.
if (const auto* MTE = dyn_cast<MaterializeTemporaryExpr>(arg))
arg = clad_compat::GetSubExpr(MTE)->IgnoreImpCasts();
if (!isa<FloatingLiteral>(arg) && !isa<IntegerLiteral>(arg)) {
if (!arg->isEvaluatable(m_Context)) {
allArgsAreConstantLiterals = false;
break;
}
Expand Down
3 changes: 3 additions & 0 deletions lib/Differentiator/StmtClone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,10 @@ DEFINE_CLONE_STMT(BreakStmt, (Node->getBreakLoc()))
DEFINE_CLONE_STMT(CXXCatchStmt, (Node->getCatchLoc(),
CloneDeclOrNull(Node->getExceptionDecl()),
Clone(Node->getHandlerBlock())))

#if CLANG_VERSION_MAJOR > 8
DEFINE_CLONE_STMT(ValueStmt, (Node->getStmtClass()))
#endif

Stmt* StmtClone::VisitCXXTryStmt(CXXTryStmt* Node) {
llvm::SmallVector<Stmt*, 4> CatchStmts(std::max(1u, Node->getNumHandlers()));
Expand Down
31 changes: 28 additions & 3 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -456,20 +456,45 @@ double fn7(double i, double j) {
// CHECK-NEXT: }
// CHECK-NEXT: }

double check_and_return(double x, char c) {
if (c == 'a')
return x;
return 1;
}
// CHECK: void check_and_return_pullback(double x, char c, double _d_y, clad::array_ref<double> _d_x, clad::array_ref<char> _d_c) {
// CHECK-NEXT: bool _cond0;
// CHECK-NEXT: _cond0 = c == 'a';
// CHECK-NEXT: if (_cond0)
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: goto _label1;
// CHECK-NEXT: _label1:
// CHECK-NEXT: ;
// CHECK-NEXT: if (_cond0)
// CHECK-NEXT: _label0:
// CHECK-NEXT: * _d_x += _d_y;
// CHECK-NEXT: }

double fn8(double x, double y) {
return x*y*std::tanh(1.0)*std::max(1.0, 2.0);
return check_and_return(x, 'a') * y * std::tanh(1.0) * std::max(1.0, 2.0);
}

// CHECK: void fn8_grad(double x, double y, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y) {
// CHECK-NEXT: double _t0;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: double _t2;
// CHECK-NEXT: _t2 = check_and_return(x, 'a');
// CHECK-NEXT: _t1 = std::tanh(1.);
// CHECK-NEXT: _t0 = std::max(1., 2.);
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: * _d_x += 1 * _t0 * _t1 * y;
// CHECK-NEXT: * _d_y += x * 1 * _t0 * _t1;
// CHECK-NEXT: double _grad0 = 0.;
// CHECK-NEXT: char _grad1 = 0i8;
// CHECK-NEXT: check_and_return_pullback(x, 'a', 1 * _t0 * _t1 * y, &_grad0, &_grad1);
// CHECK-NEXT: double _r0 = _grad0;
// CHECK-NEXT: * _d_x += _r0;
// CHECK-NEXT: char _r1 = _grad1;
// CHECK-NEXT: * _d_y += _t2 * 1 * _t0 * _t1;
// CHECK-NEXT: }
// CHECK-NEXT: }

Expand Down

0 comments on commit 76fa8d1

Please sign in to comment.