Skip to content

Commit

Permalink
changes per code review
Browse files Browse the repository at this point in the history
  • Loading branch information
mitzimorris committed Oct 16, 2024
1 parent 427cf70 commit 79c675a
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 36 deletions.
56 changes: 56 additions & 0 deletions src/stan/analyze/mcmc/mcse.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#ifndef STAN_ANALYZE_MCMC_MCSE_HPP
#define STAN_ANALYZE_MCMC_MCSE_HPP

#include <stan/analyze/mcmc/check_chains.hpp>
#include <stan/analyze/mcmc/split_rank_normalized_ess.hpp>
#include <stan/math/prim.hpp>
#include <cmath>
#include <vector>
#include <algorithm>

namespace stan {
namespace analyze {


/**
* Computes the mean Monte Carlo error estimate for the central 90% interval.
* See https://arxiv.org/abs/1903.08008, section 4.4.
* Follows implementation in the R posterior package.
*
* @param chains matrix of draws across all chains
* @return mcse
*/
inline double mcse_mean(const Eigen::MatrixXd& chains) {
const Eigen::Index num_draws = chains.rows();
if (chains.rows() < 4
|| !is_finite_and_varies(chains))
return std::numeric_limits<double>::quiet_NaN();

double sd = (chains.array() - chains.mean()).square().sum() / (chains.size() - 1);
return std::sqrt(sd / ess(chains));
}

/**
* Computes the standard deviation of the Monte Carlo error estimate
* https://arxiv.org/abs/1903.08008, section 4.4.
* Follows implementation in the R posterior package:
* https://github.com/stan-dev/posterior/blob/98bf52329d68f3307ac4ecaaea659276ee1de8df/R/convergence.R#L478-L496
*
* @param chains matrix of draws across all chains
* @return mcse
*/
inline double mcse_sd(const Eigen::MatrixXd& chains) {
if (chains.rows() < 4
|| !is_finite_and_varies(chains))
return std::numeric_limits<double>::quiet_NaN();

Eigen::MatrixXd diffs = (chains.array() - chains.mean()).matrix();
double Evar = diffs.array().square().mean();
double varvar = (math::mean(diffs.array().pow(4) - Evar * Evar)) / ess(diffs.array().abs().matrix());
return std::sqrt(varvar / Evar / 4);
}

} // namespace analyze
} // namespace stan

#endif
54 changes: 23 additions & 31 deletions src/stan/mcmc/chainset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <stan/math/prim/fun/quantile.hpp>
#include <stan/analyze/mcmc/split_rank_normalized_ess.hpp>
#include <stan/analyze/mcmc/split_rank_normalized_rhat.hpp>
#include <stan/analyze/mcmc/mcse.hpp>
#include <algorithm>
#include <cmath>
#include <iostream>
Expand All @@ -20,10 +21,9 @@

namespace stan {
namespace mcmc {
using Eigen::Dynamic;

/**
* An <code>mcmc::chainset</code> object manages the post-warmup draws
* A <code>mcmc::chainset</code> object manages the post-warmup draws
* across a set of MCMC chains, which all have the same number of samples.
*
* @note samples are stored in column major, i.e., each column corresponds to
Expand Down Expand Up @@ -290,17 +290,16 @@ class chainset {
* Compute the quantile value of the specified parameter
* at the specified probability.
*
* Throws exception if specified probability is not between 0 and 1.
* Calls stan::math::quantile which throws
* std::invalid_argument If any element of samples_vec is NaN, or size 0.
* and std::domain_error If `p<0` or `p>1`.
*
* @param index parameter index
* @param prob probability
* @return parameter value at quantile
*/
double quantile(const int index, const double prob) const {
// Ensure the probability is within [0, 1]
if (prob <= 0.0 || prob >= 1.0) {
throw std::out_of_range("Probability must be between 0 and 1.");
}
Eigen::MatrixXd draws = samples(index);
Eigen::Map<Eigen::VectorXd> map(draws.data(), draws.size());
return stan::math::quantile(map, prob);
Expand All @@ -310,8 +309,6 @@ class chainset {
* Compute the quantile value of the specified parameter
* at the specified probability.
*
* Throws exception if specified probability is not between 0 and 1.
*
* @param name parameter name
* @param prob probability
* @return parameter value at quantile
Expand All @@ -324,8 +321,6 @@ class chainset {
* Compute the quantile values of the specified parameter
* for a set of specified probabilities.
*
* Throws exception if any probability is not between 0 and 1.
*
* @param index parameter index
* @param probs vector of probabilities
* @return vector of parameter values for quantiles
Expand All @@ -334,9 +329,6 @@ class chainset {
const Eigen::VectorXd& probs) const {
if (probs.size() == 0)
return Eigen::VectorXd::Zero(0);
if (probs.minCoeff() <= 0.0 || probs.maxCoeff() >= 1.0) {
throw std::out_of_range("Probabilities must be between 0 and 1.");
}
Eigen::MatrixXd draws = samples(index);
Eigen::Map<Eigen::VectorXd> map(draws.data(), draws.size());
std::vector<double> probs_vec(probs.data(), probs.data() + probs.size());
Expand All @@ -348,8 +340,6 @@ class chainset {
* Compute the quantile values of the specified parameter
* for a set of specified probabilities.
*
* Throws exception if any probability is not between 0 and 1.
*
* @param name parameter name
* @param probs vector of probabilities
* @return vector of parameter values for quantiles
Expand Down Expand Up @@ -420,11 +410,12 @@ class chainset {
* @return mcse
*/
double mcse_mean(const int index) const {
if (num_samples() < 4
|| !stan::analyze::is_finite_and_varies(samples(index)))
return std::numeric_limits<double>::quiet_NaN();
double ess = analyze::ess(samples(index));
return sd(index) / std::sqrt(ess);
return analyze::mcse_mean(samples(index));
// if (num_samples() < 4
// || !stan::analyze::is_finite_and_varies(samples(index)))
// return std::numeric_limits<double>::quiet_NaN();
// double ess = analyze::ess(samples(index));
// return sd(index) / std::sqrt(ess);
}

/**
Expand All @@ -448,17 +439,18 @@ class chainset {
* @return mcse_sd
*/
double mcse_sd(const int index) const {
if (num_samples() < 4
|| !stan::analyze::is_finite_and_varies(samples(index)))
return std::numeric_limits<double>::quiet_NaN();
Eigen::MatrixXd s = samples(index);
Eigen::MatrixXd s2 = s.array().square();
double ess_s = analyze::ess(s);
double ess_s2 = analyze::ess(s2);
double ess_sd = std::min(ess_s, ess_s2);
return sd(index)
* std::sqrt(stan::math::e() * std::pow(1 - 1 / ess_sd, ess_sd - 1)
- 1);
return analyze::mcse_sd(samples(index));
// if (num_samples() < 4
// || !stan::analyze::is_finite_and_varies(samples(index)))
// return std::numeric_limits<double>::quiet_NaN();
// Eigen::MatrixXd s = samples(index);
// Eigen::MatrixXd s2 = s.array().square();
// double ess_s = analyze::ess(s);
// double ess_s2 = analyze::ess(s2);
// double ess_sd = std::min(ess_s, ess_s2);
// return sd(index)
// * std::sqrt(stan::math::e() * std::pow(1 - 1 / ess_sd, ess_sd - 1)
// - 1);
}

/**
Expand Down
10 changes: 5 additions & 5 deletions src/test/unit/mcmc/chainset_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ TEST_F(McmcChains, split_rank_normalized_rhat) {

for (size_t i = 0; i < 10; ++i) {
auto rhats = chain_1.split_rank_normalized_rhat(i + 7);
EXPECT_NEAR(rhats.first, rhat_8_schools_1_bulk(i), 0.05);
EXPECT_NEAR(rhats.second, rhat_8_schools_1_tail(i), 0.05);
EXPECT_NEAR(rhats.first, rhat_8_schools_1_bulk(i), 0.04);
EXPECT_NEAR(rhats.second, rhat_8_schools_1_tail(i), 0.04);
}
}

Expand Down Expand Up @@ -285,9 +285,9 @@ TEST_F(McmcChains, mcse) {

for (size_t i = 0; i < 10; ++i) {
auto mcse_mean = chain_2.mcse_mean(i + 7);
EXPECT_NEAR(mcse_mean, s8_mcse_mean(i), 0.5);
auto mcse_sd = chain_2.mcse_sd(i + 7);
EXPECT_NEAR(mcse_sd, s8_mcse_sd(i), 0.7);
EXPECT_NEAR(mcse_mean, s8_mcse_mean(i), 0.05);
EXPECT_NEAR(mcse_sd, s8_mcse_sd(i), 0.09);
}
}

Expand Down Expand Up @@ -329,6 +329,6 @@ TEST_F(McmcChains, autocorrelation) {
0.01791577080, 0.01245035817;
auto mu_ac = chain_1.autocorrelation(0, "mu");
for (size_t i = 0; i < 10; ++i) {
EXPECT_NEAR(mu_ac_posterior(i), mu_ac(i), 0.05);
EXPECT_NEAR(mu_ac_posterior(i), mu_ac(i), 0.0005);
}
}

0 comments on commit 79c675a

Please sign in to comment.