Skip to content

Commit

Permalink
Implement visitor to check varied expression
Browse files Browse the repository at this point in the history
  • Loading branch information
Max Andriychuk authored and Max Andriychuk committed Nov 5, 2024
1 parent 90d17a0 commit 8386d91
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
10 changes: 5 additions & 5 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -627,13 +627,13 @@ namespace clad {
ArrayRef<ParmVarDecl*> FDparam = Function->parameters();
std::vector<ParmVarDecl*> derivedParam;

for(auto* parameter: FDparam){
for (auto* parameter : FDparam) {
QualType parType = parameter->getType();
if(parType->isPointerType()){
if(!parType->getPointeeType().isConstQualified())
derivedParam.push_back(parameter);
}else if(!parType.isConstQualified())
if (parType->isPointerType()) {
if (!parType->getPointeeType().isConstQualified())
derivedParam.push_back(parameter);
} else if (!parType.isConstQualified())
derivedParam.push_back(parameter);
}

std::copy(derivedParam.begin(), derivedParam.end(),
Expand Down
32 changes: 31 additions & 1 deletion lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1648,6 +1648,34 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(clonedDRE);
}

class VariedChecker : public RecursiveASTVisitor<VariedChecker> {
public:
explicit VariedChecker(const DiffRequest& DR) : Request(DR) {}

bool isVariedE(const clang::Expr* E) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
TraverseStmt(const_cast<clang::Expr*>(E));
return containsTargetVar;
}

bool VisitDeclRefExpr(const clang::DeclRefExpr* DRE) {
if (isa<VarDecl>(DRE->getDecl())) {
if (Request.shouldHaveAdjoint(dyn_cast<VarDecl>(DRE->getDecl()))) {
containsTargetVar = true;
return false;
}
} else {
containsTargetVar = true;
return false;
}
return true;
}

private:
bool containsTargetVar = false;
const DiffRequest& Request;
};

StmtDiff ReverseModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) {
auto* Constant0 =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Expand Down Expand Up @@ -1800,7 +1828,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// subexpression.
if (const auto* MTE = dyn_cast<MaterializeTemporaryExpr>(arg))
arg = clad_compat::GetSubExpr(MTE)->IgnoreImpCasts();
if (!arg->isEvaluatable(m_Context)) {

VariedChecker analyzer(m_DiffReq);
if (analyzer.isVariedE(arg)) {
allArgsAreConstantLiterals = false;
break;
}
Expand Down

0 comments on commit 8386d91

Please sign in to comment.