Skip to content

Commit

Permalink
Add type cloning.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Sep 12, 2023
1 parent b46f3db commit 06c4bed
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 98 deletions.
6 changes: 4 additions & 2 deletions include/clad/Differentiator/StmtClone.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ namespace utils {
template<class StmtTy>
StmtTy* Clone(const StmtTy* S);

// visitor part (not for public use)
// Stmt.def could be used if ABSTR_STMT is introduced
clang::QualType CloneType(const clang::QualType T);

// visitor part (not for public use)
// Stmt.def could be used if ABSTR_STMT is introduced
#define DECLARE_CLONE_FN(CLASS) clang::Stmt* Visit ## CLASS(clang::CLASS *Node);
DECLARE_CLONE_FN(BinaryOperator)
DECLARE_CLONE_FN(UnaryOperator)
Expand Down
12 changes: 6 additions & 6 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1267,8 +1267,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
for (std::size_t i = 0; i < Indices.size(); i++) {
/// FIXME: Remove redundant indices vectors.
StmtDiff IdxDiff = Visit(Indices[i]);
clonedIndices[i] = IdxDiff.getExpr();
reverseIndices[i] = IdxDiff.getExpr();
clonedIndices[i] = Clone(IdxDiff.getExpr());
reverseIndices[i] = Clone(IdxDiff.getExpr());
// reverseIndices[i] = Clone(IdxDiff.getExpr());
forwSweepDerivativeIndices[i] = IdxDiff.getExpr();
}
Expand Down Expand Up @@ -1988,8 +1988,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
diff = Visit(E, dfdx());
auto EStored = GlobalStoreAndRef(diff.getExpr());
if (EStored.getExpr() != diff.getExpr()) {
auto* assign = BuildOp(BinaryOperatorKind::BO_Assign, diff.getExpr(),
EStored.getExpr_dx());
auto* assign = BuildOp(BinaryOperatorKind::BO_Assign,
Clone(diff.getExpr()), EStored.getExpr_dx());
if (isInsideLoop)
addToCurrentBlock(EStored.getExpr(), direction::forward);
addToCurrentBlock(assign, direction::reverse);
Expand All @@ -2002,8 +2002,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
diff = Visit(E, dfdx());
auto EStored = GlobalStoreAndRef(diff.getExpr());
if (EStored.getExpr() != diff.getExpr()) {
auto* assign = BuildOp(BinaryOperatorKind::BO_Assign, diff.getExpr(),
EStored.getExpr_dx());
auto* assign = BuildOp(BinaryOperatorKind::BO_Assign,
Clone(diff.getExpr()), EStored.getExpr_dx());
if (isInsideLoop)
addToCurrentBlock(EStored.getExpr(), direction::forward);
addToCurrentBlock(assign, direction::reverse);
Expand Down
Loading

0 comments on commit 06c4bed

Please sign in to comment.