From 8b96c450cce144caf71ffca11cd54915d3205711 Mon Sep 17 00:00:00 2001 From: stevebronder Date: Fri, 12 Apr 2024 17:22:07 -0400 Subject: [PATCH 1/9] fixes aos csr matrix bug, still debugging soa matrix bug --- stan/math/prim/fun/value_of.hpp | 24 +++++- stan/math/prim/meta/promote_scalar_type.hpp | 11 ++- stan/math/rev/fun.hpp | 1 + stan/math/rev/fun/csr_matrix_times_vector.hpp | 68 ++++++--------- stan/math/rev/fun/to_soa_sparse_matrix.hpp | 84 +++++++++++++++++++ stan/math/rev/fun/to_var_value.hpp | 6 +- .../mix/fun/csr_matrix_times_vector_test.cpp | 21 ++++- 7 files changed, 163 insertions(+), 52 deletions(-) create mode 100644 stan/math/rev/fun/to_soa_sparse_matrix.hpp diff --git a/stan/math/prim/fun/value_of.hpp b/stan/math/prim/fun/value_of.hpp index 7cc37ac9b2c..cdb5115d104 100644 --- a/stan/math/prim/fun/value_of.hpp +++ b/stan/math/prim/fun/value_of.hpp @@ -1,6 +1,7 @@ #ifndef STAN_MATH_PRIM_FUN_VALUE_OF_HPP #define STAN_MATH_PRIM_FUN_VALUE_OF_HPP +#include #include #include #include @@ -67,7 +68,7 @@ inline auto value_of(const T& x) { * @param[in] M Matrix to be converted * @return Matrix of values **/ -template * = nullptr, +template * = nullptr, require_not_st_arithmetic* = nullptr> inline auto value_of(EigMat&& M) { return make_holder( @@ -77,6 +78,27 @@ inline auto value_of(EigMat&& M) { std::forward(M)); } +template * = nullptr, + require_not_st_arithmetic* = nullptr> +inline auto value_of(EigMat&& M) { + auto&& M_ref = to_ref(M); + using scalar_t = decltype(value_of(std::declval>())); + promote_scalar_t> ret(M_ref.rows(), M_ref.cols()); + ret.reserve(M_ref.nonZeros()); + for (int k = 0; k < M_ref.outerSize(); ++k) { + for (typename std::decay_t::InnerIterator it(M_ref, k); it; ++it) { + ret.insert(it.row(), it.col()) = value_of(it.valueRef()); + } + } + ret.makeCompressed(); + return ret; +} +template * = nullptr, + require_st_arithmetic* = nullptr> +inline auto value_of(EigMat&& M) { + return std::forward(M); +} + } // namespace math } // namespace stan diff --git a/stan/math/prim/meta/promote_scalar_type.hpp b/stan/math/prim/meta/promote_scalar_type.hpp index fee623ab141..0c30b5a6140 100644 --- a/stan/math/prim/meta/promote_scalar_type.hpp +++ b/stan/math/prim/meta/promote_scalar_type.hpp @@ -80,7 +80,7 @@ struct promote_scalar_type -struct promote_scalar_type> { +struct promote_scalar_type> { /** * The promoted type. */ @@ -93,6 +93,15 @@ struct promote_scalar_type> { S::RowsAtCompileTime, S::ColsAtCompileTime>>::type; }; +template +struct promote_scalar_type> { + /** + * The promoted type. + */ + using type = Eigen::SparseMatrix::type, + S::Options, typename S::StorageIndex>; +}; + template struct promote_scalar_type, std::tuple> { diff --git a/stan/math/rev/fun.hpp b/stan/math/rev/fun.hpp index 9324df2edcc..2bba861f771 100644 --- a/stan/math/rev/fun.hpp +++ b/stan/math/rev/fun.hpp @@ -183,6 +183,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/rev/fun/csr_matrix_times_vector.hpp b/stan/math/rev/fun/csr_matrix_times_vector.hpp index d0cb4bd2373..8857f1ce23d 100644 --- a/stan/math/rev/fun/csr_matrix_times_vector.hpp +++ b/stan/math/rev/fun/csr_matrix_times_vector.hpp @@ -2,6 +2,7 @@ #define STAN_MATH_REV_FUN_CSR_MATRIX_TIMES_VECTOR_HPP #include +#include #include #include #include @@ -11,35 +12,21 @@ namespace stan { namespace math { namespace internal { -template * = nullptr> -void update_w(T1& w, int m, int n, std::vector>& u, - std::vector>& v, T2&& b, Res&& res) { - Eigen::Map> w_mat( - m, n, w.size(), u.data(), v.data(), w.data()); - for (int k = 0; k < w_mat.outerSize(); ++k) { - for (Eigen::Map>::InnerIterator - it(w_mat, k); - it; ++it) { - it.valueRef().adj() - += res.adj().coeff(it.row()) * value_of(b).coeff(it.col()); - } - } -} template * = nullptr> -void update_w(T1& w, int m, int n, std::vector>& u, - std::vector>& v, T2&& b, Res&& res) { - Eigen::Map> w_mat( - m, n, w.size(), u.data(), v.data(), w.adj().data()); - for (int k = 0; k < w_mat.outerSize(); ++k) { - for (Eigen::Map>::InnerIterator - it(w_mat, k); +inline void update_w(T1& w, T2&& b, Res&& res) { + std::cout << "pre w adj: \n" << w.adj() << std::endl; + std::cout << "pre b: \n" << value_of(b).eval() << std::endl; + std::cout << "pre res: \n" << res.adj() << std::endl; + for (int k = 0; k < w.adj().outerSize(); ++k) { + for (typename T1::vari_type::InnerIterator + it(w.adj(), k); it; ++it) { it.valueRef() += res.adj().coeff(it.row()) * value_of(b).coeff(it.col()); } } + std::cout << "post w adj: \n" << w.adj() << std::endl; } } // namespace internal @@ -79,7 +66,7 @@ inline auto csr_matrix_times_vector(int m, int n, const T1& w, const std::vector& v, const std::vector& u, const T2& b) { using sparse_val_mat - = Eigen::Map>; + = Eigen::Map>; using sparse_dense_mul_type = decltype((std::declval() * value_of(b)).eval()); using return_t = return_var_matrix_t; @@ -102,43 +89,36 @@ inline auto csr_matrix_times_vector(int m, int n, const T1& w, [](auto&& x) { return x - 1; }); if (!is_constant::value && !is_constant::value) { arena_t> b_arena = b; - arena_t> w_arena = to_arena(w); - auto w_val_arena = to_arena(value_of(w_arena)); - sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(), - v_arena.data(), w_val_arena.data()); - arena_t res = w_val_mat * value_of(b_arena); + auto w_mat_arena = to_soa_sparse_matrix(m, n, w, u_arena, v_arena); + arena_t res = w_mat_arena.val() * value_of(b_arena); reverse_pass_callback( - [m, n, w_arena, w_val_arena, v_arena, u_arena, res, b_arena]() mutable { - sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(), - v_arena.data(), w_val_arena.data()); - internal::update_w(w_arena, m, n, u_arena, v_arena, b_arena, res); - b_arena.adj() += w_val_mat.transpose() * res.adj(); + [res, w_mat_arena, b_arena]() mutable { + internal::update_w(w_mat_arena, b_arena, res); + //w_mat_arena.adj() += res.adj() * b_arena.val().transpose(); + b_arena.adj() += w_mat_arena.val().transpose() * res.adj(); }); + std::cout << "NOT CALLED" << std::endl; return return_t(res); } else if (!is_constant::value) { arena_t> b_arena = b; auto w_val_arena = to_arena(value_of(w)); sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(), v_arena.data(), w_val_arena.data()); - arena_t res = w_val_mat * value_of(b_arena); reverse_pass_callback( - [m, n, w_val_arena, v_arena, u_arena, res, b_arena]() mutable { - sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(), - v_arena.data(), w_val_arena.data()); + [w_val_mat, res, b_arena]() mutable { b_arena.adj() += w_val_mat.transpose() * res.adj(); }); + std::cout << "NOT CALLED" << std::endl; return return_t(res); } else { - arena_t> w_arena = to_arena(w); - auto&& w_val = eval(value_of(w_arena)); - sparse_val_mat w_val_mat(m, n, w_val.size(), u_arena.data(), v_arena.data(), - w_val.data()); + auto w_mat_arena = to_soa_sparse_matrix(m, n, w, u_arena, v_arena); auto b_arena = to_arena(value_of(b)); - arena_t res = w_val_mat * b_arena; + arena_t res = w_mat_arena.val() * b_arena; reverse_pass_callback( - [m, n, w_arena, v_arena, u_arena, res, b_arena]() mutable { - internal::update_w(w_arena, m, n, u_arena, v_arena, b_arena, res); + [res, w_mat_arena, b_arena]() mutable { + internal::update_w(w_mat_arena, b_arena, res); + std::cout << "w_mat: " << w_mat_arena.adj(); }); return return_t(res); } diff --git a/stan/math/rev/fun/to_soa_sparse_matrix.hpp b/stan/math/rev/fun/to_soa_sparse_matrix.hpp new file mode 100644 index 00000000000..be95ad33476 --- /dev/null +++ b/stan/math/rev/fun/to_soa_sparse_matrix.hpp @@ -0,0 +1,84 @@ +#ifndef STAN_MATH_REV_FUN_TO_SOA_SPARSE_MATRIX_HPP +#define STAN_MATH_REV_FUN_TO_SOA_SPARSE_MATRIX_HPP + +#include +#include +#include +#include +#include + +namespace stan { +namespace math { + +template * = nullptr, + require_eigen_dense_base_t>* = nullptr, + require_all_std_vector_vt* = nullptr> +inline auto to_soa_sparse_matrix(int m, int n, T1&& w, T2&& u, T3&& v) { + arena_t v_arena(std::forward(v)); + arena_t u_arena(std::forward(u)); + using sparse_arena_mat_t = arena_t>; + sparse_arena_mat_t arena_val_x(m, n, w.val().size(), + u_arena.data(), v_arena.data(), w.vi_->val_.data()); + var_value> var_x(arena_val_x); + sparse_arena_mat_t arena_adj_x(m, n, w.adj().size(), + u_arena.data(), v_arena.data(), w.vi_->adj_.data()); + reverse_pass_callback([arena_adj_x, var_x]() mutable { + std::cout << "pre varx adj: \n" << var_x.adj() << std::endl; + std::cout << "pre arena_adj: \n" << arena_adj_x << std::endl; + arena_adj_x += var_x.adj(); + std::cout << "post varx adj: \n" << var_x.adj() << std::endl; + std::cout << "post arena_adj: \n" << arena_adj_x << std::endl; + }); + return var_x; + } + + +template * = nullptr, + require_all_std_vector_vt* = nullptr> +inline auto to_soa_sparse_matrix(int m, int n, T1&& w, T2&& u, T3&& v) { + arena_t w_arena(std::forward(w)); + arena_t v_arena(std::forward(v)); + arena_t u_arena(std::forward(u)); + arena_t> arena_x(m, n, w_arena.size(), + u_arena.data(), v_arena.data(), w_arena.data()); + var_value> var_x(value_of(arena_x)); + // No need to copy adj, but need to backprop + reverse_pass_callback([arena_x, var_x]() mutable { + using var_sparse_iterator_t = typename arena_t>::InnerIterator; + using dbl_sparse_iterator_t = typename arena_t>::InnerIterator; + // arena_x.adj() += var_x.adj() once custom adj() for var sparse matrix + for (int k = 0; k < arena_x.outerSize(); ++k) { + var_sparse_iterator_t it_arena_x(arena_x, k); + dbl_sparse_iterator_t it_var_x(var_x.adj(), k); + for (; bool(it_arena_x) && bool(it_var_x); ++it_arena_x, ++it_var_x) { + it_arena_x.valueRef().adj() += it_var_x.valueRef(); + } + } + }); + return var_x; + } + +template * = nullptr, + require_all_std_vector_vt* = nullptr> + inline auto to_soa_sparse_matrix(int m, int n, T1&& w, T2&& u, T3&& v) { + arena_t w_arena(std::forward(w)); + arena_t v_arena(std::forward(v)); + arena_t u_arena(std::forward(u)); + arena_t> arena_x(m, n, w_arena.size(), + u_arena.data(), v_arena.data(), w_arena.data()); + return var_value>(arena_x); + } + + template , is_eigen_sparse_base>>>* = nullptr> + inline auto to_soa_sparse_matrix(T&& x) { + return std::forward(x); + } + +} +} + +#endif + diff --git a/stan/math/rev/fun/to_var_value.hpp b/stan/math/rev/fun/to_var_value.hpp index 004b520b6f2..6664b7d7521 100644 --- a/stan/math/rev/fun/to_var_value.hpp +++ b/stan/math/rev/fun/to_var_value.hpp @@ -18,7 +18,7 @@ namespace math { * @param a matrix to convert */ template * = nullptr> -var_value> +inline var_value> to_var_value(const T& a) { arena_matrix> a_arena = a; var_value> res(a_arena.val()); @@ -34,7 +34,7 @@ to_var_value(const T& a) { * @param a matrix to convert */ template * = nullptr> -T to_var_value(T&& a) { +inline T to_var_value(T&& a) { return std::forward(a); } @@ -46,7 +46,7 @@ T to_var_value(T&& a) { * @param a std::vector of elements to convert */ template -auto to_var_value(const std::vector& a) { +inline auto to_var_value(const std::vector& a) { std::vector()))> out; out.reserve(a.size()); for (size_t i = 0; i < a.size(); ++i) { diff --git a/test/unit/math/mix/fun/csr_matrix_times_vector_test.cpp b/test/unit/math/mix/fun/csr_matrix_times_vector_test.cpp index 20c73955243..87e3377f417 100644 --- a/test/unit/math/mix/fun/csr_matrix_times_vector_test.cpp +++ b/test/unit/math/mix/fun/csr_matrix_times_vector_test.cpp @@ -1,7 +1,7 @@ #include #include - -TEST(MathMixMatFun, csr_matrix_times_vector) { +/* +TEST(MathMixMatFun, csr_matrix_times_vector1) { auto f = [](const auto& w, const auto& b) { using stan::math::csr_matrix_times_vector; std::vector v{1, 2, 3, 1, 2}; @@ -15,5 +15,20 @@ TEST(MathMixMatFun, csr_matrix_times_vector) { b << 1, 2, 3, 4, 5; stan::test::expect_ad(f, w, b); - stan::test::expect_ad_matvar(f, w, b); +} +*/ +TEST(MathMixMatFun, csr_matrix_times_vector2) { + auto f = [](const auto& w) { + using stan::math::csr_matrix_times_vector; + std::vector v{1, 2, 3, 1, 2}; + std::vector u{1, 2, 3, 4, 5, 6}; + Eigen::VectorXd b(5); + b << 1, 2, 3, 4, 5; + return csr_matrix_times_vector(5, 5, w, v, u, b); + }; + + Eigen::VectorXd w(5); + w << -0.67082, 0.5, -0.223607, -0.223607, -0.5; + + stan::test::expect_ad_matvar(f, w); } From fec3689225e7d00172cabfcdcb7513ac6ecf959a Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Mon, 15 Apr 2024 15:25:53 -0400 Subject: [PATCH 2/9] update csr matrix multiply to avoid linker error for windows. Adds to_soa_sparse_matrix for making sparse matrices from csr formatted data --- stan/math/prim/meta/is_eigen_dense_base.hpp | 16 ++++ stan/math/rev/core/arena_matrix.hpp | 25 +++--- stan/math/rev/core/var.hpp | 12 +++ stan/math/rev/core/vari.hpp | 29 ++++--- stan/math/rev/fun/csr_matrix_times_vector.hpp | 70 +++++++--------- stan/math/rev/fun/to_soa_sparse_matrix.hpp | 79 +++++++++++++----- .../mix/fun/csr_matrix_times_vector_test.cpp | 34 ++++++-- .../rev/fun/to_soa_sparse_matrix_test.cpp | 80 +++++++++++++++++++ 8 files changed, 259 insertions(+), 86 deletions(-) create mode 100644 test/unit/math/rev/fun/to_soa_sparse_matrix_test.cpp diff --git a/stan/math/prim/meta/is_eigen_dense_base.hpp b/stan/math/prim/meta/is_eigen_dense_base.hpp index 7b1ebc6cd1e..940f65bb2ef 100644 --- a/stan/math/prim/meta/is_eigen_dense_base.hpp +++ b/stan/math/prim/meta/is_eigen_dense_base.hpp @@ -33,6 +33,22 @@ using require_eigen_dense_base_t = require_t>>; /*! @} */ +/*! \ingroup require_eigens_types */ +/*! \defgroup eigen_dense_base_types eigen_dense_base_types */ +/*! \addtogroup eigen_dense_base_types */ +/*! @{ */ + +/*! \brief Require type satisfies @ref is_eigen_dense_base */ +/*! and value type satisfies `TypeCheck` */ +/*! @tparam TypeCheck The type trait to check the value type against */ +/*! @tparam Check The type to test @ref is_eigen_dense_base for and whose + * @ref value_type is checked with `TypeCheck` */ +template