Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Jul 11, 2024
1 parent 23f39bc commit e9504a6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 22 deletions.
37 changes: 18 additions & 19 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
auto gradientParams = m_Derivative->parameters();
std::string name =
m_DiffReq.BaseFunctionName + "_grad" + diffParamsPostfix(m_DiffReq);
if (m_DiffReq.use_enzyme)
name += "_enzyme";
IdentifierInfo* II = &m_Context.Idents.get(name);
DeclarationNameInfo DNI(II, noLoc);
// Calculate the total number of parameters that would be required for
Expand All @@ -131,6 +133,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_DiffReq->getNumParams() * 2 + numExtraParams;
std::size_t numOfDerivativeParams =
m_DiffReq->getNumParams() + numExtraParams;
// "Pullback parameter" here means the middle _d_y parameters used in
// pullbacks to represent the adjoint of the corresponding function call.
// Only Enzyme gradients and pullbacks of void functions don't have it.
bool hasPullbackParam =
!m_DiffReq.use_enzyme && !m_DiffReq->getReturnType()->isVoidType();
// Account for the this pointer.
if (isa<CXXMethodDecl>(m_DiffReq.Function) &&
!utils::IsStaticMethod(m_DiffReq.Function))
Expand Down Expand Up @@ -189,10 +196,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
callArgs.push_back(BuildDeclRef(VD));
}

std::size_t firstAdjParamIdx = m_DiffReq->getNumParams();
if (hasPullbackParam)
++firstAdjParamIdx;
for (std::size_t i = 0; i < numOfDerivativeParams; ++i) {
IdentifierInfo* II = nullptr;
StorageClass SC = StorageClass::SC_None;
std::size_t effectiveGradientIndex = m_DiffReq->getNumParams() + i + 1;
std::size_t effectiveGradientIndex = firstAdjParamIdx + i;
// `effectiveGradientIndex < gradientParams.size()` implies that this
// parameter represents an actual derivative of one of the function
// original parameters.
Expand Down Expand Up @@ -229,11 +239,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Build derivatives to be used in the call to the actual derived function.
// These are initialised by effectively casting the derivative parameters of
// overloaded derived function to the correct type.
for (std::size_t i = m_DiffReq->getNumParams() + 1;
i < gradientParams.size(); ++i) {
// Overloads don't have the _d_y parameter like pullbacks.
// Therefore, we have to shift the parameter index by 1.
auto* overloadParam = overloadParams[i - 1];
for (std::size_t i = firstAdjParamIdx; i < gradientParams.size(); ++i) {
// Overloads don't have the _d_y parameter like most pullbacks.
// Therefore, we have to shift the parameter index by 1 if the pullback
// has it.
auto* overloadParam = overloadParams[i - hasPullbackParam];
auto* gradientParam = gradientParams[i];
TypeSourceInfo* typeInfo =
m_Context.getTrivialTypeSourceInfo(gradientParam->getType());
Expand Down Expand Up @@ -3653,18 +3663,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_DiffReq.Mode == DiffMode::reverse) {
QualType effectiveReturnType =
m_DiffReq->getReturnType().getNonReferenceType();
// FIXME: Generally, we use the function's return type as the argument's
// derivative type. We cannot follow this strategy for `void` function
// return type. Thus, temporarily use `double` type as the placeholder
// type for argument derivatives. We should think of a more uniform and
// consistent solution to this problem. One effective strategy that may
// hold well: If we are differentiating a variable of type Y with
// respect to variable of type X, then the derivative should be of type
// X. Check this related issue for more details:
// https://github.com/vgvassilev/clad/issues/385
if (effectiveReturnType->isVoidType())
effectiveReturnType = m_Context.DoubleTy;
else
if (!effectiveReturnType->isVoidType() && !m_DiffReq.use_enzyme)
paramTypes.push_back(effectiveReturnType);

if (const auto* MD = dyn_cast<CXXMethodDecl>(m_DiffReq.Function)) {
Expand Down Expand Up @@ -3701,7 +3700,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
std::size_t dParamTypesIdx = m_DiffReq->getNumParams();

if (m_DiffReq.Mode == DiffMode::reverse &&
!m_DiffReq->getReturnType()->isVoidType()) {
!m_DiffReq->getReturnType()->isVoidType() && !m_DiffReq.use_enzyme) {
++dParamTypesIdx;
}

Expand Down
6 changes: 3 additions & 3 deletions test/Enzyme/DifferentCladEnzymeDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ double foo(double x, double y){
return x*y;
}

// CHECK: void foo_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK: void foo_pullback(double x, double y, double _d_y0, double *_d_x, double *_d_y) {
// CHECK-NEXT: {
// CHECK-NEXT: *_d_x += 1 * y;
// CHECK-NEXT: *_d_y += x * 1;
// CHECK-NEXT: *_d_x += _d_y0 * y;
// CHECK-NEXT: *_d_y += x * _d_y0;
// CHECK-NEXT: }
// CHECK-NEXT: }

Expand Down

0 comments on commit e9504a6

Please sign in to comment.