Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand scalar_seq_view to work with nested containers and tuples #3058

Open
wants to merge 21 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@
#include <stan/math/prim/fun/scaled_add.hpp>
#include <stan/math/prim/fun/sd.hpp>
#include <stan/math/prim/fun/segment.hpp>
#include <stan/math/prim/fun/sequential_index.hpp>
#include <stan/math/prim/fun/serializer.hpp>
#include <stan/math/prim/fun/select.hpp>
#include <stan/math/prim/fun/sign.hpp>
Expand Down
1 change: 1 addition & 0 deletions stan/math/prim/fun/cumulative_sum.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <vector>
#include <numeric>
#include <functional>
Expand Down
50 changes: 40 additions & 10 deletions stan/math/prim/fun/num_elements.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/functor/apply.hpp>
#include <vector>
#include <algorithm>

namespace stan {
namespace math {
Expand All @@ -16,7 +18,7 @@ namespace math {
* @return 1
*/
template <typename T, require_stan_scalar_t<T>* = nullptr>
inline int num_elements(const T& x) {
inline size_t num_elements(const T& x) {
return 1;
}

Expand All @@ -29,25 +31,53 @@ inline int num_elements(const T& x) {
* @return size of matrix
*/
template <typename T, require_matrix_t<T>* = nullptr>
inline int num_elements(const T& m) {
inline size_t num_elements(const T& m) {
return m.size();
}

/**
* Returns the number of elements in the specified vector.
* This assumes it is not ragged and that each of its contained
* elements has the same number of elements.
* @tparam T type of elements in the vector
* @param v argument vector
* @return number of contained arguments
*/
template <typename T, require_stan_scalar_t<T>* = nullptr>
inline size_t num_elements(const std::vector<T>& v) {
return v.size();
}

/**
* Returns the number of elements in the specified vector
*
* @tparam T type of elements in the vector
* @param v argument vector
* @return number of contained arguments
*/
template <typename T>
inline int num_elements(const std::vector<T>& v) {
if (v.size() == 0) {
return 0;
}
return v.size() * num_elements(v[0]);
template <typename T, require_container_t<T>* = nullptr>
inline size_t num_elements(const std::vector<T>& v) {
size_t size = 0;
std::for_each(v.cbegin(), v.cend(),
[&size](auto&& x) { size += num_elements(x); });
return size;
}

/**
* Returns the number of elements in the specified tuple
*
* @tparam T type of tuple
* @param v tuple
* @return number of contained arguments
*/
template <typename T, require_tuple_t<T>* = nullptr>
inline size_t num_elements(const T& v) {
size_t size = 0;
math::apply(
[&size](auto&&... args) {
static_cast<void>(
std::initializer_list<int>{(size += num_elements(args), 0)...});
},
v);
return size;
}

} // namespace math
Expand Down
62 changes: 52 additions & 10 deletions stan/math/prim/fun/scalar_seq_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
#define STAN_MATH_PRIM_FUN_SCALAR_SEQ_VIEW_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <type_traits>
#include <utility>
#include <stan/math/prim/fun/num_elements.hpp>
#include <stan/math/prim/fun/sequential_index.hpp>

namespace stan {
namespace internal {
template <typename T>
using require_nested_t = require_t<math::disjunction<
math::is_tuple<T>,
math::conjunction<is_std_vector<T>, is_container<value_type_t<T>>>>>;
}
/**
* scalar_seq_view provides a uniform sequence-like wrapper around either a
* scalar or a sequence of scalars.
Expand All @@ -18,7 +23,7 @@ template <typename C, typename = void>
class scalar_seq_view;

template <typename C>
class scalar_seq_view<C, require_eigen_vector_t<C>> {
class scalar_seq_view<C, require_eigen_t<C>> {
public:
template <typename T,
typename = require_same_t<plain_type_t<T>, plain_type_t<C>>>
Expand All @@ -30,6 +35,7 @@ class scalar_seq_view<C, require_eigen_vector_t<C>> {
* @return the element at the specified position in the container
*/
inline auto operator[](size_t i) const { return c_.coeff(i); }
inline auto& operator[](size_t i) { return c_.coeffRef(i); }

inline auto size() const noexcept { return c_.size(); }

Expand All @@ -47,7 +53,7 @@ class scalar_seq_view<C, require_eigen_vector_t<C>> {
}

private:
ref_type_t<C> c_;
plain_type_t<C> c_;
};

template <typename C>
Expand Down Expand Up @@ -83,7 +89,7 @@ class scalar_seq_view<C, require_var_matrix_t<C>> {
};

template <typename C>
class scalar_seq_view<C, require_std_vector_t<C>> {
class scalar_seq_view<C, require_std_vector_vt<is_stan_scalar, C>> {
public:
template <typename T,
typename = require_same_t<plain_type_t<T>, plain_type_t<C>>>
Expand All @@ -95,6 +101,7 @@ class scalar_seq_view<C, require_std_vector_t<C>> {
* @return the element at the specified position in the container
*/
inline auto operator[](size_t i) const { return c_[i]; }
inline auto& operator[](size_t i) { return c_[i]; }
inline auto size() const noexcept { return c_.size(); }
inline const auto* data() const noexcept { return c_.data(); }

Expand All @@ -109,7 +116,41 @@ class scalar_seq_view<C, require_std_vector_t<C>> {
}

private:
const C& c_;
std::decay_t<C> c_;
};

template <typename C>
class scalar_seq_view<C, internal::require_nested_t<C>> {
public:
template <typename T>
explicit scalar_seq_view(T&& c)
: c_(std::forward<T>(c)), size_(math::num_elements(c_)) {}

inline auto size() const noexcept { return size_; }

inline auto operator[](size_t i) const {
return math::sequential_index(i, std::forward<decltype(c_)>(c_));
}

inline auto& operator[](size_t i) {
return math::sequential_index(i, std::forward<decltype(c_)>(c_));
}

inline const auto* data() const noexcept { return c_.data(); }

template <typename T = C, require_st_arithmetic<T>* = nullptr>
inline decltype(auto) val(size_t i) const {
return this[i];
}

template <typename T = C, require_st_autodiff<T>* = nullptr>
inline decltype(auto) val(size_t i) const {
return this[i].val();
}

private:
std::decay_t<C> c_;
size_t size_;
};

template <typename C>
Expand Down Expand Up @@ -154,15 +195,16 @@ class scalar_seq_view<C, require_stan_scalar_t<C>> {
public:
explicit scalar_seq_view(const C& t) noexcept : t_(t) {}

inline decltype(auto) operator[](int /* i */) const noexcept { return t_; }
inline auto operator[](size_t /* i */) const { return t_; }
inline auto& operator[](size_t /* i */) { return t_; }

template <typename T = C, require_st_arithmetic<T>* = nullptr>
inline decltype(auto) val(int /* i */) const noexcept {
inline decltype(auto) val(size_t /* i */) const noexcept {
return t_;
}

template <typename T = C, require_st_autodiff<T>* = nullptr>
inline decltype(auto) val(int /* i */) const noexcept {
inline decltype(auto) val(size_t /* i */) const noexcept {
return t_.val();
}

Expand Down
118 changes: 118 additions & 0 deletions stan/math/prim/fun/sequential_index.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#ifndef STAN_MATH_PRIM_FUN_SEQUENTIAL_INDEX_HPP
#define STAN_MATH_PRIM_FUN_SEQUENTIAL_INDEX_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/meta/is_tuple.hpp>
#include <stan/math/prim/fun/num_elements.hpp>
#include <stan/math/prim/functor/apply_at.hpp>

namespace stan {
namespace math {

/**
* Utility function for indexing arbitrary types as sequential values, for use
* as both lvalues and rvalues.
*
* Base template for scalars where no indexing is needed.
*
* @tparam Type of input scalar
* @param x Input scalar
* @return Input scalar unchanged
*/
template <typename T, require_stan_scalar_t<T>* = nullptr>
inline decltype(auto) sequential_index(size_t /* i */, T&& x) {
return std::forward<T>(x);
}

/**
* Utility function for indexing arbitrary types as sequential values, for use
* as both lvalues and rvalues.
*
* Template for non-nested std::vectors
*
* @tparam Type of non-nested std::vector
* @param i Index of desired value
* @param x Input vector
* @return Value at desired index in container
*/
template <typename T, require_std_vector_vt<is_stan_scalar, T>* = nullptr>
inline decltype(auto) sequential_index(size_t i, T&& x) {
return x[i];
}

/**
* Utility function for indexing arbitrary types as sequential values, for use
* as both lvalues and rvalues.
*
* Template for Eigen types
*
* @tparam Type of Eigen input
* @param i Index of desired value
* @param x Input Eigen object
* @return Value at desired index in container
*/
template <typename T, require_eigen_t<T>* = nullptr>
inline decltype(auto) sequential_index(size_t i, T&& x) {
return x.coeffRef(i);
}

/**
* Utility function for indexing arbitrary types as sequential values, for use
* as both lvalues and rvalues.
*
* Template for nested std::vectors
*
* @tparam Type of nested std::vector
* @param i Index of desired value
* @param x Input vector
* @return Value at desired index in container (recursively extracted)
*/
template <typename T, require_std_vector_vt<is_container, T>* = nullptr>
inline decltype(auto) sequential_index(size_t i, T&& x) {
size_t inner_idx = i;
size_t elem = 0;
for (auto&& x_val : x) {
size_t num_elems = math::num_elements(x_val);
if (inner_idx <= (num_elems - 1)) {
break;
}
elem++;
inner_idx -= num_elems;
}
return sequential_index(inner_idx, std::forward<decltype(x[elem])>(x[elem]));
}

/**
* Utility function for indexing arbitrary types as sequential values, for use
* as both lvalues and rvalues.
*
* Template for tuples.
*
* @tparam Type of tuple
* @param i Index of desired value
* @param x Input tuple
* @return Value at desired index in tuple (recursively extracted if needed)
*/
template <typename T, math::require_tuple_t<T>* = nullptr>
inline decltype(auto) sequential_index(size_t i, T&& x) {
size_t inner_idx = i;
size_t elem = 0;

auto num_functor = [](auto&& arg) { return math::num_elements(arg); };
for (size_t j = 0; j < std::tuple_size<std::decay_t<T>>{}; j++) {
size_t num_elems = math::apply_at(num_functor, j, std::forward<T>(x));
if (inner_idx <= (num_elems - 1)) {
break;
}
elem++;
inner_idx -= num_elems;
}

auto index_func = [inner_idx](auto&& t_elem) -> decltype(auto) {
return sequential_index(inner_idx, std::forward<decltype(t_elem)>(t_elem));
};
return math::apply_at(index_func, elem, std::forward<T>(x));
}
} // namespace math
} // namespace stan
#endif
1 change: 1 addition & 0 deletions stan/math/prim/functor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define STAN_MATH_PRIM_FUNCTOR_HPP

#include <stan/math/prim/functor/apply.hpp>
#include <stan/math/prim/functor/apply_at.hpp>
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
#include <stan/math/prim/functor/apply_scalar_ternary.hpp>
Expand Down
Loading