Skip to content

Commit

Permalink
Revert "update so scale_reduction calls scale_reduction_rank"
Browse files Browse the repository at this point in the history
This reverts commit f12f259.
  • Loading branch information
SteveBronder committed Apr 22, 2024
1 parent da5bb8d commit b3631be
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 46 deletions.
73 changes: 69 additions & 4 deletions src/stan/analyze/mcmc/compute_potential_scale_reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,73 @@ inline std::pair<double, double> compute_potential_scale_reduction_rank(
* @return potential scale reduction for the specified parameter
*/
inline double compute_potential_scale_reduction(
const std::vector<const double*>& draws, const std::vector<size_t>& sizes) {
return compute_potential_scale_reduction_rank(draws, sizes).first;
std::vector<const double*> draws, std::vector<size_t> sizes) {
int num_chains = sizes.size();
size_t num_draws = sizes[0];
if (num_draws == 0) {
return std::numeric_limits<double>::quiet_NaN();
}
for (int chain = 1; chain < num_chains; ++chain) {
num_draws = std::min(num_draws, sizes[chain]);
}

// check if chains are constant; all equal to first draw's value
bool are_all_const = false;
Eigen::VectorXd init_draw = Eigen::VectorXd::Zero(num_chains);

for (int chain = 0; chain < num_chains; chain++) {
Eigen::Map<const Eigen::Matrix<double, Eigen::Dynamic, 1>> draw(
draws[chain], sizes[chain]);

for (int n = 0; n < num_draws; n++) {
if (!std::isfinite(draw(n))) {
return std::numeric_limits<double>::quiet_NaN();
}
}

init_draw(chain) = draw(0);

if (draw.isApproxToConstant(draw(0))) {
are_all_const |= true;
}
}

if (are_all_const) {
// If all chains are constant then return NaN
// if they all equal the same constant value
if (init_draw.isApproxToConstant(init_draw(0))) {
return std::numeric_limits<double>::quiet_NaN();
}
}

using boost::accumulators::accumulator_set;
using boost::accumulators::stats;
using boost::accumulators::tag::mean;
using boost::accumulators::tag::variance;

Eigen::VectorXd chain_mean(num_chains);
accumulator_set<double, stats<variance>> acc_chain_mean;
Eigen::VectorXd chain_var(num_chains);
double unbiased_var_scale = num_draws / (num_draws - 1.0);

for (int chain = 0; chain < num_chains; ++chain) {
accumulator_set<double, stats<mean, variance>> acc_draw;
for (int n = 0; n < num_draws; ++n) {
acc_draw(draws[chain][n]);
}

chain_mean(chain) = boost::accumulators::mean(acc_draw);
acc_chain_mean(chain_mean(chain));
chain_var(chain)
= boost::accumulators::variance(acc_draw) * unbiased_var_scale;
}

double var_between = num_draws * boost::accumulators::variance(acc_chain_mean)
* num_chains / (num_chains - 1);
double var_within = chain_var.mean();

// rewrote [(n-1)*W/n + B/n]/W as (n-1+ B/W)/n
return sqrt((var_between / var_within + num_draws - 1) / num_draws);
}

/**
Expand Down Expand Up @@ -230,7 +295,7 @@ inline double compute_potential_scale_reduction(
std::vector<const double*> draws, size_t size) {
int num_chains = draws.size();
std::vector<size_t> sizes(num_chains, size);
return compute_potential_scale_reduction_rank(draws, sizes).first;
return compute_potential_scale_reduction(draws, sizes);
}

/**
Expand Down Expand Up @@ -298,7 +363,7 @@ inline double compute_split_potential_scale_reduction(
double half = num_draws / 2.0;
std::vector<size_t> half_sizes(2 * num_chains, std::floor(half));

return compute_potential_scale_reduction_rank(split_draws, half_sizes).first;
return compute_potential_scale_reduction(split_draws, half_sizes);
}

/**
Expand Down
7 changes: 3 additions & 4 deletions src/stan/mcmc/chains.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,17 +623,16 @@ class chains {
sizes[chain] = n_kept_samples;
}

return analyze::compute_split_potential_scale_reduction_rank(draws, sizes)
.first;
return analyze::compute_split_potential_scale_reduction(draws, sizes);
}

std::pair<double, double> split_potential_scale_reduction_rank(
const std::string& name) const {
return this->split_potential_scale_reduction_rank(index(name));
return split_potential_scale_reduction_rank(index(name));
}

double split_potential_scale_reduction(const std::string& name) const {
return this->split_potential_scale_reduction_rank(index(name)).first;
return split_potential_scale_reduction(index(name));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ TEST_F(ComputeRhat, compute_potential_scale_reduction) {
chains.add(blocker2);

Eigen::VectorXd rhat(48);
rhat << 1.00067, 0.999789, 0.999656, 1.00055, 1.0011, 1.00088, 1.00032,
0.999969, 1.00201, 0.999558, 0.999555, 0.9995, 1.00292, 1.00516, 1.00591,
0.999753, 1.00088, 1.00895, 1.00079, 0.99953, 1.00092, 1.00044, 1.01005,
0.999598, 1.00151, 0.999659, 0.999648, 0.999627, 1.00315, 1.00277,
1.00247, 1.00003, 0.999937, 1.00116, 0.999521, 1.0005, 1.00091, 1.00213,
1.00019, 0.999767, 1.0003, 0.999815, 1.00003, 0.999672, 1.00306, 1.00072,
0.999602, 0.999789;
rhat << 1.00042, 1.00036, 0.99955, 1.00047, 1.00119, 1.00089, 1.00018,
1.00019, 1.00226, 0.99954, 0.9996, 0.99951, 1.00237, 1.00515, 1.00566,
0.99957, 1.00099, 1.00853, 1.0008, 0.99961, 1.0006, 1.00046, 1.01023,
0.9996, 1.0011, 0.99967, 0.99973, 0.99958, 1.00242, 1.00213, 1.00244,
0.99998, 0.99969, 1.00079, 0.99955, 1.0009, 1.00136, 1.00288, 1.00036,
0.99989, 1.00077, 0.99997, 1.00194, 0.99972, 1.00257, 1.00109, 1.00004,
0.99955;

// replicates calls to stan::analyze::compute_effective_sample_size
// for any interface *without* access to chains class
Expand Down Expand Up @@ -129,13 +129,13 @@ TEST_F(ComputeRhat, compute_potential_scale_reduction_convenience) {
chains.add(blocker2);

Eigen::VectorXd rhat(48);
rhat << 1.00067, 0.999789, 0.999656, 1.00055, 1.0011, 1.00088, 1.00032,
0.999969, 1.00201, 0.999558, 0.999555, 0.9995, 1.00292, 1.00516, 1.00591,
0.999753, 1.00088, 1.00895, 1.00079, 0.99953, 1.00092, 1.00044, 1.01005,
0.999598, 1.00151, 0.999659, 0.999648, 0.999627, 1.00315, 1.00277,
1.00247, 1.00003, 0.999937, 1.00116, 0.999521, 1.0005, 1.00091, 1.00213,
1.00019, 0.999767, 1.0003, 0.999815, 1.00003, 0.999672, 1.00306, 1.00072,
0.999602, 0.999789;
rhat << 1.00042, 1.00036, 0.99955, 1.00047, 1.00119, 1.00089, 1.00018,
1.00019, 1.00226, 0.99954, 0.9996, 0.99951, 1.00237, 1.00515, 1.00566,
0.99957, 1.00099, 1.00853, 1.0008, 0.99961, 1.0006, 1.00046, 1.01023,
0.9996, 1.0011, 0.99967, 0.99973, 0.99958, 1.00242, 1.00213, 1.00244,
0.99998, 0.99969, 1.00079, 0.99955, 1.0009, 1.00136, 1.00288, 1.00036,
0.99989, 1.00077, 0.99997, 1.00194, 0.99972, 1.00257, 1.00109, 1.00004,
0.99955;

Eigen::Matrix<Eigen::VectorXd, Eigen::Dynamic, 1> samples(
chains.num_chains());
Expand Down Expand Up @@ -222,12 +222,14 @@ TEST_F(ComputeRhat, chains_compute_split_potential_scale_reduction) {
chains.add(blocker2);

Eigen::VectorXd rhat(48);
rhat << 1.0078, 1.0109, 0.999187, 1.001, 1.00401, 1.00992, 1.00182, 1.00519,
1.00095, 1.00351, 1.00554, 1.00075, 1.00595, 1.00473, 1.00546, 1.01304,
1.00166, 1.0074, 1.00178, 1.00588, 1.00406, 1.00129, 1.00976, 1.0013,
1.00193, 1.00104, 0.999383, 1.00025, 1.01082, 1.0019, 1.00354, 1.0043,
1.00111, 1.00281, 1.00436, 1.00515, 1.00325, 1.0089, 1.00222, 1.00118,
1.00191, 1.00283, 1.0003, 1.00216, 1.00335, 1.00133, 1.00023, 1.0109;
rhat << 1.00718, 1.00473, 0.999203, 1.00061, 1.00378, 1.01031, 1.00173,
1.0045, 1.00111, 1.00337, 1.00546, 1.00105, 1.00558, 1.00463, 1.00534,
1.01244, 1.00174, 1.00718, 1.00186, 1.00554, 1.00436, 1.00147, 1.01017,
1.00162, 1.00143, 1.00058, 0.999221, 1.00012, 1.01028, 1.001, 1.00305,
1.00435, 1.00055, 1.00246, 1.00447, 1.0048, 1.00209, 1.01159, 1.00202,
1.00077, 1.0021, 1.00262, 1.00308, 1.00197, 1.00246, 1.00085, 1.00047,
1.00735;

for (int index = 4; index < chains.num_params(); index++) {
ASSERT_NEAR(rhat(index - 4), chains.split_potential_scale_reduction(index),
1e-4)
Expand Down Expand Up @@ -304,12 +306,13 @@ TEST_F(ComputeRhat, compute_split_potential_scale_reduction) {
chains.add(blocker2);

Eigen::VectorXd rhat(48);
rhat << 1.0078, 1.0109, 0.999187, 1.001, 1.00401, 1.00992, 1.00182, 1.00519,
1.00095, 1.00351, 1.00554, 1.00075, 1.00595, 1.00473, 1.00546, 1.01304,
1.00166, 1.0074, 1.00178, 1.00588, 1.00406, 1.00129, 1.00976, 1.0013,
1.00193, 1.00104, 0.999383, 1.00025, 1.01082, 1.0019, 1.00354, 1.0043,
1.00111, 1.00281, 1.00436, 1.00515, 1.00325, 1.0089, 1.00222, 1.00118,
1.00191, 1.00283, 1.0003, 1.00216, 1.00335, 1.00133, 1.00023, 1.0109;
rhat << 1.00718, 1.00473, 0.999203, 1.00061, 1.00378, 1.01031, 1.00173,
1.0045, 1.00111, 1.00337, 1.00546, 1.00105, 1.00558, 1.00463, 1.00534,
1.01244, 1.00174, 1.00718, 1.00186, 1.00554, 1.00436, 1.00147, 1.01017,
1.00162, 1.00143, 1.00058, 0.999221, 1.00012, 1.01028, 1.001, 1.00305,
1.00435, 1.00055, 1.00246, 1.00447, 1.0048, 1.00209, 1.01159, 1.00202,
1.00077, 1.0021, 1.00262, 1.00308, 1.00197, 1.00246, 1.00085, 1.00047,
1.00735;

// replicates calls to stan::analyze::compute_effective_sample_size
// for any interface *without* access to chains class
Expand Down Expand Up @@ -402,12 +405,13 @@ TEST_F(ComputeRhat, compute_split_potential_scale_reduction_convenience) {
chains.add(blocker2);

Eigen::VectorXd rhat(48);
rhat << 1.0078, 1.0109, 0.999187, 1.001, 1.00401, 1.00992, 1.00182, 1.00519,
1.00095, 1.00351, 1.00554, 1.00075, 1.00595, 1.00473, 1.00546, 1.01304,
1.00166, 1.0074, 1.00178, 1.00588, 1.00406, 1.00129, 1.00976, 1.0013,
1.00193, 1.00104, 0.999383, 1.00025, 1.01082, 1.0019, 1.00354, 1.0043,
1.00111, 1.00281, 1.00436, 1.00515, 1.00325, 1.0089, 1.00222, 1.00118,
1.00191, 1.00283, 1.0003, 1.00216, 1.00335, 1.00133, 1.00023, 1.0109;
rhat << 1.00718, 1.00473, 0.999203, 1.00061, 1.00378, 1.01031, 1.00173,
1.0045, 1.00111, 1.00337, 1.00546, 1.00105, 1.00558, 1.00463, 1.00534,
1.01244, 1.00174, 1.00718, 1.00186, 1.00554, 1.00436, 1.00147, 1.01017,
1.00162, 1.00143, 1.00058, 0.999221, 1.00012, 1.01028, 1.001, 1.00305,
1.00435, 1.00055, 1.00246, 1.00447, 1.0048, 1.00209, 1.01159, 1.00202,
1.00077, 1.0021, 1.00262, 1.00308, 1.00197, 1.00246, 1.00085, 1.00047,
1.00735;

Eigen::Matrix<Eigen::VectorXd, Eigen::Dynamic, 1> samples(
chains.num_chains());
Expand Down
22 changes: 16 additions & 6 deletions src/test/unit/mcmc/chains_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -853,12 +853,13 @@ TEST_F(McmcChains, blocker_split_potential_scale_reduction) {
chains.add(blocker2);

Eigen::VectorXd rhat(48);
rhat << 1.0078, 1.0109, 0.99919, 1.001, 1.00401, 1.00992, 1.00182, 1.00519,
1.00095, 1.00351, 1.00554, 1.00075, 1.00595, 1.00473, 1.00546, 1.01304,
1.00166, 1.0074, 1.00178, 1.00588, 1.00406, 1.00129, 1.00976, 1.0013,
1.00193, 1.00104, 0.99938, 1.00025, 1.01082, 1.0019, 1.00354, 1.0043,
1.00111, 1.00281, 1.00436, 1.00515, 1.00325, 1.0089, 1.00222, 1.00118,
1.00191, 1.00283, 1.0003, 1.00216, 1.00335, 1.00133, 1.00023, 1.0109;
rhat << 1.00718, 1.00473, 0.999203, 1.00061, 1.00378, 1.01031, 1.00173,
1.0045, 1.00111, 1.00337, 1.00546, 1.00105, 1.00558, 1.00463, 1.00534,
1.01244, 1.00174, 1.00718, 1.00186, 1.00554, 1.00436, 1.00147, 1.01017,
1.00162, 1.00143, 1.00058, 0.999221, 1.00012, 1.01028, 1.001, 1.00305,
1.00435, 1.00055, 1.00246, 1.00447, 1.0048, 1.00209, 1.01159, 1.00202,
1.00077, 1.0021, 1.00262, 1.00308, 1.00197, 1.00246, 1.00085, 1.00047,
1.00735;

for (int index = 4; index < chains.num_params(); index++) {
ASSERT_NEAR(rhat(index - 4), chains.split_potential_scale_reduction(index),
Expand All @@ -885,6 +886,15 @@ TEST_F(McmcChains, blocker_split_potential_scale_reduction_rank) {
stan::mcmc::chains<> chains(blocker1);
chains.add(blocker2);

// Eigen::VectorXd rhat(48);
// rhat
// << 1.0078, 1.0109, 1.00731, 1.00333, 1.00401, 1.00992, 1.00734, 1.00633,
// 1.00095, 1.00906, 1.01019, 1.00075, 1.00595, 1.00473, 1.00895, 1.01304,
// 1.00166, 1.0074, 1.00236, 1.00588, 1.00414, 1.00303, 1.00976, 1.00295,
// 1.00193, 1.0044, 1.00488, 1.00178, 1.01082, 1.0019, 1.00413, 1.01303,
// 1.0024, 1.01148, 1.00436, 1.00515, 1.00712, 1.0089, 1.00222, 1.00118,
// 1.00381, 1.00283, 1.00188, 1.00225, 1.00335, 1.00133, 1.00209, 1.0109;

Eigen::VectorXd rhat_bulk(48);
rhat_bulk << 1.0078, 1.0109, 0.99919, 1.001, 1.00401, 1.00992, 1.00182,
1.00519, 1.00095, 1.00351, 1.00554, 1.00075, 1.00595, 1.00473, 1.00546,
Expand Down

0 comments on commit b3631be

Please sign in to comment.