Skip to content

Commit

Permalink
Merge pull request #2950 from stan-dev/compound-funs
Browse files Browse the repository at this point in the history
Cleanup more usage of compound functions throughout Math
  • Loading branch information
andrjohns authored Oct 4, 2023
2 parents efbc688 + cc8dc55 commit c7cb0ea
Show file tree
Hide file tree
Showing 24 changed files with 62 additions and 49 deletions.
2 changes: 1 addition & 1 deletion stan/math/fwd/fun/log1m_inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace math {
template <typename T>
inline fvar<T> log1m_inv_logit(const fvar<T>& x) {
using std::exp;
return fvar<T>(log1m_inv_logit(x.val_), -x.d_ / (1 + exp(-x.val_)));
return fvar<T>(log1m_inv_logit(x.val_), -x.d_ * inv_logit(x.val_));
}

} // namespace math
Expand Down
2 changes: 1 addition & 1 deletion stan/math/fwd/fun/log1p_exp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace math {
template <typename T>
inline fvar<T> log1p_exp(const fvar<T>& x) {
using std::exp;
return fvar<T>(log1p_exp(x.val_), x.d_ / (1 + exp(-x.val_)));
return fvar<T>(log1p_exp(x.val_), x.d_ * inv_logit(x.val_));
}

} // namespace math
Expand Down
3 changes: 2 additions & 1 deletion stan/math/fwd/fun/log_inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <stan/math/fwd/meta.hpp>
#include <stan/math/fwd/core.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/fun/log_inv_logit.hpp>
#include <cmath>

Expand All @@ -12,7 +13,7 @@ namespace math {
template <typename T>
inline fvar<T> log_inv_logit(const fvar<T>& x) {
using std::exp;
return fvar<T>(log_inv_logit(x.val_), x.d_ / (1 + exp(x.val_)));
return fvar<T>(log_inv_logit(x.val_), x.d_ * inv_logit(-x.val_));
}
} // namespace math
} // namespace stan
Expand Down
6 changes: 3 additions & 3 deletions stan/math/fwd/fun/log_sum_exp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ template <typename T>
inline fvar<T> log_sum_exp(const fvar<T>& x1, const fvar<T>& x2) {
using std::exp;
return fvar<T>(log_sum_exp(x1.val_, x2.val_),
x1.d_ / (1 + exp(x2.val_ - x1.val_))
+ x2.d_ / (exp(x1.val_ - x2.val_) + 1));
x1.d_ * inv_logit(-(x2.val_ - x1.val_))
+ x2.d_ * inv_logit(-(x1.val_ - x2.val_)));
}

template <typename T>
Expand All @@ -28,7 +28,7 @@ inline fvar<T> log_sum_exp(double x1, const fvar<T>& x2) {
if (x1 == NEGATIVE_INFTY) {
return fvar<T>(x2.val_, x2.d_);
}
return fvar<T>(log_sum_exp(x1, x2.val_), x2.d_ / (exp(x1 - x2.val_) + 1));
return fvar<T>(log_sum_exp(x1, x2.val_), x2.d_ * inv_logit(-(x1 - x2.val_)));
}

template <typename T>
Expand Down
4 changes: 3 additions & 1 deletion stan/math/opencl/kernel_generator/elt_function_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ ADD_UNARY_FUNCTION_WITH_INCLUDES(digamma,
opencl_kernels::digamma_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(log1m, opencl_kernels::log1m_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(log_inv_logit,
opencl_kernels::log1p_exp_device_function,
opencl_kernels::log_inv_logit_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(log1m_exp,
opencl_kernels::log1m_exp_device_function)
Expand All @@ -317,7 +318,8 @@ ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_Phi, opencl_kernels::log1m_device_function,
opencl_kernels::phi_device_function,
opencl_kernels::inv_phi_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(
log1m_inv_logit, opencl_kernels::log1m_inv_logit_device_function)
log1m_inv_logit, opencl_kernels::log1p_exp_device_function,
opencl_kernels::log1m_inv_logit_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(trigamma,
opencl_kernels::trigamma_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(
Expand Down
4 changes: 2 additions & 2 deletions stan/math/opencl/kernels/device_functions/log1m_inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ static const char* log1m_inv_logit_device_function
*/
inline double log1m_inv_logit(double x) {
if (x > 0.0) {
return -x - log1p(exp(-x)); // prevent underflow
return -x - log1p_exp(-x); // prevent underflow
}
return -log1p(exp(x));
return -log1p_exp(x);
}
// \cond
) "\n#endif\n"; // NOLINT
Expand Down
4 changes: 2 additions & 2 deletions stan/math/opencl/kernels/device_functions/log_inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ static const char* log_inv_logit_device_function
*/
double log_inv_logit(double x) {
if (x < 0.0) {
return x - log1p(exp(x)); // prevent underflow
return x - log1p_exp(x); // prevent underflow
}
return -log1p(exp(-x));
return -log1p_exp(-x);
}
// \cond
) "\n#endif\n"; // NOLINT
Expand Down
14 changes: 6 additions & 8 deletions stan/math/opencl/prim/binomial_logit_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/functor/partials_propagator.hpp>
#include <stan/math/prim/fun/binomial_coefficient_log.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/fun/log_inv_logit.hpp>
#include <stan/math/prim/fun/log1m_inv_logit.hpp>

namespace stan {
namespace math {
Expand Down Expand Up @@ -60,18 +61,15 @@ return_type_t<T_prob_cl> binomial_logit_lpmf(const T_n_cl& n, const T_N_cl N,
= check_cl(function, "Probability parameter", alpha_val, "finite");
auto alpha_finite = isfinite(alpha_val);

auto inv_logit_alpha = inv_logit(alpha_val);
auto inv_logit_neg_alpha = inv_logit(-alpha_val);
auto log_inv_logit_alpha = log(inv_logit_alpha);
auto log_inv_logit_neg_alpha = log(inv_logit_neg_alpha);
auto log_inv_logit_alpha = log_inv_logit(alpha_val);
auto log1m_inv_logit_alpha = log1m_inv_logit(alpha_val);
auto n_diff = N - n;
auto logp_expr1 = elt_multiply(n, log_inv_logit_alpha)
+ elt_multiply(n_diff, log_inv_logit_neg_alpha);
+ elt_multiply(n_diff, log1m_inv_logit_alpha);
auto logp_expr
= static_select<include_summand<propto, T_n_cl, T_N_cl>::value>(
logp_expr1 + binomial_coefficient_log(N, n), logp_expr1);
auto alpha_deriv = elt_multiply(n, inv_logit_neg_alpha)
- elt_multiply(n_diff, inv_logit_alpha);
auto alpha_deriv = n - elt_multiply(N, exp(log_inv_logit_alpha));

matrix_cl<double> logp_cl;
matrix_cl<double> alpha_deriv_cl;
Expand Down
3 changes: 2 additions & 1 deletion stan/math/opencl/prim/logistic_lpdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/digamma.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/max_size.hpp>
#include <stan/math/prim/functor/partials_propagator.hpp>
Expand Down Expand Up @@ -75,7 +76,7 @@ return_type_t<T_y_cl, T_loc_cl, T_scale_cl> logistic_lpdf(
auto y_minus_mu = y_val - mu_val;
auto y_minus_mu_div_sigma = elt_multiply(y_minus_mu, inv_sigma);

auto logp1 = -y_minus_mu_div_sigma - 2.0 * log1p(exp(-y_minus_mu_div_sigma));
auto logp1 = -y_minus_mu_div_sigma - 2.0 * log1p_exp(-y_minus_mu_div_sigma);
auto logp_expr
= colwise_sum(static_select<include_summand<propto, T_scale_cl>::value>(
logp1 - log(sigma_val), logp1));
Expand Down
3 changes: 2 additions & 1 deletion stan/math/opencl/prim/pareto_lcdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ return_type_t<T_y_cl, T_scale_cl, T_shape_cl> pareto_lcdf(

auto log_quot = log(elt_divide(y_min_val, y_val));
auto exp_prod = exp(elt_multiply(alpha_val, log_quot));
auto lcdf_expr = colwise_sum(log(1.0 - exp_prod));
// TODO(Andrew) Further simplify derivatives and log1m_exp below
auto lcdf_expr = colwise_sum(log1m(exp_prod));

auto common_deriv = elt_divide(exp_prod, 1.0 - exp_prod);

Expand Down
3 changes: 2 additions & 1 deletion stan/math/opencl/prim/weibull_lcdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ return_type_t<T_y_cl, T_shape_cl, T_scale_cl> weibull_lcdf(

auto pow_n = pow(elt_divide(y_val, sigma_val), alpha_val);
auto exp_n = exp(-pow_n);
auto lcdf_expr = colwise_sum(log(1.0 - exp_n));
// TODO(Andrew) Further simplify derivatives and log1m_exp below
auto lcdf_expr = colwise_sum(log1m(exp_n));

auto rep_deriv = elt_divide(pow_n, elt_divide(1.0, exp_n) - 1.0);
auto deriv_y_sigma = elt_multiply(rep_deriv, alpha_val);
Expand Down
4 changes: 2 additions & 2 deletions stan/math/prim/fun/inc_beta_ddb.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <stan/math/prim/fun/inc_beta.hpp>
#include <stan/math/prim/fun/inc_beta_dda.hpp>
#include <stan/math/prim/fun/inv.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1m.hpp>
#include <cmath>

namespace stan {
Expand Down Expand Up @@ -87,7 +87,7 @@ T inc_beta_ddb(T a, T b, T z, T digamma_b, T digamma_ab) {
}
}

return inc_beta(a, b, z) * (log(1 - z) - digamma_b + sum_numer / sum_denom);
return inc_beta(a, b, z) * (log1m(z) - digamma_b + sum_numer / sum_denom);
}

} // namespace math
Expand Down
3 changes: 2 additions & 1 deletion stan/math/prim/fun/inc_beta_ddz.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1m.hpp>
#include <boost/math/special_functions/beta.hpp>
#include <cmath>

Expand All @@ -29,7 +30,7 @@ template <typename T>
T inc_beta_ddz(T a, T b, T z) {
using std::exp;
using std::log;
return exp((b - 1) * log(1 - z) + (a - 1) * log(z) + lgamma(a + b) - lgamma(a)
return exp((b - 1) * log1m(z) + (a - 1) * log(z) + lgamma(a + b) - lgamma(a)
- lgamma(b));
}

Expand Down
6 changes: 3 additions & 3 deletions stan/math/prim/fun/log1m_inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/log1p.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
#include <cmath>

Expand Down Expand Up @@ -36,9 +36,9 @@ namespace math {
inline double log1m_inv_logit(double u) {
using std::exp;
if (u > 0.0) {
return -u - log1p(exp(-u)); // prevent underflow
return -u - log1p_exp(-u); // prevent underflow
}
return -log1p(exp(u));
return -log1p_exp(u);
}

/**
Expand Down
6 changes: 3 additions & 3 deletions stan/math/prim/fun/log_inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/log1p.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
#include <cmath>

Expand Down Expand Up @@ -34,9 +34,9 @@ namespace math {
inline double log_inv_logit(double u) {
using std::exp;
if (u < 0.0) {
return u - log1p(exp(u)); // prevent underflow
return u - log1p_exp(u); // prevent underflow
}
return -log1p(exp(-u));
return -log1p_exp(-u);
}

/**
Expand Down
13 changes: 6 additions & 7 deletions stan/math/prim/fun/prob_constrain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
#define STAN_MATH_PRIM_FUN_PROB_CONSTRAIN_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1m.hpp>
#include <stan/math/prim/fun/log_inv_logit.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/log1m_inv_logit.hpp>
#include <cmath>

namespace stan {
Expand Down Expand Up @@ -49,10 +49,9 @@ inline T prob_constrain(const T& x) {
*/
template <typename T>
inline T prob_constrain(const T& x, T& lp) {
using std::log;
T inv_logit_x = inv_logit(x);
lp += log(inv_logit_x) + log1m(inv_logit_x);
return inv_logit_x;
T log_inv_logit_x = log_inv_logit(x);
lp += log_inv_logit_x + log1m_inv_logit(x);
return exp(log_inv_logit_x);
}

/**
Expand Down
5 changes: 3 additions & 2 deletions stan/math/prim/prob/logistic_cdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/prob/logistic_log.hpp>
#include <stan/math/prim/functor/partials_propagator.hpp>
#include <cmath>
Expand Down Expand Up @@ -70,8 +71,8 @@ return_type_t<T_y, T_loc, T_scale> logistic_cdf(const T_y& y, const T_loc& mu,
const T_partials_return sigma_dbl = sigma_vec.val(n);
const T_partials_return sigma_inv_vec = 1.0 / sigma_vec.val(n);

const T_partials_return Pn
= 1.0 / (1.0 + exp(-(y_dbl - mu_dbl) * sigma_inv_vec));
// TODO(Andrew) Further simplify derivatives and log scale below
const T_partials_return Pn = inv_logit((y_dbl - mu_dbl) * sigma_inv_vec);

P *= Pn;

Expand Down
4 changes: 3 additions & 1 deletion stan/math/prim/prob/logistic_lccdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/fun/max_size.hpp>
#include <stan/math/prim/fun/scalar_seq_view.hpp>
#include <stan/math/prim/fun/size.hpp>
Expand Down Expand Up @@ -71,8 +72,9 @@ return_type_t<T_y, T_loc, T_scale> logistic_lccdf(const T_y& y, const T_loc& mu,
const T_partials_return sigma_dbl = sigma_vec.val(n);
const T_partials_return sigma_inv_vec = 1.0 / sigma_vec.val(n);

// TODO(Andrew) Further simplify derivatives and log-scale below
const T_partials_return Pn
= 1.0 - 1.0 / (1.0 + exp(-(y_dbl - mu_dbl) * sigma_inv_vec));
= 1.0 - inv_logit((y_dbl - mu_dbl) * sigma_inv_vec);
P += log(Pn);

if (!is_constant_all<T_y>::value) {
Expand Down
5 changes: 3 additions & 2 deletions stan/math/prim/prob/logistic_lcdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/fun/scalar_seq_view.hpp>
#include <stan/math/prim/fun/max_size.hpp>
#include <stan/math/prim/fun/size.hpp>
Expand Down Expand Up @@ -71,8 +72,8 @@ return_type_t<T_y, T_loc, T_scale> logistic_lcdf(const T_y& y, const T_loc& mu,
const T_partials_return sigma_dbl = sigma_vec.val(n);
const T_partials_return sigma_inv_vec = 1.0 / sigma_vec.val(n);

const T_partials_return Pn
= 1.0 / (1.0 + exp(-(y_dbl - mu_dbl) * sigma_inv_vec));
// TODO(Andrew) Further simplify derivatives and log-scale below
const T_partials_return Pn = inv_logit((y_dbl - mu_dbl) * sigma_inv_vec);
P += log(Pn);

if (!is_constant_all<T_y>::value) {
Expand Down
3 changes: 2 additions & 1 deletion stan/math/prim/prob/logistic_lpdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1p.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/max_size.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
Expand Down Expand Up @@ -63,7 +64,7 @@ return_type_t<T_y, T_loc, T_scale> logistic_lpdf(const T_y& y, const T_loc& mu,

size_t N = max_size(y, mu, sigma);
T_partials_return logp = -sum(y_minus_mu_div_sigma)
- 2.0 * sum(log1p(exp(-y_minus_mu_div_sigma)));
- 2.0 * sum(log1p_exp(-y_minus_mu_div_sigma));
if (include_summand<propto, T_scale>::value) {
logp -= sum(log(sigma_val)) * N / math::size(sigma);
}
Expand Down
5 changes: 3 additions & 2 deletions stan/math/prim/prob/neg_binomial_2_log_glm_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/multiply_log.hpp>
#include <stan/math/prim/fun/scalar_seq_view.hpp>
#include <stan/math/prim/fun/size.hpp>
Expand Down Expand Up @@ -153,8 +154,8 @@ return_type_t<T_x, T_alpha, T_beta, T_precision> neg_binomial_2_log_glm_lpmf(
T_precision_val log_phi = log(phi_arr);
Array<T_partials_return, Dynamic, 1> logsumexp_theta_logphi
= (theta > log_phi)
.select(theta + log1p(exp(log_phi - theta)),
log_phi + log1p(exp(theta - log_phi)));
.select(theta + log1p_exp(log_phi - theta),
log_phi + log1p_exp(theta - log_phi));

T_sum_val y_plus_phi = y_arr + phi_arr;

Expand Down
3 changes: 2 additions & 1 deletion stan/math/prim/prob/pareto_lcdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ return_type_t<T_y, T_scale, T_shape> pareto_lcdf(const T_y& y,
const auto& exp_prod
= to_ref_if<!is_constant_all<T_y, T_scale, T_shape>::value>(
exp(alpha_val * log_quot));
T_partials_return P = sum(log(1 - exp_prod));
// TODO(Andrew) Further simplify derivatives and log1m_exp below
T_partials_return P = sum(log1m(exp_prod));

if (!is_constant_all<T_y, T_scale, T_shape>::value) {
const auto& common_deriv = to_ref_if<(!is_constant_all<T_y, T_scale>::value
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/prob/skew_double_exponential_lccdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ return_type_t<T_y, T_loc, T_scale, T_skewness> skew_double_exponential_lccdf(
if (y_dbl <= mu_dbl) {
cdf_log += log1m(tau_dbl * exp(-2.0 * expo));
} else {
cdf_log += log(1 - tau_dbl) - 2.0 * expo;
cdf_log += log1m(tau_dbl) - 2.0 * expo;
}

if (!is_constant_all<T_y>::value) {
Expand Down
Loading

0 comments on commit c7cb0ea

Please sign in to comment.