Skip to content

Commit

Permalink
Merge pull request #2945 from stan-dev/binomial-logit-numerics
Browse files Browse the repository at this point in the history
Improve numerical stability of binomial_logit_lpmf
  • Loading branch information
andrjohns authored Sep 26, 2023
2 parents 9f2689e + 2a81187 commit eb3b5d7
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 23 deletions.
32 changes: 9 additions & 23 deletions stan/math/prim/prob/binomial_logit_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/as_column_vector_or_scalar.hpp>
#include <stan/math/prim/fun/as_array_or_scalar.hpp>
#include <stan/math/prim/fun/as_value_column_array_or_scalar.hpp>
#include <stan/math/prim/fun/binomial_coefficient_log.hpp>
#include <stan/math/prim/fun/inc_beta.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/fun/lbeta.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log_inv_logit.hpp>
#include <stan/math/prim/fun/log1m_inv_logit.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/max_size.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
Expand Down Expand Up @@ -66,33 +63,22 @@ return_type_t<T_prob> binomial_logit_lpmf(const T_n& n, const T_N& N,
if (!include_summand<propto, T_prob>::value) {
return 0.0;
}
const auto& inv_logit_alpha
= to_ref_if<!is_constant_all<T_prob>::value>(inv_logit(alpha_val));
const auto& inv_logit_neg_alpha
= to_ref_if<!is_constant_all<T_prob>::value>(inv_logit(-alpha_val));
const auto& log_inv_logit_alpha
= to_ref_if<!is_constant_all<T_prob>::value>(log_inv_logit(alpha_val));
const auto& log1m_inv_logit_alpha
= to_ref_if<!is_constant_all<T_prob>::value>(log1m_inv_logit(alpha_val));

size_t maximum_size = max_size(n, N, alpha);
const auto& log_inv_logit_alpha = log(inv_logit_alpha);
const auto& log_inv_logit_neg_alpha = log(inv_logit_neg_alpha);
T_partials_return logp = sum(n_val * log_inv_logit_alpha
+ (N_val - n_val) * log_inv_logit_neg_alpha);
+ (N_val - n_val) * log1m_inv_logit_alpha);
if (include_summand<propto, T_n, T_N>::value) {
logp += sum(binomial_coefficient_log(N_val, n_val)) * maximum_size
/ max_size(n, N);
}

auto ops_partials = make_partials_propagator(alpha_ref);
if (!is_constant_all<T_prob>::value) {
if (is_vector<T_prob>::value) {
edge<0>(ops_partials).partials_
= n_val * inv_logit_neg_alpha - (N_val - n_val) * inv_logit_alpha;
} else {
T_partials_return sum_n = sum(n_val) * maximum_size / math::size(n);
partials<0>(ops_partials)[0] = forward_as<T_partials_return>(
sum_n * inv_logit_neg_alpha
- (sum(N_val) * maximum_size / math::size(N) - sum_n)
* inv_logit_alpha);
}
edge<0>(ops_partials).partials_ = n_val - N_val * exp(log_inv_logit_alpha);
}

return ops_partials.build(logp);
Expand Down
23 changes: 23 additions & 0 deletions test/unit/math/mix/prob/binomial_logit_lpmf_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include <stan/math/mix.hpp>
#include <test/unit/math/test_ad.hpp>

TEST(mathMixScalFun, binomial_logit_lpmf) {
auto f = [](const auto n, const auto N) {
return [=](const auto& alpha) {
return stan::math::binomial_logit_lpmf(n, N, alpha);
};
};

Eigen::VectorXd alpha = Eigen::VectorXd::Random(3);
std::vector<int> n_arr{1, 4, 5};
std::vector<int> N_arr{10, 45, 25};

stan::test::expect_ad(f(5, 25), 2.11);
stan::test::expect_ad(f(5, 25), alpha);
stan::test::expect_ad(f(n_arr, 25), alpha);
stan::test::expect_ad(f(n_arr, N_arr), alpha);
stan::test::expect_ad(f(n_arr, 10), 2.11);
stan::test::expect_ad(f(n_arr, N_arr), 2.11);
stan::test::expect_ad(f(5, N_arr), 2.11);
stan::test::expect_ad(f(5, N_arr), alpha);
}

0 comments on commit eb3b5d7

Please sign in to comment.