From bca7dee64f4601ba1e23be0af42d1d48889dba0a Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Tue, 3 Dec 2024 17:51:05 +0200 Subject: [PATCH] No need to handle recursive calls separately --- lib/Differentiator/ReverseModeVisitor.cpp | 166 ++++++++++------------ 1 file changed, 74 insertions(+), 92 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 9ecf043c8..5718ad8de 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1953,99 +1953,81 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Derivative was not found, check if it is a recursive call if (!OverloadedDerivedFn) { - if (FD == m_DiffReq.Function && - m_DiffReq.Mode == DiffMode::experimental_pullback) { - // Recursive call. - Expr* selfRef = - m_Sema - .BuildDeclarationNameExpr( - CXXScopeSpec(), m_Derivative->getNameInfo(), m_Derivative) - .get(); + if (m_ExternalSource) + m_ExternalSource->ActBeforeDifferentiatingCallExpr( + pullbackCallArgs, PreCallStmts, dfdx()); - OverloadedDerivedFn = - m_Sema - .ActOnCallExpr(getCurrentScope(), selfRef, Loc, - pullbackCallArgs, Loc, CUDAExecConfig) - .get(); - } else { - if (m_ExternalSource) - m_ExternalSource->ActBeforeDifferentiatingCallExpr( - pullbackCallArgs, PreCallStmts, dfdx()); - - // Overloaded derivative was not found, request the CladPlugin to - // derive the called function. - DiffRequest pullbackRequest{}; - pullbackRequest.Function = FD; - - // Mark the indexes of the global args. Necessary if the argument of the - // call has a different name than the function's signature parameter. - pullbackRequest.CUDAGlobalArgsIndexes = globalCallArgs; - - pullbackRequest.BaseFunctionName = - clad::utils::ComputeEffectiveFnName(FD); - pullbackRequest.Mode = DiffMode::experimental_pullback; - // Silence diag outputs in nested derivation process. - pullbackRequest.VerboseDiags = false; - pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis; - pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis; - bool isaMethod = isa(FD); - for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) - if (MD && isLambdaCallOperator(MD)) { - if (const auto* paramDecl = FD->getParamDecl(i)) - pullbackRequest.DVI.push_back(paramDecl); - } else if (DerivedCallOutputArgs[i + isaMethod]) - pullbackRequest.DVI.push_back(FD->getParamDecl(i)); - - FunctionDecl* pullbackFD = nullptr; - if (m_ExternalSource) - // FIXME: Error estimation currently uses singleton objects - - // m_ErrorEstHandler and m_EstModel, which is cleared after each - // error_estimate request. This requires the pullback to be derived - // at the same time to access the singleton objects. - pullbackFD = - plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest); - else - pullbackFD = m_Builder.HandleNestedDiffRequest(pullbackRequest); - - // Clad failed to derive it. - // FIXME: Add support for reference arguments to the numerical diff. If - // it already correctly support reference arguments then confirm the - // support and add tests for the same. - if (!pullbackFD && !utils::HasAnyReferenceOrPointerArgument(FD) && - !isa(FD)) { - // Try numerically deriving it. - if (NArgs == 1) { - OverloadedDerivedFn = GetSingleArgCentralDiffCall( - Clone(CE->getCallee()), DerivedCallArgs[0], - /*targetPos=*/0, - /*numArgs=*/1, DerivedCallArgs, CUDAExecConfig); - asGrad = !OverloadedDerivedFn; - } else { - auto CEType = getNonConstType(CE->getType(), m_Context, m_Sema); - OverloadedDerivedFn = GetMultiArgCentralDiffCall( - Clone(CE->getCallee()), CEType.getCanonicalType(), - CE->getNumArgs(), dfdx(), PreCallStmts, PostCallStmts, - DerivedCallArgs, CallArgDx, CUDAExecConfig); - } - CallExprDiffDiagnostics(FD, CE->getBeginLoc()); - if (!OverloadedDerivedFn) { - Stmts& block = getCurrentBlock(direction::reverse); - block.insert(block.begin(), PreCallStmts.begin(), - PreCallStmts.end()); - return StmtDiff(Clone(CE)); - } - } else if (pullbackFD) { - if (baseDiff.getExpr()) { - Expr* baseE = baseDiff.getExpr(); - OverloadedDerivedFn = BuildCallExprToMemFn( - baseE, pullbackFD->getName(), pullbackCallArgs, Loc); - } else { - OverloadedDerivedFn = - m_Sema - .ActOnCallExpr(getCurrentScope(), BuildDeclRef(pullbackFD), - Loc, pullbackCallArgs, Loc, CUDAExecConfig) - .get(); - } + // Overloaded derivative was not found, request the CladPlugin to + // derive the called function. + DiffRequest pullbackRequest{}; + pullbackRequest.Function = FD; + + // Mark the indexes of the global args. Necessary if the argument of the + // call has a different name than the function's signature parameter. + pullbackRequest.CUDAGlobalArgsIndexes = globalCallArgs; + + pullbackRequest.BaseFunctionName = + clad::utils::ComputeEffectiveFnName(FD); + pullbackRequest.Mode = DiffMode::experimental_pullback; + // Silence diag outputs in nested derivation process. + pullbackRequest.VerboseDiags = false; + pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis; + pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis; + bool isaMethod = isa(FD); + for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) + if (MD && isLambdaCallOperator(MD)) { + if (const auto* paramDecl = FD->getParamDecl(i)) + pullbackRequest.DVI.push_back(paramDecl); + } else if (DerivedCallOutputArgs[i + isaMethod]) + pullbackRequest.DVI.push_back(FD->getParamDecl(i)); + + FunctionDecl* pullbackFD = nullptr; + if (m_ExternalSource) + // FIXME: Error estimation currently uses singleton objects - + // m_ErrorEstHandler and m_EstModel, which is cleared after each + // error_estimate request. This requires the pullback to be derived + // at the same time to access the singleton objects. + pullbackFD = plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest); + else + pullbackFD = m_Builder.HandleNestedDiffRequest(pullbackRequest); + + // Clad failed to derive it. + // FIXME: Add support for reference arguments to the numerical diff. If + // it already correctly support reference arguments then confirm the + // support and add tests for the same. + if (!pullbackFD && !utils::HasAnyReferenceOrPointerArgument(FD) && + !isa(FD)) { + // Try numerically deriving it. + if (NArgs == 1) { + OverloadedDerivedFn = GetSingleArgCentralDiffCall( + Clone(CE->getCallee()), DerivedCallArgs[0], + /*targetPos=*/0, + /*numArgs=*/1, DerivedCallArgs, CUDAExecConfig); + asGrad = !OverloadedDerivedFn; + } else { + auto CEType = getNonConstType(CE->getType(), m_Context, m_Sema); + OverloadedDerivedFn = GetMultiArgCentralDiffCall( + Clone(CE->getCallee()), CEType.getCanonicalType(), + CE->getNumArgs(), dfdx(), PreCallStmts, PostCallStmts, + DerivedCallArgs, CallArgDx, CUDAExecConfig); + } + CallExprDiffDiagnostics(FD, CE->getBeginLoc()); + if (!OverloadedDerivedFn) { + Stmts& block = getCurrentBlock(direction::reverse); + block.insert(block.begin(), PreCallStmts.begin(), PreCallStmts.end()); + return StmtDiff(Clone(CE)); + } + } else if (pullbackFD) { + if (baseDiff.getExpr()) { + Expr* baseE = baseDiff.getExpr(); + OverloadedDerivedFn = BuildCallExprToMemFn( + baseE, pullbackFD->getName(), pullbackCallArgs, Loc); + } else { + OverloadedDerivedFn = + m_Sema + .ActOnCallExpr(getCurrentScope(), BuildDeclRef(pullbackFD), + Loc, pullbackCallArgs, Loc, CUDAExecConfig) + .get(); } } }