From 6052421c74c500793addf525a1de73ec7b4e0191 Mon Sep 17 00:00:00 2001 From: yiz Date: Tue, 18 Feb 2020 10:56:55 -0800 Subject: [PATCH] cross-chain degenerates into classic when only single chain exists --- src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp | 14 +++++++--- src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp | 25 ++++++++++-------- src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp | 26 +++++++++++-------- src/stan/services/util/mpi_cross_chain.hpp | 10 +++---- .../services/util/run_adaptive_sampler.hpp | 2 +- 5 files changed, 45 insertions(+), 32 deletions(-) diff --git a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp index 5518ecfa30f..4a557f1b1bb 100644 --- a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp @@ -113,9 +113,13 @@ namespace mcmc { inline void write_num_cross_chain_warmup(callbacks::writer& sample_writer, - int num_thin) { - size_t n = num_cross_chain_draws(); - sample_writer("num_warmup = " + std::to_string(n / num_thin)); + int num_thin, int num_warmup) { + if (use_cross_chain_adapt()) { + size_t n = num_cross_chain_draws(); + sample_writer("num_warmup = " + std::to_string(n / num_thin)); + } else { + sample_writer("num_warmup = " + std::to_string(num_warmup)); + } } /* @@ -454,7 +458,9 @@ namespace mcmc { sampler.set_nominal_stepsize(new_stepsize); } - inline bool use_cross_chain_adapt() { return true; } + inline bool use_cross_chain_adapt() { + return num_chains_ > 1; + } }; #else // sequential version diff --git a/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp index 8aa589d9689..09cb7b8299d 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp @@ -31,27 +31,30 @@ class adapt_dense_e_nuts : public dense_e_nuts, this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, s.accept_stat()); - bool update; if (this -> use_cross_chain_adapt()) { this -> add_cross_chain_sample(s.log_prob()); - update = this -> cross_chain_adaptation(logger); + bool update = this -> cross_chain_adaptation(logger); if (this -> is_cross_chain_adapted()) { update = false; } - } else { - update = this->covar_adaptation_.learn_covariance(this->z_.inv_e_metric_, - this->z_.q); - } - if (update) { - this->init_stepsize(logger); + if (update) { + this->init_stepsize(logger); - this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); - this->stepsize_adaptation_.restart(); + this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); + this->stepsize_adaptation_.restart(); - if (this -> use_cross_chain_adapt()) { this->set_cross_chain_stepsize(); } + } else { + bool update = this->covar_adaptation_.learn_covariance(this->z_.inv_e_metric_, + this->z_.q); + if (update) { + this->init_stepsize(logger); + + this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); + this->stepsize_adaptation_.restart(); + } } } return s; diff --git a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp index 96c34969205..e3ae02e905a 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp @@ -31,27 +31,31 @@ class adapt_diag_e_nuts : public diag_e_nuts, this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, s.accept_stat()); - bool update; + if (this -> use_cross_chain_adapt()) { this -> add_cross_chain_sample(s.log_prob()); - update = this -> cross_chain_adaptation(logger); + bool update = this -> cross_chain_adaptation(logger); if (this -> is_cross_chain_adapted()) { update = false; } - } else { - update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_, - this->z_.q); - } - if (update) { - this->init_stepsize(logger); + if (update) { + this->init_stepsize(logger); - this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); - this->stepsize_adaptation_.restart(); + this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); + this->stepsize_adaptation_.restart(); - if (this -> use_cross_chain_adapt()) { this->set_cross_chain_stepsize(); } + } else { + bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_, + this->z_.q); + if (update) { + this->init_stepsize(logger); + + this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); + this->stepsize_adaptation_.restart(); + } } } return s; diff --git a/src/stan/services/util/mpi_cross_chain.hpp b/src/stan/services/util/mpi_cross_chain.hpp index 2f1e99f3b66..15c03a28afa 100644 --- a/src/stan/services/util/mpi_cross_chain.hpp +++ b/src/stan/services/util/mpi_cross_chain.hpp @@ -54,7 +54,7 @@ namespace util { static void write_num_warmup(Sampler& sampler, callbacks::writer& sample_writer, - int num_thin) {} + int num_thin, int num_warmup) {} }; /* @@ -81,8 +81,8 @@ namespace util { static void write_num_warmup(Sampler& sampler, callbacks::writer& sample_writer, - int num_thin) { - sampler.write_num_cross_chain_warmup(sample_writer, num_thin); + int num_thin, int num_warmup) { + sampler.write_num_cross_chain_warmup(sample_writer, num_thin, num_warmup); } }; #endif @@ -111,9 +111,9 @@ namespace util { static void write_num_warmup(Sampler& sampler, callbacks::writer& sample_writer, - int num_thin) { + int num_thin, int num_warmup) { mpi_cross_chain_impl::value>:: - write_num_warmup(sampler, sample_writer, num_thin); + write_num_warmup(sampler, sample_writer, num_thin, num_warmup); } }; diff --git a/src/stan/services/util/run_adaptive_sampler.hpp b/src/stan/services/util/run_adaptive_sampler.hpp index 345e7ebde33..a3048394599 100644 --- a/src/stan/services/util/run_adaptive_sampler.hpp +++ b/src/stan/services/util/run_adaptive_sampler.hpp @@ -73,7 +73,7 @@ void run_adaptive_sampler(Sampler& sampler, Model& model, mpi_cross_chain::num_draws(sampler), num_warmup + num_samples, num_thin, refresh, save_warmup, true, writer, s, model, rng, interrupt, logger); - mpi_cross_chain::write_num_warmup(sampler, sample_writer, num_thin); + mpi_cross_chain::write_num_warmup(sampler, sample_writer, num_thin, num_warmup); clock_t end = clock(); double warm_delta_t = static_cast(end - start) / CLOCKS_PER_SEC;