Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Jul 11, 2024
1 parent 23f39bc commit c37e97f
Showing 1 changed file with 16 additions and 19 deletions.
35 changes: 16 additions & 19 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_DiffReq->getNumParams() * 2 + numExtraParams;
std::size_t numOfDerivativeParams =
m_DiffReq->getNumParams() + numExtraParams;
// "Pullback parameter" here means the middle _d_y parameters used in
// pullbacks to represent the adjoint of the corresponding function call.
// Only Enzyme gradients and pullbacks of void functions don't have it.
bool hasPullbackParam =
!m_DiffReq.use_enzyme && !m_DiffReq->getReturnType()->isVoidType();
// Account for the this pointer.
if (isa<CXXMethodDecl>(m_DiffReq.Function) &&
!utils::IsStaticMethod(m_DiffReq.Function))
Expand Down Expand Up @@ -189,10 +194,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
callArgs.push_back(BuildDeclRef(VD));
}

std::size_t firstAdjParamIdx = m_DiffReq->getNumParams();
if (hasPullbackParam)
++firstAdjParamIdx;
for (std::size_t i = 0; i < numOfDerivativeParams; ++i) {
IdentifierInfo* II = nullptr;
StorageClass SC = StorageClass::SC_None;
std::size_t effectiveGradientIndex = m_DiffReq->getNumParams() + i + 1;
std::size_t effectiveGradientIndex = firstAdjParamIdx + i;
// `effectiveGradientIndex < gradientParams.size()` implies that this
// parameter represents an actual derivative of one of the function
// original parameters.
Expand Down Expand Up @@ -229,11 +237,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Build derivatives to be used in the call to the actual derived function.
// These are initialised by effectively casting the derivative parameters of
// overloaded derived function to the correct type.
for (std::size_t i = m_DiffReq->getNumParams() + 1;
i < gradientParams.size(); ++i) {
// Overloads don't have the _d_y parameter like pullbacks.
// Therefore, we have to shift the parameter index by 1.
auto* overloadParam = overloadParams[i - 1];
for (std::size_t i = firstAdjParamIdx; i < gradientParams.size(); ++i) {
// Overloads don't have the _d_y parameter like most pullbacks.
// Therefore, we have to shift the parameter index by 1 if the pullback
// has it.
auto* overloadParam = overloadParams[i - hasPullbackParam];
auto* gradientParam = gradientParams[i];
TypeSourceInfo* typeInfo =
m_Context.getTrivialTypeSourceInfo(gradientParam->getType());
Expand Down Expand Up @@ -3653,18 +3661,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_DiffReq.Mode == DiffMode::reverse) {
QualType effectiveReturnType =
m_DiffReq->getReturnType().getNonReferenceType();
// FIXME: Generally, we use the function's return type as the argument's
// derivative type. We cannot follow this strategy for `void` function
// return type. Thus, temporarily use `double` type as the placeholder
// type for argument derivatives. We should think of a more uniform and
// consistent solution to this problem. One effective strategy that may
// hold well: If we are differentiating a variable of type Y with
// respect to variable of type X, then the derivative should be of type
// X. Check this related issue for more details:
// https://github.com/vgvassilev/clad/issues/385
if (effectiveReturnType->isVoidType())
effectiveReturnType = m_Context.DoubleTy;
else
if (!effectiveReturnType->isVoidType() && !m_DiffReq.use_enzyme)
paramTypes.push_back(effectiveReturnType);

if (const auto* MD = dyn_cast<CXXMethodDecl>(m_DiffReq.Function)) {
Expand Down Expand Up @@ -3701,7 +3698,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
std::size_t dParamTypesIdx = m_DiffReq->getNumParams();

if (m_DiffReq.Mode == DiffMode::reverse &&
!m_DiffReq->getReturnType()->isVoidType()) {
!m_DiffReq->getReturnType()->isVoidType() && !m_DiffReq.use_enzyme) {
++dParamTypesIdx;
}

Expand Down

0 comments on commit c37e97f

Please sign in to comment.