Skip to content

Commit

Permalink
Add support for std::min, std::max and std::clamp functions
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Oct 26, 2023
1 parent 8bda639 commit da2b40c
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 3 deletions.
59 changes: 59 additions & 0 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,61 @@ CUDA_HOST_DEVICE void fma_pullback(T1 a, T2 b, T3 c, T4 d_y,
*d_c += d_y;
}

template <typename T>
CUDA_HOST_DEVICE ValueAndPushforward<T, T>
min_pushforward(const T& a, const T& b, const T& d_a, const T& d_b) {
return {::std::min(a, b), a < b ? d_a : d_b};
}

template <typename T>
CUDA_HOST_DEVICE ValueAndPushforward<T, T>
max_pushforward(const T& a, const T& b, const T& d_a, const T& d_b) {
return {::std::max(a, b), a < b ? d_b : d_a};
}

template <typename T, typename U>
CUDA_HOST_DEVICE void min_pullback(const T& a, const T& b, U d_y,
clad::array_ref<decltype(T())> d_a,
clad::array_ref<decltype(T())> d_b) {
if (a < b)
*d_a += d_y;
else
*d_b += d_y;
}

template <typename T, typename U>
CUDA_HOST_DEVICE void max_pullback(const T& a, const T& b, U d_y,
clad::array_ref<decltype(T())> d_a,
clad::array_ref<decltype(T())> d_b) {
if (a < b)
*d_b += d_y;
else
*d_a += d_y;
}

#if __cplusplus >= 201703L
template <typename T>
CUDA_HOST_DEVICE ValueAndPushforward<T, T>
clamp_pushforward(const T& v, const T& lo, const T& hi, const T& d_v,
const T& d_lo, const T& d_hi) {
return {::std::clamp(v, lo, hi), v < lo ? d_lo : hi < v ? d_hi : d_v};
}

template <typename T, typename U>
CUDA_HOST_DEVICE void clamp_pullback(const T& v, const T& lo, const T& hi,
const U& d_y,
clad::array_ref<decltype(T())> d_v,
clad::array_ref<decltype(T())> d_lo,
clad::array_ref<decltype(T())> d_hi) {
if (v < lo)
*d_lo += d_y;
else if (hi < v)
*d_hi += d_y;
else
*d_v += d_y;
}
#endif

} // namespace std
// These are required because C variants of mathematical functions are
// defined in global namespace.
Expand All @@ -150,6 +205,10 @@ using std::floor_pushforward;
using std::fma_pullback;
using std::fma_pushforward;
using std::log_pushforward;
using std::max_pullback;
using std::max_pushforward;
using std::min_pullback;
using std::min_pushforward;
using std::pow_pullback;
using std::pow_pushforward;
using std::sin_pushforward;
Expand Down
6 changes: 4 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1394,7 +1394,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// if it's of type MaterializeTemporaryExpr, then check its
// subexpression.
if (const auto* MTE = dyn_cast<MaterializeTemporaryExpr>(arg))
arg = clad_compat::GetSubExpr(MTE);
arg = clad_compat::GetSubExpr(MTE)->IgnoreImpCasts();
if (!isa<FloatingLiteral>(arg) && !isa<IntegerLiteral>(arg)) {
allArgsAreConstantLiterals = false;
break;
Expand Down Expand Up @@ -1934,7 +1934,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

Expr* call = nullptr;

if (FD->getReturnType()->isReferenceType()) {
QualType returnType = FD->getReturnType();
if (returnType->isReferenceType() &&
!returnType.getNonReferenceType().isConstQualified()) {
DiffRequest calleeFnForwPassReq;
calleeFnForwPassReq.Function = FD;
calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass;
Expand Down
65 changes: 64 additions & 1 deletion test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %cladnumdiffclang %s -I%S/../../include -oFunctionCalls.out 2>&1 | FileCheck %s
// RUN: %cladnumdiffclang -std=c++17 %s -I%S/../../include -oFunctionCalls.out 2>&1 | FileCheck %s
// RUN: ./FunctionCalls.out | FileCheck -check-prefix=CHECK-EXEC %s

//CHECK-NOT: {{.*error|warning|note:.*}}
Expand Down Expand Up @@ -533,6 +533,67 @@ double fn9(double x, double y) {
// CHECK-NEXT: }
// CHECK-NEXT: }

double fn10(double x, double y) {
double out = x;
out = std::max(out, 0.0);
out = std::min(out, 10.0);
out = std::clamp(out, 3.0, 7.0);
return out * y;
}

// CHECK: void fn10_grad(double x, double y, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y) {
// CHECK-NEXT: double _d_out = 0;
// CHECK-NEXT: double _t0;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: double _t2;
// CHECK-NEXT: double _t3;
// CHECK-NEXT: double _t4;
// CHECK-NEXT: double out = x;
// CHECK-NEXT: _t0 = out;
// CHECK-NEXT: out = std::max(out, 0.);
// CHECK-NEXT: _t1 = out;
// CHECK-NEXT: out = std::min(out, 10.);
// CHECK-NEXT: _t2 = out;
// CHECK-NEXT: out = std::clamp(out, 3., 7.);
// CHECK-NEXT: _t4 = out;
// CHECK-NEXT: _t3 = y;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: double _r7 = 1 * _t3;
// CHECK-NEXT: _d_out += _r7;
// CHECK-NEXT: double _r8 = _t4 * 1;
// CHECK-NEXT: * _d_y += _r8;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: double _r_d2 = _d_out;
// CHECK-NEXT: double _grad5 = 0.;
// CHECK-NEXT: double _grad6 = 0.;
// CHECK-NEXT: clad::custom_derivatives::std::clamp_pullback(_t2, 3., 7., _r_d2, &_d_out, &_grad5, &_grad6);
// CHECK-NEXT: double _r4 = _d_out;
// CHECK-NEXT: double _r5 = _grad5;
// CHECK-NEXT: double _r6 = _grad6;
// CHECK-NEXT: _d_out -= _r_d2;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: double _r_d1 = _d_out;
// CHECK-NEXT: double _grad3 = 0.;
// CHECK-NEXT: clad::custom_derivatives::std::min_pullback(_t1, 10., _r_d1, &_d_out, &_grad3);
// CHECK-NEXT: double _r2 = _d_out;
// CHECK-NEXT: double _r3 = _grad3;
// CHECK-NEXT: _d_out -= _r_d1;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: double _r_d0 = _d_out;
// CHECK-NEXT: double _grad1 = 0.;
// CHECK-NEXT: clad::custom_derivatives::std::max_pullback(_t0, 0., _r_d0, &_d_out, &_grad1);
// CHECK-NEXT: double _r0 = _d_out;
// CHECK-NEXT: double _r1 = _grad1;
// CHECK-NEXT: _d_out -= _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: * _d_x += _d_out;
// CHECK-NEXT: }

template<typename T>
void reset(T* arr, int n) {
for (int i=0; i<n; ++i)
Expand Down Expand Up @@ -587,6 +648,7 @@ int main() {
INIT(fn7);
INIT(fn8);
INIT(fn9);
INIT(fn10);

TEST1_float(fn1, 11); // CHECK-EXEC: {3.00}
TEST2(fn2, 3, 5); // CHECK-EXEC: {1.00, 3.00}
Expand All @@ -598,4 +660,5 @@ int main() {
TEST2(fn7, 3, 5); // CHECK-EXEC: {10.00, 71.00}
TEST2(fn8, 3, 5); // CHECK-EXEC: {7.62, 4.57}
TEST2(fn9, 3, 5); // CHECK-EXEC: {5.00, 3.00}
TEST2(fn10, 8, 5); // CHECK-EXEC: {0.00, 7.00}
}

0 comments on commit da2b40c

Please sign in to comment.