Skip to content

Commit

Permalink
Call isVariedParam for subexpression
Browse files Browse the repository at this point in the history
  • Loading branch information
Max Andriychuk committed Nov 4, 2024
1 parent 0cda873 commit 397bf94
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 36 deletions.
8 changes: 2 additions & 6 deletions lib/Differentiator/ActivityAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,11 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) {
bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams());
if (noHiddenParam) {
MutableArrayRef<ParmVarDecl*> FDparam = FD->parameters();
// m_Varied = true;
// m_Marking = true;
for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) {
clang::Expr* par = CE->getArg(i);
TraverseStmt(par);
m_VariedDecls.insert(FDparam[i]);
}
// m_Varied = false;
// m_Marking = false;
}
return true;
}
Expand All @@ -141,8 +137,8 @@ bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) {
m_Varied = false;
TraverseStmt(init);
m_Marking = true;
if (m_Varied || cast<VarDecl>(D)->getType()->isPointerType() ||
cast<VarDecl>(D)->getType()->isArrayType())
QualType VDTy = cast<VarDecl>(D)->getType();
if (m_Varied || utils::isArrayOrPointerType(VDTy))
copyVarToCurBlock(cast<VarDecl>(D));
m_Marking = false;
}
Expand Down
7 changes: 3 additions & 4 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,10 +626,9 @@ namespace clad {

for (auto* parameter : FDparam) {
QualType parType = parameter->getType();
if (parType->isPointerType()) {
if (!parType->getPointeeType().isConstQualified())
derivedParam.push_back(parameter);
} else if (!parType.isConstQualified())
while (parType->isPointerType())
parType = parType->getPointeeType();
if (!parType.isConstQualified())
derivedParam.push_back(parameter);
}

Expand Down
39 changes: 13 additions & 26 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1649,46 +1649,29 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

bool ReverseModeVisitor::isVariedParam(const Expr* paramE) {
Expr::EvalResult dummy;
bool isConst =
clad_compat::Expr_EvaluateAsConstantExpr(paramE, dummy, m_Context);
if (isConst)
return false;

if (isa<BinaryOperator>(paramE)) {
auto* binparam = dyn_cast<BinaryOperator>(paramE->IgnoreImpCasts());
for (auto* subexpr : binparam->children()) {
auto* subexprE = dyn_cast<Expr>(subexpr)->IgnoreImpCasts();
bool isConst = clad_compat::Expr_EvaluateAsConstantExpr(subexprE, dummy,
m_Context);
if (isConst)
continue;
if (const auto* ASE = dyn_cast<ArraySubscriptExpr>(subexprE)) {
const Expr* baseExpr = ASE->getBase()->IgnoreImpCasts();
auto* paramVD =
dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(baseExpr)->getDecl());
if (m_DiffReq.shouldHaveAdjoint(paramVD))
return true;
} else if (const auto* DRE = dyn_cast<DeclRefExpr>(subexprE)) {
auto* paramVD = dyn_cast<VarDecl>(DRE->getDecl());
if (m_DiffReq.shouldHaveAdjoint(paramVD))
return true;
} else
return true;
return isVariedParam(subexprE);
}
return false;
} else if (const auto* ASE =
dyn_cast<ArraySubscriptExpr>(paramE->IgnoreImpCasts())) {
const Expr* baseExpr = ASE->getBase()->IgnoreImpCasts();
auto* paramVD =
dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(baseExpr)->getDecl());
return m_DiffReq.shouldHaveAdjoint(paramVD);
if (m_DiffReq.shouldHaveAdjoint(paramVD))
return true;
} else if (const auto* DRE =
dyn_cast<DeclRefExpr>(paramE->IgnoreImpCasts())) {
auto* paramVD = dyn_cast<VarDecl>(DRE->getDecl());
return m_DiffReq.shouldHaveAdjoint(paramVD);
if (m_DiffReq.shouldHaveAdjoint(paramVD))

Check warning on line 1669 in lib/Differentiator/ReverseModeVisitor.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/ReverseModeVisitor.cpp#L1669

Added line #L1669 was not covered by tests
return true;
} else {
return true;
}
return true;
return false;
}

StmtDiff ReverseModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) {
Expand Down Expand Up @@ -1843,7 +1826,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// subexpression.
if (const auto* MTE = dyn_cast<MaterializeTemporaryExpr>(arg))
arg = clad_compat::GetSubExpr(MTE)->IgnoreImpCasts();
if (isVariedParam(arg)) {
Expr::EvalResult dummy;
bool isConst =
clad_compat::Expr_EvaluateAsConstantExpr(arg, dummy, m_Context);

if (isVariedParam(arg) && !isConst) {
allArgsAreConstantLiterals = false;
break;
}
Expand Down

0 comments on commit 397bf94

Please sign in to comment.