Skip to content

Commit

Permalink
Merge branch 'feature/3299-improved-ESS-Rhat' of https://github.com/s…
Browse files Browse the repository at this point in the history
…tan-dev/stan into feature/3299-improved-ESS-Rhat
  • Loading branch information
mitzimorris committed Oct 7, 2024
2 parents 1355de8 + c184c9e commit a5fe9d2
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/stan/mcmc/chains.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ class chains {
Eigen::MatrixXd chains(n_kept_samples, n_chains);
for (size_t i = 0; i < n_chains; ++i) {
auto bottom_rows = samples_(i).col(index).bottomRows(n_kept_samples);
chains.col(i) = bottom_rows.eval();
chains.col(i) = bottom_rows.eval();
}
return analyze::split_rank_normalized_rhat(chains);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ TEST_F(ComputeRhat, compute_potential_scale_reduction_constant) {
<< "rhat for index: " << 1 << ", parameter: " << chains.param_name(1);
}


TEST_F(ComputeRhat, compute_potential_scale_reduction_nan) {
std::vector<std::string> param_names{"a"};
stan::mcmc::chains<> chains(param_names);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@ TEST(RankNormalizedEss, compute_split_rank_normalized_ess) {
eight_schools_2
= stan::io::stan_csv_reader::parse(eight_schools_2_stream, &out);
eight_schools_2_stream.close();

// test against R implementation in pkg posterior (via cmdstanr)
Eigen::VectorXd ess_8_schools_bulk(10);
ess_8_schools_bulk << 348, 370, 600, 638, 765, 608, 629, 274, 517, 112;
Eigen::VectorXd ess_8_schools_tail(10);
ess_8_schools_tail << 845, 858, 874, 726, 620, 753, 826, 628, 587, 108;


Eigen::MatrixXd chains(eight_schools_1.samples.rows(), 2);
for (size_t i = 0; i < 10; ++i) {
chains.col(0) = eight_schools_1.samples.col(i + 7);
Expand All @@ -57,7 +56,6 @@ TEST(RankNormalizedEss, short_chains_fail) {
eight_schools_5iters_2
= stan::io::stan_csv_reader::parse(eight_schools_5iters_2_stream, &out);
eight_schools_5iters_2_stream.close();


Eigen::MatrixXd chains(eight_schools_5iters_1.samples.rows(), 2);
for (size_t i = 0; i < 10; ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ TEST(RankNormalizedRhat, const_fail) {
bernoulli_const_2
= stan::io::stan_csv_reader::parse(bernoulli_const_2_stream, &out);
bernoulli_const_2_stream.close();

Eigen::MatrixXd chains(bernoulli_const_1.samples.rows(), 2);
chains.col(0) = bernoulli_const_1.samples.col(bernoulli_const_1.samples.cols() - 1);
chains.col(1) = bernoulli_const_2.samples.col(bernoulli_const_2.samples.cols() - 1);
chains.col(0)
= bernoulli_const_1.samples.col(bernoulli_const_1.samples.cols() - 1);
chains.col(1)
= bernoulli_const_2.samples.col(bernoulli_const_2.samples.cols() - 1);
auto rhat = stan::analyze::split_rank_normalized_rhat(chains);
EXPECT_TRUE(std::isnan(rhat.first));
EXPECT_TRUE(std::isnan(rhat.second));
Expand Down

0 comments on commit a5fe9d2

Please sign in to comment.