Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let value_of and value_of_rec return expressions #1872

Merged
merged 8 commits into from
May 15, 2020
Merged
2 changes: 0 additions & 2 deletions stan/math/opencl/prim/categorical_logit_glm_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ return_type_t<T_alpha_scalar, T_beta_scalar> categorical_logit_glm_lpmf(
const auto& beta_val = value_of_rec(beta);
const auto& alpha_val = value_of_rec(alpha);

const auto& alpha_val_vec = as_column_vector_or_scalar(alpha_val).transpose();

const int local_size
= opencl_kernels::categorical_logit_glm.get_option("LOCAL_SIZE_");
const int wgs = (N_instances + local_size - 1) / local_size;
Expand Down
2 changes: 0 additions & 2 deletions stan/math/opencl/prim/neg_binomial_2_log_glm_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ return_type_t<T_alpha, T_beta, T_precision> neg_binomial_2_log_glm_lpmf(
const auto& alpha_val_vec = as_column_vector_or_scalar(alpha_val);
const auto& phi_val_vec = as_column_vector_or_scalar(phi_val);

const auto& phi_arr = as_array_or_scalar(phi_val_vec);

const int local_size
= opencl_kernels::neg_binomial_2_log_glm.get_option("LOCAL_SIZE_");
const int wgs = (N + local_size - 1) / local_size;
Expand Down
3 changes: 2 additions & 1 deletion stan/math/opencl/prim/normal_id_glm_lpdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
#include <stan/math/prim/fun/sum.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/fun/value_of_rec.hpp>
#include <stan/math/prim/prob/normal_id_glm_lpdf.hpp>
#include <stan/math/opencl/copy.hpp>
Expand Down Expand Up @@ -96,7 +97,7 @@ return_type_t<T_alpha, T_beta, T_scale> normal_id_glm_lpdf(

const auto &beta_val_vec = as_column_vector_or_scalar(beta_val);
const auto &alpha_val_vec = as_column_vector_or_scalar(alpha_val);
const auto &sigma_val_vec = as_column_vector_or_scalar(sigma_val);
const auto &sigma_val_vec = to_ref(as_column_vector_or_scalar(sigma_val));

T_scale_val inv_sigma = 1 / as_array_or_scalar(sigma_val_vec);
Matrix<T_partials_return, Dynamic, 1> y_minus_mu_over_sigma_mat(N);
Expand Down
3 changes: 0 additions & 3 deletions stan/math/opencl/prim/ordered_logistic_glm_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,6 @@ return_type_t<T_beta_scalar, T_cuts_scalar> ordered_logistic_glm_lpmf(
const auto& beta_val = value_of_rec(beta);
const auto& cuts_val = value_of_rec(cuts);

const auto& beta_val_vec = as_column_vector_or_scalar(beta_val);
const auto& cuts_val_vec = as_column_vector_or_scalar(cuts_val);

operands_and_partials<Eigen::Matrix<T_beta_scalar, Eigen::Dynamic, 1>,
Eigen::Matrix<T_cuts_scalar, Eigen::Dynamic, 1>>
ops_partials(beta, cuts);
Expand Down
1 change: 1 addition & 0 deletions stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@
#include <stan/math/prim/fun/to_array_1d.hpp>
#include <stan/math/prim/fun/to_array_2d.hpp>
#include <stan/math/prim/fun/to_matrix.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/fun/to_row_vector.hpp>
#include <stan/math/prim/fun/to_vector.hpp>
#include <stan/math/prim/fun/trace.hpp>
Expand Down
7 changes: 4 additions & 3 deletions stan/math/prim/fun/log_mix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <stan/math/prim/fun/log1m.hpp>
#include <stan/math/prim/fun/log_sum_exp.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <vector>
#include <cmath>
Expand Down Expand Up @@ -86,8 +87,8 @@ return_type_t<T_theta, T_lam> log_mix(const T_theta& theta,
check_finite(function, "theta", theta);
check_consistent_sizes(function, "theta", theta, "lambda", lambda);

const auto& theta_dbl = value_of(as_column_vector_or_scalar(theta));
const auto& lam_dbl = value_of(as_column_vector_or_scalar(lambda));
const auto& theta_dbl = to_ref(value_of(as_column_vector_or_scalar(theta)));
const auto& lam_dbl = to_ref(value_of(as_column_vector_or_scalar(lambda)));

T_partials_return logp = log_sum_exp(log(theta_dbl) + lam_dbl);

Expand Down Expand Up @@ -158,7 +159,7 @@ return_type_t<T_theta, std::vector<T_lam>> log_mix(
check_consistent_sizes(function, "theta", theta, "lambda", lambda[n]);
}

const auto& theta_dbl = value_of(as_column_vector_or_scalar(theta));
const auto& theta_dbl = to_ref(value_of(as_column_vector_or_scalar(theta)));

T_partials_mat lam_dbl(M, N);
for (int n = 0; n < N; ++n) {
Expand Down
34 changes: 34 additions & 0 deletions stan/math/prim/fun/to_ref.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef STAN_MATH_PRIM_FUN_TO_REF_HPP
#define STAN_MATH_PRIM_FUN_TO_REF_HPP

#include <stan/math/prim/meta.hpp>

namespace stan {
namespace math {

/**
* No-op that should be optimized away.
* @tparam T non-Eigen argument type
* @param a argument
* @return argument
*/
template <typename T, require_not_eigen_t<T>* = nullptr>
inline T to_ref(T&& a) {
andrjohns marked this conversation as resolved.
Show resolved Hide resolved
return std::forward<T>(a);
}

/**
* Converts Eigen argument into `Eigen::Ref`. This evaluate expensive
* expressions.
* @tparam T argument type (Eigen expression)
* @param a argument
* @return argument converted to `Eigen::Ref`
*/
template <typename T, require_eigen_t<T>* = nullptr>
inline Eigen::Ref<const plain_type_t<T>> to_ref(T&& a) {
andrjohns marked this conversation as resolved.
Show resolved Hide resolved
return std::forward<T>(a);
}

} // namespace math
} // namespace stan
#endif
16 changes: 4 additions & 12 deletions stan/math/prim/fun/value_of.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,22 +99,15 @@ inline Vec value_of(Vec&& x) {
* T must implement value_of. See
* test/math/fwd/fun/value_of.cpp for fvar and var usage.
*
* @tparam T type of elements in the matrix
* @tparam R number of rows in the matrix, can be Eigen::Dynamic
* @tparam C number of columns in the matrix, can be Eigen::Dynamic
* @tparam EigMat type of the matrix
*
* @param[in] M Matrix to be converted
* @return Matrix of values
**/
template <typename EigMat, require_eigen_t<EigMat>* = nullptr,
require_not_vt_double_or_int<EigMat>* = nullptr>
inline Eigen::Matrix<typename child_type<value_type_t<EigMat>>::type,
EigMat::RowsAtCompileTime, EigMat::ColsAtCompileTime>
value_of(const EigMat& M) {
return M.array()
.unaryExpr([](const auto& scal) { return value_of(scal); })
.matrix()
.eval();
inline auto value_of(const EigMat& M) {
andrjohns marked this conversation as resolved.
Show resolved Hide resolved
return M.unaryExpr([](const auto& scal) { return value_of(scal); });
}

/**
Expand All @@ -125,8 +118,7 @@ value_of(const EigMat& M) {
*
* <p>This inline pass-through no-op should be compiled away.
*
* @tparam R number of rows in the matrix, can be Eigen::Dynamic
* @tparam C number of columns in the matrix, can be Eigen::Dynamic
* @tparam EigMat type of the matrix
*
* @param x Specified matrix.
* @return Specified matrix.
Expand Down
12 changes: 7 additions & 5 deletions stan/math/prim/fun/value_of_rec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ inline std::complex<double> value_of_rec(const std::complex<T>& x) {
* @param[in] x std::vector to be converted
* @return std::vector of values
**/
template <typename T>
template <typename T, require_not_same_t<double, T>* = nullptr>
inline std::vector<double> value_of_rec(const std::vector<T>& x) {
size_t x_size = x.size();
std::vector<double> result(x_size);
Expand All @@ -86,8 +86,10 @@ inline std::vector<double> value_of_rec(const std::vector<T>& x) {
* @param x Specified std::vector.
* @return Specified std::vector.
*/
inline const std::vector<double>& value_of_rec(const std::vector<double>& x) {
return x;
template <typename T, require_std_vector_t<T>* = nullptr,
require_vt_same<double, T>* = nullptr>
inline T value_of_rec(T&& x) {
return std::forward<T>(x);
}

/**
Expand Down Expand Up @@ -120,8 +122,8 @@ inline auto value_of_rec(const T& M) {
*/
template <typename T, typename = require_st_same<T, double>,
typename = require_eigen_t<T>>
inline const T& value_of_rec(const T& x) {
return x;
inline T value_of_rec(T&& x) {
return std::forward<T>(x);
}
} // namespace math
} // namespace stan
Expand Down
5 changes: 3 additions & 2 deletions stan/math/prim/prob/bernoulli_logit_glm_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/fun/value_of_rec.hpp>
#include <cmath>

Expand Down Expand Up @@ -82,13 +83,13 @@ return_type_t<T_x_scalar, T_alpha, T_beta> bernoulli_logit_glm_lpmf(
}

T_partials_return logp(0);
const auto &x_val = value_of_rec(x);
const auto &x_val = to_ref(value_of_rec(x));
const auto &y_val = value_of_rec(y);
const auto &beta_val = value_of_rec(beta);
const auto &alpha_val = value_of_rec(alpha);

const auto &y_val_vec = as_column_vector_or_scalar(y_val);
const auto &beta_val_vec = as_column_vector_or_scalar(beta_val);
const auto &beta_val_vec = to_ref(as_column_vector_or_scalar(beta_val));
const auto &alpha_val_vec = as_column_vector_or_scalar(alpha_val);

T_y_val signs = 2 * as_array_or_scalar(y_val_vec) - 1;
Expand Down
6 changes: 4 additions & 2 deletions stan/math/prim/prob/categorical_logit_glm_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/fun/value_of_rec.hpp>
#include <Eigen/Core>
#include <cmath>

Expand Down Expand Up @@ -73,8 +75,8 @@ categorical_logit_glm_lpmf(
return 0;
}

const auto& x_val = value_of_rec(x);
const auto& beta_val = value_of_rec(beta);
const auto& x_val = to_ref(value_of_rec(x));
const auto& beta_val = to_ref(value_of_rec(beta));
const auto& alpha_val = value_of_rec(alpha);

const auto& alpha_val_vec = as_column_vector_or_scalar(alpha_val).transpose();
Expand Down
24 changes: 14 additions & 10 deletions stan/math/prim/prob/hmm_marginal_lpdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,25 @@
#include <stan/math/prim/fun/col.hpp>
#include <stan/math/prim/fun/transpose.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/core.hpp>
#include <vector>

namespace stan {
namespace math {

template <typename T_omega, typename T_Gamma, typename T_rho, typename T_alpha>
inline auto hmm_marginal_lpdf_val(
const Eigen::Matrix<T_omega, Eigen::Dynamic, Eigen::Dynamic>& omegas,
const Eigen::Matrix<T_Gamma, Eigen::Dynamic, Eigen::Dynamic>& Gamma_val,
const Eigen::Matrix<T_rho, Eigen::Dynamic, 1>& rho_val,
Eigen::Matrix<T_alpha, Eigen::Dynamic, Eigen::Dynamic>& alphas,
Eigen::Matrix<T_alpha, Eigen::Dynamic, 1>& alpha_log_norms,
T_alpha& norm_norm) {
template <typename T_omega, typename T_Gamma, typename T_rho, typename T_alphas,
typename T_alpha_log_norm, typename T_norm,
require_all_eigen_matrix_t<T_omega, T_Gamma, T_alphas>* = nullptr,
require_all_eigen_col_vector_t<T_rho, T_alpha_log_norm>* = nullptr,
require_stan_scalar_t<T_norm>* = nullptr,
require_all_vt_same<T_alphas, T_alpha_log_norm, T_norm>* = nullptr>
inline auto hmm_marginal_lpdf_val(const T_omega& omegas,
const T_Gamma& Gamma_val,
const T_rho& rho_val, T_alphas& alphas,
T_alpha_log_norm& alpha_log_norms,
T_norm& norm_norm) {
const int n_states = omegas.rows();
const int n_transitions = omegas.cols() - 1;
alphas.col(0) = omegas.col(0).cwiseProduct(rho_val);
Expand Down Expand Up @@ -100,10 +104,10 @@ inline auto hmm_marginal_lpdf(

eig_matrix_partial alphas(n_states, n_transitions + 1);
eig_vector_partial alpha_log_norms(n_transitions + 1);
auto Gamma_val = value_of(Gamma);
const auto& Gamma_val = to_ref(value_of(Gamma));

// compute the density using the forward algorithm.
auto rho_val = value_of(rho);
const auto& rho_val = to_ref(value_of(rho));
eig_matrix_partial omegas = value_of(log_omegas).array().exp();
T_partial_type norm_norm;
auto log_marginal_density = hmm_marginal_lpdf_val(
Expand Down
11 changes: 6 additions & 5 deletions stan/math/prim/prob/neg_binomial_2_log_glm_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <stan/math/prim/fun/multiply_log.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/sum.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/fun/value_of_rec.hpp>
#include <vector>
#include <cmath>
Expand Down Expand Up @@ -105,16 +106,16 @@ neg_binomial_2_log_glm_lpmf(
}

T_partials_return logp(0);
const auto& x_val = value_of_rec(x);
const auto& x_val = to_ref(value_of_rec(x));
const auto& y_val = value_of_rec(y);
const auto& beta_val = value_of_rec(beta);
const auto& alpha_val = value_of_rec(alpha);
const auto& phi_val = value_of_rec(phi);
andrjohns marked this conversation as resolved.
Show resolved Hide resolved

const auto& y_val_vec = as_column_vector_or_scalar(y_val);
const auto& beta_val_vec = as_column_vector_or_scalar(beta_val);
const auto& y_val_vec = to_ref(as_column_vector_or_scalar(y_val));
const auto& beta_val_vec = to_ref(as_column_vector_or_scalar(beta_val));
const auto& alpha_val_vec = as_column_vector_or_scalar(alpha_val);
const auto& phi_val_vec = as_column_vector_or_scalar(phi_val);
const auto& phi_val_vec = to_ref(as_column_vector_or_scalar(phi_val));

const auto& y_arr = as_array_or_scalar(y_val_vec);
const auto& phi_arr = as_array_or_scalar(phi_val_vec);
Expand Down Expand Up @@ -147,7 +148,7 @@ neg_binomial_2_log_glm_lpmf(
}
if (include_summand<propto, T_precision>::value) {
if (is_vector<T_precision>::value) {
scalar_seq_view<decltype(phi_val)> phi_vec(phi_val);
scalar_seq_view<decltype(phi_val_vec)> phi_vec(phi_val_vec);
for (size_t n = 0; n < N_instances; ++n) {
logp += multiply_log(phi_vec[n], phi_vec[n]) - lgamma(phi_vec[n]);
}
Expand Down
7 changes: 4 additions & 3 deletions stan/math/prim/prob/normal_id_glm_lpdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
#include <stan/math/prim/fun/sum.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/fun/value_of_rec.hpp>
#include <cmath>

Expand Down Expand Up @@ -88,15 +89,15 @@ return_type_t<T_y, T_x_scalar, T_alpha, T_beta, T_scale> normal_id_glm_lpdf(
return 0;
}

const auto &x_val = value_of_rec(x);
const auto &x_val = to_ref(value_of_rec(x));
const auto &beta_val = value_of_rec(beta);
const auto &alpha_val = value_of_rec(alpha);
const auto &sigma_val = value_of_rec(sigma);
const auto &y_val = value_of_rec(y);

const auto &beta_val_vec = as_column_vector_or_scalar(beta_val);
const auto &beta_val_vec = to_ref(as_column_vector_or_scalar(beta_val));
const auto &alpha_val_vec = as_column_vector_or_scalar(alpha_val);
const auto &sigma_val_vec = as_column_vector_or_scalar(sigma_val);
const auto &sigma_val_vec = to_ref(as_column_vector_or_scalar(sigma_val));
const auto &y_val_vec = as_column_vector_or_scalar(y_val);

T_scale_val inv_sigma = 1 / as_array_or_scalar(sigma_val_vec);
Expand Down
7 changes: 4 additions & 3 deletions stan/math/prim/prob/ordered_logistic_glm_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <stan/math/prim/fun/log1m_exp.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/fun/value_of_rec.hpp>
#include <cmath>

Expand Down Expand Up @@ -81,12 +82,12 @@ ordered_logistic_glm_lpmf(
if (!include_summand<propto, T_x_scalar, T_beta_scalar, T_cuts_scalar>::value)
return 0;

const auto& x_val = value_of_rec(x);
const auto& x_val = to_ref(value_of_rec(x));
const auto& beta_val = value_of_rec(beta);
const auto& cuts_val = value_of_rec(cuts);

const auto& beta_val_vec = as_column_vector_or_scalar(beta_val);
const auto& cuts_val_vec = as_column_vector_or_scalar(cuts_val);
const auto& beta_val_vec = to_ref(as_column_vector_or_scalar(beta_val));
const auto& cuts_val_vec = to_ref(as_column_vector_or_scalar(cuts_val));

scalar_seq_view<T_y> y_seq(y);
Array<double, Dynamic, 1> cuts_y1(N_instances), cuts_y2(N_instances);
Expand Down
7 changes: 4 additions & 3 deletions stan/math/prim/prob/poisson_log_glm_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/fun/value_of_rec.hpp>
#include <cmath>

Expand Down Expand Up @@ -82,13 +83,13 @@ return_type_t<T_x_scalar, T_alpha, T_beta> poisson_log_glm_lpmf(

T_partials_return logp(0);

const auto& x_val = value_of_rec(x);
const auto& x_val = to_ref(value_of_rec(x));
const auto& y_val = value_of_rec(y);
const auto& beta_val = value_of_rec(beta);
const auto& alpha_val = value_of_rec(alpha);

const auto& y_val_vec = as_column_vector_or_scalar(y_val);
const auto& beta_val_vec = as_column_vector_or_scalar(beta_val);
const auto& y_val_vec = to_ref(as_column_vector_or_scalar(y_val));
const auto& beta_val_vec = to_ref(as_column_vector_or_scalar(beta_val));
const auto& alpha_val_vec = as_column_vector_or_scalar(alpha_val);

Array<T_partials_return, Dynamic, 1> theta(N_instances);
Expand Down
Loading