diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 6f1df389a..ed53eda93 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1368,6 +1368,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(Clone(CE)); } + auto* CEModified = dyn_cast(Clone(CE)); + auto NArgs = FD->getNumParams(); // If the function has no args and is not a member function call then we // assume that it is not related to independent variables and does not @@ -1610,8 +1612,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, baseOriginalE = OCE->getArg(0); baseDiff = Visit(baseOriginalE); - Expr* baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr()); - baseDiff.updateStmt(baseDiffStore); + + if(auto* ME = dyn_cast(CEModified->getCallee())) + ME->setBase(baseDiff.getExpr()); + Expr* baseDerivative = baseDiff.getExpr_dx(); if (!baseDerivative->getType()->isPointerType()) baseDerivative = @@ -1867,13 +1871,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint"); return StmtDiff(resValue, nullptr, resAdjoint); } // Recreate the original call expression. + call = m_Sema - .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, + .ActOnCallExpr(getCurrentScope(), CEModified->getCallee(), Loc, CallArgs, Loc) .get(); - return StmtDiff(call); - return {}; + return StmtDiff(call); } Expr* ReverseModeVisitor::GetMultiArgCentralDiffCall( @@ -3708,8 +3712,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, const clang::MaterializeTemporaryExpr* MTE) { // `MaterializeTemporaryExpr` node will be created automatically if it is // required by `ActOn`/`Build` Sema functions. - StmtDiff MTEDiff = Visit(clad_compat::GetSubExpr(MTE), dfdx()); - return MTEDiff; + Expr* MTEStore = GlobalStoreAndRef(Clone(clad_compat::GetSubExpr(MTE)), "_t", + /*force=*/true); + + auto* MTEStoreDRE = dyn_cast(MTEStore); + DeclDiff MTEDerived = + DifferentiateVarDecl(dyn_cast(MTEStoreDRE->getDecl())); + addToCurrentBlock(BuildDeclStmt(MTEDerived.getDecl_dx())); + + return StmtDiff{MTEStore, BuildDeclRef(MTEDerived.getDecl_dx())}; } StmtDiff ReverseModeVisitor::VisitSubstNonTypeTemplateParmExpr(