Skip to content

Commit

Permalink
Fix call expr to functor inside a function
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Sep 13, 2023
1 parent 59fd672 commit 9286fdc
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 26 deletions.
73 changes: 47 additions & 26 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActAfterParsingDiffArgs(request, args);

auto derivativeName = request.BaseFunctionName + "_pullback";
auto derivativeName =
utils::ComputeEffectiveFnName(m_Function) + "_pullback";
auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName);

auto paramTypes = ComputeParamTypes(args);
Expand Down Expand Up @@ -1412,12 +1413,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// statements there later.
std::size_t insertionPoint = getCurrentBlock(direction::reverse).size();

// `CXXOperatorCallExpr` have the `base` expression as the first argument.
size_t skipFirstArg = 0;

// Here we do not need to check if FD is an instance method or a static
// method because C++ forbids creating operator overloads as static methods.
if (isa<CXXOperatorCallExpr>(CE) && isa<CXXMethodDecl>(FD))
skipFirstArg = 1;

// FIXME: We should add instructions for handling non-differentiable
// arguments. Currently we are implicitly assuming function call only
// contains differentiable arguments.
for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) {
for (std::size_t i = skipFirstArg, e = CE->getNumArgs(); i != e; ++i) {
const Expr* arg = CE->getArg(i);
auto PVD = FD->getParamDecl(i);
const auto* PVD = FD->getParamDecl(i - skipFirstArg);
StmtDiff argDiff{};
bool passByRef = utils::IsReferenceOrPointerType(PVD->getType());
// We do not need to create result arg for arguments passed by reference
Expand Down Expand Up @@ -1597,26 +1606,37 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

/// Add base derivative expression in the derived call output args list if
/// `CE` is a call to an instance member function.
if (auto MCE = dyn_cast<CXXMemberCallExpr>(CE)) {
baseDiff = Visit(MCE->getImplicitObjectArgument());
StmtDiff baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr());
if (isInsideLoop) {
addToCurrentBlock(baseDiffStore.getExpr());
VarDecl* baseLocalVD = BuildVarDecl(
baseDiffStore.getExpr_dx()->getType(),
CreateUniqueIdentifier("_r"), baseDiffStore.getExpr_dx(),
/*DirectInit=*/false, /*TSI=*/nullptr,
VarDecl::InitializationStyle::CInit);
auto& block = getCurrentBlock(direction::reverse);
block.insert(block.begin() + insertionPoint,
BuildDeclStmt(baseLocalVD));
insertionPoint += 1;
Expr* baseLocalE = BuildDeclRef(baseLocalVD);
baseDiffStore = {baseDiffStore.getExpr(), baseLocalE};
if (const auto* MD = dyn_cast<CXXMethodDecl>(FD)) {
if (MD->isInstance()) {
const Expr* baseOriginalE = nullptr;
if (const auto* MCE = dyn_cast<CXXMemberCallExpr>(CE))
baseOriginalE = MCE->getImplicitObjectArgument();
else if (const auto* OCE = dyn_cast<CXXOperatorCallExpr>(CE))
baseOriginalE = OCE->getArg(0);

baseDiff = Visit(baseOriginalE);
StmtDiff baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr());
if (isInsideLoop) {
addToCurrentBlock(baseDiffStore.getExpr());
VarDecl* baseLocalVD = BuildVarDecl(
baseDiffStore.getExpr_dx()->getType(),
CreateUniqueIdentifier("_r"), baseDiffStore.getExpr_dx(),
/*DirectInit=*/false, /*TSI=*/nullptr,
VarDecl::InitializationStyle::CInit);
auto& block = getCurrentBlock(direction::reverse);
block.insert(block.begin() + insertionPoint,
BuildDeclStmt(baseLocalVD));
insertionPoint += 1;
Expr* baseLocalE = BuildDeclRef(baseLocalVD);
baseDiffStore = {baseDiffStore.getExpr(), baseLocalE};
}
baseDiff = {baseDiffStore.getExpr_dx(), baseDiff.getExpr_dx()};
Expr* baseDerivative = baseDiff.getExpr_dx();
if (!baseDerivative->getType()->isPointerType())
baseDerivative =
BuildOp(UnaryOperatorKind::UO_AddrOf, baseDerivative);
DerivedCallOutputArgs.push_back(baseDerivative);
}
baseDiff = {baseDiffStore.getExpr_dx(), baseDiff.getExpr_dx()};
DerivedCallOutputArgs.push_back(
BuildOp(UnaryOperatorKind::UO_AddrOf, baseDiff.getExpr_dx()));
}

for (auto argDerivative : CallArgDx) {
Expand Down Expand Up @@ -1689,7 +1709,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
pullbackCallArgs = DerivedCallArgs;

if (pullback)
pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs(),
pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs() -
static_cast<int>(skipFirstArg),
pullback);

// Try to find it in builtin derivatives
Expand Down Expand Up @@ -1775,7 +1796,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
usingNumericalDiff = true;
}
} else if (pullbackFD) {
if (isa<CXXMemberCallExpr>(CE)) {
if (baseDiff.getExpr()) {
Expr* baseE = baseDiff.getExpr();
OverloadedDerivedFn = BuildCallExprToMemFn(
baseE, pullbackFD->getName(), pullbackCallArgs, pullbackFD);
Expand Down Expand Up @@ -1878,7 +1899,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// We cannot reuse the derivatives previously computed because
// they might contain 'clad::pop(..)` expression.
if (isa<CXXMemberCallExpr>(CE)) {
if (baseDiff.getExpr_dx()) {
Expr* derivedBase = baseDiff.getExpr_dx();
// FIXME: We may need this if-block once we support pointers, and
// passing pointers-by-reference if
Expand Down Expand Up @@ -1906,7 +1927,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
} else
CallArgs.push_back(m_Sema.ActOnCXXNullPtrLiteral(noLoc).get());
}
if (isa<CXXMemberCallExpr>(CE)) {
if (baseDiff.getExpr()) {
Expr* baseE = baseDiff.getExpr();
call = BuildCallExprToMemFn(baseE, calleeFnForwPassFD->getName(),
CallArgs, calleeFnForwPassFD);
Expand Down
34 changes: 34 additions & 0 deletions test/Gradient/Functors.C
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ struct ExperimentNNS {
} // namespace inner
} // namespace outer

// A function calling operator() on a functor.
double CallFunctor(double i, double j) {
Experiment E(3, 5);
return E(i, j);
}

#define INIT(E) \
auto E##_grad = clad::gradient(&E); \
auto E##Ref_grad = clad::gradient(E);
Expand Down Expand Up @@ -298,4 +304,32 @@ int main() {
// CHECK-EXEC: 54.00 42.00
TEST_LAMBDA(lambdaWithCapture); // CHECK-EXEC: 54.00 42.00
// CHECK-EXEC: 54.00 42.00

// CHECK: void CallFunctor_grad(double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) {
// CHECK-NEXT: Experiment _d_E({});
// CHECK-NEXT: double _t0;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: Experiment _t2;
// CHECK-NEXT: Experiment E(3, 5);
// CHECK-NEXT: _t0 = i;
// CHECK-NEXT: _t1 = j;
// CHECK-NEXT: _t2 = E;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: double _grad0 = 0.;
// CHECK-NEXT: double _grad1 = 0.;
// CHECK-NEXT: _t2.operator_call_pullback(_t0, _t1, 1, &_d_E, &_grad0, &_grad1);
// CHECK-NEXT: double _r0 = _grad0;
// CHECK-NEXT: * _d_i += _r0;
// CHECK-NEXT: double _r1 = _grad1;
// CHECK-NEXT: * _d_j += _r1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// testing differentiating a function calling operator() on a functor
auto CallFunctor_grad = clad::gradient(CallFunctor);
double di = 0, dj = 0;
CallFunctor_grad.execute(7, 9, &di, &dj);
printf("%.2f %.2f\n", di, dj); // CHECK-EXEC: 27.00 21.00
}

0 comments on commit 9286fdc

Please sign in to comment.