From cba208d18ea4faee3b673d1c1d8714ca1a8e2275 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Fri, 29 Nov 2024 00:07:30 +0100 Subject: [PATCH] Use a single point to process non-differentiable functions --- lib/Differentiator/ReverseModeVisitor.cpp | 65 ++++++++--------------- 1 file changed, 21 insertions(+), 44 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 9ecf043c8..f771d6d3e 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1467,27 +1467,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (const auto* KCE = dyn_cast(CE)) CUDAExecConfig = Clone(KCE->getConfig()); - // If the function is non_differentiable, return zero derivative. - if (clad::utils::hasNonDifferentiableAttribute(CE)) { - // Calling the function without computing derivatives - llvm::SmallVector ClonedArgs; - for (unsigned i = 0, e = CE->getNumArgs(); i < e; ++i) - ClonedArgs.push_back(Clone(CE->getArg(i))); - - SourceLocation validLoc = clad::utils::GetValidSLoc(m_Sema); - Expr* Call = - m_Sema - .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), - validLoc, ClonedArgs, validLoc, CUDAExecConfig) - .get(); - // Creating a zero derivative - auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, - /*val=*/0); - - // Returning the function call and zero derivative - return StmtDiff(Call, zero); - } - // begin and end are common enough to have a more efficient and nice-looking // special case. Instead of _forw and a useless _pullback functions, we can // express the result in terms of the same std::begin / std::end. Note: @@ -1512,13 +1491,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } auto NArgs = FD->getNumParams(); - // If the function has no args and is not a member function call then we - // assume that it is not related to independent variables and does not - // contribute to gradient. - if ((NArgs == 0U) && !isa(CE) && - !isa(CE)) - return StmtDiff(Clone(CE)); - SourceLocation Loc = CE->getExprLoc(); // Stores the call arguments for the function to be derived @@ -1603,11 +1575,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(); } + bool nonDiff = clad::utils::hasNonDifferentiableAttribute(CE); + + // If the result does not depend on the result of the call, just clone + // the call and visit arguments (since they may contain side-effects like + // f(x = y)) + // If the callee function takes arguments by reference then it can affect + // derivatives even if there is no `dfdx()` and thus we should call the + // derived function. In the case of member functions, `implicit` + // this object is always passed by reference. + if (!nonDiff && !dfdx() && !utils::HasAnyReferenceOrPointerArgument(FD) && + !isa(CE) && !isa(CE)) + nonDiff = true; + // If all arguments are constant literals, then this does not contribute to // the gradient. - // FIXME: revert this when this is integrated in the activity analysis pass. - if (!isa(CE) && !isa(CE)) { - bool allArgsAreConstantLiterals = true; + if (!nonDiff && !isa(CE) && + !isa(CE)) { + bool allArgsAreConstant = true; for (const Expr* arg : CE->arguments()) { // if it's of type MaterializeTemporaryExpr, then check its // subexpression. @@ -1634,25 +1619,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } } analyzer(m_DiffReq); if (analyzer.isVariedE(arg)) { - allArgsAreConstantLiterals = false; + allArgsAreConstant = false; break; } } - if (allArgsAreConstantLiterals) - return StmtDiff(Clone(CE), Clone(CE)); + if (allArgsAreConstant) + nonDiff = true; } - // If the result does not depend on the result of the call, just clone - // the call and visit arguments (since they may contain side-effects like - // f(x = y)) - // If the callee function takes arguments by reference then it can affect - // derivatives even if there is no `dfdx()` and thus we should call the - // derived function. In the case of member functions, `implicit` - // this object is always passed by reference. - if (!dfdx() && !utils::HasAnyReferenceOrPointerArgument(FD) && - !isa(CE) && !isa(CE)) { + if (nonDiff) { for (const Expr* Arg : CE->arguments()) { - StmtDiff ArgDiff = Visit(Arg, dfdx()); + StmtDiff ArgDiff = Visit(Arg); CallArgs.push_back(ArgDiff.getExpr()); } Expr* call =