Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gradient computation of higher order functions #645

Merged
merged 1 commit into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1345,9 +1345,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
// Create the (_d_param[idx] += dfdx) statement.
if (dfdx()) {
Expr* add_assign = BuildOp(BO_AddAssign, it->second, dfdx());
// Add it to the body statements.
addToCurrentBlock(add_assign, direction::reverse);
// FIXME: not sure if this is generic.
// Don't update derivatives of non-record types.
if (!decl->getType()->isRecordType()) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check this with isa<RecordDecl>(decl)?

Copy link
Collaborator Author

@vaithak vaithak Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had that earlier but it failed with clang-16. isRecordType does the same thing inside it, but performs the operation on the CanonicalType instead of the qualified type.
source code of isRecordType.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, ok. That means that if we did isa<RecordDecl>(decl->getCanonicalDecl()) would work?

Copy link
Collaborator Author

@vaithak vaithak Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that works. I saw isRecordType being used at different places in our codebase, so I used that.
Should I change this to use isa<RecordDecl>? Any reason to prefer that?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, let's use your current solution if it is used elsewhere.

auto* add_assign = BuildOp(BO_AddAssign, it->second, dfdx());
// Add it to the body statements.
addToCurrentBlock(add_assign, direction::reverse);
}
}
return StmtDiff(clonedDRE, it->second, it->second);
}
Expand Down Expand Up @@ -1694,10 +1698,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
else
gradArgExpr = BuildOp(UnaryOperatorKind::UO_AddrOf, argDerivative);
} else {
// Declare: diffArgType _grad = 0;
gradVarDecl = BuildVarDecl(
PVD->getType(), gradVarII,
ConstantFolder::synthesizeLiteral(PVD->getType(), m_Context, 0));
// Declare: diffArgType _grad;
Expr* initVal = nullptr;
if (!PVD->getType()->isRecordType()) {
// If the argument is not a class type, then initialize the grad
// variable with 0.
initVal =
ConstantFolder::synthesizeLiteral(PVD->getType(), m_Context, 0);
}
gradVarDecl = BuildVarDecl(PVD->getType(), gradVarII, initVal);
// Pass the address of the declared variable
gradVarExpr = BuildDeclRef(gradVarDecl);
gradArgExpr =
Expand Down
89 changes: 89 additions & 0 deletions test/Gradient/Functors.C
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,18 @@ double CallFunctor(double i, double j) {
return E(i, j);
}

// A function taking functor as an argument.
template<typename Func>
double FunctorAsArg(Func fn, double i, double j) {
return fn(i, j);
}

// A wrapper for function taking functor as an argument.
double FunctorAsArgWrapper(double i, double j) {
Experiment E(3, 5);
return FunctorAsArg(E, i, j);
}

#define INIT(E) \
auto E##_grad = clad::gradient(&E); \
auto E##Ref_grad = clad::gradient(E);
Expand Down Expand Up @@ -332,4 +344,81 @@ int main() {
double di = 0, dj = 0;
CallFunctor_grad.execute(7, 9, &di, &dj);
printf("%.2f %.2f\n", di, dj); // CHECK-EXEC: 27.00 21.00

// CHECK: void FunctorAsArg_grad(Experiment fn, double i, double j, clad::array_ref<Experiment> _d_fn, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) {
// CHECK-NEXT: double _t0;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: Experiment _t2;
// CHECK-NEXT: _t0 = i;
// CHECK-NEXT: _t1 = j;
// CHECK-NEXT: _t2 = fn;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: double _grad0 = 0.;
// CHECK-NEXT: double _grad1 = 0.;
// CHECK-NEXT: _t2.operator_call_pullback(_t0, _t1, 1, &(* _d_fn), &_grad0, &_grad1);
// CHECK-NEXT: double _r0 = _grad0;
// CHECK-NEXT: * _d_i += _r0;
// CHECK-NEXT: double _r1 = _grad1;
// CHECK-NEXT: * _d_j += _r1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// testing differentiating a function taking functor as an argument
auto FunctorAsArg_grad = clad::gradient(FunctorAsArg<Experiment>);
di = 0, dj = 0;
Experiment E_temp(3, 5), dE_temp;
FunctorAsArg_grad.execute(E_temp, 7, 9, &dE_temp, &di, &dj);
printf("%.2f %.2f\n", di, dj); // CHECK-EXEC: 27.00 21.00

// CHECK: void FunctorAsArg_pullback(Experiment fn, double i, double j, double _d_y, clad::array_ref<Experiment> _d_fn, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) {
// CHECK-NEXT: double _t0;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: Experiment _t2;
// CHECK-NEXT: _t0 = i;
// CHECK-NEXT: _t1 = j;
// CHECK-NEXT: _t2 = fn;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: double _grad0 = 0.;
// CHECK-NEXT: double _grad1 = 0.;
// CHECK-NEXT: _t2.operator_call_pullback(_t0, _t1, _d_y, &(* _d_fn), &_grad0, &_grad1);
// CHECK-NEXT: double _r0 = _grad0;
// CHECK-NEXT: * _d_i += _r0;
// CHECK-NEXT: double _r1 = _grad1;
// CHECK-NEXT: * _d_j += _r1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void FunctorAsArgWrapper_grad(double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) {
// CHECK-NEXT: Experiment _d_E({});
// CHECK-NEXT: Experiment _t0;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: double _t2;
// CHECK-NEXT: Experiment E(3, 5);
// CHECK-NEXT: _t0 = E
// CHECK-NEXT: _t1 = i;
// CHECK-NEXT: _t2 = j;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: Experiment _grad0;
// CHECK-NEXT: double _grad1 = 0.;
// CHECK-NEXT: double _grad2 = 0.;
// CHECK-NEXT: FunctorAsArg_pullback(_t0, _t1, _t2, 1, &_grad0, &_grad1, &_grad2);
// CHECK-NEXT: Experiment _r0(_grad0);
// CHECK-NEXT: double _r1 = _grad1;
// CHECK-NEXT: * _d_i += _r1;
// CHECK-NEXT: double _r2 = _grad2;
// CHECK-NEXT: * _d_j += _r2;
// CHECK-NEXT: }
// CHECK-NEXT: }

// testing differentiating a wrapper for function taking functor as an argument
auto FunctorAsArgWrapper_grad = clad::gradient(FunctorAsArgWrapper);
di = 0, dj = 0;
FunctorAsArgWrapper_grad.execute(7, 9, &di, &dj);
printf("%.2f %.2f\n", di, dj); // CHECK-EXEC: 27.00 21.00
}
Loading