From a884c8c3556c0654fdf9ae189f18fd96985bf5b9 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Thu, 2 Nov 2023 17:58:03 +0200 Subject: [PATCH 1/2] Introduce type cloning. --- include/clad/Differentiator/StmtClone.h | 12 +- include/clad/Differentiator/VisitorBase.h | 3 + lib/Differentiator/ReverseModeVisitor.cpp | 4 +- lib/Differentiator/StmtClone.cpp | 292 +++++++++++++++------- lib/Differentiator/VisitorBase.cpp | 6 + 5 files changed, 228 insertions(+), 89 deletions(-) diff --git a/include/clad/Differentiator/StmtClone.h b/include/clad/Differentiator/StmtClone.h index 61e79e6b6..0d12059f7 100644 --- a/include/clad/Differentiator/StmtClone.h +++ b/include/clad/Differentiator/StmtClone.h @@ -48,8 +48,12 @@ namespace utils { template StmtTy* Clone(const StmtTy* S); - // visitor part (not for public use) - // Stmt.def could be used if ABSTR_STMT is introduced + /// Cloning types is necessary since VariableArrayType + /// store a pointer to their size expression. + clang::QualType CloneType(const clang::QualType T); + + // visitor part (not for public use) + // Stmt.def could be used if ABSTR_STMT is introduced #define DECLARE_CLONE_FN(CLASS) clang::Stmt* Visit ## CLASS(clang::CLASS *Node); DECLARE_CLONE_FN(BinaryOperator) DECLARE_CLONE_FN(UnaryOperator) @@ -153,6 +157,10 @@ namespace utils { ReferencesUpdater(clang::Sema& SemaRef, clang::Scope* S, const clang::FunctionDecl* FD); bool VisitDeclRefExpr(clang::DeclRefExpr* DRE); + bool VisitStmt(clang::Stmt* S); + /// Used to update the size expression of QT + /// if QT is VariableArrayType. + void updateType(clang::QualType QT); }; } // namespace utils } // namespace clad diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 81e4af9a2..9036c2c96 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -558,6 +558,9 @@ namespace clad { clang::Stmt* Clone(const clang::Stmt* S); /// A shorthand to simplify cloning of expressions. clang::Expr* Clone(const clang::Expr* E); + /// Cloning types is necessary since VariableArrayType + /// store a pointer to their size expression. + clang::QualType CloneType(const clang::QualType T); }; } // end namespace clad diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 4e5d7020e..cd39fb06f 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2534,11 +2534,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Here separate behaviour for record and non-record types is only // necessary to preserve the old tests. if (VD->getType()->isRecordType()) - VDClone = BuildVarDecl(VD->getType(), VD->getNameAsString(), + VDClone = BuildVarDecl(CloneType(VD->getType()), VD->getNameAsString(), initDiff.getExpr(), VD->isDirectInit(), VD->getTypeSourceInfo(), VD->getInitStyle()); else - VDClone = BuildVarDecl(VD->getType(), VD->getNameAsString(), + VDClone = BuildVarDecl(CloneType(VD->getType()), VD->getNameAsString(), initDiff.getExpr(), VD->isDirectInit()); Expr* derivedVDE = BuildDeclRef(VDDerived); diff --git a/lib/Differentiator/StmtClone.cpp b/lib/Differentiator/StmtClone.cpp index 241fe79ff..8907335ec 100644 --- a/lib/Differentiator/StmtClone.cpp +++ b/lib/Differentiator/StmtClone.cpp @@ -63,53 +63,115 @@ Stmt* StmtClone::Visit ## CLASS(CLASS *Node) \ return result; \ } -DEFINE_CLONE_EXPR_CO11(BinaryOperator, (CLAD_COMPAT_CLANG11_Ctx_ExtraParams Clone(Node->getLHS()), Clone(Node->getRHS()), Node->getOpcode(), Node->getType(), Node->getValueKind(), Node->getObjectKind(), Node->getOperatorLoc(), Node->getFPFeatures(CLAD_COMPAT_CLANG11_LangOptions_EtraParams))) -DEFINE_CLONE_EXPR_CO11(UnaryOperator, (CLAD_COMPAT_CLANG11_Ctx_ExtraParams Clone(Node->getSubExpr()), Node->getOpcode(), Node->getType(), Node->getValueKind(), Node->getObjectKind(), Node->getOperatorLoc() CLAD_COMPAT_CLANG7_UnaryOperator_ExtraParams CLAD_COMPAT_CLANG11_UnaryOperator_ExtraParams)) +DEFINE_CLONE_EXPR_CO11( + BinaryOperator, + (CLAD_COMPAT_CLANG11_Ctx_ExtraParams Clone(Node->getLHS()), + Clone(Node->getRHS()), Node->getOpcode(), CloneType(Node->getType()), + Node->getValueKind(), Node->getObjectKind(), Node->getOperatorLoc(), + Node->getFPFeatures(CLAD_COMPAT_CLANG11_LangOptions_EtraParams))) +DEFINE_CLONE_EXPR_CO11( + UnaryOperator, + (CLAD_COMPAT_CLANG11_Ctx_ExtraParams Clone(Node->getSubExpr()), + Node->getOpcode(), CloneType(Node->getType()), Node->getValueKind(), + Node->getObjectKind(), + Node->getOperatorLoc() CLAD_COMPAT_CLANG7_UnaryOperator_ExtraParams + CLAD_COMPAT_CLANG11_UnaryOperator_ExtraParams)) Stmt* StmtClone::VisitDeclRefExpr(DeclRefExpr *Node) { TemplateArgumentListInfo TAListInfo; Node->copyTemplateArgumentsInto(TAListInfo); - return DeclRefExpr::Create(Ctx, Node->getQualifierLoc(), Node->getTemplateKeywordLoc(), Node->getDecl(), Node->refersToEnclosingVariableOrCapture(), Node->getNameInfo(), Node->getType(), Node->getValueKind(), Node->getFoundDecl(), &TAListInfo); + return DeclRefExpr::Create( + Ctx, Node->getQualifierLoc(), Node->getTemplateKeywordLoc(), + Node->getDecl(), Node->refersToEnclosingVariableOrCapture(), + Node->getNameInfo(), CloneType(Node->getType()), Node->getValueKind(), + Node->getFoundDecl(), &TAListInfo); } -DEFINE_CREATE_EXPR(IntegerLiteral, (Ctx, Node->getValue(), Node->getType(), Node->getLocation())) -DEFINE_CLONE_EXPR_CO(PredefinedExpr, (CLAD_COMPAT_CLANG8_Ctx_ExtraParams Node->getLocation(), Node->getType(), Node->getIdentKind() CLAD_COMPAT_CLANG17_IsTransparent(Node), Node->getFunctionName())) -DEFINE_CLONE_EXPR(CharacterLiteral, (Node->getValue(), Node->getKind(), Node->getType(), Node->getLocation())) -DEFINE_CLONE_EXPR(ImaginaryLiteral, (Clone(Node->getSubExpr()), Node->getType())) +DEFINE_CREATE_EXPR(IntegerLiteral, + (Ctx, Node->getValue(), CloneType(Node->getType()), + Node->getLocation())) +DEFINE_CLONE_EXPR_CO(PredefinedExpr, + (CLAD_COMPAT_CLANG8_Ctx_ExtraParams Node->getLocation(), + CloneType(Node->getType()), + Node->getIdentKind() + CLAD_COMPAT_CLANG17_IsTransparent(Node), + Node->getFunctionName())) +DEFINE_CLONE_EXPR(CharacterLiteral, + (Node->getValue(), Node->getKind(), + CloneType(Node->getType()), Node->getLocation())) +DEFINE_CLONE_EXPR(ImaginaryLiteral, + (Clone(Node->getSubExpr()), CloneType(Node->getType()))) DEFINE_CLONE_EXPR(ParenExpr, (Node->getLParen(), Node->getRParen(), Clone(Node->getSubExpr()))) -DEFINE_CLONE_EXPR(ArraySubscriptExpr, (Clone(Node->getLHS()), Clone(Node->getRHS()), Node->getType(), Node->getValueKind(), Node->getObjectKind(), Node->getRBracketLoc())) +DEFINE_CLONE_EXPR(ArraySubscriptExpr, + (Clone(Node->getLHS()), Clone(Node->getRHS()), + CloneType(Node->getType()), Node->getValueKind(), + Node->getObjectKind(), Node->getRBracketLoc())) DEFINE_CREATE_EXPR(CXXDefaultArgExpr, (Ctx, SourceLocation(), Node->getParam() CLAD_COMPAT_CLANG16_CXXDefaultArgExpr_getRewrittenExpr_Param(Node) CLAD_COMPAT_CLANG9_CXXDefaultArgExpr_getUsedContext_Param(Node))) Stmt* StmtClone::VisitMemberExpr(MemberExpr* Node) { TemplateArgumentListInfo TemplateArgs; if (Node->hasExplicitTemplateArgs()) Node->copyTemplateArgumentsInto(TemplateArgs); - MemberExpr* result = MemberExpr::Create(Ctx, - Clone(Node->getBase()), - Node->isArrow(), - Node->getOperatorLoc(), - Node->getQualifierLoc(), - Node->getTemplateKeywordLoc(), - Node->getMemberDecl(), - Node->getFoundDecl(), - Node->getMemberNameInfo(), - &TemplateArgs, - Node->getType(), - Node->getValueKind(), - Node->getObjectKind() - CLAD_COMPAT_CLANG9_MemberExpr_ExtraParams( - Node->isNonOdrUse())); + MemberExpr* result = MemberExpr::Create( + Ctx, Clone(Node->getBase()), Node->isArrow(), Node->getOperatorLoc(), + Node->getQualifierLoc(), Node->getTemplateKeywordLoc(), + Node->getMemberDecl(), Node->getFoundDecl(), Node->getMemberNameInfo(), + &TemplateArgs, CloneType(Node->getType()), Node->getValueKind(), + Node->getObjectKind() + CLAD_COMPAT_CLANG9_MemberExpr_ExtraParams(Node->isNonOdrUse())); // Copy Value and Type dependent clad_compat::ExprSetDeps(result, Node); return result; } -DEFINE_CLONE_EXPR(CompoundLiteralExpr, (Node->getLParenLoc(), Node->getTypeSourceInfo(), Node->getType(), Node->getValueKind(), Clone(Node->getInitializer()), Node->isFileScope())) -DEFINE_CREATE_EXPR(ImplicitCastExpr, (Ctx, Node->getType(), Node->getCastKind(), Clone(Node->getSubExpr()), 0, Node->getValueKind() /*EP*/CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node) )) -DEFINE_CREATE_EXPR(CStyleCastExpr, (Ctx, Node->getType(), Node->getValueKind(), Node->getCastKind(), Clone(Node->getSubExpr()), 0 /*EP*/CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node), Node->getTypeInfoAsWritten(), Node->getLParenLoc(), Node->getRParenLoc())) -DEFINE_CREATE_EXPR(CXXStaticCastExpr, (Ctx, Node->getType(), Node->getValueKind(), Node->getCastKind(), Clone(Node->getSubExpr()), 0, Node->getTypeInfoAsWritten() /*EP*/CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) -DEFINE_CREATE_EXPR(CXXDynamicCastExpr, (Ctx, Node->getType(), Node->getValueKind(), Node->getCastKind(), Clone(Node->getSubExpr()), 0, Node->getTypeInfoAsWritten(), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) -DEFINE_CREATE_EXPR(CXXReinterpretCastExpr, (Ctx, Node->getType(), Node->getValueKind(), Node->getCastKind(), Clone(Node->getSubExpr()), 0, Node->getTypeInfoAsWritten(), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) -DEFINE_CREATE_EXPR(CXXConstCastExpr, (Ctx, Node->getType(), Node->getValueKind(), Clone(Node->getSubExpr()), Node->getTypeInfoAsWritten(), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) -DEFINE_CREATE_EXPR(CXXConstructExpr, (Ctx, Node->getType(), Node->getLocation(), Node->getConstructor(), Node->isElidable(), clad_compat::makeArrayRef(Node->getArgs(), Node->getNumArgs()), Node->hadMultipleCandidates(), Node->isListInitialization(), Node->isStdInitListInitialization(), Node->requiresZeroInitialization(), Node->getConstructionKind(), Node->getParenOrBraceRange())) -DEFINE_CREATE_EXPR(CXXFunctionalCastExpr, (Ctx, Node->getType(), Node->getValueKind(), Node->getTypeInfoAsWritten(), Node->getCastKind(), Clone(Node->getSubExpr()), 0 /*EP*/CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node), Node->getLParenLoc(), Node->getRParenLoc())) +DEFINE_CLONE_EXPR(CompoundLiteralExpr, + (Node->getLParenLoc(), Node->getTypeSourceInfo(), + CloneType(Node->getType()), Node->getValueKind(), + Clone(Node->getInitializer()), Node->isFileScope())) +DEFINE_CREATE_EXPR( + ImplicitCastExpr, + (Ctx, CloneType(Node->getType()), Node->getCastKind(), + Clone(Node->getSubExpr()), 0, + Node->getValueKind() /*EP*/ CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node))) +DEFINE_CREATE_EXPR(CStyleCastExpr, + (Ctx, CloneType(Node->getType()), Node->getValueKind(), + Node->getCastKind(), Clone(Node->getSubExpr()), + 0 /*EP*/ CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node), + Node->getTypeInfoAsWritten(), Node->getLParenLoc(), + Node->getRParenLoc())) +DEFINE_CREATE_EXPR( + CXXStaticCastExpr, + (Ctx, CloneType(Node->getType()), Node->getValueKind(), Node->getCastKind(), + Clone(Node->getSubExpr()), 0, + Node->getTypeInfoAsWritten() /*EP*/ CLAD_COMPAT_CLANG12_CastExpr_GetFPO( + Node), + Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) +DEFINE_CREATE_EXPR(CXXDynamicCastExpr, + (Ctx, CloneType(Node->getType()), Node->getValueKind(), + Node->getCastKind(), Clone(Node->getSubExpr()), 0, + Node->getTypeInfoAsWritten(), Node->getOperatorLoc(), + Node->getRParenLoc(), Node->getAngleBrackets())) +DEFINE_CREATE_EXPR(CXXReinterpretCastExpr, + (Ctx, CloneType(Node->getType()), Node->getValueKind(), + Node->getCastKind(), Clone(Node->getSubExpr()), 0, + Node->getTypeInfoAsWritten(), Node->getOperatorLoc(), + Node->getRParenLoc(), Node->getAngleBrackets())) +DEFINE_CREATE_EXPR(CXXConstCastExpr, + (Ctx, CloneType(Node->getType()), Node->getValueKind(), + Clone(Node->getSubExpr()), Node->getTypeInfoAsWritten(), + Node->getOperatorLoc(), Node->getRParenLoc(), + Node->getAngleBrackets())) +DEFINE_CREATE_EXPR( + CXXConstructExpr, + (Ctx, CloneType(Node->getType()), Node->getLocation(), + Node->getConstructor(), Node->isElidable(), + clad_compat::makeArrayRef(Node->getArgs(), Node->getNumArgs()), + Node->hadMultipleCandidates(), Node->isListInitialization(), + Node->isStdInitListInitialization(), Node->requiresZeroInitialization(), + Node->getConstructionKind(), Node->getParenOrBraceRange())) +DEFINE_CREATE_EXPR(CXXFunctionalCastExpr, + (Ctx, CloneType(Node->getType()), Node->getValueKind(), + Node->getTypeInfoAsWritten(), Node->getCastKind(), + Clone(Node->getSubExpr()), + 0 /*EP*/ CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node), + Node->getLParenLoc(), Node->getRParenLoc())) DEFINE_CREATE_EXPR(ExprWithCleanups, (Ctx, Node->getSubExpr(), Node->cleanupsHaveSideEffects(), {})) // clang <= 7 do not have `ConstantExpr` node. @@ -117,27 +179,72 @@ DEFINE_CREATE_EXPR(ExprWithCleanups, (Ctx, Node->getSubExpr(), DEFINE_CREATE_EXPR(ConstantExpr, (Ctx, Clone(Node->getSubExpr()) CLAD_COMPAT_ConstantExpr_Create_ExtraParams)) #endif -DEFINE_CLONE_EXPR_CO(CXXTemporaryObjectExpr, (Ctx, Node->getConstructor(), Node->getType(), Node->getTypeSourceInfo(), clad_compat::makeArrayRef(Node->getArgs(), Node->getNumArgs()), Node->getSourceRange(), Node->hadMultipleCandidates(), Node->isListInitialization(), Node->isStdInitListInitialization(), Node->requiresZeroInitialization())) - -DEFINE_CLONE_EXPR(MaterializeTemporaryExpr, (Node->getType(), CLAD_COMPAT_CLANG10_GetTemporaryExpr(Node), Node->isBoundToLvalueReference())) -DEFINE_CLONE_EXPR_CO11(CompoundAssignOperator, (CLAD_COMPAT_CLANG11_Ctx_ExtraParams Clone(Node->getLHS()), Clone(Node->getRHS()), Node->getOpcode(), Node->getType(), - Node->getValueKind(), Node->getObjectKind(), CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Removed Node->getOperatorLoc(), Node->getFPFeatures(CLAD_COMPAT_CLANG11_LangOptions_EtraParams) CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Moved)) -DEFINE_CLONE_EXPR(ConditionalOperator, (Clone(Node->getCond()), Node->getQuestionLoc(), Clone(Node->getLHS()), Node->getColonLoc(), Clone(Node->getRHS()), Node->getType(), Node->getValueKind(), Node->getObjectKind())) -DEFINE_CLONE_EXPR(AddrLabelExpr, (Node->getAmpAmpLoc(), Node->getLabelLoc(), Node->getLabel(), Node->getType())) -DEFINE_CLONE_EXPR(StmtExpr, (Clone(Node->getSubStmt()), Node->getType(), Node->getLParenLoc(), Node->getRParenLoc() CLAD_COMPAT_CLANG10_StmtExpr_Create_ExtraParams )) -DEFINE_CLONE_EXPR(ChooseExpr, (Node->getBuiltinLoc(), Clone(Node->getCond()), Clone(Node->getLHS()), Clone(Node->getRHS()), Node->getType(), Node->getValueKind(), Node->getObjectKind(), Node->getRParenLoc(), Node->isConditionTrue() CLAD_COMPAT_CLANG11_ChooseExpr_EtraParams_Removed)) -DEFINE_CLONE_EXPR(GNUNullExpr, (Node->getType(), Node->getTokenLocation())) -DEFINE_CLONE_EXPR(VAArgExpr, (Node->getBuiltinLoc(), Clone(Node->getSubExpr()), Node->getWrittenTypeInfo(), Node->getRParenLoc(), Node->getType(), Node->isMicrosoftABI())) -DEFINE_CLONE_EXPR(ImplicitValueInitExpr, (Node->getType())) +DEFINE_CLONE_EXPR_CO( + CXXTemporaryObjectExpr, + (Ctx, Node->getConstructor(), CloneType(Node->getType()), + Node->getTypeSourceInfo(), + clad_compat::makeArrayRef(Node->getArgs(), Node->getNumArgs()), + Node->getSourceRange(), Node->hadMultipleCandidates(), + Node->isListInitialization(), Node->isStdInitListInitialization(), + Node->requiresZeroInitialization())) + +DEFINE_CLONE_EXPR(MaterializeTemporaryExpr, + (CloneType(Node->getType()), + CLAD_COMPAT_CLANG10_GetTemporaryExpr(Node), + Node->isBoundToLvalueReference())) +DEFINE_CLONE_EXPR_CO11( + CompoundAssignOperator, + (CLAD_COMPAT_CLANG11_Ctx_ExtraParams Clone(Node->getLHS()), + Clone(Node->getRHS()), Node->getOpcode(), CloneType(Node->getType()), + Node->getValueKind(), Node->getObjectKind(), + CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Removed Node + ->getOperatorLoc(), + Node->getFPFeatures(CLAD_COMPAT_CLANG11_LangOptions_EtraParams) + CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Moved)) +DEFINE_CLONE_EXPR(ConditionalOperator, + (Clone(Node->getCond()), Node->getQuestionLoc(), + Clone(Node->getLHS()), Node->getColonLoc(), + Clone(Node->getRHS()), CloneType(Node->getType()), + Node->getValueKind(), Node->getObjectKind())) +DEFINE_CLONE_EXPR(AddrLabelExpr, (Node->getAmpAmpLoc(), Node->getLabelLoc(), + Node->getLabel(), CloneType(Node->getType()))) +DEFINE_CLONE_EXPR(StmtExpr, + (Clone(Node->getSubStmt()), CloneType(Node->getType()), + Node->getLParenLoc(), + Node->getRParenLoc() + CLAD_COMPAT_CLANG10_StmtExpr_Create_ExtraParams)) +DEFINE_CLONE_EXPR(ChooseExpr, + (Node->getBuiltinLoc(), Clone(Node->getCond()), + Clone(Node->getLHS()), Clone(Node->getRHS()), + CloneType(Node->getType()), Node->getValueKind(), + Node->getObjectKind(), Node->getRParenLoc(), + Node->isConditionTrue() + CLAD_COMPAT_CLANG11_ChooseExpr_EtraParams_Removed)) +DEFINE_CLONE_EXPR(GNUNullExpr, + (CloneType(Node->getType()), Node->getTokenLocation())) +DEFINE_CLONE_EXPR(VAArgExpr, + (Node->getBuiltinLoc(), Clone(Node->getSubExpr()), + Node->getWrittenTypeInfo(), Node->getRParenLoc(), + CloneType(Node->getType()), Node->isMicrosoftABI())) +DEFINE_CLONE_EXPR(ImplicitValueInitExpr, (CloneType(Node->getType()))) DEFINE_CLONE_EXPR(ExtVectorElementExpr, (Node->getType(), Node->getValueKind(), Clone(Node->getBase()), Node->getAccessor(), Node->getAccessorLoc())) DEFINE_CLONE_EXPR(CXXBoolLiteralExpr, (Node->getValue(), Node->getType(), Node->getSourceRange().getBegin())) DEFINE_CLONE_EXPR(CXXNullPtrLiteralExpr, (Node->getType(), Node->getSourceRange().getBegin())) DEFINE_CLONE_EXPR(CXXThisExpr, (Node->getSourceRange().getBegin(), Node->getType(), Node->isImplicit())) DEFINE_CLONE_EXPR(CXXThrowExpr, (Clone(Node->getSubExpr()), Node->getType(), Node->getThrowLoc(), Node->isThrownVariableInScope())) #if CLANG_VERSION_MAJOR < 16 -DEFINE_CLONE_EXPR(SubstNonTypeTemplateParmExpr, (Node->getType(), Node->getValueKind(), Node->getBeginLoc(), Node->getParameter(), CLAD_COMPAT_SubstNonTypeTemplateParmExpr_isReferenceParameter_ExtraParam(Node) Node->getReplacement())) +DEFINE_CLONE_EXPR( + SubstNonTypeTemplateParmExpr, + (CloneType(Node->getType()), Node->getValueKind(), Node->getBeginLoc(), + Node->getParameter(), + CLAD_COMPAT_SubstNonTypeTemplateParmExpr_isReferenceParameter_ExtraParam( + Node) Node->getReplacement())) #else -DEFINE_CLONE_EXPR(SubstNonTypeTemplateParmExpr, (Node->getType(), Node->getValueKind(), Node->getBeginLoc(), Node->getReplacement(), Node->getAssociatedDecl(), Node->getIndex(), Node->getPackIndex(), Node->isReferenceParameter())); +DEFINE_CLONE_EXPR(SubstNonTypeTemplateParmExpr, + (CloneType(Node->getType()), Node->getValueKind(), + Node->getBeginLoc(), Node->getReplacement(), + Node->getAssociatedDecl(), Node->getIndex(), + Node->getPackIndex(), Node->isReferenceParameter())); #endif DEFINE_CREATE_EXPR(PseudoObjectExpr, (Ctx, Node->getSyntacticForm(), llvm::SmallVector(Node->semantics_begin(), Node->semantics_end()), Node->getResultExprIndex())) //BlockExpr @@ -147,15 +254,14 @@ Stmt* StmtClone::VisitStringLiteral(StringLiteral* Node) { llvm::SmallVector concatLocations(Node->tokloc_begin(), Node->tokloc_end()); return StringLiteral::Create(Ctx, Node->getString(), Node->getKind(), - Node->isPascal(), Node->getType(), + Node->isPascal(), CloneType(Node->getType()), &concatLocations[0], concatLocations.size()); } Stmt* StmtClone::VisitFloatingLiteral(FloatingLiteral* Node) { - FloatingLiteral* clone = FloatingLiteral::Create(Ctx, Node->getValue(), - Node->isExact(), - Node->getType(), - Node->getLocation()); + FloatingLiteral* clone = + FloatingLiteral::Create(Ctx, Node->getValue(), Node->isExact(), + CloneType(Node->getType()), Node->getLocation()); clone->setSemantics(Node->getSemantics()); return clone; } @@ -193,25 +299,20 @@ Stmt* StmtClone::VisitDesignatedInitExpr(DesignatedInitExpr* Node) { Stmt* StmtClone::VisitUnaryExprOrTypeTraitExpr(UnaryExprOrTypeTraitExpr* Node) { if (Node->isArgumentType()) - return new (Ctx) UnaryExprOrTypeTraitExpr(Node->getKind(), - Node->getArgumentTypeInfo(), - Node->getType(), - Node->getOperatorLoc(), - Node->getRParenLoc()); - return new (Ctx) UnaryExprOrTypeTraitExpr(Node->getKind(), - Clone(Node->getArgumentExpr()), - Node->getType(), - Node->getOperatorLoc(), - Node->getRParenLoc()); + return new (Ctx) + UnaryExprOrTypeTraitExpr(Node->getKind(), Node->getArgumentTypeInfo(), + CloneType(Node->getType()), + Node->getOperatorLoc(), Node->getRParenLoc()); + return new (Ctx) UnaryExprOrTypeTraitExpr( + Node->getKind(), Clone(Node->getArgumentExpr()), + CloneType(Node->getType()), Node->getOperatorLoc(), Node->getRParenLoc()); } Stmt* StmtClone::VisitCallExpr(CallExpr* Node) { - CallExpr* result = clad_compat::CallExpr_Create(Ctx, Clone(Node->getCallee()), - llvm::ArrayRef(), - Node->getType(), - Node->getValueKind(), - Node->getRParenLoc() - CLAD_COMPAT_CLANG8_CallExpr_ExtraParams); + CallExpr* result = clad_compat::CallExpr_Create( + Ctx, Clone(Node->getCallee()), llvm::ArrayRef(), + CloneType(Node->getType()), Node->getValueKind(), + Node->getRParenLoc() CLAD_COMPAT_CLANG8_CallExpr_ExtraParams); result->setNumArgsUnsafe(Node->getNumArgs()); for (unsigned i = 0, e = Node->getNumArgs(); i < e; ++i) result->setArg(i, Clone(Node->getArg(i))); @@ -248,7 +349,7 @@ Stmt* StmtClone::VisitCXXOperatorCallExpr(CXXOperatorCallExpr* Node) { } CXXOperatorCallExpr* result = clad_compat::CXXOperatorCallExpr_Create( Ctx, Node->getOperator(), Clone(Node->getCallee()), clonedArgs, - Node->getType(), Node->getValueKind(), Node->getRParenLoc(), + CloneType(Node->getType()), Node->getValueKind(), Node->getRParenLoc(), Node->getFPFeatures() CLAD_COMPAT_CLANG11_CXXOperatorCallExpr_Create_ExtraParams); //### result->setNumArgs(Ctx, Node->getNumArgs()); @@ -263,14 +364,12 @@ Stmt* StmtClone::VisitCXXOperatorCallExpr(CXXOperatorCallExpr* Node) { } Stmt* StmtClone::VisitCXXMemberCallExpr(CXXMemberCallExpr * Node) { - CXXMemberCallExpr* result - = clad_compat::CXXMemberCallExpr_Create(Ctx, Clone(Node->getCallee()), 0, - Node->getType(), - Node->getValueKind(), - Node->getRParenLoc() - /*FP*/CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node) - ); -//### result->setNumArgs(Ctx, Node->getNumArgs()); + CXXMemberCallExpr* result = clad_compat::CXXMemberCallExpr_Create( + Ctx, Clone(Node->getCallee()), 0, CloneType(Node->getType()), + Node->getValueKind(), + Node->getRParenLoc() + /*FP*/ CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node)); + // ### result->setNumArgs(Ctx, Node->getNumArgs()); result->setNumArgsUnsafe(Node->getNumArgs()); for (unsigned i = 0, e = Node->getNumArgs(); i < e; ++i) @@ -288,9 +387,9 @@ Stmt* StmtClone::VisitShuffleVectorExpr(ShuffleVectorExpr* Node) { cloned[i] = Clone(Node->getExpr(i)); llvm::ArrayRef clonedRef = clad_compat::makeArrayRef(cloned.data(), cloned.size()); - return new (Ctx) ShuffleVectorExpr(Ctx, clonedRef, Node->getType(), - Node->getBuiltinLoc(), - Node->getRParenLoc()); + return new (Ctx) + ShuffleVectorExpr(Ctx, clonedRef, CloneType(Node->getType()), + Node->getBuiltinLoc(), Node->getRParenLoc()); } Stmt* StmtClone::VisitCaseStmt(CaseStmt* Node) { @@ -363,12 +462,10 @@ Decl* StmtClone::CloneDecl(Decl* Node) { if (Node->getKind() == Decl::Var) { VarDecl* VD = static_cast(Node); - VarDecl* cloned_Decl = VarDecl::Create(Ctx, VD->getDeclContext(), - VD->getLocation(), - VD->getInnerLocStart(), - VD->getIdentifier(), VD->getType(), - VD->getTypeSourceInfo(), - VD->getStorageClass()); + VarDecl* cloned_Decl = VarDecl::Create( + Ctx, VD->getDeclContext(), VD->getLocation(), VD->getInnerLocStart(), + VD->getIdentifier(), CloneType(VD->getType()), VD->getTypeSourceInfo(), + VD->getStorageClass()); if (VD->getInit()) m_Sema.AddInitializerToDecl(cloned_Decl, Clone(VD->getInit()), VD->isDirectInit()); cloned_Decl->setTSCSpec(VD->getTSCSpec()); @@ -435,9 +532,34 @@ bool ReferencesUpdater::VisitDeclRefExpr(DeclRefExpr* DRE) { VD->setReferenced(); VD->setIsUsed(); } + updateType(DRE->getType()); return true; } +bool ReferencesUpdater::VisitStmt(clang::Stmt* S) { + if (auto* E = dyn_cast(S)) + updateType(E->getType()); + return true; +} + +void ReferencesUpdater::updateType(QualType QT) { + if (auto* varArrType = dyn_cast(QT)) + TraverseStmt(varArrType->getSizeExpr()); +} + +QualType StmtClone::CloneType(const clang::QualType T) { + if (const auto* varArrType = + dyn_cast(T.getTypePtr())) { + auto elemType = varArrType->getElementType(); + return Ctx.getVariableArrayType(elemType, Clone(varArrType->getSizeExpr()), + varArrType->getSizeModifier(), + T.getQualifiers().getAsOpaqueValue(), + SourceRange()); + } + + return clang::QualType(T.getTypePtr(), T.getQualifiers().getAsOpaqueValue()); +} + //--------------------------------------------------------- } // end namespace utils } // end namespace clad diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 90f42dc21..279b7bf5d 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -301,6 +301,12 @@ namespace clad { return llvm::cast(Clone(S)); } + QualType VisitorBase::CloneType(const QualType QT) { + auto clonedType = m_Builder.m_NodeCloner->CloneType(QT); + utils::ReferencesUpdater up(m_Sema, getCurrentScope(), m_Function); + up.updateType(clonedType); + return clonedType; + } Expr* VisitorBase::BuildOp(UnaryOperatorKind OpCode, Expr* E, SourceLocation OpLoc) { if (!E) From c77113a56065d43aa6d4ff3d36f5fa4289736a97 Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Thu, 2 Nov 2023 21:26:23 +0200 Subject: [PATCH 2/2] Apply suggestions from clang-tidy and clang-format. --- include/clad/Differentiator/StmtClone.h | 2 +- include/clad/Differentiator/VisitorBase.h | 2 +- lib/Differentiator/StmtClone.cpp | 21 +++++++++++---------- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/include/clad/Differentiator/StmtClone.h b/include/clad/Differentiator/StmtClone.h index 0d12059f7..cf32b813f 100644 --- a/include/clad/Differentiator/StmtClone.h +++ b/include/clad/Differentiator/StmtClone.h @@ -50,7 +50,7 @@ namespace utils { /// Cloning types is necessary since VariableArrayType /// store a pointer to their size expression. - clang::QualType CloneType(const clang::QualType T); + clang::QualType CloneType(clang::QualType T); // visitor part (not for public use) // Stmt.def could be used if ABSTR_STMT is introduced diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 9036c2c96..08a6201f8 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -560,7 +560,7 @@ namespace clad { clang::Expr* Clone(const clang::Expr* E); /// Cloning types is necessary since VariableArrayType /// store a pointer to their size expression. - clang::QualType CloneType(const clang::QualType T); + clang::QualType CloneType(clang::QualType T); }; } // end namespace clad diff --git a/lib/Differentiator/StmtClone.cpp b/lib/Differentiator/StmtClone.cpp index 8907335ec..e14c7fa78 100644 --- a/lib/Differentiator/StmtClone.cpp +++ b/lib/Differentiator/StmtClone.cpp @@ -62,7 +62,7 @@ Stmt* StmtClone::Visit ## CLASS(CLASS *Node) \ clad_compat::ExprSetDeps(result, Node); \ return result; \ } - +// NOLINTBEGIN(modernize-use-auto) DEFINE_CLONE_EXPR_CO11( BinaryOperator, (CLAD_COMPAT_CLANG11_Ctx_ExtraParams Clone(Node->getLHS()), @@ -128,29 +128,29 @@ DEFINE_CLONE_EXPR(CompoundLiteralExpr, DEFINE_CREATE_EXPR( ImplicitCastExpr, (Ctx, CloneType(Node->getType()), Node->getCastKind(), - Clone(Node->getSubExpr()), 0, + Clone(Node->getSubExpr()), nullptr, Node->getValueKind() /*EP*/ CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node))) DEFINE_CREATE_EXPR(CStyleCastExpr, (Ctx, CloneType(Node->getType()), Node->getValueKind(), Node->getCastKind(), Clone(Node->getSubExpr()), - 0 /*EP*/ CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node), + nullptr /*EP*/ CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node), Node->getTypeInfoAsWritten(), Node->getLParenLoc(), Node->getRParenLoc())) DEFINE_CREATE_EXPR( CXXStaticCastExpr, (Ctx, CloneType(Node->getType()), Node->getValueKind(), Node->getCastKind(), - Clone(Node->getSubExpr()), 0, + Clone(Node->getSubExpr()), nullptr, Node->getTypeInfoAsWritten() /*EP*/ CLAD_COMPAT_CLANG12_CastExpr_GetFPO( Node), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) DEFINE_CREATE_EXPR(CXXDynamicCastExpr, (Ctx, CloneType(Node->getType()), Node->getValueKind(), - Node->getCastKind(), Clone(Node->getSubExpr()), 0, + Node->getCastKind(), Clone(Node->getSubExpr()), nullptr, Node->getTypeInfoAsWritten(), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) DEFINE_CREATE_EXPR(CXXReinterpretCastExpr, (Ctx, CloneType(Node->getType()), Node->getValueKind(), - Node->getCastKind(), Clone(Node->getSubExpr()), 0, + Node->getCastKind(), Clone(Node->getSubExpr()), nullptr, Node->getTypeInfoAsWritten(), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) DEFINE_CREATE_EXPR(CXXConstCastExpr, @@ -170,7 +170,7 @@ DEFINE_CREATE_EXPR(CXXFunctionalCastExpr, (Ctx, CloneType(Node->getType()), Node->getValueKind(), Node->getTypeInfoAsWritten(), Node->getCastKind(), Clone(Node->getSubExpr()), - 0 /*EP*/ CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node), + nullptr /*EP*/ CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node), Node->getLParenLoc(), Node->getRParenLoc())) DEFINE_CREATE_EXPR(ExprWithCleanups, (Ctx, Node->getSubExpr(), Node->cleanupsHaveSideEffects(), {})) @@ -247,8 +247,9 @@ DEFINE_CLONE_EXPR(SubstNonTypeTemplateParmExpr, Node->getPackIndex(), Node->isReferenceParameter())); #endif DEFINE_CREATE_EXPR(PseudoObjectExpr, (Ctx, Node->getSyntacticForm(), llvm::SmallVector(Node->semantics_begin(), Node->semantics_end()), Node->getResultExprIndex())) -//BlockExpr -//BlockDeclRefExpr +// NOLINTEND(modernize-use-auto) +// BlockExpr +// BlockDeclRefExpr Stmt* StmtClone::VisitStringLiteral(StringLiteral* Node) { llvm::SmallVector concatLocations(Node->tokloc_begin(), @@ -543,7 +544,7 @@ bool ReferencesUpdater::VisitStmt(clang::Stmt* S) { } void ReferencesUpdater::updateType(QualType QT) { - if (auto* varArrType = dyn_cast(QT)) + if (const auto* varArrType = dyn_cast(QT)) TraverseStmt(varArrType->getSizeExpr()); }