Skip to content

Commit

Permalink
Set derivative to 0 for fxn calls with literal arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Oct 20, 2023
1 parent dd4c9fc commit 16cd3f3
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 22 deletions.
29 changes: 29 additions & 0 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,35 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
.get();
}

// If all arguments are constant literals, then this does not contribute to
// the gradient.
if (!callDiff) {
if (!isa<CXXOperatorCallExpr>(CE) && !isa<CXXMemberCallExpr>(CE)) {
bool allArgsAreConstantLiterals = true;
for (unsigned i = 0, e = CE->getNumArgs(); i < e; ++i) {
const Expr* arg = CE->getArg(i);
// if it's of type MaterializeTemporaryExpr, then check its
// subexpression.
if (const auto* MTE = dyn_cast<MaterializeTemporaryExpr>(arg))
arg = MTE->getSubExpr();
if (!isa<FloatingLiteral>(arg) && !isa<IntegerLiteral>(arg)) {
allArgsAreConstantLiterals = false;
break;
}
}
if (allArgsAreConstantLiterals) {
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc,
llvm::MutableArrayRef<Expr*>(CallArgs), noLoc)
.get();
auto zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
return StmtDiff(call, zero);
}
}
}

if (!callDiff) {
// Overloaded derivative was not found, request the CladPlugin to
// derive the called function.
Expand Down
36 changes: 23 additions & 13 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
!isa<CXXOperatorCallExpr>(CE))
return StmtDiff(Clone(CE));

// If all arguments are constant literals, then this does not contribute to
// the gradient.
if (!isa<CXXMemberCallExpr>(CE) && !isa<CXXOperatorCallExpr>(CE)) {
bool allArgsAreConstantLiterals = true;
for (const Expr* arg : CE->arguments()) {
// if it's of type MaterializeTemporaryExpr, then check its
// subexpression.
if (const auto* MTE = dyn_cast<MaterializeTemporaryExpr>(arg))
arg = MTE->getSubExpr();

Check warning on line 1392 in lib/Differentiator/ReverseModeVisitor.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/ReverseModeVisitor.cpp#L1392

Added line #L1392 was not covered by tests
if (!isa<FloatingLiteral>(arg) && !isa<IntegerLiteral>(arg)) {
allArgsAreConstantLiterals = false;
break;
}
}
if (allArgsAreConstantLiterals)
return StmtDiff(Clone(CE));
}

// Stores the call arguments for the function to be derived
llvm::SmallVector<Expr*, 16> CallArgs{};
// Stores the dx of the call arguments for the function to be derived
Expand Down Expand Up @@ -1419,14 +1437,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// statements there later.
std::size_t insertionPoint = getCurrentBlock(direction::reverse).size();

// `CXXOperatorCallExpr` have the `base` expression as the first argument.
size_t skipFirstArg = 0;

// Here we do not need to check if FD is an instance method or a static
// method because C++ forbids creating operator overloads as static methods.
if (isa<CXXOperatorCallExpr>(CE) && isa<CXXMethodDecl>(FD))
skipFirstArg = 1;

// FIXME: We should add instructions for handling non-differentiable
// arguments. Currently we are implicitly assuming function call only
// contains differentiable arguments.
Expand Down Expand Up @@ -1665,7 +1675,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// is required because the pullback function expects `clad::array_ref`
// type for representing array derivatives. Currently, only constant
// array data members have derivatives of constant array types.
if (isa<ConstantArrayType>(argDerivative->getType())) {
if (argDerivative &&
isa<ConstantArrayType>(argDerivative->getType())) {
Expr* init =
utils::BuildCladArrayInitByConstArray(m_Sema, argDerivative);
auto derivativeArrayRefVD = BuildVarDecl(
Expand All @@ -1676,11 +1687,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
ArgDeclStmts.push_back(BuildDeclStmt(derivativeArrayRefVD));
argDerivative = BuildDeclRef(derivativeArrayRefVD);
}
if (isCladArrayType(argDerivative->getType())) {
if (argDerivative && isCladArrayType(argDerivative->getType()))
gradArgExpr = argDerivative;
} else {
else
gradArgExpr = BuildOp(UnaryOperatorKind::UO_AddrOf, argDerivative);
}
} else {
// Declare: diffArgType _grad = 0;
gradVarDecl = BuildVarDecl(
Expand Down Expand Up @@ -1721,7 +1731,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

if (pullback)
pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs() -
static_cast<int>(skipFirstArg),
static_cast<int>(isCXXOperatorCall),
pullback);

// Try to find it in builtin derivatives
Expand Down
7 changes: 1 addition & 6 deletions test/FirstDerivative/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,9 @@ float test_4(int x) {
return overloaded();
}

// CHECK: {{(clad::)?}}ValueAndPushforward<int, int> overloaded_pushforward() {
// CHECK-NEXT: return {3, 0};
// CHECK-NEXT: }

// CHECK: float test_4_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<int, int> _t0 = overloaded_pushforward();
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: return 0;
// CHECK-NEXT: }

float test_5(int x) {
Expand Down
8 changes: 5 additions & 3 deletions test/FirstDerivative/FunctionCallsWithResults.C
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ double sum(double* arr, int n) {
double fn8(double i, double j) {
double arr[5] = {};
modifyArr(arr, 5, i*j);
return sum(arr, 5);
return sum(arr, 5) * std::tanh(1.0);
}

// CHECK: double fn8_darg0(double i, double j) {
Expand All @@ -307,7 +307,9 @@ double fn8(double i, double j) {
// CHECK-NEXT: double arr[5] = {};
// CHECK-NEXT: modifyArr_pushforward(arr, 5, i * j, _d_arr, 0, _d_i * j + i * _d_j);
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = sum_pushforward(arr, 5, _d_arr, 0);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: double &_t1 = _t0.value;
// CHECK-NEXT: double _t2 = std::tanh(1.);
// CHECK-NEXT: return _t0.pushforward * _t2 + _t1 * 0;
// CHECK-NEXT: }

float test_1_darg0(float x);
Expand Down Expand Up @@ -346,6 +348,6 @@ int main () {
TEST(fn5, 3, 5); // CHECK-EXEC: {1.00}
TEST(fn6, 3, 5, 7); // CHECK-EXEC: {3.00}
TEST(fn7, 3, 5); // CHECK-EXEC: {8.00}
TEST(fn8, 3, 5); // CHECK-EXEC: {25.00}
TEST(fn8, 3, 5); // CHECK-EXEC: {19.04}
return 0;
}
26 changes: 26 additions & 0 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,30 @@ double fn7(double i, double j) {
// CHECK-NEXT: }
// CHECK-NEXT: }

double fn8(double x, double y) {
return x*y*std::tanh(1.0);
}

// CHECK: void fn8_grad(double x, double y, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y) {
// CHECK-NEXT: double _t0;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: double _t2;
// CHECK-NEXT: double _t3;
// CHECK-NEXT: _t2 = x;
// CHECK-NEXT: _t1 = y;
// CHECK-NEXT: _t3 = _t2 * _t1;
// CHECK-NEXT: _t0 = std::tanh(1.);
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 1 * _t0;
// CHECK-NEXT: double _r1 = _r0 * _t1;
// CHECK-NEXT: * _d_x += _r1;
// CHECK-NEXT: double _r2 = _t2 * _r0;
// CHECK-NEXT: * _d_y += _r2;
// CHECK-NEXT: double _r3 = _t3 * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

template<typename T>
void reset(T* arr, int n) {
Expand Down Expand Up @@ -513,6 +537,7 @@ int main() {
INIT(fn5);
INIT(fn6);
INIT(fn7);
INIT(fn8);

TEST1_float(fn1, 11); // CHECK-EXEC: {3.00}
TEST2(fn2, 3, 5); // CHECK-EXEC: {1.00, 3.00}
Expand All @@ -522,4 +547,5 @@ int main() {
TEST_ARR5(fn5, arr, 5); // CHECK-EXEC: {5.00, 1.00, 0.00, 0.00, 0.00}
TEST2(fn6, 3, 5); // CHECK-EXEC: {5.00, 3.00}
TEST2(fn7, 3, 5); // CHECK-EXEC: {10.00, 71.00}
TEST2(fn8, 3, 5); // CHECK-EXEC: {3.81, 2.28}
}

0 comments on commit 16cd3f3

Please sign in to comment.