diff --git a/include/clad/Differentiator/BuiltinDerivatives.h b/include/clad/Differentiator/BuiltinDerivatives.h index a31b239bb..3f9dcae67 100644 --- a/include/clad/Differentiator/BuiltinDerivatives.h +++ b/include/clad/Differentiator/BuiltinDerivatives.h @@ -30,6 +30,11 @@ template struct ValueAndPushforward { } }; +template +ValueAndPushforward make_value_and_pushforward(T value, U pushforward) { + return {value, pushforward}; +} + template struct ValueAndAdjoint { T value; U adjoint; diff --git a/include/clad/Differentiator/STLBuiltins.h b/include/clad/Differentiator/STLBuiltins.h index 9f3cb42c0..11192b16d 100644 --- a/include/clad/Differentiator/STLBuiltins.h +++ b/include/clad/Differentiator/STLBuiltins.h @@ -1,8 +1,11 @@ #ifndef CLAD_STL_BUILTINS_H #define CLAD_STL_BUILTINS_H +#include #include +#include #include +#include #include namespace clad { @@ -338,7 +341,85 @@ void constructor_pullback(::std::array* a, const ::std::array& arr, (*d_arr)[i] += (*d_a)[i]; } +template +clad::ValueAndPushforward<::std::tuple, ::std::tuple> +operator_equal_pushforward(::std::tuple* tu, + ::std::tuple&& in, + ::std::tuple* d_tu, + ::std::tuple&& d_in) noexcept { + ::std::tuple t1 = (*tu = in); + ::std::tuple t2 = (*d_tu = d_in); + return {t1, t2}; +} + } // namespace class_functions + +namespace std { + +// Helper functions for selecting subtuples +template <::std::size_t shift_amount, ::std::size_t... Is> +constexpr auto shift_sequence(IndexSequence) { + return IndexSequence{}; +} + +template +auto select_tuple_elements(const Tuple& tpl, IndexSequence) { + return ::std::make_tuple(::std::get(tpl)...); +} + +template auto first_half_tuple(const Tuple& tpl) { + // static_assert(::std::tuple_size::value % 2 == 0); + constexpr ::std::size_t half = ::std::tuple_size::value / 2; + + constexpr MakeIndexSequence first_half; + return select_tuple_elements(tpl, first_half); +} + +template auto second_half_tuple(const Tuple& tpl) { + // static_assert(::std::tuple_size::value % 2 == 0); + constexpr ::std::size_t half = ::std::tuple_size::value / 2; + + constexpr MakeIndexSequence first_half; + constexpr auto second_half = shift_sequence(first_half); + return select_tuple_elements(tpl, second_half); +} + +template +auto select_tuple_elements_tie(const Tuple& tpl, IndexSequence) { + return ::std::tie(::std::get(tpl)...); +} + +template auto first_half_tuple_tie(const Tuple& tpl) { + // static_assert(::std::tuple_size::value % 2 == 0); + constexpr ::std::size_t half = ::std::tuple_size::value / 2; + + constexpr MakeIndexSequence first_half; + return select_tuple_elements_tie(tpl, first_half); +} + +template auto second_half_tuple_tie(const Tuple& tpl) { + // static_assert(::std::tuple_size::value % 2 == 0); + constexpr ::std::size_t half = ::std::tuple_size::value / 2; + + constexpr MakeIndexSequence first_half; + constexpr auto second_half = shift_sequence(first_half); + return select_tuple_elements_tie(tpl, second_half); +} + +template auto tie_pushforward(Args&&... args) noexcept { + ::std::tuple t = ::std::tie(args...); + return clad::make_value_and_pushforward(first_half_tuple_tie(t), + second_half_tuple_tie(t)); +} + +template auto make_tuple_pushforward(Args... args) noexcept { + ::std::tuple t = ::std::make_tuple(args...); + return clad::make_value_and_pushforward(first_half_tuple(t), + second_half_tuple(t)); +} + +} // namespace std + } // namespace custom_derivatives } // namespace clad diff --git a/test/ForwardMode/STLCustomDerivatives.C b/test/ForwardMode/STLCustomDerivatives.C index cfb1f7813..7d441fa0f 100644 --- a/test/ForwardMode/STLCustomDerivatives.C +++ b/test/ForwardMode/STLCustomDerivatives.C @@ -1,4 +1,4 @@ -// RUN: %cladclang %s -std=c++14 -I%S/../../include -oSTLCustomDerivatives.out | %filecheck %s +// RUN: %cladclang -std=c++14 %s -I%S/../../include -oSTLCustomDerivatives.out | %filecheck %s // RUN: ./STLCustomDerivatives.out | %filecheck_exec %s #include "clad/Differentiator/Differentiator.h" @@ -7,6 +7,7 @@ #include #include #include +#include #include "../TestUtils.h" #include "../PrintOverloads.h" @@ -181,6 +182,33 @@ double fnArr2(double x) { //CHECK-NEXT: return (_t0.pushforward * _t3 + _t2 * _t1.pushforward) * _t6 + _t5 * _t4.pushforward; //CHECK-NEXT: } +auto pack(double x) { + return std::make_tuple(x, 2*x, 3*x); +} + +double fnTuple1(double x, double y) { + double u, v = 288*x, w; + + std::tie(u, v, w) = pack(x+y); + + return v; +} // = 2x + 2y + +//CHECK: double fnTuple1_darg0(double x, double y) { +//CHECK-NEXT: double _d_x = 1; +//CHECK-NEXT: double _d_y = 0; +//CHECK-NEXT: double _d_u, _d_v = 0 * x + 288 * _d_x, _d_w; +//CHECK-NEXT: double u, v = 288 * x, w; +//CHECK-NEXT: clad::ValueAndPushforward, tuple > _t0 = clad::custom_derivatives::std::tie_pushforward(u, v, w, _d_u, _d_v, _d_w); +//CHECK-NEXT: clad::ValueAndPushforward<{{.*}}> _t1 = pack_pushforward(x + y, _d_x + _d_y); +//CHECK-NEXT: clad::ValueAndPushforward<{{.*}}> _t2 = clad::custom_derivatives::class_functions::operator_equal_pushforward(&_t0.value, static_cast &&>(_t1.value), &_t0.pushforward, static_cast &&>(_t1.pushforward)); +//CHECK-NEXT: return _d_v; +//CHECK-NEXT: } +//CHECK: clad::ValueAndPushforward<{{.*}}> pack_pushforward({{.*}}) { +//CHECK-NEXT: clad::ValueAndPushforward, tuple > _t0 = clad::custom_derivatives::std::make_tuple_pushforward(x, 2 * x, 3 * x, _d_x, 0 * x + 2 * _d_x, 0 * x + 3 * _d_x); +//CHECK-NEXT: return {_t0.value, _t0.pushforward}; +//CHECK-NEXT: } + int main() { INIT_DIFFERENTIATE(fnVec1, "u"); INIT_DIFFERENTIATE(fnVec2, "u"); @@ -188,6 +216,7 @@ int main() { INIT_DIFFERENTIATE(fnVec4, "u"); INIT_DIFFERENTIATE(fnArr1, "x"); INIT_DIFFERENTIATE(fnArr2, "x"); + INIT_DIFFERENTIATE(fnTuple1, "x"); TEST_DIFFERENTIATE(fnVec1, 3, 5); // CHECK-EXEC: {10.00} TEST_DIFFERENTIATE(fnVec2, 3, 5); // CHECK-EXEC: {5.00} @@ -195,4 +224,5 @@ int main() { TEST_DIFFERENTIATE(fnVec4, 3, 5); // CHECK-EXEC: {30.00} TEST_DIFFERENTIATE(fnArr1, 3); // CHECK-EXEC: {3.00} TEST_DIFFERENTIATE(fnArr2, 3); // CHECK-EXEC: {108.00} + TEST_DIFFERENTIATE(fnTuple1, 3, 4); // CHECK-EXEC: {2.00} } diff --git a/test/ForwardMode/UserDefinedTypes.C b/test/ForwardMode/UserDefinedTypes.C index e0fca6587..922270c96 100644 --- a/test/ForwardMode/UserDefinedTypes.C +++ b/test/ForwardMode/UserDefinedTypes.C @@ -1,4 +1,4 @@ -// RUN: %cladclang %s -I%S/../../include -oUserDefinedTypes.out | %filecheck %s +// RUN: %cladclang -std=c++14 %s -I%S/../../include -oUserDefinedTypes.out | %filecheck %s // RUN: ./UserDefinedTypes.out | %filecheck_exec %s // XFAIL: asserts