diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 75549acaa..2c234a4d7 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -386,8 +386,6 @@ namespace clad { /// y" will give 'f_grad_0_1' and "x, z" will give 'f_grad_0_2'. DerivativeAndOverload Derive(const clang::FunctionDecl* FD, const DiffRequest& request); - DerivativeAndOverload DerivePullback(const clang::FunctionDecl* FD, - const DiffRequest& request); StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE); StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp); StmtDiff VisitCallExpr(const clang::CallExpr* CE); diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index d688a64ae..7c86d4fe6 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -419,17 +419,13 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { result = V.DerivePushforward(FD, request); } else if (request.Mode == DiffMode::reverse) { ReverseModeVisitor V(*this, request); - if (request.CallUpdateRequired) { - result = V.Derive(FD, request); - } else { - if (!m_ErrorEstHandler.empty()) { - InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this, request); - V.AddExternalSource(*m_ErrorEstHandler.back()); - } - result = V.DerivePullback(FD, request); - if (!m_ErrorEstHandler.empty()) - CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel); + if (!request.CallUpdateRequired && !m_ErrorEstHandler.empty()) { + InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this, request); + V.AddExternalSource(*m_ErrorEstHandler.back()); } + result = V.Derive(FD, request); + if (!request.CallUpdateRequired && !m_ErrorEstHandler.empty()) + CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel); } else if (request.Mode == DiffMode::reverse_mode_forward_pass) { ReverseModeForwPassVisitor V(*this, request); result = V.Derive(FD, request); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 5dd3c23cb..1950fcd9f 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -120,6 +120,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto gradientParams = m_Derivative->parameters(); std::string name = m_DiffReq.BaseFunctionName + "_grad" + diffParamsPostfix(m_DiffReq); + if (m_DiffReq.use_enzyme) + name += "_enzyme"; IdentifierInfo* II = &m_Context.Idents.get(name); DeclarationNameInfo DNI(II, noLoc); // Calculate the total number of parameters that would be required for @@ -131,6 +133,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_DiffReq->getNumParams() * 2 + numExtraParams; std::size_t numOfDerivativeParams = m_DiffReq->getNumParams() + numExtraParams; + // "Pullback parameter" here means the middle _d_y parameters used in + // pullbacks to represent the adjoint of the corresponding function call. + // Only Enzyme gradients and pullbacks of void functions don't have it. + bool hasPullbackParam = + !m_DiffReq.use_enzyme && !m_DiffReq->getReturnType()->isVoidType(); // Account for the this pointer. if (isa(m_DiffReq.Function) && !utils::IsStaticMethod(m_DiffReq.Function)) @@ -189,10 +196,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, callArgs.push_back(BuildDeclRef(VD)); } + std::size_t firstAdjParamIdx = m_DiffReq->getNumParams(); + if (hasPullbackParam) + ++firstAdjParamIdx; for (std::size_t i = 0; i < numOfDerivativeParams; ++i) { IdentifierInfo* II = nullptr; StorageClass SC = StorageClass::SC_None; - std::size_t effectiveGradientIndex = m_DiffReq->getNumParams() + i + 1; + std::size_t effectiveGradientIndex = firstAdjParamIdx + i; // `effectiveGradientIndex < gradientParams.size()` implies that this // parameter represents an actual derivative of one of the function // original parameters. @@ -229,11 +239,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Build derivatives to be used in the call to the actual derived function. // These are initialised by effectively casting the derivative parameters of // overloaded derived function to the correct type. - for (std::size_t i = m_DiffReq->getNumParams() + 1; - i < gradientParams.size(); ++i) { - // Overloads don't have the _d_y parameter like pullbacks. - // Therefore, we have to shift the parameter index by 1. - auto* overloadParam = overloadParams[i - 1]; + for (std::size_t i = firstAdjParamIdx; i < gradientParams.size(); ++i) { + // Overloads don't have the _d_y parameter like most pullbacks. + // Therefore, we have to shift the parameter index by 1 if the pullback + // has it. + auto* overloadParam = overloadParams[i - hasPullbackParam]; auto* gradientParam = gradientParams[i]; TypeSourceInfo* typeInfo = m_Context.getTrivialTypeSourceInfo(gradientParam->getType()); @@ -283,15 +293,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, assert(m_DiffReq.Function && "Must not be null."); DiffParams args{}; - DiffInputVarsInfo DVI; - if (request.Args) { - DVI = request.DVI; - for (const auto& dParam : DVI) + if (!request.DVI.empty()) + for (const auto& dParam : request.DVI) args.push_back(dParam.param); - } else std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); - if (args.empty()) + if (args.empty() && (!isa(FD) || utils::IsStaticMethod(FD))) return {}; if (m_ExternalSource) @@ -303,10 +310,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, outputArrayStr = m_DiffReq->getParamDecl(lastArgN)->getNameAsString(); } - auto derivativeBaseName = request.BaseFunctionName; - std::string gradientName = derivativeBaseName + funcPostfix(m_DiffReq); + std::string derivativeBaseName = request.BaseFunctionName; + std::string derivativeName = derivativeBaseName + funcPostfix(m_DiffReq); - IdentifierInfo* II = &m_Context.Idents.get(gradientName); + IdentifierInfo* II = &m_Context.Idents.get(derivativeName); DeclarationNameInfo name(II, noLoc); // If we are in error estimation mode, we have an extra `double&` @@ -323,7 +330,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // If reverse mode differentiates only part of the arguments it needs to // generate an overload that can take in all the diff variables bool shouldCreateOverload = false; - if (request.Mode != DiffMode::jacobian) + if (request.Mode != DiffMode::jacobian && m_DiffReq.CallUpdateRequired) shouldCreateOverload = true; if (!request.DeclarationOnly && !request.DerivedFDPrototypes.empty()) // If the overload is already created, we don't need to create it again. @@ -347,7 +354,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) auto* DC = const_cast(m_DiffReq->getDeclContext()); if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl( - gradientName, DC, gradientFunctionType)) { + derivativeName, DC, gradientFunctionType)) { // Set m_Derivative for creating the overload. m_Derivative = customDerivative; FunctionDecl* gradientOverloadFD = nullptr; @@ -459,137 +466,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return DerivativeAndOverload{result.first, gradientOverloadFD}; } - DerivativeAndOverload - ReverseModeVisitor::DerivePullback(const clang::FunctionDecl* FD, - const DiffRequest& request) { - // FIXME: Duplication of external source here is a workaround - // for the two 'Derive's being different functions. - if (m_ExternalSource) - m_ExternalSource->ActOnStartOfDerive(); - silenceDiags = !request.VerboseDiags; - // FIXME: We should not use const_cast to get the decl request here. - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast(m_DiffReq) = request; - assert(m_DiffReq.Function && "Must not be null."); - - DiffParams args{}; - if (!request.DVI.empty()) - for (const auto& dParam : request.DVI) - args.push_back(dParam.param); - else - std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); -#ifndef NDEBUG - bool isStaticMethod = utils::IsStaticMethod(FD); - assert((!args.empty() || !isStaticMethod) && - "Cannot generate pullback function of a function " - "with no differentiable arguments"); -#endif - - if (m_ExternalSource) - m_ExternalSource->ActAfterParsingDiffArgs(request, args); - - auto derivativeName = utils::ComputeEffectiveFnName(m_DiffReq.Function) + - funcPostfix(m_DiffReq); - auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName); - - auto paramTypes = ComputeParamTypes(args); - const auto* originalFnType = - dyn_cast(m_DiffReq->getType()); - - if (m_ExternalSource) - m_ExternalSource->ActAfterCreatingDerivedFnParamTypes(paramTypes); - - QualType pullbackFnType = m_Context.getFunctionType( - m_Context.VoidTy, paramTypes, originalFnType->getExtProtoInfo()); - - // Check if the function is already declared as a custom derivative. - // FIXME: We should not use const_cast to get the decl context here. - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - auto* DC = const_cast(m_DiffReq->getDeclContext()); - if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl( - derivativeName, DC, pullbackFnType)) - return DerivativeAndOverload{customDerivative, nullptr}; - - llvm::SaveAndRestore saveContext(m_Sema.CurContext); - llvm::SaveAndRestore saveScope(getCurrentScope(), - getEnclosingNamespaceOrTUScope()); - // FIXME: We should not use const_cast to get the decl context here. - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - m_Sema.CurContext = const_cast(m_DiffReq->getDeclContext()); - - SourceLocation validLoc{m_DiffReq->getLocation()}; - DeclWithContext fnBuildRes = - m_Builder.cloneFunction(m_DiffReq.Function, *this, m_Sema.CurContext, - validLoc, DNI, pullbackFnType); - m_Derivative = fnBuildRes.first; - - if (m_ExternalSource) - m_ExternalSource->ActBeforeCreatingDerivedFnScope(); - - beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | - Scope::DeclScope); - m_Sema.PushFunctionScope(); - m_Sema.PushDeclContext(getCurrentScope(), m_Derivative); - - if (m_ExternalSource) - m_ExternalSource->ActAfterCreatingDerivedFnScope(); - - auto params = BuildParams(args); - if (m_ExternalSource) - m_ExternalSource->ActAfterCreatingDerivedFnParams(params); - - m_Derivative->setParams(params); - m_Derivative->setBody(nullptr); - - if (!request.DeclarationOnly) { - if (m_ExternalSource) - m_ExternalSource->ActBeforeCreatingDerivedFnBodyScope(); - - beginScope(Scope::FnScope | Scope::DeclScope); - m_DerivativeFnScope = getCurrentScope(); - - beginBlock(); - if (m_ExternalSource) - m_ExternalSource->ActOnStartOfDerivedFnBody(request); - - StmtDiff bodyDiff = Visit(m_DiffReq->getBody()); - Stmt* forward = bodyDiff.getStmt(); - Stmt* reverse = bodyDiff.getStmt_dx(); - - // Create the body of the function. - // Firstly, all "global" Stmts are put into fn's body. - for (Stmt* S : m_Globals) - addToCurrentBlock(S, direction::forward); - // Forward pass. - if (auto* CS = dyn_cast(forward)) - for (Stmt* S : CS->body()) - addToCurrentBlock(S, direction::forward); - - // Reverse pass. - if (auto* RCS = dyn_cast(reverse)) - for (Stmt* S : RCS->body()) - addToCurrentBlock(S, direction::forward); - - if (m_ExternalSource) - m_ExternalSource->ActOnEndOfDerivedFnBody(); - - Stmt* fnBody = endBlock(); - m_Derivative->setBody(fnBody); - endScope(); // Function body scope - - // Size >= current derivative order means that there exists a declaration - // or prototype for the currently derived function. - if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder) - m_Derivative->setPreviousDeclaration( - request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]); - } - m_Sema.PopFunctionScopeInfo(); - m_Sema.PopDeclContext(); - endScope(); // Function decl scope - - return DerivativeAndOverload{fnBuildRes.first, nullptr}; - } - void ReverseModeVisitor::DifferentiateWithClad() { llvm::ArrayRef paramsRef = m_Derivative->parameters(); @@ -3905,18 +3781,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_DiffReq.Mode == DiffMode::reverse) { QualType effectiveReturnType = m_DiffReq->getReturnType().getNonReferenceType(); - // FIXME: Generally, we use the function's return type as the argument's - // derivative type. We cannot follow this strategy for `void` function - // return type. Thus, temporarily use `double` type as the placeholder - // type for argument derivatives. We should think of a more uniform and - // consistent solution to this problem. One effective strategy that may - // hold well: If we are differentiating a variable of type Y with - // respect to variable of type X, then the derivative should be of type - // X. Check this related issue for more details: - // https://github.com/vgvassilev/clad/issues/385 - if (effectiveReturnType->isVoidType()) - effectiveReturnType = m_Context.DoubleTy; - else + if (!effectiveReturnType->isVoidType() && !m_DiffReq.use_enzyme) paramTypes.push_back(effectiveReturnType); if (const auto* MD = dyn_cast(m_DiffReq.Function)) { @@ -3953,7 +3818,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, std::size_t dParamTypesIdx = m_DiffReq->getNumParams(); if (m_DiffReq.Mode == DiffMode::reverse && - !m_DiffReq->getReturnType()->isVoidType()) { + !m_DiffReq->getReturnType()->isVoidType() && !m_DiffReq.use_enzyme) { ++dParamTypesIdx; } diff --git a/test/Enzyme/DifferentCladEnzymeDerivatives.C b/test/Enzyme/DifferentCladEnzymeDerivatives.C index 6ae2eb037..753201929 100644 --- a/test/Enzyme/DifferentCladEnzymeDerivatives.C +++ b/test/Enzyme/DifferentCladEnzymeDerivatives.C @@ -10,10 +10,10 @@ double foo(double x, double y){ return x*y; } -// CHECK: void foo_grad(double x, double y, double *_d_x, double *_d_y) { +// CHECK: void foo_pullback(double x, double y, double _d_y0, double *_d_x, double *_d_y) { // CHECK-NEXT: { -// CHECK-NEXT: *_d_x += 1 * y; -// CHECK-NEXT: *_d_y += x * 1; +// CHECK-NEXT: *_d_x += _d_y0 * y; +// CHECK-NEXT: *_d_y += x * _d_y0; // CHECK-NEXT: } // CHECK-NEXT: }