From b3631beaefae57c43b3629f70ae161a2653dd334 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Mon, 22 Apr 2024 11:03:57 -0400 Subject: [PATCH] Revert "update so scale_reduction calls scale_reduction_rank" This reverts commit f12f2591348c30e3849bb9c964ce67f248fcf6ab. --- .../compute_potential_scale_reduction.hpp | 73 ++++++++++++++++++- src/stan/mcmc/chains.hpp | 7 +- ...compute_potential_scale_reduction_test.cpp | 68 +++++++++-------- src/test/unit/mcmc/chains_test.cpp | 22 ++++-- 4 files changed, 124 insertions(+), 46 deletions(-) diff --git a/src/stan/analyze/mcmc/compute_potential_scale_reduction.hpp b/src/stan/analyze/mcmc/compute_potential_scale_reduction.hpp index 47d0769573..93813775ca 100644 --- a/src/stan/analyze/mcmc/compute_potential_scale_reduction.hpp +++ b/src/stan/analyze/mcmc/compute_potential_scale_reduction.hpp @@ -183,8 +183,73 @@ inline std::pair compute_potential_scale_reduction_rank( * @return potential scale reduction for the specified parameter */ inline double compute_potential_scale_reduction( - const std::vector& draws, const std::vector& sizes) { - return compute_potential_scale_reduction_rank(draws, sizes).first; + std::vector draws, std::vector sizes) { + int num_chains = sizes.size(); + size_t num_draws = sizes[0]; + if (num_draws == 0) { + return std::numeric_limits::quiet_NaN(); + } + for (int chain = 1; chain < num_chains; ++chain) { + num_draws = std::min(num_draws, sizes[chain]); + } + + // check if chains are constant; all equal to first draw's value + bool are_all_const = false; + Eigen::VectorXd init_draw = Eigen::VectorXd::Zero(num_chains); + + for (int chain = 0; chain < num_chains; chain++) { + Eigen::Map> draw( + draws[chain], sizes[chain]); + + for (int n = 0; n < num_draws; n++) { + if (!std::isfinite(draw(n))) { + return std::numeric_limits::quiet_NaN(); + } + } + + init_draw(chain) = draw(0); + + if (draw.isApproxToConstant(draw(0))) { + are_all_const |= true; + } + } + + if (are_all_const) { + // If all chains are constant then return NaN + // if they all equal the same constant value + if (init_draw.isApproxToConstant(init_draw(0))) { + return std::numeric_limits::quiet_NaN(); + } + } + + using boost::accumulators::accumulator_set; + using boost::accumulators::stats; + using boost::accumulators::tag::mean; + using boost::accumulators::tag::variance; + + Eigen::VectorXd chain_mean(num_chains); + accumulator_set> acc_chain_mean; + Eigen::VectorXd chain_var(num_chains); + double unbiased_var_scale = num_draws / (num_draws - 1.0); + + for (int chain = 0; chain < num_chains; ++chain) { + accumulator_set> acc_draw; + for (int n = 0; n < num_draws; ++n) { + acc_draw(draws[chain][n]); + } + + chain_mean(chain) = boost::accumulators::mean(acc_draw); + acc_chain_mean(chain_mean(chain)); + chain_var(chain) + = boost::accumulators::variance(acc_draw) * unbiased_var_scale; + } + + double var_between = num_draws * boost::accumulators::variance(acc_chain_mean) + * num_chains / (num_chains - 1); + double var_within = chain_var.mean(); + + // rewrote [(n-1)*W/n + B/n]/W as (n-1+ B/W)/n + return sqrt((var_between / var_within + num_draws - 1) / num_draws); } /** @@ -230,7 +295,7 @@ inline double compute_potential_scale_reduction( std::vector draws, size_t size) { int num_chains = draws.size(); std::vector sizes(num_chains, size); - return compute_potential_scale_reduction_rank(draws, sizes).first; + return compute_potential_scale_reduction(draws, sizes); } /** @@ -298,7 +363,7 @@ inline double compute_split_potential_scale_reduction( double half = num_draws / 2.0; std::vector half_sizes(2 * num_chains, std::floor(half)); - return compute_potential_scale_reduction_rank(split_draws, half_sizes).first; + return compute_potential_scale_reduction(split_draws, half_sizes); } /** diff --git a/src/stan/mcmc/chains.hpp b/src/stan/mcmc/chains.hpp index 69ae7669f8..a553fd36cd 100644 --- a/src/stan/mcmc/chains.hpp +++ b/src/stan/mcmc/chains.hpp @@ -623,17 +623,16 @@ class chains { sizes[chain] = n_kept_samples; } - return analyze::compute_split_potential_scale_reduction_rank(draws, sizes) - .first; + return analyze::compute_split_potential_scale_reduction(draws, sizes); } std::pair split_potential_scale_reduction_rank( const std::string& name) const { - return this->split_potential_scale_reduction_rank(index(name)); + return split_potential_scale_reduction_rank(index(name)); } double split_potential_scale_reduction(const std::string& name) const { - return this->split_potential_scale_reduction_rank(index(name)).first; + return split_potential_scale_reduction(index(name)); } }; diff --git a/src/test/unit/analyze/mcmc/compute_potential_scale_reduction_test.cpp b/src/test/unit/analyze/mcmc/compute_potential_scale_reduction_test.cpp index 84f6efbe81..41a102e51d 100644 --- a/src/test/unit/analyze/mcmc/compute_potential_scale_reduction_test.cpp +++ b/src/test/unit/analyze/mcmc/compute_potential_scale_reduction_test.cpp @@ -31,13 +31,13 @@ TEST_F(ComputeRhat, compute_potential_scale_reduction) { chains.add(blocker2); Eigen::VectorXd rhat(48); - rhat << 1.00067, 0.999789, 0.999656, 1.00055, 1.0011, 1.00088, 1.00032, - 0.999969, 1.00201, 0.999558, 0.999555, 0.9995, 1.00292, 1.00516, 1.00591, - 0.999753, 1.00088, 1.00895, 1.00079, 0.99953, 1.00092, 1.00044, 1.01005, - 0.999598, 1.00151, 0.999659, 0.999648, 0.999627, 1.00315, 1.00277, - 1.00247, 1.00003, 0.999937, 1.00116, 0.999521, 1.0005, 1.00091, 1.00213, - 1.00019, 0.999767, 1.0003, 0.999815, 1.00003, 0.999672, 1.00306, 1.00072, - 0.999602, 0.999789; + rhat << 1.00042, 1.00036, 0.99955, 1.00047, 1.00119, 1.00089, 1.00018, + 1.00019, 1.00226, 0.99954, 0.9996, 0.99951, 1.00237, 1.00515, 1.00566, + 0.99957, 1.00099, 1.00853, 1.0008, 0.99961, 1.0006, 1.00046, 1.01023, + 0.9996, 1.0011, 0.99967, 0.99973, 0.99958, 1.00242, 1.00213, 1.00244, + 0.99998, 0.99969, 1.00079, 0.99955, 1.0009, 1.00136, 1.00288, 1.00036, + 0.99989, 1.00077, 0.99997, 1.00194, 0.99972, 1.00257, 1.00109, 1.00004, + 0.99955; // replicates calls to stan::analyze::compute_effective_sample_size // for any interface *without* access to chains class @@ -129,13 +129,13 @@ TEST_F(ComputeRhat, compute_potential_scale_reduction_convenience) { chains.add(blocker2); Eigen::VectorXd rhat(48); - rhat << 1.00067, 0.999789, 0.999656, 1.00055, 1.0011, 1.00088, 1.00032, - 0.999969, 1.00201, 0.999558, 0.999555, 0.9995, 1.00292, 1.00516, 1.00591, - 0.999753, 1.00088, 1.00895, 1.00079, 0.99953, 1.00092, 1.00044, 1.01005, - 0.999598, 1.00151, 0.999659, 0.999648, 0.999627, 1.00315, 1.00277, - 1.00247, 1.00003, 0.999937, 1.00116, 0.999521, 1.0005, 1.00091, 1.00213, - 1.00019, 0.999767, 1.0003, 0.999815, 1.00003, 0.999672, 1.00306, 1.00072, - 0.999602, 0.999789; + rhat << 1.00042, 1.00036, 0.99955, 1.00047, 1.00119, 1.00089, 1.00018, + 1.00019, 1.00226, 0.99954, 0.9996, 0.99951, 1.00237, 1.00515, 1.00566, + 0.99957, 1.00099, 1.00853, 1.0008, 0.99961, 1.0006, 1.00046, 1.01023, + 0.9996, 1.0011, 0.99967, 0.99973, 0.99958, 1.00242, 1.00213, 1.00244, + 0.99998, 0.99969, 1.00079, 0.99955, 1.0009, 1.00136, 1.00288, 1.00036, + 0.99989, 1.00077, 0.99997, 1.00194, 0.99972, 1.00257, 1.00109, 1.00004, + 0.99955; Eigen::Matrix samples( chains.num_chains()); @@ -222,12 +222,14 @@ TEST_F(ComputeRhat, chains_compute_split_potential_scale_reduction) { chains.add(blocker2); Eigen::VectorXd rhat(48); - rhat << 1.0078, 1.0109, 0.999187, 1.001, 1.00401, 1.00992, 1.00182, 1.00519, - 1.00095, 1.00351, 1.00554, 1.00075, 1.00595, 1.00473, 1.00546, 1.01304, - 1.00166, 1.0074, 1.00178, 1.00588, 1.00406, 1.00129, 1.00976, 1.0013, - 1.00193, 1.00104, 0.999383, 1.00025, 1.01082, 1.0019, 1.00354, 1.0043, - 1.00111, 1.00281, 1.00436, 1.00515, 1.00325, 1.0089, 1.00222, 1.00118, - 1.00191, 1.00283, 1.0003, 1.00216, 1.00335, 1.00133, 1.00023, 1.0109; + rhat << 1.00718, 1.00473, 0.999203, 1.00061, 1.00378, 1.01031, 1.00173, + 1.0045, 1.00111, 1.00337, 1.00546, 1.00105, 1.00558, 1.00463, 1.00534, + 1.01244, 1.00174, 1.00718, 1.00186, 1.00554, 1.00436, 1.00147, 1.01017, + 1.00162, 1.00143, 1.00058, 0.999221, 1.00012, 1.01028, 1.001, 1.00305, + 1.00435, 1.00055, 1.00246, 1.00447, 1.0048, 1.00209, 1.01159, 1.00202, + 1.00077, 1.0021, 1.00262, 1.00308, 1.00197, 1.00246, 1.00085, 1.00047, + 1.00735; + for (int index = 4; index < chains.num_params(); index++) { ASSERT_NEAR(rhat(index - 4), chains.split_potential_scale_reduction(index), 1e-4) @@ -304,12 +306,13 @@ TEST_F(ComputeRhat, compute_split_potential_scale_reduction) { chains.add(blocker2); Eigen::VectorXd rhat(48); - rhat << 1.0078, 1.0109, 0.999187, 1.001, 1.00401, 1.00992, 1.00182, 1.00519, - 1.00095, 1.00351, 1.00554, 1.00075, 1.00595, 1.00473, 1.00546, 1.01304, - 1.00166, 1.0074, 1.00178, 1.00588, 1.00406, 1.00129, 1.00976, 1.0013, - 1.00193, 1.00104, 0.999383, 1.00025, 1.01082, 1.0019, 1.00354, 1.0043, - 1.00111, 1.00281, 1.00436, 1.00515, 1.00325, 1.0089, 1.00222, 1.00118, - 1.00191, 1.00283, 1.0003, 1.00216, 1.00335, 1.00133, 1.00023, 1.0109; + rhat << 1.00718, 1.00473, 0.999203, 1.00061, 1.00378, 1.01031, 1.00173, + 1.0045, 1.00111, 1.00337, 1.00546, 1.00105, 1.00558, 1.00463, 1.00534, + 1.01244, 1.00174, 1.00718, 1.00186, 1.00554, 1.00436, 1.00147, 1.01017, + 1.00162, 1.00143, 1.00058, 0.999221, 1.00012, 1.01028, 1.001, 1.00305, + 1.00435, 1.00055, 1.00246, 1.00447, 1.0048, 1.00209, 1.01159, 1.00202, + 1.00077, 1.0021, 1.00262, 1.00308, 1.00197, 1.00246, 1.00085, 1.00047, + 1.00735; // replicates calls to stan::analyze::compute_effective_sample_size // for any interface *without* access to chains class @@ -402,12 +405,13 @@ TEST_F(ComputeRhat, compute_split_potential_scale_reduction_convenience) { chains.add(blocker2); Eigen::VectorXd rhat(48); - rhat << 1.0078, 1.0109, 0.999187, 1.001, 1.00401, 1.00992, 1.00182, 1.00519, - 1.00095, 1.00351, 1.00554, 1.00075, 1.00595, 1.00473, 1.00546, 1.01304, - 1.00166, 1.0074, 1.00178, 1.00588, 1.00406, 1.00129, 1.00976, 1.0013, - 1.00193, 1.00104, 0.999383, 1.00025, 1.01082, 1.0019, 1.00354, 1.0043, - 1.00111, 1.00281, 1.00436, 1.00515, 1.00325, 1.0089, 1.00222, 1.00118, - 1.00191, 1.00283, 1.0003, 1.00216, 1.00335, 1.00133, 1.00023, 1.0109; + rhat << 1.00718, 1.00473, 0.999203, 1.00061, 1.00378, 1.01031, 1.00173, + 1.0045, 1.00111, 1.00337, 1.00546, 1.00105, 1.00558, 1.00463, 1.00534, + 1.01244, 1.00174, 1.00718, 1.00186, 1.00554, 1.00436, 1.00147, 1.01017, + 1.00162, 1.00143, 1.00058, 0.999221, 1.00012, 1.01028, 1.001, 1.00305, + 1.00435, 1.00055, 1.00246, 1.00447, 1.0048, 1.00209, 1.01159, 1.00202, + 1.00077, 1.0021, 1.00262, 1.00308, 1.00197, 1.00246, 1.00085, 1.00047, + 1.00735; Eigen::Matrix samples( chains.num_chains()); diff --git a/src/test/unit/mcmc/chains_test.cpp b/src/test/unit/mcmc/chains_test.cpp index b0d0db1d24..f2d90aaf0d 100644 --- a/src/test/unit/mcmc/chains_test.cpp +++ b/src/test/unit/mcmc/chains_test.cpp @@ -853,12 +853,13 @@ TEST_F(McmcChains, blocker_split_potential_scale_reduction) { chains.add(blocker2); Eigen::VectorXd rhat(48); - rhat << 1.0078, 1.0109, 0.99919, 1.001, 1.00401, 1.00992, 1.00182, 1.00519, - 1.00095, 1.00351, 1.00554, 1.00075, 1.00595, 1.00473, 1.00546, 1.01304, - 1.00166, 1.0074, 1.00178, 1.00588, 1.00406, 1.00129, 1.00976, 1.0013, - 1.00193, 1.00104, 0.99938, 1.00025, 1.01082, 1.0019, 1.00354, 1.0043, - 1.00111, 1.00281, 1.00436, 1.00515, 1.00325, 1.0089, 1.00222, 1.00118, - 1.00191, 1.00283, 1.0003, 1.00216, 1.00335, 1.00133, 1.00023, 1.0109; + rhat << 1.00718, 1.00473, 0.999203, 1.00061, 1.00378, 1.01031, 1.00173, + 1.0045, 1.00111, 1.00337, 1.00546, 1.00105, 1.00558, 1.00463, 1.00534, + 1.01244, 1.00174, 1.00718, 1.00186, 1.00554, 1.00436, 1.00147, 1.01017, + 1.00162, 1.00143, 1.00058, 0.999221, 1.00012, 1.01028, 1.001, 1.00305, + 1.00435, 1.00055, 1.00246, 1.00447, 1.0048, 1.00209, 1.01159, 1.00202, + 1.00077, 1.0021, 1.00262, 1.00308, 1.00197, 1.00246, 1.00085, 1.00047, + 1.00735; for (int index = 4; index < chains.num_params(); index++) { ASSERT_NEAR(rhat(index - 4), chains.split_potential_scale_reduction(index), @@ -885,6 +886,15 @@ TEST_F(McmcChains, blocker_split_potential_scale_reduction_rank) { stan::mcmc::chains<> chains(blocker1); chains.add(blocker2); + // Eigen::VectorXd rhat(48); + // rhat + // << 1.0078, 1.0109, 1.00731, 1.00333, 1.00401, 1.00992, 1.00734, 1.00633, + // 1.00095, 1.00906, 1.01019, 1.00075, 1.00595, 1.00473, 1.00895, 1.01304, + // 1.00166, 1.0074, 1.00236, 1.00588, 1.00414, 1.00303, 1.00976, 1.00295, + // 1.00193, 1.0044, 1.00488, 1.00178, 1.01082, 1.0019, 1.00413, 1.01303, + // 1.0024, 1.01148, 1.00436, 1.00515, 1.00712, 1.0089, 1.00222, 1.00118, + // 1.00381, 1.00283, 1.00188, 1.00225, 1.00335, 1.00133, 1.00209, 1.0109; + Eigen::VectorXd rhat_bulk(48); rhat_bulk << 1.0078, 1.0109, 0.99919, 1.001, 1.00401, 1.00992, 1.00182, 1.00519, 1.00095, 1.00351, 1.00554, 1.00075, 1.00595, 1.00473, 1.00546,