Skip to content

Commit

Permalink
Call isVariedParam for subexpression
Browse files Browse the repository at this point in the history
  • Loading branch information
Max Andriychuk committed Nov 4, 2024
1 parent 0cda873 commit f70473b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 30 deletions.
8 changes: 2 additions & 6 deletions lib/Differentiator/ActivityAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,11 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) {
bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams());
if (noHiddenParam) {
MutableArrayRef<ParmVarDecl*> FDparam = FD->parameters();
// m_Varied = true;
// m_Marking = true;
for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) {
clang::Expr* par = CE->getArg(i);
TraverseStmt(par);
m_VariedDecls.insert(FDparam[i]);
}
// m_Varied = false;
// m_Marking = false;
}
return true;
}
Expand All @@ -141,8 +137,8 @@ bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) {
m_Varied = false;
TraverseStmt(init);
m_Marking = true;
if (m_Varied || cast<VarDecl>(D)->getType()->isPointerType() ||
cast<VarDecl>(D)->getType()->isArrayType())
QualType VDTy = cast<VarDecl>(D)->getType();
if (m_Varied || utils::isArrayOrPointerType(VDTy))
copyVarToCurBlock(cast<VarDecl>(D));
m_Marking = false;
}
Expand Down
7 changes: 3 additions & 4 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,10 +626,9 @@ namespace clad {

for (auto* parameter : FDparam) {
QualType parType = parameter->getType();
if (parType->isPointerType()) {
if (!parType->getPointeeType().isConstQualified())
derivedParam.push_back(parameter);
} else if (!parType.isConstQualified())
while (parType->isPointerType())
parType = parType->getPointeeType();
if (!parType.isConstQualified())
derivedParam.push_back(parameter);
}

Expand Down
27 changes: 7 additions & 20 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1654,41 +1654,28 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
clad_compat::Expr_EvaluateAsConstantExpr(paramE, dummy, m_Context);
if (isConst)
return false;

if (isa<BinaryOperator>(paramE)) {
auto* binparam = dyn_cast<BinaryOperator>(paramE->IgnoreImpCasts());
for (auto* subexpr : binparam->children()) {
auto* subexprE = dyn_cast<Expr>(subexpr)->IgnoreImpCasts();
bool isConst = clad_compat::Expr_EvaluateAsConstantExpr(subexprE, dummy,
m_Context);
if (isConst)
continue;
if (const auto* ASE = dyn_cast<ArraySubscriptExpr>(subexprE)) {
const Expr* baseExpr = ASE->getBase()->IgnoreImpCasts();
auto* paramVD =
dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(baseExpr)->getDecl());
if (m_DiffReq.shouldHaveAdjoint(paramVD))
return true;
} else if (const auto* DRE = dyn_cast<DeclRefExpr>(subexprE)) {
auto* paramVD = dyn_cast<VarDecl>(DRE->getDecl());
if (m_DiffReq.shouldHaveAdjoint(paramVD))
return true;
} else
return true;
return isVariedParam(subexprE);
}
return false;
} else if (const auto* ASE =
dyn_cast<ArraySubscriptExpr>(paramE->IgnoreImpCasts())) {
const Expr* baseExpr = ASE->getBase()->IgnoreImpCasts();
auto* paramVD =
dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(baseExpr)->getDecl());
return m_DiffReq.shouldHaveAdjoint(paramVD);
if (m_DiffReq.shouldHaveAdjoint(paramVD))
return true;
} else if (const auto* DRE =
dyn_cast<DeclRefExpr>(paramE->IgnoreImpCasts())) {
auto* paramVD = dyn_cast<VarDecl>(DRE->getDecl());
return m_DiffReq.shouldHaveAdjoint(paramVD);
if (m_DiffReq.shouldHaveAdjoint(paramVD))
return true;
} else {
return true;
}
return true;
}

StmtDiff ReverseModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) {
Expand Down

0 comments on commit f70473b

Please sign in to comment.