Skip to content

Commit

Permalink
Revert "Don't create adjoint pullback parameters for non-differentiab…
Browse files Browse the repository at this point in the history
…le arguments."

This reverts commit 5b66f0e.
  • Loading branch information
vaithak committed Apr 3, 2024
1 parent 09522b6 commit 770faa8
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 74 deletions.
4 changes: 0 additions & 4 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,6 @@ namespace clad {
}

void DiffRequest::UpdateDiffParamsInfo(Sema& semaRef) {
// Diff info for pullbacks is generated automatically,
// its parameters are not provided by the user.
if (Mode == DiffMode::experimental_pullback)
return;
DVI.clear();
auto& C = semaRef.getASTContext();
const Expr* diffArgs = Args;
Expand Down
20 changes: 8 additions & 12 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,11 +479,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
assert(m_Function && "Must not be null.");

DiffParams args{};
if (!request.DVI.empty())
for (const auto& dParam : request.DVI)
args.push_back(dParam.param);
else
std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args));
std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args));

#ifndef NDEBUG
bool isStaticMethod = utils::IsStaticMethod(FD);
assert((!args.empty() || !isStaticMethod) &&
Expand Down Expand Up @@ -1520,6 +1517,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// statements there later.
std::size_t insertionPoint = getCurrentBlock(direction::reverse).size();

// FIXME: We should add instructions for handling non-differentiable
// arguments. Currently we are implicitly assuming function call only
// contains differentiable arguments.
bool isCXXOperatorCall = isa<CXXOperatorCallExpr>(CE);

for (std::size_t i = static_cast<std::size_t>(isCXXOperatorCall),
Expand Down Expand Up @@ -1737,9 +1737,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
"corresponding dfdx().");
}

for (Expr* arg : DerivedCallOutputArgs)
if (arg)
DerivedCallArgs.push_back(arg);
DerivedCallArgs.insert(DerivedCallArgs.end(),
DerivedCallOutputArgs.begin(),
DerivedCallOutputArgs.end());
pullbackCallArgs = DerivedCallArgs;

if (pullback)
Expand Down Expand Up @@ -1790,10 +1790,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Silence diag outputs in nested derivation process.
pullbackRequest.VerboseDiags = false;
pullbackRequest.EnableTBRAnalysis = enableTBR;
bool isaMethod = isa<CXXMethodDecl>(FD);
for (size_t i = 0, e = FD->getNumParams(); i < e; ++i)
if (DerivedCallOutputArgs[i + isaMethod])
pullbackRequest.DVI.push_back(FD->getParamDecl(i));
FunctionDecl* pullbackFD = plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest);
// Clad failed to derive it.
// FIXME: Add support for reference arguments to the numerical diff. If
Expand Down
58 changes: 0 additions & 58 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -758,59 +758,6 @@ double fn16(double x, double y) {
//CHECK-NEXT: }
//CHECK-NEXT: }

double add(double a, double* b) {
return a + b[0];
}

//CHECK: void add_pullback(double a, double *b, double _d_y, double *_d_a) {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: *_d_a += _d_y;
//CHECK-NEXT: }

//CHECK: void add_pullback(double a, double *b, double _d_y, double *_d_a, double *_d_b) {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: *_d_a += _d_y;
//CHECK-NEXT: _d_b[0] += _d_y;
//CHECK-NEXT: }
//CHECK-NEXT: }

double fn17 (double x, double* y) {
x = add(x, y);
x = add(x, &x);
return x;
}

//CHECK: void fn17_grad_0(double x, double *y, double *_d_x) {
//CHECK-NEXT: double _t0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = add(x, y);
//CHECK-NEXT: _t1 = x;
//CHECK-NEXT: x = add(x, &x);
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: *_d_x += 1;
//CHECK-NEXT: {
//CHECK-NEXT: x = _t1;
//CHECK-NEXT: double _r_d1 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d1;
//CHECK-NEXT: double _r1 = 0;
//CHECK-NEXT: add_pullback(x, &x, _r_d1, &_r1, &*_d_x);
//CHECK-NEXT: *_d_x += _r1;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: double _r_d0 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d0;
//CHECK-NEXT: double _r0 = 0;
//CHECK-NEXT: add_pullback(x, y, _r_d0, &_r0);
//CHECK-NEXT: *_d_x += _r0;
//CHECK-NEXT: }
//CHECK-NEXT: }

template<typename T>
void reset(T* arr, int n) {
for (int i=0; i<n; ++i)
Expand Down Expand Up @@ -897,9 +844,4 @@ int main() {
TEST2(fn15, 6, -2) // CHECK-EXEC: {1.00, 1.00}
INIT(fn16);
TEST2(fn16, 12, 8) // CHECK-EXEC: {8.00, 8.00}

auto fn17_grad_0 = clad::gradient(fn17, "x");
double y[] = {3.0, 2.0}, dx = 0;
fn17_grad_0.execute(5, y, &dx);
printf("{%.2f}\n", dx); // CHECK-EXEC: {2.00}
}

0 comments on commit 770faa8

Please sign in to comment.