From 7061c34b55fdf63a0775d90296da44e27d8c7b71 Mon Sep 17 00:00:00 2001 From: Max Andriychuk Date: Wed, 6 Nov 2024 23:07:42 +0100 Subject: [PATCH] Implement visitor to check varied expression --- lib/Differentiator/DiffPlanner.cpp | 11 +++++------ lib/Differentiator/ReverseModeVisitor.cpp | 19 ++++++++++++++++++- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index d501ac614..25dceb4a7 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -627,13 +627,12 @@ namespace clad { ArrayRef FDparam = Function->parameters(); std::vector derivedParam; - for(auto* parameter: FDparam){ + for (auto* parameter : FDparam) { QualType parType = parameter->getType(); - if(parType->isPointerType()){ - if(!parType->getPointeeType().isConstQualified()) - derivedParam.push_back(parameter); - }else if(!parType.isConstQualified()) - derivedParam.push_back(parameter); + while (parType->isPointerType()) + parType = parType->getPointeeType(); + if (!parType.isConstQualified()) + derivedParam.push_back(parameter); } std::copy(derivedParam.begin(), derivedParam.end(), diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 8995a2a67..11945fa69 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1800,7 +1800,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // subexpression. if (const auto* MTE = dyn_cast(arg)) arg = clad_compat::GetSubExpr(MTE)->IgnoreImpCasts(); - if (!arg->isEvaluatable(m_Context)) { + class VariedChecker : public RecursiveASTVisitor { + const DiffRequest& Request; + + public: + VariedChecker(const DiffRequest& DR) : Request(DR) {} + bool isVariedE(const clang::Expr* E) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + return !TraverseStmt(const_cast(E)); + } + bool VisitDeclRefExpr(const clang::DeclRefExpr* DRE) { + if (!isa(DRE->getDecl())) + return true; + if (Request.shouldHaveAdjoint(cast(DRE->getDecl()))) + return false; + return true; + } + } analyzer(m_DiffReq); + if (analyzer.isVariedE(arg)) { allArgsAreConstantLiterals = false; break; }