Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't consider arrays as a special case in DifferentiateVarDecl #1164

Merged
merged 2 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/clang-tidy-review-post.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
steps:
- name: Post review comments
id: post-review
uses: ZedThree/clang-tidy-review/post@v0.18.0
uses: ZedThree/clang-tidy-review/post@v0.20.1
with:
max_comments: 10

Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/clang-tidy-review.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
git config --global --add safe.directory /github/workspace

- name: Run clang-tidy
uses: ZedThree/clang-tidy-review@v0.18.0
uses: ZedThree/clang-tidy-review@v0.20.1
id: review
with:
build_dir: build
Expand All @@ -47,4 +47,4 @@ jobs:
-DCMAKE_EXPORT_COMPILE_COMMANDS=On

- name: Upload artifacts
uses: ZedThree/clang-tidy-review/upload@v0.18.0
uses: ZedThree/clang-tidy-review/upload@v0.20.1
182 changes: 85 additions & 97 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2755,116 +2755,104 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// VDDerivedInit now serves two purposes -- as the initial derivative value
// or the size of the derivative array -- depending on the primal type.
if (const auto* AT = dyn_cast<ArrayType>(VDType)) {
if (!isa<VariableArrayType>(AT)) {
Expr* zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
VDDerivedInit = m_Sema.ActOnInitList(noLoc, {zero}, noLoc).get();
}
if (promoteToFnScope) {
if (promoteToFnScope)
if (const auto* AT = dyn_cast<ArrayType>(VDType))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: no header providing "clang::dyn_cast" is directly included [misc-include-cleaner]

      if (const auto* AT = dyn_cast<ArrayType>(VDType))
                           ^

// If an array-type declaration is promoted to function global,
// its type is changed for clad::array. In that case we should
// initialize it with its size.
initDiff = getArraySizeExpr(AT, m_Context, *this);
}
VDDerived = BuildGlobalVarDecl(
VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit, false,
nullptr, VarDecl::InitializationStyle::CInit);
} else {
// If VD is a reference to a local variable, then the initial value is set
// to the derived variable of the corresponding local variable.
// If VD is a reference to a non-local variable (global variable, struct
// member etc), then no derived variable is available, thus `VDDerived`
// does not need to reference any variable, consequentially the
// `VDDerivedType` is the corresponding non-reference type and the initial
// value is set to 0.
// Otherwise, for non-reference types, the initial value is set to 0.
if (!VDDerivedInit)
VDDerivedInit = getZeroInit(VDType);

// `specialThisDiffCase` is only required for correctly differentiating
// the following code:
// ```
// Class _d_this_obj;
// Class* _d_this = &_d_this_obj;
// ```
// Computation of hessian requires this code to be correctly
// differentiated.
bool specialThisDiffCase = false;
if (const auto* MD = dyn_cast<CXXMethodDecl>(m_DiffReq.Function)) {
if (VDDerivedType->isPointerType() && MD->isInstance()) {
specialThisDiffCase = true;
}
}
// If VD is a reference to a local variable, then the initial value is set
// to the derived variable of the corresponding local variable.
// If VD is a reference to a non-local variable (global variable, struct
// member etc), then no derived variable is available, thus `VDDerived`
// does not need to reference any variable, consequentially the
// `VDDerivedType` is the corresponding non-reference type and the initial
// value is set to 0.
// Otherwise, for non-reference types, the initial value is set to 0.
if (!VDDerivedInit)
VDDerivedInit = getZeroInit(VDType);

// `specialThisDiffCase` is only required for correctly differentiating
// the following code:
// ```
// Class _d_this_obj;
// Class* _d_this = &_d_this_obj;
// ```
// Computation of hessian requires this code to be correctly
// differentiated.
bool specialThisDiffCase = false;
if (const auto* MD = dyn_cast<CXXMethodDecl>(m_DiffReq.Function)) {
if (VDDerivedType->isPointerType() && MD->isInstance())
specialThisDiffCase = true;
}

if (isRefType) {
initDiff = Visit(VD->getInit());
if (!initDiff.getForwSweepExpr_dx()) {
VDDerivedType = ComputeAdjointType(VDType.getNonReferenceType());
isRefType = false;
}
if (promoteToFnScope || !isRefType)
VDDerivedInit = getZeroInit(VDDerivedType);
else
VDDerivedInit = initDiff.getForwSweepExpr_dx();
if (isRefType) {
initDiff = Visit(VD->getInit());
if (!initDiff.getForwSweepExpr_dx()) {
VDDerivedType = ComputeAdjointType(VDType.getNonReferenceType());
isRefType = false;
}
if (promoteToFnScope || !isRefType)
VDDerivedInit = getZeroInit(VDDerivedType);
else
VDDerivedInit = initDiff.getForwSweepExpr_dx();
}

if (VDType->isStructureOrClassType()) {
m_TrackConstructorPullbackInfo = true;
initDiff = Visit(VD->getInit());
m_TrackConstructorPullbackInfo = false;
constructorPullbackInfo = getConstructorPullbackCallInfo();
resetConstructorPullbackCallInfo();
if (initDiff.getForwSweepExpr_dx())
VDDerivedInit = initDiff.getForwSweepExpr_dx();
}

if (VDType->isStructureOrClassType()) {
m_TrackConstructorPullbackInfo = true;
// FIXME: Remove the special cases introduced by `specialThisDiffCase`
// once reverse mode supports pointers. `specialThisDiffCase` is only
// required for correctly differentiating the following code:
// ```
// Class _d_this_obj;
// Class* _d_this = &_d_this_obj;
// ```
// Computation of hessian requires this code to be correctly
// differentiated.
if (specialThisDiffCase && VD->getNameAsString() == "_d_this") {
VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema);
initDiff = Visit(VD->getInit());
if (initDiff.getExpr_dx())
VDDerivedInit = initDiff.getExpr_dx();
}
// if VD is a pointer type, then the initial value is set to the derived
// expression of the corresponding pointer type.
else if (isPointerType) {
if (!isInitializedByNewExpr)
initDiff = Visit(VD->getInit());
m_TrackConstructorPullbackInfo = false;
constructorPullbackInfo = getConstructorPullbackCallInfo();
resetConstructorPullbackCallInfo();
if (initDiff.getForwSweepExpr_dx())
VDDerivedInit = initDiff.getForwSweepExpr_dx();
}

// FIXME: Remove the special cases introduced by `specialThisDiffCase`
// once reverse mode supports pointers. `specialThisDiffCase` is only
// required for correctly differentiating the following code:
// ```
// Class _d_this_obj;
// Class* _d_this = &_d_this_obj;
// ```
// Computation of hessian requires this code to be correctly
// differentiated.
if (specialThisDiffCase && VD->getNameAsString() == "_d_this") {
// If the pointer is const and derived expression is not available, then
// we should not create a derived variable for it. This will be useful
// for reducing number of differentiation variables in pullbacks.
bool constPointer = VDType->getPointeeType().isConstQualified();
if (constPointer && !isInitializedByNewExpr && !initDiff.getExpr_dx())
initializeDerivedVar = false;
else {
VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema);
initDiff = Visit(VD->getInit());
if (initDiff.getExpr_dx())
VDDerivedInit = initDiff.getExpr_dx();
}
// if VD is a pointer type, then the initial value is set to the derived
// expression of the corresponding pointer type.
else if (isPointerType) {
if (!isInitializedByNewExpr)
initDiff = Visit(VD->getInit());

// If the pointer is const and derived expression is not available, then
// we should not create a derived variable for it. This will be useful
// for reducing number of differentiation variables in pullbacks.
bool constPointer = VDType->getPointeeType().isConstQualified();
if (constPointer && !isInitializedByNewExpr && !initDiff.getExpr_dx())
initializeDerivedVar = false;
else {
VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema);
// If it's a pointer to a constant type, then remove the constness.
if (constPointer) {
// first extract the pointee type
auto pointeeType = VDType->getPointeeType();
// then remove the constness
pointeeType.removeLocalConst();
// then create a new pointer type with the new pointee type
VDDerivedType = m_Context.getPointerType(pointeeType);
}
VDDerivedInit = getZeroInit(VDDerivedType);
// If it's a pointer to a constant type, then remove the constness.
if (constPointer) {
// first extract the pointee type
auto pointeeType = VDType->getPointeeType();
// then remove the constness
pointeeType.removeLocalConst();
// then create a new pointer type with the new pointee type
VDDerivedType = m_Context.getPointerType(pointeeType);
}
VDDerivedInit = getZeroInit(VDDerivedType);
}
if (initializeDerivedVar)
VDDerived = BuildGlobalVarDecl(
VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit, false,
nullptr, VD->getInitStyle());
}
if (initializeDerivedVar)
VDDerived =
BuildGlobalVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(),
VDDerivedInit, false, nullptr, VD->getInitStyle());

if (!m_DiffReq.shouldHaveAdjoint((VD)))
VDDerived = nullptr;
Expand Down
11 changes: 6 additions & 5 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,12 +410,13 @@ namespace clad {

Expr* VisitorBase::getZeroInit(QualType T) {
// FIXME: Consolidate other uses of synthesizeLiteral for creation 0 or 1.
if (T->isVoidType())
if (T->isVoidType() || isa<VariableArrayType>(T))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: no header providing "clang::VariableArrayType" is directly included [misc-include-cleaner]

lib/Differentiator/VisitorBase.cpp:28:

- #include <numeric>
+ #include <clang/AST/Type.h>
+ #include <numeric>

return nullptr;
if ((T->isScalarType() || T->isPointerType()) && !T->isReferenceType()) {
ExprResult Zero =
ConstantFolder::synthesizeLiteral(T, m_Context, /*val=*/0);
return Zero.get();
if ((T->isScalarType() || T->isPointerType()) && !T->isReferenceType())
return ConstantFolder::synthesizeLiteral(T, m_Context, /*val=*/0);
if (isa<ConstantArrayType>(T)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: no header providing "clang::ConstantArrayType" is directly included [misc-include-cleaner]

    if (isa<ConstantArrayType>(T)) {
            ^

Expr* zero = ConstantFolder::synthesizeLiteral(T, m_Context, /*val=*/0);
return m_Sema.ActOnInitList(noLoc, {zero}, noLoc).get();
}
return m_Sema.ActOnInitList(noLoc, {}, noLoc).get();
}
Expand Down
Loading