diff --git a/src/stan/analyze/mcmc/mcse.hpp b/src/stan/analyze/mcmc/mcse.hpp new file mode 100644 index 0000000000..9bd9aa8341 --- /dev/null +++ b/src/stan/analyze/mcmc/mcse.hpp @@ -0,0 +1,56 @@ +#ifndef STAN_ANALYZE_MCMC_MCSE_HPP +#define STAN_ANALYZE_MCMC_MCSE_HPP + +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace analyze { + + + /** + * Computes the mean Monte Carlo error estimate for the central 90% interval. + * See https://arxiv.org/abs/1903.08008, section 4.4. + * Follows implementation in the R posterior package. + * + * @param chains matrix of draws across all chains + * @return mcse + */ + inline double mcse_mean(const Eigen::MatrixXd& chains) { + const Eigen::Index num_draws = chains.rows(); + if (chains.rows() < 4 + || !is_finite_and_varies(chains)) + return std::numeric_limits::quiet_NaN(); + + double sd = (chains.array() - chains.mean()).square().sum() / (chains.size() - 1); + return std::sqrt(sd / ess(chains)); + } + + /** + * Computes the standard deviation of the Monte Carlo error estimate + * https://arxiv.org/abs/1903.08008, section 4.4. + * Follows implementation in the R posterior package: + * https://github.com/stan-dev/posterior/blob/98bf52329d68f3307ac4ecaaea659276ee1de8df/R/convergence.R#L478-L496 + * + * @param chains matrix of draws across all chains + * @return mcse + */ + inline double mcse_sd(const Eigen::MatrixXd& chains) { + if (chains.rows() < 4 + || !is_finite_and_varies(chains)) + return std::numeric_limits::quiet_NaN(); + + Eigen::MatrixXd diffs = (chains.array() - chains.mean()).matrix(); + double Evar = diffs.array().square().mean(); + double varvar = (math::mean(diffs.array().pow(4) - Evar * Evar)) / ess(diffs.array().abs().matrix()); + return std::sqrt(varvar / Evar / 4); + } + +} // namespace analyze +} // namespace stan + +#endif diff --git a/src/stan/mcmc/chainset.hpp b/src/stan/mcmc/chainset.hpp index 264becb7f5..b1f735d344 100644 --- a/src/stan/mcmc/chainset.hpp +++ b/src/stan/mcmc/chainset.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -20,10 +21,9 @@ namespace stan { namespace mcmc { -using Eigen::Dynamic; /** - * An mcmc::chainset object manages the post-warmup draws + * A mcmc::chainset object manages the post-warmup draws * across a set of MCMC chains, which all have the same number of samples. * * @note samples are stored in column major, i.e., each column corresponds to @@ -290,7 +290,9 @@ class chainset { * Compute the quantile value of the specified parameter * at the specified probability. * - * Throws exception if specified probability is not between 0 and 1. + * Calls stan::math::quantile which throws + * std::invalid_argument If any element of samples_vec is NaN, or size 0. + * and std::domain_error If `p<0` or `p>1`. * * @param index parameter index * @param prob probability @@ -298,9 +300,6 @@ class chainset { */ double quantile(const int index, const double prob) const { // Ensure the probability is within [0, 1] - if (prob <= 0.0 || prob >= 1.0) { - throw std::out_of_range("Probability must be between 0 and 1."); - } Eigen::MatrixXd draws = samples(index); Eigen::Map map(draws.data(), draws.size()); return stan::math::quantile(map, prob); @@ -310,8 +309,6 @@ class chainset { * Compute the quantile value of the specified parameter * at the specified probability. * - * Throws exception if specified probability is not between 0 and 1. - * * @param name parameter name * @param prob probability * @return parameter value at quantile @@ -324,8 +321,6 @@ class chainset { * Compute the quantile values of the specified parameter * for a set of specified probabilities. * - * Throws exception if any probability is not between 0 and 1. - * * @param index parameter index * @param probs vector of probabilities * @return vector of parameter values for quantiles @@ -334,9 +329,6 @@ class chainset { const Eigen::VectorXd& probs) const { if (probs.size() == 0) return Eigen::VectorXd::Zero(0); - if (probs.minCoeff() <= 0.0 || probs.maxCoeff() >= 1.0) { - throw std::out_of_range("Probabilities must be between 0 and 1."); - } Eigen::MatrixXd draws = samples(index); Eigen::Map map(draws.data(), draws.size()); std::vector probs_vec(probs.data(), probs.data() + probs.size()); @@ -348,8 +340,6 @@ class chainset { * Compute the quantile values of the specified parameter * for a set of specified probabilities. * - * Throws exception if any probability is not between 0 and 1. - * * @param name parameter name * @param probs vector of probabilities * @return vector of parameter values for quantiles @@ -420,11 +410,12 @@ class chainset { * @return mcse */ double mcse_mean(const int index) const { - if (num_samples() < 4 - || !stan::analyze::is_finite_and_varies(samples(index))) - return std::numeric_limits::quiet_NaN(); - double ess = analyze::ess(samples(index)); - return sd(index) / std::sqrt(ess); + return analyze::mcse_mean(samples(index)); + // if (num_samples() < 4 + // || !stan::analyze::is_finite_and_varies(samples(index))) + // return std::numeric_limits::quiet_NaN(); + // double ess = analyze::ess(samples(index)); + // return sd(index) / std::sqrt(ess); } /** @@ -448,17 +439,18 @@ class chainset { * @return mcse_sd */ double mcse_sd(const int index) const { - if (num_samples() < 4 - || !stan::analyze::is_finite_and_varies(samples(index))) - return std::numeric_limits::quiet_NaN(); - Eigen::MatrixXd s = samples(index); - Eigen::MatrixXd s2 = s.array().square(); - double ess_s = analyze::ess(s); - double ess_s2 = analyze::ess(s2); - double ess_sd = std::min(ess_s, ess_s2); - return sd(index) - * std::sqrt(stan::math::e() * std::pow(1 - 1 / ess_sd, ess_sd - 1) - - 1); + return analyze::mcse_sd(samples(index)); + // if (num_samples() < 4 + // || !stan::analyze::is_finite_and_varies(samples(index))) + // return std::numeric_limits::quiet_NaN(); + // Eigen::MatrixXd s = samples(index); + // Eigen::MatrixXd s2 = s.array().square(); + // double ess_s = analyze::ess(s); + // double ess_s2 = analyze::ess(s2); + // double ess_sd = std::min(ess_s, ess_s2); + // return sd(index) + // * std::sqrt(stan::math::e() * std::pow(1 - 1 / ess_sd, ess_sd - 1) + // - 1); } /** diff --git a/src/test/unit/mcmc/chainset_test.cpp b/src/test/unit/mcmc/chainset_test.cpp index 9662c3159c..5b8a371f94 100644 --- a/src/test/unit/mcmc/chainset_test.cpp +++ b/src/test/unit/mcmc/chainset_test.cpp @@ -189,8 +189,8 @@ TEST_F(McmcChains, split_rank_normalized_rhat) { for (size_t i = 0; i < 10; ++i) { auto rhats = chain_1.split_rank_normalized_rhat(i + 7); - EXPECT_NEAR(rhats.first, rhat_8_schools_1_bulk(i), 0.05); - EXPECT_NEAR(rhats.second, rhat_8_schools_1_tail(i), 0.05); + EXPECT_NEAR(rhats.first, rhat_8_schools_1_bulk(i), 0.04); + EXPECT_NEAR(rhats.second, rhat_8_schools_1_tail(i), 0.04); } } @@ -285,9 +285,9 @@ TEST_F(McmcChains, mcse) { for (size_t i = 0; i < 10; ++i) { auto mcse_mean = chain_2.mcse_mean(i + 7); - EXPECT_NEAR(mcse_mean, s8_mcse_mean(i), 0.5); auto mcse_sd = chain_2.mcse_sd(i + 7); - EXPECT_NEAR(mcse_sd, s8_mcse_sd(i), 0.7); + EXPECT_NEAR(mcse_mean, s8_mcse_mean(i), 0.05); + EXPECT_NEAR(mcse_sd, s8_mcse_sd(i), 0.09); } } @@ -329,6 +329,6 @@ TEST_F(McmcChains, autocorrelation) { 0.01791577080, 0.01245035817; auto mu_ac = chain_1.autocorrelation(0, "mu"); for (size_t i = 0; i < 10; ++i) { - EXPECT_NEAR(mu_ac_posterior(i), mu_ac(i), 0.05); + EXPECT_NEAR(mu_ac_posterior(i), mu_ac(i), 0.0005); } }