Skip to content

Commit

Permalink
Fix temporary expressions in reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
MihailMihov committed Jul 17, 2024
1 parent 8963b8c commit c3df195
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(Clone(CE));
}

auto* CEModified = dyn_cast<CallExpr>(Clone(CE));

auto NArgs = FD->getNumParams();
// If the function has no args and is not a member function call then we
// assume that it is not related to independent variables and does not
Expand Down Expand Up @@ -1618,8 +1620,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
baseOriginalE = OCE->getArg(0);

baseDiff = Visit(baseOriginalE);
Expr* baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr());
baseDiff.updateStmt(baseDiffStore);

if (auto* ME = dyn_cast<MemberExpr>(CEModified->getCallee()))
ME->setBase(baseDiff.getExpr());

Expr* baseDerivative = baseDiff.getExpr_dx();
if (!baseDerivative->getType()->isPointerType())
baseDerivative =
Expand Down Expand Up @@ -1878,13 +1882,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint");
return StmtDiff(resValue, nullptr, resAdjoint);
} // Recreate the original call expression.

call = m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
.ActOnCallExpr(getCurrentScope(), CEModified->getCallee(), Loc,
CallArgs, Loc)
.get();
return StmtDiff(call);

return {};
return StmtDiff(call);
}

Expr* ReverseModeVisitor::GetMultiArgCentralDiffCall(
Expand Down Expand Up @@ -3747,8 +3751,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
const clang::MaterializeTemporaryExpr* MTE) {
// `MaterializeTemporaryExpr` node will be created automatically if it is
// required by `ActOn`/`Build` Sema functions.
StmtDiff MTEDiff = Visit(clad_compat::GetSubExpr(MTE), dfdx());
return MTEDiff;
if(dfdx()) {
StmtDiff MTEDiff = Visit(clad_compat::GetSubExpr(MTE), dfdx());
return MTEDiff;
}

Expr* MTEStore =
GlobalStoreAndRef(Clone(clad_compat::GetSubExpr(MTE)), "_t",
/*force=*/true);

auto* MTEStoreDRE = dyn_cast<DeclRefExpr>(MTEStore);
DeclDiff<VarDecl> MTEDerived =
DifferentiateVarDecl(dyn_cast<VarDecl>(MTEStoreDRE->getDecl()));
addToCurrentBlock(BuildDeclStmt(MTEDerived.getDecl_dx()));

return StmtDiff{MTEStore, BuildDeclRef(MTEDerived.getDecl_dx())};
}

StmtDiff ReverseModeVisitor::VisitSubstNonTypeTemplateParmExpr(
Expand Down Expand Up @@ -3928,8 +3945,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_DiffReq.Mode == DiffMode::experimental_pullback &&
!m_DiffReq->getReturnType()->isVoidType()) {
IdentifierInfo* pullbackParamII = CreateUniqueIdentifier("_d_y");
/*QualType pullbackType =*/
/* derivativeFnType->getParamType(m_DiffReq->getNumParams());*/
QualType pullbackType =
derivativeFnType->getParamType(m_DiffReq->getNumParams());
m_Context.DoubleTy;
ParmVarDecl* pullbackPVD = utils::BuildParmVarDecl(
m_Sema, m_Derivative, pullbackParamII, pullbackType);
paramDerivatives.insert(paramDerivatives.begin(), pullbackPVD);
Expand Down

0 comments on commit c3df195

Please sign in to comment.