From d06f5fdfa8a5c2ab4a92a27ee3e345c3ce253a47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Fri, 15 May 2020 06:47:47 +0200 Subject: [PATCH] Let value_of and value_of_rec return expressions (#1872) * let value_of and value_of_rec return expressions * [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.04.1 (tags/RELEASE_600/final) * added missing include * added missing include and generalized hmm_marginal_lpdf_val * added more missing includes * addressed review comments * [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.04.1 (tags/RELEASE_600/final) Co-authored-by: Stan Jenkins --- .../prim/categorical_logit_glm_lpmf.hpp | 2 - .../prim/neg_binomial_2_log_glm_lpmf.hpp | 2 - stan/math/opencl/prim/normal_id_glm_lpdf.hpp | 3 +- .../opencl/prim/ordered_logistic_glm_lpmf.hpp | 3 - stan/math/prim/fun.hpp | 1 + stan/math/prim/fun/log_mix.hpp | 7 +- stan/math/prim/fun/to_ref.hpp | 34 ++++++++ stan/math/prim/fun/value_of.hpp | 16 +--- stan/math/prim/fun/value_of_rec.hpp | 12 +-- .../prim/prob/bernoulli_logit_glm_lpmf.hpp | 5 +- .../prim/prob/categorical_logit_glm_lpmf.hpp | 6 +- stan/math/prim/prob/hmm_marginal_lpdf.hpp | 24 +++--- .../prim/prob/neg_binomial_2_log_glm_lpmf.hpp | 11 +-- stan/math/prim/prob/normal_id_glm_lpdf.hpp | 7 +- .../prim/prob/ordered_logistic_glm_lpmf.hpp | 7 +- stan/math/prim/prob/poisson_log_glm_lpmf.hpp | 7 +- test/unit/math/prim/fun/value_of_rec_test.cpp | 82 +++++++++++++++++++ test/unit/math/prim/fun/value_of_test.cpp | 41 ++++++++++ test/unit/math/rev/fun/value_of_rec_test.cpp | 18 ++++ test/unit/math/rev/fun/value_of_test.cpp | 18 ++++ 20 files changed, 250 insertions(+), 56 deletions(-) create mode 100644 stan/math/prim/fun/to_ref.hpp diff --git a/stan/math/opencl/prim/categorical_logit_glm_lpmf.hpp b/stan/math/opencl/prim/categorical_logit_glm_lpmf.hpp index 3a496836523..8f6d9d1fb90 100644 --- a/stan/math/opencl/prim/categorical_logit_glm_lpmf.hpp +++ b/stan/math/opencl/prim/categorical_logit_glm_lpmf.hpp @@ -72,8 +72,6 @@ return_type_t categorical_logit_glm_lpmf( const auto& beta_val = value_of_rec(beta); const auto& alpha_val = value_of_rec(alpha); - const auto& alpha_val_vec = as_column_vector_or_scalar(alpha_val).transpose(); - const int local_size = opencl_kernels::categorical_logit_glm.get_option("LOCAL_SIZE_"); const int wgs = (N_instances + local_size - 1) / local_size; diff --git a/stan/math/opencl/prim/neg_binomial_2_log_glm_lpmf.hpp b/stan/math/opencl/prim/neg_binomial_2_log_glm_lpmf.hpp index cc0d07c58f6..01724a1c025 100644 --- a/stan/math/opencl/prim/neg_binomial_2_log_glm_lpmf.hpp +++ b/stan/math/opencl/prim/neg_binomial_2_log_glm_lpmf.hpp @@ -100,8 +100,6 @@ return_type_t neg_binomial_2_log_glm_lpmf( const auto& alpha_val_vec = as_column_vector_or_scalar(alpha_val); const auto& phi_val_vec = as_column_vector_or_scalar(phi_val); - const auto& phi_arr = as_array_or_scalar(phi_val_vec); - const int local_size = opencl_kernels::neg_binomial_2_log_glm.get_option("LOCAL_SIZE_"); const int wgs = (N + local_size - 1) / local_size; diff --git a/stan/math/opencl/prim/normal_id_glm_lpdf.hpp b/stan/math/opencl/prim/normal_id_glm_lpdf.hpp index 233229bfb88..32a7c1cfd1b 100644 --- a/stan/math/opencl/prim/normal_id_glm_lpdf.hpp +++ b/stan/math/opencl/prim/normal_id_glm_lpdf.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -96,7 +97,7 @@ return_type_t normal_id_glm_lpdf( const auto &beta_val_vec = as_column_vector_or_scalar(beta_val); const auto &alpha_val_vec = as_column_vector_or_scalar(alpha_val); - const auto &sigma_val_vec = as_column_vector_or_scalar(sigma_val); + const auto &sigma_val_vec = to_ref(as_column_vector_or_scalar(sigma_val)); T_scale_val inv_sigma = 1 / as_array_or_scalar(sigma_val_vec); Matrix y_minus_mu_over_sigma_mat(N); diff --git a/stan/math/opencl/prim/ordered_logistic_glm_lpmf.hpp b/stan/math/opencl/prim/ordered_logistic_glm_lpmf.hpp index 936fafc27cf..11344b67fba 100644 --- a/stan/math/opencl/prim/ordered_logistic_glm_lpmf.hpp +++ b/stan/math/opencl/prim/ordered_logistic_glm_lpmf.hpp @@ -83,9 +83,6 @@ return_type_t ordered_logistic_glm_lpmf( const auto& beta_val = value_of_rec(beta); const auto& cuts_val = value_of_rec(cuts); - const auto& beta_val_vec = as_column_vector_or_scalar(beta_val); - const auto& cuts_val_vec = as_column_vector_or_scalar(cuts_val); - operands_and_partials, Eigen::Matrix> ops_partials(beta, cuts); diff --git a/stan/math/prim/fun.hpp b/stan/math/prim/fun.hpp index 64dbb661311..28997c51c61 100644 --- a/stan/math/prim/fun.hpp +++ b/stan/math/prim/fun.hpp @@ -306,6 +306,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/prim/fun/log_mix.hpp b/stan/math/prim/fun/log_mix.hpp index 240c7ddac12..d6ecf369eb8 100644 --- a/stan/math/prim/fun/log_mix.hpp +++ b/stan/math/prim/fun/log_mix.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -86,8 +87,8 @@ return_type_t log_mix(const T_theta& theta, check_finite(function, "theta", theta); check_consistent_sizes(function, "theta", theta, "lambda", lambda); - const auto& theta_dbl = value_of(as_column_vector_or_scalar(theta)); - const auto& lam_dbl = value_of(as_column_vector_or_scalar(lambda)); + const auto& theta_dbl = to_ref(value_of(as_column_vector_or_scalar(theta))); + const auto& lam_dbl = to_ref(value_of(as_column_vector_or_scalar(lambda))); T_partials_return logp = log_sum_exp(log(theta_dbl) + lam_dbl); @@ -158,7 +159,7 @@ return_type_t> log_mix( check_consistent_sizes(function, "theta", theta, "lambda", lambda[n]); } - const auto& theta_dbl = value_of(as_column_vector_or_scalar(theta)); + const auto& theta_dbl = to_ref(value_of(as_column_vector_or_scalar(theta))); T_partials_mat lam_dbl(M, N); for (int n = 0; n < N; ++n) { diff --git a/stan/math/prim/fun/to_ref.hpp b/stan/math/prim/fun/to_ref.hpp new file mode 100644 index 00000000000..ff305bee6f0 --- /dev/null +++ b/stan/math/prim/fun/to_ref.hpp @@ -0,0 +1,34 @@ +#ifndef STAN_MATH_PRIM_FUN_TO_REF_HPP +#define STAN_MATH_PRIM_FUN_TO_REF_HPP + +#include + +namespace stan { +namespace math { + +/** + * No-op that should be optimized away. + * @tparam T non-Eigen argument type + * @param a argument + * @return argument + */ +template * = nullptr> +inline T to_ref(T&& a) { + return std::forward(a); +} + +/** + * Converts Eigen argument into `Eigen::Ref`. This evaluate expensive + * expressions. + * @tparam T argument type (Eigen expression) + * @param a argument + * @return argument converted to `Eigen::Ref` + */ +template * = nullptr> +inline Eigen::Ref> to_ref(T&& a) { + return std::forward(a); +} + +} // namespace math +} // namespace stan +#endif diff --git a/stan/math/prim/fun/value_of.hpp b/stan/math/prim/fun/value_of.hpp index 059eb17341c..54e954633dd 100644 --- a/stan/math/prim/fun/value_of.hpp +++ b/stan/math/prim/fun/value_of.hpp @@ -99,22 +99,15 @@ inline Vec value_of(Vec&& x) { * T must implement value_of. See * test/math/fwd/fun/value_of.cpp for fvar and var usage. * - * @tparam T type of elements in the matrix - * @tparam R number of rows in the matrix, can be Eigen::Dynamic - * @tparam C number of columns in the matrix, can be Eigen::Dynamic + * @tparam EigMat type of the matrix * * @param[in] M Matrix to be converted * @return Matrix of values **/ template * = nullptr, require_not_vt_double_or_int* = nullptr> -inline Eigen::Matrix>::type, - EigMat::RowsAtCompileTime, EigMat::ColsAtCompileTime> -value_of(const EigMat& M) { - return M.array() - .unaryExpr([](const auto& scal) { return value_of(scal); }) - .matrix() - .eval(); +inline auto value_of(const EigMat& M) { + return M.unaryExpr([](const auto& scal) { return value_of(scal); }); } /** @@ -125,8 +118,7 @@ value_of(const EigMat& M) { * *

This inline pass-through no-op should be compiled away. * - * @tparam R number of rows in the matrix, can be Eigen::Dynamic - * @tparam C number of columns in the matrix, can be Eigen::Dynamic + * @tparam EigMat type of the matrix * * @param x Specified matrix. * @return Specified matrix. diff --git a/stan/math/prim/fun/value_of_rec.hpp b/stan/math/prim/fun/value_of_rec.hpp index d27b8e637d9..03dd59e3bd1 100644 --- a/stan/math/prim/fun/value_of_rec.hpp +++ b/stan/math/prim/fun/value_of_rec.hpp @@ -65,7 +65,7 @@ inline std::complex value_of_rec(const std::complex& x) { * @param[in] x std::vector to be converted * @return std::vector of values **/ -template +template * = nullptr> inline std::vector value_of_rec(const std::vector& x) { size_t x_size = x.size(); std::vector result(x_size); @@ -86,8 +86,10 @@ inline std::vector value_of_rec(const std::vector& x) { * @param x Specified std::vector. * @return Specified std::vector. */ -inline const std::vector& value_of_rec(const std::vector& x) { - return x; +template * = nullptr, + require_vt_same* = nullptr> +inline T value_of_rec(T&& x) { + return std::forward(x); } /** @@ -120,8 +122,8 @@ inline auto value_of_rec(const T& M) { */ template , typename = require_eigen_t> -inline const T& value_of_rec(const T& x) { - return x; +inline T value_of_rec(T&& x) { + return std::forward(x); } } // namespace math } // namespace stan diff --git a/stan/math/prim/prob/bernoulli_logit_glm_lpmf.hpp b/stan/math/prim/prob/bernoulli_logit_glm_lpmf.hpp index 86b1af1c5be..258265f87fd 100644 --- a/stan/math/prim/prob/bernoulli_logit_glm_lpmf.hpp +++ b/stan/math/prim/prob/bernoulli_logit_glm_lpmf.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -82,13 +83,13 @@ return_type_t bernoulli_logit_glm_lpmf( } T_partials_return logp(0); - const auto &x_val = value_of_rec(x); + const auto &x_val = to_ref(value_of_rec(x)); const auto &y_val = value_of_rec(y); const auto &beta_val = value_of_rec(beta); const auto &alpha_val = value_of_rec(alpha); const auto &y_val_vec = as_column_vector_or_scalar(y_val); - const auto &beta_val_vec = as_column_vector_or_scalar(beta_val); + const auto &beta_val_vec = to_ref(as_column_vector_or_scalar(beta_val)); const auto &alpha_val_vec = as_column_vector_or_scalar(alpha_val); T_y_val signs = 2 * as_array_or_scalar(y_val_vec) - 1; diff --git a/stan/math/prim/prob/categorical_logit_glm_lpmf.hpp b/stan/math/prim/prob/categorical_logit_glm_lpmf.hpp index f0c87178742..428cda67a83 100644 --- a/stan/math/prim/prob/categorical_logit_glm_lpmf.hpp +++ b/stan/math/prim/prob/categorical_logit_glm_lpmf.hpp @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include #include @@ -73,8 +75,8 @@ categorical_logit_glm_lpmf( return 0; } - const auto& x_val = value_of_rec(x); - const auto& beta_val = value_of_rec(beta); + const auto& x_val = to_ref(value_of_rec(x)); + const auto& beta_val = to_ref(value_of_rec(beta)); const auto& alpha_val = value_of_rec(alpha); const auto& alpha_val_vec = as_column_vector_or_scalar(alpha_val).transpose(); diff --git a/stan/math/prim/prob/hmm_marginal_lpdf.hpp b/stan/math/prim/prob/hmm_marginal_lpdf.hpp index 4dd4fe1cb09..67968c6d9b9 100644 --- a/stan/math/prim/prob/hmm_marginal_lpdf.hpp +++ b/stan/math/prim/prob/hmm_marginal_lpdf.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -14,14 +15,17 @@ namespace stan { namespace math { -template -inline auto hmm_marginal_lpdf_val( - const Eigen::Matrix& omegas, - const Eigen::Matrix& Gamma_val, - const Eigen::Matrix& rho_val, - Eigen::Matrix& alphas, - Eigen::Matrix& alpha_log_norms, - T_alpha& norm_norm) { +template * = nullptr, + require_all_eigen_col_vector_t* = nullptr, + require_stan_scalar_t* = nullptr, + require_all_vt_same* = nullptr> +inline auto hmm_marginal_lpdf_val(const T_omega& omegas, + const T_Gamma& Gamma_val, + const T_rho& rho_val, T_alphas& alphas, + T_alpha_log_norm& alpha_log_norms, + T_norm& norm_norm) { const int n_states = omegas.rows(); const int n_transitions = omegas.cols() - 1; alphas.col(0) = omegas.col(0).cwiseProduct(rho_val); @@ -100,10 +104,10 @@ inline auto hmm_marginal_lpdf( eig_matrix_partial alphas(n_states, n_transitions + 1); eig_vector_partial alpha_log_norms(n_transitions + 1); - auto Gamma_val = value_of(Gamma); + const auto& Gamma_val = to_ref(value_of(Gamma)); // compute the density using the forward algorithm. - auto rho_val = value_of(rho); + const auto& rho_val = to_ref(value_of(rho)); eig_matrix_partial omegas = value_of(log_omegas).array().exp(); T_partial_type norm_norm; auto log_marginal_density = hmm_marginal_lpdf_val( diff --git a/stan/math/prim/prob/neg_binomial_2_log_glm_lpmf.hpp b/stan/math/prim/prob/neg_binomial_2_log_glm_lpmf.hpp index 713a24c0217..70a09d0960a 100644 --- a/stan/math/prim/prob/neg_binomial_2_log_glm_lpmf.hpp +++ b/stan/math/prim/prob/neg_binomial_2_log_glm_lpmf.hpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -105,16 +106,16 @@ neg_binomial_2_log_glm_lpmf( } T_partials_return logp(0); - const auto& x_val = value_of_rec(x); + const auto& x_val = to_ref(value_of_rec(x)); const auto& y_val = value_of_rec(y); const auto& beta_val = value_of_rec(beta); const auto& alpha_val = value_of_rec(alpha); const auto& phi_val = value_of_rec(phi); - const auto& y_val_vec = as_column_vector_or_scalar(y_val); - const auto& beta_val_vec = as_column_vector_or_scalar(beta_val); + const auto& y_val_vec = to_ref(as_column_vector_or_scalar(y_val)); + const auto& beta_val_vec = to_ref(as_column_vector_or_scalar(beta_val)); const auto& alpha_val_vec = as_column_vector_or_scalar(alpha_val); - const auto& phi_val_vec = as_column_vector_or_scalar(phi_val); + const auto& phi_val_vec = to_ref(as_column_vector_or_scalar(phi_val)); const auto& y_arr = as_array_or_scalar(y_val_vec); const auto& phi_arr = as_array_or_scalar(phi_val_vec); @@ -147,7 +148,7 @@ neg_binomial_2_log_glm_lpmf( } if (include_summand::value) { if (is_vector::value) { - scalar_seq_view phi_vec(phi_val); + scalar_seq_view phi_vec(phi_val_vec); for (size_t n = 0; n < N_instances; ++n) { logp += multiply_log(phi_vec[n], phi_vec[n]) - lgamma(phi_vec[n]); } diff --git a/stan/math/prim/prob/normal_id_glm_lpdf.hpp b/stan/math/prim/prob/normal_id_glm_lpdf.hpp index 5a32fe95f72..f24f72fc1f3 100644 --- a/stan/math/prim/prob/normal_id_glm_lpdf.hpp +++ b/stan/math/prim/prob/normal_id_glm_lpdf.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -88,15 +89,15 @@ return_type_t normal_id_glm_lpdf( return 0; } - const auto &x_val = value_of_rec(x); + const auto &x_val = to_ref(value_of_rec(x)); const auto &beta_val = value_of_rec(beta); const auto &alpha_val = value_of_rec(alpha); const auto &sigma_val = value_of_rec(sigma); const auto &y_val = value_of_rec(y); - const auto &beta_val_vec = as_column_vector_or_scalar(beta_val); + const auto &beta_val_vec = to_ref(as_column_vector_or_scalar(beta_val)); const auto &alpha_val_vec = as_column_vector_or_scalar(alpha_val); - const auto &sigma_val_vec = as_column_vector_or_scalar(sigma_val); + const auto &sigma_val_vec = to_ref(as_column_vector_or_scalar(sigma_val)); const auto &y_val_vec = as_column_vector_or_scalar(y_val); T_scale_val inv_sigma = 1 / as_array_or_scalar(sigma_val_vec); diff --git a/stan/math/prim/prob/ordered_logistic_glm_lpmf.hpp b/stan/math/prim/prob/ordered_logistic_glm_lpmf.hpp index 0f1656ad849..89cdca3bee2 100644 --- a/stan/math/prim/prob/ordered_logistic_glm_lpmf.hpp +++ b/stan/math/prim/prob/ordered_logistic_glm_lpmf.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -81,12 +82,12 @@ ordered_logistic_glm_lpmf( if (!include_summand::value) return 0; - const auto& x_val = value_of_rec(x); + const auto& x_val = to_ref(value_of_rec(x)); const auto& beta_val = value_of_rec(beta); const auto& cuts_val = value_of_rec(cuts); - const auto& beta_val_vec = as_column_vector_or_scalar(beta_val); - const auto& cuts_val_vec = as_column_vector_or_scalar(cuts_val); + const auto& beta_val_vec = to_ref(as_column_vector_or_scalar(beta_val)); + const auto& cuts_val_vec = to_ref(as_column_vector_or_scalar(cuts_val)); scalar_seq_view y_seq(y); Array cuts_y1(N_instances), cuts_y2(N_instances); diff --git a/stan/math/prim/prob/poisson_log_glm_lpmf.hpp b/stan/math/prim/prob/poisson_log_glm_lpmf.hpp index 8571c5c3526..01bf233ab72 100644 --- a/stan/math/prim/prob/poisson_log_glm_lpmf.hpp +++ b/stan/math/prim/prob/poisson_log_glm_lpmf.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -82,13 +83,13 @@ return_type_t poisson_log_glm_lpmf( T_partials_return logp(0); - const auto& x_val = value_of_rec(x); + const auto& x_val = to_ref(value_of_rec(x)); const auto& y_val = value_of_rec(y); const auto& beta_val = value_of_rec(beta); const auto& alpha_val = value_of_rec(alpha); - const auto& y_val_vec = as_column_vector_or_scalar(y_val); - const auto& beta_val_vec = as_column_vector_or_scalar(beta_val); + const auto& y_val_vec = to_ref(as_column_vector_or_scalar(y_val)); + const auto& beta_val_vec = to_ref(as_column_vector_or_scalar(beta_val)); const auto& alpha_val_vec = as_column_vector_or_scalar(alpha_val); Array theta(N_instances); diff --git a/test/unit/math/prim/fun/value_of_rec_test.cpp b/test/unit/math/prim/fun/value_of_rec_test.cpp index 2b4da43c24e..59867a6dd12 100644 --- a/test/unit/math/prim/fun/value_of_rec_test.cpp +++ b/test/unit/math/prim/fun/value_of_rec_test.cpp @@ -2,6 +2,10 @@ #include #include +#define EXPECT_MATRIX_EQ(A, B) \ + for (int i = 0; i < A.size(); i++) \ + EXPECT_EQ(A(i), B(i)); + TEST(MathFunctions, value_of_rec) { using stan::math::value_of_rec; double x = 5.0; @@ -53,3 +57,81 @@ TEST(MathMatrixPrimMat, value_of_rec) { for (int j = 0; j < 5; ++j) EXPECT_FLOAT_EQ(a(i, j), d_a(i, j)); } + +TEST(MathMatrixPrimMat, value_of_rec_expression) { + using stan::math::value_of_rec; + + Eigen::MatrixXd a = Eigen::MatrixXd::Random(5, 4); + Eigen::MatrixXd res_a = value_of_rec(2 * a); + Eigen::MatrixXd correct_a = 2 * a; + EXPECT_MATRIX_EQ(res_a, correct_a); + + Eigen::VectorXi b = Eigen::VectorXi::Random(7); + Eigen::VectorXd res_b = value_of_rec(2 * b); + Eigen::VectorXd correct_b = (2 * b).cast(); + EXPECT_MATRIX_EQ(res_b, correct_b); + + Eigen::ArrayXXd c = a.array(); + Eigen::ArrayXXd res_c = value_of_rec(2 * c); + Eigen::ArrayXXd correct_c = 2 * c; + EXPECT_MATRIX_EQ(res_c, correct_c); +} + +TEST(MathFunctions, value_of_rec_return_type_short_circuit_std_vector) { + std::vector a(5); + const std::vector b(5); + EXPECT_TRUE((std::is_same&>::value)); + EXPECT_TRUE((std::is_same&>::value)); +} + +TEST(MathFunctions, value_of_rec_return_type_short_circuit_vector_xd) { + Eigen::Matrix a(5); + const Eigen::Matrix b(5); + EXPECT_TRUE((std::is_same&>::value)); + EXPECT_TRUE( + (std::is_same&>::value)); +} + +TEST(MathFunctions, value_of_rec_return_type_short_circuit_row_vector_xd) { + Eigen::Matrix a(5); + const Eigen::Matrix b(5); + EXPECT_TRUE((std::is_same&>::value)); + EXPECT_TRUE( + (std::is_same&>::value)); +} + +TEST(MathFunctions, value_of_rec_return_type_short_circuit_matrix_xd) { + Eigen::Matrix a(5, 4); + const Eigen::Matrix b(5, 4); + EXPECT_TRUE((std::is_same< + decltype(stan::math::value_of_rec(a)), + Eigen::Matrix&>::value)); + EXPECT_TRUE((std::is_same&>::value)); +} + +TEST(MathFunctions, value_of_rec_return_type_short_circuit_expression) { + const Eigen::Matrix a(5, 4); + + const auto& expr = 3 * a; + + EXPECT_TRUE((std::is_same::value)); +} + +TEST(MathFunctions, + value_of_rec_return_type_short_circuit_static_sized_matrix) { + Eigen::Matrix a; + const Eigen::Matrix b; + EXPECT_TRUE((std::is_same&>::value)); + EXPECT_TRUE((std::is_same&>::value)); +} diff --git a/test/unit/math/prim/fun/value_of_test.cpp b/test/unit/math/prim/fun/value_of_test.cpp index 33ef68fefde..822a113ad75 100644 --- a/test/unit/math/prim/fun/value_of_test.cpp +++ b/test/unit/math/prim/fun/value_of_test.cpp @@ -5,6 +5,10 @@ #include #include +#define EXPECT_MATRIX_EQ(A, B) \ + for (int i = 0; i < A.size(); i++) \ + EXPECT_EQ(A(i), B(i)); + TEST(MathFunctions, value_of) { using stan::math::value_of; double x = 5.0; @@ -81,6 +85,34 @@ TEST(MathMatrixPrimMat, value_of) { EXPECT_FLOAT_EQ(a(i, j), d_a(i, j)); } +TEST(MathMatrixPrimMat, value_of_expression) { + using stan::math::value_of; + + Eigen::MatrixXd a = Eigen::MatrixXd::Random(5, 4); + Eigen::MatrixXd res_a = value_of(2 * a); + Eigen::MatrixXd correct_a = 2 * a; + EXPECT_MATRIX_EQ(res_a, correct_a); + + Eigen::VectorXi b = Eigen::VectorXi::Random(7); + Eigen::VectorXi res_b = value_of(2 * b); + Eigen::VectorXi correct_b = 2 * b; + EXPECT_MATRIX_EQ(res_b, correct_b); + + Eigen::ArrayXXd c = a.array(); + Eigen::ArrayXXd res_c = value_of(2 * c); + Eigen::ArrayXXd correct_c = 2 * c; + EXPECT_MATRIX_EQ(res_c, correct_c); +} + +TEST(MathFunctions, value_of_return_type_short_circuit_std_vector) { + std::vector a(5); + const std::vector b(5); + EXPECT_TRUE((std::is_same&>::value)); + EXPECT_TRUE((std::is_same&>::value)); +} + TEST(MathFunctions, value_of_return_type_short_circuit_vector_xd) { Eigen::Matrix a(5); const Eigen::Matrix b(5); @@ -112,6 +144,15 @@ TEST(MathFunctions, value_of_return_type_short_circuit_matrix_xd) { Eigen::Dynamic>&>::value)); } +TEST(MathFunctions, value_of_return_type_short_circuit_expression) { + const Eigen::Matrix a(5, 4); + + const auto& expr = 3 * a; + + EXPECT_TRUE((std::is_same::value)); +} + TEST(MathFunctions, value_of_return_type_short_circuit_static_sized_matrix) { Eigen::Matrix a; const Eigen::Matrix b; diff --git a/test/unit/math/rev/fun/value_of_rec_test.cpp b/test/unit/math/rev/fun/value_of_rec_test.cpp index 5b4326636bb..c6188640e89 100644 --- a/test/unit/math/rev/fun/value_of_rec_test.cpp +++ b/test/unit/math/rev/fun/value_of_rec_test.cpp @@ -3,6 +3,10 @@ #include #include +#define EXPECT_MATRIX_NEAR(A, B, DELTA) \ + for (int i = 0; i < A.size(); i++) \ + EXPECT_NEAR(A(i), B(i), DELTA); + TEST(AgradRev, value_of_rec) { using stan::math::value_of_rec; using stan::math::var; @@ -77,3 +81,17 @@ TEST(AgradMatrixRev, value_of_rec) { EXPECT_FLOAT_EQ(a(i, j), d_v_a(i, j)); } } + +TEST(AgradMatrixRev, value_of_rec_expression) { + using Eigen::Array; + using Eigen::ArrayXXd; + using Eigen::Matrix; + using Eigen::MatrixXd; + using stan::math::value_of; + using stan::math::var; + Matrix a = MatrixXd::Random(7, 4); + MatrixXd res = value_of_rec(2 * a); + MatrixXd correct = 2 * value_of_rec(a); + + EXPECT_MATRIX_NEAR(res, correct, 1e-10); +} diff --git a/test/unit/math/rev/fun/value_of_test.cpp b/test/unit/math/rev/fun/value_of_test.cpp index ef64108a6b0..d530db530a1 100644 --- a/test/unit/math/rev/fun/value_of_test.cpp +++ b/test/unit/math/rev/fun/value_of_test.cpp @@ -3,6 +3,10 @@ #include #include +#define EXPECT_MATRIX_NEAR(A, B, DELTA) \ + for (int i = 0; i < A.size(); i++) \ + EXPECT_NEAR(A(i), B(i), DELTA); + TEST(AgradRev, value_of) { using stan::math::value_of; using stan::math::var; @@ -85,3 +89,17 @@ TEST(AgradMatrix, value_of) { EXPECT_FLOAT_EQ(a(i, j), d_v_a(i, j)); } } + +TEST(AgradMatrix, value_of_expression) { + using Eigen::Array; + using Eigen::ArrayXXd; + using Eigen::Matrix; + using Eigen::MatrixXd; + using stan::math::value_of; + using stan::math::var; + Matrix a = MatrixXd::Random(7, 4); + MatrixXd res = value_of(2 * a); + MatrixXd correct = 2 * value_of(a); + + EXPECT_MATRIX_NEAR(res, correct, 1e-10); +}