Skip to content

Commit

Permalink
Don't consider arrays as a special case in DifferentiateVarDecl
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Dec 8, 2024
1 parent 20784b6 commit eee6faa
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 102 deletions.
182 changes: 85 additions & 97 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2755,116 +2755,104 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// VDDerivedInit now serves two purposes -- as the initial derivative value
// or the size of the derivative array -- depending on the primal type.
if (const auto* AT = dyn_cast<ArrayType>(VDType)) {
if (!isa<VariableArrayType>(AT)) {
Expr* zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
VDDerivedInit = m_Sema.ActOnInitList(noLoc, {zero}, noLoc).get();
}
if (promoteToFnScope) {
if (promoteToFnScope)
if (const auto* AT = dyn_cast<ArrayType>(VDType))
// If an array-type declaration is promoted to function global,
// its type is changed for clad::array. In that case we should
// initialize it with its size.
initDiff = getArraySizeExpr(AT, m_Context, *this);
}
VDDerived = BuildGlobalVarDecl(
VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit, false,
nullptr, VarDecl::InitializationStyle::CInit);
} else {
// If VD is a reference to a local variable, then the initial value is set
// to the derived variable of the corresponding local variable.
// If VD is a reference to a non-local variable (global variable, struct
// member etc), then no derived variable is available, thus `VDDerived`
// does not need to reference any variable, consequentially the
// `VDDerivedType` is the corresponding non-reference type and the initial
// value is set to 0.
// Otherwise, for non-reference types, the initial value is set to 0.
if (!VDDerivedInit)
VDDerivedInit = getZeroInit(VDType);

// `specialThisDiffCase` is only required for correctly differentiating
// the following code:
// ```
// Class _d_this_obj;
// Class* _d_this = &_d_this_obj;
// ```
// Computation of hessian requires this code to be correctly
// differentiated.
bool specialThisDiffCase = false;
if (const auto* MD = dyn_cast<CXXMethodDecl>(m_DiffReq.Function)) {
if (VDDerivedType->isPointerType() && MD->isInstance()) {
specialThisDiffCase = true;
}
}
// If VD is a reference to a local variable, then the initial value is set
// to the derived variable of the corresponding local variable.
// If VD is a reference to a non-local variable (global variable, struct
// member etc), then no derived variable is available, thus `VDDerived`
// does not need to reference any variable, consequentially the
// `VDDerivedType` is the corresponding non-reference type and the initial
// value is set to 0.
// Otherwise, for non-reference types, the initial value is set to 0.
if (!VDDerivedInit)
VDDerivedInit = getZeroInit(VDType);

// `specialThisDiffCase` is only required for correctly differentiating
// the following code:
// ```
// Class _d_this_obj;
// Class* _d_this = &_d_this_obj;
// ```
// Computation of hessian requires this code to be correctly
// differentiated.
bool specialThisDiffCase = false;
if (const auto* MD = dyn_cast<CXXMethodDecl>(m_DiffReq.Function)) {
if (VDDerivedType->isPointerType() && MD->isInstance())
specialThisDiffCase = true;
}

if (isRefType) {
initDiff = Visit(VD->getInit());
if (!initDiff.getForwSweepExpr_dx()) {
VDDerivedType = ComputeAdjointType(VDType.getNonReferenceType());
isRefType = false;
}
if (promoteToFnScope || !isRefType)
VDDerivedInit = getZeroInit(VDDerivedType);
else
VDDerivedInit = initDiff.getForwSweepExpr_dx();
if (isRefType) {
initDiff = Visit(VD->getInit());
if (!initDiff.getForwSweepExpr_dx()) {
VDDerivedType = ComputeAdjointType(VDType.getNonReferenceType());
isRefType = false;
}
if (promoteToFnScope || !isRefType)
VDDerivedInit = getZeroInit(VDDerivedType);
else
VDDerivedInit = initDiff.getForwSweepExpr_dx();
}

if (VDType->isStructureOrClassType()) {
m_TrackConstructorPullbackInfo = true;
initDiff = Visit(VD->getInit());
m_TrackConstructorPullbackInfo = false;
constructorPullbackInfo = getConstructorPullbackCallInfo();
resetConstructorPullbackCallInfo();
if (initDiff.getForwSweepExpr_dx())
VDDerivedInit = initDiff.getForwSweepExpr_dx();
}

if (VDType->isStructureOrClassType()) {
m_TrackConstructorPullbackInfo = true;
// FIXME: Remove the special cases introduced by `specialThisDiffCase`
// once reverse mode supports pointers. `specialThisDiffCase` is only
// required for correctly differentiating the following code:
// ```
// Class _d_this_obj;
// Class* _d_this = &_d_this_obj;
// ```
// Computation of hessian requires this code to be correctly
// differentiated.
if (specialThisDiffCase && VD->getNameAsString() == "_d_this") {
VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema);
initDiff = Visit(VD->getInit());
if (initDiff.getExpr_dx())
VDDerivedInit = initDiff.getExpr_dx();
}
// if VD is a pointer type, then the initial value is set to the derived
// expression of the corresponding pointer type.
else if (isPointerType) {
if (!isInitializedByNewExpr)
initDiff = Visit(VD->getInit());
m_TrackConstructorPullbackInfo = false;
constructorPullbackInfo = getConstructorPullbackCallInfo();
resetConstructorPullbackCallInfo();
if (initDiff.getForwSweepExpr_dx())
VDDerivedInit = initDiff.getForwSweepExpr_dx();
}

// FIXME: Remove the special cases introduced by `specialThisDiffCase`
// once reverse mode supports pointers. `specialThisDiffCase` is only
// required for correctly differentiating the following code:
// ```
// Class _d_this_obj;
// Class* _d_this = &_d_this_obj;
// ```
// Computation of hessian requires this code to be correctly
// differentiated.
if (specialThisDiffCase && VD->getNameAsString() == "_d_this") {
// If the pointer is const and derived expression is not available, then
// we should not create a derived variable for it. This will be useful
// for reducing number of differentiation variables in pullbacks.
bool constPointer = VDType->getPointeeType().isConstQualified();
if (constPointer && !isInitializedByNewExpr && !initDiff.getExpr_dx())
initializeDerivedVar = false;
else {
VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema);
initDiff = Visit(VD->getInit());
if (initDiff.getExpr_dx())
VDDerivedInit = initDiff.getExpr_dx();
}
// if VD is a pointer type, then the initial value is set to the derived
// expression of the corresponding pointer type.
else if (isPointerType) {
if (!isInitializedByNewExpr)
initDiff = Visit(VD->getInit());

// If the pointer is const and derived expression is not available, then
// we should not create a derived variable for it. This will be useful
// for reducing number of differentiation variables in pullbacks.
bool constPointer = VDType->getPointeeType().isConstQualified();
if (constPointer && !isInitializedByNewExpr && !initDiff.getExpr_dx())
initializeDerivedVar = false;
else {
VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema);
// If it's a pointer to a constant type, then remove the constness.
if (constPointer) {
// first extract the pointee type
auto pointeeType = VDType->getPointeeType();
// then remove the constness
pointeeType.removeLocalConst();
// then create a new pointer type with the new pointee type
VDDerivedType = m_Context.getPointerType(pointeeType);
}
VDDerivedInit = getZeroInit(VDDerivedType);
// If it's a pointer to a constant type, then remove the constness.
if (constPointer) {
// first extract the pointee type
auto pointeeType = VDType->getPointeeType();
// then remove the constness
pointeeType.removeLocalConst();
// then create a new pointer type with the new pointee type
VDDerivedType = m_Context.getPointerType(pointeeType);
}
VDDerivedInit = getZeroInit(VDDerivedType);
}
if (initializeDerivedVar)
VDDerived = BuildGlobalVarDecl(
VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit, false,
nullptr, VD->getInitStyle());
}
if (initializeDerivedVar)
VDDerived =
BuildGlobalVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(),
VDDerivedInit, false, nullptr, VD->getInitStyle());

if (!m_DiffReq.shouldHaveAdjoint((VD)))
VDDerived = nullptr;
Expand Down
11 changes: 6 additions & 5 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,12 +410,13 @@ namespace clad {

Expr* VisitorBase::getZeroInit(QualType T) {
// FIXME: Consolidate other uses of synthesizeLiteral for creation 0 or 1.
if (T->isVoidType())
if (T->isVoidType() || isa<VariableArrayType>(T))
return nullptr;
if ((T->isScalarType() || T->isPointerType()) && !T->isReferenceType()) {
ExprResult Zero =
ConstantFolder::synthesizeLiteral(T, m_Context, /*val=*/0);
return Zero.get();
if ((T->isScalarType() || T->isPointerType()) && !T->isReferenceType())
return ConstantFolder::synthesizeLiteral(T, m_Context, /*val=*/0);
if (isa<ConstantArrayType>(T)) {
Expr* zero = ConstantFolder::synthesizeLiteral(T, m_Context, /*val=*/0);
return m_Sema.ActOnInitList(noLoc, {zero}, noLoc).get();
}
return m_Sema.ActOnInitList(noLoc, {}, noLoc).get();
}
Expand Down

0 comments on commit eee6faa

Please sign in to comment.