Skip to content

Commit

Permalink
Add support for std::array in the fwd mode
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch committed Jul 19, 2024
1 parent c1f30f5 commit 7eef009
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 1 deletion.
74 changes: 74 additions & 0 deletions include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,80 @@ operator_subscript_pushforward(const ::std::vector<T>* v, unsigned idx,
return {(*v)[idx], (*d_v)[idx]};
}

template <typename T, ::std::size_t N>
constexpr clad::ValueAndPushforward<T&, T&>
operator_subscript_pushforward(::std::array<T, N>* a, ::std::size_t i,
::std::array<T, N>* _d_a,
::std::size_t _d_i) noexcept {
return {(*a)[i], (*_d_a)[i]};
}

template <typename T, ::std::size_t N>
constexpr clad::ValueAndPushforward<T&, T&>
at_pushforward(::std::array<T, N>* a, ::std::size_t i, ::std::array<T, N>* _d_a,
::std::size_t _d_i) noexcept {
return {(*a)[i], (*_d_a)[i]};
}

template <typename T, ::std::size_t N>
constexpr clad::ValueAndPushforward<const T&, const T&>
operator_subscript_pushforward(const ::std::array<T, N>* a, ::std::size_t i,
const ::std::array<T, N>* _d_a,
::std::size_t _d_i) noexcept {
return {(*a)[i], (*_d_a)[i]};
}

template <typename T, ::std::size_t N>
constexpr clad::ValueAndPushforward<const T&, const T&>
at_pushforward(const ::std::array<T, N>* a, ::std::size_t i,
const ::std::array<T, N>* _d_a, ::std::size_t _d_i) noexcept {
return {(*a)[i], (*_d_a)[i]};
}

template <typename T, ::std::size_t N>
constexpr clad::ValueAndPushforward<::std::array<T, N>&, ::std::array<T, N>&>
operator_equal_pushforward(::std::array<T, N>* a,
const ::std::array<T, N>& param,
::std::array<T, N>* _d_a,
const ::std::array<T, N>& _d_param) noexcept {
(*a) = param;
(*_d_a) = _d_param;
return {*a, *_d_a};
}

template <typename T, ::std::size_t N>
inline constexpr clad::ValueAndPushforward<const T&, const T&>
front_pushforward(const ::std::array<T, N>* a,
const ::std::array<T, N>* _d_a) noexcept {
return {a->front(), _d_a->front()};
}

template <typename T, ::std::size_t N>
inline constexpr clad::ValueAndPushforward<T&, T&>
front_pushforward(::std::array<T, N>* a, ::std::array<T, N>* _d_a) noexcept {
return {a->front(), _d_a->front()};
}

template <typename T, ::std::size_t N>
inline constexpr clad::ValueAndPushforward<const T&, const T&>
back_pushforward(const ::std::array<T, N>* a,
const ::std::array<T, N>* _d_a) noexcept {
return {a->back(), _d_a->back()};
}

template <typename T, ::std::size_t N>
inline constexpr clad::ValueAndPushforward<T&, T&>
back_pushforward(::std::array<T, N>* a, ::std::array<T, N>* _d_a) noexcept {
return {a->back(), _d_a->back()};
}

template <typename T, ::std::size_t N>
void fill_pushforward(::std::array<T, N>* a, const T& u,
::std::array<T, N>* _d_a, const T& _d_u) {
a->fill(u);
_d_a->fill(_d_u);
}

} // namespace class_functions
} // namespace custom_derivatives
} // namespace clad
Expand Down
72 changes: 71 additions & 1 deletion test/ForwardMode/STLCustomDerivatives.C
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %cladclang %s -I%S/../../include -oSTLCustomDerivatives.out | %filecheck %s
// RUN: %cladclang %s -Wc++14-extensions -I%S/../../include -oSTLCustomDerivatives.out | %filecheck %s
// RUN: ./STLCustomDerivatives.out | %filecheck_exec %s

// CHECK-NOT: {{.*error|warning|note:.*}}
Expand Down Expand Up @@ -116,14 +116,84 @@ double fnVec4(double u, double v) {
// CHECK-NEXT: return _t1.pushforward * _t4 + _t3 * _t2.pushforward;
// CHECK-NEXT: }

double fnArr1(double x) {
std::array<double, 3> a;
a.fill(x);

for (size_t i = 0; i < a.size(); ++i) {
a[i] *= i;
}

double res = 0;
for (size_t i = 0; i < a.size(); ++i) {
res += a[i];
}

return res;
}

//CHECK: double fnArr1_darg0(double x) {
//CHECK-NEXT: double _d_x = 1;
//CHECK-NEXT: std::array<double, 3> _d_a;
//CHECK-NEXT: std::array<double, 3> a;
//CHECK-NEXT: clad::custom_derivatives::class_functions::fill_pushforward(&a, x, &_d_a, _d_x);
//CHECK-NEXT: {
//CHECK-NEXT: size_t _d_i = 0;
//CHECK-NEXT: for (size_t i = 0; i < a.size(); ++i) {
//CHECK-NEXT: clad::ValueAndPushforward<double &, double &> _t0 = clad::custom_derivatives::class_functions::operator_subscript_pushforward(&a, i, &_d_a, _d_i);
//CHECK-NEXT: double &_t1 = _t0.pushforward;
//CHECK-NEXT: double &_t2 = _t0.value;
//CHECK-NEXT: _t1 = _t1 * i + _t2 * _d_i;
//CHECK-NEXT: _t2 *= i;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: double _d_res = 0;
//CHECK-NEXT: double res = 0;
//CHECK-NEXT: {
//CHECK-NEXT: size_t _d_i = 0;
//CHECK-NEXT: for (size_t i = 0; i < a.size(); ++i) {
//CHECK-NEXT: clad::ValueAndPushforward<double &, double &> _t3 = clad::custom_derivatives::class_functions::operator_subscript_pushforward(&a, i, &_d_a, _d_i);
//CHECK-NEXT: _d_res += _t3.pushforward;
//CHECK-NEXT: res += _t3.value;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: return _d_res;
//CHECK-NEXT: }

double fnArr2(double x) {
std::array<double, 2> a{5, x};
const std::array<double, 3> b{x, 0, x*x};
return a.back() * b.front() * b.at(2);
}

//CHECK: double fnArr2_darg0(double x) {
//CHECK-NEXT: double _d_x = 1;
//CHECK-NEXT: std::array<double, 2> _d_a{{[{][{]0, _d_x[}][}]}};
//CHECK-NEXT: std::array<double, 2> a{{[{][{]5, x[}][}]}};
//CHECK-NEXT: const std::array<double, 3> _d_b{{[{][{]_d_x, 0, _d_x \* x \+ x \* _d_x[}][}]}};
//CHECK-NEXT: const std::array<double, 3> b{{[{][{]x, 0, x \* x[}][}]}};
//CHECK-NEXT: clad::ValueAndPushforward<double &, double &> _t0 = clad::custom_derivatives::class_functions::back_pushforward(&a, &_d_a);
//CHECK-NEXT: clad::ValueAndPushforward<const double &, const double &> _t1 = clad::custom_derivatives::class_functions::front_pushforward(&b, &_d_b);
//CHECK-NEXT: double &_t2 = _t0.value;
//CHECK-NEXT: const double _t3 = _t1.value;
//CHECK-NEXT: clad::ValueAndPushforward<const double &, const double &> _t4 = clad::custom_derivatives::class_functions::at_pushforward(&b, 2, &_d_b, 0);
//CHECK-NEXT: double _t5 = _t2 * _t3;
//CHECK-NEXT: const double _t6 = _t4.value;
//CHECK-NEXT: return (_t0.pushforward * _t3 + _t2 * _t1.pushforward) * _t6 + _t5 * _t4.pushforward;
//CHECK-NEXT: }

int main() {
INIT_DIFFERENTIATE(fnVec1, "u");
INIT_DIFFERENTIATE(fnVec2, "u");
INIT_DIFFERENTIATE(fnVec3, "u");
INIT_DIFFERENTIATE(fnVec4, "u");
INIT_DIFFERENTIATE(fnArr1, "x");
INIT_DIFFERENTIATE(fnArr2, "x");

TEST_DIFFERENTIATE(fnVec1, 3, 5); // CHECK-EXEC: {10.00}
TEST_DIFFERENTIATE(fnVec2, 3, 5); // CHECK-EXEC: {5.00}
TEST_DIFFERENTIATE(fnVec3, 3, 5); // CHECK-EXEC: {2.00}
TEST_DIFFERENTIATE(fnVec4, 3, 5); // CHECK-EXEC: {30.00}
TEST_DIFFERENTIATE(fnArr1, 3); // CHECK-EXEC: {3.00}
TEST_DIFFERENTIATE(fnArr2, 3); // CHECK-EXEC: {108.00}
}

0 comments on commit 7eef009

Please sign in to comment.