Skip to content

Commit

Permalink
Don't create pullbacks for function with not varied parameters (vgvas…
Browse files Browse the repository at this point in the history
…silev#1127)

This PR enables clad not to create pullbacks if the parameters are either constant or not varied.

Fixes: vgvassilev#642, Fixes vgvassilev#682
  • Loading branch information
ovdiiuv authored Nov 7, 2024
1 parent 9070023 commit fa89545
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 10 deletions.
7 changes: 2 additions & 5 deletions lib/Differentiator/ActivityAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,11 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) {
bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams());
if (noHiddenParam) {
MutableArrayRef<ParmVarDecl*> FDparam = FD->parameters();
m_Varied = true;
m_Marking = true;
for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) {
clang::Expr* par = CE->getArg(i);
TraverseStmt(par);
m_VariedDecls.insert(FDparam[i]);
}
m_Varied = false;
m_Marking = false;
}
return true;
}
Expand All @@ -141,7 +137,8 @@ bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) {
m_Varied = false;
TraverseStmt(init);
m_Marking = true;
if (m_Varied)
QualType VDTy = cast<VarDecl>(D)->getType();
if (m_Varied || utils::isArrayOrPointerType(VDTy))
copyVarToCurBlock(cast<VarDecl>(D));
m_Marking = false;
}
Expand Down
16 changes: 12 additions & 4 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -636,11 +636,19 @@ namespace clad {
if (!EnableVariedAnalysis)
return true;

if (VD->getType()->isPointerType() || isa<ArrayType>(VD->getType()))
return true;

if (!m_ActivityRunInfo.HasAnalysisRun) {
std::copy(Function->param_begin(), Function->param_end(),
ArrayRef<ParmVarDecl*> FDparam = Function->parameters();
std::vector<ParmVarDecl*> derivedParam;

for (auto* parameter : FDparam) {
QualType parType = parameter->getType();
while (parType->isPointerType())
parType = parType->getPointeeType();
if (!parType.isConstQualified())
derivedParam.push_back(parameter);
}

std::copy(derivedParam.begin(), derivedParam.end(),
std::inserter(m_ActivityRunInfo.ToBeRecorded,
m_ActivityRunInfo.ToBeRecorded.end()));

Expand Down
22 changes: 21 additions & 1 deletion lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1826,7 +1826,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// subexpression.
if (const auto* MTE = dyn_cast<MaterializeTemporaryExpr>(arg))
arg = clad_compat::GetSubExpr(MTE)->IgnoreImpCasts();
if (!arg->isEvaluatable(m_Context)) {
// FIXME: We should consider moving this code in the VariedAnalysis
// where we could decide to remove pullback requests from the
// diff graph.
class VariedChecker : public RecursiveASTVisitor<VariedChecker> {
const DiffRequest& Request;

public:
VariedChecker(const DiffRequest& DR) : Request(DR) {}
bool isVariedE(const clang::Expr* E) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
return !TraverseStmt(const_cast<clang::Expr*>(E));
}
bool VisitDeclRefExpr(const clang::DeclRefExpr* DRE) {
if (!isa<VarDecl>(DRE->getDecl()))
return true;
if (Request.shouldHaveAdjoint(cast<VarDecl>(DRE->getDecl())))
return false;
return true;
}
} analyzer(m_DiffReq);
if (analyzer.isVariedE(arg)) {
allArgsAreConstantLiterals = false;
break;
}
Expand Down
46 changes: 46 additions & 0 deletions test/Analyses/ActivityReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,46 @@ double f8(double x){
// CHECK-NEXT: }
// CHECK-NEXT: }

double fn9(double x, double const *obs)
{
double res = 0.0;
for (int loopIdx0 = 0; loopIdx0 < 2; loopIdx0++) {
res += std::lgamma(obs[2 + loopIdx0] + 1) + x;
}
return res;
}

// CHECK: void fn9_grad(double x, const double *obs, double *_d_x, double *_d_obs) {
// CHECK-NEXT: int loopIdx0 = 0;
// CHECK-NEXT: clad::tape<double> _t1 = {};
// CHECK-NEXT: double _d_res = 0.;
// CHECK-NEXT: double res = 0.;
// CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}};
// CHECK-NEXT: for (loopIdx0 = 0; ; loopIdx0++) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!(loopIdx0 < 2))
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: _t0++;
// CHECK-NEXT: clad::push(_t1, res);
// CHECK-NEXT: res += std::lgamma(obs[2 + loopIdx0] + 1) + x;
// CHECK-NEXT: }
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: for (;; _t0--) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!_t0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: loopIdx0--;
// CHECK-NEXT: {
// CHECK-NEXT: res = clad::pop(_t1);
// CHECK-NEXT: double _r_d0 = _d_res;
// CHECK-NEXT: *_d_x += _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }


#define TEST(F, x) { \
result[0] = 0; \
auto F##grad = clad::gradient<clad::opts::enable_va>(F);\
Expand All @@ -272,7 +312,10 @@ double f8(double x){
}

int main(){
double arr[] = {1,2,3,4,5};
double darr[] = {0,0,0,0,0};
double result[3] = {};
double dx;
TEST(f1, 3);// CHECK-EXEC: {6.00}
TEST(f2, 3);// CHECK-EXEC: {6.00}
TEST(f3, 3);// CHECK-EXEC: {0.00}
Expand All @@ -281,6 +324,9 @@ int main(){
TEST(f6, 3);// CHECK-EXEC: {0.00}
TEST(f7, 3);// CHECK-EXEC: {1.00}
TEST(f8, 3);// CHECK-EXEC: {1.00}
auto grad = clad::gradient<clad::opts::enable_va>(fn9);
grad.execute(3, arr, &dx, darr);
printf("%.2f\n", dx);// CHECK-EXEC: 2.00
}

// CHECK: void f4_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u) {
Expand Down

0 comments on commit fa89545

Please sign in to comment.