From 770faa833bf06ec3ddc7b48eed413015f116f503 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Wed, 3 Apr 2024 02:45:07 +0200 Subject: [PATCH] Revert "Don't create adjoint pullback parameters for non-differentiable arguments." This reverts commit 5b66f0e909c0f88ce9fdd5f093255f7cc229b010. --- lib/Differentiator/DiffPlanner.cpp | 4 -- lib/Differentiator/ReverseModeVisitor.cpp | 20 ++++---- test/Gradient/FunctionCalls.C | 58 ----------------------- 3 files changed, 8 insertions(+), 74 deletions(-) diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index a86576d8e..e86a3c18e 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -273,10 +273,6 @@ namespace clad { } void DiffRequest::UpdateDiffParamsInfo(Sema& semaRef) { - // Diff info for pullbacks is generated automatically, - // its parameters are not provided by the user. - if (Mode == DiffMode::experimental_pullback) - return; DVI.clear(); auto& C = semaRef.getASTContext(); const Expr* diffArgs = Args; diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index d9283efa3..e42545a1d 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -479,11 +479,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, assert(m_Function && "Must not be null."); DiffParams args{}; - if (!request.DVI.empty()) - for (const auto& dParam : request.DVI) - args.push_back(dParam.param); - else - std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); + std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); + #ifndef NDEBUG bool isStaticMethod = utils::IsStaticMethod(FD); assert((!args.empty() || !isStaticMethod) && @@ -1520,6 +1517,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // statements there later. std::size_t insertionPoint = getCurrentBlock(direction::reverse).size(); + // FIXME: We should add instructions for handling non-differentiable + // arguments. Currently we are implicitly assuming function call only + // contains differentiable arguments. bool isCXXOperatorCall = isa(CE); for (std::size_t i = static_cast(isCXXOperatorCall), @@ -1737,9 +1737,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, "corresponding dfdx()."); } - for (Expr* arg : DerivedCallOutputArgs) - if (arg) - DerivedCallArgs.push_back(arg); + DerivedCallArgs.insert(DerivedCallArgs.end(), + DerivedCallOutputArgs.begin(), + DerivedCallOutputArgs.end()); pullbackCallArgs = DerivedCallArgs; if (pullback) @@ -1790,10 +1790,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Silence diag outputs in nested derivation process. pullbackRequest.VerboseDiags = false; pullbackRequest.EnableTBRAnalysis = enableTBR; - bool isaMethod = isa(FD); - for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) - if (DerivedCallOutputArgs[i + isaMethod]) - pullbackRequest.DVI.push_back(FD->getParamDecl(i)); FunctionDecl* pullbackFD = plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest); // Clad failed to derive it. // FIXME: Add support for reference arguments to the numerical diff. If diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index e522d4227..364b8be4f 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -758,59 +758,6 @@ double fn16(double x, double y) { //CHECK-NEXT: } //CHECK-NEXT: } -double add(double a, double* b) { - return a + b[0]; -} - -//CHECK: void add_pullback(double a, double *b, double _d_y, double *_d_a) { -//CHECK-NEXT: goto _label0; -//CHECK-NEXT: _label0: -//CHECK-NEXT: *_d_a += _d_y; -//CHECK-NEXT: } - -//CHECK: void add_pullback(double a, double *b, double _d_y, double *_d_a, double *_d_b) { -//CHECK-NEXT: goto _label0; -//CHECK-NEXT: _label0: -//CHECK-NEXT: { -//CHECK-NEXT: *_d_a += _d_y; -//CHECK-NEXT: _d_b[0] += _d_y; -//CHECK-NEXT: } -//CHECK-NEXT: } - -double fn17 (double x, double* y) { - x = add(x, y); - x = add(x, &x); - return x; -} - -//CHECK: void fn17_grad_0(double x, double *y, double *_d_x) { -//CHECK-NEXT: double _t0; -//CHECK-NEXT: double _t1; -//CHECK-NEXT: _t0 = x; -//CHECK-NEXT: x = add(x, y); -//CHECK-NEXT: _t1 = x; -//CHECK-NEXT: x = add(x, &x); -//CHECK-NEXT: goto _label0; -//CHECK-NEXT: _label0: -//CHECK-NEXT: *_d_x += 1; -//CHECK-NEXT: { -//CHECK-NEXT: x = _t1; -//CHECK-NEXT: double _r_d1 = *_d_x; -//CHECK-NEXT: *_d_x -= _r_d1; -//CHECK-NEXT: double _r1 = 0; -//CHECK-NEXT: add_pullback(x, &x, _r_d1, &_r1, &*_d_x); -//CHECK-NEXT: *_d_x += _r1; -//CHECK-NEXT: } -//CHECK-NEXT: { -//CHECK-NEXT: x = _t0; -//CHECK-NEXT: double _r_d0 = *_d_x; -//CHECK-NEXT: *_d_x -= _r_d0; -//CHECK-NEXT: double _r0 = 0; -//CHECK-NEXT: add_pullback(x, y, _r_d0, &_r0); -//CHECK-NEXT: *_d_x += _r0; -//CHECK-NEXT: } -//CHECK-NEXT: } - template void reset(T* arr, int n) { for (int i=0; i