Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add variables for objects in VisitMaterializeTemporaryExpr
Browse files Browse the repository at this point in the history
MihailMihov committed Jun 25, 2024
1 parent 4f487e3 commit b97b75e
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
@@ -1368,6 +1368,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(Clone(CE));
}

auto* CEModified = dyn_cast<CallExpr>(Clone(CE));

auto NArgs = FD->getNumParams();
// If the function has no args and is not a member function call then we
// assume that it is not related to independent variables and does not
@@ -1610,8 +1612,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
baseOriginalE = OCE->getArg(0);

baseDiff = Visit(baseOriginalE);
Expr* baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr());
baseDiff.updateStmt(baseDiffStore);

if(auto* ME = dyn_cast<MemberExpr>(CEModified->getCallee()))
ME->setBase(baseDiff.getExpr());

Expr* baseDerivative = baseDiff.getExpr_dx();
if (!baseDerivative->getType()->isPointerType())
baseDerivative =
@@ -1867,13 +1871,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint");
return StmtDiff(resValue, nullptr, resAdjoint);
} // Recreate the original call expression.

call = m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
.ActOnCallExpr(getCurrentScope(), CEModified->getCallee(), Loc,
CallArgs, Loc)
.get();
return StmtDiff(call);

return {};
return StmtDiff(call);
}

Expr* ReverseModeVisitor::GetMultiArgCentralDiffCall(
@@ -3708,8 +3712,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
const clang::MaterializeTemporaryExpr* MTE) {
// `MaterializeTemporaryExpr` node will be created automatically if it is
// required by `ActOn`/`Build` Sema functions.
StmtDiff MTEDiff = Visit(clad_compat::GetSubExpr(MTE), dfdx());
return MTEDiff;
Expr* MTEStore = GlobalStoreAndRef(Clone(clad_compat::GetSubExpr(MTE)), "_t",
/*force=*/true);

auto* MTEStoreDRE = dyn_cast<DeclRefExpr>(MTEStore);
DeclDiff<VarDecl> MTEDerived =
DifferentiateVarDecl(dyn_cast<VarDecl>(MTEStoreDRE->getDecl()));
addToCurrentBlock(BuildDeclStmt(MTEDerived.getDecl_dx()));

return StmtDiff{MTEStore, BuildDeclRef(MTEDerived.getDecl_dx())};
}

StmtDiff ReverseModeVisitor::VisitSubstNonTypeTemplateParmExpr(

0 comments on commit b97b75e

Please sign in to comment.