Skip to content

Commit

Permalink
Add support for CStyleCastExpr
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Feb 22, 2024
1 parent 3d8feec commit 1cb590e
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 1 deletion.
1 change: 1 addition & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class BaseForwardModeVisitor
StmtDiff VisitSubstNonTypeTemplateParmExpr(
const clang::SubstNonTypeTemplateParmExpr* NTTP);
StmtDiff VisitImplicitValueInitExpr(const clang::ImplicitValueInitExpr* IVIE);
StmtDiff VisitCStyleCastExpr(const clang::CStyleCastExpr* CSCE);

virtual clang::QualType
GetPushForwardDerivativeType(clang::QualType ParamType);
Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ namespace clad {
StmtDiff VisitImplicitCastExpr(const clang::ImplicitCastExpr* ICE);
StmtDiff
VisitImplicitValueInitExpr(const clang::ImplicitValueInitExpr* IVIE);
StmtDiff VisitCStyleCastExpr(const clang::CStyleCastExpr* CSCE);
StmtDiff VisitInitListExpr(const clang::InitListExpr* ILE);
StmtDiff VisitIntegerLiteral(const clang::IntegerLiteral* IL);
StmtDiff VisitMemberExpr(const clang::MemberExpr* ME);
Expand Down
19 changes: 19 additions & 0 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,25 @@ StmtDiff BaseForwardModeVisitor::VisitImplicitValueInitExpr(
return StmtDiff(Clone(E), Clone(E));
}

StmtDiff
BaseForwardModeVisitor::VisitCStyleCastExpr(const CStyleCastExpr* CSCE) {
StmtDiff subExprDiff = Visit(CSCE->getSubExpr());
// Create a new CStyleCastExpr with the same type and the same subexpression
// as the original one.
Expr* castExpr = m_Sema
.BuildCStyleCastExpr(
CSCE->getLParenLoc(), CSCE->getTypeInfoAsWritten(),
CSCE->getRParenLoc(), subExprDiff.getExpr())
.get();
Expr* castExprDiff =
m_Sema
.BuildCStyleCastExpr(CSCE->getLParenLoc(),
CSCE->getTypeInfoAsWritten(),
CSCE->getRParenLoc(), subExprDiff.getExpr_dx())
.get();
return StmtDiff(castExpr, castExprDiff);
}

StmtDiff
BaseForwardModeVisitor::VisitCXXDefaultArgExpr(const CXXDefaultArgExpr* DE) {
// FIXME: Shouldn't we simply clone the CXXDefaultArgExpr?
Expand Down
19 changes: 18 additions & 1 deletion lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1429,7 +1429,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}
if (allArgsAreConstantLiterals)
return StmtDiff(Clone(CE));
return StmtDiff(Clone(CE), Clone(CE));
}

// Stores the call arguments for the function to be derived
Expand Down Expand Up @@ -2943,6 +2943,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return {Clone(IVIE), Clone(IVIE)};
}

StmtDiff ReverseModeVisitor::VisitCStyleCastExpr(const CStyleCastExpr* CSCE) {
StmtDiff subExprDiff = Visit(CSCE->getSubExpr(), dfdx());
Expr* castExpr = m_Sema
.BuildCStyleCastExpr(
CSCE->getLParenLoc(), CSCE->getTypeInfoAsWritten(),
CSCE->getRParenLoc(), subExprDiff.getExpr())
.get();
Expr* castExprDiff = subExprDiff.getExpr_dx();
if (castExprDiff != nullptr)
castExprDiff = m_Sema
.BuildCStyleCastExpr(
CSCE->getLParenLoc(), CSCE->getTypeInfoAsWritten(),
CSCE->getRParenLoc(), subExprDiff.getExpr_dx())
.get();
return {castExpr, castExprDiff};
}

StmtDiff ReverseModeVisitor::VisitMemberExpr(const MemberExpr* ME) {
auto baseDiff = VisitWithExplicitNoDfDx(ME->getBase());
auto* field = ME->getMemberDecl();
Expand Down

0 comments on commit 1cb590e

Please sign in to comment.