From 7e421083f3a2a68a890ce7b544e6d6840ca93db8 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Tue, 19 Sep 2023 18:54:33 +0300 Subject: [PATCH 1/2] Improve numerical stability of binomial_logit, add tests --- stan/math/prim/prob/binomial_logit_lpmf.hpp | 32 ++++++------------- .../mix/prob/binomial_logit_lpmf_test.cpp | 22 +++++++++++++ 2 files changed, 31 insertions(+), 23 deletions(-) create mode 100644 test/unit/math/mix/prob/binomial_logit_lpmf_test.cpp diff --git a/stan/math/prim/prob/binomial_logit_lpmf.hpp b/stan/math/prim/prob/binomial_logit_lpmf.hpp index 606f88067f0..1fc7dfd30ec 100644 --- a/stan/math/prim/prob/binomial_logit_lpmf.hpp +++ b/stan/math/prim/prob/binomial_logit_lpmf.hpp @@ -3,14 +3,11 @@ #include #include -#include -#include #include #include -#include -#include -#include -#include +#include +#include +#include #include #include #include @@ -66,16 +63,14 @@ return_type_t binomial_logit_lpmf(const T_n& n, const T_N& N, if (!include_summand::value) { return 0.0; } - const auto& inv_logit_alpha - = to_ref_if::value>(inv_logit(alpha_val)); - const auto& inv_logit_neg_alpha - = to_ref_if::value>(inv_logit(-alpha_val)); + const auto& log_inv_logit_alpha + = to_ref_if::value>(log_inv_logit(alpha_val)); + const auto& log1m_inv_logit_alpha + = to_ref_if::value>(log1m_inv_logit(alpha_val)); size_t maximum_size = max_size(n, N, alpha); - const auto& log_inv_logit_alpha = log(inv_logit_alpha); - const auto& log_inv_logit_neg_alpha = log(inv_logit_neg_alpha); T_partials_return logp = sum(n_val * log_inv_logit_alpha - + (N_val - n_val) * log_inv_logit_neg_alpha); + + (N_val - n_val) * log1m_inv_logit_alpha); if (include_summand::value) { logp += sum(binomial_coefficient_log(N_val, n_val)) * maximum_size / max_size(n, N); @@ -83,16 +78,7 @@ return_type_t binomial_logit_lpmf(const T_n& n, const T_N& N, auto ops_partials = make_partials_propagator(alpha_ref); if (!is_constant_all::value) { - if (is_vector::value) { - edge<0>(ops_partials).partials_ - = n_val * inv_logit_neg_alpha - (N_val - n_val) * inv_logit_alpha; - } else { - T_partials_return sum_n = sum(n_val) * maximum_size / math::size(n); - partials<0>(ops_partials)[0] = forward_as( - sum_n * inv_logit_neg_alpha - - (sum(N_val) * maximum_size / math::size(N) - sum_n) - * inv_logit_alpha); - } + edge<0>(ops_partials).partials_ = n_val - N_val * exp(log_inv_logit_alpha); } return ops_partials.build(logp); diff --git a/test/unit/math/mix/prob/binomial_logit_lpmf_test.cpp b/test/unit/math/mix/prob/binomial_logit_lpmf_test.cpp new file mode 100644 index 00000000000..47241944c07 --- /dev/null +++ b/test/unit/math/mix/prob/binomial_logit_lpmf_test.cpp @@ -0,0 +1,22 @@ +#include +#include + +TEST(mathMixScalFun, binomial_logit_lpmf) { + auto f = [](const auto n, const auto N) { + return [=](const auto& alpha) { + return stan::math::binomial_logit_lpmf(n, N, alpha); + }; + }; + + Eigen::VectorXd alpha = Eigen::VectorXd::Random(3); + std::vector n_arr{1, 4, 5}; + std::vector N_arr{10, 45, 25}; + + stan::test::expect_ad(f(5, 25), 2.11); + stan::test::expect_ad(f(5, 25), alpha); + stan::test::expect_ad(f(n_arr, 25), alpha); + stan::test::expect_ad(f(n_arr, N_arr), alpha); + stan::test::expect_ad(f(n_arr, 10), 2.11); + stan::test::expect_ad(f(n_arr, N_arr), 2.11); + stan::test::expect_ad(f(5, N_arr), 2.11); +} From 2a8118787b772f5bf59a483b314752c5ef3f91f3 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Tue, 19 Sep 2023 19:00:49 +0300 Subject: [PATCH 2/2] Missed test combination --- test/unit/math/mix/prob/binomial_logit_lpmf_test.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/test/unit/math/mix/prob/binomial_logit_lpmf_test.cpp b/test/unit/math/mix/prob/binomial_logit_lpmf_test.cpp index 47241944c07..250ab303e39 100644 --- a/test/unit/math/mix/prob/binomial_logit_lpmf_test.cpp +++ b/test/unit/math/mix/prob/binomial_logit_lpmf_test.cpp @@ -19,4 +19,5 @@ TEST(mathMixScalFun, binomial_logit_lpmf) { stan::test::expect_ad(f(n_arr, 10), 2.11); stan::test::expect_ad(f(n_arr, N_arr), 2.11); stan::test::expect_ad(f(5, N_arr), 2.11); + stan::test::expect_ad(f(5, N_arr), alpha); }