From c37e97f5dd2047646bfe2013218dc6189cecfa4e Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Thu, 11 Jul 2024 15:59:08 +0300 Subject: [PATCH] fix --- lib/Differentiator/ReverseModeVisitor.cpp | 35 +++++++++++------------ 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 0716107d8..1388af0d7 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -131,6 +131,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_DiffReq->getNumParams() * 2 + numExtraParams; std::size_t numOfDerivativeParams = m_DiffReq->getNumParams() + numExtraParams; + // "Pullback parameter" here means the middle _d_y parameters used in + // pullbacks to represent the adjoint of the corresponding function call. + // Only Enzyme gradients and pullbacks of void functions don't have it. + bool hasPullbackParam = + !m_DiffReq.use_enzyme && !m_DiffReq->getReturnType()->isVoidType(); // Account for the this pointer. if (isa(m_DiffReq.Function) && !utils::IsStaticMethod(m_DiffReq.Function)) @@ -189,10 +194,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, callArgs.push_back(BuildDeclRef(VD)); } + std::size_t firstAdjParamIdx = m_DiffReq->getNumParams(); + if (hasPullbackParam) + ++firstAdjParamIdx; for (std::size_t i = 0; i < numOfDerivativeParams; ++i) { IdentifierInfo* II = nullptr; StorageClass SC = StorageClass::SC_None; - std::size_t effectiveGradientIndex = m_DiffReq->getNumParams() + i + 1; + std::size_t effectiveGradientIndex = firstAdjParamIdx + i; // `effectiveGradientIndex < gradientParams.size()` implies that this // parameter represents an actual derivative of one of the function // original parameters. @@ -229,11 +237,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Build derivatives to be used in the call to the actual derived function. // These are initialised by effectively casting the derivative parameters of // overloaded derived function to the correct type. - for (std::size_t i = m_DiffReq->getNumParams() + 1; - i < gradientParams.size(); ++i) { - // Overloads don't have the _d_y parameter like pullbacks. - // Therefore, we have to shift the parameter index by 1. - auto* overloadParam = overloadParams[i - 1]; + for (std::size_t i = firstAdjParamIdx; i < gradientParams.size(); ++i) { + // Overloads don't have the _d_y parameter like most pullbacks. + // Therefore, we have to shift the parameter index by 1 if the pullback + // has it. + auto* overloadParam = overloadParams[i - hasPullbackParam]; auto* gradientParam = gradientParams[i]; TypeSourceInfo* typeInfo = m_Context.getTrivialTypeSourceInfo(gradientParam->getType()); @@ -3653,18 +3661,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_DiffReq.Mode == DiffMode::reverse) { QualType effectiveReturnType = m_DiffReq->getReturnType().getNonReferenceType(); - // FIXME: Generally, we use the function's return type as the argument's - // derivative type. We cannot follow this strategy for `void` function - // return type. Thus, temporarily use `double` type as the placeholder - // type for argument derivatives. We should think of a more uniform and - // consistent solution to this problem. One effective strategy that may - // hold well: If we are differentiating a variable of type Y with - // respect to variable of type X, then the derivative should be of type - // X. Check this related issue for more details: - // https://github.com/vgvassilev/clad/issues/385 - if (effectiveReturnType->isVoidType()) - effectiveReturnType = m_Context.DoubleTy; - else + if (!effectiveReturnType->isVoidType() && !m_DiffReq.use_enzyme) paramTypes.push_back(effectiveReturnType); if (const auto* MD = dyn_cast(m_DiffReq.Function)) { @@ -3701,7 +3698,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, std::size_t dParamTypesIdx = m_DiffReq->getNumParams(); if (m_DiffReq.Mode == DiffMode::reverse && - !m_DiffReq->getReturnType()->isVoidType()) { + !m_DiffReq->getReturnType()->isVoidType() && !m_DiffReq.use_enzyme) { ++dParamTypesIdx; }