Skip to content

Commit

Permalink
Merge the vector pow custom derivative with the non-vector one using …
Browse files Browse the repository at this point in the history
…templates
  • Loading branch information
PetroZarytskyi committed Nov 13, 2024
1 parent cd40d2f commit f9d57a8
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 56 deletions.
39 changes: 14 additions & 25 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,32 +263,21 @@ ValueAndPushforward<float, float> sqrtf_pushforward(float x, float d_x) {

#endif

template <typename T1, typename T2>
CUDA_HOST_DEVICE ValueAndPushforward<decltype(::std::pow(T1(), T2())),
decltype(::std::pow(T1(), T2()))>
pow_pushforward(T1 x, T2 exponent, T1 d_x, T2 d_exponent) {
auto val = ::std::pow(x, exponent);
auto derivative = (exponent * ::std::pow(x, exponent - 1)) * d_x;
// Only add directional derivative of base^exp w.r.t exp if the directional
// seed d_exponent is non-zero. This is required because if base is less than
// or equal to 0, then log(base) is undefined, and therefore if user only
// requested directional derivative of base^exp w.r.t base -- which is valid
// --, the result would be undefined because as per C++ valid number + NaN * 0
// = NaN.
if (d_exponent)
derivative += (::std::pow(x, exponent) * ::std::log(x)) * d_exponent;
return {val, derivative};
}
template <typename T, typename dT> struct AdjOutType {
using type = T;
};

template <typename T, typename dT> struct AdjOutType<T, clad::array<dT>> {
using type = clad::array<T>;
};

template <typename T1, typename T2>
CUDA_HOST_DEVICE
ValueAndPushforward<decltype(::std::pow(T1(), T2())),
clad::array<decltype(::std::pow(T1(), T2()))>>
pow_pushforward(T1 x, T2 exponent, clad::array<T1> d_x,
clad::array<T2> d_exponent) {
decltype(::std::pow(T1(), T2())) val = ::std::pow(x, exponent);
clad::array<decltype(::std::pow(T1(), T2()))> derivative =
(exponent * ::std::pow(x, exponent - 1)) * d_x;
template <typename T1, typename T2, typename dT1, typename dT2,
typename T_out = decltype(::std::pow(T1(), T2())),
typename dT_out = typename AdjOutType<T_out, dT1>::type>
CUDA_HOST_DEVICE ValueAndPushforward<T_out, dT_out>
pow_pushforward(T1 x, T2 exponent, dT1 d_x, dT2 d_exponent) {
T_out val = ::std::pow(x, exponent);
dT_out derivative = (exponent * ::std::pow(x, exponent - 1)) * d_x;
// Only add directional derivative of base^exp w.r.t exp if the directional
// seed d_exponent is non-zero. This is required because if base is less than
// or equal to 0, then log(base) is undefined, and therefore if user only
Expand Down
8 changes: 4 additions & 4 deletions test/FirstDerivative/BuiltinDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ float f7(float x) {

// CHECK: float f7_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<decltype(::std::pow(float(), double())), decltype(::std::pow(float(), double()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, 2., _d_x, 0.);
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<double, double> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, 2., _d_x, 0.);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

Expand All @@ -120,7 +120,7 @@ double f8(float x) {

// CHECK: double f8_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, 2, _d_x, 0);
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<double, double> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, 2, _d_x, 0);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

Expand All @@ -142,7 +142,7 @@ float f9(float x, float y) {
// CHECK: float f9_darg0(float x, float y) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: {{(clad::)?}}clad::ValueAndPushforward<float, float> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

Expand All @@ -165,7 +165,7 @@ double f10(float x, int y) {
// CHECK: double f10_darg0(float x, int y) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: int _d_y = 0;
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<double, double> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

Expand Down
2 changes: 1 addition & 1 deletion test/ForwardMode/UserDefinedTypes.C
Original file line number Diff line number Diff line change
Expand Up @@ -1167,7 +1167,7 @@ int main() {
// CHECK-NEXT: {
// CHECK-NEXT: unsigned int _d_i = 0;
// CHECK-NEXT: for (unsigned int i = 0; i < 5U; ++i) {
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(double(), double())), decltype(::std::pow(double(), double()))> _t0 = clad::custom_derivatives::pow_pushforward(a.data[i], b.data[i], _d_a.data[i], _d_b.data[i]);
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<double, double> _t0 = clad::custom_derivatives::pow_pushforward(a.data[i], b.data[i], _d_a.data[i], _d_b.data[i]);
// CHECK-NEXT: _d_res.data[i] = _t0.pushforward;
// CHECK-NEXT: res.data[i] = _t0.value;
// CHECK-NEXT: }
Expand Down
28 changes: 14 additions & 14 deletions test/Hessian/BuiltinDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -266,17 +266,17 @@ int main() {

// CHECK: float f4_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, 4.F, _d_x, 0.F);
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, 4.F, _d_x, 0.F);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

// CHECK: void pow_pushforward_pullback(float x, float exponent, float d_x, float d_exponent, ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _d_y, float *_d_x, float *_d_exponent, float *_d_d_x, float *_d_d_exponent);
// CHECK: void pow_pushforward_pullback(float x, float exponent, float d_x, float d_exponent, ValueAndPushforward<float, float> _d_y, float *_d_x, float *_d_exponent, float *_d_d_x, float *_d_d_exponent);

// CHECK: void f4_darg0_grad(float x, float *_d_x) {
// CHECK-NEXT: float _d__d_x = 0.F;
// CHECK-NEXT: float _d_x0 = 1;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _d__t0 = {};
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t00 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, 4.F, _d_x0, 0.F);
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _d__t0 = {};
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _t00 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, 4.F, _d_x0, 0.F);
// CHECK-NEXT: _d__t0.pushforward += 1;
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}} _r0 = {};
Expand All @@ -293,15 +293,15 @@ int main() {

// CHECK: float f5_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(2.F, x, 0.F, _d_x);
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(2.F, x, 0.F, _d_x);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

// CHECK: void f5_darg0_grad(float x, float *_d_x) {
// CHECK-NEXT: float _d__d_x = 0.F;
// CHECK-NEXT: float _d_x0 = 1;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _d__t0 = {};
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t00 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(2.F, x, 0.F, _d_x0);
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _d__t0 = {};
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _t00 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(2.F, x, 0.F, _d_x0);
// CHECK-NEXT: _d__t0.pushforward += 1;
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}} _r0 = {};
Expand All @@ -319,7 +319,7 @@ int main() {
// CHECK: float f6_darg0(float x, float y) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

Expand All @@ -328,8 +328,8 @@ int main() {
// CHECK-NEXT: float _d_x0 = 1;
// CHECK-NEXT: float _d__d_y = 0.F;
// CHECK-NEXT: float _d_y0 = 0;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _d__t0 = {};
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t00 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x0, _d_y0);
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _d__t0 = {};
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _t00 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x0, _d_y0);
// CHECK-NEXT: _d__t0.pushforward += 1;
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}} _r0 = {};
Expand All @@ -349,7 +349,7 @@ int main() {
// CHECK: float f6_darg1(float x, float y) {
// CHECK-NEXT: float _d_x = 0;
// CHECK-NEXT: float _d_y = 1;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

Expand All @@ -358,8 +358,8 @@ int main() {
// CHECK-NEXT: float _d_x0 = 0;
// CHECK-NEXT: float _d__d_y = 0.F;
// CHECK-NEXT: float _d_y0 = 1;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _d__t0 = {};
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t00 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x0, _d_y0);
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _d__t0 = {};
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _t00 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x0, _d_y0);
// CHECK-NEXT: _d__t0.pushforward += 1;
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}} _r0 = {};
Expand Down Expand Up @@ -434,7 +434,7 @@ int main() {
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void pow_pushforward_pullback(float x, float exponent, float d_x, float d_exponent, ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _d_y, float *_d_x, float *_d_exponent, float *_d_d_x, float *_d_d_exponent) {
// CHECK: void pow_pushforward_pullback(float x, float exponent, float d_x, float d_exponent, ValueAndPushforward<float, float> _d_y, float *_d_x, float *_d_exponent, float *_d_d_x, float *_d_d_exponent) {
// CHECK-NEXT: bool _cond0;
// CHECK-NEXT: float _t1;
// CHECK-NEXT: float _t2;
Expand Down
24 changes: 12 additions & 12 deletions test/NthDerivative/CustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ float test_trig(float x, float y, int a, int b) {
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<float, float>, ValueAndPushforward<float, float> > _t0 = clad::custom_derivatives::std::sin_pushforward_pushforward(x * y, _d_x0 * y + x * _d_y0, _d_x * y + x * _d_y, _d__d_x * y + _d_x0 * _d_y + _d_x * _d_y0 + x * _d__d_y);
// CHECK-NEXT: ValueAndPushforward<float, float> _d__t0 = _t0.pushforward;
// CHECK-NEXT: ValueAndPushforward<float, float> _t00 = _t0.value;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))>, ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> > _t1 = pow_pushforward_pushforward(_t00.value, a, _t00.pushforward, _d_a0, _d__t0.value, _d_a, _d__t0.pushforward, _d__d_a);
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> _d__t1 = _t1.pushforward;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> _t10 = _t1.value;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<double, double>, ValueAndPushforward<double, double> > _t1 = pow_pushforward_pushforward(_t00.value, a, _t00.pushforward, _d_a0, _d__t0.value, _d_a, _d__t0.pushforward, _d__d_a);
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _d__t1 = _t1.pushforward;
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t10 = _t1.value;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<float, float>, ValueAndPushforward<float, float> > _t2 = clad::custom_derivatives::std::cos_pushforward_pushforward(x * y, _d_x0 * y + x * _d_y0, _d_x * y + x * _d_y, _d__d_x * y + _d_x0 * _d_y + _d_x * _d_y0 + x * _d__d_y);
// CHECK-NEXT: ValueAndPushforward<float, float> _d__t2 = _t2.pushforward;
// CHECK-NEXT: ValueAndPushforward<float, float> _t20 = _t2.value;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))>, ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> > _t3 = clad::custom_derivatives::std::pow_pushforward_pushforward(_t20.value, b, _t20.pushforward, _d_b0, _d__t2.value, _d_b, _d__t2.pushforward, _d__d_b);
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> _d__t3 = _t3.pushforward;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> _t30 = _t3.value;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<double, double>, ValueAndPushforward<double, double> > _t3 = clad::custom_derivatives::std::pow_pushforward_pushforward(_t20.value, b, _t20.pushforward, _d_b0, _d__t2.value, _d_b, _d__t2.pushforward, _d__d_b);
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _d__t3 = _t3.pushforward;
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t30 = _t3.value;
// CHECK-NEXT: double &_d__t4 = _d__t1.value;
// CHECK-NEXT: double &_t40 = _t10.value;
// CHECK-NEXT: double &_d__t5 = _d__t3.value;
Expand Down Expand Up @@ -92,15 +92,15 @@ float test_trig(float x, float y, int a, int b) {
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<float, float>, ValueAndPushforward<float, float> > _t0 = clad::custom_derivatives::std::sin_pushforward_pushforward(x * y, _d_x0 * y + x * _d_y0, _d_x * y + x * _d_y, _d__d_x * y + _d_x0 * _d_y + _d_x * _d_y0 + x * _d__d_y);
// CHECK-NEXT: ValueAndPushforward<float, float> _d__t0 = _t0.pushforward;
// CHECK-NEXT: ValueAndPushforward<float, float> _t00 = _t0.value;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))>, ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> > _t1 = clad::custom_derivatives::std::pow_pushforward_pushforward(_t00.value, a, _t00.pushforward, _d_a0, _d__t0.value, _d_a, _d__t0.pushforward, _d__d_a);
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> _d__t1 = _t1.pushforward;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> _t10 = _t1.value;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<double, double>, ValueAndPushforward<double, double> > _t1 = clad::custom_derivatives::std::pow_pushforward_pushforward(_t00.value, a, _t00.pushforward, _d_a0, _d__t0.value, _d_a, _d__t0.pushforward, _d__d_a);
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _d__t1 = _t1.pushforward;
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t10 = _t1.value;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<float, float>, ValueAndPushforward<float, float> > _t2 = clad::custom_derivatives::std::cos_pushforward_pushforward(x * y, _d_x0 * y + x * _d_y0, _d_x * y + x * _d_y, _d__d_x * y + _d_x0 * _d_y + _d_x * _d_y0 + x * _d__d_y);
// CHECK-NEXT: ValueAndPushforward<float, float> _d__t2 = _t2.pushforward;
// CHECK-NEXT: ValueAndPushforward<float, float> _t20 = _t2.value;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))>, ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> > _t3 = clad::custom_derivatives::std::pow_pushforward_pushforward(_t20.value, b, _t20.pushforward, _d_b0, _d__t2.value, _d_b, _d__t2.pushforward, _d__d_b);
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> _d__t3 = _t3.pushforward;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> _t30 = _t3.value;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<double, double>, ValueAndPushforward<double, double> > _t3 = clad::custom_derivatives::std::pow_pushforward_pushforward(_t20.value, b, _t20.pushforward, _d_b0, _d__t2.value, _d_b, _d__t2.pushforward, _d__d_b);
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _d__t3 = _t3.pushforward;
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t30 = _t3.value;
// CHECK-NEXT: double &_d__t4 = _d__t1.value;
// CHECK-NEXT: double &_t40 = _t10.value;
// CHECK-NEXT: double &_d__t5 = _d__t3.value;
Expand Down

0 comments on commit f9d57a8

Please sign in to comment.