Skip to content

Commit

Permalink
Enhance the support of std::vector and std::array in the fwd mode
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch committed Sep 17, 2024
1 parent c0c782c commit f97c16e
Show file tree
Hide file tree
Showing 2 changed files with 397 additions and 3 deletions.
199 changes: 196 additions & 3 deletions include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ namespace clad {
namespace custom_derivatives {
namespace class_functions {

// vector forward mode

template <typename T>
void clear_pushforward(::std::vector<T>* v, ::std::vector<T>* d_v) {
d_v->clear();
Expand Down Expand Up @@ -131,6 +133,181 @@ operator_subscript_pushforward(const ::std::vector<T>* v, unsigned idx,
return {(*v)[idx], (*d_v)[idx]};
}

template <typename T>
ValueAndPushforward<T&, T&> at_pushforward(::std::vector<T>* v, unsigned idx,
::std::vector<T>* d_v,
unsigned d_idx) {
return {(*v)[idx], (*d_v)[idx]};
}

template <typename T>
ValueAndPushforward<const T&, const T&>
at_pushforward(const ::std::vector<T>* v, unsigned idx,
const ::std::vector<T>* d_v, unsigned d_idx) {
return {(*v)[idx], (*d_v)[idx]};
}

template <typename T>
clad::ValueAndPushforward<::std::vector<T>&, ::std::vector<T>&>
operator_equal_pushforward(::std::vector<T>* a, const ::std::vector<T>& param,
::std::vector<T>* d_a,
const ::std::vector<T>& d_param) noexcept {
(*a) = param;
(*d_a) = d_param;
return {*a, *d_a};
}

template <typename T>
inline clad::ValueAndPushforward<const T&, const T&>
front_pushforward(const ::std::vector<T>* a,
const ::std::vector<T>* d_a) noexcept {
return {a->front(), d_a->front()};
}

template <typename T>
inline clad::ValueAndPushforward<T&, T&>
front_pushforward(::std::vector<T>* a, ::std::vector<T>* d_a) noexcept {
return {a->front(), d_a->front()};
}

template <typename T>
inline clad::ValueAndPushforward<const T&, const T&>
back_pushforward(const ::std::vector<T>* a,
const ::std::vector<T>* d_a) noexcept {
return {a->back(), d_a->back()};
}

template <typename T>
inline clad::ValueAndPushforward<T&, T&>
back_pushforward(::std::vector<T>* a, ::std::vector<T>* d_a) noexcept {
return {a->back(), d_a->back()};
}

template <typename T>
ValueAndPushforward<typename ::std::vector<T>::iterator,
typename ::std::vector<T>::iterator>
begin_pushforward(::std::vector<T>* v, ::std::vector<T>* d_v) {
return {v->begin(), d_v->begin()};
}

template <typename T>
ValueAndPushforward<typename ::std::vector<T>::iterator,
typename ::std::vector<T>::iterator>
end_pushforward(::std::vector<T>* v, ::std::vector<T>* d_v) {
return {v->end(), d_v->end()};
}

template <typename T>
ValueAndPushforward<typename ::std::vector<T>::iterator,
typename ::std::vector<T>::iterator>
erase_pushforward(::std::vector<T>* v,
typename ::std::vector<T>::const_iterator pos,
::std::vector<T>* d_v,
typename ::std::vector<T>::const_iterator d_pos) {
return {v->erase(pos), d_v->erase(d_pos)};
}

template <typename T, typename U>
ValueAndPushforward<typename ::std::vector<T>::iterator,
typename ::std::vector<T>::iterator>
insert_pushforward(::std::vector<T>* v,
typename ::std::vector<T>::const_iterator pos, U u,
::std::vector<T>* d_v,
typename ::std::vector<T>::const_iterator d_pos, U d_u) {
return {v->insert(pos, u), d_v->insert(d_pos, d_u)};
}

template <typename T, typename U>
ValueAndPushforward<typename ::std::vector<T>::iterator,
typename ::std::vector<T>::iterator>
insert_pushforward(::std::vector<T>* v,
typename ::std::vector<T>::const_iterator pos,
::std::initializer_list<U> list, ::std::vector<T>* d_v,
typename ::std::vector<T>::const_iterator d_pos,
::std::initializer_list<U> d_list) {
return {v->insert(pos, list), d_v->insert(d_pos, d_list)};
}

template <typename T, typename U>
ValueAndPushforward<typename ::std::vector<T>::iterator,
typename ::std::vector<T>::iterator>
insert_pushforward(::std::vector<T>* v,
typename ::std::vector<T>::const_iterator pos, U first,
U last, ::std::vector<T>* d_v,
typename ::std::vector<T>::const_iterator d_pos, U d_first,
U d_last) {
return {v->insert(pos, first, last), d_v->insert(d_pos, d_first, d_last)};
}

template <typename T, typename U>
void assign_pushforward(::std::vector<T>* v,
typename ::std::vector<T>::size_type n, const U& val,
::std::vector<T>* d_v,
typename ::std::vector<T>::size_type /*d_n*/,
const U& d_val) {
v->assign(n, val);
d_v->assign(n, d_val);
}

template <typename T, typename U>
void assign_pushforward(::std::vector<T>* v, U first, U last,
::std::vector<T>* d_v, U d_first, U d_last) {
v->assign(first, last);
d_v->assign(d_first, d_last);
}

template <typename T, typename U>
void assign_pushforward(::std::vector<T>* v, ::std::initializer_list<U> list,
::std::vector<T>* d_v,
::std::initializer_list<U> d_list) {
v->assign(list);
d_v->assign(d_list);
}

template <typename T>
void reserve_pushforward(::std::vector<T>* v,
typename ::std::vector<T>::size_type n,
::std::vector<T>* d_v,
typename ::std::vector<T>::size_type /*d_n*/) {
v->reserve(n);
d_v->reserve(n);
}

template <typename T>
void shrink_to_fit_pushforward(::std::vector<T>* v, ::std::vector<T>* d_v) {
v->shrink_to_fit();
d_v->shrink_to_fit();
}

template <typename T, typename U>
void push_back_pushforward(::std::vector<T>* v, U val, ::std::vector<T>* d_v,
U d_val) {
v->push_back(val);
d_v->push_back(d_val);
}

template <typename T>
void pop_back_pushforward(::std::vector<T>* v, ::std::vector<T>* d_v) noexcept {
v->pop_back();
d_v->pop_back();
}

template <typename T>
clad::ValueAndPushforward<::std::size_t, ::std::size_t>
size_pushforward(const ::std::vector<T>* v,
const ::std::vector<T>* d_v) noexcept {
return {v->size(), 0};
}

template <typename T>
clad::ValueAndPushforward<::std::size_t, ::std::size_t>
capacity_pushforward(const ::std::vector<T>* v,
const ::std::vector<T>* d_v) noexcept {
return {v->capacity(), 0};
}

// array forward mode

template <typename T, ::std::size_t N>
constexpr clad::ValueAndPushforward<T&, T&>
operator_subscript_pushforward(::std::array<T, N>* a, ::std::size_t i,
Expand Down Expand Up @@ -198,13 +375,23 @@ back_pushforward(::std::array<T, N>* a, ::std::array<T, N>* d_a) noexcept {
return {a->back(), d_a->back()};
}

template <typename T, ::std::size_t N>
void fill_pushforward(::std::array<T, N>* a, const T& u,
::std::array<T, N>* d_a, const T& d_u) {
template <typename T, typename U, ::std::size_t N>
void fill_pushforward(::std::array<T, N>* a, const U& u,
::std::array<T, N>* d_a, const U& d_u) {
a->fill(u);
d_a->fill(d_u);
}

template <typename T, ::std::size_t N>
clad::ValueAndPushforward<::std::size_t, ::std::size_t>
size_pushforward(const ::std::array<T, N>* a,
const ::std::array<T, N>* d_a) noexcept {
return {a->size(), 0};
}

// vector reverse mode
// more can be found in tests: test/Gradient/STLCustomDerivatives.C

template <typename T, typename U>
void push_back_reverse_forw(::std::vector<T>* v, U val, ::std::vector<T>* d_v,
U d_val) {
Expand Down Expand Up @@ -256,6 +443,8 @@ void constructor_pullback(::std::vector<T>* v, S count, U val,
d_v->clear();
}

// array reverse mode

template <typename T, ::std::size_t N>
clad::ValueAndAdjoint<T&, T&> operator_subscript_reverse_forw(
::std::array<T, N>* arr, typename ::std::array<T, N>::size_type idx,
Expand Down Expand Up @@ -341,6 +530,8 @@ void constructor_pullback(::std::array<T, N>* a, const ::std::array<T, N>& arr,
(*d_arr)[i] += (*d_a)[i];
}

// tuple forward mode

template <typename... Args1, typename... Args2>
clad::ValueAndPushforward<::std::tuple<Args1...>, ::std::tuple<Args1...>>
operator_equal_pushforward(::std::tuple<Args1...>* tu,
Expand All @@ -356,6 +547,8 @@ operator_equal_pushforward(::std::tuple<Args1...>* tu,

namespace std {

// tie and maketuple forward mode

// Helper functions for selecting subtuples
template <::std::size_t shift_amount, ::std::size_t... Is>
constexpr auto shift_sequence(IndexSequence<Is...>) {
Expand Down
Loading

0 comments on commit f97c16e

Please sign in to comment.