From 701aa1b7230f01827791b9bdb0448553cb28487d Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Sat, 7 Dec 2024 00:08:27 +0200 Subject: [PATCH] Don't consider arrays as a special case in DifferentiateVarDecl --- lib/Differentiator/ReverseModeVisitor.cpp | 182 ++++++++++------------ lib/Differentiator/VisitorBase.cpp | 11 +- 2 files changed, 91 insertions(+), 102 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 144430d21..9d280197b 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2755,116 +2755,104 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // VDDerivedInit now serves two purposes -- as the initial derivative value // or the size of the derivative array -- depending on the primal type. - if (const auto* AT = dyn_cast(VDType)) { - if (!isa(AT)) { - Expr* zero = - ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); - VDDerivedInit = m_Sema.ActOnInitList(noLoc, {zero}, noLoc).get(); - } - if (promoteToFnScope) { + if (promoteToFnScope) + if (const auto* AT = dyn_cast(VDType)) // If an array-type declaration is promoted to function global, // its type is changed for clad::array. In that case we should // initialize it with its size. initDiff = getArraySizeExpr(AT, m_Context, *this); - } - VDDerived = BuildGlobalVarDecl( - VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit, false, - nullptr, VarDecl::InitializationStyle::CInit); - } else { - // If VD is a reference to a local variable, then the initial value is set - // to the derived variable of the corresponding local variable. - // If VD is a reference to a non-local variable (global variable, struct - // member etc), then no derived variable is available, thus `VDDerived` - // does not need to reference any variable, consequentially the - // `VDDerivedType` is the corresponding non-reference type and the initial - // value is set to 0. - // Otherwise, for non-reference types, the initial value is set to 0. - if (!VDDerivedInit) - VDDerivedInit = getZeroInit(VDType); - - // `specialThisDiffCase` is only required for correctly differentiating - // the following code: - // ``` - // Class _d_this_obj; - // Class* _d_this = &_d_this_obj; - // ``` - // Computation of hessian requires this code to be correctly - // differentiated. - bool specialThisDiffCase = false; - if (const auto* MD = dyn_cast(m_DiffReq.Function)) { - if (VDDerivedType->isPointerType() && MD->isInstance()) { - specialThisDiffCase = true; - } - } + // If VD is a reference to a local variable, then the initial value is set + // to the derived variable of the corresponding local variable. + // If VD is a reference to a non-local variable (global variable, struct + // member etc), then no derived variable is available, thus `VDDerived` + // does not need to reference any variable, consequentially the + // `VDDerivedType` is the corresponding non-reference type and the initial + // value is set to 0. + // Otherwise, for non-reference types, the initial value is set to 0. + if (!VDDerivedInit) + VDDerivedInit = getZeroInit(VDType); + + // `specialThisDiffCase` is only required for correctly differentiating + // the following code: + // ``` + // Class _d_this_obj; + // Class* _d_this = &_d_this_obj; + // ``` + // Computation of hessian requires this code to be correctly + // differentiated. + bool specialThisDiffCase = false; + if (const auto* MD = dyn_cast(m_DiffReq.Function)) { + if (VDDerivedType->isPointerType() && MD->isInstance()) + specialThisDiffCase = true; + } - if (isRefType) { - initDiff = Visit(VD->getInit()); - if (!initDiff.getForwSweepExpr_dx()) { - VDDerivedType = ComputeAdjointType(VDType.getNonReferenceType()); - isRefType = false; - } - if (promoteToFnScope || !isRefType) - VDDerivedInit = getZeroInit(VDDerivedType); - else - VDDerivedInit = initDiff.getForwSweepExpr_dx(); + if (isRefType) { + initDiff = Visit(VD->getInit()); + if (!initDiff.getForwSweepExpr_dx()) { + VDDerivedType = ComputeAdjointType(VDType.getNonReferenceType()); + isRefType = false; } + if (promoteToFnScope || !isRefType) + VDDerivedInit = getZeroInit(VDDerivedType); + else + VDDerivedInit = initDiff.getForwSweepExpr_dx(); + } + + if (VDType->isStructureOrClassType()) { + m_TrackConstructorPullbackInfo = true; + initDiff = Visit(VD->getInit()); + m_TrackConstructorPullbackInfo = false; + constructorPullbackInfo = getConstructorPullbackCallInfo(); + resetConstructorPullbackCallInfo(); + if (initDiff.getForwSweepExpr_dx()) + VDDerivedInit = initDiff.getForwSweepExpr_dx(); + } - if (VDType->isStructureOrClassType()) { - m_TrackConstructorPullbackInfo = true; + // FIXME: Remove the special cases introduced by `specialThisDiffCase` + // once reverse mode supports pointers. `specialThisDiffCase` is only + // required for correctly differentiating the following code: + // ``` + // Class _d_this_obj; + // Class* _d_this = &_d_this_obj; + // ``` + // Computation of hessian requires this code to be correctly + // differentiated. + if (specialThisDiffCase && VD->getNameAsString() == "_d_this") { + VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema); + initDiff = Visit(VD->getInit()); + if (initDiff.getExpr_dx()) + VDDerivedInit = initDiff.getExpr_dx(); + } + // if VD is a pointer type, then the initial value is set to the derived + // expression of the corresponding pointer type. + else if (isPointerType) { + if (!isInitializedByNewExpr) initDiff = Visit(VD->getInit()); - m_TrackConstructorPullbackInfo = false; - constructorPullbackInfo = getConstructorPullbackCallInfo(); - resetConstructorPullbackCallInfo(); - if (initDiff.getForwSweepExpr_dx()) - VDDerivedInit = initDiff.getForwSweepExpr_dx(); - } - // FIXME: Remove the special cases introduced by `specialThisDiffCase` - // once reverse mode supports pointers. `specialThisDiffCase` is only - // required for correctly differentiating the following code: - // ``` - // Class _d_this_obj; - // Class* _d_this = &_d_this_obj; - // ``` - // Computation of hessian requires this code to be correctly - // differentiated. - if (specialThisDiffCase && VD->getNameAsString() == "_d_this") { + // If the pointer is const and derived expression is not available, then + // we should not create a derived variable for it. This will be useful + // for reducing number of differentiation variables in pullbacks. + bool constPointer = VDType->getPointeeType().isConstQualified(); + if (constPointer && !isInitializedByNewExpr && !initDiff.getExpr_dx()) + initializeDerivedVar = false; + else { VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema); - initDiff = Visit(VD->getInit()); - if (initDiff.getExpr_dx()) - VDDerivedInit = initDiff.getExpr_dx(); - } - // if VD is a pointer type, then the initial value is set to the derived - // expression of the corresponding pointer type. - else if (isPointerType) { - if (!isInitializedByNewExpr) - initDiff = Visit(VD->getInit()); - - // If the pointer is const and derived expression is not available, then - // we should not create a derived variable for it. This will be useful - // for reducing number of differentiation variables in pullbacks. - bool constPointer = VDType->getPointeeType().isConstQualified(); - if (constPointer && !isInitializedByNewExpr && !initDiff.getExpr_dx()) - initializeDerivedVar = false; - else { - VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema); - // If it's a pointer to a constant type, then remove the constness. - if (constPointer) { - // first extract the pointee type - auto pointeeType = VDType->getPointeeType(); - // then remove the constness - pointeeType.removeLocalConst(); - // then create a new pointer type with the new pointee type - VDDerivedType = m_Context.getPointerType(pointeeType); - } - VDDerivedInit = getZeroInit(VDDerivedType); + // If it's a pointer to a constant type, then remove the constness. + if (constPointer) { + // first extract the pointee type + auto pointeeType = VDType->getPointeeType(); + // then remove the constness + pointeeType.removeLocalConst(); + // then create a new pointer type with the new pointee type + VDDerivedType = m_Context.getPointerType(pointeeType); } + VDDerivedInit = getZeroInit(VDDerivedType); } - if (initializeDerivedVar) - VDDerived = BuildGlobalVarDecl( - VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit, false, - nullptr, VD->getInitStyle()); } + if (initializeDerivedVar) + VDDerived = + BuildGlobalVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(), + VDDerivedInit, false, nullptr, VD->getInitStyle()); if (!m_DiffReq.shouldHaveAdjoint((VD))) VDDerived = nullptr; diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 87c55b4cd..7884ca934 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -410,12 +410,13 @@ namespace clad { Expr* VisitorBase::getZeroInit(QualType T) { // FIXME: Consolidate other uses of synthesizeLiteral for creation 0 or 1. - if (T->isVoidType()) + if (T->isVoidType() || isa(T)) return nullptr; - if ((T->isScalarType() || T->isPointerType()) && !T->isReferenceType()) { - ExprResult Zero = - ConstantFolder::synthesizeLiteral(T, m_Context, /*val=*/0); - return Zero.get(); + if ((T->isScalarType() || T->isPointerType()) && !T->isReferenceType()) + return ConstantFolder::synthesizeLiteral(T, m_Context, /*val=*/0); + if (isa(T)) { + Expr* zero = ConstantFolder::synthesizeLiteral(T, m_Context, /*val=*/0); + return m_Sema.ActOnInitList(noLoc, {zero}, noLoc).get(); } return m_Sema.ActOnInitList(noLoc, {}, noLoc).get(); }