Skip to content

Commit

Permalink
Fix assertion related to template arguments pack
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Feb 6, 2024
1 parent d50a33f commit 77a8a22
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
6 changes: 4 additions & 2 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1929,11 +1929,13 @@ StmtDiff BaseForwardModeVisitor::VisitCXXFunctionalCastExpr(
SourceLocation fakeLoc = utils::GetValidSLoc(m_Sema);
Expr* clonedFCE = m_Sema
.BuildCXXFunctionalCastExpr(
FCE->getTypeInfoAsWritten(), FCE->getType(), fakeLoc, castExprDiff.getExpr(), fakeLoc)
FCE->getTypeInfoAsWritten(), FCE->getType(),
fakeLoc, castExprDiff.getExpr(), fakeLoc)
.get();
Expr* derivedFCE = m_Sema
.BuildCXXFunctionalCastExpr(
FCE->getTypeInfoAsWritten(), FCE->getType(), fakeLoc, castExprDiff.getExpr_dx(), fakeLoc)
FCE->getTypeInfoAsWritten(), FCE->getType(),
fakeLoc, castExprDiff.getExpr_dx(), fakeLoc)
.get();
return {clonedFCE, derivedFCE};
}
Expand Down
22 changes: 16 additions & 6 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,9 +562,14 @@ namespace clad {
// bitmask_opts is a template pack of unsigned integers, so we need to
// do bitwise or of all the values to get the final value.
unsigned bitmasked_opts_value = 0;
for (auto const& arg :
FD->getTemplateSpecializationArgs()->get(0).pack_elements()) {
bitmasked_opts_value |= arg.getAsIntegral().getExtValue();
const auto template_arg = FD->getTemplateSpecializationArgs()->get(0);
if (template_arg.getKind() == TemplateArgument::Pack) {
for (const auto& arg :
FD->getTemplateSpecializationArgs()->get(0).pack_elements()) {
bitmasked_opts_value |= arg.getAsIntegral().getExtValue();
}
} else {
bitmasked_opts_value = template_arg.getAsIntegral().getExtValue();
}
unsigned derivative_order =
clad::GetDerivativeOrder(bitmasked_opts_value);
Expand Down Expand Up @@ -602,9 +607,14 @@ namespace clad {
// bitmask_opts is a template pack of unsigned integers, so we need to
// do bitwise or of all the values to get the final value.
unsigned bitmasked_opts_value = 0;
for (auto const& arg :
FD->getTemplateSpecializationArgs()->get(0).pack_elements()) {
bitmasked_opts_value |= arg.getAsIntegral().getExtValue();
const auto template_arg = FD->getTemplateSpecializationArgs()->get(0);
if (template_arg.getKind() == TemplateArgument::Pack) {
for (const auto& arg :
FD->getTemplateSpecializationArgs()->get(0).pack_elements()) {
bitmasked_opts_value |= arg.getAsIntegral().getExtValue();
}
} else {
bitmasked_opts_value = template_arg.getAsIntegral().getExtValue();
}
if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme)) {
request.use_enzyme = true;
Expand Down
6 changes: 4 additions & 2 deletions lib/Differentiator/StmtClone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Stmt* StmtClone::VisitDeclRefExpr(DeclRefExpr *Node) {
Ctx, Node->getQualifierLoc(), Node->getTemplateKeywordLoc(),
Node->getDecl(), Node->refersToEnclosingVariableOrCapture(),
Node->getNameInfo(), CloneType(Node->getType()), Node->getValueKind(),
Node->getFoundDecl(), &TAListInfo);
Node->getFoundDecl(), &TAListInfo, Node->isNonOdrUse());
}
DEFINE_CREATE_EXPR(IntegerLiteral,
(Ctx, Node->getValue(), CloneType(Node->getType()),
Expand Down Expand Up @@ -227,7 +227,9 @@ DEFINE_CLONE_EXPR(VAArgExpr,
Node->getWrittenTypeInfo(), Node->getRParenLoc(),
CloneType(Node->getType()), Node->isMicrosoftABI()))
DEFINE_CLONE_EXPR(ImplicitValueInitExpr, (CloneType(Node->getType())))
DEFINE_CLONE_EXPR(CXXScalarValueInitExpr, (CloneType(Node->getType()), Node->getTypeSourceInfo(), Node->getRParenLoc()))
DEFINE_CLONE_EXPR(CXXScalarValueInitExpr,
(CloneType(Node->getType()), Node->getTypeSourceInfo(),
Node->getRParenLoc()))
DEFINE_CLONE_EXPR(ExtVectorElementExpr, (Node->getType(), Node->getValueKind(), Clone(Node->getBase()), Node->getAccessor(), Node->getAccessorLoc()))
DEFINE_CLONE_EXPR(CXXBoolLiteralExpr, (Node->getValue(), Node->getType(), Node->getSourceRange().getBegin()))
DEFINE_CLONE_EXPR(CXXNullPtrLiteralExpr, (Node->getType(), Node->getSourceRange().getBegin()))
Expand Down

0 comments on commit 77a8a22

Please sign in to comment.