From 35bdb07570910d0ef59e11861c966f4f83077404 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Fri, 11 Oct 2024 17:15:53 -0400 Subject: [PATCH] basic_ess needed for mcse --- .../mcmc/split_rank_normalized_ess.hpp | 27 +++++++++++++++++-- src/stan/io/stan_csv_reader.hpp | 2 +- src/stan/mcmc/chainset.hpp | 4 +-- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/stan/analyze/mcmc/split_rank_normalized_ess.hpp b/src/stan/analyze/mcmc/split_rank_normalized_ess.hpp index 6a26490147..4087e497f7 100644 --- a/src/stan/analyze/mcmc/split_rank_normalized_ess.hpp +++ b/src/stan/analyze/mcmc/split_rank_normalized_ess.hpp @@ -103,7 +103,8 @@ double ess(const Eigen::MatrixXd& chains) { /** * Computes the split effective sample size (split ESS) using rank based * diagnostic for a set of per-chain draws. Based on paper - * https://arxiv.org/abs/1903.08008 + * https://arxiv.org/abs/1903.08008 Computes bulk ESS over entire sample, + * and tail ESS over the 0.05 and 0.95 quantiles. * * When the number of total draws N is odd, the last draw is ignored. * @@ -111,7 +112,7 @@ double ess(const Eigen::MatrixXd& chains) { * Scale Reduction". http://mc-stan.org/users/documentation * @param chains matrix of per-chain draws, num_iters X chain - * @return potential scale reduction + * @return pair ESS_bulk, ESS_tail */ inline std::pair split_rank_normalized_ess( const Eigen::MatrixXd& chains) { @@ -142,6 +143,28 @@ inline std::pair split_rank_normalized_ess( return std::make_pair(ess_bulk, ess_tail); } +/** + * Computes the split effective sample size (split ESS) + * diagnostic for a set of per-chain draws. + * + * When the number of total draws N is odd, the last draw is ignored. + * + * See more details in Stan reference manual section "Potential + * Scale Reduction". http://mc-stan.org/users/documentation + + * @param chains matrix of per-chain draws, num_iters X chain + * @return potential scale reduction + */ +inline double split_basic_ess( + const Eigen::MatrixXd& chains) { + Eigen::MatrixXd split_draws_matrix = split_chains(chains); + if (!is_finite_and_varies(split_draws_matrix) + || split_draws_matrix.rows() < 4) { + return std::numeric_limits::quiet_NaN(); + } + return ess(split_draws_matrix); +} + } // namespace analyze } // namespace stan diff --git a/src/stan/io/stan_csv_reader.hpp b/src/stan/io/stan_csv_reader.hpp index a5dfc9f0d0..459ee456f1 100644 --- a/src/stan/io/stan_csv_reader.hpp +++ b/src/stan/io/stan_csv_reader.hpp @@ -377,7 +377,7 @@ class stan_csv_reader { if (!read_samples(in, data.samples, data.timing)) { if (out) - *out << "Unable to parse sample" << std::endl; + *out << "no draws found" << std::endl; } return data; } diff --git a/src/stan/mcmc/chainset.hpp b/src/stan/mcmc/chainset.hpp index 3846212e95..eb33837619 100644 --- a/src/stan/mcmc/chainset.hpp +++ b/src/stan/mcmc/chainset.hpp @@ -420,8 +420,8 @@ class chainset { * @return pair (bulk_ess, tail_ess) */ double mcse_mean(const int index) const { - double ess_bulk = analyze::split_rank_normalized_ess(samples(index)).first; - return sd(index) / std::sqrt(ess_bulk); + double ess_basic = analyze::split_basic_ess(samples(index)); + return sd(index) / std::sqrt(ess_basic); } /**