Skip to content

Commit

Permalink
Merge branch 'develop' into feature/inbedded-expression-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns committed Sep 19, 2023
2 parents ed5de89 + 9f2689e commit 545cdee
Show file tree
Hide file tree
Showing 16 changed files with 137 additions and 138 deletions.
3 changes: 2 additions & 1 deletion stan/math/opencl/kernel_generator/elt_function_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_square,
opencl_kernels::inv_square_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_logit,
opencl_kernels::inv_logit_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(logit, opencl_kernels::logit_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(logit, opencl_kernels::log1m_device_function,
opencl_kernels::logit_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi, opencl_kernels::phi_device_function)
ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi_approx,
opencl_kernels::inv_logit_device_function,
Expand Down
4 changes: 2 additions & 2 deletions stan/math/opencl/kernels/device_functions/Phi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ static const char* phi_device_function
if (x < -37.5) {
return 0;
} else if (x < -5.0) {
return 0.5 * erfc(-1.0 / sqrt(2.0) * x);
return 0.5 * erfc(-M_SQRT1_2 * x);
} else if (x > 8.25) {
return 1;
} else {
return 0.5 * (1.0 + erf(1.0 / sqrt(2.0) * x));
return 0.5 * (1.0 + erf(M_SQRT1_2 * x));
}
}
// \cond
Expand Down
2 changes: 1 addition & 1 deletion stan/math/opencl/kernels/device_functions/inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ static const char* inv_logit_device_function
*/
double inv_logit(double x) {
if (x < 0) {
if (x < log(2.2204460492503131E-16)) {
if (x < log(DBL_EPSILON)) {
return exp(x);
}
return exp(x) / (1 + exp(x));
Expand Down
8 changes: 5 additions & 3 deletions stan/math/opencl/kernels/device_functions/lbeta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,23 @@ static const char* lbeta_device_function
return lgamma(x) + lgamma(y) - lgamma(x + y);
}
double x_over_xy = x / (x + y);
double log_xpy = log(x + y);
if (x < LGAMMA_STIRLING_DIFF_USEFUL) {
// y large, x small
double stirling_diff
= lgamma_stirling_diff(y) - lgamma_stirling_diff(x + y);
double stirling
= (y - 0.5) * log1p(-x_over_xy) + x * (1 - log(x + y));
= (y - 0.5) * log1p(-x_over_xy) + x * (1 - log_xpy);
return stirling + lgamma(x) + stirling_diff;
}

// both large
double stirling_diff = lgamma_stirling_diff(x)
+ lgamma_stirling_diff(y)
- lgamma_stirling_diff(x + y);
double stirling = (x - 0.5) * log(x_over_xy) + y * log1p(-x_over_xy)
+ 0.5 * log(2.0 * M_PI) - 0.5 * log(y);
double stirling = (x - 0.5) * (log(x) - log_xpy)
+ y * log1p(-x_over_xy)
+ 0.5 * (M_LN2 + log(M_PI)) - 0.5 * log(y);
return stirling + stirling_diff;
}
// \cond
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ static const char* lgamma_stirling_device_function
* @return Stirling's approximation to lgamma(x).
*/
double lgamma_stirling(double x) {
return 0.5 * log(2.0 * M_PI) + (x - 0.5) * log(x) - x;
return 0.5 * (M_LN2 + log(M_PI)) + (x - 0.5) * log(x) - x;
}
// \cond
) "\n#endif\n"; // NOLINT
Expand Down
2 changes: 1 addition & 1 deletion stan/math/opencl/kernels/device_functions/logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ static const char* logit_device_function
* @param x argument
* @return log odds of argument
*/
double logit(double x) { return log(x / (1 - x)); }
double logit(double x) { return log(x) - log1m(x); }
// \cond
) "\n#endif\n"; // NOLINT
// \endcond
Expand Down
7 changes: 4 additions & 3 deletions stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <stan/math/opencl/kernel_cl.hpp>
#include <stan/math/opencl/kernels/device_functions/digamma.hpp>
#include <stan/math/opencl/kernels/device_functions/log1p_exp.hpp>

namespace stan {
namespace math {
Expand Down Expand Up @@ -92,9 +93,9 @@ static const char* neg_binomial_2_log_glm_kernel_code = STRINGIFY(
double log_phi = log(phi);
double logsumexp_theta_logphi;
if (theta > log_phi) {
logsumexp_theta_logphi = theta + log1p(exp(log_phi - theta));
logsumexp_theta_logphi = theta + log1p_exp(log_phi - theta);
} else {
logsumexp_theta_logphi = log_phi + log1p(exp(theta - log_phi));
logsumexp_theta_logphi = log_phi + log1p_exp(theta - log_phi);
}
double y_plus_phi = y + phi;
if (need_logp1) {
Expand Down Expand Up @@ -196,7 +197,7 @@ const kernel_cl<out_buffer, out_buffer, out_buffer, out_buffer, in_buffer,
in_buffer, in_buffer, in_buffer, in_buffer, int, int, int, int,
int, int, int, int, int, int, int, int, int>
neg_binomial_2_log_glm("neg_binomial_2_log_glm",
{digamma_device_function,
{digamma_device_function, log1p_exp_device_function,
neg_binomial_2_log_glm_kernel_code},
{{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}});

Expand Down
16 changes: 4 additions & 12 deletions stan/math/opencl/kernels/ordered_logistic_glm_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <stan/math/opencl/kernel_cl.hpp>
#include <stan/math/opencl/kernels/device_functions/log1m_exp.hpp>
#include <stan/math/opencl/kernels/device_functions/log1p_exp.hpp>
#include <stan/math/opencl/kernels/device_functions/inv_logit.hpp>

namespace stan {
namespace math {
Expand Down Expand Up @@ -87,20 +88,10 @@ static const char* ordered_logistic_glm_kernel_code = STRINGIFY(

if (need_location_derivative || need_cuts_derivative) {
double exp_cuts_diff = exp(cut_y2 - cut_y1);
if (cut2 > 0) {
double exp_m_cut2 = exp(-cut2);
d1 = exp_m_cut2 / (1 + exp_m_cut2);
} else {
d1 = 1 / (1 + exp(cut2));
}
d1 = inv_logit(-cut2);
d1 -= exp_cuts_diff / (exp_cuts_diff - 1);
d2 = 1 / (1 - exp_cuts_diff);
if (cut1 > 0) {
double exp_m_cut1 = exp(-cut1);
d2 -= exp_m_cut1 / (1 + exp_m_cut1);
} else {
d2 -= 1 / (1 + exp(cut1));
}
d2 -= inv_logit(-cut1);

if (need_location_derivative) {
location_derivative[gid] = d1 - d2;
Expand Down Expand Up @@ -181,6 +172,7 @@ const kernel_cl<out_buffer, out_buffer, out_buffer, out_buffer, in_buffer,
in_buffer, in_buffer, in_buffer, int, int, int, int, int, int>
ordered_logistic_glm("ordered_logistic_glm",
{log1p_exp_device_function, log1m_exp_device_function,
inv_logit_device_function,
ordered_logistic_glm_kernel_code},
{{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}});

Expand Down
17 changes: 4 additions & 13 deletions stan/math/opencl/kernels/ordered_logistic_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <stan/math/opencl/kernel_cl.hpp>
#include <stan/math/opencl/kernels/device_functions/log1m_exp.hpp>
#include <stan/math/opencl/kernels/device_functions/log1p_exp.hpp>
#include <stan/math/opencl/kernels/device_functions/inv_logit.hpp>

namespace stan {
namespace math {
Expand Down Expand Up @@ -83,20 +84,10 @@ static const char* ordered_logistic_kernel_code = STRINGIFY(

if (need_lambda_derivative || need_cuts_derivative) {
double exp_cuts_diff = exp(cut_y2 - cut_y1);
if (cut2 > 0) {
double exp_m_cut2 = exp(-cut2);
d1 = exp_m_cut2 / (1 + exp_m_cut2);
} else {
d1 = 1 / (1 + exp(cut2));
}
d1 = inv_logit(-cut2);
d1 -= exp_cuts_diff / (exp_cuts_diff - 1);
d2 = 1 / (1 - exp_cuts_diff);
if (cut1 > 0) {
double exp_m_cut1 = exp(-cut1);
d2 -= exp_m_cut1 / (1 + exp_m_cut1);
} else {
d2 -= 1 / (1 + exp(cut1));
}
d2 -= inv_logit(-cut1);

if (need_lambda_derivative) {
lambda_derivative[gid] = d1 - d2;
Expand Down Expand Up @@ -175,7 +166,7 @@ const kernel_cl<out_buffer, out_buffer, out_buffer, in_buffer, in_buffer,
in_buffer, int, int, int, int, int, int>
ordered_logistic("ordered_logistic",
{log1p_exp_device_function, log1m_exp_device_function,
ordered_logistic_kernel_code},
inv_logit_device_function, ordered_logistic_kernel_code},
{{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}});

} // namespace opencl_kernels
Expand Down
6 changes: 3 additions & 3 deletions stan/math/opencl/kernels/tridiagonalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ static const char* tridiagonalization_householder_kernel_code = STRINGIFY(
q = q_local[0];
alpha = q_local[1];
if (q != 0) {
double multi = sqrt(2.) / q;
double multi = M_SQRT2 / q;
// normalize the Householder vector
for (int i = lid + 1; i < P_span; i += lsize) {
P[P_start + i] *= multi;
}
}
if (gid == 0) {
P[P_rows * (k + j + 1) + k + j]
= P[P_rows * (k + j) + k + j + 1] * q / sqrt(2.) + alpha;
= P[P_rows * (k + j) + k + j + 1] * q / M_SQRT2 + alpha;
}
}
// \cond
Expand Down Expand Up @@ -291,7 +291,7 @@ static const char* tridiagonalization_v_step_3_kernel_code = STRINGIFY(
v[i] -= acc * u[i];
}
if (gid == 0) {
P[P_rows * (k + j + 1) + k + j] -= *q / sqrt(2.) * u[0];
P[P_rows * (k + j + 1) + k + j] -= *q / M_SQRT2 * u[0];
}
}
// \cond
Expand Down
46 changes: 12 additions & 34 deletions stan/math/prim/prob/bernoulli_cdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/max_size.hpp>
#include <stan/math/prim/fun/scalar_seq_view.hpp>
#include <stan/math/prim/fun/any.hpp>
#include <stan/math/prim/fun/select.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/functor/partials_propagator.hpp>

namespace stan {
Expand Down Expand Up @@ -36,50 +34,30 @@ return_type_t<T_prob> bernoulli_cdf(const T_n& n, const T_prob& theta) {
check_consistent_sizes(function, "Random variable", n,
"Probability parameter", theta);
T_theta_ref theta_ref = theta;
check_bounded(function, "Probability parameter", value_of(theta_ref), 0.0,
1.0);
const auto& n_arr = as_array_or_scalar(n);
const auto& theta_arr = as_value_column_array_or_scalar(theta_ref);
check_bounded(function, "Probability parameter", theta_arr, 0.0, 1.0);

if (size_zero(n, theta)) {
return 1.0;
}

T_partials_return P(1.0);
auto ops_partials = make_partials_propagator(theta_ref);

scalar_seq_view<T_n> n_vec(n);
scalar_seq_view<T_theta_ref> theta_vec(theta_ref);
size_t max_size_seq_view = max_size(n, theta);

// Explicit return for extreme values
// The gradients are technically ill-defined, but treated as zero
for (size_t i = 0; i < stan::math::size(n); i++) {
if (n_vec.val(i) < 0) {
return ops_partials.build(0.0);
}
if (any(n_arr < 0)) {
return ops_partials.build(0.0);
}
const auto& log1m_theta = select(theta_arr == 1, 0.0, log1m(theta_arr));
const auto& P1 = select(n_arr == 0, log1m_theta, 0.0);

for (size_t i = 0; i < max_size_seq_view; i++) {
// Explicit results for extreme values
// The gradients are technically ill-defined, but treated as zero
if (n_vec.val(i) >= 1) {
continue;
}

const T_partials_return Pi = 1 - theta_vec.val(i);

P *= Pi;

if (!is_constant_all<T_prob>::value) {
partials<0>(ops_partials)[i] += -1 / Pi;
}
}
T_partials_return P = sum(P1);

if (!is_constant_all<T_prob>::value) {
for (size_t i = 0; i < stan::math::size(theta); ++i) {
partials<0>(ops_partials)[i] *= P;
}
partials<0>(ops_partials) = select(n_arr == 0, -exp(P - P1), 0.0);
}
return ops_partials.build(P);
return ops_partials.build(exp(P));
}

} // namespace math
Expand Down
45 changes: 15 additions & 30 deletions stan/math/prim/prob/bernoulli_lccdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/any.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/inv.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/max_size.hpp>
#include <stan/math/prim/fun/scalar_seq_view.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/select.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/functor/partials_propagator.hpp>
#include <cmath>

namespace stan {
namespace math {
Expand All @@ -33,50 +30,38 @@ template <typename T_n, typename T_prob,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
T_n, T_prob>* = nullptr>
return_type_t<T_prob> bernoulli_lccdf(const T_n& n, const T_prob& theta) {
using T_partials_return = partials_return_t<T_n, T_prob>;
using T_theta_ref = ref_type_t<T_prob>;
using std::log;
static const char* function = "bernoulli_lccdf";
check_consistent_sizes(function, "Random variable", n,
"Probability parameter", theta);
T_theta_ref theta_ref = theta;
check_bounded(function, "Probability parameter", value_of(theta_ref), 0.0,
1.0);
const auto& n_arr = as_array_or_scalar(n);
const auto& theta_arr = as_value_column_array_or_scalar(theta_ref);
check_bounded(function, "Probability parameter", theta_arr, 0.0, 1.0);

if (size_zero(n, theta)) {
return 0.0;
}

T_partials_return P(0.0);
auto ops_partials = make_partials_propagator(theta_ref);

scalar_seq_view<T_n> n_vec(n);
scalar_seq_view<T_theta_ref> theta_vec(theta_ref);
size_t max_size_seq_view = max_size(n, theta);

// Explicit return for extreme values
// The gradients are technically ill-defined, but treated as zero
for (size_t i = 0; i < stan::math::size(n); i++) {
const double n_dbl = n_vec.val(i);
if (n_dbl < 0) {
return ops_partials.build(0.0);
}
if (n_dbl >= 1) {
return ops_partials.build(NEGATIVE_INFTY);
}
if (any(n_arr < 0)) {
return ops_partials.build(0.0);
} else if (any(n_arr >= 1)) {
return ops_partials.build(NEGATIVE_INFTY);
}

for (size_t i = 0; i < max_size_seq_view; i++) {
const T_partials_return Pi = theta_vec.val(i);

P += log(Pi);
size_t theta_size = math::size(theta_arr);
size_t n_size = math::size(n_arr);
double broadcast_n = theta_size == n_size ? 1 : n_size;

if (!is_constant_all<T_prob>::value) {
partials<0>(ops_partials)[i] += inv(Pi);
}
if (!is_constant_all<T_prob>::value) {
partials<0>(ops_partials) = inv(theta_arr) * broadcast_n;
}

return ops_partials.build(P);
return ops_partials.build(sum(log(theta_arr)) * broadcast_n);
}

} // namespace math
Expand Down
Loading

0 comments on commit 545cdee

Please sign in to comment.