Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix call expr to functor inside a function #626

Merged
merged 1 commit into from
Sep 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 47 additions & 26 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActAfterParsingDiffArgs(request, args);

auto derivativeName = request.BaseFunctionName + "_pullback";
auto derivativeName =
utils::ComputeEffectiveFnName(m_Function) + "_pullback";
auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName);

auto paramTypes = ComputeParamTypes(args);
Expand Down Expand Up @@ -1412,12 +1413,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// statements there later.
std::size_t insertionPoint = getCurrentBlock(direction::reverse).size();

// `CXXOperatorCallExpr` have the `base` expression as the first argument.
size_t skipFirstArg = 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the late review.

What is the main issue if we do not skip the first arg? I suppose it might be required if we modify the functor's member variables in the call operator.

Copy link
Collaborator Author

@vaithak vaithak Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first arg in the operator call is the base expression which is not needed in calling the pullback function, I used the logic of skipping the first arg from this in forward mode: https://github.com/vgvassilev/clad/blob/master/lib/Differentiator/BaseForwardModeVisitor.cpp#L980


// Here we do not need to check if FD is an instance method or a static
// method because C++ forbids creating operator overloads as static methods.
if (isa<CXXOperatorCallExpr>(CE) && isa<CXXMethodDecl>(FD))
skipFirstArg = 1;

// FIXME: We should add instructions for handling non-differentiable
// arguments. Currently we are implicitly assuming function call only
// contains differentiable arguments.
for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) {
for (std::size_t i = skipFirstArg, e = CE->getNumArgs(); i != e; ++i) {
vaithak marked this conversation as resolved.
Show resolved Hide resolved
const Expr* arg = CE->getArg(i);
auto PVD = FD->getParamDecl(i);
const auto* PVD = FD->getParamDecl(i - skipFirstArg);
StmtDiff argDiff{};
bool passByRef = utils::IsReferenceOrPointerType(PVD->getType());
// We do not need to create result arg for arguments passed by reference
Expand Down Expand Up @@ -1597,26 +1606,37 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

/// Add base derivative expression in the derived call output args list if
/// `CE` is a call to an instance member function.
if (auto MCE = dyn_cast<CXXMemberCallExpr>(CE)) {
baseDiff = Visit(MCE->getImplicitObjectArgument());
StmtDiff baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr());
if (isInsideLoop) {
addToCurrentBlock(baseDiffStore.getExpr());
VarDecl* baseLocalVD = BuildVarDecl(
baseDiffStore.getExpr_dx()->getType(),
CreateUniqueIdentifier("_r"), baseDiffStore.getExpr_dx(),
/*DirectInit=*/false, /*TSI=*/nullptr,
VarDecl::InitializationStyle::CInit);
auto& block = getCurrentBlock(direction::reverse);
block.insert(block.begin() + insertionPoint,
BuildDeclStmt(baseLocalVD));
insertionPoint += 1;
Expr* baseLocalE = BuildDeclRef(baseLocalVD);
baseDiffStore = {baseDiffStore.getExpr(), baseLocalE};
if (const auto* MD = dyn_cast<CXXMethodDecl>(FD)) {
if (MD->isInstance()) {
const Expr* baseOriginalE = nullptr;
if (const auto* MCE = dyn_cast<CXXMemberCallExpr>(CE))
baseOriginalE = MCE->getImplicitObjectArgument();
else if (const auto* OCE = dyn_cast<CXXOperatorCallExpr>(CE))
baseOriginalE = OCE->getArg(0);

baseDiff = Visit(baseOriginalE);
StmtDiff baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr());
if (isInsideLoop) {
addToCurrentBlock(baseDiffStore.getExpr());
VarDecl* baseLocalVD = BuildVarDecl(
baseDiffStore.getExpr_dx()->getType(),
CreateUniqueIdentifier("_r"), baseDiffStore.getExpr_dx(),
/*DirectInit=*/false, /*TSI=*/nullptr,
VarDecl::InitializationStyle::CInit);
auto& block = getCurrentBlock(direction::reverse);
block.insert(block.begin() + insertionPoint,
BuildDeclStmt(baseLocalVD));
insertionPoint += 1;
Expr* baseLocalE = BuildDeclRef(baseLocalVD);
baseDiffStore = {baseDiffStore.getExpr(), baseLocalE};
}
baseDiff = {baseDiffStore.getExpr_dx(), baseDiff.getExpr_dx()};
Expr* baseDerivative = baseDiff.getExpr_dx();
if (!baseDerivative->getType()->isPointerType())
baseDerivative =
BuildOp(UnaryOperatorKind::UO_AddrOf, baseDerivative);
DerivedCallOutputArgs.push_back(baseDerivative);
}
baseDiff = {baseDiffStore.getExpr_dx(), baseDiff.getExpr_dx()};
DerivedCallOutputArgs.push_back(
BuildOp(UnaryOperatorKind::UO_AddrOf, baseDiff.getExpr_dx()));
}

for (auto argDerivative : CallArgDx) {
Expand Down Expand Up @@ -1689,7 +1709,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
pullbackCallArgs = DerivedCallArgs;

if (pullback)
pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs(),
pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs() -
static_cast<int>(skipFirstArg),
pullback);

// Try to find it in builtin derivatives
Expand Down Expand Up @@ -1775,7 +1796,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
usingNumericalDiff = true;
}
} else if (pullbackFD) {
if (isa<CXXMemberCallExpr>(CE)) {
if (baseDiff.getExpr()) {
Expr* baseE = baseDiff.getExpr();
OverloadedDerivedFn = BuildCallExprToMemFn(
baseE, pullbackFD->getName(), pullbackCallArgs, pullbackFD);
Expand Down Expand Up @@ -1878,7 +1899,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// We cannot reuse the derivatives previously computed because
// they might contain 'clad::pop(..)` expression.
if (isa<CXXMemberCallExpr>(CE)) {
if (baseDiff.getExpr_dx()) {
Expr* derivedBase = baseDiff.getExpr_dx();
// FIXME: We may need this if-block once we support pointers, and
// passing pointers-by-reference if
Expand Down Expand Up @@ -1906,7 +1927,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
} else
CallArgs.push_back(m_Sema.ActOnCXXNullPtrLiteral(noLoc).get());
}
if (isa<CXXMemberCallExpr>(CE)) {
if (baseDiff.getExpr()) {
Expr* baseE = baseDiff.getExpr();
call = BuildCallExprToMemFn(baseE, calleeFnForwPassFD->getName(),
CallArgs, calleeFnForwPassFD);
Expand Down
34 changes: 34 additions & 0 deletions test/Gradient/Functors.C
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ struct ExperimentNNS {
} // namespace inner
} // namespace outer

// A function calling operator() on a functor.
double CallFunctor(double i, double j) {
Experiment E(3, 5);
return E(i, j);
}

#define INIT(E) \
auto E##_grad = clad::gradient(&E); \
auto E##Ref_grad = clad::gradient(E);
Expand Down Expand Up @@ -298,4 +304,32 @@ int main() {
// CHECK-EXEC: 54.00 42.00
TEST_LAMBDA(lambdaWithCapture); // CHECK-EXEC: 54.00 42.00
// CHECK-EXEC: 54.00 42.00

// CHECK: void CallFunctor_grad(double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) {
// CHECK-NEXT: Experiment _d_E({});
// CHECK-NEXT: double _t0;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: Experiment _t2;
// CHECK-NEXT: Experiment E(3, 5);
// CHECK-NEXT: _t0 = i;
// CHECK-NEXT: _t1 = j;
// CHECK-NEXT: _t2 = E;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: double _grad0 = 0.;
// CHECK-NEXT: double _grad1 = 0.;
// CHECK-NEXT: _t2.operator_call_pullback(_t0, _t1, 1, &_d_E, &_grad0, &_grad1);
// CHECK-NEXT: double _r0 = _grad0;
// CHECK-NEXT: * _d_i += _r0;
// CHECK-NEXT: double _r1 = _grad1;
// CHECK-NEXT: * _d_j += _r1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// testing differentiating a function calling operator() on a functor
auto CallFunctor_grad = clad::gradient(CallFunctor);
double di = 0, dj = 0;
CallFunctor_grad.execute(7, 9, &di, &dj);
printf("%.2f %.2f\n", di, dj); // CHECK-EXEC: 27.00 21.00
}
Loading