Skip to content

Commit

Permalink
Introduce type cloning.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Nov 3, 2023
1 parent cf4ee10 commit a884c8c
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 89 deletions.
12 changes: 10 additions & 2 deletions include/clad/Differentiator/StmtClone.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,12 @@ namespace utils {
template<class StmtTy>
StmtTy* Clone(const StmtTy* S);

// visitor part (not for public use)
// Stmt.def could be used if ABSTR_STMT is introduced
/// Cloning types is necessary since VariableArrayType
/// store a pointer to their size expression.
clang::QualType CloneType(const clang::QualType T);

// visitor part (not for public use)
// Stmt.def could be used if ABSTR_STMT is introduced
#define DECLARE_CLONE_FN(CLASS) clang::Stmt* Visit ## CLASS(clang::CLASS *Node);
DECLARE_CLONE_FN(BinaryOperator)
DECLARE_CLONE_FN(UnaryOperator)
Expand Down Expand Up @@ -153,6 +157,10 @@ namespace utils {
ReferencesUpdater(clang::Sema& SemaRef, clang::Scope* S,
const clang::FunctionDecl* FD);
bool VisitDeclRefExpr(clang::DeclRefExpr* DRE);
bool VisitStmt(clang::Stmt* S);
/// Used to update the size expression of QT
/// if QT is VariableArrayType.
void updateType(clang::QualType QT);
};
} // namespace utils
} // namespace clad
Expand Down
3 changes: 3 additions & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,9 @@ namespace clad {
clang::Stmt* Clone(const clang::Stmt* S);
/// A shorthand to simplify cloning of expressions.
clang::Expr* Clone(const clang::Expr* E);
/// Cloning types is necessary since VariableArrayType
/// store a pointer to their size expression.
clang::QualType CloneType(const clang::QualType T);
};
} // end namespace clad

Expand Down
4 changes: 2 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2534,11 +2534,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Here separate behaviour for record and non-record types is only
// necessary to preserve the old tests.
if (VD->getType()->isRecordType())
VDClone = BuildVarDecl(VD->getType(), VD->getNameAsString(),
VDClone = BuildVarDecl(CloneType(VD->getType()), VD->getNameAsString(),
initDiff.getExpr(), VD->isDirectInit(),
VD->getTypeSourceInfo(), VD->getInitStyle());
else
VDClone = BuildVarDecl(VD->getType(), VD->getNameAsString(),
VDClone = BuildVarDecl(CloneType(VD->getType()), VD->getNameAsString(),
initDiff.getExpr(), VD->isDirectInit());
Expr* derivedVDE = BuildDeclRef(VDDerived);

Expand Down
Loading

0 comments on commit a884c8c

Please sign in to comment.