Skip to content

Commit

Permalink
Add basic support for std::tie and tuples in the fwd mode
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch committed Sep 16, 2024
1 parent b414202 commit 7639909
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 2 deletions.
5 changes: 5 additions & 0 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ template <typename T, typename U> struct ValueAndPushforward {
}
};

template <typename T, typename U>
ValueAndPushforward<T, U> make_value_and_pushforward(T value, U pushforward) {
return {value, pushforward};
}

template <typename T, typename U> struct ValueAndAdjoint {
T value;
U adjoint;
Expand Down
81 changes: 81 additions & 0 deletions include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#ifndef CLAD_STL_BUILTINS_H
#define CLAD_STL_BUILTINS_H

#include <array>
#include <clad/Differentiator/BuiltinDerivatives.h>
#include <clad/Differentiator/FunctionTraits.h>
#include <initializer_list>
#include <tuple>
#include <vector>

namespace clad {
Expand Down Expand Up @@ -338,7 +341,85 @@ void constructor_pullback(::std::array<T, N>* a, const ::std::array<T, N>& arr,
(*d_arr)[i] += (*d_a)[i];
}

template <typename... Args1, typename... Args2>
clad::ValueAndPushforward<::std::tuple<Args1...>, ::std::tuple<Args1...>>
operator_equal_pushforward(::std::tuple<Args1...>* tu,
::std::tuple<Args2...>&& in,
::std::tuple<Args1...>* d_tu,
::std::tuple<Args2...>&& d_in) noexcept {
::std::tuple<Args1...> t1 = (*tu = in);
::std::tuple<Args1...> t2 = (*d_tu = d_in);
return {t1, t2};
}

} // namespace class_functions

namespace std {

// Helper functions for selecting subtuples
template <::std::size_t shift_amount, ::std::size_t... Is>
constexpr auto shift_sequence(IndexSequence<Is...>) {
return IndexSequence<shift_amount + Is...>{};
}

template <typename Tuple, ::std::size_t... Indices>
auto select_tuple_elements(const Tuple& tpl, IndexSequence<Indices...>) {
return ::std::make_tuple(::std::get<Indices>(tpl)...);
}

template <typename Tuple> auto first_half_tuple(const Tuple& tpl) {
// static_assert(::std::tuple_size<Tuple>::value % 2 == 0);
constexpr ::std::size_t half = ::std::tuple_size<Tuple>::value / 2;

constexpr MakeIndexSequence<half> first_half;
return select_tuple_elements(tpl, first_half);
}

template <typename Tuple> auto second_half_tuple(const Tuple& tpl) {
// static_assert(::std::tuple_size<Tuple>::value % 2 == 0);
constexpr ::std::size_t half = ::std::tuple_size<Tuple>::value / 2;

constexpr MakeIndexSequence<half> first_half;
constexpr auto second_half = shift_sequence<half>(first_half);
return select_tuple_elements(tpl, second_half);
}

template <typename Tuple, ::std::size_t... Indices>
auto select_tuple_elements_tie(const Tuple& tpl, IndexSequence<Indices...>) {
return ::std::tie(::std::get<Indices>(tpl)...);
}

template <typename Tuple> auto first_half_tuple_tie(const Tuple& tpl) {
// static_assert(::std::tuple_size<Tuple>::value % 2 == 0);
constexpr ::std::size_t half = ::std::tuple_size<Tuple>::value / 2;

constexpr MakeIndexSequence<half> first_half;
return select_tuple_elements_tie(tpl, first_half);
}

template <typename Tuple> auto second_half_tuple_tie(const Tuple& tpl) {
// static_assert(::std::tuple_size<Tuple>::value % 2 == 0);
constexpr ::std::size_t half = ::std::tuple_size<Tuple>::value / 2;

constexpr MakeIndexSequence<half> first_half;
constexpr auto second_half = shift_sequence<half>(first_half);
return select_tuple_elements_tie(tpl, second_half);
}

template <typename... Args> auto tie_pushforward(Args&&... args) noexcept {
::std::tuple<Args&...> t = ::std::tie(args...);
return clad::make_value_and_pushforward(first_half_tuple_tie(t),
second_half_tuple_tie(t));
}

template <typename... Args> auto make_tuple_pushforward(Args... args) noexcept {
::std::tuple<Args...> t = ::std::make_tuple(args...);
return clad::make_value_and_pushforward(first_half_tuple(t),
second_half_tuple(t));
}

} // namespace std

} // namespace custom_derivatives
} // namespace clad

Expand Down
32 changes: 31 additions & 1 deletion test/ForwardMode/STLCustomDerivatives.C
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %cladclang %s -std=c++14 -I%S/../../include -oSTLCustomDerivatives.out | %filecheck %s
// RUN: %cladclang -std=c++14 %s -I%S/../../include -oSTLCustomDerivatives.out | %filecheck %s
// RUN: ./STLCustomDerivatives.out | %filecheck_exec %s

#include "clad/Differentiator/Differentiator.h"
Expand All @@ -7,6 +7,7 @@
#include <array>
#include <vector>
#include <map>
#include <tuple>

#include "../TestUtils.h"
#include "../PrintOverloads.h"
Expand Down Expand Up @@ -181,18 +182,47 @@ double fnArr2(double x) {
//CHECK-NEXT: return (_t0.pushforward * _t3 + _t2 * _t1.pushforward) * _t6 + _t5 * _t4.pushforward;
//CHECK-NEXT: }

auto pack(double x) {
return std::make_tuple(x, 2*x, 3*x);
}

double fnTuple1(double x, double y) {
double u, v = 288*x, w;

std::tie(u, v, w) = pack(x+y);

return v;
} // = 2x + 2y

//CHECK: double fnTuple1_darg0(double x, double y) {
//CHECK-NEXT: double _d_x = 1;
//CHECK-NEXT: double _d_y = 0;
//CHECK-NEXT: double _d_u, _d_v = 0 * x + 288 * _d_x, _d_w;
//CHECK-NEXT: double u, v = 288 * x, w;
//CHECK-NEXT: clad::ValueAndPushforward<tuple<double &, double &, double &>, tuple<double &, double &, double &> > _t0 = clad::custom_derivatives::std::tie_pushforward(u, v, w, _d_u, _d_v, _d_w);
//CHECK-NEXT: clad::ValueAndPushforward<{{.*}}> _t1 = pack_pushforward(x + y, _d_x + _d_y);
//CHECK-NEXT: clad::ValueAndPushforward<{{.*}}> _t2 = clad::custom_derivatives::class_functions::operator_equal_pushforward(&_t0.value, static_cast<std::tuple<double, double, double> &&>(_t1.value), &_t0.pushforward, static_cast<std::tuple<double, double, double> &&>(_t1.pushforward));
//CHECK-NEXT: return _d_v;
//CHECK-NEXT: }
//CHECK: clad::ValueAndPushforward<{{.*}}> pack_pushforward({{.*}}) {
//CHECK-NEXT: clad::ValueAndPushforward<tuple<double, double, double>, tuple<double, double, double> > _t0 = clad::custom_derivatives::std::make_tuple_pushforward(x, 2 * x, 3 * x, _d_x, 0 * x + 2 * _d_x, 0 * x + 3 * _d_x);
//CHECK-NEXT: return {_t0.value, _t0.pushforward};
//CHECK-NEXT: }

int main() {
INIT_DIFFERENTIATE(fnVec1, "u");
INIT_DIFFERENTIATE(fnVec2, "u");
INIT_DIFFERENTIATE(fnVec3, "u");
INIT_DIFFERENTIATE(fnVec4, "u");
INIT_DIFFERENTIATE(fnArr1, "x");
INIT_DIFFERENTIATE(fnArr2, "x");
INIT_DIFFERENTIATE(fnTuple1, "x");

TEST_DIFFERENTIATE(fnVec1, 3, 5); // CHECK-EXEC: {10.00}
TEST_DIFFERENTIATE(fnVec2, 3, 5); // CHECK-EXEC: {5.00}
TEST_DIFFERENTIATE(fnVec3, 3, 5); // CHECK-EXEC: {2.00}
TEST_DIFFERENTIATE(fnVec4, 3, 5); // CHECK-EXEC: {30.00}
TEST_DIFFERENTIATE(fnArr1, 3); // CHECK-EXEC: {3.00}
TEST_DIFFERENTIATE(fnArr2, 3); // CHECK-EXEC: {108.00}
TEST_DIFFERENTIATE(fnTuple1, 3, 4); // CHECK-EXEC: {2.00}
}
2 changes: 1 addition & 1 deletion test/ForwardMode/UserDefinedTypes.C
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %cladclang %s -I%S/../../include -oUserDefinedTypes.out | %filecheck %s
// RUN: %cladclang -std=c++14 %s -I%S/../../include -oUserDefinedTypes.out | %filecheck %s
// RUN: ./UserDefinedTypes.out | %filecheck_exec %s

// XFAIL: asserts
Expand Down

0 comments on commit 7639909

Please sign in to comment.