From c3df195bed9dd74e3531fe545f2877f5f385c960 Mon Sep 17 00:00:00 2001 From: Mihail Mihov Date: Wed, 17 Jul 2024 15:23:41 +0300 Subject: [PATCH] Fix temporary expressions in reverse mode fixes #917 --- lib/Differentiator/ReverseModeVisitor.cpp | 35 +++++++++++++++++------ 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 6394ee9dd..bfa9505b1 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1369,6 +1369,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(Clone(CE)); } + auto* CEModified = dyn_cast(Clone(CE)); + auto NArgs = FD->getNumParams(); // If the function has no args and is not a member function call then we // assume that it is not related to independent variables and does not @@ -1618,8 +1620,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, baseOriginalE = OCE->getArg(0); baseDiff = Visit(baseOriginalE); - Expr* baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr()); - baseDiff.updateStmt(baseDiffStore); + + if (auto* ME = dyn_cast(CEModified->getCallee())) + ME->setBase(baseDiff.getExpr()); + Expr* baseDerivative = baseDiff.getExpr_dx(); if (!baseDerivative->getType()->isPointerType()) baseDerivative = @@ -1878,13 +1882,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint"); return StmtDiff(resValue, nullptr, resAdjoint); } // Recreate the original call expression. + call = m_Sema - .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, + .ActOnCallExpr(getCurrentScope(), CEModified->getCallee(), Loc, CallArgs, Loc) .get(); - return StmtDiff(call); - return {}; + return StmtDiff(call); } Expr* ReverseModeVisitor::GetMultiArgCentralDiffCall( @@ -3747,8 +3751,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, const clang::MaterializeTemporaryExpr* MTE) { // `MaterializeTemporaryExpr` node will be created automatically if it is // required by `ActOn`/`Build` Sema functions. - StmtDiff MTEDiff = Visit(clad_compat::GetSubExpr(MTE), dfdx()); - return MTEDiff; + if(dfdx()) { + StmtDiff MTEDiff = Visit(clad_compat::GetSubExpr(MTE), dfdx()); + return MTEDiff; + } + + Expr* MTEStore = + GlobalStoreAndRef(Clone(clad_compat::GetSubExpr(MTE)), "_t", + /*force=*/true); + + auto* MTEStoreDRE = dyn_cast(MTEStore); + DeclDiff MTEDerived = + DifferentiateVarDecl(dyn_cast(MTEStoreDRE->getDecl())); + addToCurrentBlock(BuildDeclStmt(MTEDerived.getDecl_dx())); + + return StmtDiff{MTEStore, BuildDeclRef(MTEDerived.getDecl_dx())}; } StmtDiff ReverseModeVisitor::VisitSubstNonTypeTemplateParmExpr( @@ -3928,8 +3945,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_DiffReq.Mode == DiffMode::experimental_pullback && !m_DiffReq->getReturnType()->isVoidType()) { IdentifierInfo* pullbackParamII = CreateUniqueIdentifier("_d_y"); + /*QualType pullbackType =*/ + /* derivativeFnType->getParamType(m_DiffReq->getNumParams());*/ QualType pullbackType = - derivativeFnType->getParamType(m_DiffReq->getNumParams()); + m_Context.DoubleTy; ParmVarDecl* pullbackPVD = utils::BuildParmVarDecl( m_Sema, m_Derivative, pullbackParamII, pullbackType); paramDerivatives.insert(paramDerivatives.begin(), pullbackPVD);