Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup more usage of compound functions throughout Math #2950

Merged
merged 8 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion stan/math/fwd/fun/log1m_inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace math {
template <typename T>
inline fvar<T> log1m_inv_logit(const fvar<T>& x) {
using std::exp;
return fvar<T>(log1m_inv_logit(x.val_), -x.d_ / (1 + exp(-x.val_)));
return fvar<T>(log1m_inv_logit(x.val_), -x.d_ * inv_logit(x.val_));
}

} // namespace math
Expand Down
2 changes: 1 addition & 1 deletion stan/math/fwd/fun/log1p_exp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace math {
template <typename T>
inline fvar<T> log1p_exp(const fvar<T>& x) {
using std::exp;
return fvar<T>(log1p_exp(x.val_), x.d_ / (1 + exp(-x.val_)));
return fvar<T>(log1p_exp(x.val_), x.d_ * inv_logit(x.val_));
}

} // namespace math
Expand Down
3 changes: 2 additions & 1 deletion stan/math/fwd/fun/log_inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <stan/math/fwd/meta.hpp>
#include <stan/math/fwd/core.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/fun/log_inv_logit.hpp>
#include <cmath>

Expand All @@ -12,7 +13,7 @@ namespace math {
template <typename T>
inline fvar<T> log_inv_logit(const fvar<T>& x) {
using std::exp;
return fvar<T>(log_inv_logit(x.val_), x.d_ / (1 + exp(x.val_)));
return fvar<T>(log_inv_logit(x.val_), x.d_ * inv_logit(-x.val_));
}
} // namespace math
} // namespace stan
Expand Down
6 changes: 3 additions & 3 deletions stan/math/fwd/fun/log_sum_exp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ template <typename T>
inline fvar<T> log_sum_exp(const fvar<T>& x1, const fvar<T>& x2) {
using std::exp;
return fvar<T>(log_sum_exp(x1.val_, x2.val_),
x1.d_ / (1 + exp(x2.val_ - x1.val_))
+ x2.d_ / (exp(x1.val_ - x2.val_) + 1));
x1.d_ * inv_logit(-(x2.val_ - x1.val_))
+ x2.d_ * inv_logit(-(x1.val_ - x2.val_)));
}

template <typename T>
Expand All @@ -28,7 +28,7 @@ inline fvar<T> log_sum_exp(double x1, const fvar<T>& x2) {
if (x1 == NEGATIVE_INFTY) {
return fvar<T>(x2.val_, x2.d_);
}
return fvar<T>(log_sum_exp(x1, x2.val_), x2.d_ / (exp(x1 - x2.val_) + 1));
return fvar<T>(log_sum_exp(x1, x2.val_), x2.d_ * inv_logit(-(x1 - x2.val_)));
}

template <typename T>
Expand Down
4 changes: 3 additions & 1 deletion stan/math/opencl/kernel_generator/elt_function_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ ADD_UNARY_FUNCTION_WITH_INCLUDES(digamma,
opencl_kernels::digamma_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(log1m, opencl_kernels::log1m_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(log_inv_logit,
opencl_kernels::log1p_exp_device_function,
opencl_kernels::log_inv_logit_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(log1m_exp,
opencl_kernels::log1m_exp_device_function)
Expand All @@ -317,7 +318,8 @@ ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_Phi, opencl_kernels::log1m_device_function,
opencl_kernels::phi_device_function,
opencl_kernels::inv_phi_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(
log1m_inv_logit, opencl_kernels::log1m_inv_logit_device_function)
log1m_inv_logit, opencl_kernels::log1p_exp_device_function,
opencl_kernels::log1m_inv_logit_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(trigamma,
opencl_kernels::trigamma_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(
Expand Down
4 changes: 2 additions & 2 deletions stan/math/opencl/kernels/device_functions/log1m_inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ static const char* log1m_inv_logit_device_function
*/
inline double log1m_inv_logit(double x) {
if (x > 0.0) {
return -x - log1p(exp(-x)); // prevent underflow
return -x - log1p_exp(-x); // prevent underflow
}
return -log1p(exp(x));
return -log1p_exp(x);
}
// \cond
) "\n#endif\n"; // NOLINT
Expand Down
4 changes: 2 additions & 2 deletions stan/math/opencl/kernels/device_functions/log_inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ static const char* log_inv_logit_device_function
*/
double log_inv_logit(double x) {
if (x < 0.0) {
return x - log1p(exp(x)); // prevent underflow
return x - log1p_exp(x); // prevent underflow
}
return -log1p(exp(-x));
return -log1p_exp(-x);
}
// \cond
) "\n#endif\n"; // NOLINT
Expand Down
14 changes: 6 additions & 8 deletions stan/math/opencl/prim/binomial_logit_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/functor/partials_propagator.hpp>
#include <stan/math/prim/fun/binomial_coefficient_log.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/fun/log_inv_logit.hpp>
#include <stan/math/prim/fun/log1m_inv_logit.hpp>

namespace stan {
namespace math {
Expand Down Expand Up @@ -60,18 +61,15 @@ return_type_t<T_prob_cl> binomial_logit_lpmf(const T_n_cl& n, const T_N_cl N,
= check_cl(function, "Probability parameter", alpha_val, "finite");
auto alpha_finite = isfinite(alpha_val);

auto inv_logit_alpha = inv_logit(alpha_val);
auto inv_logit_neg_alpha = inv_logit(-alpha_val);
auto log_inv_logit_alpha = log(inv_logit_alpha);
auto log_inv_logit_neg_alpha = log(inv_logit_neg_alpha);
auto log_inv_logit_alpha = log_inv_logit(alpha_val);
auto log1m_inv_logit_alpha = log1m_inv_logit(alpha_val);
auto n_diff = N - n;
auto logp_expr1 = elt_multiply(n, log_inv_logit_alpha)
+ elt_multiply(n_diff, log_inv_logit_neg_alpha);
+ elt_multiply(n_diff, log1m_inv_logit_alpha);
auto logp_expr
= static_select<include_summand<propto, T_n_cl, T_N_cl>::value>(
logp_expr1 + binomial_coefficient_log(N, n), logp_expr1);
auto alpha_deriv = elt_multiply(n, inv_logit_neg_alpha)
- elt_multiply(n_diff, inv_logit_alpha);
auto alpha_deriv = n - elt_multiply(N, exp(log_inv_logit_alpha));

matrix_cl<double> logp_cl;
matrix_cl<double> alpha_deriv_cl;
Expand Down
3 changes: 2 additions & 1 deletion stan/math/opencl/prim/logistic_lpdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/digamma.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/max_size.hpp>
#include <stan/math/prim/functor/partials_propagator.hpp>
Expand Down Expand Up @@ -75,7 +76,7 @@ return_type_t<T_y_cl, T_loc_cl, T_scale_cl> logistic_lpdf(
auto y_minus_mu = y_val - mu_val;
auto y_minus_mu_div_sigma = elt_multiply(y_minus_mu, inv_sigma);

auto logp1 = -y_minus_mu_div_sigma - 2.0 * log1p(exp(-y_minus_mu_div_sigma));
auto logp1 = -y_minus_mu_div_sigma - 2.0 * log1p_exp(-y_minus_mu_div_sigma);
auto logp_expr
= colwise_sum(static_select<include_summand<propto, T_scale_cl>::value>(
logp1 - log(sigma_val), logp1));
Expand Down
3 changes: 2 additions & 1 deletion stan/math/opencl/prim/pareto_lcdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ return_type_t<T_y_cl, T_scale_cl, T_shape_cl> pareto_lcdf(

auto log_quot = log(elt_divide(y_min_val, y_val));
auto exp_prod = exp(elt_multiply(alpha_val, log_quot));
auto lcdf_expr = colwise_sum(log(1.0 - exp_prod));
// TODO(Andrew) Further simplify derivatives and log1m_exp below
auto lcdf_expr = colwise_sum(log1m(exp_prod));

auto common_deriv = elt_divide(exp_prod, 1.0 - exp_prod);

Expand Down
3 changes: 2 additions & 1 deletion stan/math/opencl/prim/weibull_lcdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ return_type_t<T_y_cl, T_shape_cl, T_scale_cl> weibull_lcdf(

auto pow_n = pow(elt_divide(y_val, sigma_val), alpha_val);
auto exp_n = exp(-pow_n);
auto lcdf_expr = colwise_sum(log(1.0 - exp_n));
// TODO(Andrew) Further simplify derivatives and log1m_exp below
auto lcdf_expr = colwise_sum(log1m(exp_n));

auto rep_deriv = elt_divide(pow_n, elt_divide(1.0, exp_n) - 1.0);
auto deriv_y_sigma = elt_multiply(rep_deriv, alpha_val);
Expand Down
4 changes: 2 additions & 2 deletions stan/math/prim/fun/inc_beta_ddb.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <stan/math/prim/fun/inc_beta.hpp>
#include <stan/math/prim/fun/inc_beta_dda.hpp>
#include <stan/math/prim/fun/inv.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1m.hpp>
#include <cmath>

namespace stan {
Expand Down Expand Up @@ -87,7 +87,7 @@ T inc_beta_ddb(T a, T b, T z, T digamma_b, T digamma_ab) {
}
}

return inc_beta(a, b, z) * (log(1 - z) - digamma_b + sum_numer / sum_denom);
return inc_beta(a, b, z) * (log1m(z) - digamma_b + sum_numer / sum_denom);
}

} // namespace math
Expand Down
3 changes: 2 additions & 1 deletion stan/math/prim/fun/inc_beta_ddz.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1m.hpp>
#include <boost/math/special_functions/beta.hpp>
#include <cmath>

Expand All @@ -29,7 +30,7 @@ template <typename T>
T inc_beta_ddz(T a, T b, T z) {
using std::exp;
using std::log;
return exp((b - 1) * log(1 - z) + (a - 1) * log(z) + lgamma(a + b) - lgamma(a)
return exp((b - 1) * log1m(z) + (a - 1) * log(z) + lgamma(a + b) - lgamma(a)
- lgamma(b));
}

Expand Down
6 changes: 3 additions & 3 deletions stan/math/prim/fun/log1m_inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/log1p.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
#include <cmath>

Expand Down Expand Up @@ -36,9 +36,9 @@ namespace math {
inline double log1m_inv_logit(double u) {
using std::exp;
if (u > 0.0) {
return -u - log1p(exp(-u)); // prevent underflow
return -u - log1p_exp(-u); // prevent underflow
}
return -log1p(exp(u));
return -log1p_exp(u);
}

/**
Expand Down
6 changes: 3 additions & 3 deletions stan/math/prim/fun/log_inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/log1p.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
#include <cmath>

Expand Down Expand Up @@ -34,9 +34,9 @@ namespace math {
inline double log_inv_logit(double u) {
using std::exp;
if (u < 0.0) {
return u - log1p(exp(u)); // prevent underflow
return u - log1p_exp(u); // prevent underflow
}
return -log1p(exp(-u));
return -log1p_exp(-u);
}

/**
Expand Down
13 changes: 6 additions & 7 deletions stan/math/prim/fun/prob_constrain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
#define STAN_MATH_PRIM_FUN_PROB_CONSTRAIN_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1m.hpp>
#include <stan/math/prim/fun/log_inv_logit.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/log1m_inv_logit.hpp>
#include <cmath>

namespace stan {
Expand Down Expand Up @@ -49,10 +49,9 @@ inline T prob_constrain(const T& x) {
*/
template <typename T>
inline T prob_constrain(const T& x, T& lp) {
using std::log;
T inv_logit_x = inv_logit(x);
lp += log(inv_logit_x) + log1m(inv_logit_x);
return inv_logit_x;
T log_inv_logit_x = log_inv_logit(x);
lp += log_inv_logit_x + log1m_inv_logit(x);
return exp(log_inv_logit_x);
}

/**
Expand Down
5 changes: 3 additions & 2 deletions stan/math/prim/prob/logistic_cdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/prob/logistic_log.hpp>
#include <stan/math/prim/functor/partials_propagator.hpp>
#include <cmath>
Expand Down Expand Up @@ -70,8 +71,8 @@ return_type_t<T_y, T_loc, T_scale> logistic_cdf(const T_y& y, const T_loc& mu,
const T_partials_return sigma_dbl = sigma_vec.val(n);
const T_partials_return sigma_inv_vec = 1.0 / sigma_vec.val(n);

const T_partials_return Pn
= 1.0 / (1.0 + exp(-(y_dbl - mu_dbl) * sigma_inv_vec));
// TODO(Andrew) Further simplify derivatives and log scale below
const T_partials_return Pn = inv_logit((y_dbl - mu_dbl) * sigma_inv_vec);

P *= Pn;

Expand Down
4 changes: 3 additions & 1 deletion stan/math/prim/prob/logistic_lccdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/fun/max_size.hpp>
#include <stan/math/prim/fun/scalar_seq_view.hpp>
#include <stan/math/prim/fun/size.hpp>
Expand Down Expand Up @@ -71,8 +72,9 @@ return_type_t<T_y, T_loc, T_scale> logistic_lccdf(const T_y& y, const T_loc& mu,
const T_partials_return sigma_dbl = sigma_vec.val(n);
const T_partials_return sigma_inv_vec = 1.0 / sigma_vec.val(n);

// TODO(Andrew) Further simplify derivatives and log-scale below
const T_partials_return Pn
= 1.0 - 1.0 / (1.0 + exp(-(y_dbl - mu_dbl) * sigma_inv_vec));
= 1.0 - inv_logit((y_dbl - mu_dbl) * sigma_inv_vec);
P += log(Pn);

if (!is_constant_all<T_y>::value) {
Expand Down
5 changes: 3 additions & 2 deletions stan/math/prim/prob/logistic_lcdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/fun/scalar_seq_view.hpp>
#include <stan/math/prim/fun/max_size.hpp>
#include <stan/math/prim/fun/size.hpp>
Expand Down Expand Up @@ -71,8 +72,8 @@ return_type_t<T_y, T_loc, T_scale> logistic_lcdf(const T_y& y, const T_loc& mu,
const T_partials_return sigma_dbl = sigma_vec.val(n);
const T_partials_return sigma_inv_vec = 1.0 / sigma_vec.val(n);

const T_partials_return Pn
= 1.0 / (1.0 + exp(-(y_dbl - mu_dbl) * sigma_inv_vec));
// TODO(Andrew) Further simplify derivatives and log-scale below
const T_partials_return Pn = inv_logit((y_dbl - mu_dbl) * sigma_inv_vec);
P += log(Pn);

if (!is_constant_all<T_y>::value) {
Expand Down
3 changes: 2 additions & 1 deletion stan/math/prim/prob/logistic_lpdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1p.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/max_size.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
Expand Down Expand Up @@ -63,7 +64,7 @@ return_type_t<T_y, T_loc, T_scale> logistic_lpdf(const T_y& y, const T_loc& mu,

size_t N = max_size(y, mu, sigma);
T_partials_return logp = -sum(y_minus_mu_div_sigma)
- 2.0 * sum(log1p(exp(-y_minus_mu_div_sigma)));
- 2.0 * sum(log1p_exp(-y_minus_mu_div_sigma));
if (include_summand<propto, T_scale>::value) {
logp -= sum(log(sigma_val)) * N / math::size(sigma);
}
Expand Down
5 changes: 3 additions & 2 deletions stan/math/prim/prob/neg_binomial_2_log_glm_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/multiply_log.hpp>
#include <stan/math/prim/fun/scalar_seq_view.hpp>
#include <stan/math/prim/fun/size.hpp>
Expand Down Expand Up @@ -153,8 +154,8 @@ return_type_t<T_x, T_alpha, T_beta, T_precision> neg_binomial_2_log_glm_lpmf(
T_precision_val log_phi = log(phi_arr);
Array<T_partials_return, Dynamic, 1> logsumexp_theta_logphi
= (theta > log_phi)
.select(theta + log1p(exp(log_phi - theta)),
log_phi + log1p(exp(theta - log_phi)));
.select(theta + log1p_exp(log_phi - theta),
log_phi + log1p_exp(theta - log_phi));

T_sum_val y_plus_phi = y_arr + phi_arr;

Expand Down
3 changes: 2 additions & 1 deletion stan/math/prim/prob/pareto_lcdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ return_type_t<T_y, T_scale, T_shape> pareto_lcdf(const T_y& y,
const auto& exp_prod
= to_ref_if<!is_constant_all<T_y, T_scale, T_shape>::value>(
exp(alpha_val * log_quot));
T_partials_return P = sum(log(1 - exp_prod));
// TODO(Andrew) Further simplify derivatives and log1m_exp below
T_partials_return P = sum(log1m(exp_prod));

if (!is_constant_all<T_y, T_scale, T_shape>::value) {
const auto& common_deriv = to_ref_if<(!is_constant_all<T_y, T_scale>::value
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/prob/skew_double_exponential_lccdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ return_type_t<T_y, T_loc, T_scale, T_skewness> skew_double_exponential_lccdf(
if (y_dbl <= mu_dbl) {
cdf_log += log1m(tau_dbl * exp(-2.0 * expo));
} else {
cdf_log += log(1 - tau_dbl) - 2.0 * expo;
cdf_log += log1m(tau_dbl) - 2.0 * expo;
}

if (!is_constant_all<T_y>::value) {
Expand Down
Loading