From 1952caf94101c363d7d7bcd530ffa59652d6d7f5 Mon Sep 17 00:00:00 2001 From: Ben Bales Date: Thu, 18 Apr 2019 11:22:36 -0700 Subject: [PATCH 01/73] Added in switching adaptation --- lib/stan_math | 2 +- .../mcmc/hmc/nuts/adapt_switching_e_nuts.hpp | 59 +++++ src/stan/mcmc/stepsize_switching_adapter.hpp | 48 +++++ src/stan/mcmc/switching_adaptation.hpp | 202 ++++++++++++++++++ .../sample/hmc_nuts_switching_e_adapt.hpp | 181 ++++++++++++++++ 5 files changed, 491 insertions(+), 1 deletion(-) create mode 100644 src/stan/mcmc/hmc/nuts/adapt_switching_e_nuts.hpp create mode 100644 src/stan/mcmc/stepsize_switching_adapter.hpp create mode 100644 src/stan/mcmc/switching_adaptation.hpp create mode 100644 src/stan/services/sample/hmc_nuts_switching_e_adapt.hpp diff --git a/lib/stan_math b/lib/stan_math index 48d34c56fe3..5fe4e2996cb 160000 --- a/lib/stan_math +++ b/lib/stan_math @@ -1 +1 @@ -Subproject commit 48d34c56fe3750e30122ef35bc92350c2f4a0775 +Subproject commit 5fe4e2996cb93220ccff9b426ae837073dba3338 diff --git a/src/stan/mcmc/hmc/nuts/adapt_switching_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_switching_e_nuts.hpp new file mode 100644 index 00000000000..c9f660b6e01 --- /dev/null +++ b/src/stan/mcmc/hmc/nuts/adapt_switching_e_nuts.hpp @@ -0,0 +1,59 @@ +#ifndef STAN_MCMC_HMC_NUTS_ADAPT_SWITCHING_E_NUTS_HPP +#define STAN_MCMC_HMC_NUTS_ADAPT_SWITCHING_E_NUTS_HPP + +#include +#include +#include + +namespace stan { + namespace mcmc { + /** + * The No-U-Turn sampler (NUTS) with multinomial sampling + * with a Gaussian-Euclidean disintegration and adaptive + * dense metric and adaptive step size + */ + template + class adapt_switching_e_nuts : public dense_e_nuts, + public stepsize_switching_adapter { + protected: + const Model& model_; + public: + adapt_switching_e_nuts(const Model& model, BaseRNG& rng) + : model_(model), dense_e_nuts(model, rng), + stepsize_switching_adapter(model.num_params_r()) {} + + ~adapt_switching_e_nuts() {} + + sample + transition(sample& init_sample, callbacks::logger& logger) { + sample s = dense_e_nuts::transition(init_sample, + logger); + + if (this->adapt_flag_) { + this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, + s.accept_stat()); + + bool update = this->switching_adaptation_.learn_covariance( + model_, + this->z_.inv_e_metric_, + this->z_.q); + + if (update) { + this->init_stepsize(logger); + + this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); + this->stepsize_adaptation_.restart(); + } + } + return s; + } + + void disengage_adaptation() { + base_adapter::disengage_adaptation(); + this->stepsize_adaptation_.complete_adaptation(this->nom_epsilon_); + } + }; + + } // mcmc +} // stan +#endif diff --git a/src/stan/mcmc/stepsize_switching_adapter.hpp b/src/stan/mcmc/stepsize_switching_adapter.hpp new file mode 100644 index 00000000000..aaa35ad6238 --- /dev/null +++ b/src/stan/mcmc/stepsize_switching_adapter.hpp @@ -0,0 +1,48 @@ +#ifndef STAN_MCMC_STEPSIZE_SWITCHING_ADAPTER_HPP +#define STAN_MCMC_STEPSIZE_SWITCHING_ADAPTER_HPP + +#include +#include +#include +#include + +namespace stan { + + namespace mcmc { + + class stepsize_switching_adapter: public base_adapter { + public: + explicit stepsize_switching_adapter(int n) + : switching_adaptation_(n) { + } + + stepsize_adaptation& get_stepsize_adaptation() { + return stepsize_adaptation_; + } + + switching_adaptation& get_switching_adaptation() { + return switching_adaptation_; + } + + void set_window_params(unsigned int num_warmup, + unsigned int init_buffer, + unsigned int term_buffer, + unsigned int base_window, + callbacks::logger& logger) { + switching_adaptation_.set_window_params(num_warmup, + init_buffer, + term_buffer, + base_window, + logger); + } + + protected: + stepsize_adaptation stepsize_adaptation_; + switching_adaptation switching_adaptation_; + }; + + } // mcmc + +} // stan + +#endif diff --git a/src/stan/mcmc/switching_adaptation.hpp b/src/stan/mcmc/switching_adaptation.hpp new file mode 100644 index 00000000000..7630cc7a627 --- /dev/null +++ b/src/stan/mcmc/switching_adaptation.hpp @@ -0,0 +1,202 @@ +#ifndef STAN_MCMC_SWITCHING_ADAPTATION_HPP +#define STAN_MCMC_SWITCHING_ADAPTATION_HPP + +#include +#include +#include + +namespace stan { + + namespace mcmc { + template + struct log_prob_wrapper_covar { + const Model& model_; + log_prob_wrapper_covar(const Model& model) : model_(model) {} + + template + T operator()(const Eigen::Matrix& q) const { + return model_.template log_prob(const_cast& >(q), &std::cout); + } + }; + + template + class scaled_hessian_vector { + private: + const Model& model_; + const Eigen::MatrixXd& L_; + const Eigen::VectorXd& q_; + public: + scaled_hessian_vector(const Model& model, + const Eigen::MatrixXd& L, + const Eigen::VectorXd& q) : model_(model), + L_(L), + q_(q) {} + + int rows() { return q_.size(); } + int cols() { return q_.size(); } + + void perform_op(const double* x_in, double* y_out) { + Eigen::Map x(x_in, cols()); + Eigen::Map y(y_out, rows()); + + double lp; + Eigen::VectorXd grad1; + Eigen::VectorXd grad2; + //stan::math::hessian_times_vector(log_prob_wrapper_covar(model), q, x, lp, Ax); + double dx = 1e-5; + Eigen::VectorXd dr = L_ * x * dx; + stan::math::gradient(log_prob_wrapper_covar(model_), q_ + dr / 2.0, lp, grad1); + stan::math::gradient(log_prob_wrapper_covar(model_), q_ - dr / 2.0, lp, grad2); + y = L_.transpose() * (grad1 - grad2) / dx; + } + }; + + class switching_adaptation: public windowed_adaptation { + public: + explicit switching_adaptation(int n) + : windowed_adaptation("covariance") {} + + Eigen::MatrixXd covariance(const Eigen::MatrixXd& Y) { + Eigen::MatrixXd centered = Y.colwise() - Y.rowwise().mean(); + return centered * centered.transpose() / std::max(centered.rows() - 1.0, 1.0); + } + + template + double top_eigenvalue(const Model& model, const Eigen::MatrixXd& L, const Eigen::VectorXd& q) { + Eigen::VectorXd eigenvalues; + Eigen::MatrixXd eigenvectors; + + scaled_hessian_vector op(model, L, q); + + Spectra::SymEigsSolver eigs(&op, 1, 2); + eigs.init(); + eigs.compute(); + + if(eigs.info() != Spectra::SUCCESSFUL) { + throw std::domain_error("Failed to compute eigenvalue of Hessian of log density. The switching metric requires these"); + } + + return eigs.eigenvalues()(0); + } + + double bottom_eigenvalue_estimate(const Eigen::MatrixXd& L, const Eigen::MatrixXd& covar) { + Eigen::MatrixXd S = L.template triangularView(). + solve(L.template triangularView().solve(covar).transpose()).transpose(); + + Spectra::DenseSymMatProd op(S); + Spectra::SymEigsSolver eigs(&op, 1, 2); + eigs.init(); + eigs.compute(); + + if(eigs.info() != Spectra::SUCCESSFUL) { + throw std::domain_error("Failed to compute eigenvalue of covariance of log density. The switching metric requires these"); + } + + return -1.0 / eigs.eigenvalues()(0); + } + + template + bool learn_covariance(const Model& model, Eigen::MatrixXd& covar, const Eigen::VectorXd& q) { + if (adaptation_window()) + qs_.push_back(q); + + if (end_adaptation_window()) { + compute_next_window(); + + int N = q.size(); + int M = qs_.size(); + + Eigen::MatrixXd Y = Eigen::MatrixXd::Zero(N, M); + std::vector idxs(M); + for(int i = 0; i < qs_.size(); i++) + idxs[i] = i; + + std::random_shuffle(idxs.begin(), idxs.end()); + for(int i = 0; i < qs_.size(); i++) + Y.block(0, i, N, 1) = qs_[idxs[i]]; + + bool use_dense = false; + for(auto state : { "selection", "refinement" }) { + Eigen::MatrixXd Ytrain; + Eigen::MatrixXd Ytest; + + if(state == "selection") { + int Ntest; + Ntest = int(0.2 * Y.cols()); + if(Ntest < 5) { + Ntest = 5; + } + + if(Y.cols() < 10) { + throw std::runtime_error("Each warmup stage must have at least 10 samples"); + } + + std::cout << "train: " << Y.cols() - Ntest << ", test: " << Ntest << std::endl; + Ytrain = Y.block(0, 0, N, Y.cols() - Ntest); + Ytest = Y.block(0, Ytrain.cols(), N, Ntest); + } else { + Ytrain = Y; + } + + Eigen::MatrixXd cov_train = covariance(Ytrain); + Eigen::MatrixXd cov_test = covariance(Ytest); + + Eigen::MatrixXd dense = (N / (N + 5.0)) * cov_train + + 1e-3 * (5.0 / (N + 5.0)) * Eigen::MatrixXd::Identity(cov_train.rows(), cov_train.cols()); + Eigen::MatrixXd diag = dense.diagonal().asDiagonal(); + + covar = dense; + + if(state == "selection") { + Eigen::MatrixXd L_dense = dense.llt().matrixL(); + Eigen::MatrixXd L_diag = diag.diagonal().array().sqrt().matrix().asDiagonal(); + + double low_eigenvalue_dense = bottom_eigenvalue_estimate(L_dense, cov_test); + double low_eigenvalue_diag = bottom_eigenvalue_estimate(L_diag, cov_test); + + double c_dense = 0.0; + double c_diag = 0.0; + for(int i = 0; i < 5; i++) { + double high_eigenvalue_dense = top_eigenvalue(model, L_dense, Ytest.block(0, i, N, 1)); + double high_eigenvalue_diag = top_eigenvalue(model, L_diag, Ytest.block(0, i, N, 1)); + + c_dense = std::max(c_dense, std::sqrt(high_eigenvalue_dense / low_eigenvalue_dense)); + c_diag = std::max(c_diag, std::sqrt(high_eigenvalue_diag / low_eigenvalue_diag)); + } + + std::cout << "adapt: " << adapt_window_counter_ << ", which: dense, max: " << c_dense << std::endl; + std::cout << "adapt: " << adapt_window_counter_ << ", which: diag, max: " << c_diag << std::endl; + + if(c_dense < c_diag) { + use_dense = true; + } else { + use_dense = false; + } + } else { + if(use_dense) { + covar = dense; + } else { + covar = diag; + } + } + } + + ++adapt_window_counter_; + qs_.clear(); + + return true; + } + + ++adapt_window_counter_; + return false; + } + + protected: + std::vector< Eigen::VectorXd > qs_; + }; + + } // mcmc + +} // stan + +#endif diff --git a/src/stan/services/sample/hmc_nuts_switching_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_switching_e_adapt.hpp new file mode 100644 index 00000000000..dea87cc6c38 --- /dev/null +++ b/src/stan/services/sample/hmc_nuts_switching_e_adapt.hpp @@ -0,0 +1,181 @@ +#ifndef STAN_SERVICES_SAMPLE_HMC_NUTS_SWITCHING_E_ADAPT_HPP +#define STAN_SERVICES_SAMPLE_HMC_NUTS_SWITCHING_E_ADAPT_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { + namespace services { + namespace sample { + + /** + * Runs HMC with NUTS with adaptation using dense Euclidean metric + * with a pre-specified Euclidean metric. + * + * @tparam Model Model class + * @param[in] model Input model to test (with data already instantiated) + * @param[in] init var context for initialization + * @param[in] init_inv_metric var context exposing an initial dense + inverse Euclidean metric (must be positive definite) + * @param[in] random_seed random seed for the random number generator + * @param[in] chain chain id to advance the pseudo random number generator + * @param[in] init_radius radius to initialize + * @param[in] num_warmup Number of warmup samples + * @param[in] num_samples Number of samples + * @param[in] num_thin Number to thin the samples + * @param[in] save_warmup Indicates whether to save the warmup iterations + * @param[in] refresh Controls the output + * @param[in] stepsize initial stepsize for discrete evolution + * @param[in] stepsize_jitter uniform random jitter of stepsize + * @param[in] max_depth Maximum tree depth + * @param[in] delta adaptation target acceptance statistic + * @param[in] gamma adaptation regularization scale + * @param[in] kappa adaptation relaxation exponent + * @param[in] t0 adaptation iteration offset + * @param[in] init_buffer width of initial fast adaptation interval + * @param[in] term_buffer width of final fast adaptation interval + * @param[in] window initial width of slow adaptation interval + * @param[in,out] interrupt Callback for interrupts + * @param[in,out] logger Logger for messages + * @param[in,out] init_writer Writer callback for unconstrained inits + * @param[in,out] sample_writer Writer for draws + * @param[in,out] diagnostic_writer Writer for diagnostic information + * @return error_codes::OK if successful + */ + template + int hmc_nuts_switching_e_adapt(Model& model, stan::io::var_context& init, + stan::io::var_context& init_inv_metric, + unsigned int random_seed, unsigned int chain, + double init_radius, int num_warmup, + int num_samples, int num_thin, + bool save_warmup, int refresh, double stepsize, + double stepsize_jitter, int max_depth, + double delta, double gamma, double kappa, + double t0, unsigned int init_buffer, + unsigned int term_buffer, unsigned int window, + callbacks::interrupt& interrupt, + callbacks::logger& logger, + callbacks::writer& init_writer, + callbacks::writer& sample_writer, + callbacks::writer& diagnostic_writer) { + boost::ecuyer1988 rng = util::create_rng(random_seed, chain); + + std::vector disc_vector; + std::vector cont_vector + = util::initialize(model, init, rng, init_radius, true, + logger, init_writer); + + Eigen::MatrixXd inv_metric; + try { + inv_metric = + util::read_dense_inv_metric(init_inv_metric, model.num_params_r(), + logger); + util::validate_dense_inv_metric(inv_metric, logger); + } catch (const std::domain_error& e) { + return error_codes::CONFIG; + } + + stan::mcmc::adapt_switching_e_nuts + sampler(model, rng); + + sampler.set_metric(inv_metric); + + sampler.set_nominal_stepsize(stepsize); + sampler.set_stepsize_jitter(stepsize_jitter); + sampler.set_max_depth(max_depth); + + sampler.get_stepsize_adaptation().set_mu(log(10 * stepsize)); + sampler.get_stepsize_adaptation().set_delta(delta); + sampler.get_stepsize_adaptation().set_gamma(gamma); + sampler.get_stepsize_adaptation().set_kappa(kappa); + sampler.get_stepsize_adaptation().set_t0(t0); + + sampler.set_window_params(num_warmup, init_buffer, term_buffer, + window, logger); + + util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, + num_samples, num_thin, refresh, save_warmup, + rng, interrupt, logger, + sample_writer, diagnostic_writer); + + return error_codes::OK; + } + + /** + * Runs HMC with NUTS with adaptation using dense Euclidean metric, + * with identity matrix as initial inv_metric. + * + * @tparam Model Model class + * @param[in] model Input model to test (with data already instantiated) + * @param[in] init var context for initialization + * @param[in] random_seed random seed for the random number generator + * @param[in] chain chain id to advance the pseudo random number generator + * @param[in] init_radius radius to initialize + * @param[in] num_warmup Number of warmup samples + * @param[in] num_samples Number of samples + * @param[in] num_thin Number to thin the samples + * @param[in] save_warmup Indicates whether to save the warmup iterations + * @param[in] refresh Controls the output + * @param[in] stepsize initial stepsize for discrete evolution + * @param[in] stepsize_jitter uniform random jitter of stepsize + * @param[in] max_depth Maximum tree depth + * @param[in] delta adaptation target acceptance statistic + * @param[in] gamma adaptation regularization scale + * @param[in] kappa adaptation relaxation exponent + * @param[in] t0 adaptation iteration offset + * @param[in] init_buffer width of initial fast adaptation interval + * @param[in] term_buffer width of final fast adaptation interval + * @param[in] window initial width of slow adaptation interval + * @param[in,out] interrupt Callback for interrupts + * @param[in,out] logger Logger for messages + * @param[in,out] init_writer Writer callback for unconstrained inits + * @param[in,out] sample_writer Writer for draws + * @param[in,out] diagnostic_writer Writer for diagnostic information + * @return error_codes::OK if successful + */ + template + int hmc_nuts_switching_e_adapt(Model& model, stan::io::var_context& init, + unsigned int random_seed, unsigned int chain, + double init_radius, int num_warmup, + int num_samples, int num_thin, + bool save_warmup, int refresh, double stepsize, + double stepsize_jitter, int max_depth, + double delta, double gamma, double kappa, + double t0, unsigned int init_buffer, + unsigned int term_buffer, unsigned int window, + callbacks::interrupt& interrupt, + callbacks::logger& logger, + callbacks::writer& init_writer, + callbacks::writer& sample_writer, + callbacks::writer& diagnostic_writer) { + stan::io::dump dmp = + util::create_unit_e_dense_inv_metric(model.num_params_r()); + stan::io::var_context& unit_e_metric = dmp; + + return hmc_nuts_switching_e_adapt(model, init, unit_e_metric, + random_seed, chain, init_radius, + num_warmup, num_samples, num_thin, + save_warmup, refresh, + stepsize, stepsize_jitter, max_depth, + delta, gamma, kappa, t0, + init_buffer, term_buffer, window, + interrupt, logger, + init_writer, sample_writer, + diagnostic_writer); + } + + } + } +} +#endif From 4dc0fff4991d9759f9e0fb98d16f402d4f9606a8 Mon Sep 17 00:00:00 2001 From: Ben Bales Date: Sat, 20 Apr 2019 14:27:10 -0700 Subject: [PATCH 02/73] Simplified code and added some comments --- lib/stan_math | 2 +- .../hmc/hamiltonians/switching_e_metric.hpp | 76 ++++++++ .../hmc/hamiltonians/switching_e_point.hpp | 81 +++++++++ .../mcmc/hmc/nuts/adapt_switching_e_nuts.hpp | 11 +- src/stan/mcmc/hmc/nuts/switching_e_nuts.hpp | 26 +++ src/stan/mcmc/switching_adaptation.hpp | 164 +++++++++++------- 6 files changed, 291 insertions(+), 69 deletions(-) create mode 100644 src/stan/mcmc/hmc/hamiltonians/switching_e_metric.hpp create mode 100644 src/stan/mcmc/hmc/hamiltonians/switching_e_point.hpp create mode 100644 src/stan/mcmc/hmc/nuts/switching_e_nuts.hpp diff --git a/lib/stan_math b/lib/stan_math index 5fe4e2996cb..c2fb25c3f53 160000 --- a/lib/stan_math +++ b/lib/stan_math @@ -1 +1 @@ -Subproject commit 5fe4e2996cb93220ccff9b426ae837073dba3338 +Subproject commit c2fb25c3f53d3d33157a27ce995911b0bc92d55d diff --git a/src/stan/mcmc/hmc/hamiltonians/switching_e_metric.hpp b/src/stan/mcmc/hmc/hamiltonians/switching_e_metric.hpp new file mode 100644 index 00000000000..b86cf3a9136 --- /dev/null +++ b/src/stan/mcmc/hmc/hamiltonians/switching_e_metric.hpp @@ -0,0 +1,76 @@ +#ifndef STAN_MCMC_HMC_HAMILTONIANS_SWITCHING_E_METRIC_HPP +#define STAN_MCMC_HMC_HAMILTONIANS_SWITCHING_E_METRIC_HPP + +#include +#include +#include +#include +#include +#include +#include + +namespace stan { + namespace mcmc { + + // Euclidean manifold with dense metric + template + class switching_e_metric + : public base_hamiltonian { + public: + explicit switching_e_metric(const Model& model) + : base_hamiltonian(model) {} + + double T(switching_e_point& z) { + return 0.5 * z.p.transpose() * z.inv_e_metric_ * z.p; + } + + double tau(switching_e_point& z) { + return T(z); + } + + double phi(switching_e_point& z) { + return this->V(z); + } + + double dG_dt(switching_e_point& z, callbacks::logger& logger) { + return 2 * T(z) - z.q.dot(z.g); + } + + Eigen::VectorXd dtau_dq(switching_e_point& z, callbacks::logger& logger) { + return Eigen::VectorXd::Zero(this->model_.num_params_r()); + } + + Eigen::VectorXd dtau_dp(switching_e_point& z) { + if(z.is_diagonal_) { + return z.inv_e_metric_.diagonal().cwiseProduct(z.p); + } else { + return z.inv_e_metric_ * z.p; + } + } + + Eigen::VectorXd dphi_dq(switching_e_point& z, callbacks::logger& logger) { + return z.g; + } + + void sample_p(switching_e_point& z, BaseRNG& rng) { + typedef typename stan::math::index_type::type idx_t; + boost::variate_generator > + rand_gaus(rng, boost::normal_distribution<>()); + + if(z.is_diagonal_) { + for (int i = 0; i < z.p.size(); ++i) + z.p(i) = rand_gaus() / sqrt(z.inv_e_metric_(i, i)); + } else { + Eigen::VectorXd u(z.p.size()); + + for (idx_t i = 0; i < u.size(); ++i) + u(i) = rand_gaus(); + + z.p = z.inv_e_metric_.llt().matrixU().solve(u); + } + } + }; + + } // mcmc +} // stan +#endif diff --git a/src/stan/mcmc/hmc/hamiltonians/switching_e_point.hpp b/src/stan/mcmc/hmc/hamiltonians/switching_e_point.hpp new file mode 100644 index 00000000000..3d72230fab8 --- /dev/null +++ b/src/stan/mcmc/hmc/hamiltonians/switching_e_point.hpp @@ -0,0 +1,81 @@ +#ifndef STAN_MCMC_HMC_HAMILTONIANS_SWITCHING_E_POINT_HPP +#define STAN_MCMC_HMC_HAMILTONIANS_SWITCHING_E_POINT_HPP + +#include +#include + +namespace stan { + namespace mcmc { + /** + * Point in a phase space with a base + * Euclidean manifold with switching metric + */ + class switching_e_point: public ps_point { + public: + /** + * Inverse mass matrix. + */ + Eigen::MatrixXd inv_e_metric_; + + /** + * Is inv_e_metric_ diagonal or not + */ + bool is_diagonal_; + + /** + * Construct a switching point in n-dimensional phase space + * with identity matrix as inverse mass matrix. + * + * @param n number of dimensions + */ + explicit switching_e_point(int n) + : ps_point(n), inv_e_metric_(n, n), is_diagonal_(true) { + inv_e_metric_.setIdentity(); + } + + /** + * Copy constructor which does fast copy of inverse mass matrix. + * + * @param z point to copy + */ + switching_e_point(const switching_e_point& z) + : ps_point(z), inv_e_metric_(z.inv_e_metric_.rows(), + z.inv_e_metric_.cols()) { + fast_matrix_copy_(inv_e_metric_, z.inv_e_metric_); + is_diagonal_ = z.is_diagonal_; + } + + /** + * Set elements of mass matrix + * + * @param inv_e_metric initial mass matrix + */ + void + set_metric(const Eigen::MatrixXd& inv_e_metric) { + inv_e_metric_ = inv_e_metric; + is_diagonal_ = false; + } + + /** + * Write elements of mass matrix to string and handoff to writer. + * + * @param writer Stan writer callback + */ + inline + void + write_metric(stan::callbacks::writer& writer) { + writer("Elements of inverse mass matrix:"); + for (int i = 0; i < inv_e_metric_.rows(); ++i) { + std::stringstream inv_e_metric_ss; + inv_e_metric_ss << inv_e_metric_(i, 0); + for (int j = 1; j < inv_e_metric_.cols(); ++j) + inv_e_metric_ss << ", " << inv_e_metric_(i, j); + writer(inv_e_metric_ss.str()); + } + } + }; + + } // mcmc +} // stan + +#endif diff --git a/src/stan/mcmc/hmc/nuts/adapt_switching_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_switching_e_nuts.hpp index c9f660b6e01..b0a21c9e023 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_switching_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_switching_e_nuts.hpp @@ -3,7 +3,7 @@ #include #include -#include +#include namespace stan { namespace mcmc { @@ -13,21 +13,21 @@ namespace stan { * dense metric and adaptive step size */ template - class adapt_switching_e_nuts : public dense_e_nuts, + class adapt_switching_e_nuts : public switching_e_nuts, public stepsize_switching_adapter { protected: const Model& model_; public: adapt_switching_e_nuts(const Model& model, BaseRNG& rng) - : model_(model), dense_e_nuts(model, rng), + : model_(model), switching_e_nuts(model, rng), stepsize_switching_adapter(model.num_params_r()) {} ~adapt_switching_e_nuts() {} sample transition(sample& init_sample, callbacks::logger& logger) { - sample s = dense_e_nuts::transition(init_sample, - logger); + sample s = switching_e_nuts::transition(init_sample, + logger); if (this->adapt_flag_) { this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, @@ -36,6 +36,7 @@ namespace stan { bool update = this->switching_adaptation_.learn_covariance( model_, this->z_.inv_e_metric_, + this->z_.is_diagonal_, this->z_.q); if (update) { diff --git a/src/stan/mcmc/hmc/nuts/switching_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/switching_e_nuts.hpp new file mode 100644 index 00000000000..cd91df6fb8a --- /dev/null +++ b/src/stan/mcmc/hmc/nuts/switching_e_nuts.hpp @@ -0,0 +1,26 @@ +#ifndef STAN_MCMC_HMC_NUTS_SWITCHING_E_NUTS_HPP +#define STAN_MCMC_HMC_NUTS_SWITCHING_E_NUTS_HPP + +#include +#include +#include +#include + +namespace stan { + namespace mcmc { + /** + * The No-U-Turn sampler (NUTS) with multinomial sampling + * with a Gaussian-Euclidean disintegration and dense metric + */ + template + class switching_e_nuts : public base_nuts { + public: + switching_e_nuts(const Model& model, BaseRNG& rng) + : base_nuts(model, rng) { } + }; + + } // mcmc +} // stan +#endif diff --git a/src/stan/mcmc/switching_adaptation.hpp b/src/stan/mcmc/switching_adaptation.hpp index 7630cc7a627..088543c27b1 100644 --- a/src/stan/mcmc/switching_adaptation.hpp +++ b/src/stan/mcmc/switching_adaptation.hpp @@ -19,84 +19,121 @@ namespace stan { } }; - template - class scaled_hessian_vector { - private: - const Model& model_; - const Eigen::MatrixXd& L_; - const Eigen::VectorXd& q_; - public: - scaled_hessian_vector(const Model& model, - const Eigen::MatrixXd& L, - const Eigen::VectorXd& q) : model_(model), - L_(L), - q_(q) {} - - int rows() { return q_.size(); } - int cols() { return q_.size(); } - - void perform_op(const double* x_in, double* y_out) { - Eigen::Map x(x_in, cols()); - Eigen::Map y(y_out, rows()); - - double lp; - Eigen::VectorXd grad1; - Eigen::VectorXd grad2; - //stan::math::hessian_times_vector(log_prob_wrapper_covar(model), q, x, lp, Ax); - double dx = 1e-5; - Eigen::VectorXd dr = L_ * x * dx; - stan::math::gradient(log_prob_wrapper_covar(model_), q_ + dr / 2.0, lp, grad1); - stan::math::gradient(log_prob_wrapper_covar(model_), q_ - dr / 2.0, lp, grad2); - y = L_.transpose() * (grad1 - grad2) / dx; - } - }; - class switching_adaptation: public windowed_adaptation { public: explicit switching_adaptation(int n) : windowed_adaptation("covariance") {} + /** + * Compute the covariance of data in Y. Rows of Y are different data. Columns of Y are different variables. + * + * @param Y Data + * @return Covariance of Y + */ Eigen::MatrixXd covariance(const Eigen::MatrixXd& Y) { Eigen::MatrixXd centered = Y.colwise() - Y.rowwise().mean(); return centered * centered.transpose() / std::max(centered.rows() - 1.0, 1.0); } - - template - double top_eigenvalue(const Model& model, const Eigen::MatrixXd& L, const Eigen::VectorXd& q) { - Eigen::VectorXd eigenvalues; - Eigen::MatrixXd eigenvectors; - - scaled_hessian_vector op(model, L, q); - Spectra::SymEigsSolver eigs(&op, 1, 2); - eigs.init(); - eigs.compute(); - - if(eigs.info() != Spectra::SUCCESSFUL) { - throw std::domain_error("Failed to compute eigenvalue of Hessian of log density. The switching metric requires these"); + /** + * Compute the largest magnitude eigenvalue of a symmetric matrix using the power method. The function f + * should return the product of that matrix with an abitrary vector. + * + * f should take one Eigen::VectorXd argument, x, and return the product of a matrix with x as + * an Eigen::VectorXd argument of the same size. + * + * The eigenvalue is estimated iteratively. If the kth estimate is e_k, then the function returns when + * either abs(e_{k + 1} - e_k) < tol * abs(e_k) or the maximum number of iterations have been performed + * + * This means the returned eigenvalue might not be computed to full precision + * + * @param initial_guess Initial guess of the eigenvector of the largest eigenvalue + * @param max_iterations Maximum number of power iterations + * @param tol Relative tolerance + * @return Largest magnitude eigenvalue of operator f + */ + template + double power_method(F& f, const Eigen::VectorXd& initial_guess, int max_iterations, double tol) { + Eigen::VectorXd v = initial_guess; + double eval = 0.0; + + for(int i = 0; i < max_iterations; i++) { + Eigen::VectorXd Av = f(v); + double v_norm = v.norm(); + double new_eval = v.dot(Av) / (v_norm * v_norm); + if(std::abs(new_eval - eval) <= tol * std::abs(eval)) { + std::cout << "Converged at i = " << i << std::endl; + eval = new_eval; + break; + } + eval = new_eval; + v = Av / Av.norm(); } - return eigs.eigenvalues()(0); + return eval; } - double bottom_eigenvalue_estimate(const Eigen::MatrixXd& L, const Eigen::MatrixXd& covar) { - Eigen::MatrixXd S = L.template triangularView(). - solve(L.template triangularView().solve(covar).transpose()).transpose(); + /** + * Compute the largest eigenvalue of the Hessian of the log density rescaled by a metric, + * that is, the largest eigenvalue of L^T \nabla^2_{qq} H(q) L + * + * @tparam Model Type of model + * @param model Defines the log density + * @param q Point around which to compute the Hessian + * @param L Cholesky decomposition of Metric + * @return Largest eigenvalue + */ + template + double eigenvalue_scaled_hessian(const Model& model, const Eigen::MatrixXd& L, const Eigen::VectorXd& q) { + Eigen::VectorXd eigenvalues; + Eigen::MatrixXd eigenvectors; - Spectra::DenseSymMatProd op(S); - Spectra::SymEigsSolver eigs(&op, 1, 2); - eigs.init(); - eigs.compute(); + auto hessian_vector = [&](const Eigen::VectorXd& x) -> Eigen::VectorXd { + double lp; + Eigen::VectorXd grad1; + Eigen::VectorXd grad2; + //stan::math::hessian_times_vector(log_prob_wrapper_covar(model), q, x, lp, Ax); + double dx = 1e-5; + Eigen::VectorXd dr = L * x * dx; + stan::math::gradient(log_prob_wrapper_covar(model), q + dr / 2.0, lp, grad1); + stan::math::gradient(log_prob_wrapper_covar(model), q - dr / 2.0, lp, grad2); + return L.transpose() * (grad1 - grad2) / dx; + }; + + return power_method(hessian_vector, Eigen::VectorXd::Random(q.size()), 100, 1e-3); + } - if(eigs.info() != Spectra::SUCCESSFUL) { - throw std::domain_error("Failed to compute eigenvalue of covariance of log density. The switching metric requires these"); - } + /** + * Compute the largest eigenvalue of the sample covariance rescaled by a metric, + * that is, the largest eigenvalue of L^{-T} \Sigma L^{-1} + * + * @param L Cholesky decomposition of Metric + * @param Sigma Sample covariance + * @return Largest eigenvalue + */ + double eigenvalue_scaled_covariance(const Eigen::MatrixXd& L, const Eigen::MatrixXd& Sigma) { + Eigen::MatrixXd S = L.template triangularView(). + solve(L.template triangularView().solve(Sigma).transpose()).transpose(); - return -1.0 / eigs.eigenvalues()(0); + auto Sx = [&](Eigen::VectorXd x) -> Eigen::VectorXd { + return S * x; + }; + + return power_method(Sx, Eigen::VectorXd::Random(Sigma.cols()), 100, 1e-3); } + /** + * Update the metric if at the end of an adaptation window. + * + * @tparam Model Type of model + * @param model Defines the log density + * @param covar[out] New metric + * @param covar_is_diagonal[out] Set to true if metric is diagonal, false otherwise + * @param q New MCMC draw + * @return True if this was the end of an adaptation window, false otherwise + */ template - bool learn_covariance(const Model& model, Eigen::MatrixXd& covar, const Eigen::VectorXd& q) { + bool learn_covariance(const Model& model, Eigen::MatrixXd& covar, bool& covar_is_diagonal, const Eigen::VectorXd& q) { if (adaptation_window()) qs_.push_back(q); @@ -131,7 +168,6 @@ namespace stan { throw std::runtime_error("Each warmup stage must have at least 10 samples"); } - std::cout << "train: " << Y.cols() - Ntest << ", test: " << Ntest << std::endl; Ytrain = Y.block(0, 0, N, Y.cols() - Ntest); Ytest = Y.block(0, Ytrain.cols(), N, Ntest); } else { @@ -151,14 +187,14 @@ namespace stan { Eigen::MatrixXd L_dense = dense.llt().matrixL(); Eigen::MatrixXd L_diag = diag.diagonal().array().sqrt().matrix().asDiagonal(); - double low_eigenvalue_dense = bottom_eigenvalue_estimate(L_dense, cov_test); - double low_eigenvalue_diag = bottom_eigenvalue_estimate(L_diag, cov_test); + double low_eigenvalue_dense = -1.0 / eigenvalue_scaled_covariance(L_dense, cov_test); + double low_eigenvalue_diag = -1.0 / eigenvalue_scaled_covariance(L_diag, cov_test); double c_dense = 0.0; double c_diag = 0.0; for(int i = 0; i < 5; i++) { - double high_eigenvalue_dense = top_eigenvalue(model, L_dense, Ytest.block(0, i, N, 1)); - double high_eigenvalue_diag = top_eigenvalue(model, L_diag, Ytest.block(0, i, N, 1)); + double high_eigenvalue_dense = eigenvalue_scaled_hessian(model, L_dense, Ytest.block(0, i, N, 1)); + double high_eigenvalue_diag = eigenvalue_scaled_hessian(model, L_diag, Ytest.block(0, i, N, 1)); c_dense = std::max(c_dense, std::sqrt(high_eigenvalue_dense / low_eigenvalue_dense)); c_diag = std::max(c_diag, std::sqrt(high_eigenvalue_diag / low_eigenvalue_diag)); @@ -175,8 +211,10 @@ namespace stan { } else { if(use_dense) { covar = dense; + covar_is_diagonal = false; } else { covar = diag; + covar_is_diagonal = true; } } } From 800069148021c77f173248da14af273ecd3a11c9 Mon Sep 17 00:00:00 2001 From: Ben Bales Date: Sat, 20 Apr 2019 15:28:13 -0700 Subject: [PATCH 03/73] Changed 'switching' to 'auto' --- ...ing_adaptation.hpp => auto_adaptation.hpp} | 10 +-- ...itching_e_metric.hpp => auto_e_metric.hpp} | 30 +++---- ...switching_e_point.hpp => auto_e_point.hpp} | 14 +-- ...ching_e_nuts.hpp => adapt_auto_e_nuts.hpp} | 24 +++--- .../{switching_e_nuts.hpp => auto_e_nuts.hpp} | 14 +-- ..._adapter.hpp => stepsize_auto_adapter.hpp} | 20 ++--- ..._e_adapt.hpp => hmc_nuts_auto_e_adapt.hpp} | 86 +++++++++---------- 7 files changed, 99 insertions(+), 99 deletions(-) rename src/stan/mcmc/{switching_adaptation.hpp => auto_adaptation.hpp} (96%) rename src/stan/mcmc/hmc/hamiltonians/{switching_e_metric.hpp => auto_e_metric.hpp} (61%) rename src/stan/mcmc/hmc/hamiltonians/{switching_e_point.hpp => auto_e_point.hpp} (83%) rename src/stan/mcmc/hmc/nuts/{adapt_switching_e_nuts.hpp => adapt_auto_e_nuts.hpp} (63%) rename src/stan/mcmc/hmc/nuts/{switching_e_nuts.hpp => auto_e_nuts.hpp} (55%) rename src/stan/mcmc/{stepsize_switching_adapter.hpp => stepsize_auto_adapter.hpp} (60%) rename src/stan/services/sample/{hmc_nuts_switching_e_adapt.hpp => hmc_nuts_auto_e_adapt.hpp} (74%) diff --git a/src/stan/mcmc/switching_adaptation.hpp b/src/stan/mcmc/auto_adaptation.hpp similarity index 96% rename from src/stan/mcmc/switching_adaptation.hpp rename to src/stan/mcmc/auto_adaptation.hpp index 088543c27b1..9ff87bed575 100644 --- a/src/stan/mcmc/switching_adaptation.hpp +++ b/src/stan/mcmc/auto_adaptation.hpp @@ -1,5 +1,5 @@ -#ifndef STAN_MCMC_SWITCHING_ADAPTATION_HPP -#define STAN_MCMC_SWITCHING_ADAPTATION_HPP +#ifndef STAN_MCMC_AUTO_ADAPTATION_HPP +#define STAN_MCMC_AUTO_ADAPTATION_HPP #include #include @@ -19,9 +19,9 @@ namespace stan { } }; - class switching_adaptation: public windowed_adaptation { + class auto_adaptation: public windowed_adaptation { public: - explicit switching_adaptation(int n) + explicit auto_adaptation(int n) : windowed_adaptation("covariance") {} /** @@ -62,7 +62,7 @@ namespace stan { double v_norm = v.norm(); double new_eval = v.dot(Av) / (v_norm * v_norm); if(std::abs(new_eval - eval) <= tol * std::abs(eval)) { - std::cout << "Converged at i = " << i << std::endl; + //std::cout << "Converged at i = " << i << std::endl; eval = new_eval; break; } diff --git a/src/stan/mcmc/hmc/hamiltonians/switching_e_metric.hpp b/src/stan/mcmc/hmc/hamiltonians/auto_e_metric.hpp similarity index 61% rename from src/stan/mcmc/hmc/hamiltonians/switching_e_metric.hpp rename to src/stan/mcmc/hmc/hamiltonians/auto_e_metric.hpp index b86cf3a9136..bdacb513e72 100644 --- a/src/stan/mcmc/hmc/hamiltonians/switching_e_metric.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/auto_e_metric.hpp @@ -1,10 +1,10 @@ -#ifndef STAN_MCMC_HMC_HAMILTONIANS_SWITCHING_E_METRIC_HPP -#define STAN_MCMC_HMC_HAMILTONIANS_SWITCHING_E_METRIC_HPP +#ifndef STAN_MCMC_HMC_HAMILTONIANS_AUTO_E_METRIC_HPP +#define STAN_MCMC_HMC_HAMILTONIANS_AUTO_E_METRIC_HPP #include #include #include -#include +#include #include #include #include @@ -14,33 +14,33 @@ namespace stan { // Euclidean manifold with dense metric template - class switching_e_metric - : public base_hamiltonian { + class auto_e_metric + : public base_hamiltonian { public: - explicit switching_e_metric(const Model& model) - : base_hamiltonian(model) {} + explicit auto_e_metric(const Model& model) + : base_hamiltonian(model) {} - double T(switching_e_point& z) { + double T(auto_e_point& z) { return 0.5 * z.p.transpose() * z.inv_e_metric_ * z.p; } - double tau(switching_e_point& z) { + double tau(auto_e_point& z) { return T(z); } - double phi(switching_e_point& z) { + double phi(auto_e_point& z) { return this->V(z); } - double dG_dt(switching_e_point& z, callbacks::logger& logger) { + double dG_dt(auto_e_point& z, callbacks::logger& logger) { return 2 * T(z) - z.q.dot(z.g); } - Eigen::VectorXd dtau_dq(switching_e_point& z, callbacks::logger& logger) { + Eigen::VectorXd dtau_dq(auto_e_point& z, callbacks::logger& logger) { return Eigen::VectorXd::Zero(this->model_.num_params_r()); } - Eigen::VectorXd dtau_dp(switching_e_point& z) { + Eigen::VectorXd dtau_dp(auto_e_point& z) { if(z.is_diagonal_) { return z.inv_e_metric_.diagonal().cwiseProduct(z.p); } else { @@ -48,11 +48,11 @@ namespace stan { } } - Eigen::VectorXd dphi_dq(switching_e_point& z, callbacks::logger& logger) { + Eigen::VectorXd dphi_dq(auto_e_point& z, callbacks::logger& logger) { return z.g; } - void sample_p(switching_e_point& z, BaseRNG& rng) { + void sample_p(auto_e_point& z, BaseRNG& rng) { typedef typename stan::math::index_type::type idx_t; boost::variate_generator > rand_gaus(rng, boost::normal_distribution<>()); diff --git a/src/stan/mcmc/hmc/hamiltonians/switching_e_point.hpp b/src/stan/mcmc/hmc/hamiltonians/auto_e_point.hpp similarity index 83% rename from src/stan/mcmc/hmc/hamiltonians/switching_e_point.hpp rename to src/stan/mcmc/hmc/hamiltonians/auto_e_point.hpp index 3d72230fab8..69ba07fb35b 100644 --- a/src/stan/mcmc/hmc/hamiltonians/switching_e_point.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/auto_e_point.hpp @@ -1,5 +1,5 @@ -#ifndef STAN_MCMC_HMC_HAMILTONIANS_SWITCHING_E_POINT_HPP -#define STAN_MCMC_HMC_HAMILTONIANS_SWITCHING_E_POINT_HPP +#ifndef STAN_MCMC_HMC_HAMILTONIANS_AUTO_E_POINT_HPP +#define STAN_MCMC_HMC_HAMILTONIANS_AUTO_E_POINT_HPP #include #include @@ -8,9 +8,9 @@ namespace stan { namespace mcmc { /** * Point in a phase space with a base - * Euclidean manifold with switching metric + * Euclidean manifold with auto metric */ - class switching_e_point: public ps_point { + class auto_e_point: public ps_point { public: /** * Inverse mass matrix. @@ -23,12 +23,12 @@ namespace stan { bool is_diagonal_; /** - * Construct a switching point in n-dimensional phase space + * Construct a auto point in n-dimensional phase space * with identity matrix as inverse mass matrix. * * @param n number of dimensions */ - explicit switching_e_point(int n) + explicit auto_e_point(int n) : ps_point(n), inv_e_metric_(n, n), is_diagonal_(true) { inv_e_metric_.setIdentity(); } @@ -38,7 +38,7 @@ namespace stan { * * @param z point to copy */ - switching_e_point(const switching_e_point& z) + auto_e_point(const auto_e_point& z) : ps_point(z), inv_e_metric_(z.inv_e_metric_.rows(), z.inv_e_metric_.cols()) { fast_matrix_copy_(inv_e_metric_, z.inv_e_metric_); diff --git a/src/stan/mcmc/hmc/nuts/adapt_switching_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_auto_e_nuts.hpp similarity index 63% rename from src/stan/mcmc/hmc/nuts/adapt_switching_e_nuts.hpp rename to src/stan/mcmc/hmc/nuts/adapt_auto_e_nuts.hpp index b0a21c9e023..4335f61df9d 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_switching_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_auto_e_nuts.hpp @@ -1,9 +1,9 @@ -#ifndef STAN_MCMC_HMC_NUTS_ADAPT_SWITCHING_E_NUTS_HPP -#define STAN_MCMC_HMC_NUTS_ADAPT_SWITCHING_E_NUTS_HPP +#ifndef STAN_MCMC_HMC_NUTS_ADAPT_AUTO_E_NUTS_HPP +#define STAN_MCMC_HMC_NUTS_ADAPT_AUTO_E_NUTS_HPP #include -#include -#include +#include +#include namespace stan { namespace mcmc { @@ -13,27 +13,27 @@ namespace stan { * dense metric and adaptive step size */ template - class adapt_switching_e_nuts : public switching_e_nuts, - public stepsize_switching_adapter { + class adapt_auto_e_nuts : public auto_e_nuts, + public stepsize_auto_adapter { protected: const Model& model_; public: - adapt_switching_e_nuts(const Model& model, BaseRNG& rng) - : model_(model), switching_e_nuts(model, rng), - stepsize_switching_adapter(model.num_params_r()) {} + adapt_auto_e_nuts(const Model& model, BaseRNG& rng) + : model_(model), auto_e_nuts(model, rng), + stepsize_auto_adapter(model.num_params_r()) {} - ~adapt_switching_e_nuts() {} + ~adapt_auto_e_nuts() {} sample transition(sample& init_sample, callbacks::logger& logger) { - sample s = switching_e_nuts::transition(init_sample, + sample s = auto_e_nuts::transition(init_sample, logger); if (this->adapt_flag_) { this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, s.accept_stat()); - bool update = this->switching_adaptation_.learn_covariance( + bool update = this->auto_adaptation_.learn_covariance( model_, this->z_.inv_e_metric_, this->z_.is_diagonal_, diff --git a/src/stan/mcmc/hmc/nuts/switching_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/auto_e_nuts.hpp similarity index 55% rename from src/stan/mcmc/hmc/nuts/switching_e_nuts.hpp rename to src/stan/mcmc/hmc/nuts/auto_e_nuts.hpp index cd91df6fb8a..4ed11dff61e 100644 --- a/src/stan/mcmc/hmc/nuts/switching_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/auto_e_nuts.hpp @@ -1,9 +1,9 @@ -#ifndef STAN_MCMC_HMC_NUTS_SWITCHING_E_NUTS_HPP -#define STAN_MCMC_HMC_NUTS_SWITCHING_E_NUTS_HPP +#ifndef STAN_MCMC_HMC_NUTS_AUTO_E_NUTS_HPP +#define STAN_MCMC_HMC_NUTS_AUTO_E_NUTS_HPP #include #include -#include +#include #include namespace stan { @@ -13,11 +13,11 @@ namespace stan { * with a Gaussian-Euclidean disintegration and dense metric */ template - class switching_e_nuts : public base_nuts { + class auto_e_nuts : public base_nuts { public: - switching_e_nuts(const Model& model, BaseRNG& rng) - : base_nuts(model, rng) { } }; diff --git a/src/stan/mcmc/stepsize_switching_adapter.hpp b/src/stan/mcmc/stepsize_auto_adapter.hpp similarity index 60% rename from src/stan/mcmc/stepsize_switching_adapter.hpp rename to src/stan/mcmc/stepsize_auto_adapter.hpp index aaa35ad6238..0db72870d38 100644 --- a/src/stan/mcmc/stepsize_switching_adapter.hpp +++ b/src/stan/mcmc/stepsize_auto_adapter.hpp @@ -1,27 +1,27 @@ -#ifndef STAN_MCMC_STEPSIZE_SWITCHING_ADAPTER_HPP -#define STAN_MCMC_STEPSIZE_SWITCHING_ADAPTER_HPP +#ifndef STAN_MCMC_STEPSIZE_AUTO_ADAPTER_HPP +#define STAN_MCMC_STEPSIZE_AUTO_ADAPTER_HPP #include #include #include -#include +#include namespace stan { namespace mcmc { - class stepsize_switching_adapter: public base_adapter { + class stepsize_auto_adapter: public base_adapter { public: - explicit stepsize_switching_adapter(int n) - : switching_adaptation_(n) { + explicit stepsize_auto_adapter(int n) + : auto_adaptation_(n) { } stepsize_adaptation& get_stepsize_adaptation() { return stepsize_adaptation_; } - switching_adaptation& get_switching_adaptation() { - return switching_adaptation_; + auto_adaptation& get_auto_adaptation() { + return auto_adaptation_; } void set_window_params(unsigned int num_warmup, @@ -29,7 +29,7 @@ namespace stan { unsigned int term_buffer, unsigned int base_window, callbacks::logger& logger) { - switching_adaptation_.set_window_params(num_warmup, + auto_adaptation_.set_window_params(num_warmup, init_buffer, term_buffer, base_window, @@ -38,7 +38,7 @@ namespace stan { protected: stepsize_adaptation stepsize_adaptation_; - switching_adaptation switching_adaptation_; + auto_adaptation auto_adaptation_; }; } // mcmc diff --git a/src/stan/services/sample/hmc_nuts_switching_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_auto_e_adapt.hpp similarity index 74% rename from src/stan/services/sample/hmc_nuts_switching_e_adapt.hpp rename to src/stan/services/sample/hmc_nuts_auto_e_adapt.hpp index dea87cc6c38..464880d3164 100644 --- a/src/stan/services/sample/hmc_nuts_switching_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_auto_e_adapt.hpp @@ -1,5 +1,5 @@ -#ifndef STAN_SERVICES_SAMPLE_HMC_NUTS_SWITCHING_E_ADAPT_HPP -#define STAN_SERVICES_SAMPLE_HMC_NUTS_SWITCHING_E_ADAPT_HPP +#ifndef STAN_SERVICES_SAMPLE_HMC_NUTS_AUTO_E_ADAPT_HPP +#define STAN_SERVICES_SAMPLE_HMC_NUTS_AUTO_E_ADAPT_HPP #include #include @@ -8,7 +8,7 @@ #include #include #include -#include +#include #include #include #include @@ -54,21 +54,21 @@ namespace stan { * @return error_codes::OK if successful */ template - int hmc_nuts_switching_e_adapt(Model& model, stan::io::var_context& init, - stan::io::var_context& init_inv_metric, - unsigned int random_seed, unsigned int chain, - double init_radius, int num_warmup, - int num_samples, int num_thin, - bool save_warmup, int refresh, double stepsize, - double stepsize_jitter, int max_depth, - double delta, double gamma, double kappa, - double t0, unsigned int init_buffer, - unsigned int term_buffer, unsigned int window, - callbacks::interrupt& interrupt, - callbacks::logger& logger, - callbacks::writer& init_writer, - callbacks::writer& sample_writer, - callbacks::writer& diagnostic_writer) { + int hmc_nuts_auto_e_adapt(Model& model, stan::io::var_context& init, + stan::io::var_context& init_inv_metric, + unsigned int random_seed, unsigned int chain, + double init_radius, int num_warmup, + int num_samples, int num_thin, + bool save_warmup, int refresh, double stepsize, + double stepsize_jitter, int max_depth, + double delta, double gamma, double kappa, + double t0, unsigned int init_buffer, + unsigned int term_buffer, unsigned int window, + callbacks::interrupt& interrupt, + callbacks::logger& logger, + callbacks::writer& init_writer, + callbacks::writer& sample_writer, + callbacks::writer& diagnostic_writer) { boost::ecuyer1988 rng = util::create_rng(random_seed, chain); std::vector disc_vector; @@ -86,7 +86,7 @@ namespace stan { return error_codes::CONFIG; } - stan::mcmc::adapt_switching_e_nuts + stan::mcmc::adapt_auto_e_nuts sampler(model, rng); sampler.set_metric(inv_metric); @@ -145,34 +145,34 @@ namespace stan { * @return error_codes::OK if successful */ template - int hmc_nuts_switching_e_adapt(Model& model, stan::io::var_context& init, - unsigned int random_seed, unsigned int chain, - double init_radius, int num_warmup, - int num_samples, int num_thin, - bool save_warmup, int refresh, double stepsize, - double stepsize_jitter, int max_depth, - double delta, double gamma, double kappa, - double t0, unsigned int init_buffer, - unsigned int term_buffer, unsigned int window, - callbacks::interrupt& interrupt, - callbacks::logger& logger, - callbacks::writer& init_writer, - callbacks::writer& sample_writer, - callbacks::writer& diagnostic_writer) { + int hmc_nuts_auto_e_adapt(Model& model, stan::io::var_context& init, + unsigned int random_seed, unsigned int chain, + double init_radius, int num_warmup, + int num_samples, int num_thin, + bool save_warmup, int refresh, double stepsize, + double stepsize_jitter, int max_depth, + double delta, double gamma, double kappa, + double t0, unsigned int init_buffer, + unsigned int term_buffer, unsigned int window, + callbacks::interrupt& interrupt, + callbacks::logger& logger, + callbacks::writer& init_writer, + callbacks::writer& sample_writer, + callbacks::writer& diagnostic_writer) { stan::io::dump dmp = util::create_unit_e_dense_inv_metric(model.num_params_r()); stan::io::var_context& unit_e_metric = dmp; - return hmc_nuts_switching_e_adapt(model, init, unit_e_metric, - random_seed, chain, init_radius, - num_warmup, num_samples, num_thin, - save_warmup, refresh, - stepsize, stepsize_jitter, max_depth, - delta, gamma, kappa, t0, - init_buffer, term_buffer, window, - interrupt, logger, - init_writer, sample_writer, - diagnostic_writer); + return hmc_nuts_auto_e_adapt(model, init, unit_e_metric, + random_seed, chain, init_radius, + num_warmup, num_samples, num_thin, + save_warmup, refresh, + stepsize, stepsize_jitter, max_depth, + delta, gamma, kappa, t0, + init_buffer, term_buffer, window, + interrupt, logger, + init_writer, sample_writer, + diagnostic_writer); } } From b1205f3422e425c5bf5e56345869cc1c70f377ca Mon Sep 17 00:00:00 2001 From: Ben Bales Date: Mon, 22 Apr 2019 15:51:47 -0700 Subject: [PATCH 04/73] Added try/catch so that if auto adaptation fails it falls back to diagonal --- src/stan/mcmc/auto_adaptation.hpp | 113 ++++++++++++++++-------------- 1 file changed, 61 insertions(+), 52 deletions(-) diff --git a/src/stan/mcmc/auto_adaptation.hpp b/src/stan/mcmc/auto_adaptation.hpp index 9ff87bed575..daae43b0d0f 100644 --- a/src/stan/mcmc/auto_adaptation.hpp +++ b/src/stan/mcmc/auto_adaptation.hpp @@ -152,71 +152,80 @@ namespace stan { for(int i = 0; i < qs_.size(); i++) Y.block(0, i, N, 1) = qs_[idxs[i]]; - bool use_dense = false; - for(auto state : { "selection", "refinement" }) { - Eigen::MatrixXd Ytrain; - Eigen::MatrixXd Ytest; - - if(state == "selection") { - int Ntest; - Ntest = int(0.2 * Y.cols()); - if(Ntest < 5) { - Ntest = 5; - } - - if(Y.cols() < 10) { - throw std::runtime_error("Each warmup stage must have at least 10 samples"); - } + try { + bool use_dense = false; + for(auto state : { "selection", "refinement" }) { + Eigen::MatrixXd Ytrain; + Eigen::MatrixXd Ytest; + + if(state == "selection") { + int Ntest; + Ntest = int(0.2 * Y.cols()); + if(Ntest < 5) { + Ntest = 5; + } + + if(Y.cols() < 10) { + throw std::runtime_error("Each warmup stage must have at least 10 samples"); + } - Ytrain = Y.block(0, 0, N, Y.cols() - Ntest); - Ytest = Y.block(0, Ytrain.cols(), N, Ntest); - } else { - Ytrain = Y; - } + Ytrain = Y.block(0, 0, N, Y.cols() - Ntest); + Ytest = Y.block(0, Ytrain.cols(), N, Ntest); + } else { + Ytrain = Y; + } - Eigen::MatrixXd cov_train = covariance(Ytrain); - Eigen::MatrixXd cov_test = covariance(Ytest); + Eigen::MatrixXd cov_train = covariance(Ytrain); + Eigen::MatrixXd cov_test = covariance(Ytest); - Eigen::MatrixXd dense = (N / (N + 5.0)) * cov_train + - 1e-3 * (5.0 / (N + 5.0)) * Eigen::MatrixXd::Identity(cov_train.rows(), cov_train.cols()); - Eigen::MatrixXd diag = dense.diagonal().asDiagonal(); + Eigen::MatrixXd dense = (N / (N + 5.0)) * cov_train + + 1e-3 * (5.0 / (N + 5.0)) * Eigen::MatrixXd::Identity(cov_train.rows(), cov_train.cols()); + Eigen::MatrixXd diag = dense.diagonal().asDiagonal(); - covar = dense; + covar = dense; - if(state == "selection") { - Eigen::MatrixXd L_dense = dense.llt().matrixL(); - Eigen::MatrixXd L_diag = diag.diagonal().array().sqrt().matrix().asDiagonal(); + if(state == "selection") { + Eigen::MatrixXd L_dense = dense.llt().matrixL(); + Eigen::MatrixXd L_diag = diag.diagonal().array().sqrt().matrix().asDiagonal(); - double low_eigenvalue_dense = -1.0 / eigenvalue_scaled_covariance(L_dense, cov_test); - double low_eigenvalue_diag = -1.0 / eigenvalue_scaled_covariance(L_diag, cov_test); + double low_eigenvalue_dense = -1.0 / eigenvalue_scaled_covariance(L_dense, cov_test); + double low_eigenvalue_diag = -1.0 / eigenvalue_scaled_covariance(L_diag, cov_test); - double c_dense = 0.0; - double c_diag = 0.0; - for(int i = 0; i < 5; i++) { - double high_eigenvalue_dense = eigenvalue_scaled_hessian(model, L_dense, Ytest.block(0, i, N, 1)); - double high_eigenvalue_diag = eigenvalue_scaled_hessian(model, L_diag, Ytest.block(0, i, N, 1)); + double c_dense = 0.0; + double c_diag = 0.0; + for(int i = 0; i < 5; i++) { + double high_eigenvalue_dense = eigenvalue_scaled_hessian(model, L_dense, Ytest.block(0, i, N, 1)); + double high_eigenvalue_diag = eigenvalue_scaled_hessian(model, L_diag, Ytest.block(0, i, N, 1)); - c_dense = std::max(c_dense, std::sqrt(high_eigenvalue_dense / low_eigenvalue_dense)); - c_diag = std::max(c_diag, std::sqrt(high_eigenvalue_diag / low_eigenvalue_diag)); - } + c_dense = std::max(c_dense, std::sqrt(high_eigenvalue_dense / low_eigenvalue_dense)); + c_diag = std::max(c_diag, std::sqrt(high_eigenvalue_diag / low_eigenvalue_diag)); + } - std::cout << "adapt: " << adapt_window_counter_ << ", which: dense, max: " << c_dense << std::endl; - std::cout << "adapt: " << adapt_window_counter_ << ", which: diag, max: " << c_diag << std::endl; + std::cout << "adapt: " << adapt_window_counter_ << ", which: dense, max: " << c_dense << std::endl; + std::cout << "adapt: " << adapt_window_counter_ << ", which: diag, max: " << c_diag << std::endl; - if(c_dense < c_diag) { - use_dense = true; - } else { - use_dense = false; - } - } else { - if(use_dense) { - covar = dense; - covar_is_diagonal = false; + if(c_dense < c_diag) { + use_dense = true; + } else { + use_dense = false; + } } else { - covar = diag; - covar_is_diagonal = true; + if(use_dense) { + covar = dense; + covar_is_diagonal = false; + } else { + covar = diag; + covar_is_diagonal = true; + } } } + } catch(const std::exception& e) { + std::cout << e.what() << std::endl; + std::cout << "Exception while using auto adaptation, falling back to diagonal" << std::endl; + Eigen::MatrixXd cov = covariance(Y); + covar = ((M / (M + 5.0)) * cov.diagonal() + + 1e-3 * (5.0 / (M + 5.0)) * Eigen::VectorXd::Ones(cov.cols())).asDiagonal(); + covar_is_diagonal = true; } ++adapt_window_counter_; From baacc34304202facdabe5593f531b996728ccc67 Mon Sep 17 00:00:00 2001 From: Ben Date: Wed, 11 Sep 2019 15:11:35 -0400 Subject: [PATCH 05/73] Fixed bug with regularization of automatically picked metric. Added tests for automatically picked metric --- src/stan/mcmc/auto_adaptation.hpp | 135 ++++++++------ .../good/model/correlated_gaussian.stan | 7 + .../good/model/independent_gaussian.stan | 7 + .../test-models/good/model/known_hessian.stan | 7 + ...ation_learn_covariance_pick_dense_test.cpp | 64 +++++++ ...tation_learn_covariance_pick_diag_test.cpp | 63 +++++++ src/test/unit/mcmc/auto_adaptation_test.cpp | 170 ++++++++++++++++++ 7 files changed, 396 insertions(+), 57 deletions(-) create mode 100644 src/test/test-models/good/model/correlated_gaussian.stan create mode 100644 src/test/test-models/good/model/independent_gaussian.stan create mode 100644 src/test/test-models/good/model/known_hessian.stan create mode 100644 src/test/unit/mcmc/auto_adaptation_learn_covariance_pick_dense_test.cpp create mode 100644 src/test/unit/mcmc/auto_adaptation_learn_covariance_pick_diag_test.cpp create mode 100644 src/test/unit/mcmc/auto_adaptation_test.cpp diff --git a/src/stan/mcmc/auto_adaptation.hpp b/src/stan/mcmc/auto_adaptation.hpp index daae43b0d0f..f7f85164ec2 100644 --- a/src/stan/mcmc/auto_adaptation.hpp +++ b/src/stan/mcmc/auto_adaptation.hpp @@ -1,7 +1,7 @@ #ifndef STAN_MCMC_AUTO_ADAPTATION_HPP #define STAN_MCMC_AUTO_ADAPTATION_HPP -#include +#include #include #include @@ -19,20 +19,23 @@ namespace stan { } }; - class auto_adaptation: public windowed_adaptation { - public: - explicit auto_adaptation(int n) - : windowed_adaptation("covariance") {} - + namespace internal { /** - * Compute the covariance of data in Y. Rows of Y are different data. Columns of Y are different variables. + * Compute the covariance of data in Y. + * + * Columns of Y are different variables. Rows are different samples. + * + * When there is only one row in Y, return a covariance matrix of the expected + * size filled with zeros. * * @param Y Data * @return Covariance of Y */ Eigen::MatrixXd covariance(const Eigen::MatrixXd& Y) { - Eigen::MatrixXd centered = Y.colwise() - Y.rowwise().mean(); - return centered * centered.transpose() / std::max(centered.rows() - 1.0, 1.0); + stan::math::check_nonzero_size("covariance", "Y", Y); + + Eigen::MatrixXd centered = Y.rowwise() - Y.colwise().mean(); + return centered.transpose() * centered / std::max(centered.rows() - 1.0, 1.0); } /** @@ -48,31 +51,59 @@ namespace stan { * This means the returned eigenvalue might not be computed to full precision * * @param initial_guess Initial guess of the eigenvector of the largest eigenvalue - * @param max_iterations Maximum number of power iterations - * @param tol Relative tolerance + * @param[in,out] max_iterations Maximum number of power iterations, on return number of iterations used + * @param[in,out] tol Relative tolerance, on return the relative error in the eigenvalue estimate * @return Largest magnitude eigenvalue of operator f */ template - double power_method(F& f, const Eigen::VectorXd& initial_guess, int max_iterations, double tol) { + double power_method(F& f, const Eigen::VectorXd& initial_guess, int& max_iterations, double& tol) { Eigen::VectorXd v = initial_guess; double eval = 0.0; + Eigen::VectorXd Av = f(v); + stan::math::check_matching_sizes("power_method", "matrix vector product", Av, "vector", v); - for(int i = 0; i < max_iterations; i++) { - Eigen::VectorXd Av = f(v); + int i = 0; + for(; i < max_iterations; ++i) { double v_norm = v.norm(); double new_eval = v.dot(Av) / (v_norm * v_norm); - if(std::abs(new_eval - eval) <= tol * std::abs(eval)) { - //std::cout << "Converged at i = " << i << std::endl; + if(i == max_iterations - 1 || std::abs(new_eval - eval) <= tol * std::abs(eval)) { + tol = std::abs(new_eval - eval) / std::abs(eval); eval = new_eval; + max_iterations = i + 1; break; } + eval = new_eval; v = Av / Av.norm(); + + Av = f(v); } return eval; } + /** + * Compute the largest eigenvalue of the sample covariance rescaled by a metric, + * that is, the largest eigenvalue of L^{-1} \Sigma L^{-T} + * + * @param L Cholesky decomposition of Metric + * @param Sigma Sample covariance + * @return Largest eigenvalue + */ + double eigenvalue_scaled_covariance(const Eigen::MatrixXd& L, const Eigen::MatrixXd& Sigma) { + Eigen::MatrixXd S = L.template triangularView(). + solve(L.template triangularView().solve(Sigma).transpose()).transpose(); + + auto Sx = [&](Eigen::VectorXd x) -> Eigen::VectorXd { + return S * x; + }; + + int max_iterations = 100; + double tol = 1e-3; + + return internal::power_method(Sx, Eigen::VectorXd::Random(Sigma.cols()), max_iterations, tol); + } + /** * Compute the largest eigenvalue of the Hessian of the log density rescaled by a metric, * that is, the largest eigenvalue of L^T \nabla^2_{qq} H(q) L @@ -99,29 +130,18 @@ namespace stan { stan::math::gradient(log_prob_wrapper_covar(model), q - dr / 2.0, lp, grad2); return L.transpose() * (grad1 - grad2) / dx; }; - - return power_method(hessian_vector, Eigen::VectorXd::Random(q.size()), 100, 1e-3); - } - - /** - * Compute the largest eigenvalue of the sample covariance rescaled by a metric, - * that is, the largest eigenvalue of L^{-T} \Sigma L^{-1} - * - * @param L Cholesky decomposition of Metric - * @param Sigma Sample covariance - * @return Largest eigenvalue - */ - double eigenvalue_scaled_covariance(const Eigen::MatrixXd& L, const Eigen::MatrixXd& Sigma) { - Eigen::MatrixXd S = L.template triangularView(). - solve(L.template triangularView().solve(Sigma).transpose()).transpose(); - - auto Sx = [&](Eigen::VectorXd x) -> Eigen::VectorXd { - return S * x; - }; + + int max_iterations = 100; + double tol = 1e-3; - return power_method(Sx, Eigen::VectorXd::Random(Sigma.cols()), 100, 1e-3); + return internal::power_method(hessian_vector, Eigen::VectorXd::Random(q.size()), max_iterations, tol); } + } + class auto_adaptation: public windowed_adaptation { + public: + explicit auto_adaptation(int n) + : windowed_adaptation("covariance") {} /** * Update the metric if at the end of an adaptation window. * @@ -140,17 +160,17 @@ namespace stan { if (end_adaptation_window()) { compute_next_window(); - int N = q.size(); - int M = qs_.size(); + int M = q.size(); + int N = qs_.size(); - Eigen::MatrixXd Y = Eigen::MatrixXd::Zero(N, M); - std::vector idxs(M); + Eigen::MatrixXd Y = Eigen::MatrixXd::Zero(M, N); + std::vector idxs(N); for(int i = 0; i < qs_.size(); i++) idxs[i] = i; std::random_shuffle(idxs.begin(), idxs.end()); for(int i = 0; i < qs_.size(); i++) - Y.block(0, i, N, 1) = qs_[idxs[i]]; + Y.block(0, i, M, 1) = qs_[idxs[i]]; try { bool use_dense = false; @@ -159,27 +179,28 @@ namespace stan { Eigen::MatrixXd Ytest; if(state == "selection") { - int Ntest; - Ntest = int(0.2 * Y.cols()); - if(Ntest < 5) { - Ntest = 5; + int Mtest; + Mtest = int(0.2 * Y.cols()); + if(Mtest < 5) { + Mtest = 5; } if(Y.cols() < 10) { throw std::runtime_error("Each warmup stage must have at least 10 samples"); } - Ytrain = Y.block(0, 0, N, Y.cols() - Ntest); - Ytest = Y.block(0, Ytrain.cols(), N, Ntest); + Ytrain = Y.block(0, 0, M, Y.cols() - Mtest); + Ytest = Y.block(0, Ytrain.cols(), M, Mtest); } else { Ytrain = Y; } - Eigen::MatrixXd cov_train = covariance(Ytrain); - Eigen::MatrixXd cov_test = covariance(Ytest); - + Eigen::MatrixXd cov_train = (Ytrain.cols() > 0) ? internal::covariance(Ytrain.transpose()) : Eigen::MatrixXd::Zero(M, M); + Eigen::MatrixXd cov_test = (Ytest.cols() > 0) ? internal::covariance(Ytest.transpose()) : Eigen::MatrixXd::Zero(M, M); + Eigen::MatrixXd dense = (N / (N + 5.0)) * cov_train + 1e-3 * (5.0 / (N + 5.0)) * Eigen::MatrixXd::Identity(cov_train.rows(), cov_train.cols()); + Eigen::MatrixXd diag = dense.diagonal().asDiagonal(); covar = dense; @@ -188,14 +209,14 @@ namespace stan { Eigen::MatrixXd L_dense = dense.llt().matrixL(); Eigen::MatrixXd L_diag = diag.diagonal().array().sqrt().matrix().asDiagonal(); - double low_eigenvalue_dense = -1.0 / eigenvalue_scaled_covariance(L_dense, cov_test); - double low_eigenvalue_diag = -1.0 / eigenvalue_scaled_covariance(L_diag, cov_test); + double low_eigenvalue_dense = -1.0 / internal::eigenvalue_scaled_covariance(L_dense, cov_test); + double low_eigenvalue_diag = -1.0 / internal::eigenvalue_scaled_covariance(L_diag, cov_test); double c_dense = 0.0; double c_diag = 0.0; for(int i = 0; i < 5; i++) { - double high_eigenvalue_dense = eigenvalue_scaled_hessian(model, L_dense, Ytest.block(0, i, N, 1)); - double high_eigenvalue_diag = eigenvalue_scaled_hessian(model, L_diag, Ytest.block(0, i, N, 1)); + double high_eigenvalue_dense = internal::eigenvalue_scaled_hessian(model, L_dense, Ytest.block(0, i, M, 1)); + double high_eigenvalue_diag = internal::eigenvalue_scaled_hessian(model, L_diag, Ytest.block(0, i, M, 1)); c_dense = std::max(c_dense, std::sqrt(high_eigenvalue_dense / low_eigenvalue_dense)); c_diag = std::max(c_diag, std::sqrt(high_eigenvalue_diag / low_eigenvalue_diag)); @@ -222,9 +243,9 @@ namespace stan { } catch(const std::exception& e) { std::cout << e.what() << std::endl; std::cout << "Exception while using auto adaptation, falling back to diagonal" << std::endl; - Eigen::MatrixXd cov = covariance(Y); - covar = ((M / (M + 5.0)) * cov.diagonal() - + 1e-3 * (5.0 / (M + 5.0)) * Eigen::VectorXd::Ones(cov.cols())).asDiagonal(); + Eigen::MatrixXd cov = internal::covariance(Y.transpose()); + covar = ((N / (N + 5.0)) * cov.diagonal() + + 1e-3 * (5.0 / (N + 5.0)) * Eigen::VectorXd::Ones(cov.cols())).asDiagonal(); covar_is_diagonal = true; } diff --git a/src/test/test-models/good/model/correlated_gaussian.stan b/src/test/test-models/good/model/correlated_gaussian.stan new file mode 100644 index 00000000000..fc55ebd83e1 --- /dev/null +++ b/src/test/test-models/good/model/correlated_gaussian.stan @@ -0,0 +1,7 @@ +parameters { + vector[2] x; +} + +model { + x ~ multi_normal([0.0, 0.0], [[1.0, 0.99], [0.99, 1.0]]); +} diff --git a/src/test/test-models/good/model/independent_gaussian.stan b/src/test/test-models/good/model/independent_gaussian.stan new file mode 100644 index 00000000000..451a1251c83 --- /dev/null +++ b/src/test/test-models/good/model/independent_gaussian.stan @@ -0,0 +1,7 @@ +parameters { + vector[2] x; +} + +model { + x ~ normal(0.0, 1.0); +} diff --git a/src/test/test-models/good/model/known_hessian.stan b/src/test/test-models/good/model/known_hessian.stan new file mode 100644 index 00000000000..da35f674436 --- /dev/null +++ b/src/test/test-models/good/model/known_hessian.stan @@ -0,0 +1,7 @@ +parameters { + real x[3]; +} + +model { + x ~ normal(0, 1); +} diff --git a/src/test/unit/mcmc/auto_adaptation_learn_covariance_pick_dense_test.cpp b/src/test/unit/mcmc/auto_adaptation_learn_covariance_pick_dense_test.cpp new file mode 100644 index 00000000000..0a8bd02139f --- /dev/null +++ b/src/test/unit/mcmc/auto_adaptation_learn_covariance_pick_dense_test.cpp @@ -0,0 +1,64 @@ +#include +#include +#include +#include + +TEST(McmcVarAdaptation, learn_covariance_pick_dense) { + std::fstream data_stream(std::string("").c_str(), std::fstream::in); + stan::io::dump data_var_context(data_stream); + data_stream.close(); + + std::stringstream output; + correlated_gaussian_model_namespace::correlated_gaussian_model + correlated_gaussian_model(data_var_context, &output); + + stan::test::unit::instrumented_logger logger; + + const int M = 2; + const int N = 20; + Eigen::MatrixXd qs(N, M); + qs << 0.256173753306128, -0.0238087093098673, + -1.63218152810157, -1.5309929638363, + -0.812451331685826, -1.15062373620068, + -1.49814775191801, -1.51310110681996, + 0.738630631536685, 1.03588205799336, + 0.472288580035284, 0.250286770328584, + -1.63634486169493, -1.6222798835089, + -0.400790615207103, -0.337669147200631, + -0.568779612417544, -0.424833495378187, + 0.103690913176746, 0.272885200284842, + -0.453017424229528, -0.504634004215693, + 3.34484533887237, 3.29418872328382, + -1.3376507113241, -1.32724775403694, + -0.137543235057544, -0.0290938109919368, + -1.58194496352741, -1.39338740677379, + 0.312166136194586, 0.336989933768233, + -0.628941448228566, -0.850758612234264, + -0.766816808981044, -0.645020468024267, + -0.75078110234827, -0.502544092120385, + -0.00694807494461906, -0.186748159558166; + + Eigen::MatrixXd covar(M, M); + bool covar_is_diagonal; + + Eigen::MatrixXd target_covar(M, M); + + target_covar << 1.0311414783609130, 1.0100577463968425, + 1.0100577463968425, 1.0148380697138280; + + stan::mcmc::auto_adaptation adapter(M); + adapter.set_window_params(50, 0, 0, N, logger); + + for (int i = 0; i < N; ++i) { + Eigen::VectorXd q = qs.block(i, 0, 1, M).transpose(); + adapter.learn_covariance(correlated_gaussian_model, covar, covar_is_diagonal, q); + } + + for (int i = 0; i < covar.size(); ++i) { + EXPECT_FLOAT_EQ(target_covar(i), covar(i)); + } + + EXPECT_EQ(covar_is_diagonal, false); + + EXPECT_EQ(0, logger.call_count()); +} diff --git a/src/test/unit/mcmc/auto_adaptation_learn_covariance_pick_diag_test.cpp b/src/test/unit/mcmc/auto_adaptation_learn_covariance_pick_diag_test.cpp new file mode 100644 index 00000000000..26f3da56cef --- /dev/null +++ b/src/test/unit/mcmc/auto_adaptation_learn_covariance_pick_diag_test.cpp @@ -0,0 +1,63 @@ +#include +#include +#include +#include + +TEST(McmcVarAdaptation, learn_covariance_pick_diagonal) { + std::fstream data_stream(std::string("").c_str(), std::fstream::in); + stan::io::dump data_var_context(data_stream); + data_stream.close(); + + std::stringstream output; + independent_gaussian_model_namespace::independent_gaussian_model + independent_gaussian_model(data_var_context, &output); + + stan::test::unit::instrumented_logger logger; + + const int M = 2; + const int N = 20; + Eigen::MatrixXd qs(N, M); + qs << 0.607446257145326, 0.338465765807058, + 1.47389672467345, -1.0577986841911, + 1.02886652895522, 0.364277500948572, + 0.492316893603469, 2.19693408641558, + -0.931854393410476, 1.62634580968769, + -0.443145375724188, 0.902790875582656, + 0.517782110245233, -1.56724331755861, + -1.7556390097031, 0.310274990315213, + 0.0394975482340945, 0.366999438969482, + 1.29372950054929, 0.361369734821582, + -0.258301497542829, 0.166994731172984, + 0.492639248874412, -0.659502589885556, + 0.913729457222598, 1.99580706461809, + 0.669655370469707, -0.509028392475839, + -0.626041244059129, -0.771981104624195, + -0.842385483586737, 0.337166271031201, + 0.548177804329155, -0.0462961925005498, + 0.955748803092952, 1.3141117316189, + 0.335670079140694, 1.09112083087171, + 0.759245358940033, -1.11318882201676; + + Eigen::MatrixXd covar(M, M); + bool covar_is_diagonal; + + Eigen::MatrixXd target_covar(M, M); + + target_covar << 0.55350038163333048, 0.0, 0.0, 0.86122545968912112; + + stan::mcmc::auto_adaptation adapter(M); + adapter.set_window_params(50, 0, 0, N, logger); + + for (int i = 0; i < N; ++i) { + Eigen::VectorXd q = qs.block(i, 0, 1, M).transpose(); + adapter.learn_covariance(independent_gaussian_model, covar, covar_is_diagonal, q); + } + + for (int i = 0; i < covar.size(); ++i) { + EXPECT_FLOAT_EQ(target_covar(i), covar(i)); + } + + EXPECT_EQ(covar_is_diagonal, true); + + EXPECT_EQ(0, logger.call_count()); +} diff --git a/src/test/unit/mcmc/auto_adaptation_test.cpp b/src/test/unit/mcmc/auto_adaptation_test.cpp new file mode 100644 index 00000000000..69890163251 --- /dev/null +++ b/src/test/unit/mcmc/auto_adaptation_test.cpp @@ -0,0 +1,170 @@ +#include +#include +#include +#include + +TEST(McmcAutoAdaptation, test_covariance_zero_rows_zero_cols) { + Eigen::MatrixXd X1(0, 5); + + EXPECT_THROW(stan::mcmc::internal::covariance(X1), std::invalid_argument); + + Eigen::MatrixXd X2(1, 0); + + EXPECT_THROW(stan::mcmc::internal::covariance(X2), std::invalid_argument); +} + +TEST(McmcAutoAdaptation, test_covariance_one_row_one_col) { + Eigen::MatrixXd X1(1, 2); + Eigen::MatrixXd X2(3, 1); + + X1 << 1.0, 2.0; + X2 << 1.0, 2.0, 3.0; + + Eigen::MatrixXd cov1 = stan::mcmc::internal::covariance(X1); + Eigen::MatrixXd cov2 = stan::mcmc::internal::covariance(X2); + + ASSERT_EQ(cov1.rows(), 2); + ASSERT_EQ(cov1.cols(), 2); + + ASSERT_EQ(cov2.rows(), 1); + ASSERT_EQ(cov2.cols(), 1); + + for(int i = 0; i < cov1.size(); ++i) { + ASSERT_FLOAT_EQ(cov1(i), 0.0); + } + + ASSERT_FLOAT_EQ(cov2(0), 1.0); +} + +TEST(McmcAutoAdaptation, test_covariance) { + Eigen::MatrixXd X1(3, 2); + Eigen::MatrixXd X2(2, 3); + + X1 << 0.0, -1.0, 0.5, -2.7, 3.0, 5.0; + X2 << 0.0, 3, -2.7, 0.5, -1, 5.0; + + Eigen::MatrixXd cov1 = stan::mcmc::internal::covariance(X1); + Eigen::MatrixXd cov2 = stan::mcmc::internal::covariance(X2); + + Eigen::MatrixXd cov1_ref(2, 2); + Eigen::MatrixXd cov2_ref(3, 3); + + cov1_ref << 2.5833333333333335, 6.0666666666666664, + 6.0666666666666664, 16.3633333333333333; + + cov2_ref << 0.125, -1.0, 1.925, + -1.000, 8.0, -15.4, + 1.925, -15.4, 29.645; + + ASSERT_EQ(cov1.rows(), cov1_ref.rows()); + ASSERT_EQ(cov1.cols(), cov1_ref.cols()); + + ASSERT_EQ(cov2.rows(), cov2_ref.rows()); + ASSERT_EQ(cov2.cols(), cov2_ref.cols()); + + for(int i = 0; i < cov1_ref.size(); ++i) { + ASSERT_FLOAT_EQ(cov1(i), cov1_ref(i)); + } + + for(int i = 0; i < cov2_ref.size(); ++i) { + ASSERT_FLOAT_EQ(cov2(i), cov2_ref(i)); + } +} + +TEST(McmcAutoAdaptation, power_method) { + Eigen::MatrixXd X(2, 2); + Eigen::VectorXd x0(2); + + X << 2.0, 0.5, 0.5, 1.0; + x0 << 1.0, 0.0; + + const int max_iterations = 10; + const double tol = 1e-10; + + auto Av = [&](const Eigen::VectorXd& v) { return X * v; }; + + int max_iterations_1 = max_iterations; + double tol_1 = tol; + + double eval = stan::mcmc::internal::power_method(Av, x0, max_iterations_1, tol_1); + + EXPECT_FLOAT_EQ(eval, 2.20710678118654746); +} + +TEST(McmcAutoAdaptation, power_method_tol_check) { + Eigen::MatrixXd X(2, 2); + Eigen::VectorXd x0(2); + + X << 2.0, 0.5, 0.5, 1.0; + x0 << 1.0, 0.0; + + const int max_iterations = 1000; + const double tol = 1e-12; + + auto Av = [&](const Eigen::VectorXd& v) { return X * v; }; + + int max_iterations_1 = max_iterations; + double tol_1 = tol; + double eval = stan::mcmc::internal::power_method(Av, x0, max_iterations_1, tol_1); + + EXPECT_LT(tol_1, tol); +} + +TEST(McmcAutoAdaptation, power_method_iter_check) { + Eigen::MatrixXd X(2, 2); + Eigen::VectorXd x0(2); + + X << 2.0, 0.5, 0.5, 1.0; + x0 << 1.0, 0.0; + + const int max_iterations = 10; + const double tol = 1e-50; + + auto Av = [&](const Eigen::VectorXd& v) { return X * v; }; + + int max_iterations_1 = max_iterations; + double tol_1 = tol; + double eval = stan::mcmc::internal::power_method(Av, x0, max_iterations_1, tol_1); + + EXPECT_GT(tol_1, tol); + EXPECT_EQ(max_iterations_1, max_iterations); +} + +// The checks in here are very coarse because eigenvalue_scaled_covariance +// only estimates things with low precision +TEST(McmcAutoAdaptation, eigenvalue_scaled_covariance) { + Eigen::MatrixXd L(2, 2), Sigma(2, 2); + + L << 1.0, 0.0, 0.5, 1.0; + Sigma << 2.0, 0.7, 0.7, 1.3; + + double eval = stan::mcmc::internal::eigenvalue_scaled_covariance(L, Sigma); + + EXPECT_LT(std::abs(eval - 2.0908326913195983) / eval, 1e-2); + + L << 2.0, 0.0, 0.7, 1.3; + + eval = stan::mcmc::internal::eigenvalue_scaled_covariance(L, Sigma); + + EXPECT_LT(std::abs(eval - 0.62426035502958577) / eval, 1e-2); +} + +// The checks in here are very coarse because eigenvalue_scaled_hessian +// only estimates things with low precision +TEST(McmcAutoAdaptation, eigenvalue_scaled_hessian) { + std::fstream data_stream(std::string("").c_str(), std::fstream::in); + stan::io::dump data_var_context(data_stream); + data_stream.close(); + + std::stringstream output; + known_hessian_model_namespace::known_hessian_model known_hessian_model(data_var_context, &output); + + Eigen::MatrixXd L(3, 3); + Eigen::VectorXd q(3); + L << 2.0, 0.0, 0.0, 0.7, 1.3, 0.0, -1.5, 2.0, 4.0; + q << 0.0, 0.0, 0.0; + + double eval = stan::mcmc::internal::eigenvalue_scaled_hessian(known_hessian_model, L, q); + + EXPECT_LT(std::abs(eval - 22.8141075806892850) / eval, 1e-2); +} From 76bfd27eeee41521fd625b796b1798e03e292abb Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 30 Oct 2019 13:20:08 -0700 Subject: [PATCH 06/73] first draft of master-slave communication setup for warmup --- src/stan/services/util/mpi.hpp | 296 ++++++++++++++++++ .../unit/services/util/mpi_warmup_test.cpp | 67 ++++ 2 files changed, 363 insertions(+) create mode 100644 src/stan/services/util/mpi.hpp create mode 100644 src/test/unit/services/util/mpi_warmup_test.cpp diff --git a/src/stan/services/util/mpi.hpp b/src/stan/services/util/mpi.hpp new file mode 100644 index 00000000000..ec4c09c8fd5 --- /dev/null +++ b/src/stan/services/util/mpi.hpp @@ -0,0 +1,296 @@ +#ifndef STAN_SERVICES_UTIL_MPI_WARMUP_HPP +#define STAN_SERVICES_UTIL_MPI_WARMUP_HPP + +#ifdef STAN_LANG_MPI + +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace services { +namespace util { +namespace mpi { + + /* + * MPI Evionment that initializes and finalizes the MPI + */ + struct Envionment { + Envionment() { + init(); + } + ~Envionment() { + finalize(); + } + + static void init() { +#ifdef STAN_LANG_MPI + int flag; + MPI_Initialized(&flag); + if(!flag) { + MPI_Init(NULL, NULL); + } +#endif + } + + static void finalize() { +#ifdef STAN_LANG_MPI + int flag; + MPI_Finalized(&flag); + if(!flag) MPI_Finalize(); +#endif + } + }; + + /* + * MPI Communicators. With default constructor disabled, + * a communicator can only be created through duplication. + */ + struct Communicator { + private: + Communicator(); + + const Envionment& env_; + + public: + MPI_Comm comm; + int size; + int rank; + + /* + * communicator constructor using @c Envionment and @c MPI_Comm + */ + Communicator(const Envionment& env, MPI_Comm other) : + env_(env), comm(MPI_COMM_NULL) { + MPI_Comm_dup(other, &comm); + MPI_Comm_size(comm, &size); + MPI_Comm_rank(comm, &rank); + } + + /* + * copy constructor is deep + */ + explicit Communicator(const Communicator& other) : + Communicator(other.env_, other.comm) + {} + + /* + * type-cast to MPI_Comm object + */ + operator MPI_Comm() { + return this -> comm; + } + + /* + * destructor needs to free MPI_Comm + */ + ~Communicator() { + MPI_Comm_free(&comm); + } + }; + +#define NUM_STAN_LANG_MPI_COMM 1 +#define STAN_LANG_MPI_COMM_WARMUP 0 + + /* + * MPI communicator wrapper for RAII. Note that no + * MPI's predfined comm such as @c MPI_COMM_WOLRD are allowed. + */ + template + struct Session { + static Envionment env; + static std::vector comms; + }; + + template + Envionment Session::env; + + template + std::vector Session::comms(N_comm, Communicator(Session::env, MPI_COMM_WORLD)); + + /** + * Dynamic loader that manages master & slave + * communication and data assembly. + */ + struct warmup_dynamic_loader_base { + static const int work_tag = 1; + static const int err_tag = 2; + static const int adapt_tag = 3; + + //! communicator wrapper for warmup + const Communicator& warmup_comm; + //! MPI communicator + const MPI_Comm comm; + //! communication interval + int interval; + //! double workspace + Eigen::MatrixXd workspace_r; + + //! construct loader given MPI communicator + warmup_dynamic_loader_base(const Communicator& comm_in, int inter) : + warmup_comm(comm_in), comm(warmup_comm.comm), + interval(inter) + { + // make sure there are slave chains. + static const char* caller = "warmup_dynamic_loader"; + stan::math::check_greater(caller, "MPI comm size", warmup_comm.size, 1); + } + }; + + /** + * master receives adaptation info from slave chains and + * process that info through an external functor @c ensemble_func + * in order to improve the quality of adaptation, before + * sending the improved adapt info to slave chains. + */ + struct warmup_dynamic_loader_master : warmup_dynamic_loader_base { + //! construct loader given MPI communicator + warmup_dynamic_loader_master(const Communicator& comm_in, + int inter) : + warmup_dynamic_loader_base(comm_in, inter) + {} + + /* + * helper function to master node (rank = 0) to recv + * results. + * @return array {tag, source}. + */ + template + MPI_Status + recv(std::vector& req, const Recv_processor& chain_func, + Sampler& sampler, Model& model) { + MPI_Status stat; + MPI_Probe(MPI_ANY_SOURCE, MPI_ANY_TAG, comm, &stat); + int source = stat.MPI_SOURCE; + if (stat.MPI_TAG == err_tag) { + double dummy; + MPI_Irecv(&dummy, 0, MPI_DOUBLE, source, err_tag, comm, &req[source]); + } else { + int n = chain_func.recv_size(sampler, model); + MPI_Irecv(&workspace_r((source - 1) * n), n, + MPI_DOUBLE, source, work_tag, comm, &req[source]); + } + return stat; + } + + /* + * master node (rank = 0) recv results and send + * available tasks to vacant slaves. + */ + template + void operator()(Sampler& sampler, Model& model, + const Send_processor& ensemble_func, + const Recv_processor& chain_func) { + static const char* caller = "warmup_dynamic_loader_master::master"; + stan::math::check_less(caller, "MPI comm rank", warmup_comm.rank, 1); + + std::vector req(warmup_comm.size); + std::array recv_out; + + int recved = 0; + int irecve = 0; + int source; + bool is_invalid = false; + while (irecve != warmup_comm.size || (!is_invalid)) { + // recv adaption results from certain chain + MPI_Status stat(recv(req, chain_func, sampler, model)); + is_invalid = stat.MPI_TAG == err_tag; + source = stat.MPI_SOURCE; + irecve++; + + // processing recieved data + if (!is_invalid) { + int index, flag = 0; + MPI_Testany(warmup_comm.size, req.data(), &index, + &flag, MPI_STATUS_IGNORE); + if(flag) { + recved++; + chain_func(sampler, model, workspace_r, index); + } + } + } + + if (is_invalid) { + for (int i = 1; i < warmup_comm.size; ++i) { + MPI_Send(workspace_r.data(), 0, MPI_DOUBLE, i, err_tag, comm); + } + while (irecve != warmup_comm.size) { + recv(req, chain_func, sampler, model); + irecve++; + } + MPI_Waitall(warmup_comm.size, req.data(), MPI_STATUSES_IGNORE); + std::ostringstream chain_adapt_fail_msg; + chain_adapt_fail_msg << "Invalid adaptation data in Chain " << source; + throw std::runtime_error(chain_adapt_fail_msg.str()); + } else { + for (int i = 1; i < warmup_comm.size; ++i) { + MPI_Send(workspace_r.data(), 0, MPI_DOUBLE, i, adapt_tag, comm); + } + while (recved != warmup_comm.size) { + int index, flag = 0; + MPI_Testany(warmup_comm.size, req.data(), &index, + &flag, MPI_STATUS_IGNORE); + if(flag) { + recved++; + chain_func(sampler, model, workspace_r, index); + } + } + ensemble_func(sampler, model, workspace_r); + MPI_Bcast(workspace_r.data(), ensemble_func.send_size, MPI_DOUBLE, 0, comm); + } + } + }; + + struct warmup_dynamic_loader_slave : warmup_dynamic_loader_base { + //! construct loader given MPI communicator + warmup_dynamic_loader_slave(const Communicator& comm_in, + int inter) : + warmup_dynamic_loader_base(comm_in, inter) + {} + + /* + * master node (rank = 0) recv results and send + * available tasks to vacant slaves. + */ + template + void operator()(Sampler& sampler, Model& model, + const Send_processor& chain_func, + const Recv_processor& adapt_func) { + using Eigen::MatrixXd; + using Eigen::Matrix; + + static const char* caller = "warmup_dynamic_loader_slave::slave"; + stan::math::check_greater(caller, "MPI comm rank", warmup_comm.rank, 0); + + // process adapt info before sending out to master + workspace_r.resize(chain_func.send_size, 1); + chain_func(sampler, model, workspace_r, warmup_comm.rank); + MPI_Send(workspace_r.data(), chain_func.send_size, MPI_DOUBLE, 0, work_tag, comm); + + MPI_Status stat; + MPI_Recv(workspace_r.data(), 0, MPI_DOUBLE, 0, MPI_ANY_TAG, comm, &stat); + if (stat.MPI_TAG == err_tag) { + std::ostringstream chain_adapt_fail_msg; + chain_adapt_fail_msg << "Invalid adaptation data in ensemble"; + throw std::runtime_error(chain_adapt_fail_msg.str()); + } else if (stat.MPI_TAG == adapt_tag) { + workspace_r.resize(adapt_func.recv_size, 1); + MPI_Bcast(workspace_r.data(), adapt_func.recv_size, MPI_DOUBLE, 0, comm); + adapt_func(sampler, model, workspace_r); + } + } + }; + +} // mpi +} // namespace util +} // namespace services +} // namespace stan +#endif + + +#endif diff --git a/src/test/unit/services/util/mpi_warmup_test.cpp b/src/test/unit/services/util/mpi_warmup_test.cpp new file mode 100644 index 00000000000..f9eb17b012c --- /dev/null +++ b/src/test/unit/services/util/mpi_warmup_test.cpp @@ -0,0 +1,67 @@ +#ifdef STAN_LANG_MPI + +#include +#include + +using Eigen::MatrixXd; +using Eigen::Matrix; +using std::vector; +using stan::services::util::mpi::Communicator; +using stan::services::util::mpi::Session; +using stan::services::util::mpi::warmup_dynamic_loader_base; +using stan::services::util::mpi::warmup_dynamic_loader_master; +using stan::services::util::mpi::warmup_dynamic_loader_slave; + +struct dummy_sampler {}; +struct dummy_model {}; +struct dummy_master_ensemble_processor { + template + int send_size(Sampler& sampler, Model& model) { + return 10; + } + + template + void operator()(Sampler& sampler, Model& model, + Eigen::MatrixXd& workspace_r) { + } +}; + +struct dummy_master_chain_processor { + template + int recv_size(Sampler& sampler, Model& model) { + return 10; + } + + template + void operator()(Sampler& sampler, Model& model, + Eigen::MatrixXd& workspace_r, int index) { + } +}; + +struct dummy_slave_chain_processor { + template + void operator()(Sampler& sampler, Model& model, + Eigen::MatrixXd& workspace_r, int index) { + } +}; + +TEST(mpi_warmup_test, mpi_session) { + const Communicator& warmup_comm = + Session::comms[0]; + + warmup_dynamic_loader_base load(warmup_comm, 10); +} + +TEST(mpi_warmup_test, mpi_master_slave) { + const Communicator& warmup_comm = + Session::comms[0]; + + if (warmup_comm.rank == 0) { + warmup_dynamic_loader_master master(warmup_comm, 10); + } else { + warmup_dynamic_loader_slave slave(warmup_comm, 10); + } + +} + +#endif From 7bfe044bc84fb5002bf9a9d49d1df1ee4f91ab46 Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 30 Oct 2019 14:59:34 -0700 Subject: [PATCH 07/73] unit test for mpi warmup communication --- src/stan/services/util/mpi.hpp | 68 ++++++++++--------- .../unit/services/util/mpi_warmup_test.cpp | 47 ++++++++++++- 2 files changed, 81 insertions(+), 34 deletions(-) diff --git a/src/stan/services/util/mpi.hpp b/src/stan/services/util/mpi.hpp index ec4c09c8fd5..47c332bbf4d 100644 --- a/src/stan/services/util/mpi.hpp +++ b/src/stan/services/util/mpi.hpp @@ -157,22 +157,23 @@ namespace mpi { /* * helper function to master node (rank = 0) to recv * results. - * @return array {tag, source}. + * @return MPI_Status of recv operation */ template MPI_Status - recv(std::vector& req, const Recv_processor& chain_func, + recv(std::vector& req, Recv_processor& chain_func, Sampler& sampler, Model& model) { MPI_Status stat; MPI_Probe(MPI_ANY_SOURCE, MPI_ANY_TAG, comm, &stat); int source = stat.MPI_SOURCE; + int ireq = source - 1; if (stat.MPI_TAG == err_tag) { double dummy; - MPI_Irecv(&dummy, 0, MPI_DOUBLE, source, err_tag, comm, &req[source]); + MPI_Irecv(&dummy, 0, MPI_DOUBLE, source, err_tag, comm, &req[ireq]); } else { int n = chain_func.recv_size(sampler, model); MPI_Irecv(&workspace_r((source - 1) * n), n, - MPI_DOUBLE, source, work_tag, comm, &req[source]); + MPI_DOUBLE, source, work_tag, comm, &req[ireq]); } return stat; } @@ -182,48 +183,44 @@ namespace mpi { * available tasks to vacant slaves. */ template + typename Send_processor, typename Recv_processor, + typename Post_processor> void operator()(Sampler& sampler, Model& model, - const Send_processor& ensemble_func, - const Recv_processor& chain_func) { + Send_processor& ensemble_func, + Recv_processor& chain_func, + Post_processor& post_func) { static const char* caller = "warmup_dynamic_loader_master::master"; stan::math::check_less(caller, "MPI comm rank", warmup_comm.rank, 1); - std::vector req(warmup_comm.size); + int nslave = warmup_comm.size - 1; + std::vector req(nslave); std::array recv_out; int recved = 0; int irecve = 0; int source; bool is_invalid = false; - while (irecve != warmup_comm.size || (!is_invalid)) { + while (irecve != nslave && (!is_invalid)) { // recv adaption results from certain chain + workspace_r.resize(chain_func.recv_size(sampler, model), + nslave); MPI_Status stat(recv(req, chain_func, sampler, model)); is_invalid = stat.MPI_TAG == err_tag; source = stat.MPI_SOURCE; irecve++; - - // processing recieved data - if (!is_invalid) { - int index, flag = 0; - MPI_Testany(warmup_comm.size, req.data(), &index, - &flag, MPI_STATUS_IGNORE); - if(flag) { - recved++; - chain_func(sampler, model, workspace_r, index); - } - } } + MPI_Request bcast_req; + if (is_invalid) { for (int i = 1; i < warmup_comm.size; ++i) { MPI_Send(workspace_r.data(), 0, MPI_DOUBLE, i, err_tag, comm); } - while (irecve != warmup_comm.size) { + while (irecve != nslave) { recv(req, chain_func, sampler, model); irecve++; } - MPI_Waitall(warmup_comm.size, req.data(), MPI_STATUSES_IGNORE); + MPI_Waitall(nslave, req.data(), MPI_STATUSES_IGNORE); std::ostringstream chain_adapt_fail_msg; chain_adapt_fail_msg << "Invalid adaptation data in Chain " << source; throw std::runtime_error(chain_adapt_fail_msg.str()); @@ -231,18 +228,21 @@ namespace mpi { for (int i = 1; i < warmup_comm.size; ++i) { MPI_Send(workspace_r.data(), 0, MPI_DOUBLE, i, adapt_tag, comm); } - while (recved != warmup_comm.size) { + while (recved != nslave) { int index, flag = 0; - MPI_Testany(warmup_comm.size, req.data(), &index, - &flag, MPI_STATUS_IGNORE); + MPI_Testany(nslave, req.data(), &index, &flag, MPI_STATUS_IGNORE); if(flag) { recved++; - chain_func(sampler, model, workspace_r, index); + chain_func(sampler, model, workspace_r, index + 1); } } ensemble_func(sampler, model, workspace_r); - MPI_Bcast(workspace_r.data(), ensemble_func.send_size, MPI_DOUBLE, 0, comm); + MPI_Ibcast(workspace_r.data(), ensemble_func.send_size(sampler, model), + MPI_DOUBLE, 0, comm, &bcast_req); } + + post_func(sampler, model, workspace_r); + MPI_Wait(&bcast_req, MPI_STATUS_IGNORE); } }; @@ -259,8 +259,8 @@ namespace mpi { */ template void operator()(Sampler& sampler, Model& model, - const Send_processor& chain_func, - const Recv_processor& adapt_func) { + Send_processor& chain_func, + Recv_processor& adapt_func) { using Eigen::MatrixXd; using Eigen::Matrix; @@ -268,19 +268,21 @@ namespace mpi { stan::math::check_greater(caller, "MPI comm rank", warmup_comm.rank, 0); // process adapt info before sending out to master - workspace_r.resize(chain_func.send_size, 1); chain_func(sampler, model, workspace_r, warmup_comm.rank); - MPI_Send(workspace_r.data(), chain_func.send_size, MPI_DOUBLE, 0, work_tag, comm); + MPI_Send(workspace_r.data(), chain_func.send_size(sampler, model), MPI_DOUBLE, 0, work_tag, comm); MPI_Status stat; + MPI_Request bcast_req; MPI_Recv(workspace_r.data(), 0, MPI_DOUBLE, 0, MPI_ANY_TAG, comm, &stat); if (stat.MPI_TAG == err_tag) { std::ostringstream chain_adapt_fail_msg; chain_adapt_fail_msg << "Invalid adaptation data in ensemble"; throw std::runtime_error(chain_adapt_fail_msg.str()); } else if (stat.MPI_TAG == adapt_tag) { - workspace_r.resize(adapt_func.recv_size, 1); - MPI_Bcast(workspace_r.data(), adapt_func.recv_size, MPI_DOUBLE, 0, comm); + workspace_r.resize(adapt_func.recv_size(sampler, model), 1); + MPI_Ibcast(workspace_r.data(), adapt_func.recv_size(sampler, model), + MPI_DOUBLE, 0, comm, &bcast_req); + MPI_Wait(&bcast_req, MPI_STATUS_IGNORE); adapt_func(sampler, model, workspace_r); } } diff --git a/src/test/unit/services/util/mpi_warmup_test.cpp b/src/test/unit/services/util/mpi_warmup_test.cpp index f9eb17b012c..26bbce068e9 100644 --- a/src/test/unit/services/util/mpi_warmup_test.cpp +++ b/src/test/unit/services/util/mpi_warmup_test.cpp @@ -23,6 +23,9 @@ struct dummy_master_ensemble_processor { template void operator()(Sampler& sampler, Model& model, Eigen::MatrixXd& workspace_r) { + for (int i = 0; i < workspace_r.cols(); ++i) { + workspace_r(i, 0) = workspace_r(i + 1, i); + } } }; @@ -35,13 +38,42 @@ struct dummy_master_chain_processor { template void operator()(Sampler& sampler, Model& model, Eigen::MatrixXd& workspace_r, int index) { + workspace_r(index, index - 1) *= 2.5; + } +}; + +struct dummy_master_post_processor { + template + void operator()(Sampler& sampler, Model& model, + Eigen::MatrixXd& workspace_r) { } }; struct dummy_slave_chain_processor { + template + int send_size(Sampler& sampler, Model& model) { + return 10; + } + template void operator()(Sampler& sampler, Model& model, Eigen::MatrixXd& workspace_r, int index) { + workspace_r.resize(send_size(sampler, model), 1); + workspace_r.setZero(); + workspace_r(index) = index; + } +}; + +struct dummy_slave_adapt_processor { + template + int recv_size(Sampler& sampler, Model& model) { + return 10; + } + + template + void operator()(Sampler& sampler, Model& model, + Eigen::MatrixXd& workspace_r) { + workspace_r *= 1.5; } }; @@ -56,12 +88,25 @@ TEST(mpi_warmup_test, mpi_master_slave) { const Communicator& warmup_comm = Session::comms[0]; + dummy_sampler sampler; + dummy_model model; + if (warmup_comm.rank == 0) { warmup_dynamic_loader_master master(warmup_comm, 10); + dummy_master_ensemble_processor f; + dummy_master_chain_processor g; + dummy_master_post_processor h; + EXPECT_NO_THROW(master(sampler, model, f, g, h)); + EXPECT_FLOAT_EQ(master.workspace_r(0), 2.5); + EXPECT_FLOAT_EQ(master.workspace_r(1), 5.0); } else { warmup_dynamic_loader_slave slave(warmup_comm, 10); + dummy_slave_chain_processor f; + dummy_slave_adapt_processor g; + EXPECT_NO_THROW(slave(sampler, model, f, g)); + EXPECT_FLOAT_EQ(slave.workspace_r(0), 3.75); + EXPECT_FLOAT_EQ(slave.workspace_r(1), 7.50); } - } #endif From d57e3d991b959b69a40a79f79b797df3151a2469 Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 6 Nov 2019 11:12:23 -0800 Subject: [PATCH 08/73] inter chain and intra chain MPI communicators --- src/stan/services/util/mpi.hpp | 239 ++++++++++++------ .../unit/services/util/mpi_warmup_test.cpp | 198 ++++++++++++--- 2 files changed, 328 insertions(+), 109 deletions(-) diff --git a/src/stan/services/util/mpi.hpp b/src/stan/services/util/mpi.hpp index 47c332bbf4d..14e49760f98 100644 --- a/src/stan/services/util/mpi.hpp +++ b/src/stan/services/util/mpi.hpp @@ -4,13 +4,25 @@ #ifdef STAN_LANG_MPI #include +#include +#include #include +#include #include #include #include #include #include +// default comm to world comm, in case stan needs to be +// called as library. +#define MPI_COMM_STAN MPI_COMM_WORLD + +// by default there are no warmup-pulling chains. +#ifndef NUM_MPI_CHAINS +#define NUM_MPI_CHAINS 1 +#endif + namespace stan { namespace services { namespace util { @@ -20,32 +32,41 @@ namespace mpi { * MPI Evionment that initializes and finalizes the MPI */ struct Envionment { - Envionment() { - init(); - } - ~Envionment() { - finalize(); - } + struct Envionment_ { + Envionment_() { + init(); + } + ~Envionment_() { + finalize(); + } - static void init() { + static void init() { #ifdef STAN_LANG_MPI - int flag; - MPI_Initialized(&flag); - if(!flag) { - MPI_Init(NULL, NULL); - } + int flag; + MPI_Initialized(&flag); + if(!flag) { + int provided; + MPI_Init_thread(NULL, NULL, MPI_THREAD_SINGLE, &provided); + // print provided when needed + } #endif - } + } - static void finalize() { + static void finalize() { #ifdef STAN_LANG_MPI - int flag; - MPI_Finalized(&flag); - if(!flag) MPI_Finalize(); + int flag; + MPI_Finalized(&flag); + if(!flag) MPI_Finalize(); #endif - } + } + }; + + static const Envionment_ env; }; + // out-of-line initilization + const Envionment::Envionment_ Envionment::env; + /* * MPI Communicators. With default constructor disabled, * a communicator can only be created through duplication. @@ -54,8 +75,6 @@ namespace mpi { private: Communicator(); - const Envionment& env_; - public: MPI_Comm comm; int size; @@ -64,18 +83,20 @@ namespace mpi { /* * communicator constructor using @c Envionment and @c MPI_Comm */ - Communicator(const Envionment& env, MPI_Comm other) : - env_(env), comm(MPI_COMM_NULL) { - MPI_Comm_dup(other, &comm); - MPI_Comm_size(comm, &size); - MPI_Comm_rank(comm, &rank); + explicit Communicator(MPI_Comm other) : + comm(MPI_COMM_NULL), size(0), rank(-1) { + if (other != MPI_COMM_NULL) { + MPI_Comm_dup(other, &comm); + MPI_Comm_size(comm, &size); + MPI_Comm_rank(comm, &rank); + } } /* * copy constructor is deep */ explicit Communicator(const Communicator& other) : - Communicator(other.env_, other.comm) + Communicator(other.comm) {} /* @@ -89,55 +110,123 @@ namespace mpi { * destructor needs to free MPI_Comm */ ~Communicator() { - MPI_Comm_free(&comm); + if (comm != MPI_COMM_NULL) { + MPI_Comm_free(&comm); + } } }; -#define NUM_STAN_LANG_MPI_COMM 1 -#define STAN_LANG_MPI_COMM_WARMUP 0 + MPI_Comm inter_chain_comm(int num_mpi_chains) { + Envionment::env.init(); + + int world_size; + MPI_Comm_size(MPI_COMM_STAN, &world_size); + stan::math::check_greater_or_equal("MPI inter-chain session", + "number of procs", world_size, + num_mpi_chains); + + MPI_Group stan_group, new_group; + MPI_Comm_group(MPI_COMM_STAN, &stan_group); + int num_chain_with_extra_proc = world_size % num_mpi_chains; + int num_proc_per_chain = world_size / num_mpi_chains; + std::vector ranks(num_mpi_chains); + if (num_chain_with_extra_proc == 0) { + for (int i = 0, j = 0; i < world_size; i += num_proc_per_chain, ++j) { + ranks[j] = i; + } + } else { + num_proc_per_chain++; + int i = 0; + for (int j = 0; j < num_chain_with_extra_proc; ++j) { + ranks[j] = i; + i += num_proc_per_chain; + } + num_proc_per_chain--; + for (int j = num_chain_with_extra_proc; j < num_mpi_chains; ++j) { + ranks[j] = i; + i += num_proc_per_chain; + } + } - /* - * MPI communicator wrapper for RAII. Note that no - * MPI's predfined comm such as @c MPI_COMM_WOLRD are allowed. - */ - template + MPI_Group_incl(stan_group, num_mpi_chains, ranks.data(), &new_group); + MPI_Comm new_inter_comm, new_intra_comm; + MPI_Comm_create_group(MPI_COMM_STAN, new_group, 99, &new_inter_comm); + MPI_Group_free(&new_group); + MPI_Group_free(&stan_group); + return new_inter_comm; + } + + MPI_Comm intra_chain_comm(int num_mpi_chains) { + Envionment::env.init(); + + int world_size, world_rank, color; + MPI_Comm_size(MPI_COMM_STAN, &world_size); + MPI_Comm_rank(MPI_COMM_STAN, &world_rank); + + int num_chain_with_extra_proc = world_size % num_mpi_chains; + const int n_proc = world_size / num_mpi_chains; + if (num_chain_with_extra_proc == 0) { + color = world_rank / n_proc; + } else { + int i = 0; + for (int j = 0; j < num_mpi_chains; ++j) { + const int n = j < num_chain_with_extra_proc ? (n_proc + 1) : n_proc; + if (world_rank >= i && world_rank < i + n) { + color = i; + break; + } + i += n; + } + } + + MPI_Comm new_intra_comm; + MPI_Comm_split(MPI_COMM_STAN, color, world_rank, &new_intra_comm); + return new_intra_comm; + } + + // * MPI communicator wrapper for RAII. Note that no + // * MPI's predfined comm such as @c MPI_COMM_WOLRD are allowed. + template struct Session { - static Envionment env; - static std::vector comms; + static const Communicator stan_comm; + static const MPI_Comm MPI_COMM_INTER_CHAIN; + static const MPI_Comm MPI_COMM_INTRA_CHAIN; }; - template - Envionment Session::env; + template + const Communicator Session::stan_comm(MPI_COMM_WORLD); - template - std::vector Session::comms(N_comm, Communicator(Session::env, MPI_COMM_WORLD)); + template + const MPI_Comm Session:: + MPI_COMM_INTER_CHAIN(inter_chain_comm(num_mpi_chains)); + + template + const MPI_Comm Session:: + MPI_COMM_INTRA_CHAIN(intra_chain_comm(num_mpi_chains)); /** * Dynamic loader that manages master & slave * communication and data assembly. */ - struct warmup_dynamic_loader_base { - static const int work_tag = 1; - static const int err_tag = 2; - static const int adapt_tag = 3; + struct mpi_loader_base { + static const int work_tag = 1; + static const int err_tag = 2; + static const int done_tag = 3; //! communicator wrapper for warmup - const Communicator& warmup_comm; + const Communicator& comm; //! MPI communicator - const MPI_Comm comm; - //! communication interval - int interval; + const MPI_Comm mpi_comm; //! double workspace Eigen::MatrixXd workspace_r; //! construct loader given MPI communicator - warmup_dynamic_loader_base(const Communicator& comm_in, int inter) : - warmup_comm(comm_in), comm(warmup_comm.comm), - interval(inter) + mpi_loader_base(const Communicator& comm_in) : + comm(comm_in), mpi_comm(comm.comm) { // make sure there are slave chains. - static const char* caller = "warmup_dynamic_loader"; - stan::math::check_greater(caller, "MPI comm size", warmup_comm.size, 1); + static const char* caller = "MPI load balance initialization"; + stan::math::check_greater(caller, "MPI comm size", comm.size, 1); } }; @@ -147,13 +236,20 @@ namespace mpi { * in order to improve the quality of adaptation, before * sending the improved adapt info to slave chains. */ - struct warmup_dynamic_loader_master : warmup_dynamic_loader_base { + struct warmup_dynamic_loader_master : mpi_loader_base { + MPI_Request bcast_req; + int interval; + //! construct loader given MPI communicator - warmup_dynamic_loader_master(const Communicator& comm_in, - int inter) : - warmup_dynamic_loader_base(comm_in, inter) + warmup_dynamic_loader_master(const Communicator& comm_in, int inter) : + mpi_loader_base(comm_in), interval(inter) {} + //! during destruction ensure MPI request is fulfilled. + ~warmup_dynamic_loader_master() { + MPI_Wait(&bcast_req, MPI_STATUS_IGNORE); + } + /* * helper function to master node (rank = 0) to recv * results. @@ -183,12 +279,12 @@ namespace mpi { * available tasks to vacant slaves. */ template + typename Send_processor, typename Recv_processor> void operator()(Sampler& sampler, Model& model, + const stan::mcmc::sample& sample, + stan::callbacks::logger& logger, Send_processor& ensemble_func, - Recv_processor& chain_func, - Post_processor& post_func) { + Recv_processor& chain_func) { static const char* caller = "warmup_dynamic_loader_master::master"; stan::math::check_less(caller, "MPI comm rank", warmup_comm.rank, 1); @@ -210,8 +306,6 @@ namespace mpi { irecve++; } - MPI_Request bcast_req; - if (is_invalid) { for (int i = 1; i < warmup_comm.size; ++i) { MPI_Send(workspace_r.data(), 0, MPI_DOUBLE, i, err_tag, comm); @@ -226,27 +320,24 @@ namespace mpi { throw std::runtime_error(chain_adapt_fail_msg.str()); } else { for (int i = 1; i < warmup_comm.size; ++i) { - MPI_Send(workspace_r.data(), 0, MPI_DOUBLE, i, adapt_tag, comm); + MPI_Send(workspace_r.data(), 0, MPI_DOUBLE, i, done_tag, comm); } while (recved != nslave) { int index, flag = 0; MPI_Testany(nslave, req.data(), &index, &flag, MPI_STATUS_IGNORE); if(flag) { recved++; - chain_func(sampler, model, workspace_r, index + 1); + chain_func(sampler, model, sample, logger, workspace_r, index + 1); } } - ensemble_func(sampler, model, workspace_r); + ensemble_func(sampler, model, sample, logger, workspace_r); MPI_Ibcast(workspace_r.data(), ensemble_func.send_size(sampler, model), MPI_DOUBLE, 0, comm, &bcast_req); } - - post_func(sampler, model, workspace_r); - MPI_Wait(&bcast_req, MPI_STATUS_IGNORE); } }; - struct warmup_dynamic_loader_slave : warmup_dynamic_loader_base { + struct warmup_dynamic_loader_slave : mpi_loader_base { //! construct loader given MPI communicator warmup_dynamic_loader_slave(const Communicator& comm_in, int inter) : @@ -259,6 +350,8 @@ namespace mpi { */ template void operator()(Sampler& sampler, Model& model, + const stan::mcmc::sample& sample, + stan::callbacks::logger& logger, Send_processor& chain_func, Recv_processor& adapt_func) { using Eigen::MatrixXd; @@ -268,7 +361,7 @@ namespace mpi { stan::math::check_greater(caller, "MPI comm rank", warmup_comm.rank, 0); // process adapt info before sending out to master - chain_func(sampler, model, workspace_r, warmup_comm.rank); + chain_func(sampler, model, sample, logger, workspace_r, warmup_comm.rank); MPI_Send(workspace_r.data(), chain_func.send_size(sampler, model), MPI_DOUBLE, 0, work_tag, comm); MPI_Status stat; @@ -278,12 +371,12 @@ namespace mpi { std::ostringstream chain_adapt_fail_msg; chain_adapt_fail_msg << "Invalid adaptation data in ensemble"; throw std::runtime_error(chain_adapt_fail_msg.str()); - } else if (stat.MPI_TAG == adapt_tag) { + } else if (stat.MPI_TAG == done_tag) { workspace_r.resize(adapt_func.recv_size(sampler, model), 1); MPI_Ibcast(workspace_r.data(), adapt_func.recv_size(sampler, model), MPI_DOUBLE, 0, comm, &bcast_req); MPI_Wait(&bcast_req, MPI_STATUS_IGNORE); - adapt_func(sampler, model, workspace_r); + adapt_func(sampler, model, sample, logger, workspace_r); } } }; diff --git a/src/test/unit/services/util/mpi_warmup_test.cpp b/src/test/unit/services/util/mpi_warmup_test.cpp index 26bbce068e9..cdd5421a389 100644 --- a/src/test/unit/services/util/mpi_warmup_test.cpp +++ b/src/test/unit/services/util/mpi_warmup_test.cpp @@ -22,6 +22,8 @@ struct dummy_master_ensemble_processor { template void operator()(Sampler& sampler, Model& model, + const stan::mcmc::sample& sample, + stan::callbacks::logger& logger, Eigen::MatrixXd& workspace_r) { for (int i = 0; i < workspace_r.cols(); ++i) { workspace_r(i, 0) = workspace_r(i + 1, i); @@ -37,18 +39,13 @@ struct dummy_master_chain_processor { template void operator()(Sampler& sampler, Model& model, + const stan::mcmc::sample& sample, + stan::callbacks::logger& logger, Eigen::MatrixXd& workspace_r, int index) { workspace_r(index, index - 1) *= 2.5; } }; -struct dummy_master_post_processor { - template - void operator()(Sampler& sampler, Model& model, - Eigen::MatrixXd& workspace_r) { - } -}; - struct dummy_slave_chain_processor { template int send_size(Sampler& sampler, Model& model) { @@ -57,6 +54,8 @@ struct dummy_slave_chain_processor { template void operator()(Sampler& sampler, Model& model, + const stan::mcmc::sample& sample, + stan::callbacks::logger& logger, Eigen::MatrixXd& workspace_r, int index) { workspace_r.resize(send_size(sampler, model), 1); workspace_r.setZero(); @@ -72,41 +71,168 @@ struct dummy_slave_adapt_processor { template void operator()(Sampler& sampler, Model& model, + const stan::mcmc::sample& sample, + stan::callbacks::logger& logger, Eigen::MatrixXd& workspace_r) { workspace_r *= 1.5; } }; -TEST(mpi_warmup_test, mpi_session) { - const Communicator& warmup_comm = - Session::comms[0]; - - warmup_dynamic_loader_base load(warmup_comm, 10); -} - -TEST(mpi_warmup_test, mpi_master_slave) { - const Communicator& warmup_comm = - Session::comms[0]; - - dummy_sampler sampler; - dummy_model model; - - if (warmup_comm.rank == 0) { - warmup_dynamic_loader_master master(warmup_comm, 10); - dummy_master_ensemble_processor f; - dummy_master_chain_processor g; - dummy_master_post_processor h; - EXPECT_NO_THROW(master(sampler, model, f, g, h)); - EXPECT_FLOAT_EQ(master.workspace_r(0), 2.5); - EXPECT_FLOAT_EQ(master.workspace_r(1), 5.0); - } else { - warmup_dynamic_loader_slave slave(warmup_comm, 10); - dummy_slave_chain_processor f; - dummy_slave_adapt_processor g; - EXPECT_NO_THROW(slave(sampler, model, f, g)); - EXPECT_FLOAT_EQ(slave.workspace_r(0), 3.75); - EXPECT_FLOAT_EQ(slave.workspace_r(1), 7.50); +TEST(mpi_warmup_test, mpi_inter_intra_comms) { + const Communicator world_comm(MPI_COMM_STAN); + const Communicator inter_comm(Session<3>::MPI_COMM_INTER_CHAIN); + const Communicator intra_comm(Session<3>::MPI_COMM_INTRA_CHAIN); + if (world_comm.size == 3) { + switch (world_comm.rank) { + case 0: + EXPECT_EQ(inter_comm.rank, 0); + EXPECT_EQ(intra_comm.rank, 0); + break; + case 1: + EXPECT_EQ(inter_comm.rank, 1); + EXPECT_EQ(intra_comm.rank, 0); + break; + case 2: + EXPECT_EQ(inter_comm.rank, 2); + EXPECT_EQ(intra_comm.rank, 0); + break; + } + } else if (world_comm.size == 4) { + switch (world_comm.rank) { + case 0: + EXPECT_EQ(inter_comm.rank, 0); + EXPECT_EQ(intra_comm.rank, 0); + break; + case 1: + EXPECT_EQ(inter_comm.rank, -1); + EXPECT_EQ(intra_comm.rank, 1); + break; + case 2: + EXPECT_EQ(inter_comm.rank, 1); + EXPECT_EQ(intra_comm.rank, 0); + break; + case 3: + EXPECT_EQ(inter_comm.rank, 2); + EXPECT_EQ(intra_comm.rank, 0); + break; + } + } else if (world_comm.size == 5) { + switch (world_comm.rank) { + case 0: + EXPECT_EQ(inter_comm.rank, 0); + EXPECT_EQ(intra_comm.rank, 0); + break; + case 1: + EXPECT_EQ(inter_comm.rank, -1); + EXPECT_EQ(intra_comm.rank, 1); + break; + case 2: + EXPECT_EQ(inter_comm.rank, 1); + EXPECT_EQ(intra_comm.rank, 0); + break; + case 3: + EXPECT_EQ(inter_comm.rank, -1); + EXPECT_EQ(intra_comm.rank, 1); + break; + case 4: + EXPECT_EQ(inter_comm.rank, 2); + EXPECT_EQ(intra_comm.rank, 0); + break; + } + } else if (world_comm.size == 6) { + switch (world_comm.rank) { + case 0: + EXPECT_EQ(inter_comm.rank, 0); + EXPECT_EQ(intra_comm.rank, 0); + break; + case 1: + EXPECT_EQ(inter_comm.rank, -1); + EXPECT_EQ(intra_comm.rank, 1); + break; + case 2: + EXPECT_EQ(inter_comm.rank, 1); + EXPECT_EQ(intra_comm.rank, 0); + break; + case 3: + EXPECT_EQ(inter_comm.rank, -1); + EXPECT_EQ(intra_comm.rank, 1); + break; + case 4: + EXPECT_EQ(inter_comm.rank, 2); + EXPECT_EQ(intra_comm.rank, 0); + break; + case 5: + EXPECT_EQ(inter_comm.rank, -1); + EXPECT_EQ(intra_comm.rank, 1); + break; + } + } else if (world_comm.size == 7) { + switch (world_comm.rank) { + case 0: + EXPECT_EQ(inter_comm.rank, 0); + EXPECT_EQ(intra_comm.rank, 0); + break; + case 1: + EXPECT_EQ(inter_comm.rank, -1); + EXPECT_EQ(intra_comm.rank, 1); + break; + case 2: + EXPECT_EQ(inter_comm.rank, -1); + EXPECT_EQ(intra_comm.rank, 2); + break; + case 3: + EXPECT_EQ(inter_comm.rank, 1); + EXPECT_EQ(intra_comm.rank, 0); + break; + case 4: + EXPECT_EQ(inter_comm.rank, -1); + EXPECT_EQ(intra_comm.rank, 1); + break; + case 5: + EXPECT_EQ(inter_comm.rank, 2); + EXPECT_EQ(intra_comm.rank, 0); + break; + case 6: + EXPECT_EQ(inter_comm.rank, -1); + EXPECT_EQ(intra_comm.rank, 1); + break; + } } } +// TEST(mpi_warmup_test, mpi_warmup) { +// const Communicator world_comm(MPI_COMM_STAN); +// const Communicator inter_comm(Session<3>::MPI_COMM_INTER_CHAIN); +// const Communicator intra_comm(Session<3>::MPI_COMM_INTRA_CHAIN); +// // +// // warmup_dynamic_loader_base load(warmup_comm, 10); + +// } + +// TEST(mpi_warmup_test, mpi_master_slave) { +// const Communicator& warmup_comm = +// Session::comms[0]; + +// dummy_sampler sampler; +// dummy_model model; +// stan::mcmc::sample sample(Eigen::VectorXd(0), 0, 0); +// stan::callbacks::logger logger; + +// if (warmup_comm.rank == 0) { +// warmup_dynamic_loader_master master(warmup_comm, 10); +// dummy_master_ensemble_processor f; +// dummy_master_chain_processor g; +// EXPECT_NO_THROW(master(sampler, model, sample, logger, f, g)); +// EXPECT_FLOAT_EQ(master.workspace_r(0), 2.5); +// EXPECT_FLOAT_EQ(master.workspace_r(1), 5.0); +// } else { +// warmup_dynamic_loader_slave slave(warmup_comm, 10); +// dummy_slave_chain_processor f; +// dummy_slave_adapt_processor g; +// EXPECT_NO_THROW(slave(sampler, model, sample, logger, f, g)); +// EXPECT_FLOAT_EQ(slave.workspace_r(0), 3.75); +// EXPECT_FLOAT_EQ(slave.workspace_r(1), 7.50); +// } +// } + #endif From a98744d60c3618a2836d766ae7cffad5156314ca Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 13 Nov 2019 17:23:53 -0800 Subject: [PATCH 09/73] warmup loader MPI test --- src/stan/services/util/mpi.hpp | 205 +++++++----------- .../unit/services/util/mpi_warmup_test.cpp | 173 ++++++--------- 2 files changed, 149 insertions(+), 229 deletions(-) diff --git a/src/stan/services/util/mpi.hpp b/src/stan/services/util/mpi.hpp index 14e49760f98..5e0520da2fe 100644 --- a/src/stan/services/util/mpi.hpp +++ b/src/stan/services/util/mpi.hpp @@ -225,8 +225,11 @@ namespace mpi { comm(comm_in), mpi_comm(comm.comm) { // make sure there are slave chains. - static const char* caller = "MPI load balance initialization"; - stan::math::check_greater(caller, "MPI comm size", comm.size, 1); + // comm.rank == -1 indicates non inter-comm node + if (comm.rank >= 0) { + static const char* caller = "MPI load balance initialization"; + stan::math::check_greater(caller, "MPI comm size", comm.size, 1); + } } }; @@ -236,147 +239,97 @@ namespace mpi { * in order to improve the quality of adaptation, before * sending the improved adapt info to slave chains. */ - struct warmup_dynamic_loader_master : mpi_loader_base { - MPI_Request bcast_req; + struct mpi_warmup { + mpi_loader_base& loader; + Eigen::MatrixXd& workspace_r; int interval; + MPI_Request req; + bool is_inter_comm_node; //! construct loader given MPI communicator - warmup_dynamic_loader_master(const Communicator& comm_in, int inter) : - mpi_loader_base(comm_in), interval(inter) + mpi_warmup(mpi_loader_base& l, int inter) : + loader(l), workspace_r(l.workspace_r), interval(inter), + is_inter_comm_node(loader.comm.size > 0) {} - //! during destruction ensure MPI request is fulfilled. - ~warmup_dynamic_loader_master() { - MPI_Wait(&bcast_req, MPI_STATUS_IGNORE); - } + ~mpi_warmup() {} /* - * helper function to master node (rank = 0) to recv - * results. - * @return MPI_Status of recv operation - */ - template - MPI_Status - recv(std::vector& req, Recv_processor& chain_func, - Sampler& sampler, Model& model) { - MPI_Status stat; - MPI_Probe(MPI_ANY_SOURCE, MPI_ANY_TAG, comm, &stat); - int source = stat.MPI_SOURCE; - int ireq = source - 1; - if (stat.MPI_TAG == err_tag) { - double dummy; - MPI_Irecv(&dummy, 0, MPI_DOUBLE, source, err_tag, comm, &req[ireq]); - } else { - int n = chain_func.recv_size(sampler, model); - MPI_Irecv(&workspace_r((source - 1) * n), n, - MPI_DOUBLE, source, work_tag, comm, &req[ireq]); - } - return stat; - } - - /* - * master node (rank = 0) recv results and send - * available tasks to vacant slaves. + * run transitions and process each chain's adaptation + * information before sending it to others. + * + * @tparam Sampler sampler used + * @tparam Model model struct + * @tparam S functor that process the adaptation + * information and return it as a vector. + * @tparam F functor that does transitions. + * @tparam Ts args of @c F. */ template + typename S, + typename F, typename... Ts> void operator()(Sampler& sampler, Model& model, - const stan::mcmc::sample& sample, - stan::callbacks::logger& logger, - Send_processor& ensemble_func, - Recv_processor& chain_func) { - static const char* caller = "warmup_dynamic_loader_master::master"; - stan::math::check_less(caller, "MPI comm rank", warmup_comm.rank, 1); - - int nslave = warmup_comm.size - 1; - std::vector req(nslave); - std::array recv_out; - - int recved = 0; - int irecve = 0; - int source; - bool is_invalid = false; - while (irecve != nslave && (!is_invalid)) { - // recv adaption results from certain chain - workspace_r.resize(chain_func.recv_size(sampler, model), - nslave); - MPI_Status stat(recv(req, chain_func, sampler, model)); - is_invalid = stat.MPI_TAG == err_tag; - source = stat.MPI_SOURCE; - irecve++; + stan::mcmc::sample& sample, + const S& fs, F& f, Ts... pars) { + if (is_inter_comm_node) { + f(pars...); + + const int rank = loader.comm.rank; + const int mpi_size = loader.comm.size; + const int size = S::size(sampler, model, sample); + workspace_r.resize(size, mpi_size); + Eigen::VectorXd work(fs(sampler, model, sample)); + MPI_Iallgather(work.data(), size, MPI_DOUBLE, + workspace_r.data(), size, MPI_DOUBLE, + loader.mpi_comm, &req); + } + } - if (is_invalid) { - for (int i = 1; i < warmup_comm.size; ++i) { - MPI_Send(workspace_r.data(), 0, MPI_DOUBLE, i, err_tag, comm); - } - while (irecve != nslave) { - recv(req, chain_func, sampler, model); - irecve++; - } - MPI_Waitall(nslave, req.data(), MPI_STATUSES_IGNORE); - std::ostringstream chain_adapt_fail_msg; - chain_adapt_fail_msg << "Invalid adaptation data in Chain " << source; - throw std::runtime_error(chain_adapt_fail_msg.str()); - } else { - for (int i = 1; i < warmup_comm.size; ++i) { - MPI_Send(workspace_r.data(), 0, MPI_DOUBLE, i, done_tag, comm); - } - while (recved != nslave) { - int index, flag = 0; - MPI_Testany(nslave, req.data(), &index, &flag, MPI_STATUS_IGNORE); - if(flag) { - recved++; - chain_func(sampler, model, sample, logger, workspace_r, index + 1); - } - } - ensemble_func(sampler, model, sample, logger, workspace_r); - MPI_Ibcast(workspace_r.data(), ensemble_func.send_size(sampler, model), - MPI_DOUBLE, 0, comm, &bcast_req); + /* + * check if the MPI communication is finished. While + * waiting, keep doing transitions. When communication + * is done, generate updated adaptation information and + * update sampler. + * + * @tparam Sampler sampler used + * @tparam Model model struct + * @tparam S functor that update sampler with new adaptation . + * @tparam F functor that does transitions. + * @tparam Ts args of @c F. + */ + void finalize() { + if (is_inter_comm_node) { + MPI_Wait(&req, MPI_STATUS_IGNORE); } } - }; - - struct warmup_dynamic_loader_slave : mpi_loader_base { - //! construct loader given MPI communicator - warmup_dynamic_loader_slave(const Communicator& comm_in, - int inter) : - warmup_dynamic_loader_base(comm_in, inter) - {} /* - * master node (rank = 0) recv results and send - * available tasks to vacant slaves. + * check if the MPI communication is finished. While + * waiting, keep doing transitions. When communication + * is done, generate updated adaptation information and + * update sampler. + * + * @tparam Sampler sampler used + * @tparam Model model struct + * @tparam S functor that update sampler with new adaptation . + * @tparam F functor that does transitions. + * @tparam Ts args of @c F. */ - template - void operator()(Sampler& sampler, Model& model, - const stan::mcmc::sample& sample, - stan::callbacks::logger& logger, - Send_processor& chain_func, - Recv_processor& adapt_func) { - using Eigen::MatrixXd; - using Eigen::Matrix; - - static const char* caller = "warmup_dynamic_loader_slave::slave"; - stan::math::check_greater(caller, "MPI comm rank", warmup_comm.rank, 0); - - // process adapt info before sending out to master - chain_func(sampler, model, sample, logger, workspace_r, warmup_comm.rank); - MPI_Send(workspace_r.data(), chain_func.send_size(sampler, model), MPI_DOUBLE, 0, work_tag, comm); - - MPI_Status stat; - MPI_Request bcast_req; - MPI_Recv(workspace_r.data(), 0, MPI_DOUBLE, 0, MPI_ANY_TAG, comm, &stat); - if (stat.MPI_TAG == err_tag) { - std::ostringstream chain_adapt_fail_msg; - chain_adapt_fail_msg << "Invalid adaptation data in ensemble"; - throw std::runtime_error(chain_adapt_fail_msg.str()); - } else if (stat.MPI_TAG == done_tag) { - workspace_r.resize(adapt_func.recv_size(sampler, model), 1); - MPI_Ibcast(workspace_r.data(), adapt_func.recv_size(sampler, model), - MPI_DOUBLE, 0, comm, &bcast_req); - MPI_Wait(&bcast_req, MPI_STATUS_IGNORE); - adapt_func(sampler, model, sample, logger, workspace_r); + template + void finalize(Sampler& sampler, Model& model, + stan::mcmc::sample& sample, + const S& fs, F& f, Ts... pars) { + if (is_inter_comm_node) { + int flag = 0; + MPI_Test(&req, &flag, MPI_STATUS_IGNORE); + while (flag == 0) { + f(pars...); + MPI_Test(&req, &flag, MPI_STATUS_IGNORE); + } + fs(workspace_r, sampler, model, sample); } } }; diff --git a/src/test/unit/services/util/mpi_warmup_test.cpp b/src/test/unit/services/util/mpi_warmup_test.cpp index cdd5421a389..a41f9d040b6 100644 --- a/src/test/unit/services/util/mpi_warmup_test.cpp +++ b/src/test/unit/services/util/mpi_warmup_test.cpp @@ -8,75 +8,8 @@ using Eigen::Matrix; using std::vector; using stan::services::util::mpi::Communicator; using stan::services::util::mpi::Session; -using stan::services::util::mpi::warmup_dynamic_loader_base; -using stan::services::util::mpi::warmup_dynamic_loader_master; -using stan::services::util::mpi::warmup_dynamic_loader_slave; - -struct dummy_sampler {}; -struct dummy_model {}; -struct dummy_master_ensemble_processor { - template - int send_size(Sampler& sampler, Model& model) { - return 10; - } - - template - void operator()(Sampler& sampler, Model& model, - const stan::mcmc::sample& sample, - stan::callbacks::logger& logger, - Eigen::MatrixXd& workspace_r) { - for (int i = 0; i < workspace_r.cols(); ++i) { - workspace_r(i, 0) = workspace_r(i + 1, i); - } - } -}; - -struct dummy_master_chain_processor { - template - int recv_size(Sampler& sampler, Model& model) { - return 10; - } - - template - void operator()(Sampler& sampler, Model& model, - const stan::mcmc::sample& sample, - stan::callbacks::logger& logger, - Eigen::MatrixXd& workspace_r, int index) { - workspace_r(index, index - 1) *= 2.5; - } -}; - -struct dummy_slave_chain_processor { - template - int send_size(Sampler& sampler, Model& model) { - return 10; - } - - template - void operator()(Sampler& sampler, Model& model, - const stan::mcmc::sample& sample, - stan::callbacks::logger& logger, - Eigen::MatrixXd& workspace_r, int index) { - workspace_r.resize(send_size(sampler, model), 1); - workspace_r.setZero(); - workspace_r(index) = index; - } -}; - -struct dummy_slave_adapt_processor { - template - int recv_size(Sampler& sampler, Model& model) { - return 10; - } - - template - void operator()(Sampler& sampler, Model& model, - const stan::mcmc::sample& sample, - stan::callbacks::logger& logger, - Eigen::MatrixXd& workspace_r) { - workspace_r *= 1.5; - } -}; +using stan::services::util::mpi::mpi_loader_base; +using stan::services::util::mpi::mpi_warmup; TEST(mpi_warmup_test, mpi_inter_intra_comms) { const Communicator world_comm(MPI_COMM_STAN); @@ -200,39 +133,73 @@ TEST(mpi_warmup_test, mpi_inter_intra_comms) { } } -// TEST(mpi_warmup_test, mpi_warmup) { -// const Communicator world_comm(MPI_COMM_STAN); -// const Communicator inter_comm(Session<3>::MPI_COMM_INTER_CHAIN); -// const Communicator intra_comm(Session<3>::MPI_COMM_INTRA_CHAIN); -// // -// // warmup_dynamic_loader_base load(warmup_comm, 10); - -// } - -// TEST(mpi_warmup_test, mpi_master_slave) { -// const Communicator& warmup_comm = -// Session::comms[0]; - -// dummy_sampler sampler; -// dummy_model model; -// stan::mcmc::sample sample(Eigen::VectorXd(0), 0, 0); -// stan::callbacks::logger logger; - -// if (warmup_comm.rank == 0) { -// warmup_dynamic_loader_master master(warmup_comm, 10); -// dummy_master_ensemble_processor f; -// dummy_master_chain_processor g; -// EXPECT_NO_THROW(master(sampler, model, sample, logger, f, g)); -// EXPECT_FLOAT_EQ(master.workspace_r(0), 2.5); -// EXPECT_FLOAT_EQ(master.workspace_r(1), 5.0); -// } else { -// warmup_dynamic_loader_slave slave(warmup_comm, 10); -// dummy_slave_chain_processor f; -// dummy_slave_adapt_processor g; -// EXPECT_NO_THROW(slave(sampler, model, sample, logger, f, g)); -// EXPECT_FLOAT_EQ(slave.workspace_r(0), 3.75); -// EXPECT_FLOAT_EQ(slave.workspace_r(1), 7.50); -// } -// } +struct send_processor { + const Communicator& comm; + + send_processor(const Communicator& comm_in) : + comm(comm_in) + {} + + template + static int size(const Sampler& sampler, const Model& model, + stan::mcmc::sample& sample) { + return 10; + } + + template + Eigen::VectorXd operator()(Sampler& sampler, Model& model, stan::mcmc::sample& sample) const { + Eigen::VectorXd x(Eigen::VectorXd::Zero(size(sampler, model, sample))); + x(comm.rank) = comm.rank; + return x; + } +}; + +struct adapt_processor { + const Communicator& comm; + + adapt_processor(const Communicator& comm_in) : + comm(comm_in) + {} + + template + void operator()(const Eigen::MatrixXd& workspace_r, Sampler& sampler, Model& model, stan::mcmc::sample& sample) const { + for (int i = 0; i < workspace_r.cols(); ++i) { + EXPECT_FLOAT_EQ(workspace_r(i, i), double(i)); + } + double sum1 = 0.5 * comm.size * (comm.size - 1); + double sum2 = workspace_r.sum(); + EXPECT_FLOAT_EQ(sum1, sum2); + } +}; + +struct dummy_transition { + template + void operator()(Sampler& sampler, Model& model, stan::mcmc::sample& sample) { + } + + void operator()() { + } +}; + +TEST(mpi_warmup_test, mpi_warmup_loader) { + const Communicator inter_comm(Session<3>::MPI_COMM_INTER_CHAIN); + mpi_loader_base loader(inter_comm); + + Eigen::MatrixXd dummy_sampler; + Eigen::MatrixXd dummy_model; + stan::mcmc::sample sample(Eigen::VectorXd(0), 0, 0); + mpi_warmup mpi_warmup_adapt(loader, 10); + + send_processor fs(inter_comm); + adapt_processor fd(inter_comm); + dummy_transition f; + + mpi_warmup_adapt(dummy_sampler, dummy_model, sample, fs, + f, dummy_sampler, dummy_model, sample); + + mpi_warmup_adapt.finalize(dummy_sampler, dummy_model, sample, fd, f); + + mpi_warmup_adapt.finalize(); +} #endif From 8d09106c8123e3574f69728ba6a83b7e74735599 Mon Sep 17 00:00:00 2001 From: yiz Date: Fri, 15 Nov 2019 07:57:44 -0800 Subject: [PATCH 10/73] unit test with mpi gathering of stepsize --- src/stan/services/util/mpi.hpp | 23 +--- .../unit/services/util/mpi_warmup_test.cpp | 117 +++++++++++++++++- 2 files changed, 118 insertions(+), 22 deletions(-) diff --git a/src/stan/services/util/mpi.hpp b/src/stan/services/util/mpi.hpp index 5e0520da2fe..2b6b93df1f3 100644 --- a/src/stan/services/util/mpi.hpp +++ b/src/stan/services/util/mpi.hpp @@ -244,7 +244,7 @@ namespace mpi { Eigen::MatrixXd& workspace_r; int interval; MPI_Request req; - bool is_inter_comm_node; + const bool is_inter_comm_node; //! construct loader given MPI communicator mpi_warmup(mpi_loader_base& l, int inter) : @@ -290,25 +290,8 @@ namespace mpi { * check if the MPI communication is finished. While * waiting, keep doing transitions. When communication * is done, generate updated adaptation information and - * update sampler. - * - * @tparam Sampler sampler used - * @tparam Model model struct - * @tparam S functor that update sampler with new adaptation . - * @tparam F functor that does transitions. - * @tparam Ts args of @c F. - */ - void finalize() { - if (is_inter_comm_node) { - MPI_Wait(&req, MPI_STATUS_IGNORE); - } - } - - /* - * check if the MPI communication is finished. While - * waiting, keep doing transitions. When communication - * is done, generate updated adaptation information and - * update sampler. + * update sampler. This function must be called before + * exiting the scope in which @c mpi_warmup obj is declared. * * @tparam Sampler sampler used * @tparam Model model struct diff --git a/src/test/unit/services/util/mpi_warmup_test.cpp b/src/test/unit/services/util/mpi_warmup_test.cpp index a41f9d040b6..70b29f63edf 100644 --- a/src/test/unit/services/util/mpi_warmup_test.cpp +++ b/src/test/unit/services/util/mpi_warmup_test.cpp @@ -3,6 +3,16 @@ #include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include + using Eigen::MatrixXd; using Eigen::Matrix; using std::vector; @@ -143,7 +153,7 @@ struct send_processor { template static int size(const Sampler& sampler, const Model& model, stan::mcmc::sample& sample) { - return 10; + return 1000000; } template @@ -198,8 +208,111 @@ TEST(mpi_warmup_test, mpi_warmup_loader) { f, dummy_sampler, dummy_model, sample); mpi_warmup_adapt.finalize(dummy_sampler, dummy_model, sample, fd, f); +} + + +struct send_adapt_processor { + const Communicator& comm; + + send_adapt_processor(const Communicator& comm_in) : + comm(comm_in) + {} + + template + static int size(const Sampler& sampler, const Model& model, + stan::mcmc::sample& sample) { + return 1; + } + + template + Eigen::VectorXd operator()(Sampler& sampler, Model& model, stan::mcmc::sample& sample) const { + Eigen::VectorXd x(Eigen::VectorXd::Zero(size(sampler, model, sample))); + x(0) = sampler.get_nominal_stepsize() + 0.01 * comm.rank; + return x; + } +}; + +struct warmup_processor { +template +void operator()(stan::mcmc::base_mcmc& sampler, int num_iterations, + int start, int finish, int num_thin, int refresh, bool save, + stan::services::util::mcmc_writer& mcmc_writer, + stan::mcmc::sample& s, Model& model, + RNG& base_rng, stan::callbacks::interrupt& callback, + stan::callbacks::logger& logger) { + stan::services::util::generate_transitions(sampler, num_iterations, start, finish, + num_thin, refresh, save, true, mcmc_writer, s, + model, base_rng, callback, logger); +} +}; + +struct collect_adapt_processor { + const Communicator& comm; + + collect_adapt_processor(const Communicator& comm_in) : + comm(comm_in) + {} + + template + void operator()(const Eigen::MatrixXd& workspace_r, Sampler& sampler, Model& model, stan::mcmc::sample& sample) const { + EXPECT_EQ(workspace_r.cols(), comm.size); + for (int i = 0; i < comm.size; ++i) { + EXPECT_FLOAT_EQ(workspace_r(0, i), 0.01 * i + workspace_r(0, 0)); + } + } +}; + +TEST(mpi_warmup_test, unit_e_nuts) { + using Model = gauss3D_model_namespace::gauss3D_model; + using Sampler = stan::mcmc::adapt_unit_e_nuts; + boost::ecuyer1988 rng(4839294); + + stan::mcmc::unit_e_point z_init(3); + z_init.q(0) = 1; + z_init.q(1) = -1; + z_init.q(2) = 1; + z_init.p(0) = -1; + z_init.p(1) = 1; + z_init.p(2) = -1; + + std::stringstream debug, info, warn, error, fatal; + stan::callbacks::stream_logger logger(debug, info, warn, error, fatal); + + std::fstream empty_stream("", std::fstream::in); + stan::io::dump data_var_context(empty_stream); + Model model(data_var_context); + + Sampler sampler(model, rng); + sampler.z() = z_init; + sampler.init_hamiltonian(logger); + sampler.set_nominal_stepsize(0.1); + sampler.set_stepsize_jitter(0); + sampler.sample_stepsize(); + + stan::mcmc::sample s(z_init.q, 0, 0); + + stan::callbacks::writer sample_writer; + stan::callbacks::writer diagnostic_writer; + stan::services::util::mcmc_writer writer(sample_writer, diagnostic_writer, logger); + stan::callbacks::interrupt interrupt; + + stan::services::util::generate_transitions(sampler, 10, 0, 20, + 1, 0, false, true, writer, s, + model, rng, interrupt, logger); + + const Communicator inter_comm(Session<3>::MPI_COMM_INTER_CHAIN); + mpi_loader_base loader(inter_comm); + mpi_warmup mpi_warmup_adapt(loader, 10); + warmup_processor f_warmup; + send_adapt_processor fs(inter_comm); + mpi_warmup_adapt(sampler, model, s, fs, + f_warmup, + sampler, 10, 0, 20, 1, 0, false, writer, s, model, rng, interrupt, logger); - mpi_warmup_adapt.finalize(); + collect_adapt_processor fd(inter_comm); + mpi_warmup_adapt.finalize(sampler, model, s, fd, + f_warmup, + sampler, 10, 0, 20, 1, 0, false, writer, s, model, rng, interrupt, logger); } #endif From b69492ba98ef00399a296b7b848cdf0ab09a6d22 Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 18 Dec 2019 14:24:34 -0800 Subject: [PATCH 11/73] update math submodule to 7f4e3a4af5 --- lib/stan_math | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/stan_math b/lib/stan_math index 025a142ec01..7f4e3a4af56 160000 --- a/lib/stan_math +++ b/lib/stan_math @@ -1 +1 @@ -Subproject commit 025a142ec01b68e91adf339a9b86d67e6d0e20ee +Subproject commit 7f4e3a4af56d8471606b414ed59364898638b323 From 1e116504220875c8b2ac81e028f3f9324f750d19 Mon Sep 17 00:00:00 2001 From: yiz Date: Tue, 7 Jan 2020 16:35:29 -0800 Subject: [PATCH 12/73] rhat as convergence critierior for mpi warmup --- .../services/sample/hmc_nuts_unit_e_adapt.hpp | 6 + src/stan/services/util/campfire_warmup.hpp | 175 ++++++++++++++++++ .../util/run_mpi_adaptive_sampler.hpp | 90 +++++++++ 3 files changed, 271 insertions(+) create mode 100644 src/stan/services/util/campfire_warmup.hpp create mode 100644 src/stan/services/util/run_mpi_adaptive_sampler.hpp diff --git a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp index adc84353d61..ef255e0b63c 100644 --- a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp @@ -71,9 +71,15 @@ int hmc_nuts_unit_e_adapt( sampler.get_stepsize_adaptation().set_kappa(kappa); sampler.get_stepsize_adaptation().set_t0(t0); +#ifdef MPI_ADAPTED_WARMUP + util::run_mpi_adaptive_sampler( + sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, + save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); +#else util::run_adaptive_sampler( sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); +#endif return error_codes::OK; } diff --git a/src/stan/services/util/campfire_warmup.hpp b/src/stan/services/util/campfire_warmup.hpp new file mode 100644 index 00000000000..c68ed41704e --- /dev/null +++ b/src/stan/services/util/campfire_warmup.hpp @@ -0,0 +1,175 @@ +#ifndef STAN_SERVICES_UTIL_CAMPFIRE_WARMUP_HPP +#define STAN_SERVICES_UTIL_CAMPFIRE_WARMUP_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace services { +namespace util { + +/** + * Generates MCMC transitions. + * + * @tparam Model model class + * @tparam RNG random number generator class + * @param[in,out] sampler MCMC sampler used to generate transitions + * @param[in] num_iterations number of MCMC transitions + * @param[in] start starting iteration number used for printing messages + * @param[in] finish end iteration number used for printing messages + * @param[in] num_thin when save is true, a draw will be written to the + * mcmc_writer every num_thin iterations + * @param[in] refresh number of iterations to print a message. If + * refresh is zero, iteration number messages will not be printed + * @param[in] save if save is true, the transitions will be written + * to the mcmc_writer. If false, transitions will not be written + * @param[in] warmup indicates whether these transitions are warmup. Used + * for printing iteration number messages + * @param[in,out] mcmc_writer writer to handle mcmc otuput + * @param[in,out] init_s starts as the initial unconstrained parameter + * values. When the function completes, this will have the final + * iteration's unconstrained parameter values + * @param[in] model model + * @param[in,out] base_rng random number generator + * @param[in,out] callback interrupt callback called once an iteration + * @param[in,out] logger logger for messages + */ +template +void campfire_warmup(stan::mcmc::base_mcmc& sampler, int num_iterations, + int start, int finish, int num_thin, int refresh, + bool save, bool warmup, + util::mcmc_writer& mcmc_writer, + stan::mcmc::sample& init_s, Model& model, + RNG& base_rng, callbacks::interrupt& callback, + callbacks::logger& logger) { + // for prototyping, we have @c max_num_windows fixed + const int window_size = 100; + const int max_num_windows = num_iterations / window_size; + const int num_chains; + + // rhat for each window combination, e.g. ABCD, BCD, CD, D for 4 windows. + std::vector adapt_quantity(max_num_windows, 0.0); + bool is_adapted = false; + + const int target_rhat = 1.05; + const int target_ess = 50; + + using boost::accumulators::accumulator_set; + using boost::accumulators::stats; + using boost::accumulators::tag::mean; + using boost::accumulators::tag::variance; + + using stan::math::mpi::Session; + using stan::math::mpi::Communicator; + + accumulator_set> acc_log(max_num_windows); + + for (int m = 0; m < num_iterations; ++m) { + callback(); + + if (refresh > 0 & (start + m + 1 == finish || m == 0 || (m + 1) % refresh == 0)) { + int it_print_width = std::ceil(std::log10(static_cast(finish))); + std::stringstream message; + message << "Iteration: "; + message << std::setw(it_print_width) << m + 1 + start << " / " << finish; + message << " [" << std::setw(3) + << static_cast((100.0 * (start + m + 1)) / finish) << "%] "; + message << (warmup ? " (Warmup)" : " (Sampling)"); + + logger.info(message); + } + + init_s = sampler.transition(init_s, logger); + + if (save & ((m % num_thin) == 0)) { + mcmc_writer.write_sample_params(base_rng, init_s, sampler, model); + mcmc_writer.write_diagnostic_params(init_s, sampler); + } + + double stepsize = -999.0; + + // check adaptation by examining rhat and ess + if(stan::math::mpi::Session::is_in_inter_chain_comm()) { + if (!boost::math::isfinite(init_s.log_prob())) { + return std::numeric_limits::quiet_NaN(); + } + + int m_win = m / window_size + 1; + for (int i = 0; i < m_win; ++i) { + acc_log[i](init_s.log_prob()); + } + + // though @c boost::acc gives population var instead + // of sample var, the nb. of draws is supposed to be + // large enough to make it irrelevant. But for + // between-chain variance we must correct it because + // the nb. of chains is not large + + if (m >= window_size && (m + 1) % window_size == 0) { + int n_gather = 3 * m_win; // mean, variance, stepsize + std::vector chain_gather(n_gather, 0.0); + for (int i = 0; i < m_win; ++i) { + chain_gather[3 * i] = boost::accumulators::mean(acc_log[i]); + chain_gather[3 * i + 1] = boost::accumulators::variance(acc_log[i]); + chain_gather[3 * i + 2] = sampler.get_nominal_stepsize(); + } + + const Communicator& comm = Session::inter_chain_comm(); + if (comm.rank() == 0) { + std::vector rhat(m_win), ess(m_win); + std::vector all_chain_gather(n_gather * num_chains); + MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, + all_chain_gather.data(), n_gather, MPI_DOUBLE, 0, comm.comm()); + for (int i = 0; i < m_win; ++i) { + accumulator_set> acc_chain_mean; + accumulator_set> acc_chain_var; + for (int chain = 0; chain < num_chains; ++chain) { + acc_chain_mean(all_chain_gather[chain * n_gather + 3 * i]); + acc_chain_var(all_chain_gather[chain * n_gather + 3 * i + 1]); + } + int n_draws = (m_win - i) * window_size; + double var_between = n_draws * boost::accumulators::variance(acc_chain_mean) + * num_chains / (num_chains - 1); + double var_within = boost::accumulators::mean(acc_chain_var); + rhat[i] = sqrt((var_between / var_within + num_draws - 1) / num_draws); + // TODO also calculate ess + is_adapted = (rhat[i]) < target_rhat; + if (is_adapted) { + accumulator_set> acc_step; + for (int chain = 0; chain < num_chains; ++chain) { + acc_step(all_chain_gather[chain * n_gather + 3 * i + 2]); + } + stepsize = boost::accumulators::mean(acc_step); + break; + } + } + MPI_Bcast(&stepsize, 1, MPI_DOUBLE, 0, comm.comm()); + } else { + MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, + NULL, 0, MPI_DOUBLE, 0, comm.comm()); + MPI_Bcast(&stepsize, 1, MPI_DOUBLE, 0, comm.comm()); + } + } + } + + const Communicator& intra_comm = Session::intra_chain_comm(); + MPI_Bcast(&stepsize, 1, MPI_DOUBLE, 0, intra_comm.comm()); + if (stepsize > 0.0) { + sampler.set_nominal_stepsize(stepsize); + break; + } + } +} + +} // namespace util +} // namespace services +} // namespace stan + +#endif diff --git a/src/stan/services/util/run_mpi_adaptive_sampler.hpp b/src/stan/services/util/run_mpi_adaptive_sampler.hpp new file mode 100644 index 00000000000..b6ef6250f1a --- /dev/null +++ b/src/stan/services/util/run_mpi_adaptive_sampler.hpp @@ -0,0 +1,90 @@ +#ifndef STAN_SERVICES_UTIL_RUN_MPI_ADAPTIVE_SAMPLER_HPP +#define STAN_SERVICES_UTIL_RUN_MPI_ADAPTIVE_SAMPLER_HPP + +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace services { +namespace util { + +/** + * Runs the sampler with adaptation. + * + * @tparam Sampler Type of adaptive sampler. + * @tparam Model Type of model + * @tparam RNG Type of random number generator + * @param[in,out] sampler the mcmc sampler to use on the model + * @param[in] model the model concept to use for computing log probability + * @param[in] cont_vector initial parameter values + * @param[in] num_warmup number of warmup draws + * @param[in] num_samples number of post warmup draws + * @param[in] num_thin number to thin the draws. Must be greater than + * or equal to 1. + * @param[in] refresh controls output to the logger + * @param[in] save_warmup indicates whether the warmup draws should be + * sent to the sample writer + * @param[in,out] rng random number generator + * @param[in,out] interrupt interrupt callback + * @param[in,out] logger logger for messages + * @param[in,out] sample_writer writer for draws + * @param[in,out] diagnostic_writer writer for diagnostic information + */ +template +void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, + std::vector& cont_vector, int num_warmup, + int num_samples, int num_thin, int refresh, + bool save_warmup, RNG& rng, + callbacks::interrupt& interrupt, + callbacks::logger& logger, + callbacks::writer& sample_writer, + callbacks::writer& diagnostic_writer) { + Eigen::Map cont_params(cont_vector.data(), + cont_vector.size()); + + sampler.engage_adaptation(); + try { + sampler.z().q = cont_params; + sampler.init_stepsize(logger); + } catch (const std::exception& e) { + logger.info("Exception initializing step size."); + logger.info(e.what()); + return; + } + + services::util::mcmc_writer writer(sample_writer, diagnostic_writer, logger); + stan::mcmc::sample s(cont_params, 0, 0); + + // Headers + writer.write_sample_names(s, sampler, model); + writer.write_diagnostic_names(s, sampler, model); + + // warmup + clock_t start = clock(); + util::campfire_warmup(sampler, num_warmup, 0, num_warmup + num_samples, + num_thin, refresh, save_warmup, true, writer, s, + model, rng, interrupt, logger); + clock_t end = clock(); + double warm_delta_t = static_cast(end - start) / CLOCKS_PER_SEC; + + sampler.disengage_adaptation(); + writer.write_adapt_finish(sampler); + sampler.write_sampler_state(sample_writer); + + start = clock(); + util::generate_transitions(sampler, num_samples, num_warmup, + num_warmup + num_samples, num_thin, refresh, true, + false, writer, s, model, rng, interrupt, logger); + end = clock(); + double sample_delta_t = static_cast(end - start) / CLOCKS_PER_SEC; + + writer.write_timing(warm_delta_t, sample_delta_t); +} +} // namespace util +} // namespace services +} // namespace stan +#endif From bae89e1f804fc29cd9637229d0a62b323f85a23a Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 8 Jan 2020 15:57:34 -0800 Subject: [PATCH 13/73] campfire warmup with rhat and synced stepsize --- .../services/sample/hmc_nuts_diag_e_adapt.hpp | 7 +++ src/stan/services/util/campfire_warmup.hpp | 50 ++++++++++--------- .../util/run_mpi_adaptive_sampler.hpp | 1 + 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp index e349c25a681..cd84e74ec95 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -95,9 +96,15 @@ int hmc_nuts_diag_e_adapt( sampler.set_window_params(num_warmup, init_buffer, term_buffer, window, logger); +#ifdef MPI_ADAPTED_WARMUP + util::run_mpi_adaptive_sampler( + sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, + save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); +#else util::run_adaptive_sampler( sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); +#endif return error_codes::OK; } diff --git a/src/stan/services/util/campfire_warmup.hpp b/src/stan/services/util/campfire_warmup.hpp index c68ed41704e..346e3d0669a 100644 --- a/src/stan/services/util/campfire_warmup.hpp +++ b/src/stan/services/util/campfire_warmup.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -41,8 +42,8 @@ namespace util { * @param[in,out] callback interrupt callback called once an iteration * @param[in,out] logger logger for messages */ -template -void campfire_warmup(stan::mcmc::base_mcmc& sampler, int num_iterations, + template +void campfire_warmup(Sampler& sampler, int num_iterations, int start, int finish, int num_thin, int refresh, bool save, bool warmup, util::mcmc_writer& mcmc_writer, @@ -50,16 +51,10 @@ void campfire_warmup(stan::mcmc::base_mcmc& sampler, int num_iterations, RNG& base_rng, callbacks::interrupt& callback, callbacks::logger& logger) { // for prototyping, we have @c max_num_windows fixed + const int num_mpi_chains = 4; const int window_size = 100; const int max_num_windows = num_iterations / window_size; - const int num_chains; - - // rhat for each window combination, e.g. ABCD, BCD, CD, D for 4 windows. - std::vector adapt_quantity(max_num_windows, 0.0); - bool is_adapted = false; - - const int target_rhat = 1.05; - const int target_ess = 50; + const int num_chains = 4; using boost::accumulators::accumulator_set; using boost::accumulators::stats; @@ -69,12 +64,21 @@ void campfire_warmup(stan::mcmc::base_mcmc& sampler, int num_iterations, using stan::math::mpi::Session; using stan::math::mpi::Communicator; - accumulator_set> acc_log(max_num_windows); + std::vector>> acc_log(max_num_windows); + + // Session::inter_chain_comm(num_mpi_chains); + // Session::intra_chain_comm(num_mpi_chains); + + bool is_adapted = false; + const double target_rhat = 1.05; + const double target_ess = 50.0; - for (int m = 0; m < num_iterations; ++m) { + int m = 0; + while (m < num_iterations && (!is_adapted)) { callback(); - if (refresh > 0 & (start + m + 1 == finish || m == 0 || (m + 1) % refresh == 0)) { + if (refresh > 0 + && (start + m + 1 == finish || m == 0 || (m + 1) % refresh == 0)) { int it_print_width = std::ceil(std::log10(static_cast(finish))); std::stringstream message; message << "Iteration: "; @@ -88,19 +92,15 @@ void campfire_warmup(stan::mcmc::base_mcmc& sampler, int num_iterations, init_s = sampler.transition(init_s, logger); - if (save & ((m % num_thin) == 0)) { + if (save && ((m % num_thin) == 0)) { mcmc_writer.write_sample_params(base_rng, init_s, sampler, model); mcmc_writer.write_diagnostic_params(init_s, sampler); } double stepsize = -999.0; + bool is_inter_rank = Session::is_in_inter_chain_comm(num_mpi_chains); - // check adaptation by examining rhat and ess - if(stan::math::mpi::Session::is_in_inter_chain_comm()) { - if (!boost::math::isfinite(init_s.log_prob())) { - return std::numeric_limits::quiet_NaN(); - } - + if (is_inter_rank && boost::math::isfinite(init_s.log_prob())) { int m_win = m / window_size + 1; for (int i = 0; i < m_win; ++i) { acc_log[i](init_s.log_prob()); @@ -121,7 +121,7 @@ void campfire_warmup(stan::mcmc::base_mcmc& sampler, int num_iterations, chain_gather[3 * i + 2] = sampler.get_nominal_stepsize(); } - const Communicator& comm = Session::inter_chain_comm(); + const Communicator& comm = Session::inter_chain_comm(num_mpi_chains); if (comm.rank() == 0) { std::vector rhat(m_win), ess(m_win); std::vector all_chain_gather(n_gather * num_chains); @@ -138,7 +138,8 @@ void campfire_warmup(stan::mcmc::base_mcmc& sampler, int num_iterations, double var_between = n_draws * boost::accumulators::variance(acc_chain_mean) * num_chains / (num_chains - 1); double var_within = boost::accumulators::mean(acc_chain_var); - rhat[i] = sqrt((var_between / var_within + num_draws - 1) / num_draws); + rhat[i] = sqrt((var_between / var_within + n_draws - 1) / n_draws); + // TODO also calculate ess is_adapted = (rhat[i]) < target_rhat; if (is_adapted) { @@ -159,12 +160,15 @@ void campfire_warmup(stan::mcmc::base_mcmc& sampler, int num_iterations, } } - const Communicator& intra_comm = Session::intra_chain_comm(); + const Communicator& intra_comm = Session::intra_chain_comm(num_mpi_chains); MPI_Bcast(&stepsize, 1, MPI_DOUBLE, 0, intra_comm.comm()); if (stepsize > 0.0) { + is_adapted = true; sampler.set_nominal_stepsize(stepsize); break; } + + m++; } } diff --git a/src/stan/services/util/run_mpi_adaptive_sampler.hpp b/src/stan/services/util/run_mpi_adaptive_sampler.hpp index b6ef6250f1a..dc72e29aea0 100644 --- a/src/stan/services/util/run_mpi_adaptive_sampler.hpp +++ b/src/stan/services/util/run_mpi_adaptive_sampler.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include From 6e1fce86f3c51e661b42da5ce77460816ea445b9 Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 8 Jan 2020 16:57:38 -0800 Subject: [PATCH 14/73] pass num_chains through run_mpi_adaptive_sampler so that cmdstan has access to the parameter --- .../services/sample/hmc_nuts_diag_e_adapt.hpp | 5 +++-- .../services/sample/hmc_nuts_unit_e_adapt.hpp | 3 ++- src/stan/services/util/campfire_warmup.hpp | 17 +++++++---------- .../services/util/run_mpi_adaptive_sampler.hpp | 6 ++++-- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp index cd84e74ec95..4aba49937e2 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp @@ -97,8 +97,9 @@ int hmc_nuts_diag_e_adapt( logger); #ifdef MPI_ADAPTED_WARMUP - util::run_mpi_adaptive_sampler( - sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, + const int num_chains = 4; + util::run_mpi_adaptive_sampler(sampler, + model, cont_vector, num_chains, num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); #else util::run_adaptive_sampler( diff --git a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp index ef255e0b63c..810a0a3be8f 100644 --- a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp @@ -72,8 +72,9 @@ int hmc_nuts_unit_e_adapt( sampler.get_stepsize_adaptation().set_t0(t0); #ifdef MPI_ADAPTED_WARMUP + const int num_chains = 4; util::run_mpi_adaptive_sampler( - sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, + sampler, model, cont_vector, num_chains, num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); #else util::run_adaptive_sampler( diff --git a/src/stan/services/util/campfire_warmup.hpp b/src/stan/services/util/campfire_warmup.hpp index 346e3d0669a..774da8a1157 100644 --- a/src/stan/services/util/campfire_warmup.hpp +++ b/src/stan/services/util/campfire_warmup.hpp @@ -42,8 +42,9 @@ namespace util { * @param[in,out] callback interrupt callback called once an iteration * @param[in,out] logger logger for messages */ - template -void campfire_warmup(Sampler& sampler, int num_iterations, +template +void campfire_warmup(Sampler& sampler, int num_chains, + int num_iterations, int start, int finish, int num_thin, int refresh, bool save, bool warmup, util::mcmc_writer& mcmc_writer, @@ -51,10 +52,8 @@ void campfire_warmup(Sampler& sampler, int num_iterations, RNG& base_rng, callbacks::interrupt& callback, callbacks::logger& logger) { // for prototyping, we have @c max_num_windows fixed - const int num_mpi_chains = 4; const int window_size = 100; const int max_num_windows = num_iterations / window_size; - const int num_chains = 4; using boost::accumulators::accumulator_set; using boost::accumulators::stats; @@ -66,9 +65,6 @@ void campfire_warmup(Sampler& sampler, int num_iterations, std::vector>> acc_log(max_num_windows); - // Session::inter_chain_comm(num_mpi_chains); - // Session::intra_chain_comm(num_mpi_chains); - bool is_adapted = false; const double target_rhat = 1.05; const double target_ess = 50.0; @@ -98,7 +94,7 @@ void campfire_warmup(Sampler& sampler, int num_iterations, } double stepsize = -999.0; - bool is_inter_rank = Session::is_in_inter_chain_comm(num_mpi_chains); + bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains); if (is_inter_rank && boost::math::isfinite(init_s.log_prob())) { int m_win = m / window_size + 1; @@ -121,7 +117,7 @@ void campfire_warmup(Sampler& sampler, int num_iterations, chain_gather[3 * i + 2] = sampler.get_nominal_stepsize(); } - const Communicator& comm = Session::inter_chain_comm(num_mpi_chains); + const Communicator& comm = Session::inter_chain_comm(num_chains); if (comm.rank() == 0) { std::vector rhat(m_win), ess(m_win); std::vector all_chain_gather(n_gather * num_chains); @@ -148,6 +144,7 @@ void campfire_warmup(Sampler& sampler, int num_iterations, acc_step(all_chain_gather[chain * n_gather + 3 * i + 2]); } stepsize = boost::accumulators::mean(acc_step); + std::cout << "taki test rhat: " << rhat[i] << "\n"; break; } } @@ -160,7 +157,7 @@ void campfire_warmup(Sampler& sampler, int num_iterations, } } - const Communicator& intra_comm = Session::intra_chain_comm(num_mpi_chains); + const Communicator& intra_comm = Session::intra_chain_comm(num_chains); MPI_Bcast(&stepsize, 1, MPI_DOUBLE, 0, intra_comm.comm()); if (stepsize > 0.0) { is_adapted = true; diff --git a/src/stan/services/util/run_mpi_adaptive_sampler.hpp b/src/stan/services/util/run_mpi_adaptive_sampler.hpp index dc72e29aea0..bb54e3f9637 100644 --- a/src/stan/services/util/run_mpi_adaptive_sampler.hpp +++ b/src/stan/services/util/run_mpi_adaptive_sampler.hpp @@ -37,7 +37,8 @@ namespace util { */ template void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, - std::vector& cont_vector, int num_warmup, + std::vector& cont_vector, + int num_chains, int num_warmup, int num_samples, int num_thin, int refresh, bool save_warmup, RNG& rng, callbacks::interrupt& interrupt, @@ -66,7 +67,8 @@ void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, // warmup clock_t start = clock(); - util::campfire_warmup(sampler, num_warmup, 0, num_warmup + num_samples, + util::campfire_warmup(sampler, num_chains, + num_warmup, 0, num_warmup + num_samples, num_thin, refresh, save_warmup, true, writer, s, model, rng, interrupt, logger); clock_t end = clock(); From 45ed1d45989c2169b278bc9b9878ea9b00182f1f Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 8 Jan 2020 17:14:25 -0800 Subject: [PATCH 15/73] update submodule --- lib/stan_math | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/stan_math b/lib/stan_math index 7f4e3a4af56..89cee61d436 160000 --- a/lib/stan_math +++ b/lib/stan_math @@ -1 +1 @@ -Subproject commit 7f4e3a4af56d8471606b414ed59364898638b323 +Subproject commit 89cee61d43607d2ce011701b69e8ddffc1db2aeb From 54f5fa34e91bbfe44af6de2fb5db4df087f720d6 Mon Sep 17 00:00:00 2001 From: yiz Date: Thu, 9 Jan 2020 21:50:56 -0800 Subject: [PATCH 16/73] cross chain rhat calculation unit test --- .../services/util/mpi_cross_chain_adapt.hpp | 103 +++++ .../unit/services/util/mpi_warmup_test.cpp | 363 ++++-------------- 2 files changed, 171 insertions(+), 295 deletions(-) create mode 100644 src/stan/services/util/mpi_cross_chain_adapt.hpp diff --git a/src/stan/services/util/mpi_cross_chain_adapt.hpp b/src/stan/services/util/mpi_cross_chain_adapt.hpp new file mode 100644 index 00000000000..e4aa4a79919 --- /dev/null +++ b/src/stan/services/util/mpi_cross_chain_adapt.hpp @@ -0,0 +1,103 @@ +#ifndef STAN_SERVICES_UTIL_MPI_CROSS_CHAIN_ADAPT_HPP +#define STAN_SERVICES_UTIL_MPI_CROSS_CHAIN_ADAPT_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace services { +namespace util { + /* + * @tparam Sampler sampler class + * @param[in] m_win number of windows + * @param[in] window_size window size + * @param[in] num_chains number of chains + * @param[in,out] chain_gather gathered information from each chain, + * must have enough capacity to store up to + * maximum windows for all chains. + # @return vector {stepsize, rhat(only in rank 0)} + */ + template + std::vector + mpi_cross_chain_adapt(const std::vector& acc, + const std::vector& chain_stepsize, + int num_current_window, int max_num_window, + int window_size, int num_chains, + double target_rhat, + std::vector& chain_gather) { + using boost::accumulators::accumulator_set; + using boost::accumulators::stats; + using boost::accumulators::tag::mean; + using boost::accumulators::tag::variance; + + using stan::math::mpi::Session; + using stan::math::mpi::Communicator; + + const Communicator& comm = Session::inter_chain_comm(num_chains); + + const int nd_win = 3; // mean, variance, chain_stepsize + int n_gather = nd_win * num_current_window; + for (int win = 0; win < num_current_window; ++win) { + int n_draws = (num_current_window - win) * window_size; + double unbiased_var_scale = n_draws / (n_draws - 1.0); + chain_gather[nd_win * win] = boost::accumulators::mean(acc[win]); + chain_gather[nd_win * win + 1] = boost::accumulators::variance(acc[win]) * + unbiased_var_scale; + chain_gather[nd_win * win + 2] = chain_stepsize[win]; + } + + std::vector res; + double stepsize = -999.0; + + if (comm.rank() == 0) { + std::vector rhat(num_current_window), ess(num_current_window); + std::vector all_chain_gather(n_gather * num_chains); + MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, + all_chain_gather.data(), n_gather, MPI_DOUBLE, 0, comm.comm()); + for (int win = 0; win < num_current_window; ++win) { + accumulator_set> acc_chain_mean; + accumulator_set> acc_chain_var; + for (int chain = 0; chain < num_chains; ++chain) { + acc_chain_mean(all_chain_gather[chain * n_gather + nd_win * win]); + acc_chain_var(all_chain_gather[chain * n_gather + nd_win * win + 1]); + } + int n_draws = (num_current_window - win) * window_size; + double var_between = n_draws * boost::accumulators::variance(acc_chain_mean) + * num_chains / (num_chains - 1); + double var_within = boost::accumulators::mean(acc_chain_var); + rhat[win] = sqrt((var_between / var_within + n_draws - 1) / n_draws); + + // TODO also calculate ess + bool is_adapted = (rhat[win]) < target_rhat; + if (is_adapted) { + accumulator_set> acc_step; + for (int chain = 0; chain < num_chains; ++chain) { + acc_step(all_chain_gather[chain * n_gather + nd_win * win + 2]); + } + stepsize = boost::accumulators::mean(acc_step); + res.push_back(stepsize); + res.push_back(rhat[win]); + break; + } + } + MPI_Bcast(&stepsize, 1, MPI_DOUBLE, 0, comm.comm()); + } else { + MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, + NULL, 0, MPI_DOUBLE, 0, comm.comm()); + MPI_Bcast(&stepsize, 1, MPI_DOUBLE, 0, comm.comm()); + res.push_back(stepsize); + } + return res; + } +} +} +} +#endif diff --git a/src/test/unit/services/util/mpi_warmup_test.cpp b/src/test/unit/services/util/mpi_warmup_test.cpp index 70b29f63edf..57eea909f61 100644 --- a/src/test/unit/services/util/mpi_warmup_test.cpp +++ b/src/test/unit/services/util/mpi_warmup_test.cpp @@ -1,9 +1,9 @@ #ifdef STAN_LANG_MPI #include -#include - -#include +#include +#include +#include #include #include #include @@ -16,303 +16,76 @@ using Eigen::MatrixXd; using Eigen::Matrix; using std::vector; -using stan::services::util::mpi::Communicator; -using stan::services::util::mpi::Session; -using stan::services::util::mpi::mpi_loader_base; -using stan::services::util::mpi::mpi_warmup; - -TEST(mpi_warmup_test, mpi_inter_intra_comms) { - const Communicator world_comm(MPI_COMM_STAN); - const Communicator inter_comm(Session<3>::MPI_COMM_INTER_CHAIN); - const Communicator intra_comm(Session<3>::MPI_COMM_INTRA_CHAIN); - if (world_comm.size == 3) { - switch (world_comm.rank) { - case 0: - EXPECT_EQ(inter_comm.rank, 0); - EXPECT_EQ(intra_comm.rank, 0); - break; - case 1: - EXPECT_EQ(inter_comm.rank, 1); - EXPECT_EQ(intra_comm.rank, 0); - break; - case 2: - EXPECT_EQ(inter_comm.rank, 2); - EXPECT_EQ(intra_comm.rank, 0); - break; - } - } else if (world_comm.size == 4) { - switch (world_comm.rank) { - case 0: - EXPECT_EQ(inter_comm.rank, 0); - EXPECT_EQ(intra_comm.rank, 0); - break; - case 1: - EXPECT_EQ(inter_comm.rank, -1); - EXPECT_EQ(intra_comm.rank, 1); - break; - case 2: - EXPECT_EQ(inter_comm.rank, 1); - EXPECT_EQ(intra_comm.rank, 0); - break; - case 3: - EXPECT_EQ(inter_comm.rank, 2); - EXPECT_EQ(intra_comm.rank, 0); - break; - } - } else if (world_comm.size == 5) { - switch (world_comm.rank) { - case 0: - EXPECT_EQ(inter_comm.rank, 0); - EXPECT_EQ(intra_comm.rank, 0); - break; - case 1: - EXPECT_EQ(inter_comm.rank, -1); - EXPECT_EQ(intra_comm.rank, 1); - break; - case 2: - EXPECT_EQ(inter_comm.rank, 1); - EXPECT_EQ(intra_comm.rank, 0); - break; - case 3: - EXPECT_EQ(inter_comm.rank, -1); - EXPECT_EQ(intra_comm.rank, 1); - break; - case 4: - EXPECT_EQ(inter_comm.rank, 2); - EXPECT_EQ(intra_comm.rank, 0); - break; - } - } else if (world_comm.size == 6) { - switch (world_comm.rank) { - case 0: - EXPECT_EQ(inter_comm.rank, 0); - EXPECT_EQ(intra_comm.rank, 0); - break; - case 1: - EXPECT_EQ(inter_comm.rank, -1); - EXPECT_EQ(intra_comm.rank, 1); - break; - case 2: - EXPECT_EQ(inter_comm.rank, 1); - EXPECT_EQ(intra_comm.rank, 0); - break; - case 3: - EXPECT_EQ(inter_comm.rank, -1); - EXPECT_EQ(intra_comm.rank, 1); - break; - case 4: - EXPECT_EQ(inter_comm.rank, 2); - EXPECT_EQ(intra_comm.rank, 0); - break; - case 5: - EXPECT_EQ(inter_comm.rank, -1); - EXPECT_EQ(intra_comm.rank, 1); - break; - } - } else if (world_comm.size == 7) { - switch (world_comm.rank) { - case 0: - EXPECT_EQ(inter_comm.rank, 0); - EXPECT_EQ(intra_comm.rank, 0); - break; - case 1: - EXPECT_EQ(inter_comm.rank, -1); - EXPECT_EQ(intra_comm.rank, 1); - break; - case 2: - EXPECT_EQ(inter_comm.rank, -1); - EXPECT_EQ(intra_comm.rank, 2); - break; - case 3: - EXPECT_EQ(inter_comm.rank, 1); - EXPECT_EQ(intra_comm.rank, 0); - break; - case 4: - EXPECT_EQ(inter_comm.rank, -1); - EXPECT_EQ(intra_comm.rank, 1); - break; - case 5: - EXPECT_EQ(inter_comm.rank, 2); - EXPECT_EQ(intra_comm.rank, 0); - break; - case 6: - EXPECT_EQ(inter_comm.rank, -1); - EXPECT_EQ(intra_comm.rank, 1); - break; - } - } -} - -struct send_processor { - const Communicator& comm; - - send_processor(const Communicator& comm_in) : - comm(comm_in) - {} - - template - static int size(const Sampler& sampler, const Model& model, - stan::mcmc::sample& sample) { - return 1000000; - } - - template - Eigen::VectorXd operator()(Sampler& sampler, Model& model, stan::mcmc::sample& sample) const { - Eigen::VectorXd x(Eigen::VectorXd::Zero(size(sampler, model, sample))); - x(comm.rank) = comm.rank; - return x; - } -}; - -struct adapt_processor { - const Communicator& comm; - - adapt_processor(const Communicator& comm_in) : - comm(comm_in) - {} - - template - void operator()(const Eigen::MatrixXd& workspace_r, Sampler& sampler, Model& model, stan::mcmc::sample& sample) const { - for (int i = 0; i < workspace_r.cols(); ++i) { - EXPECT_FLOAT_EQ(workspace_r(i, i), double(i)); - } - double sum1 = 0.5 * comm.size * (comm.size - 1); - double sum2 = workspace_r.sum(); - EXPECT_FLOAT_EQ(sum1, sum2); - } -}; - -struct dummy_transition { - template - void operator()(Sampler& sampler, Model& model, stan::mcmc::sample& sample) { - } - - void operator()() { - } -}; - -TEST(mpi_warmup_test, mpi_warmup_loader) { - const Communicator inter_comm(Session<3>::MPI_COMM_INTER_CHAIN); - mpi_loader_base loader(inter_comm); - - Eigen::MatrixXd dummy_sampler; - Eigen::MatrixXd dummy_model; - stan::mcmc::sample sample(Eigen::VectorXd(0), 0, 0); - mpi_warmup mpi_warmup_adapt(loader, 10); - - send_processor fs(inter_comm); - adapt_processor fd(inter_comm); - dummy_transition f; - - mpi_warmup_adapt(dummy_sampler, dummy_model, sample, fs, - f, dummy_sampler, dummy_model, sample); - - mpi_warmup_adapt.finalize(dummy_sampler, dummy_model, sample, fd, f); -} - - -struct send_adapt_processor { - const Communicator& comm; - - send_adapt_processor(const Communicator& comm_in) : - comm(comm_in) - {} - - template - static int size(const Sampler& sampler, const Model& model, - stan::mcmc::sample& sample) { - return 1; - } - - template - Eigen::VectorXd operator()(Sampler& sampler, Model& model, stan::mcmc::sample& sample) const { - Eigen::VectorXd x(Eigen::VectorXd::Zero(size(sampler, model, sample))); - x(0) = sampler.get_nominal_stepsize() + 0.01 * comm.rank; - return x; +using stan::math::mpi::Session; +using stan::math::mpi::Communicator; +using Eigen::MatrixXd; +using Eigen::VectorXd; +using boost::accumulators::accumulator_set; +using boost::accumulators::stats; +using boost::accumulators::tag::mean; +using boost::accumulators::tag::variance; + +// 4 chains with 4 cores, each chain run on a core +TEST(mpi_warmup_test, rhat_adaption) { + const int num_chains = 4; + const int max_num_windows = num_chains; + const size_t s = 25; + const std::vector sizes(num_chains, s); + std::vector draw_vecs(num_chains, Eigen::VectorXd(s)); + draw_vecs[0] << + -276.606 , -277.168 , -272.621 , -271.142 , -271.950 , + -269.749 , -267.016 , -273.508 , -268.650 , -265.904 , + -264.629 , -260.797 , -263.184 , -263.892 , -268.810 , + -272.563 , -268.320 , -266.297 , -265.787 , -266.073 , + -265.788 , -262.260 , -265.073 , -265.511 , -264.318; + draw_vecs[1] << + -264.318 , -266.261 , -265.633 , -265.323 , -265.633 , + -265.426 , -265.690 , -266.122 , -264.876 , -264.829 , + -264.238 , -265.822 , -262.979 , -264.012 , -263.801 , + -264.745 , -263.940 , -263.586 , -263.284 , -262.566 , + -261.816 , -265.308 , -266.467 , -265.915 , -266.122; + draw_vecs[2] << + -266.122 , -265.903 , -265.903 , -265.717 , -271.780 , + -271.780 , -271.712 , -271.712 , -271.011 , -273.137 , + -272.125 , -265.535 , -265.168 , -267.824 , -262.983 , + -262.985 , -261.967 , -265.455 , -265.900 , -265.623 , + -262.111 , -262.111 , -262.111 , -266.586 , -266.545; + draw_vecs[3] << + -266.545 , -263.267 , -268.256 , -270.425 , -268.454 , + -268.807 , -269.154 , -269.154 , -269.528 , -268.206 , + -271.774 , -269.453 , -267.725 , -266.435 , -269.434 , + -267.838 , -267.676 , -267.925 , -268.343 , -267.824 , + -267.824 , -267.050 , -268.138 , -268.072 , -267.321; + + const std::vector draws{draw_vecs[0].data(), + draw_vecs[1].data(), draw_vecs[2].data(), draw_vecs[3].data()}; + double rhat = stan::analyze::compute_potential_scale_reduction(draws, sizes); + + std::vector>> acc(max_num_windows); + std::vector chain_stepsize{1.1, 1.2, 1.3, 1.4}; + const Communicator& comm = Session::inter_chain_comm(num_chains); + // each rank has different draws + for (int j = 0; j < s; ++j) { acc[0](draw_vecs[comm.rank()](j)); } + // each rank's stepsize is jittered + for (int j = 0; j < num_chains; ++j) {chain_stepsize[j] += 0.1 * comm.rank();} + + std::vector chain_gather(2 * num_chains * max_num_windows, 0.0); + std::vector output = stan::services::util::mpi_cross_chain_adapt(acc, + chain_stepsize, + 1, + max_num_windows, + s, num_chains, + 1.5, chain_gather); + if (comm.rank() == 0) { + EXPECT_EQ(output[1], rhat); } -}; + EXPECT_FLOAT_EQ(output[0], 1.25); -struct warmup_processor { -template -void operator()(stan::mcmc::base_mcmc& sampler, int num_iterations, - int start, int finish, int num_thin, int refresh, bool save, - stan::services::util::mcmc_writer& mcmc_writer, - stan::mcmc::sample& s, Model& model, - RNG& base_rng, stan::callbacks::interrupt& callback, - stan::callbacks::logger& logger) { - stan::services::util::generate_transitions(sampler, num_iterations, start, finish, - num_thin, refresh, save, true, mcmc_writer, s, - model, base_rng, callback, logger); + // } } -}; - -struct collect_adapt_processor { - const Communicator& comm; - collect_adapt_processor(const Communicator& comm_in) : - comm(comm_in) - {} - - template - void operator()(const Eigen::MatrixXd& workspace_r, Sampler& sampler, Model& model, stan::mcmc::sample& sample) const { - EXPECT_EQ(workspace_r.cols(), comm.size); - for (int i = 0; i < comm.size; ++i) { - EXPECT_FLOAT_EQ(workspace_r(0, i), 0.01 * i + workspace_r(0, 0)); - } - } -}; - -TEST(mpi_warmup_test, unit_e_nuts) { - using Model = gauss3D_model_namespace::gauss3D_model; - using Sampler = stan::mcmc::adapt_unit_e_nuts; - boost::ecuyer1988 rng(4839294); - - stan::mcmc::unit_e_point z_init(3); - z_init.q(0) = 1; - z_init.q(1) = -1; - z_init.q(2) = 1; - z_init.p(0) = -1; - z_init.p(1) = 1; - z_init.p(2) = -1; - - std::stringstream debug, info, warn, error, fatal; - stan::callbacks::stream_logger logger(debug, info, warn, error, fatal); - - std::fstream empty_stream("", std::fstream::in); - stan::io::dump data_var_context(empty_stream); - Model model(data_var_context); - - Sampler sampler(model, rng); - sampler.z() = z_init; - sampler.init_hamiltonian(logger); - sampler.set_nominal_stepsize(0.1); - sampler.set_stepsize_jitter(0); - sampler.sample_stepsize(); - - stan::mcmc::sample s(z_init.q, 0, 0); +#endif - stan::callbacks::writer sample_writer; - stan::callbacks::writer diagnostic_writer; - stan::services::util::mcmc_writer writer(sample_writer, diagnostic_writer, logger); - stan::callbacks::interrupt interrupt; - stan::services::util::generate_transitions(sampler, 10, 0, 20, - 1, 0, false, true, writer, s, - model, rng, interrupt, logger); - const Communicator inter_comm(Session<3>::MPI_COMM_INTER_CHAIN); - mpi_loader_base loader(inter_comm); - mpi_warmup mpi_warmup_adapt(loader, 10); - warmup_processor f_warmup; - send_adapt_processor fs(inter_comm); - mpi_warmup_adapt(sampler, model, s, fs, - f_warmup, - sampler, 10, 0, 20, 1, 0, false, writer, s, model, rng, interrupt, logger); - collect_adapt_processor fd(inter_comm); - mpi_warmup_adapt.finalize(sampler, model, s, fd, - f_warmup, - sampler, 10, 0, 20, 1, 0, false, writer, s, model, rng, interrupt, logger); -} - -#endif From ef318513280ecb0ade3a1a43e999c89d55573080 Mon Sep 17 00:00:00 2001 From: yiz Date: Mon, 13 Jan 2020 22:14:42 -0800 Subject: [PATCH 17/73] unit test for cross-chain adapted warmup --- .../services/util/mpi_cross_chain_adapt.hpp | 82 +++-- .../unit/services/util/mpi_warmup_test.cpp | 329 +++++++++++++++--- 2 files changed, 335 insertions(+), 76 deletions(-) diff --git a/src/stan/services/util/mpi_cross_chain_adapt.hpp b/src/stan/services/util/mpi_cross_chain_adapt.hpp index e4aa4a79919..726b20d43a0 100644 --- a/src/stan/services/util/mpi_cross_chain_adapt.hpp +++ b/src/stan/services/util/mpi_cross_chain_adapt.hpp @@ -6,6 +6,8 @@ #include #include #include +#include +#include #include #include #include @@ -15,6 +17,32 @@ namespace stan { namespace services { namespace util { + /* + * Computes the effective sample size (ESS) for the specified + * parameter across all kept samples. The value returned is the + * minimum of ESS and the number_total_draws * + * log10(number_total_draws). + * + * This version is based on the one at + * stan/analyze/mcmc/compute_effective_sample_size.hpp + * but assuming the chain_mean and chain_var has been + * calculated(on the fly during adaptation) + * + */ +inline double +single_chain_ess(const double* draw, size_t num_draws) { + Eigen::Map > d(draw, num_draws); + Eigen::Matrix acov; + stan::math::autocorrelation(d, acov); + double rhos = 0.0; + int i = 1; + while (i < num_draws && acov(i) > 0.05) { + rhos += acov(i); + i++; + } + return double(num_draws) / (1.0 + 2.0 * rhos); +} + /* * @tparam Sampler sampler class * @param[in] m_win number of windows @@ -27,12 +55,12 @@ namespace util { */ template std::vector - mpi_cross_chain_adapt(const std::vector& acc, + mpi_cross_chain_adapt(const double* draw_p, + const std::vector& acc, const std::vector& chain_stepsize, int num_current_window, int max_num_window, int window_size, int num_chains, - double target_rhat, - std::vector& chain_gather) { + double target_rhat, double target_ess) { using boost::accumulators::accumulator_set; using boost::accumulators::stats; using boost::accumulators::tag::mean; @@ -41,59 +69,63 @@ namespace util { using stan::math::mpi::Session; using stan::math::mpi::Communicator; + const Communicator& comm = Session::inter_chain_comm(num_chains); - const int nd_win = 3; // mean, variance, chain_stepsize + const int nd_win = 4; // mean, variance, chain_stepsize int n_gather = nd_win * num_current_window; + std::vector chain_gather(n_gather, 0.0); for (int win = 0; win < num_current_window; ++win) { - int n_draws = (num_current_window - win) * window_size; - double unbiased_var_scale = n_draws / (n_draws - 1.0); + int num_draws = (num_current_window - win) * window_size; + double unbiased_var_scale = num_draws / (num_draws - 1.0); chain_gather[nd_win * win] = boost::accumulators::mean(acc[win]); chain_gather[nd_win * win + 1] = boost::accumulators::variance(acc[win]) * unbiased_var_scale; chain_gather[nd_win * win + 2] = chain_stepsize[win]; + chain_gather[nd_win * win + 3] = + single_chain_ess(draw_p + win * window_size, num_draws); } - std::vector res; double stepsize = -999.0; + std::vector res(1 + max_num_window, stepsize); if (comm.rank() == 0) { - std::vector rhat(num_current_window), ess(num_current_window); std::vector all_chain_gather(n_gather * num_chains); MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, all_chain_gather.data(), n_gather, MPI_DOUBLE, 0, comm.comm()); for (int win = 0; win < num_current_window; ++win) { accumulator_set> acc_chain_mean; accumulator_set> acc_chain_var; + accumulator_set> acc_step; + Eigen::VectorXd chain_mean(num_chains); + Eigen::VectorXd chain_var(num_chains); + Eigen::ArrayXd chain_ess(num_chains); for (int chain = 0; chain < num_chains; ++chain) { - acc_chain_mean(all_chain_gather[chain * n_gather + nd_win * win]); - acc_chain_var(all_chain_gather[chain * n_gather + nd_win * win + 1]); + chain_mean(chain) = all_chain_gather[chain * n_gather + nd_win * win]; + acc_chain_mean(chain_mean(chain)); + chain_var(chain) = all_chain_gather[chain * n_gather + nd_win * win + 1]; + acc_chain_var(chain_var(chain)); + acc_step(all_chain_gather[chain * n_gather + nd_win * win + 2]); + chain_ess(chain) = all_chain_gather[chain * n_gather + nd_win * win + 3]; } - int n_draws = (num_current_window - win) * window_size; - double var_between = n_draws * boost::accumulators::variance(acc_chain_mean) + size_t num_draws = (num_current_window - win) * window_size; + double var_between = num_draws * boost::accumulators::variance(acc_chain_mean) * num_chains / (num_chains - 1); double var_within = boost::accumulators::mean(acc_chain_var); - rhat[win] = sqrt((var_between / var_within + n_draws - 1) / n_draws); - - // TODO also calculate ess - bool is_adapted = (rhat[win]) < target_rhat; + double rhat = sqrt((var_between / var_within + num_draws - 1) / num_draws); + res[win + 1] = rhat; + bool is_adapted = rhat < target_rhat && (chain_ess > target_ess).all(); if (is_adapted) { - accumulator_set> acc_step; - for (int chain = 0; chain < num_chains; ++chain) { - acc_step(all_chain_gather[chain * n_gather + nd_win * win + 2]); - } stepsize = boost::accumulators::mean(acc_step); - res.push_back(stepsize); - res.push_back(rhat[win]); + res[0] = stepsize; break; } } - MPI_Bcast(&stepsize, 1, MPI_DOUBLE, 0, comm.comm()); + MPI_Bcast(res.data(), 1, MPI_DOUBLE, 0, comm.comm()); } else { MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, NULL, 0, MPI_DOUBLE, 0, comm.comm()); - MPI_Bcast(&stepsize, 1, MPI_DOUBLE, 0, comm.comm()); - res.push_back(stepsize); + MPI_Bcast(res.data(), 1, MPI_DOUBLE, 0, comm.comm()); } return res; } diff --git a/src/test/unit/services/util/mpi_warmup_test.cpp b/src/test/unit/services/util/mpi_warmup_test.cpp index 57eea909f61..3d68f4b10d9 100644 --- a/src/test/unit/services/util/mpi_warmup_test.cpp +++ b/src/test/unit/services/util/mpi_warmup_test.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -28,60 +29,286 @@ using boost::accumulators::tag::variance; // 4 chains with 4 cores, each chain run on a core TEST(mpi_warmup_test, rhat_adaption) { const int num_chains = 4; - const int max_num_windows = num_chains; - const size_t s = 25; - const std::vector sizes(num_chains, s); - std::vector draw_vecs(num_chains, Eigen::VectorXd(s)); - draw_vecs[0] << - -276.606 , -277.168 , -272.621 , -271.142 , -271.950 , - -269.749 , -267.016 , -273.508 , -268.650 , -265.904 , - -264.629 , -260.797 , -263.184 , -263.892 , -268.810 , - -272.563 , -268.320 , -266.297 , -265.787 , -266.073 , - -265.788 , -262.260 , -265.073 , -265.511 , -264.318; - draw_vecs[1] << - -264.318 , -266.261 , -265.633 , -265.323 , -265.633 , - -265.426 , -265.690 , -266.122 , -264.876 , -264.829 , - -264.238 , -265.822 , -262.979 , -264.012 , -263.801 , - -264.745 , -263.940 , -263.586 , -263.284 , -262.566 , - -261.816 , -265.308 , -266.467 , -265.915 , -266.122; - draw_vecs[2] << - -266.122 , -265.903 , -265.903 , -265.717 , -271.780 , - -271.780 , -271.712 , -271.712 , -271.011 , -273.137 , - -272.125 , -265.535 , -265.168 , -267.824 , -262.983 , - -262.985 , -261.967 , -265.455 , -265.900 , -265.623 , - -262.111 , -262.111 , -262.111 , -266.586 , -266.545; - draw_vecs[3] << - -266.545 , -263.267 , -268.256 , -270.425 , -268.454 , - -268.807 , -269.154 , -269.154 , -269.528 , -268.206 , - -271.774 , -269.453 , -267.725 , -266.435 , -269.434 , - -267.838 , -267.676 , -267.925 , -268.343 , -267.824 , - -267.824 , -267.050 , -268.138 , -268.072 , -267.321; - - const std::vector draws{draw_vecs[0].data(), - draw_vecs[1].data(), draw_vecs[2].data(), draw_vecs[3].data()}; - double rhat = stan::analyze::compute_potential_scale_reduction(draws, sizes); - - std::vector>> acc(max_num_windows); + const int max_num_windows = 5; + const int window_size = 50; + std::vector + draw_vecs(num_chains, Eigen::VectorXd(window_size * max_num_windows)); +draw_vecs[0] << +-276.606, -277.168, -272.621, -271.142, -271.95 , +-269.749, -267.016, -273.508, -268.65 , -265.904, +-264.629, -260.797, -263.184, -263.892, -268.81 , +-272.563, -268.32 , -266.297, -265.787, -266.073, +-265.788, -262.26 , -265.073, -265.511, -264.318, +-264.318, -266.261, -265.633, -265.323, -265.633, +-265.426, -265.69 , -266.122, -264.876, -264.829, +-264.238, -265.822, -262.979, -264.012, -263.801, +-264.745, -263.94 , -263.586, -263.284, -262.566, +-261.816, -265.308, -266.467, -265.915, -266.122, +-266.122, -265.903, -265.903, -265.717, -271.78 , +-271.78 , -271.712, -271.712, -271.011, -273.137, +-272.125, -265.535, -265.168, -267.824, -262.983, +-262.985, -261.967, -265.455, -265.9 , -265.623, +-262.111, -262.111, -262.111, -266.586, -266.545, +-266.545, -263.267, -268.256, -270.425, -268.454, +-268.807, -269.154, -269.154, -269.528, -268.206, +-271.774, -269.453, -267.725, -266.435, -269.434, +-267.838, -267.676, -267.925, -268.343, -267.824, +-267.824, -267.05 , -268.138, -268.072, -267.321, +-267.529, -267.481, -267.118, -267.872, -269.605, +-269.974, -269.347, -269.806, -273.444, -272.257, +-269.983, -271.206, -271.453, -268.328, -268.185, +-268.817, -266.788, -264.052, -270.256, -269.739, +-271.512, -266.883, -266.736, -266.872, -267.525, +-266.845, -267.412, -267.754, -267.754, -267.625, +-266.819, -266.978, -267.949, -266.816, -267.641, +-268.377, -267.13 , -266.892, -269.544, -270.316, +-270.461, -270.989, -273.724, -273.155, -272.725, +-272.082, -264.071, -265.269, -263.945, -261.799, +-261.854, -264.487, -267.127, -265.134, -264.052, +-269.239, -263.838, -264.494, -261.844, -264.41 , +-261.969, -264.178, -265.37 , -266.054, -264.703, +-266.988, -267.21 , -265.177, -263.338, -266.309, +-272.157, -269.383, -266.892, -266.822, -268.786, +-271.036, -266.955, -267.356, -270.616, -265.706, +-264.444, -263.224, -263.313, -265.252, -263.874, +-265.89 , -260.837, -262.717, -262.073, -264.779, +-264.05 , -265.203, -262.597, -261.822, -264.143, +-268.655, -269.055, -270.736, -265.17 , -265.217, +-269.879, -270.83 , -271.194, -269.754, -263.825, +-263.737, -265.485, -264.626, -264.713, -265.561, +-266.183, -262.944, -263.938, -263.534, -263.802, +-262.138, -262.138, -261.331, -261.777, -261.62 , +-263.027, -263.062, -262.453, -263.18 , -264.445, +-266.134, -265.103, -264.626, -264.427, -265.528, +-263.938, -263.587, -263.358, -264.897, -265.179, +-264.573, -270.805, -270.824, -268.878, -268.878, +-269.638, -269.536, -276.973, -274.614, -277.589, +-273.321, -273.301, -271.049, -273.554, -269.292; + +draw_vecs[1] << +-270.284, -266.874, -263.361, -260.089, -262.139, +-262.139, -265.862, -265.862, -264.475, -264.475, +-263.834, -262.765, -263.039, -265.855, -267.63 , +-266.903, -267.004, -262.547, -262.196, -259.266, +-259.185, -258.549, -259.665, -258.64 , -258.7 , +-260.475, -260.475, -261.463, -261.483, -261.603, +-261.016, -262.461, -262.359, -262.543, -261.563, +-261.563, -262.017, -262.425, -262.895, -263.366, +-267.275, -265.694, -266.102, -265.527, -261.296, +-261.296, -263.983, -262.662, -263.794, -268.656, +-267.987, -268.543, -267.519, -265.96 , -267.899, +-268.445, -269.063, -270.79 , -264.644, -265.781, +-268.941, -269.489, -269.419, -272.76 , -267.807, +-270.202, -267.557, -265.109, -265.497, -268.019, +-266.981, -268.117, -265.153, -265.451, -271.16 , +-266.011, -265.764, -267.551, -267.334, -264.686, +-266.051, -267.103, -268.63 , -269.366, -269.251, +-267.918, -267.476, -265.557, -266.437, -264.879, +-265.035, -266.154, -268.055, -265.552, -265.48 , +-264.146, -264.952, -264.283, -264.768, -262.73 , +-263.659, -263.659, -270.215, -269.461, -271.369, +-276.308, -270.751, -267.617, -268.485, -266.52 , +-266.03 , -263.784, -263.786, -264.258, -264.176, +-265.869, -264.22 , -265.61 , -264.198, -263.931, +-264.291, -266.761, -267.04 , -268.286, -266.905, +-268.179, -266.984, -268.538, -267.756, -267.756, +-269.435, -269.391, -264.409, -264.465, -266.222, +-270.104, -269.579, -267.769, -263.89 , -264.642, +-264.289, -262.59 , -260.718, -266.693, -272.436, +-272.034, -269.769, -262.51 , -265.406, -272.021, +-270.796, -267.703, -267.549, -265.936, -265.205, +-268.592, -265.528, -261.628, -262.462, -262.253, +-262.93 , -262.932, -262.872, -263.627, -264.512, +-263.074, -263.795, -263.434, -265.15 , -264.709, +-262.096, -263.163, -259.09 , -259.09 , -261.602, +-261.602, -264.06 , -263.836, -260.945, -261.985, +-262.039, -261.927, -268.013, -272.047, -273.161, +-268.47 , -269.855, -269.855, -267.957, -267.957, +-261.807, -261.807, -261.807, -261.807, -262.876, +-262.905, -262.086, -262.461, -262.948, -261.041, +-258.394, -258.675, -259.269, -262.313, -261.776, -259 +, -257.015, -258.733, -259.681, -261.881, -263.303, +-263.303, -264.105, -263.857, -264.845, -265.121, +-265.121, -265.121, -265.121, -264.857, -263.811, +-263.796, -263.796, -265.803, -263.442, -263.442, +-262.237, -262.237, -262.237, -261.879, -262.177, +-264.667, -265.174, -265.174, -264.707, -264.707, +-264.535, -265.637, -261.316, -261.456, -262.575, +-265.12 , -263.7 , -263.7 , -263.7 , -263.7 , +-264.145, -268.846, -261.643, -261.561; + +draw_vecs[2] << +-262.744, -261.998, -261.994, -262.239, -264.747, +-263.467, -266.498, -266.158, -266.158, -266.158, +-266.334, -268.21 , -266.863, -265.772, -267.149, +-266.097, -266.097, -265.845, -265.976, -267.216, +-269.566, -269.566, -270.05 , -269.622, -269.981, +-270.698, -270.698, -270.698, -268.899, -268.504, +-268.814, -267.439, -268.08 , -267.438, -268.135, +-268.135, -268.135, -268.135, -267.767, -267.448, +-268.76 , -268.76 , -267.301, -268.337, -267.902, +-269.79 , -267.688, -265.888, -266.014, -266.122, +-266.953, -266.722, -267.119, -267.119, -267.047, +-267.047, -266.797, -266.797, -266.269, -265.727, +-266.522, -266.522, -267.202, -265.66 , -265.66 , +-268.183, -266.952, -267.373, -264.304, -264.59 , +-263.903, -263.988, -264.204, -264.204, -265.643, +-265.643, -264.296, -264.457, -265.484, -265.378, +-265.128, -265.128, -265.128, -265.128, -264.844, +-266.096, -265.205, -265.205, -265.205, -265.198, +-265.198, -265.741, -265.314, -265.903, -265.903, +-266.001, -266.001, -265.504, -265.313, -265.949, +-265.194, -264.247, -264.247, -264.455, -264.455, +-264.455, -264.303, -264.303, -264.303, -264.303, +-265.261, -265.885, -265.097, -264.3 , -264.3 , +-264.555, -264.798, -264.404, -265.854, -265.858, +-265.858, -265.858, -265.858, -265.122, -264.49 , +-264.49 , -264.066, -264.129, -265.102, -264.296, +-264.324, -264.273, -264.078, -264.07 , -263.923, +-263.923, -264.11 , -264.3 , -264.3 , -264.051, +-264.051, -264.864, -265.61 , -265.61 , -264.518, +-264.555, -264.711, -264.711, -265.346, -265.346, +-264.946, -265.135, -265.102, -265.625, -265.482, +-265.482, -265.482, -265.194, -264.499, -264.178, +-264.848, -264.848, -264.155, -264.155, -264.117, +-264.45 , -264.45 , -265.476, -265.476, -267.104, +-264.804, -264.496, -264.565, -264.637, -264.426, +-264.574, -264.659, -265.509, -265.509, -264.669, +-264.669, -264.669, -264.669, -265.014, -265.014, +-264.797, -266.052, -266.052, -267.349, -266.748, +-266.266, -267.778, -266.736, -266.736, -268.388, +-265.949, -265.949, -266.144, -267.147, -267.147, +-265.965, -265.329, -265.411, -267.016, -265.516, +-265.516, -265.516, -265.516, -265.516, -267.111, +-266.987, -266.662, -265.979, -265.517, -265.495, +-265.898, -266.085, -266.085, -265.282, -265.337, +-265.337, -265.337, -265.873, -265.044, -265.044, +-265.044, -267.565, -265.853, -266.693, -265.688, +-265.92 , -266.021, -266.147, -266.802, -266.84 , +-266.84 , -266.84 , -266.84 , -266.84 , -267.173, +-266.361, -266.361, -266.825, -266.444, -266.444, +-266.444, -266.444, -267.152, -266.515, -266.533; + +draw_vecs[3] << + -266.014,-265.791, -265.791, -266.053, -266.196, +-265.953, -265.787, -265.787, -266.336, -266.658, +-267.189, -267.232, -270.645, -270.645, -270.645, +-270.645, -272.083, -269 , -266.46 , -266.786, +-267.246, -266.353, -266.782, -266.782, -266.823, +-266.781, -266.781, -266.781, -266.781, -266.718, +-266.562, -266.837, -268.308, -267.161, -267.081, +-267.889, -267.103, -267.103, -265.488, -265.73 , +-265.73 , -266.46 , -266.46 , -267.288, -265.69 , +-265.69 , -265.69 , -266.302, -266.107, -266.107, +-266.107, -266.107, -266.107, -264.795, -264.659, +-265.365, -266.233, -265.995, -265.995, -266.013, +-266.025, -265.512, -265.512, -265.512, -265.512, +-265.605, -265.605, -265.605, -265.571, -265.916, +-265.325, -266.295, -265.598, -266.856, -266.856, +-266.19 , -266.19 , -265.077, -265.249, -265.43 , +-265.43 , -265.429, -265.481, -265.408, -265.408, +-265.991, -265.595, -266.051, -266.051, -266.525, +-267.047, -265.283, -265.167, -265.223, -265.223, +-265.526, -265.158, -265.11 , -265.11 , -265.24 , +-265.293, -265.45 , -265.45 , -265.109, -265.863, +-265.112, -265.112, -265.112, -265.112, -265.112, +-264.967, -264.967, -266.176, -265.038, -265.238, +-265.238, -265.238, -265.531, -265.531, -265.461, +-265.882, -265.882, -265.301, -265.301, -266.118, +-266.254, -264.316, -264.316, -264.316, -265.241, +-264.463, -264.658, -265.323, -264.331, -264.331, +-266.603, -264.131, -264.131, -264.289, -264.289, +-265.96 , -264.685, -264.731, -265.294, -264.663, +-264.831, -264.288, -265.753, -265.753, -265.925, +-265.925, -268.329, -266.288, -266.288, -266.288, +-266.288, -264.796, -264.573, -265.464, -264.95 , +-264.966, -264.602, -264.602, -265.338, -265.549, +-265.575, -267.306, -266.802, -266.268, -265.888, +-265.746, -265.746, -265.746, -266.17 , -266.134, +-265.365, -265.365, -265.484, -266.118, -265.285, +-265.285, -265.285, -265.285, -265.285, -266.041, +-266.041, -267.61 , -267.557, -267.557, -266.593, +-266.132, -265.76 , -265.757, -265.793, -265.793, +-265.793, -265.598, -265.354, -267.131, -265.039, +-265.039, -265.039, -265.039, -265.039, -265.869, +-265.869, -266.309, -265.897, -265.727, -265.958, +-267.231, -266.862, -267.255, -267.545, -267.009, +-265.8 , -265.551, -266.254, -265.394, -264.825, +-264.825, -265.245, -265.245, -264.312, -264.365, +-264.253, -264.514, -264.413, -264.413, -264.413, +-264.413, -264.413, -264.589, -265.277, -265.378, +-265.69 , -265.69 , -264.972, -264.972, -264.972, +-264.972, -265.283, -265.237, -264.671, -264.88 , +-265.099, -266.919, -265.878, -264.653, -264.653; + + const Communicator& comm = Session::inter_chain_comm(num_chains); + + const std::vector draws{draw_vecs[0].data(), + draw_vecs[1].data(), draw_vecs[2].data(), draw_vecs[3].data()}; + std::vector chain_stepsize{1.1, 1.2, 1.3, 1.4}; - const Communicator& comm = Session::inter_chain_comm(num_chains); - // each rank has different draws - for (int j = 0; j < s; ++j) { acc[0](draw_vecs[comm.rank()](j)); } - // each rank's stepsize is jittered - for (int j = 0; j < num_chains; ++j) {chain_stepsize[j] += 0.1 * comm.rank();} - - std::vector chain_gather(2 * num_chains * max_num_windows, 0.0); - std::vector output = stan::services::util::mpi_cross_chain_adapt(acc, - chain_stepsize, - 1, - max_num_windows, - s, num_chains, - 1.5, chain_gather); - if (comm.rank() == 0) { - EXPECT_EQ(output[1], rhat); + for (int j = 0; j < num_chains; ++j) { + chain_stepsize[j] += 0.1 * comm.rank(); } - EXPECT_FLOAT_EQ(output[0], 1.25); - // } + // a large ESS target should make all windows fail to pass tests + for (int curr_num_win = 1; curr_num_win < 6; ++curr_num_win) { + double target_ess = 40.0; + std::vector>> acc(curr_num_win); + for (int win = 0; win < curr_num_win; ++win) { + for (int j = win * window_size; j < curr_num_win * window_size; ++j) { + acc[win](draw_vecs[comm.rank()](j)); + } + } + + std::vector output = + stan::services::util::mpi_cross_chain_adapt(draws[comm.rank()], acc, + chain_stepsize, + curr_num_win, max_num_windows, + window_size, num_chains, 1.1, target_ess); + for (int win = 0; win < curr_num_win; ++win) { + const std::vector p{ + draws[0] + win * window_size, + draws[1] + win * window_size, + draws[2] + win * window_size, + draws[3] + win * window_size}; + double rhat = + stan::analyze::compute_potential_scale_reduction(p, (curr_num_win - win) * window_size); + if (comm.rank() == 0) { + EXPECT_FLOAT_EQ(rhat, output[win + 1]); + } + } + } + + // a target_ess that 4-window tests should pass + { + int curr_num_win = 4; + double target_ess = 15.0; + std::vector>> acc(curr_num_win); + for (int win = 0; win < curr_num_win; ++win) { + for (int j = win * window_size; j < curr_num_win * window_size; ++j) { + acc[win](draw_vecs[comm.rank()](j)); + } + } + + std::vector output = + stan::services::util::mpi_cross_chain_adapt(draws[comm.rank()], acc, + chain_stepsize, + curr_num_win, max_num_windows, + window_size, num_chains, 1.1, target_ess); + + int win = 1; // win = 1 @c is_adapted + const std::vector p{ + draws[0] + win * window_size, + draws[1] + win * window_size, + draws[2] + win * window_size, + draws[3] + win * window_size}; + double rhat = + stan::analyze::compute_potential_scale_reduction(p, (curr_num_win - win) * window_size); + if (comm.rank() == 0) { + EXPECT_FLOAT_EQ(rhat, output[1 + 1]); + } + } } #endif From d54c26a21b10391deac68471faa732309c7b8912 Mon Sep 17 00:00:00 2001 From: yiz Date: Tue, 14 Jan 2020 13:38:40 -0800 Subject: [PATCH 18/73] use cross chain warmup for nuts adapted --- .../services/sample/hmc_nuts_diag_e_adapt.hpp | 17 +++- src/stan/services/util/campfire_warmup.hpp | 98 ++++++------------- .../services/util/mpi_cross_chain_adapt.hpp | 8 +- .../util/run_mpi_adaptive_sampler.hpp | 7 +- .../unit/services/util/mpi_warmup_test.cpp | 5 +- 5 files changed, 55 insertions(+), 80 deletions(-) diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp index 4aba49937e2..ded741e711d 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp @@ -65,6 +65,22 @@ int hmc_nuts_diag_e_adapt( unsigned int window, callbacks::interrupt& interrupt, callbacks::logger& logger, callbacks::writer& init_writer, callbacks::writer& sample_writer, callbacks::writer& diagnostic_writer) { + + using stan::math::mpi::Session; + using stan::math::mpi::Communicator; + +#ifdef MPI_ADAPTED_WARMUP + const int num_chains = 4; + const Communicator& inter_comm = Session::inter_chain_comm(num_chains); + const Communicator& intra_comm = Session::intra_chain_comm(num_chains); + bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains); + if (is_inter_rank) { + random_seed += inter_comm.rank(); + } + MPI_Bcast(&random_seed, 1, MPI_UNSIGNED, 0, intra_comm.comm()); + int rank; + MPI_Comm_rank(MPI_COMM_STAN, &rank); +#endif boost::ecuyer1988 rng = util::create_rng(random_seed, chain); std::vector disc_vector; @@ -97,7 +113,6 @@ int hmc_nuts_diag_e_adapt( logger); #ifdef MPI_ADAPTED_WARMUP - const int num_chains = 4; util::run_mpi_adaptive_sampler(sampler, model, cont_vector, num_chains, num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); diff --git a/src/stan/services/util/campfire_warmup.hpp b/src/stan/services/util/campfire_warmup.hpp index 774da8a1157..0cda4d59bc4 100644 --- a/src/stan/services/util/campfire_warmup.hpp +++ b/src/stan/services/util/campfire_warmup.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -47,14 +48,11 @@ void campfire_warmup(Sampler& sampler, int num_chains, int num_iterations, int start, int finish, int num_thin, int refresh, bool save, bool warmup, + int window_size, double target_rhat, double target_ess, util::mcmc_writer& mcmc_writer, stan::mcmc::sample& init_s, Model& model, RNG& base_rng, callbacks::interrupt& callback, callbacks::logger& logger) { - // for prototyping, we have @c max_num_windows fixed - const int window_size = 100; - const int max_num_windows = num_iterations / window_size; - using boost::accumulators::accumulator_set; using boost::accumulators::stats; using boost::accumulators::tag::mean; @@ -63,13 +61,17 @@ void campfire_warmup(Sampler& sampler, int num_chains, using stan::math::mpi::Session; using stan::math::mpi::Communicator; - std::vector>> acc_log(max_num_windows); + const int max_num_windows = num_iterations / window_size; + std::vector>> + acc_log(max_num_windows); + std::vector acov(max_num_windows, 0.0); bool is_adapted = false; - const double target_rhat = 1.05; - const double target_ess = 50.0; int m = 0; + std::vector draw; + draw.reserve(num_iterations); + double stepsize = -999.0; while (m < num_iterations && (!is_adapted)) { callback(); @@ -93,80 +95,38 @@ void campfire_warmup(Sampler& sampler, int num_chains, mcmc_writer.write_diagnostic_params(init_s, sampler); } - double stepsize = -999.0; + const Communicator& inter_comm = Session::inter_chain_comm(num_chains); bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains); + int m_win = m / window_size + 1; - if (is_inter_rank && boost::math::isfinite(init_s.log_prob())) { - int m_win = m / window_size + 1; + // incrementally add data + if (is_inter_rank) { + draw.push_back(init_s.log_prob()); for (int i = 0; i < m_win; ++i) { acc_log[i](init_s.log_prob()); } + } - // though @c boost::acc gives population var instead - // of sample var, the nb. of draws is supposed to be - // large enough to make it irrelevant. But for - // between-chain variance we must correct it because - // the nb. of chains is not large - - if (m >= window_size && (m + 1) % window_size == 0) { - int n_gather = 3 * m_win; // mean, variance, stepsize - std::vector chain_gather(n_gather, 0.0); - for (int i = 0; i < m_win; ++i) { - chain_gather[3 * i] = boost::accumulators::mean(acc_log[i]); - chain_gather[3 * i + 1] = boost::accumulators::variance(acc_log[i]); - chain_gather[3 * i + 2] = sampler.get_nominal_stepsize(); + if (boost::math::isfinite(init_s.log_prob())) { + const Communicator& intra_comm = Session::intra_chain_comm(num_chains); + if ((m + 1) % window_size == 0) { + if (is_inter_rank) { + std::vector adapt_result = + stan::services::util::mpi_cross_chain_adapt(draw.data(), acc_log, + sampler.get_nominal_stepsize(), + m_win, max_num_windows, + window_size, num_chains, target_rhat, target_ess); + stepsize = adapt_result[0]; } - - const Communicator& comm = Session::inter_chain_comm(num_chains); - if (comm.rank() == 0) { - std::vector rhat(m_win), ess(m_win); - std::vector all_chain_gather(n_gather * num_chains); - MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, - all_chain_gather.data(), n_gather, MPI_DOUBLE, 0, comm.comm()); - for (int i = 0; i < m_win; ++i) { - accumulator_set> acc_chain_mean; - accumulator_set> acc_chain_var; - for (int chain = 0; chain < num_chains; ++chain) { - acc_chain_mean(all_chain_gather[chain * n_gather + 3 * i]); - acc_chain_var(all_chain_gather[chain * n_gather + 3 * i + 1]); - } - int n_draws = (m_win - i) * window_size; - double var_between = n_draws * boost::accumulators::variance(acc_chain_mean) - * num_chains / (num_chains - 1); - double var_within = boost::accumulators::mean(acc_chain_var); - rhat[i] = sqrt((var_between / var_within + n_draws - 1) / n_draws); - - // TODO also calculate ess - is_adapted = (rhat[i]) < target_rhat; - if (is_adapted) { - accumulator_set> acc_step; - for (int chain = 0; chain < num_chains; ++chain) { - acc_step(all_chain_gather[chain * n_gather + 3 * i + 2]); - } - stepsize = boost::accumulators::mean(acc_step); - std::cout << "taki test rhat: " << rhat[i] << "\n"; - break; - } - } - MPI_Bcast(&stepsize, 1, MPI_DOUBLE, 0, comm.comm()); - } else { - MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, - NULL, 0, MPI_DOUBLE, 0, comm.comm()); - MPI_Bcast(&stepsize, 1, MPI_DOUBLE, 0, comm.comm()); + MPI_Bcast(&stepsize, 1, MPI_DOUBLE, 0, intra_comm.comm()); + if (stepsize > 0.0) { + is_adapted = true; } } } - - const Communicator& intra_comm = Session::intra_chain_comm(num_chains); - MPI_Bcast(&stepsize, 1, MPI_DOUBLE, 0, intra_comm.comm()); - if (stepsize > 0.0) { - is_adapted = true; - sampler.set_nominal_stepsize(stepsize); - break; - } - m++; } + sampler.set_nominal_stepsize(stepsize); } } // namespace util diff --git a/src/stan/services/util/mpi_cross_chain_adapt.hpp b/src/stan/services/util/mpi_cross_chain_adapt.hpp index 726b20d43a0..90fe418d175 100644 --- a/src/stan/services/util/mpi_cross_chain_adapt.hpp +++ b/src/stan/services/util/mpi_cross_chain_adapt.hpp @@ -57,7 +57,7 @@ single_chain_ess(const double* draw, size_t num_draws) { std::vector mpi_cross_chain_adapt(const double* draw_p, const std::vector& acc, - const std::vector& chain_stepsize, + double chain_stepsize, int num_current_window, int max_num_window, int window_size, int num_chains, double target_rhat, double target_ess) { @@ -69,7 +69,6 @@ single_chain_ess(const double* draw, size_t num_draws) { using stan::math::mpi::Session; using stan::math::mpi::Communicator; - const Communicator& comm = Session::inter_chain_comm(num_chains); const int nd_win = 4; // mean, variance, chain_stepsize @@ -81,7 +80,7 @@ single_chain_ess(const double* draw, size_t num_draws) { chain_gather[nd_win * win] = boost::accumulators::mean(acc[win]); chain_gather[nd_win * win + 1] = boost::accumulators::variance(acc[win]) * unbiased_var_scale; - chain_gather[nd_win * win + 2] = chain_stepsize[win]; + chain_gather[nd_win * win + 2] = chain_stepsize; chain_gather[nd_win * win + 3] = single_chain_ess(draw_p + win * window_size, num_draws); } @@ -121,12 +120,11 @@ single_chain_ess(const double* draw, size_t num_draws) { break; } } - MPI_Bcast(res.data(), 1, MPI_DOUBLE, 0, comm.comm()); } else { MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, NULL, 0, MPI_DOUBLE, 0, comm.comm()); - MPI_Bcast(res.data(), 1, MPI_DOUBLE, 0, comm.comm()); } + MPI_Bcast(res.data(), 1, MPI_DOUBLE, 0, comm.comm()); return res; } } diff --git a/src/stan/services/util/run_mpi_adaptive_sampler.hpp b/src/stan/services/util/run_mpi_adaptive_sampler.hpp index bb54e3f9637..1c5714035dc 100644 --- a/src/stan/services/util/run_mpi_adaptive_sampler.hpp +++ b/src/stan/services/util/run_mpi_adaptive_sampler.hpp @@ -67,9 +67,14 @@ void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, // warmup clock_t start = clock(); + const double target_rhat = 1.1; + const double target_ess = 50; + const int window_size = 100; util::campfire_warmup(sampler, num_chains, num_warmup, 0, num_warmup + num_samples, - num_thin, refresh, save_warmup, true, writer, s, + num_thin, refresh, save_warmup, true, + window_size, target_rhat, target_ess, + writer, s, model, rng, interrupt, logger); clock_t end = clock(); double warm_delta_t = static_cast(end - start) / CLOCKS_PER_SEC; diff --git a/src/test/unit/services/util/mpi_warmup_test.cpp b/src/test/unit/services/util/mpi_warmup_test.cpp index 3d68f4b10d9..07a94b1ffbf 100644 --- a/src/test/unit/services/util/mpi_warmup_test.cpp +++ b/src/test/unit/services/util/mpi_warmup_test.cpp @@ -246,10 +246,7 @@ draw_vecs[3] << const std::vector draws{draw_vecs[0].data(), draw_vecs[1].data(), draw_vecs[2].data(), draw_vecs[3].data()}; - std::vector chain_stepsize{1.1, 1.2, 1.3, 1.4}; - for (int j = 0; j < num_chains; ++j) { - chain_stepsize[j] += 0.1 * comm.rank(); - } + double chain_stepsize = 1.1 + 0.1 * comm.rank(); // a large ESS target should make all windows fail to pass tests for (int curr_num_win = 1; curr_num_win < 6; ++curr_num_win) { From 1a0f7434d7669663d5e603b587d5407dc9ffd5cd Mon Sep 17 00:00:00 2001 From: yiz Date: Tue, 14 Jan 2020 16:16:50 -0800 Subject: [PATCH 19/73] stream writer only writes for rank == 0 --- src/stan/callbacks/stream_writer.hpp | 58 +++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/src/stan/callbacks/stream_writer.hpp b/src/stan/callbacks/stream_writer.hpp index 6519531c0ac..897c3fafaf3 100644 --- a/src/stan/callbacks/stream_writer.hpp +++ b/src/stan/callbacks/stream_writer.hpp @@ -2,6 +2,7 @@ #define STAN_CALLBACKS_STREAM_WRITER_HPP #include +#include #include #include #include @@ -52,12 +53,24 @@ class stream_writer : public writer { * * @param[in] state Values in a std::vector */ - void operator()(const std::vector& state) { write_vector(state); } + void operator()(const std::vector& state) { + write_vector(state); + } /** * Writes the comment_prefix to the stream followed by a newline. */ - void operator()() { output_ << comment_prefix_ << std::endl; } + void operator()() { +#ifdef MPI_ADAPTED_WARMUP + int rank; + MPI_Comm_rank(MPI_COMM_STAN, &rank); + if (rank == 0) { + output_ << comment_prefix_ << std::endl; + } +#else + output_ << comment_prefix_ << std::endl; +#endif + } /** * Writes the comment_prefix then the message followed by a newline. @@ -65,7 +78,15 @@ class stream_writer : public writer { * @param[in] message A string */ void operator()(const std::string& message) { +#ifdef MPI_ADAPTED_WARMUP + int rank; + MPI_Comm_rank(MPI_COMM_STAN, &rank); + if (rank == 0) { + output_ << comment_prefix_ << message << std::endl; + } +#else output_ << comment_prefix_ << message << std::endl; +#endif } private: @@ -89,16 +110,33 @@ class stream_writer : public writer { */ template void write_vector(const std::vector& v) { - if (v.empty()) - return; +#ifdef MPI_ADAPTED_WARMUP + int rank; + MPI_Comm_rank(MPI_COMM_STAN, &rank); + if (rank == 0) { + if (v.empty()) + return; - typename std::vector::const_iterator last = v.end(); - --last; + typename std::vector::const_iterator last = v.end(); + --last; - for (typename std::vector::const_iterator it = v.begin(); it != last; - ++it) - output_ << *it << ","; - output_ << v.back() << std::endl; + for (typename std::vector::const_iterator it = v.begin(); it != last; + ++it) + output_ << *it << ","; + output_ << v.back() << std::endl; + } +#else + if (v.empty()) + return; + + typename std::vector::const_iterator last = v.end(); + --last; + + for (typename std::vector::const_iterator it = v.begin(); it != last; + ++it) + output_ << *it << ","; + output_ << v.back() << std::endl; +#endif } }; From 84d9807135117bfe3e575c980f971182fac51366 Mon Sep 17 00:00:00 2001 From: yiz Date: Tue, 14 Jan 2020 22:53:45 -0800 Subject: [PATCH 20/73] use degenerated version of compute ess for warmup --- .../services/util/mpi_cross_chain_adapt.hpp | 122 ++++++++++++++++-- 1 file changed, 109 insertions(+), 13 deletions(-) diff --git a/src/stan/services/util/mpi_cross_chain_adapt.hpp b/src/stan/services/util/mpi_cross_chain_adapt.hpp index 90fe418d175..21005a7d1d1 100644 --- a/src/stan/services/util/mpi_cross_chain_adapt.hpp +++ b/src/stan/services/util/mpi_cross_chain_adapt.hpp @@ -17,6 +17,110 @@ namespace stan { namespace services { namespace util { + +inline double compute_effective_sample_size(std::vector draws, + std::vector sizes) { + int num_chains = sizes.size(); + size_t num_draws = sizes[0]; + for (int chain = 1; chain < num_chains; ++chain) { + num_draws = std::min(num_draws, sizes[chain]); + } + + // check if chains are constant; all equal to first draw's value + bool are_all_const = false; + Eigen::VectorXd init_draw = Eigen::VectorXd::Zero(num_chains); + + for (int chain_idx = 0; chain_idx < num_chains; chain_idx++) { + Eigen::Map> draw( + draws[chain_idx], sizes[chain_idx]); + + for (int n = 0; n < num_draws; n++) { + if (!boost::math::isfinite(draw(n))) { + return std::numeric_limits::quiet_NaN(); + } + } + + init_draw(chain_idx) = draw(0); + + if (draw.isApproxToConstant(draw(0))) { + are_all_const |= true; + } + } + + if (are_all_const) { + // If all chains are constant then return NaN + // if they all equal the same constant value + if (init_draw.isApproxToConstant(init_draw(0))) { + return std::numeric_limits::quiet_NaN(); + } + } + + Eigen::Matrix acov(num_chains); + Eigen::VectorXd chain_mean(num_chains); + Eigen::VectorXd chain_var(num_chains); + for (int chain = 0; chain < num_chains; ++chain) { + Eigen::Map> draw( + draws[chain], sizes[chain]); + stan::analyze::autocovariance(draw, acov(chain)); + chain_mean(chain) = draw.mean(); + chain_var(chain) = acov(chain)(0) * num_draws / (num_draws - 1); + } + + double mean_var = chain_var.mean(); + double var_plus = mean_var * (num_draws - 1) / num_draws; + if (num_chains > 1) + var_plus += math::variance(chain_mean); + Eigen::VectorXd rho_hat_s(num_draws); + rho_hat_s.setZero(); + Eigen::VectorXd acov_s(num_chains); + for (int chain = 0; chain < num_chains; ++chain) + acov_s(chain) = acov(chain)(1); + double rho_hat_even = 1.0; + rho_hat_s(0) = rho_hat_even; + double rho_hat_odd = 1 - (mean_var - acov_s.mean()) / var_plus; + rho_hat_s(1) = rho_hat_odd; + + // Convert raw autocovariance estimators into Geyer's initial + // positive sequence. Loop only until num_draws - 4 to + // leave the last pair of autocorrelations as a bias term that + // reduces variance in the case of antithetical chains. + size_t s = 1; + while (s < (num_draws - 4) && (rho_hat_even + rho_hat_odd) > 0) { + for (int chain = 0; chain < num_chains; ++chain) + acov_s(chain) = acov(chain)(s + 1); + rho_hat_even = 1 - (mean_var - acov_s.mean()) / var_plus; + for (int chain = 0; chain < num_chains; ++chain) + acov_s(chain) = acov(chain)(s + 2); + rho_hat_odd = 1 - (mean_var - acov_s.mean()) / var_plus; + if ((rho_hat_even + rho_hat_odd) >= 0) { + rho_hat_s(s + 1) = rho_hat_even; + rho_hat_s(s + 2) = rho_hat_odd; + } + s += 2; + } + + int max_s = s; + // this is used in the improved estimate, which reduces variance + // in antithetic case -- see tau_hat below + if (rho_hat_even > 0) + rho_hat_s(max_s + 1) = rho_hat_even; + + // Convert Geyer's initial positive sequence into an initial + // monotone sequence + for (int s = 1; s <= max_s - 3; s += 2) { + if (rho_hat_s(s + 1) + rho_hat_s(s + 2) > rho_hat_s(s - 1) + rho_hat_s(s)) { + rho_hat_s(s + 1) = (rho_hat_s(s - 1) + rho_hat_s(s)) / 2; + rho_hat_s(s + 2) = rho_hat_s(s + 1); + } + } + + double num_total_draws = num_chains * num_draws; + // Geyer's truncated estimator for the asymptotic variance + // Improved estimate reduces variance in antithetic case + double tau_hat = -1 + 2 * rho_hat_s.head(max_s).sum() + rho_hat_s(max_s + 1); + return std::min(num_total_draws / tau_hat, + num_total_draws * std::log10(num_total_draws)); +} /* * Computes the effective sample size (ESS) for the specified * parameter across all kept samples. The value returned is the @@ -29,18 +133,10 @@ namespace util { * calculated(on the fly during adaptation) * */ -inline double -single_chain_ess(const double* draw, size_t num_draws) { - Eigen::Map > d(draw, num_draws); - Eigen::Matrix acov; - stan::math::autocorrelation(d, acov); - double rhos = 0.0; - int i = 1; - while (i < num_draws && acov(i) > 0.05) { - rhos += acov(i); - i++; - } - return double(num_draws) / (1.0 + 2.0 * rhos); +inline double compute_effective_sample_size(const double* draw, size_t size) { + std::vector draws{draw}; + std::vector sizes{size}; + return compute_effective_sample_size(draws, sizes); } /* @@ -82,7 +178,7 @@ single_chain_ess(const double* draw, size_t num_draws) { unbiased_var_scale; chain_gather[nd_win * win + 2] = chain_stepsize; chain_gather[nd_win * win + 3] = - single_chain_ess(draw_p + win * window_size, num_draws); + compute_effective_sample_size(draw_p + win * window_size, num_draws); } double stepsize = -999.0; From e4c6b50f574b20f960cfde933ee8c7920ffe94c1 Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 15 Jan 2020 13:23:48 -0800 Subject: [PATCH 21/73] fix branch of gitmodules --- .gitmodules | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitmodules b/.gitmodules index 235095d569a..8bc3af6a4a0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,4 @@ [submodule "lib/stan_math"] path = lib/stan_math url = https://github.com/stan-dev/math.git + branch = mpi_warmup_framework From 731b310ea7c2666ab0e9e773342593291d5ade3c Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 15 Jan 2020 13:31:43 -0800 Subject: [PATCH 22/73] rename campfire warmup to cross chain warmup --- .../util/{campfire_warmup.hpp => mpi_cross_chain_warmup} | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) rename src/stan/services/util/{campfire_warmup.hpp => mpi_cross_chain_warmup} (96%) diff --git a/src/stan/services/util/campfire_warmup.hpp b/src/stan/services/util/mpi_cross_chain_warmup similarity index 96% rename from src/stan/services/util/campfire_warmup.hpp rename to src/stan/services/util/mpi_cross_chain_warmup index 0cda4d59bc4..cda514d8c5e 100644 --- a/src/stan/services/util/campfire_warmup.hpp +++ b/src/stan/services/util/mpi_cross_chain_warmup @@ -1,5 +1,5 @@ -#ifndef STAN_SERVICES_UTIL_CAMPFIRE_WARMUP_HPP -#define STAN_SERVICES_UTIL_CAMPFIRE_WARMUP_HPP +#ifndef STAN_SERVICES_UTIL_MPI_CROSS_CHAIN_WARMUP_HPP +#define STAN_SERVICES_UTIL_MPI_CROSS_CHAIN_WARMUP_HPP #include #include @@ -44,7 +44,7 @@ namespace util { * @param[in,out] logger logger for messages */ template -void campfire_warmup(Sampler& sampler, int num_chains, +void mpi_cross_chain_warmup(Sampler& sampler, int num_chains, int num_iterations, int start, int finish, int num_thin, int refresh, bool save, bool warmup, From b393841388d470223be2b17321f0f9b8d60e6770 Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 15 Jan 2020 13:37:52 -0800 Subject: [PATCH 23/73] rename campfire warmup to cross chain --- .../{mpi_cross_chain_warmup => mpi_cross_chain_warmup.hpp} | 0 src/stan/services/util/run_mpi_adaptive_sampler.hpp | 4 ++-- 2 files changed, 2 insertions(+), 2 deletions(-) rename src/stan/services/util/{mpi_cross_chain_warmup => mpi_cross_chain_warmup.hpp} (100%) diff --git a/src/stan/services/util/mpi_cross_chain_warmup b/src/stan/services/util/mpi_cross_chain_warmup.hpp similarity index 100% rename from src/stan/services/util/mpi_cross_chain_warmup rename to src/stan/services/util/mpi_cross_chain_warmup.hpp diff --git a/src/stan/services/util/run_mpi_adaptive_sampler.hpp b/src/stan/services/util/run_mpi_adaptive_sampler.hpp index 1c5714035dc..01f0d532a14 100644 --- a/src/stan/services/util/run_mpi_adaptive_sampler.hpp +++ b/src/stan/services/util/run_mpi_adaptive_sampler.hpp @@ -5,7 +5,7 @@ #include #include #include -#include +#include #include #include @@ -70,7 +70,7 @@ void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, const double target_rhat = 1.1; const double target_ess = 50; const int window_size = 100; - util::campfire_warmup(sampler, num_chains, + util::mpi_cross_chain_warmup(sampler, num_chains, num_warmup, 0, num_warmup + num_samples, num_thin, refresh, save_warmup, true, window_size, target_rhat, target_ess, From 524f7500603b345db19439bbe6ed3aab2315dcd6 Mon Sep 17 00:00:00 2001 From: yiz Date: Tue, 21 Jan 2020 17:16:44 -0800 Subject: [PATCH 24/73] add mpi cross chain adapter --- .../services/util/mpi_cross_chain_adapter.hpp | 318 ++++++++++++++++++ 1 file changed, 318 insertions(+) create mode 100644 src/stan/services/util/mpi_cross_chain_adapter.hpp diff --git a/src/stan/services/util/mpi_cross_chain_adapter.hpp b/src/stan/services/util/mpi_cross_chain_adapter.hpp new file mode 100644 index 00000000000..0b9df2a9457 --- /dev/null +++ b/src/stan/services/util/mpi_cross_chain_adapter.hpp @@ -0,0 +1,318 @@ +#ifndef STAN_SERVICES_UTIL_MPI_CROSS_CHAIN_ADAPTER_HPP +#define STAN_SERVICES_UTIL_MPI_CROSS_CHAIN_ADAPTER_HPP + +#include +#include +#include +#include +#include +// #include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace services { +namespace util { + + class mpi_cross_chain_adapter { + protected: + bool is_adapted_; + int window_size_; + int num_chains_; + int max_num_windows_; + double target_rhat_; + double target_ess_; + std::vector log_prob_draws_; + std::vector>> log_prob_accumulators_; // NOLINT + Eigen::ArrayXd rhat_; + Eigen::ArrayXd ess_; + + public: + mpi_cross_chain_adapter() = default; + + inline void set_cross_chain_adaptation_params(int num_iterations, + int window_size, + int num_chains, + double target_rhat, double target_ess) { + window_size_ = window_size; + num_chains_ = num_chains; + max_num_windows_ = num_iterations / window_size; + target_rhat_ = target_rhat; + target_ess_ = target_ess; + log_prob_draws_.clear(); + log_prob_draws_.reserve(num_iterations); + log_prob_accumulators_.clear(); + log_prob_accumulators_.resize(max_num_windows_); + rhat_ = Eigen::ArrayXd::Zero(max_num_windows_); + ess_ = Eigen::ArrayXd::Zero(num_chains_); + } + + inline void reset_cross_chain_adaptation() { + log_prob_draws_.clear(); + log_prob_accumulators_.clear(); + log_prob_accumulators_.resize(max_num_windows_); + rhat_ = Eigen::ArrayXd::Zero(max_num_windows_); + ess_ = Eigen::ArrayXd::Zero(num_chains_); + } + + inline int current_cross_chain_window_counter() { + return (log_prob_draws_.size() - 1) / window_size_ + 1; + } + + inline void add_cross_chain_sample(const Eigen::VectorXd& q, double s) { + using stan::math::mpi::Session; + using stan::math::mpi::Communicator; + bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); + if (is_inter_rank) { + log_prob_draws_.push_back(s); + int n = current_cross_chain_window_counter(); + for (int i = 0; i < n; ++i) { + log_prob_accumulators_[i](s); + } + } + } + + inline double compute_effective_sample_size(std::vector draws, + std::vector sizes) { + int num_chains = sizes.size(); + size_t num_draws = sizes[0]; + for (int chain = 1; chain < num_chains; ++chain) { + num_draws = std::min(num_draws, sizes[chain]); + } + + // check if chains are constant; all equal to first draw's value + bool are_all_const = false; + Eigen::VectorXd init_draw = Eigen::VectorXd::Zero(num_chains); + + for (int chain_idx = 0; chain_idx < num_chains; chain_idx++) { + Eigen::Map> draw( + draws[chain_idx], sizes[chain_idx]); + + for (int n = 0; n < num_draws; n++) { + if (!boost::math::isfinite(draw(n))) { + return std::numeric_limits::quiet_NaN(); + } + } + + init_draw(chain_idx) = draw(0); + + if (draw.isApproxToConstant(draw(0))) { + are_all_const |= true; + } + } + + if (are_all_const) { + // If all chains are constant then return NaN + // if they all equal the same constant value + if (init_draw.isApproxToConstant(init_draw(0))) { + return std::numeric_limits::quiet_NaN(); + } + } + + Eigen::Matrix acov(num_chains); + Eigen::VectorXd chain_mean(num_chains); + Eigen::VectorXd chain_var(num_chains); + for (int chain = 0; chain < num_chains; ++chain) { + Eigen::Map> draw( + draws[chain], sizes[chain]); + stan::analyze::autocovariance(draw, acov(chain)); + chain_mean(chain) = draw.mean(); + chain_var(chain) = acov(chain)(0) * num_draws / (num_draws - 1); + } + + double mean_var = chain_var.mean(); + double var_plus = mean_var * (num_draws - 1) / num_draws; + if (num_chains > 1) + var_plus += math::variance(chain_mean); + Eigen::VectorXd rho_hat_s(num_draws); + rho_hat_s.setZero(); + Eigen::VectorXd acov_s(num_chains); + for (int chain = 0; chain < num_chains; ++chain) + acov_s(chain) = acov(chain)(1); + double rho_hat_even = 1.0; + rho_hat_s(0) = rho_hat_even; + double rho_hat_odd = 1 - (mean_var - acov_s.mean()) / var_plus; + rho_hat_s(1) = rho_hat_odd; + + // Convert raw autocovariance estimators into Geyer's initial + // positive sequence. Loop only until num_draws - 4 to + // leave the last pair of autocorrelations as a bias term that + // reduces variance in the case of antithetical chains. + size_t s = 1; + while (s < (num_draws - 4) && (rho_hat_even + rho_hat_odd) > 0) { + for (int chain = 0; chain < num_chains; ++chain) + acov_s(chain) = acov(chain)(s + 1); + rho_hat_even = 1 - (mean_var - acov_s.mean()) / var_plus; + for (int chain = 0; chain < num_chains; ++chain) + acov_s(chain) = acov(chain)(s + 2); + rho_hat_odd = 1 - (mean_var - acov_s.mean()) / var_plus; + if ((rho_hat_even + rho_hat_odd) >= 0) { + rho_hat_s(s + 1) = rho_hat_even; + rho_hat_s(s + 2) = rho_hat_odd; + } + s += 2; + } + + int max_s = s; + // this is used in the improved estimate, which reduces variance + // in antithetic case -- see tau_hat below + if (rho_hat_even > 0) + rho_hat_s(max_s + 1) = rho_hat_even; + + // Convert Geyer's initial positive sequence into an initial + // monotone sequence + for (int s = 1; s <= max_s - 3; s += 2) { + if (rho_hat_s(s + 1) + rho_hat_s(s + 2) > rho_hat_s(s - 1) + rho_hat_s(s)) { + rho_hat_s(s + 1) = (rho_hat_s(s - 1) + rho_hat_s(s)) / 2; + rho_hat_s(s + 2) = rho_hat_s(s + 1); + } + } + + double num_total_draws = num_chains * num_draws; + // Geyer's truncated estimator for the asymptotic variance + // Improved estimate reduces variance in antithetic case + double tau_hat = -1 + 2 * rho_hat_s.head(max_s).sum() + rho_hat_s(max_s + 1); + return std::min(num_total_draws / tau_hat, + num_total_draws * std::log10(num_total_draws)); + } + + /* + * Computes the effective sample size (ESS) for the specified + * parameter across all kept samples. The value returned is the + * minimum of ESS and the number_total_draws * + * log10(number_total_draws). + * + * This version is based on the one at + * stan/analyze/mcmc/compute_effective_sample_size.hpp + * but assuming the chain_mean and chain_var has been + * calculated(on the fly during adaptation) + * + */ + inline double compute_effective_sample_size(size_t i_begin, size_t i_size) { + std::vector draws{log_prob_draws_.data() + i_begin}; + std::vector sizes{i_size}; + return compute_effective_sample_size(draws, sizes); + } + + inline const Eigen::ArrayXd& cross_chain_adapt_rhat() { + return rhat_; + } + + inline const Eigen::ArrayXd& cross_chain_adapt_ess() { + return ess_; + } + + inline bool is_cross_chain_adapt_window_end() { + return (!log_prob_draws_.empty()) && + (log_prob_draws_.size() % window_size_ == 0); + } + + inline void debug() { + std::cout << "taki test debug: " << compute_effective_sample_size(0, 50) << "\n"; + } + + /* + * @tparam Sampler sampler class + * @param[in] m_win number of windows + * @param[in] window_size window size + * @param[in] num_chains number of chains + * @param[in,out] chain_gather gathered information from each chain, + * must have enough capacity to store up to + * maximum windows for all chains. + # @return vector {stepsize, rhat(only in rank 0)} + */ + inline bool cross_chain_adaptation(double& chain_stepsize, + Eigen::VectorXd& inv_e_metric) { + using boost::accumulators::accumulator_set; + using boost::accumulators::stats; + using boost::accumulators::tag::mean; + using boost::accumulators::tag::variance; + + using stan::math::mpi::Session; + using stan::math::mpi::Communicator; + + bool is_adapted = false; + + if (is_cross_chain_adapt_window_end()) { + bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); + if (is_inter_rank) { + const Communicator& comm = Session::inter_chain_comm(num_chains_); + + const int nd_win = 4; // mean, variance, chain_stepsize + const int win_count = current_cross_chain_window_counter(); + int n_gather = nd_win * win_count; + std::vector chain_gather(n_gather, 0.0); + for (int win = 0; win < win_count; ++win) { + int num_draws = (win_count - win) * window_size_; + double unbiased_var_scale = num_draws / (num_draws - 1.0); + chain_gather[nd_win * win] = boost::accumulators::mean(log_prob_accumulators_[win]); + chain_gather[nd_win * win + 1] = boost::accumulators::variance(log_prob_accumulators_[win]) * + unbiased_var_scale; + chain_gather[nd_win * win + 2] = chain_stepsize; + chain_gather[nd_win * win + 3] = + compute_effective_sample_size(win * window_size_, num_draws); + } + + double invalid_stepsize = -999.0; + rhat_ = Eigen::ArrayXd::Zero(max_num_windows_); + ess_ = Eigen::ArrayXd::Zero(num_chains_); + + if (comm.rank() == 0) { + std::vector all_chain_gather(n_gather * num_chains_); + MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, + all_chain_gather.data(), n_gather, MPI_DOUBLE, 0, comm.comm()); + for (int win = 0; win < win_count; ++win) { + accumulator_set> acc_chain_mean; + accumulator_set> acc_chain_var; + accumulator_set> acc_step; + Eigen::VectorXd chain_mean(num_chains_); + Eigen::VectorXd chain_var(num_chains_); + for (int chain = 0; chain < num_chains_; ++chain) { + chain_mean(chain) = all_chain_gather[chain * n_gather + nd_win * win]; + acc_chain_mean(chain_mean(chain)); + chain_var(chain) = all_chain_gather[chain * n_gather + nd_win * win + 1]; + acc_chain_var(chain_var(chain)); + acc_step(all_chain_gather[chain * n_gather + nd_win * win + 2]); + ess_(chain) = all_chain_gather[chain * n_gather + nd_win * win + 3]; + } + size_t num_draws = (win_count - win) * window_size_; + double var_between = num_draws * boost::accumulators::variance(acc_chain_mean) + * num_chains_ / (num_chains_ - 1); + double var_within = boost::accumulators::mean(acc_chain_var); + rhat_(win) = sqrt((var_between / var_within + num_draws - 1) / num_draws); + is_adapted = rhat_(win) < target_rhat_ && (ess_ > target_ess_).all(); + chain_stepsize = invalid_stepsize; + if (is_adapted) { + chain_stepsize = boost::accumulators::mean(acc_step); + break; + } + } + } else { + MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, + NULL, 0, MPI_DOUBLE, 0, comm.comm()); + } + MPI_Bcast(&chain_stepsize, 1, MPI_DOUBLE, 0, comm.comm()); + } + const Communicator& intra_comm = Session::intra_chain_comm(num_chains_); + MPI_Bcast(&chain_stepsize, 1, MPI_DOUBLE, 0, intra_comm.comm()); + is_adapted = chain_stepsize > 0.0; + if (is_adapted) { + // MPI_Bcast(mpi_inv_e_metric.data(), num_params, MPI_DOUBLE, 0, intra_comm.comm()) + } + } + return is_adapted; + } + + }; +} +} +} +#endif From c67a5e0cb0923bba1ab38f0323a03b8497e51831 Mon Sep 17 00:00:00 2001 From: yiz Date: Tue, 21 Jan 2020 17:21:55 -0800 Subject: [PATCH 25/73] rename mpi adapter; add unit tests --- .../util => mcmc}/mpi_cross_chain_adapter.hpp | 8 +- .../unit/services/util/mpi_warmup_test.cpp | 91 ++++++++++--------- 2 files changed, 52 insertions(+), 47 deletions(-) rename src/stan/{services/util => mcmc}/mpi_cross_chain_adapter.hpp (98%) diff --git a/src/stan/services/util/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/mpi_cross_chain_adapter.hpp similarity index 98% rename from src/stan/services/util/mpi_cross_chain_adapter.hpp rename to src/stan/mcmc/mpi_cross_chain_adapter.hpp index 0b9df2a9457..88dfae161d2 100644 --- a/src/stan/services/util/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/mpi_cross_chain_adapter.hpp @@ -1,5 +1,5 @@ -#ifndef STAN_SERVICES_UTIL_MPI_CROSS_CHAIN_ADAPTER_HPP -#define STAN_SERVICES_UTIL_MPI_CROSS_CHAIN_ADAPTER_HPP +#ifndef STAN_MCMC_MPI_CROSS_CHAIN_ADAPTER_HPP +#define STAN_MCMC_MPI_CROSS_CHAIN_ADAPTER_HPP #include #include @@ -17,8 +17,7 @@ #include namespace stan { -namespace services { -namespace util { +namespace mcmc { class mpi_cross_chain_adapter { protected: @@ -314,5 +313,4 @@ namespace util { }; } } -} #endif diff --git a/src/test/unit/services/util/mpi_warmup_test.cpp b/src/test/unit/services/util/mpi_warmup_test.cpp index 07a94b1ffbf..4b7e7779100 100644 --- a/src/test/unit/services/util/mpi_warmup_test.cpp +++ b/src/test/unit/services/util/mpi_warmup_test.cpp @@ -2,8 +2,8 @@ #include #include +#include #include -#include #include #include #include @@ -27,7 +27,7 @@ using boost::accumulators::tag::mean; using boost::accumulators::tag::variance; // 4 chains with 4 cores, each chain run on a core -TEST(mpi_warmup_test, rhat_adaption) { +TEST(mpi_warmup_test, mpi_cross_chain_adapter) { const int num_chains = 4; const int max_num_windows = 5; const int window_size = 50; @@ -248,52 +248,55 @@ draw_vecs[3] << double chain_stepsize = 1.1 + 0.1 * comm.rank(); + const int num_iterations = window_size * max_num_windows; + stan::mcmc::mpi_cross_chain_adapter cc_adapter; + cc_adapter.set_cross_chain_adaptation_params(num_iterations, + window_size, + num_chains, 1.1, 40); + + Eigen::VectorXd dummy; + // a large ESS target should make all windows fail to pass tests - for (int curr_num_win = 1; curr_num_win < 6; ++curr_num_win) { - double target_ess = 40.0; - std::vector>> acc(curr_num_win); - for (int win = 0; win < curr_num_win; ++win) { - for (int j = win * window_size; j < curr_num_win * window_size; ++j) { - acc[win](draw_vecs[comm.rank()](j)); - } - } + for (int i = 0; i < num_iterations; ++i) { + cc_adapter.add_cross_chain_sample(dummy, draw_vecs[comm.rank()](i)); - std::vector output = - stan::services::util::mpi_cross_chain_adapt(draws[comm.rank()], acc, - chain_stepsize, - curr_num_win, max_num_windows, - window_size, num_chains, 1.1, target_ess); - for (int win = 0; win < curr_num_win; ++win) { - const std::vector p{ - draws[0] + win * window_size, - draws[1] + win * window_size, - draws[2] + win * window_size, - draws[3] + win * window_size}; - double rhat = - stan::analyze::compute_potential_scale_reduction(p, (curr_num_win - win) * window_size); - if (comm.rank() == 0) { - EXPECT_FLOAT_EQ(rhat, output[win + 1]); - } - } - } + double step = chain_stepsize; + bool is_adapted = cc_adapter.cross_chain_adaptation(step, dummy); + + EXPECT_FALSE(is_adapted); + + if (cc_adapter.is_cross_chain_adapt_window_end()) { + int curr_num_win = cc_adapter.current_cross_chain_window_counter(); + for (int win = 0; win < curr_num_win; ++win) { + const std::vector p{ + draws[0] + win * window_size, + draws[1] + win * window_size, + draws[2] + win * window_size, + draws[3] + win * window_size}; + double rhat = + stan::analyze::compute_potential_scale_reduction(p, (curr_num_win - win) * window_size); + if (comm.rank() == 0) { + EXPECT_FLOAT_EQ(rhat, cc_adapter.cross_chain_adapt_rhat()(win)); + } + } + } + } // a target_ess that 4-window tests should pass + cc_adapter.set_cross_chain_adaptation_params(num_iterations, + window_size, + num_chains, 1.1, 15); + { int curr_num_win = 4; double target_ess = 15.0; - std::vector>> acc(curr_num_win); - for (int win = 0; win < curr_num_win; ++win) { - for (int j = win * window_size; j < curr_num_win * window_size; ++j) { - acc[win](draw_vecs[comm.rank()](j)); - } - } - - std::vector output = - stan::services::util::mpi_cross_chain_adapt(draws[comm.rank()], acc, - chain_stepsize, - curr_num_win, max_num_windows, - window_size, num_chains, 1.1, target_ess); + for (int i = 0; i < num_iterations; ++i) { + cc_adapter.add_cross_chain_sample(dummy, draw_vecs[comm.rank()](i)); + double step = chain_stepsize; + bool is_adapted = cc_adapter.cross_chain_adaptation(step, dummy); + if (is_adapted) break; + } int win = 1; // win = 1 @c is_adapted const std::vector p{ draws[0] + win * window_size, @@ -302,9 +305,13 @@ draw_vecs[3] << draws[3] + win * window_size}; double rhat = stan::analyze::compute_potential_scale_reduction(p, (curr_num_win - win) * window_size); - if (comm.rank() == 0) { - EXPECT_FLOAT_EQ(rhat, output[1 + 1]); + if (comm.rank() == 0) { + EXPECT_FLOAT_EQ(rhat, cc_adapter.cross_chain_adapt_rhat()(win)); + for (int i = win + 1; i < max_num_windows; ++i) { + EXPECT_FLOAT_EQ(0.0, cc_adapter.cross_chain_adapt_rhat()(i)); + } } + } } From 4547a62d9eb47abf7976ec0f8c5b46f78f8bfcdc Mon Sep 17 00:00:00 2001 From: yiz Date: Tue, 21 Jan 2020 21:45:31 -0800 Subject: [PATCH 26/73] adapt unit/diag nuts inherit cross chain adapter --- lib/stan_math | 2 +- src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp | 12 ++++ src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp | 10 ++++ .../services/sample/hmc_nuts_diag_e_adapt.hpp | 2 - .../services/util/mpi_cross_chain_warmup.hpp | 58 +++---------------- .../util/run_mpi_adaptive_sampler.hpp | 3 + 6 files changed, 34 insertions(+), 53 deletions(-) diff --git a/lib/stan_math b/lib/stan_math index 89cee61d436..fccbead62a2 160000 --- a/lib/stan_math +++ b/lib/stan_math @@ -1 +1 @@ -Subproject commit 89cee61d43607d2ce011701b69e8ddffc1db2aeb +Subproject commit fccbead62a2f40bafd698f77c5a8f39f108c7dd7 diff --git a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp index 45e92380f57..5f931cd87af 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp @@ -5,6 +5,10 @@ #include #include +#ifdef MPI_ADAPTED_WARMUP +#include +#endif + namespace stan { namespace mcmc { /** @@ -13,8 +17,14 @@ namespace mcmc { * diagonal metric and adaptive step size */ template +#ifdef MPI_ADAPTED_WARMUP +class adapt_diag_e_nuts : public diag_e_nuts, + public stepsize_var_adapter, + public mpi_cross_chain_adapter { +#else class adapt_diag_e_nuts : public diag_e_nuts, public stepsize_var_adapter { +#endif public: adapt_diag_e_nuts(const Model& model, BaseRNG& rng) : diag_e_nuts(model, rng), @@ -38,6 +48,8 @@ class adapt_diag_e_nuts : public diag_e_nuts, this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); this->stepsize_adaptation_.restart(); } + + this -> add_cross_chain_sample(this->z_.q, s.log_prob()); } return s; } diff --git a/src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp index 3929fc7ed12..7f8c7df7048 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp @@ -5,6 +5,10 @@ #include #include +#ifdef MPI_ADAPTED_WARMUP +#include +#endif + namespace stan { namespace mcmc { /** @@ -13,8 +17,14 @@ namespace mcmc { * and adaptive step size */ template +#ifdef MPI_ADAPTED_WARMUP +class adapt_unit_e_nuts : public unit_e_nuts, + public stepsize_adapter, + public mpi_cross_chain_adapter { +#else class adapt_unit_e_nuts : public unit_e_nuts, public stepsize_adapter { +#endif public: adapt_unit_e_nuts(const Model& model, BaseRNG& rng) : unit_e_nuts(model, rng) {} diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp index ded741e711d..ef8cfb45120 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp @@ -78,8 +78,6 @@ int hmc_nuts_diag_e_adapt( random_seed += inter_comm.rank(); } MPI_Bcast(&random_seed, 1, MPI_UNSIGNED, 0, intra_comm.comm()); - int rank; - MPI_Comm_rank(MPI_COMM_STAN, &rank); #endif boost::ecuyer1988 rng = util::create_rng(random_seed, chain); diff --git a/src/stan/services/util/mpi_cross_chain_warmup.hpp b/src/stan/services/util/mpi_cross_chain_warmup.hpp index cda514d8c5e..e2832de78de 100644 --- a/src/stan/services/util/mpi_cross_chain_warmup.hpp +++ b/src/stan/services/util/mpi_cross_chain_warmup.hpp @@ -53,26 +53,7 @@ void mpi_cross_chain_warmup(Sampler& sampler, int num_chains, stan::mcmc::sample& init_s, Model& model, RNG& base_rng, callbacks::interrupt& callback, callbacks::logger& logger) { - using boost::accumulators::accumulator_set; - using boost::accumulators::stats; - using boost::accumulators::tag::mean; - using boost::accumulators::tag::variance; - - using stan::math::mpi::Session; - using stan::math::mpi::Communicator; - - const int max_num_windows = num_iterations / window_size; - std::vector>> - acc_log(max_num_windows); - std::vector acov(max_num_windows, 0.0); - - bool is_adapted = false; - - int m = 0; - std::vector draw; - draw.reserve(num_iterations); - double stepsize = -999.0; - while (m < num_iterations && (!is_adapted)) { + for (int m = 0; m < num_iterations; ++m) { callback(); if (refresh > 0 @@ -95,38 +76,15 @@ void mpi_cross_chain_warmup(Sampler& sampler, int num_chains, mcmc_writer.write_diagnostic_params(init_s, sampler); } - const Communicator& inter_comm = Session::inter_chain_comm(num_chains); - bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains); - int m_win = m / window_size + 1; - - // incrementally add data - if (is_inter_rank) { - draw.push_back(init_s.log_prob()); - for (int i = 0; i < m_win; ++i) { - acc_log[i](init_s.log_prob()); - } - } - - if (boost::math::isfinite(init_s.log_prob())) { - const Communicator& intra_comm = Session::intra_chain_comm(num_chains); - if ((m + 1) % window_size == 0) { - if (is_inter_rank) { - std::vector adapt_result = - stan::services::util::mpi_cross_chain_adapt(draw.data(), acc_log, - sampler.get_nominal_stepsize(), - m_win, max_num_windows, - window_size, num_chains, target_rhat, target_ess); - stepsize = adapt_result[0]; - } - MPI_Bcast(&stepsize, 1, MPI_DOUBLE, 0, intra_comm.comm()); - if (stepsize > 0.0) { - is_adapted = true; - } - } + // check cross-chain convergence + double stepsize = sampler.get_nominal_stepsize(); + Eigen::VectorXd dummy; // for future diag metric + bool is_adapted = sampler.cross_chain_adaptation(stepsize, dummy); + if (is_adapted) { + sampler.set_nominal_stepsize(stepsize); + break; } - m++; } - sampler.set_nominal_stepsize(stepsize); } } // namespace util diff --git a/src/stan/services/util/run_mpi_adaptive_sampler.hpp b/src/stan/services/util/run_mpi_adaptive_sampler.hpp index 01f0d532a14..634c8e23e6e 100644 --- a/src/stan/services/util/run_mpi_adaptive_sampler.hpp +++ b/src/stan/services/util/run_mpi_adaptive_sampler.hpp @@ -70,6 +70,9 @@ void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, const double target_rhat = 1.1; const double target_ess = 50; const int window_size = 100; + sampler.set_cross_chain_adaptation_params(num_warmup, + window_size, num_chains, + target_rhat, target_ess); util::mpi_cross_chain_warmup(sampler, num_chains, num_warmup, 0, num_warmup + num_samples, num_thin, refresh, save_warmup, true, From c6d98f7c18c74e74506c7c31996b8d680ce67b36 Mon Sep 17 00:00:00 2001 From: yiz Date: Tue, 21 Jan 2020 22:05:10 -0800 Subject: [PATCH 27/73] check cross-chain convergence in during sampler transition --- src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp | 6 ++++++ src/stan/mcmc/mpi_cross_chain_adapter.hpp | 18 +++++++++--------- .../services/util/mpi_cross_chain_warmup.hpp | 6 +----- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp index 5f931cd87af..01c02446748 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp @@ -49,7 +49,13 @@ class adapt_diag_e_nuts : public diag_e_nuts, this->stepsize_adaptation_.restart(); } + // check cross chain convergence this -> add_cross_chain_sample(this->z_.q, s.log_prob()); + double stepsize = this -> get_nominal_stepsize(); + this -> cross_chain_adaptation(stepsize, this->z_.inv_e_metric_); + if (this -> is_cross_chain_adapted()) { + this -> set_nominal_stepsize(stepsize); + } } return s; } diff --git a/src/stan/mcmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/mpi_cross_chain_adapter.hpp index 88dfae161d2..6a73fd0404a 100644 --- a/src/stan/mcmc/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/mpi_cross_chain_adapter.hpp @@ -41,6 +41,7 @@ namespace mcmc { int window_size, int num_chains, double target_rhat, double target_ess) { + is_adapted_ = false; window_size_ = window_size; num_chains_ = num_chains; max_num_windows_ = num_iterations / window_size; @@ -55,6 +56,7 @@ namespace mcmc { } inline void reset_cross_chain_adaptation() { + is_adapted_ = false; log_prob_draws_.clear(); log_prob_accumulators_.clear(); log_prob_accumulators_.resize(max_num_windows_); @@ -214,8 +216,8 @@ namespace mcmc { (log_prob_draws_.size() % window_size_ == 0); } - inline void debug() { - std::cout << "taki test debug: " << compute_effective_sample_size(0, 50) << "\n"; + inline bool is_cross_chain_adapted() { + return is_adapted_; } /* @@ -238,8 +240,6 @@ namespace mcmc { using stan::math::mpi::Session; using stan::math::mpi::Communicator; - bool is_adapted = false; - if (is_cross_chain_adapt_window_end()) { bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); if (is_inter_rank) { @@ -287,9 +287,9 @@ namespace mcmc { * num_chains_ / (num_chains_ - 1); double var_within = boost::accumulators::mean(acc_chain_var); rhat_(win) = sqrt((var_between / var_within + num_draws - 1) / num_draws); - is_adapted = rhat_(win) < target_rhat_ && (ess_ > target_ess_).all(); + is_adapted_ = rhat_(win) < target_rhat_ && (ess_ > target_ess_).all(); chain_stepsize = invalid_stepsize; - if (is_adapted) { + if (is_adapted_) { chain_stepsize = boost::accumulators::mean(acc_step); break; } @@ -302,12 +302,12 @@ namespace mcmc { } const Communicator& intra_comm = Session::intra_chain_comm(num_chains_); MPI_Bcast(&chain_stepsize, 1, MPI_DOUBLE, 0, intra_comm.comm()); - is_adapted = chain_stepsize > 0.0; - if (is_adapted) { + is_adapted_ = chain_stepsize > 0.0; + if (is_adapted_) { // MPI_Bcast(mpi_inv_e_metric.data(), num_params, MPI_DOUBLE, 0, intra_comm.comm()) } } - return is_adapted; + return is_adapted_; } }; diff --git a/src/stan/services/util/mpi_cross_chain_warmup.hpp b/src/stan/services/util/mpi_cross_chain_warmup.hpp index e2832de78de..72a58c2aec0 100644 --- a/src/stan/services/util/mpi_cross_chain_warmup.hpp +++ b/src/stan/services/util/mpi_cross_chain_warmup.hpp @@ -77,11 +77,7 @@ void mpi_cross_chain_warmup(Sampler& sampler, int num_chains, } // check cross-chain convergence - double stepsize = sampler.get_nominal_stepsize(); - Eigen::VectorXd dummy; // for future diag metric - bool is_adapted = sampler.cross_chain_adaptation(stepsize, dummy); - if (is_adapted) { - sampler.set_nominal_stepsize(stepsize); + if (sampler.is_cross_chain_adapted()) { break; } } From ab8ab7ea613ea3e5fe64dcc0f1ccf036ebdab841 Mon Sep 17 00:00:00 2001 From: yiz Date: Tue, 21 Jan 2020 23:48:21 -0800 Subject: [PATCH 28/73] 1st draft of adapting diag metric --- src/stan/mcmc/mpi_cross_chain_adapter.hpp | 30 ++++++++++++++-- src/stan/mcmc/mpi_var_adaptation.hpp | 34 +++++++++++++++++++ .../util/run_mpi_adaptive_sampler.hpp | 5 ++- .../unit/services/util/mpi_warmup_test.cpp | 2 ++ 4 files changed, 68 insertions(+), 3 deletions(-) create mode 100644 src/stan/mcmc/mpi_var_adaptation.hpp diff --git a/src/stan/mcmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/mpi_cross_chain_adapter.hpp index 6a73fd0404a..e2e72c442f5 100644 --- a/src/stan/mcmc/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/mpi_cross_chain_adapter.hpp @@ -4,9 +4,9 @@ #include #include #include +#include #include #include -// #include #include #include #include @@ -33,10 +33,16 @@ namespace mcmc { boost::accumulators::tag::variance>>> log_prob_accumulators_; // NOLINT Eigen::ArrayXd rhat_; Eigen::ArrayXd ess_; + mpi_var_adaptation* var_adapt; public: mpi_cross_chain_adapter() = default; + inline void set_cross_chain_var_adaptation(mpi_var_adaptation& adapt) + { + var_adapt = &adapt; + } + inline void set_cross_chain_adaptation_params(int num_iterations, int window_size, int num_chains, @@ -62,6 +68,7 @@ namespace mcmc { log_prob_accumulators_.resize(max_num_windows_); rhat_ = Eigen::ArrayXd::Zero(max_num_windows_); ess_ = Eigen::ArrayXd::Zero(num_chains_); + var_adapt -> estimator.restart(); } inline int current_cross_chain_window_counter() { @@ -71,6 +78,13 @@ namespace mcmc { inline void add_cross_chain_sample(const Eigen::VectorXd& q, double s) { using stan::math::mpi::Session; using stan::math::mpi::Communicator; + + // every rank needs num_params through q's size + if (log_prob_draws_.empty()) { + var_adapt -> estimator.restart(q.size()); + } + + // only add samples to inter-chain ranks bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); if (is_inter_rank) { log_prob_draws_.push_back(s); @@ -78,6 +92,13 @@ namespace mcmc { for (int i = 0; i < n; ++i) { log_prob_accumulators_[i](s); } + + // we only keep @c window_size q's + if (is_cross_chain_adapt_window_begin()) { + var_adapt -> estimator.restart(q.size()); + } + + var_adapt -> estimator.add_sample(q); } } @@ -216,6 +237,10 @@ namespace mcmc { (log_prob_draws_.size() % window_size_ == 0); } + inline bool is_cross_chain_adapt_window_begin() { + return (log_prob_draws_.size() - 1) % window_size_ == 0; + } + inline bool is_cross_chain_adapted() { return is_adapted_; } @@ -304,7 +329,8 @@ namespace mcmc { MPI_Bcast(&chain_stepsize, 1, MPI_DOUBLE, 0, intra_comm.comm()); is_adapted_ = chain_stepsize > 0.0; if (is_adapted_) { - // MPI_Bcast(mpi_inv_e_metric.data(), num_params, MPI_DOUBLE, 0, intra_comm.comm()) + var_adapt -> learn_variance(inv_e_metric); + MPI_Bcast(inv_e_metric.data(), var_adapt -> estimator.num_params(), MPI_DOUBLE, 0, intra_comm.comm()); } } return is_adapted_; diff --git a/src/stan/mcmc/mpi_var_adaptation.hpp b/src/stan/mcmc/mpi_var_adaptation.hpp new file mode 100644 index 00000000000..0edda73e6a0 --- /dev/null +++ b/src/stan/mcmc/mpi_var_adaptation.hpp @@ -0,0 +1,34 @@ +#ifndef STAN_MCMC_MPI_VAR_ADAPTATION_HPP +#define STAN_MCMC_MPI_VAR_ADAPTATION_HPP + +#include +#include +#include + +namespace stan { + +namespace mcmc { + +class mpi_var_adaptation { + public: + stan::math::mpi::mpi_var_estimator estimator; + + explicit mpi_var_adaptation(int n_params, + const stan::math::mpi::Communicator& comm) + : estimator(n_params, comm) {} + + explicit mpi_var_adaptation(int num_chains) + : estimator(0, stan::math::mpi::Session::inter_chain_comm(num_chains)) {} + + void learn_variance(Eigen::VectorXd& var) { + double n = static_cast(estimator.sample_variance(var)); + var = (n / (n + 5.0)) * var + + 1e-3 * (5.0 / (n + 5.0)) * Eigen::VectorXd::Ones(var.size()); + estimator.restart(); + } +}; + +} // namespace mcmc + +} // namespace stan +#endif diff --git a/src/stan/services/util/run_mpi_adaptive_sampler.hpp b/src/stan/services/util/run_mpi_adaptive_sampler.hpp index 634c8e23e6e..a6e60eb8122 100644 --- a/src/stan/services/util/run_mpi_adaptive_sampler.hpp +++ b/src/stan/services/util/run_mpi_adaptive_sampler.hpp @@ -71,8 +71,11 @@ void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, const double target_ess = 50; const int window_size = 100; sampler.set_cross_chain_adaptation_params(num_warmup, - window_size, num_chains, + window_size, num_chains, target_rhat, target_ess); + stan::mcmc::mpi_var_adaptation var_adapt(sampler.z().q.size(), + stan::math::mpi::Session::inter_chain_comm(num_chains)); + sampler.set_cross_chain_var_adaptation(var_adapt); util::mpi_cross_chain_warmup(sampler, num_chains, num_warmup, 0, num_warmup + num_samples, num_thin, refresh, save_warmup, true, diff --git a/src/test/unit/services/util/mpi_warmup_test.cpp b/src/test/unit/services/util/mpi_warmup_test.cpp index 4b7e7779100..caec230a7b3 100644 --- a/src/test/unit/services/util/mpi_warmup_test.cpp +++ b/src/test/unit/services/util/mpi_warmup_test.cpp @@ -253,6 +253,8 @@ draw_vecs[3] << cc_adapter.set_cross_chain_adaptation_params(num_iterations, window_size, num_chains, 1.1, 40); + stan::mcmc::mpi_var_adaptation var_adapt(0, comm); + cc_adapter.set_cross_chain_var_adaptation(var_adapt); Eigen::VectorXd dummy; From 536d789111cad274896f54a7a94d1afa3dd31076 Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 22 Jan 2020 09:24:54 -0800 Subject: [PATCH 29/73] use harmonic mean for ESS test --- src/stan/mcmc/mpi_cross_chain_adapter.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/stan/mcmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/mpi_cross_chain_adapter.hpp index e2e72c442f5..6a15835c8ed 100644 --- a/src/stan/mcmc/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/mpi_cross_chain_adapter.hpp @@ -312,7 +312,8 @@ namespace mcmc { * num_chains_ / (num_chains_ - 1); double var_within = boost::accumulators::mean(acc_chain_var); rhat_(win) = sqrt((var_between / var_within + num_draws - 1) / num_draws); - is_adapted_ = rhat_(win) < target_rhat_ && (ess_ > target_ess_).all(); + double ess_hmean = ess_.size()/((1.0/ess_).sum()); // harmonic mean + is_adapted_ = rhat_(win) < target_rhat_ && ess_hmean > target_ess_; chain_stepsize = invalid_stepsize; if (is_adapted_) { chain_stepsize = boost::accumulators::mean(acc_step); From fc8c3ceb5edaf9ddbc117b836ca62f9adba4c333 Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 22 Jan 2020 12:01:39 -0800 Subject: [PATCH 30/73] fix broadcasting metrics issue --- src/stan/mcmc/mpi_cross_chain_adapter.hpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/stan/mcmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/mpi_cross_chain_adapter.hpp index 6a15835c8ed..06263ce796f 100644 --- a/src/stan/mcmc/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/mpi_cross_chain_adapter.hpp @@ -31,6 +31,8 @@ namespace mcmc { std::vector>> log_prob_accumulators_; // NOLINT + boost::accumulators::accumulator_set > draw_counter_acc_; Eigen::ArrayXd rhat_; Eigen::ArrayXd ess_; mpi_var_adaptation* var_adapt; @@ -57,6 +59,7 @@ namespace mcmc { log_prob_draws_.reserve(num_iterations); log_prob_accumulators_.clear(); log_prob_accumulators_.resize(max_num_windows_); + draw_counter_acc_ = {}; rhat_ = Eigen::ArrayXd::Zero(max_num_windows_); ess_ = Eigen::ArrayXd::Zero(num_chains_); } @@ -66,6 +69,7 @@ namespace mcmc { log_prob_draws_.clear(); log_prob_accumulators_.clear(); log_prob_accumulators_.resize(max_num_windows_); + draw_counter_acc_ = {}; rhat_ = Eigen::ArrayXd::Zero(max_num_windows_); ess_ = Eigen::ArrayXd::Zero(num_chains_); var_adapt -> estimator.restart(); @@ -84,6 +88,9 @@ namespace mcmc { var_adapt -> estimator.restart(q.size()); } + // all procs keep a counter + draw_counter_acc_(0); + // only add samples to inter-chain ranks bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); if (is_inter_rank) { @@ -233,8 +240,8 @@ namespace mcmc { } inline bool is_cross_chain_adapt_window_end() { - return (!log_prob_draws_.empty()) && - (log_prob_draws_.size() % window_size_ == 0); + size_t n = boost::accumulators::count(draw_counter_acc_); + return n > 0 && (n % window_size_ == 0); } inline bool is_cross_chain_adapt_window_begin() { @@ -330,7 +337,9 @@ namespace mcmc { MPI_Bcast(&chain_stepsize, 1, MPI_DOUBLE, 0, intra_comm.comm()); is_adapted_ = chain_stepsize > 0.0; if (is_adapted_) { - var_adapt -> learn_variance(inv_e_metric); + if (is_inter_rank) { + var_adapt -> learn_variance(inv_e_metric); + } MPI_Bcast(inv_e_metric.data(), var_adapt -> estimator.num_params(), MPI_DOUBLE, 0, intra_comm.comm()); } } From a58e2a925078a63d9ed8c9d2dd63ea95e5b9ec90 Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 22 Jan 2020 12:32:09 -0800 Subject: [PATCH 31/73] rm old mpi adapt implementation --- .../services/util/mpi_cross_chain_adapt.hpp | 229 ------------------ .../services/util/mpi_cross_chain_warmup.hpp | 1 - .../unit/services/util/mpi_warmup_test.cpp | 1 - 3 files changed, 231 deletions(-) delete mode 100644 src/stan/services/util/mpi_cross_chain_adapt.hpp diff --git a/src/stan/services/util/mpi_cross_chain_adapt.hpp b/src/stan/services/util/mpi_cross_chain_adapt.hpp deleted file mode 100644 index 21005a7d1d1..00000000000 --- a/src/stan/services/util/mpi_cross_chain_adapt.hpp +++ /dev/null @@ -1,229 +0,0 @@ -#ifndef STAN_SERVICES_UTIL_MPI_CROSS_CHAIN_ADAPT_HPP -#define STAN_SERVICES_UTIL_MPI_CROSS_CHAIN_ADAPT_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace stan { -namespace services { -namespace util { - -inline double compute_effective_sample_size(std::vector draws, - std::vector sizes) { - int num_chains = sizes.size(); - size_t num_draws = sizes[0]; - for (int chain = 1; chain < num_chains; ++chain) { - num_draws = std::min(num_draws, sizes[chain]); - } - - // check if chains are constant; all equal to first draw's value - bool are_all_const = false; - Eigen::VectorXd init_draw = Eigen::VectorXd::Zero(num_chains); - - for (int chain_idx = 0; chain_idx < num_chains; chain_idx++) { - Eigen::Map> draw( - draws[chain_idx], sizes[chain_idx]); - - for (int n = 0; n < num_draws; n++) { - if (!boost::math::isfinite(draw(n))) { - return std::numeric_limits::quiet_NaN(); - } - } - - init_draw(chain_idx) = draw(0); - - if (draw.isApproxToConstant(draw(0))) { - are_all_const |= true; - } - } - - if (are_all_const) { - // If all chains are constant then return NaN - // if they all equal the same constant value - if (init_draw.isApproxToConstant(init_draw(0))) { - return std::numeric_limits::quiet_NaN(); - } - } - - Eigen::Matrix acov(num_chains); - Eigen::VectorXd chain_mean(num_chains); - Eigen::VectorXd chain_var(num_chains); - for (int chain = 0; chain < num_chains; ++chain) { - Eigen::Map> draw( - draws[chain], sizes[chain]); - stan::analyze::autocovariance(draw, acov(chain)); - chain_mean(chain) = draw.mean(); - chain_var(chain) = acov(chain)(0) * num_draws / (num_draws - 1); - } - - double mean_var = chain_var.mean(); - double var_plus = mean_var * (num_draws - 1) / num_draws; - if (num_chains > 1) - var_plus += math::variance(chain_mean); - Eigen::VectorXd rho_hat_s(num_draws); - rho_hat_s.setZero(); - Eigen::VectorXd acov_s(num_chains); - for (int chain = 0; chain < num_chains; ++chain) - acov_s(chain) = acov(chain)(1); - double rho_hat_even = 1.0; - rho_hat_s(0) = rho_hat_even; - double rho_hat_odd = 1 - (mean_var - acov_s.mean()) / var_plus; - rho_hat_s(1) = rho_hat_odd; - - // Convert raw autocovariance estimators into Geyer's initial - // positive sequence. Loop only until num_draws - 4 to - // leave the last pair of autocorrelations as a bias term that - // reduces variance in the case of antithetical chains. - size_t s = 1; - while (s < (num_draws - 4) && (rho_hat_even + rho_hat_odd) > 0) { - for (int chain = 0; chain < num_chains; ++chain) - acov_s(chain) = acov(chain)(s + 1); - rho_hat_even = 1 - (mean_var - acov_s.mean()) / var_plus; - for (int chain = 0; chain < num_chains; ++chain) - acov_s(chain) = acov(chain)(s + 2); - rho_hat_odd = 1 - (mean_var - acov_s.mean()) / var_plus; - if ((rho_hat_even + rho_hat_odd) >= 0) { - rho_hat_s(s + 1) = rho_hat_even; - rho_hat_s(s + 2) = rho_hat_odd; - } - s += 2; - } - - int max_s = s; - // this is used in the improved estimate, which reduces variance - // in antithetic case -- see tau_hat below - if (rho_hat_even > 0) - rho_hat_s(max_s + 1) = rho_hat_even; - - // Convert Geyer's initial positive sequence into an initial - // monotone sequence - for (int s = 1; s <= max_s - 3; s += 2) { - if (rho_hat_s(s + 1) + rho_hat_s(s + 2) > rho_hat_s(s - 1) + rho_hat_s(s)) { - rho_hat_s(s + 1) = (rho_hat_s(s - 1) + rho_hat_s(s)) / 2; - rho_hat_s(s + 2) = rho_hat_s(s + 1); - } - } - - double num_total_draws = num_chains * num_draws; - // Geyer's truncated estimator for the asymptotic variance - // Improved estimate reduces variance in antithetic case - double tau_hat = -1 + 2 * rho_hat_s.head(max_s).sum() + rho_hat_s(max_s + 1); - return std::min(num_total_draws / tau_hat, - num_total_draws * std::log10(num_total_draws)); -} - /* - * Computes the effective sample size (ESS) for the specified - * parameter across all kept samples. The value returned is the - * minimum of ESS and the number_total_draws * - * log10(number_total_draws). - * - * This version is based on the one at - * stan/analyze/mcmc/compute_effective_sample_size.hpp - * but assuming the chain_mean and chain_var has been - * calculated(on the fly during adaptation) - * - */ -inline double compute_effective_sample_size(const double* draw, size_t size) { - std::vector draws{draw}; - std::vector sizes{size}; - return compute_effective_sample_size(draws, sizes); -} - - /* - * @tparam Sampler sampler class - * @param[in] m_win number of windows - * @param[in] window_size window size - * @param[in] num_chains number of chains - * @param[in,out] chain_gather gathered information from each chain, - * must have enough capacity to store up to - * maximum windows for all chains. - # @return vector {stepsize, rhat(only in rank 0)} - */ - template - std::vector - mpi_cross_chain_adapt(const double* draw_p, - const std::vector& acc, - double chain_stepsize, - int num_current_window, int max_num_window, - int window_size, int num_chains, - double target_rhat, double target_ess) { - using boost::accumulators::accumulator_set; - using boost::accumulators::stats; - using boost::accumulators::tag::mean; - using boost::accumulators::tag::variance; - - using stan::math::mpi::Session; - using stan::math::mpi::Communicator; - - const Communicator& comm = Session::inter_chain_comm(num_chains); - - const int nd_win = 4; // mean, variance, chain_stepsize - int n_gather = nd_win * num_current_window; - std::vector chain_gather(n_gather, 0.0); - for (int win = 0; win < num_current_window; ++win) { - int num_draws = (num_current_window - win) * window_size; - double unbiased_var_scale = num_draws / (num_draws - 1.0); - chain_gather[nd_win * win] = boost::accumulators::mean(acc[win]); - chain_gather[nd_win * win + 1] = boost::accumulators::variance(acc[win]) * - unbiased_var_scale; - chain_gather[nd_win * win + 2] = chain_stepsize; - chain_gather[nd_win * win + 3] = - compute_effective_sample_size(draw_p + win * window_size, num_draws); - } - - double stepsize = -999.0; - std::vector res(1 + max_num_window, stepsize); - - if (comm.rank() == 0) { - std::vector all_chain_gather(n_gather * num_chains); - MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, - all_chain_gather.data(), n_gather, MPI_DOUBLE, 0, comm.comm()); - for (int win = 0; win < num_current_window; ++win) { - accumulator_set> acc_chain_mean; - accumulator_set> acc_chain_var; - accumulator_set> acc_step; - Eigen::VectorXd chain_mean(num_chains); - Eigen::VectorXd chain_var(num_chains); - Eigen::ArrayXd chain_ess(num_chains); - for (int chain = 0; chain < num_chains; ++chain) { - chain_mean(chain) = all_chain_gather[chain * n_gather + nd_win * win]; - acc_chain_mean(chain_mean(chain)); - chain_var(chain) = all_chain_gather[chain * n_gather + nd_win * win + 1]; - acc_chain_var(chain_var(chain)); - acc_step(all_chain_gather[chain * n_gather + nd_win * win + 2]); - chain_ess(chain) = all_chain_gather[chain * n_gather + nd_win * win + 3]; - } - size_t num_draws = (num_current_window - win) * window_size; - double var_between = num_draws * boost::accumulators::variance(acc_chain_mean) - * num_chains / (num_chains - 1); - double var_within = boost::accumulators::mean(acc_chain_var); - double rhat = sqrt((var_between / var_within + num_draws - 1) / num_draws); - res[win + 1] = rhat; - bool is_adapted = rhat < target_rhat && (chain_ess > target_ess).all(); - if (is_adapted) { - stepsize = boost::accumulators::mean(acc_step); - res[0] = stepsize; - break; - } - } - } else { - MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, - NULL, 0, MPI_DOUBLE, 0, comm.comm()); - } - MPI_Bcast(res.data(), 1, MPI_DOUBLE, 0, comm.comm()); - return res; - } -} -} -} -#endif diff --git a/src/stan/services/util/mpi_cross_chain_warmup.hpp b/src/stan/services/util/mpi_cross_chain_warmup.hpp index 72a58c2aec0..f9db36531b1 100644 --- a/src/stan/services/util/mpi_cross_chain_warmup.hpp +++ b/src/stan/services/util/mpi_cross_chain_warmup.hpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include #include diff --git a/src/test/unit/services/util/mpi_warmup_test.cpp b/src/test/unit/services/util/mpi_warmup_test.cpp index caec230a7b3..08ff13a1bb7 100644 --- a/src/test/unit/services/util/mpi_warmup_test.cpp +++ b/src/test/unit/services/util/mpi_warmup_test.cpp @@ -1,7 +1,6 @@ #ifdef STAN_LANG_MPI #include -#include #include #include #include From 7befdde64705fda6509e20b8a8515951f847b174 Mon Sep 17 00:00:00 2001 From: yiz Date: Thu, 23 Jan 2020 09:52:34 -0800 Subject: [PATCH 32/73] stream writer only writes to inter chain ranks --- src/stan/callbacks/stream_writer.hpp | 12 +++--------- src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp | 3 ++- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/stan/callbacks/stream_writer.hpp b/src/stan/callbacks/stream_writer.hpp index 897c3fafaf3..4700c591979 100644 --- a/src/stan/callbacks/stream_writer.hpp +++ b/src/stan/callbacks/stream_writer.hpp @@ -62,9 +62,7 @@ class stream_writer : public writer { */ void operator()() { #ifdef MPI_ADAPTED_WARMUP - int rank; - MPI_Comm_rank(MPI_COMM_STAN, &rank); - if (rank == 0) { + if (stan::math::mpi::Session::is_in_inter_chain_comm(4)) { output_ << comment_prefix_ << std::endl; } #else @@ -79,9 +77,7 @@ class stream_writer : public writer { */ void operator()(const std::string& message) { #ifdef MPI_ADAPTED_WARMUP - int rank; - MPI_Comm_rank(MPI_COMM_STAN, &rank); - if (rank == 0) { + if (stan::math::mpi::Session::is_in_inter_chain_comm(4)) { output_ << comment_prefix_ << message << std::endl; } #else @@ -111,9 +107,7 @@ class stream_writer : public writer { template void write_vector(const std::vector& v) { #ifdef MPI_ADAPTED_WARMUP - int rank; - MPI_Comm_rank(MPI_COMM_STAN, &rank); - if (rank == 0) { + if (stan::math::mpi::Session::is_in_inter_chain_comm(4)) { if (v.empty()) return; diff --git a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp index 01c02446748..fc2a4e463ce 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp @@ -49,13 +49,14 @@ class adapt_diag_e_nuts : public diag_e_nuts, this->stepsize_adaptation_.restart(); } - // check cross chain convergence +#ifdef MPI_ADAPTED_WARMUP this -> add_cross_chain_sample(this->z_.q, s.log_prob()); double stepsize = this -> get_nominal_stepsize(); this -> cross_chain_adaptation(stepsize, this->z_.inv_e_metric_); if (this -> is_cross_chain_adapted()) { this -> set_nominal_stepsize(stepsize); } +#endif } return s; } From 99cc23658662a142f8b3082095d72a98ad0da7c1 Mon Sep 17 00:00:00 2001 From: yiz Date: Thu, 23 Jan 2020 14:41:05 -0800 Subject: [PATCH 33/73] rng seed should be from cmdstan --- src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp | 7 ------- src/stan/services/util/run_mpi_adaptive_sampler.hpp | 1 + 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp index ef8cfb45120..86f6ba0283a 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp @@ -71,13 +71,6 @@ int hmc_nuts_diag_e_adapt( #ifdef MPI_ADAPTED_WARMUP const int num_chains = 4; - const Communicator& inter_comm = Session::inter_chain_comm(num_chains); - const Communicator& intra_comm = Session::intra_chain_comm(num_chains); - bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains); - if (is_inter_rank) { - random_seed += inter_comm.rank(); - } - MPI_Bcast(&random_seed, 1, MPI_UNSIGNED, 0, intra_comm.comm()); #endif boost::ecuyer1988 rng = util::create_rng(random_seed, chain); diff --git a/src/stan/services/util/run_mpi_adaptive_sampler.hpp b/src/stan/services/util/run_mpi_adaptive_sampler.hpp index a6e60eb8122..63106586326 100644 --- a/src/stan/services/util/run_mpi_adaptive_sampler.hpp +++ b/src/stan/services/util/run_mpi_adaptive_sampler.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include From 32f552e1f6db44feaa9fbdc7a21ff43690959335 Mon Sep 17 00:00:00 2001 From: yiz Date: Fri, 24 Jan 2020 07:51:24 -0800 Subject: [PATCH 34/73] pass `num_cross_chains` to diag_e_adapt function --- src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp | 11 ++++------- src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp | 5 ++--- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp index 86f6ba0283a..25ce994e132 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp @@ -58,7 +58,7 @@ template int hmc_nuts_diag_e_adapt( Model& model, stan::io::var_context& init, stan::io::var_context& init_inv_metric, unsigned int random_seed, - unsigned int chain, double init_radius, int num_warmup, int num_samples, + unsigned int chain, double init_radius, int num_cross_chains, int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, double delta, double gamma, double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer, @@ -69,9 +69,6 @@ int hmc_nuts_diag_e_adapt( using stan::math::mpi::Session; using stan::math::mpi::Communicator; -#ifdef MPI_ADAPTED_WARMUP - const int num_chains = 4; -#endif boost::ecuyer1988 rng = util::create_rng(random_seed, chain); std::vector disc_vector; @@ -105,7 +102,7 @@ int hmc_nuts_diag_e_adapt( #ifdef MPI_ADAPTED_WARMUP util::run_mpi_adaptive_sampler(sampler, - model, cont_vector, num_chains, num_warmup, num_samples, num_thin, refresh, + model, cont_vector, num_cross_chains, num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); #else util::run_adaptive_sampler( @@ -150,7 +147,7 @@ int hmc_nuts_diag_e_adapt( template int hmc_nuts_diag_e_adapt( Model& model, stan::io::var_context& init, unsigned int random_seed, - unsigned int chain, double init_radius, int num_warmup, int num_samples, + unsigned int chain, double init_radius, int num_cross_chains, int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, double delta, double gamma, double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer, @@ -162,7 +159,7 @@ int hmc_nuts_diag_e_adapt( stan::io::var_context& unit_e_metric = dmp; return hmc_nuts_diag_e_adapt( - model, init, unit_e_metric, random_seed, chain, init_radius, num_warmup, + model, init, unit_e_metric, random_seed, chain, init_radius, num_cross_chains, num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, logger, init_writer, sample_writer, diagnostic_writer); diff --git a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp index 810a0a3be8f..c9a14c1a0ac 100644 --- a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp @@ -48,7 +48,7 @@ namespace sample { template int hmc_nuts_unit_e_adapt( Model& model, stan::io::var_context& init, unsigned int random_seed, - unsigned int chain, double init_radius, int num_warmup, int num_samples, + unsigned int chain, double init_radius, int num_cross_chains, int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, double delta, double gamma, double kappa, double t0, callbacks::interrupt& interrupt, @@ -72,9 +72,8 @@ int hmc_nuts_unit_e_adapt( sampler.get_stepsize_adaptation().set_t0(t0); #ifdef MPI_ADAPTED_WARMUP - const int num_chains = 4; util::run_mpi_adaptive_sampler( - sampler, model, cont_vector, num_chains, num_warmup, num_samples, num_thin, refresh, + sampler, model, cont_vector, num_cross_chains, num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); #else util::run_adaptive_sampler( From 16eac209da4f44f592bc4f041ff5ab9afa1f41e1 Mon Sep 17 00:00:00 2001 From: yiz Date: Fri, 24 Jan 2020 10:08:55 -0800 Subject: [PATCH 35/73] fix sequential compile failure when no MPI flags issued --- src/stan/callbacks/stream_writer.hpp | 5 ++++- src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp | 9 ++++----- src/stan/services/util/run_mpi_adaptive_sampler.hpp | 13 +++++++++++-- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/src/stan/callbacks/stream_writer.hpp b/src/stan/callbacks/stream_writer.hpp index 4700c591979..23483e27250 100644 --- a/src/stan/callbacks/stream_writer.hpp +++ b/src/stan/callbacks/stream_writer.hpp @@ -2,11 +2,14 @@ #define STAN_CALLBACKS_STREAM_WRITER_HPP #include -#include #include #include #include +#ifdef STAN_LANG_MPI +#include +#endif + namespace stan { namespace callbacks { diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp index 25ce994e132..9c8f5eb5ba9 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp @@ -10,12 +10,15 @@ #include #include #include -#include #include #include #include #include +#ifdef MPI_ADAPTED_WARMUP +#include +#endif + namespace stan { namespace services { namespace sample { @@ -65,10 +68,6 @@ int hmc_nuts_diag_e_adapt( unsigned int window, callbacks::interrupt& interrupt, callbacks::logger& logger, callbacks::writer& init_writer, callbacks::writer& sample_writer, callbacks::writer& diagnostic_writer) { - - using stan::math::mpi::Session; - using stan::math::mpi::Communicator; - boost::ecuyer1988 rng = util::create_rng(random_seed, chain); std::vector disc_vector; diff --git a/src/stan/services/util/run_mpi_adaptive_sampler.hpp b/src/stan/services/util/run_mpi_adaptive_sampler.hpp index 63106586326..b0a99734533 100644 --- a/src/stan/services/util/run_mpi_adaptive_sampler.hpp +++ b/src/stan/services/util/run_mpi_adaptive_sampler.hpp @@ -3,13 +3,16 @@ #include #include -#include #include #include -#include #include #include +#ifdef MPI_ADAPTED_WARMUP +#include +#include +#endif + namespace stan { namespace services { namespace util { @@ -68,6 +71,7 @@ void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, // warmup clock_t start = clock(); +#ifdef MPI_ADAPTED_WARMUP const double target_rhat = 1.1; const double target_ess = 50; const int window_size = 100; @@ -83,6 +87,11 @@ void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, window_size, target_rhat, target_ess, writer, s, model, rng, interrupt, logger); +#else + util::generate_transitions(sampler, num_warmup, 0, num_warmup + num_samples, + num_thin, refresh, save_warmup, true, writer, s, + model, rng, interrupt, logger); +#endif clock_t end = clock(); double warm_delta_t = static_cast(end - start) / CLOCKS_PER_SEC; From 26ca4f39c467a675a2d5b9363373e72bf10daec2 Mon Sep 17 00:00:00 2001 From: yiz Date: Sat, 25 Jan 2020 21:52:30 -0800 Subject: [PATCH 36/73] turn off metric aggregation. update submoduel --- lib/stan_math | 2 +- src/stan/mcmc/mpi_cross_chain_adapter.hpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/stan_math b/lib/stan_math index fccbead62a2..4824e30732c 160000 --- a/lib/stan_math +++ b/lib/stan_math @@ -1 +1 @@ -Subproject commit fccbead62a2f40bafd698f77c5a8f39f108c7dd7 +Subproject commit 4824e30732cc4d0c49922788b71feaedc15329ee diff --git a/src/stan/mcmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/mpi_cross_chain_adapter.hpp index 06263ce796f..d54215e690f 100644 --- a/src/stan/mcmc/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/mpi_cross_chain_adapter.hpp @@ -338,9 +338,9 @@ namespace mcmc { is_adapted_ = chain_stepsize > 0.0; if (is_adapted_) { if (is_inter_rank) { - var_adapt -> learn_variance(inv_e_metric); + // var_adapt -> learn_variance(inv_e_metric); } - MPI_Bcast(inv_e_metric.data(), var_adapt -> estimator.num_params(), MPI_DOUBLE, 0, intra_comm.comm()); + // MPI_Bcast(inv_e_metric.data(), var_adapt -> estimator.num_params(), MPI_DOUBLE, 0, intra_comm.comm()); } } return is_adapted_; From a8da9363da3e1571104e00ec52d09d3e36503c78 Mon Sep 17 00:00:00 2001 From: yiz Date: Mon, 27 Jan 2020 10:53:30 -0800 Subject: [PATCH 37/73] unit test for mpi var adaptation --- .../unit/mcmc/mpi_var_adaptation_test.cpp | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 src/test/unit/mcmc/mpi_var_adaptation_test.cpp diff --git a/src/test/unit/mcmc/mpi_var_adaptation_test.cpp b/src/test/unit/mcmc/mpi_var_adaptation_test.cpp new file mode 100644 index 00000000000..53b286c93ff --- /dev/null +++ b/src/test/unit/mcmc/mpi_var_adaptation_test.cpp @@ -0,0 +1,86 @@ +#include +#include +#include +#include +#include + +TEST(McmcVarAdaptation, mpi_learn_variance) { + stan::test::unit::instrumented_logger logger; + + const int n = 10; + Eigen::VectorXd q(Eigen::VectorXd::Zero(n)), var(q); + const int n_learn = 12; + Eigen::VectorXd target_var(Eigen::VectorXd::Ones(n)); + target_var *= 1e-3 * 5.0 / (n_learn + 5.0); + + stan::mcmc::var_adaptation adapter(n); + adapter.set_window_params(50, 0, 0, n_learn, logger); + for (int i = 0; i < n_learn; ++i) + adapter.learn_variance(var, q); + + for (int i = 0; i < n; ++i) + EXPECT_FLOAT_EQ(target_var(i), var(i)); + + EXPECT_EQ(0, logger.call_count()); + + stan::math::mpi::Communicator comm(MPI_COMM_STAN); + const int num_chains = comm.size(); + const int n_learn_chain = n_learn / num_chains; + stan::mcmc::mpi_var_adaptation mpi_adapter(n, comm); + Eigen::VectorXd mpi_var(Eigen::VectorXd::Zero(n)); + for (int i = 0; i < n_learn_chain; ++i) + mpi_adapter.estimator.add_sample(q); + + mpi_adapter.learn_variance(mpi_var); + + for (int i = 0; i < n; ++i) { + EXPECT_FLOAT_EQ(var(i), mpi_var(i)); + } +} + +TEST(McmcVarAdaptation, mpi_data_learn_variance) { + stan::test::unit::instrumented_logger logger; + + const int n = 5; + const int n_learn = 12; + Eigen::VectorXd q_all(n * n_learn); + q_all << + -276.606, -277.168, -272.621, -271.142, -271.95 , + -269.749, -267.016, -273.508, -268.65 , -265.904, + -264.629, -260.797, -263.184, -263.892, -268.81 , + -272.563, -268.32 , -266.297, -265.787, -266.073, + -265.788, -262.26 , -265.073, -265.511, -264.318, + -264.318, -266.261, -265.633, -265.323, -265.633, + -265.426, -265.69 , -266.122, -264.876, -264.829, + -264.238, -265.822, -262.979, -264.012, -263.801, + -264.745, -263.94 , -263.586, -263.284, -262.566, + -261.816, -265.308, -266.467, -265.915, -266.122, + -266.122, -265.903, -265.903, -265.717, -271.78 , + -271.78 , -271.712, -271.712, -271.011, -273.137; + + stan::mcmc::var_adaptation adapter(n); + adapter.set_window_params(50, 0, 0, n_learn, logger); + Eigen::VectorXd var(Eigen::VectorXd::Zero(n)); + for (int i = 0; i < n_learn; ++i) { + Eigen::VectorXd q = Eigen::VectorXd::Map(&q_all(i * n), n); + adapter.learn_variance(var, q); + } + + EXPECT_EQ(0, logger.call_count()); + + stan::math::mpi::Communicator comm(MPI_COMM_STAN); + const int num_chains = comm.size(); + const int n_learn_chain = n_learn / num_chains; + stan::mcmc::mpi_var_adaptation mpi_adapter(n, comm); + Eigen::VectorXd mpi_var(Eigen::VectorXd::Zero(n)); + for (int i = 0; i < n_learn_chain; ++i) { + Eigen::VectorXd q = + Eigen::VectorXd::Map(&q_all(i * n + comm.rank() * n * n_learn_chain), n); + mpi_adapter.estimator.add_sample(q); + } + mpi_adapter.learn_variance(mpi_var); + + for (int i = 0; i < n; ++i) { + EXPECT_FLOAT_EQ(var(i), mpi_var(i)); + } +} From a703c4ddbc449ae20d05cdb65e6a8b940739489f Mon Sep 17 00:00:00 2001 From: yiz Date: Tue, 28 Jan 2020 12:42:32 -0800 Subject: [PATCH 38/73] adapter with logging function and new mpi_var_adaptation --- lib/stan_math | 2 +- src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp | 2 +- src/stan/mcmc/mpi_cross_chain_adapter.hpp | 21 +++++++++++++++++-- src/stan/mcmc/mpi_var_adaptation.hpp | 15 ++++++------- .../util/run_mpi_adaptive_sampler.hpp | 11 ++-------- .../unit/mcmc/mpi_var_adaptation_test.cpp | 8 +++---- 6 files changed, 35 insertions(+), 24 deletions(-) diff --git a/lib/stan_math b/lib/stan_math index 4824e30732c..69f31322c18 160000 --- a/lib/stan_math +++ b/lib/stan_math @@ -1 +1 @@ -Subproject commit 4824e30732cc4d0c49922788b71feaedc15329ee +Subproject commit 69f31322c180f555f156fed8e482bb2cd7fd604c diff --git a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp index fc2a4e463ce..7c49830497d 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp @@ -52,7 +52,7 @@ class adapt_diag_e_nuts : public diag_e_nuts, #ifdef MPI_ADAPTED_WARMUP this -> add_cross_chain_sample(this->z_.q, s.log_prob()); double stepsize = this -> get_nominal_stepsize(); - this -> cross_chain_adaptation(stepsize, this->z_.inv_e_metric_); + this -> cross_chain_adaptation(stepsize, this->z_.inv_e_metric_, logger); if (this -> is_cross_chain_adapted()) { this -> set_nominal_stepsize(stepsize); } diff --git a/src/stan/mcmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/mpi_cross_chain_adapter.hpp index d54215e690f..3dea7029f95 100644 --- a/src/stan/mcmc/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/mpi_cross_chain_adapter.hpp @@ -252,6 +252,18 @@ namespace mcmc { return is_adapted_; } + inline void log_adaptation(int win, callbacks::logger& logger) { + std::stringstream message; + message << "iteration: "; + message << boost::accumulators::count(draw_counter_acc_); + message << " window: " << win + 1 << " / " << current_cross_chain_window_counter(); + message << " Rhat: " << std::setprecision(3) << cross_chain_adapt_rhat()[win]; + const Eigen::ArrayXd& ess(cross_chain_adapt_ess()); + message << " ESS: " << std::setprecision(3) << ess.size()/((1.0/ess).sum()); + + logger.info(message); + } + /* * @tparam Sampler sampler class * @param[in] m_win number of windows @@ -263,7 +275,8 @@ namespace mcmc { # @return vector {stepsize, rhat(only in rank 0)} */ inline bool cross_chain_adaptation(double& chain_stepsize, - Eigen::VectorXd& inv_e_metric) { + Eigen::VectorXd& inv_e_metric, + callbacks::logger& logger) { using boost::accumulators::accumulator_set; using boost::accumulators::stats; using boost::accumulators::tag::mean; @@ -321,6 +334,9 @@ namespace mcmc { rhat_(win) = sqrt((var_between / var_within + num_draws - 1) / num_draws); double ess_hmean = ess_.size()/((1.0/ess_).sum()); // harmonic mean is_adapted_ = rhat_(win) < target_rhat_ && ess_hmean > target_ess_; + + log_adaptation(win, logger); + chain_stepsize = invalid_stepsize; if (is_adapted_) { chain_stepsize = boost::accumulators::mean(acc_step); @@ -338,7 +354,8 @@ namespace mcmc { is_adapted_ = chain_stepsize > 0.0; if (is_adapted_) { if (is_inter_rank) { - // var_adapt -> learn_variance(inv_e_metric); + const Communicator& comm = Session::inter_chain_comm(num_chains_); + // var_adapt -> learn_variance(inv_e_metric, comm); } // MPI_Bcast(inv_e_metric.data(), var_adapt -> estimator.num_params(), MPI_DOUBLE, 0, intra_comm.comm()); } diff --git a/src/stan/mcmc/mpi_var_adaptation.hpp b/src/stan/mcmc/mpi_var_adaptation.hpp index 0edda73e6a0..71dab8c0006 100644 --- a/src/stan/mcmc/mpi_var_adaptation.hpp +++ b/src/stan/mcmc/mpi_var_adaptation.hpp @@ -13,15 +13,16 @@ class mpi_var_adaptation { public: stan::math::mpi::mpi_var_estimator estimator; - explicit mpi_var_adaptation(int n_params, - const stan::math::mpi::Communicator& comm) - : estimator(n_params, comm) {} + explicit mpi_var_adaptation(int n_params) + : estimator(n_params) + {} - explicit mpi_var_adaptation(int num_chains) - : estimator(0, stan::math::mpi::Session::inter_chain_comm(num_chains)) {} + // explicit mpi_var_adaptation(int num_chains) + // : estimator(0, stan::math::mpi::Session::inter_chain_comm(num_chains)) {} - void learn_variance(Eigen::VectorXd& var) { - double n = static_cast(estimator.sample_variance(var)); + void learn_variance(Eigen::VectorXd& var, + const stan::math::mpi::Communicator& comm) { + double n = static_cast(estimator.sample_variance(var, comm)); var = (n / (n + 5.0)) * var + 1e-3 * (5.0 / (n + 5.0)) * Eigen::VectorXd::Ones(var.size()); estimator.restart(); diff --git a/src/stan/services/util/run_mpi_adaptive_sampler.hpp b/src/stan/services/util/run_mpi_adaptive_sampler.hpp index b0a99734533..48a0435b890 100644 --- a/src/stan/services/util/run_mpi_adaptive_sampler.hpp +++ b/src/stan/services/util/run_mpi_adaptive_sampler.hpp @@ -71,15 +71,13 @@ void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, // warmup clock_t start = clock(); -#ifdef MPI_ADAPTED_WARMUP - const double target_rhat = 1.1; + const double target_rhat = 1.05; const double target_ess = 50; const int window_size = 100; sampler.set_cross_chain_adaptation_params(num_warmup, window_size, num_chains, target_rhat, target_ess); - stan::mcmc::mpi_var_adaptation var_adapt(sampler.z().q.size(), - stan::math::mpi::Session::inter_chain_comm(num_chains)); + stan::mcmc::mpi_var_adaptation var_adapt(sampler.z().q.size()); sampler.set_cross_chain_var_adaptation(var_adapt); util::mpi_cross_chain_warmup(sampler, num_chains, num_warmup, 0, num_warmup + num_samples, @@ -87,11 +85,6 @@ void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, window_size, target_rhat, target_ess, writer, s, model, rng, interrupt, logger); -#else - util::generate_transitions(sampler, num_warmup, 0, num_warmup + num_samples, - num_thin, refresh, save_warmup, true, writer, s, - model, rng, interrupt, logger); -#endif clock_t end = clock(); double warm_delta_t = static_cast(end - start) / CLOCKS_PER_SEC; diff --git a/src/test/unit/mcmc/mpi_var_adaptation_test.cpp b/src/test/unit/mcmc/mpi_var_adaptation_test.cpp index 53b286c93ff..68bd9988d77 100644 --- a/src/test/unit/mcmc/mpi_var_adaptation_test.cpp +++ b/src/test/unit/mcmc/mpi_var_adaptation_test.cpp @@ -26,12 +26,12 @@ TEST(McmcVarAdaptation, mpi_learn_variance) { stan::math::mpi::Communicator comm(MPI_COMM_STAN); const int num_chains = comm.size(); const int n_learn_chain = n_learn / num_chains; - stan::mcmc::mpi_var_adaptation mpi_adapter(n, comm); + stan::mcmc::mpi_var_adaptation mpi_adapter(n); Eigen::VectorXd mpi_var(Eigen::VectorXd::Zero(n)); for (int i = 0; i < n_learn_chain; ++i) mpi_adapter.estimator.add_sample(q); - mpi_adapter.learn_variance(mpi_var); + mpi_adapter.learn_variance(mpi_var, comm); for (int i = 0; i < n; ++i) { EXPECT_FLOAT_EQ(var(i), mpi_var(i)); @@ -71,14 +71,14 @@ TEST(McmcVarAdaptation, mpi_data_learn_variance) { stan::math::mpi::Communicator comm(MPI_COMM_STAN); const int num_chains = comm.size(); const int n_learn_chain = n_learn / num_chains; - stan::mcmc::mpi_var_adaptation mpi_adapter(n, comm); + stan::mcmc::mpi_var_adaptation mpi_adapter(n); Eigen::VectorXd mpi_var(Eigen::VectorXd::Zero(n)); for (int i = 0; i < n_learn_chain; ++i) { Eigen::VectorXd q = Eigen::VectorXd::Map(&q_all(i * n + comm.rank() * n * n_learn_chain), n); mpi_adapter.estimator.add_sample(q); } - mpi_adapter.learn_variance(mpi_var); + mpi_adapter.learn_variance(mpi_var, comm); for (int i = 0; i < n; ++i) { EXPECT_FLOAT_EQ(var(i), mpi_var(i)); From 1a5be755e722a188557599959b491dbb19b397bf Mon Sep 17 00:00:00 2001 From: yiz Date: Tue, 28 Jan 2020 17:06:44 -0800 Subject: [PATCH 39/73] track all windows for end-of-warmup metric update --- src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp | 12 ++-- src/stan/mcmc/mpi_cross_chain_adapter.hpp | 58 +++++++++---------- src/stan/mcmc/mpi_var_adaptation.hpp | 22 ++++--- .../services/util/mpi_cross_chain_warmup.hpp | 16 +++++ .../util/run_mpi_adaptive_sampler.hpp | 3 +- 5 files changed, 64 insertions(+), 47 deletions(-) diff --git a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp index 7c49830497d..0ed54b75350 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp @@ -50,11 +50,13 @@ class adapt_diag_e_nuts : public diag_e_nuts, } #ifdef MPI_ADAPTED_WARMUP - this -> add_cross_chain_sample(this->z_.q, s.log_prob()); - double stepsize = this -> get_nominal_stepsize(); - this -> cross_chain_adaptation(stepsize, this->z_.inv_e_metric_, logger); - if (this -> is_cross_chain_adapted()) { - this -> set_nominal_stepsize(stepsize); + if (!this -> is_cross_chain_adapted()) { + this -> add_cross_chain_sample(this->z_.q, s.log_prob()); + double stepsize = this -> get_nominal_stepsize(); + this -> cross_chain_adaptation(stepsize, this->z_.inv_e_metric_, logger); + if (this -> is_cross_chain_adapted()) { + this -> set_nominal_stepsize(stepsize); + } } #endif } diff --git a/src/stan/mcmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/mpi_cross_chain_adapter.hpp index 3dea7029f95..1c1f12f901d 100644 --- a/src/stan/mcmc/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/mpi_cross_chain_adapter.hpp @@ -72,7 +72,7 @@ namespace mcmc { draw_counter_acc_ = {}; rhat_ = Eigen::ArrayXd::Zero(max_num_windows_); ess_ = Eigen::ArrayXd::Zero(num_chains_); - var_adapt -> estimator.restart(); + var_adapt -> restart(); } inline int current_cross_chain_window_counter() { @@ -83,11 +83,6 @@ namespace mcmc { using stan::math::mpi::Session; using stan::math::mpi::Communicator; - // every rank needs num_params through q's size - if (log_prob_draws_.empty()) { - var_adapt -> estimator.restart(q.size()); - } - // all procs keep a counter draw_counter_acc_(0); @@ -95,17 +90,11 @@ namespace mcmc { bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); if (is_inter_rank) { log_prob_draws_.push_back(s); - int n = current_cross_chain_window_counter(); - for (int i = 0; i < n; ++i) { - log_prob_accumulators_[i](s); - } - - // we only keep @c window_size q's - if (is_cross_chain_adapt_window_begin()) { - var_adapt -> estimator.restart(q.size()); + int n_win = current_cross_chain_window_counter(); + for (int win = 0; win < n_win; ++win) { + log_prob_accumulators_[win](s); + var_adapt -> estimators[win].add_sample(q); } - - var_adapt -> estimator.add_sample(q); } } @@ -252,14 +241,16 @@ namespace mcmc { return is_adapted_; } - inline void log_adaptation(int win, callbacks::logger& logger) { + inline void msg_adaptation(int win, callbacks::logger& logger) { std::stringstream message; message << "iteration: "; + message << std::setw(3); message << boost::accumulators::count(draw_counter_acc_); message << " window: " << win + 1 << " / " << current_cross_chain_window_counter(); - message << " Rhat: " << std::setprecision(3) << cross_chain_adapt_rhat()[win]; + message << std::setw(5) << std::setprecision(2); + message << " Rhat: " << std::fixed << cross_chain_adapt_rhat()[win]; const Eigen::ArrayXd& ess(cross_chain_adapt_ess()); - message << " ESS: " << std::setprecision(3) << ess.size()/((1.0/ess).sum()); + message << " ESS: " << std::fixed << ess.size()/((1.0/ess).sum()); logger.info(message); } @@ -287,6 +278,8 @@ namespace mcmc { if (is_cross_chain_adapt_window_end()) { bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); + double invalid_stepsize = -999.0; + double new_stepsize = invalid_stepsize; if (is_inter_rank) { const Communicator& comm = Session::inter_chain_comm(num_chains_); @@ -305,9 +298,10 @@ namespace mcmc { compute_effective_sample_size(win * window_size_, num_draws); } - double invalid_stepsize = -999.0; rhat_ = Eigen::ArrayXd::Zero(max_num_windows_); ess_ = Eigen::ArrayXd::Zero(num_chains_); + const int invalid_win = -999; + int adapted_win = invalid_win; if (comm.rank() == 0) { std::vector all_chain_gather(n_gather * num_chains_); @@ -335,11 +329,10 @@ namespace mcmc { double ess_hmean = ess_.size()/((1.0/ess_).sum()); // harmonic mean is_adapted_ = rhat_(win) < target_rhat_ && ess_hmean > target_ess_; - log_adaptation(win, logger); + msg_adaptation(win, logger); - chain_stepsize = invalid_stepsize; if (is_adapted_) { - chain_stepsize = boost::accumulators::mean(acc_step); + adapted_win = win; break; } } @@ -347,22 +340,23 @@ namespace mcmc { MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, NULL, 0, MPI_DOUBLE, 0, comm.comm()); } - MPI_Bcast(&chain_stepsize, 1, MPI_DOUBLE, 0, comm.comm()); + MPI_Bcast(&adapted_win, 1, MPI_INT, 0, comm.comm()); + if (adapted_win >= 0) { + MPI_Allreduce(&chain_stepsize, &new_stepsize, 1, MPI_DOUBLE, MPI_SUM, comm.comm()); + new_stepsize /= num_chains_; + var_adapt -> learn_variance(inv_e_metric, adapted_win, comm); + } } const Communicator& intra_comm = Session::intra_chain_comm(num_chains_); - MPI_Bcast(&chain_stepsize, 1, MPI_DOUBLE, 0, intra_comm.comm()); - is_adapted_ = chain_stepsize > 0.0; + MPI_Bcast(&new_stepsize, 1, MPI_DOUBLE, 0, intra_comm.comm()); + is_adapted_ = new_stepsize > 0.0; if (is_adapted_) { - if (is_inter_rank) { - const Communicator& comm = Session::inter_chain_comm(num_chains_); - // var_adapt -> learn_variance(inv_e_metric, comm); - } - // MPI_Bcast(inv_e_metric.data(), var_adapt -> estimator.num_params(), MPI_DOUBLE, 0, intra_comm.comm()); + chain_stepsize = new_stepsize; + MPI_Bcast(inv_e_metric.data(), var_adapt -> estimators[0].num_params(), MPI_DOUBLE, 0, intra_comm.comm()); } } return is_adapted_; } - }; } } diff --git a/src/stan/mcmc/mpi_var_adaptation.hpp b/src/stan/mcmc/mpi_var_adaptation.hpp index 71dab8c0006..55337bfedf5 100644 --- a/src/stan/mcmc/mpi_var_adaptation.hpp +++ b/src/stan/mcmc/mpi_var_adaptation.hpp @@ -11,21 +11,25 @@ namespace mcmc { class mpi_var_adaptation { public: - stan::math::mpi::mpi_var_estimator estimator; + std::vector estimators; - explicit mpi_var_adaptation(int n_params) - : estimator(n_params) + mpi_var_adaptation(int n_params, int num_iterations, int window_size) + : estimators(num_iterations / window_size, + stan::math::mpi::mpi_var_estimator(n_params)) {} - // explicit mpi_var_adaptation(int num_chains) - // : estimator(0, stan::math::mpi::Session::inter_chain_comm(num_chains)) {} - - void learn_variance(Eigen::VectorXd& var, + void learn_variance(Eigen::VectorXd& var, int win, const stan::math::mpi::Communicator& comm) { - double n = static_cast(estimator.sample_variance(var, comm)); + double n = static_cast(estimators[win].sample_variance(var, comm)); var = (n / (n + 5.0)) * var + 1e-3 * (5.0 / (n + 5.0)) * Eigen::VectorXd::Ones(var.size()); - estimator.restart(); + restart(); + } + + void restart() { + for (auto&& adapt : estimators) { + adapt.restart(); + } } }; diff --git a/src/stan/services/util/mpi_cross_chain_warmup.hpp b/src/stan/services/util/mpi_cross_chain_warmup.hpp index f9db36531b1..b91e4b1e0b3 100644 --- a/src/stan/services/util/mpi_cross_chain_warmup.hpp +++ b/src/stan/services/util/mpi_cross_chain_warmup.hpp @@ -77,6 +77,22 @@ void mpi_cross_chain_warmup(Sampler& sampler, int num_chains, // check cross-chain convergence if (sampler.is_cross_chain_adapted()) { + for (int j = m + 1; j < m + 50; ++j) { + if (refresh > 0 + && (start + j + 1 == finish || j == 0 || (j + 1) % refresh == 0)) { + int it_print_width = std::ceil(std::log10(static_cast(finish))); + std::stringstream message; + message << "Iteration: "; + message << std::setw(it_print_width) << j + 1 + start << " / " << finish; + message << " [" << std::setw(3) + << static_cast((100.0 * (start + j + 1)) / finish) << "%] "; + message << (warmup ? " (Warmup)" : " (Sampling)"); + + logger.info(message); + } + + init_s = sampler.transition(init_s, logger); + } break; } } diff --git a/src/stan/services/util/run_mpi_adaptive_sampler.hpp b/src/stan/services/util/run_mpi_adaptive_sampler.hpp index 48a0435b890..85ee55532b5 100644 --- a/src/stan/services/util/run_mpi_adaptive_sampler.hpp +++ b/src/stan/services/util/run_mpi_adaptive_sampler.hpp @@ -77,7 +77,8 @@ void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, sampler.set_cross_chain_adaptation_params(num_warmup, window_size, num_chains, target_rhat, target_ess); - stan::mcmc::mpi_var_adaptation var_adapt(sampler.z().q.size()); + stan::mcmc::mpi_var_adaptation + var_adapt(sampler.z().q.size(), num_warmup, window_size); sampler.set_cross_chain_var_adaptation(var_adapt); util::mpi_cross_chain_warmup(sampler, num_chains, num_warmup, 0, num_warmup + num_samples, From 36bd48089c7ad95b84f7bc01bbf3b9704c1e08f3 Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 29 Jan 2020 11:57:32 -0800 Subject: [PATCH 40/73] use minimum instead of harmonic mean for ESS test --- src/stan/mcmc/mpi_cross_chain_adapter.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/stan/mcmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/mpi_cross_chain_adapter.hpp index 1c1f12f901d..b9f97bb8a18 100644 --- a/src/stan/mcmc/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/mpi_cross_chain_adapter.hpp @@ -250,7 +250,7 @@ namespace mcmc { message << std::setw(5) << std::setprecision(2); message << " Rhat: " << std::fixed << cross_chain_adapt_rhat()[win]; const Eigen::ArrayXd& ess(cross_chain_adapt_ess()); - message << " ESS: " << std::fixed << ess.size()/((1.0/ess).sum()); + message << " ESS: " << std::fixed << ess_.matrix().mean(); logger.info(message); } @@ -326,8 +326,8 @@ namespace mcmc { * num_chains_ / (num_chains_ - 1); double var_within = boost::accumulators::mean(acc_chain_var); rhat_(win) = sqrt((var_between / var_within + num_draws - 1) / num_draws); - double ess_hmean = ess_.size()/((1.0/ess_).sum()); // harmonic mean - is_adapted_ = rhat_(win) < target_rhat_ && ess_hmean > target_ess_; + // double ess_hmean = ess_.size()/((1.0/ess_).sum()); // harmonic mean + is_adapted_ = rhat_(win) < target_rhat_ && (ess_ > target_ess_).all(); msg_adaptation(win, logger); From a34a6617bc6b0d224315bb2a19186fcc40a9d7ba Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 29 Jan 2020 12:43:56 -0800 Subject: [PATCH 41/73] mpi_var_adaptatin under #ifdef STAN_LANG_MPI --- lib/stan_math | 2 +- make/mpi_warmup.mk | 5 +++++ makefile | 1 + src/stan/mcmc/mpi_var_adaptation.hpp | 18 ++++++++++++++---- src/test/unit/mcmc/mpi_var_adaptation_test.cpp | 16 ++++++++++------ 5 files changed, 31 insertions(+), 11 deletions(-) create mode 100644 make/mpi_warmup.mk diff --git a/lib/stan_math b/lib/stan_math index 69f31322c18..ad1525ab441 160000 --- a/lib/stan_math +++ b/lib/stan_math @@ -1 +1 @@ -Subproject commit 69f31322c180f555f156fed8e482bb2cd7fd604c +Subproject commit ad1525ab44178ff781a5130d93920a0e6d2923f1 diff --git a/make/mpi_warmup.mk b/make/mpi_warmup.mk new file mode 100644 index 00000000000..7b047a025ad --- /dev/null +++ b/make/mpi_warmup.mk @@ -0,0 +1,5 @@ +ifdef MPI_ADAPTED_WARMUP + CXXFLAGS += -DSTAN_LANG_MPI -DMPI_ADAPTED_WARMUP + CC=mpicxx + CXX=mpicxx +endif diff --git a/makefile b/makefile index da551f3baaa..0e6d98ab23f 100644 --- a/makefile +++ b/makefile @@ -13,6 +13,7 @@ help: -include $(HOME)/.config/stan/make.local # user-defined variables -include make/local # user-defined variables +-include make/mpi_warmup.mk MATH ?= lib/stan_math/ ifeq ($(OS),Windows_NT) diff --git a/src/stan/mcmc/mpi_var_adaptation.hpp b/src/stan/mcmc/mpi_var_adaptation.hpp index 55337bfedf5..82016e81346 100644 --- a/src/stan/mcmc/mpi_var_adaptation.hpp +++ b/src/stan/mcmc/mpi_var_adaptation.hpp @@ -1,6 +1,8 @@ #ifndef STAN_MCMC_MPI_VAR_ADAPTATION_HPP #define STAN_MCMC_MPI_VAR_ADAPTATION_HPP +#ifdef STAN_LANG_MPI + #include #include #include @@ -10,12 +12,17 @@ namespace stan { namespace mcmc { class mpi_var_adaptation { - public: - std::vector estimators; + using est_t = stan::math::mpi::mpi_var_estimator; + +public: + std::vector estimators; + + mpi_var_adaptation(int n_params, int max_num_windows) + : estimators(max_num_windows, est_t(n_params)) + {} mpi_var_adaptation(int n_params, int num_iterations, int window_size) - : estimators(num_iterations / window_size, - stan::math::mpi::mpi_var_estimator(n_params)) + : mpi_var_adaptation(n_params, num_iterations / window_size) {} void learn_variance(Eigen::VectorXd& var, int win, @@ -36,4 +43,7 @@ class mpi_var_adaptation { } // namespace mcmc } // namespace stan + +#endif + #endif diff --git a/src/test/unit/mcmc/mpi_var_adaptation_test.cpp b/src/test/unit/mcmc/mpi_var_adaptation_test.cpp index 68bd9988d77..b04b06bd84a 100644 --- a/src/test/unit/mcmc/mpi_var_adaptation_test.cpp +++ b/src/test/unit/mcmc/mpi_var_adaptation_test.cpp @@ -1,3 +1,5 @@ +#ifdef STAN_LANG_MPI + #include #include #include @@ -26,12 +28,12 @@ TEST(McmcVarAdaptation, mpi_learn_variance) { stan::math::mpi::Communicator comm(MPI_COMM_STAN); const int num_chains = comm.size(); const int n_learn_chain = n_learn / num_chains; - stan::mcmc::mpi_var_adaptation mpi_adapter(n); + stan::mcmc::mpi_var_adaptation mpi_adapter(n, 1); Eigen::VectorXd mpi_var(Eigen::VectorXd::Zero(n)); for (int i = 0; i < n_learn_chain; ++i) - mpi_adapter.estimator.add_sample(q); + mpi_adapter.estimators[0].add_sample(q); - mpi_adapter.learn_variance(mpi_var, comm); + mpi_adapter.learn_variance(mpi_var, 0, comm); for (int i = 0; i < n; ++i) { EXPECT_FLOAT_EQ(var(i), mpi_var(i)); @@ -71,16 +73,18 @@ TEST(McmcVarAdaptation, mpi_data_learn_variance) { stan::math::mpi::Communicator comm(MPI_COMM_STAN); const int num_chains = comm.size(); const int n_learn_chain = n_learn / num_chains; - stan::mcmc::mpi_var_adaptation mpi_adapter(n); + stan::mcmc::mpi_var_adaptation mpi_adapter(n, 1); Eigen::VectorXd mpi_var(Eigen::VectorXd::Zero(n)); for (int i = 0; i < n_learn_chain; ++i) { Eigen::VectorXd q = Eigen::VectorXd::Map(&q_all(i * n + comm.rank() * n * n_learn_chain), n); - mpi_adapter.estimator.add_sample(q); + mpi_adapter.estimators[0].add_sample(q); } - mpi_adapter.learn_variance(mpi_var, comm); + mpi_adapter.learn_variance(mpi_var, 0, comm); for (int i = 0; i < n; ++i) { EXPECT_FLOAT_EQ(var(i), mpi_var(i)); } } + +#endif From a299c5075ca4e132a29e8af971e88c2f73ffa061 Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 29 Jan 2020 15:12:57 -0800 Subject: [PATCH 42/73] clean up cross-chain adapter calls in sampler --- .../{ => hmc}/mpi_cross_chain_adapter.hpp | 44 ++++++++++++------- src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp | 19 +++----- src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp | 11 ++--- .../util => mcmc/hmc}/mpi_warmup_test.cpp | 32 +++++++++++--- 4 files changed, 61 insertions(+), 45 deletions(-) rename src/stan/mcmc/{ => hmc}/mpi_cross_chain_adapter.hpp (93%) rename src/test/unit/{services/util => mcmc/hmc}/mpi_warmup_test.cpp (94%) diff --git a/src/stan/mcmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp similarity index 93% rename from src/stan/mcmc/mpi_cross_chain_adapter.hpp rename to src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp index b9f97bb8a18..8947e805600 100644 --- a/src/stan/mcmc/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp @@ -1,9 +1,11 @@ -#ifndef STAN_MCMC_MPI_CROSS_CHAIN_ADAPTER_HPP -#define STAN_MCMC_MPI_CROSS_CHAIN_ADAPTER_HPP +#ifndef STAN_MCMC_HMC_MPI_CROSS_CHAIN_ADAPTER_HPP +#define STAN_MCMC_HMC_MPI_CROSS_CHAIN_ADAPTER_HPP + +#ifdef MPI_ADAPTED_WARMUP #include #include -#include +#include #include #include #include @@ -83,17 +85,19 @@ namespace mcmc { using stan::math::mpi::Session; using stan::math::mpi::Communicator; - // all procs keep a counter - draw_counter_acc_(0); - - // only add samples to inter-chain ranks - bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); - if (is_inter_rank) { - log_prob_draws_.push_back(s); - int n_win = current_cross_chain_window_counter(); - for (int win = 0; win < n_win; ++win) { - log_prob_accumulators_[win](s); - var_adapt -> estimators[win].add_sample(q); + if (!is_adapted_) { + // all procs keep a counter + draw_counter_acc_(0); + + // only add samples to inter-chain ranks + bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); + if (is_inter_rank) { + log_prob_draws_.push_back(s); + int n_win = current_cross_chain_window_counter(); + for (int win = 0; win < n_win; ++win) { + log_prob_accumulators_[win](s); + var_adapt -> estimators[win].add_sample(q); + } } } } @@ -265,7 +269,8 @@ namespace mcmc { * maximum windows for all chains. # @return vector {stepsize, rhat(only in rank 0)} */ - inline bool cross_chain_adaptation(double& chain_stepsize, + template + inline void cross_chain_adaptation(Sampler* hmc_sampler, Eigen::VectorXd& inv_e_metric, callbacks::logger& logger) { using boost::accumulators::accumulator_set; @@ -276,7 +281,8 @@ namespace mcmc { using stan::math::mpi::Session; using stan::math::mpi::Communicator; - if (is_cross_chain_adapt_window_end()) { + if ((!is_adapted_) && is_cross_chain_adapt_window_end()) { + double chain_stepsize = hmc_sampler -> get_nominal_stepsize(); bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); double invalid_stepsize = -999.0; double new_stepsize = invalid_stepsize; @@ -354,10 +360,14 @@ namespace mcmc { chain_stepsize = new_stepsize; MPI_Bcast(inv_e_metric.data(), var_adapt -> estimators[0].num_params(), MPI_DOUBLE, 0, intra_comm.comm()); } + if (is_adapted_) { + hmc_sampler -> set_nominal_stepsize(chain_stepsize); + } } - return is_adapted_; } }; } } #endif + +#endif diff --git a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp index 0ed54b75350..7dfb7bcedf5 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp @@ -6,7 +6,7 @@ #include #ifdef MPI_ADAPTED_WARMUP -#include +#include #endif namespace stan { @@ -17,14 +17,11 @@ namespace mcmc { * diagonal metric and adaptive step size */ template -#ifdef MPI_ADAPTED_WARMUP -class adapt_diag_e_nuts : public diag_e_nuts, - public stepsize_var_adapter, - public mpi_cross_chain_adapter { -#else class adapt_diag_e_nuts : public diag_e_nuts, - public stepsize_var_adapter { +#ifdef MPI_ADAPTED_WARMUP + public mpi_cross_chain_adapter, #endif + public stepsize_var_adapter { public: adapt_diag_e_nuts(const Model& model, BaseRNG& rng) : diag_e_nuts(model, rng), @@ -50,14 +47,8 @@ class adapt_diag_e_nuts : public diag_e_nuts, } #ifdef MPI_ADAPTED_WARMUP - if (!this -> is_cross_chain_adapted()) { this -> add_cross_chain_sample(this->z_.q, s.log_prob()); - double stepsize = this -> get_nominal_stepsize(); - this -> cross_chain_adaptation(stepsize, this->z_.inv_e_metric_, logger); - if (this -> is_cross_chain_adapted()) { - this -> set_nominal_stepsize(stepsize); - } - } + this -> cross_chain_adaptation(this, this->z_.inv_e_metric_, logger); #endif } return s; diff --git a/src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp index 7f8c7df7048..a0f3ed62ff9 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp @@ -6,7 +6,7 @@ #include #ifdef MPI_ADAPTED_WARMUP -#include +#include #endif namespace stan { @@ -17,14 +17,11 @@ namespace mcmc { * and adaptive step size */ template -#ifdef MPI_ADAPTED_WARMUP -class adapt_unit_e_nuts : public unit_e_nuts, - public stepsize_adapter, - public mpi_cross_chain_adapter { -#else class adapt_unit_e_nuts : public unit_e_nuts, - public stepsize_adapter { +#ifdef MPI_ADAPTED_WARMUP + public mpi_cross_chain_adapter, #endif + public stepsize_adapter { public: adapt_unit_e_nuts(const Model& model, BaseRNG& rng) : unit_e_nuts(model, rng) {} diff --git a/src/test/unit/services/util/mpi_warmup_test.cpp b/src/test/unit/mcmc/hmc/mpi_warmup_test.cpp similarity index 94% rename from src/test/unit/services/util/mpi_warmup_test.cpp rename to src/test/unit/mcmc/hmc/mpi_warmup_test.cpp index 08ff13a1bb7..f17f83e8ef5 100644 --- a/src/test/unit/services/util/mpi_warmup_test.cpp +++ b/src/test/unit/mcmc/hmc/mpi_warmup_test.cpp @@ -1,7 +1,7 @@ #ifdef STAN_LANG_MPI #include -#include +#include #include #include #include @@ -25,8 +25,25 @@ using boost::accumulators::stats; using boost::accumulators::tag::mean; using boost::accumulators::tag::variance; +struct dummy_sampler { + double stepsize; + + dummy_sampler(double step) : stepsize(step) {} + + double get_nominal_stepsize() { + return stepsize; + } + + void set_nominal_stepsize(double new_stepsize) { + stepsize = new_stepsize; + } +}; + // 4 chains with 4 cores, each chain run on a core TEST(mpi_warmup_test, mpi_cross_chain_adapter) { + stan::callbacks::stream_logger logger(std::cout, std::cout, std::cout, + std::cerr, std::cerr); + const int num_chains = 4; const int max_num_windows = 5; const int window_size = 50; @@ -252,7 +269,7 @@ draw_vecs[3] << cc_adapter.set_cross_chain_adaptation_params(num_iterations, window_size, num_chains, 1.1, 40); - stan::mcmc::mpi_var_adaptation var_adapt(0, comm); + stan::mcmc::mpi_var_adaptation var_adapt(0, max_num_windows); cc_adapter.set_cross_chain_var_adaptation(var_adapt); Eigen::VectorXd dummy; @@ -261,10 +278,10 @@ draw_vecs[3] << for (int i = 0; i < num_iterations; ++i) { cc_adapter.add_cross_chain_sample(dummy, draw_vecs[comm.rank()](i)); - double step = chain_stepsize; - bool is_adapted = cc_adapter.cross_chain_adaptation(step, dummy); + dummy_sampler sampler(chain_stepsize); + cc_adapter.cross_chain_adaptation(&sampler, dummy, logger); - EXPECT_FALSE(is_adapted); + EXPECT_FALSE(cc_adapter.is_cross_chain_adapted()); if (cc_adapter.is_cross_chain_adapt_window_end()) { int curr_num_win = cc_adapter.current_cross_chain_window_counter(); @@ -292,11 +309,12 @@ draw_vecs[3] << int curr_num_win = 4; double target_ess = 15.0; for (int i = 0; i < num_iterations; ++i) { + dummy_sampler sampler(chain_stepsize); cc_adapter.add_cross_chain_sample(dummy, draw_vecs[comm.rank()](i)); double step = chain_stepsize; - bool is_adapted = cc_adapter.cross_chain_adaptation(step, dummy); - if (is_adapted) break; + cc_adapter.cross_chain_adaptation(&sampler, dummy, logger); + if (cc_adapter.is_cross_chain_adapted()) break; } int win = 1; // win = 1 @c is_adapted const std::vector p{ From 77065b3615ddddaa24745334e7b624ff6cc98d1a Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 29 Jan 2020 21:48:18 -0800 Subject: [PATCH 43/73] pass cross chain window size as function call arg --- src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp | 11 +++++++---- src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp | 6 ++++-- src/stan/services/util/run_mpi_adaptive_sampler.hpp | 9 ++++----- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp index 9c8f5eb5ba9..cda456a0605 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp @@ -61,7 +61,8 @@ template int hmc_nuts_diag_e_adapt( Model& model, stan::io::var_context& init, stan::io::var_context& init_inv_metric, unsigned int random_seed, - unsigned int chain, double init_radius, int num_cross_chains, int num_warmup, int num_samples, + unsigned int chain, double init_radius, + int num_cross_chains, int cross_chain_window, int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, double delta, double gamma, double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer, @@ -101,7 +102,7 @@ int hmc_nuts_diag_e_adapt( #ifdef MPI_ADAPTED_WARMUP util::run_mpi_adaptive_sampler(sampler, - model, cont_vector, num_cross_chains, num_warmup, num_samples, num_thin, refresh, + model, cont_vector, num_cross_chains, cross_chain_window, num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); #else util::run_adaptive_sampler( @@ -146,7 +147,8 @@ int hmc_nuts_diag_e_adapt( template int hmc_nuts_diag_e_adapt( Model& model, stan::io::var_context& init, unsigned int random_seed, - unsigned int chain, double init_radius, int num_cross_chains, int num_warmup, int num_samples, + unsigned int chain, double init_radius, + int num_cross_chains, int cross_chain_window, int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, double delta, double gamma, double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer, @@ -158,7 +160,8 @@ int hmc_nuts_diag_e_adapt( stan::io::var_context& unit_e_metric = dmp; return hmc_nuts_diag_e_adapt( - model, init, unit_e_metric, random_seed, chain, init_radius, num_cross_chains, num_warmup, + model, init, unit_e_metric, random_seed, chain, init_radius, + num_cross_chains, cross_chain_window, num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, logger, init_writer, sample_writer, diagnostic_writer); diff --git a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp index c9a14c1a0ac..bb442588476 100644 --- a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp @@ -48,7 +48,8 @@ namespace sample { template int hmc_nuts_unit_e_adapt( Model& model, stan::io::var_context& init, unsigned int random_seed, - unsigned int chain, double init_radius, int num_cross_chains, int num_warmup, int num_samples, + unsigned int chain, double init_radius, + int num_cross_chains, int cross_chain_window, int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, double delta, double gamma, double kappa, double t0, callbacks::interrupt& interrupt, @@ -73,7 +74,8 @@ int hmc_nuts_unit_e_adapt( #ifdef MPI_ADAPTED_WARMUP util::run_mpi_adaptive_sampler( - sampler, model, cont_vector, num_cross_chains, num_warmup, num_samples, num_thin, refresh, + sampler, model, cont_vector, num_cross_chains, cross_chain_window, + num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); #else util::run_adaptive_sampler( diff --git a/src/stan/services/util/run_mpi_adaptive_sampler.hpp b/src/stan/services/util/run_mpi_adaptive_sampler.hpp index 85ee55532b5..5feb0237742 100644 --- a/src/stan/services/util/run_mpi_adaptive_sampler.hpp +++ b/src/stan/services/util/run_mpi_adaptive_sampler.hpp @@ -42,7 +42,7 @@ namespace util { template void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, std::vector& cont_vector, - int num_chains, int num_warmup, + int num_chains, int cross_chain_window, int num_warmup, int num_samples, int num_thin, int refresh, bool save_warmup, RNG& rng, callbacks::interrupt& interrupt, @@ -73,17 +73,16 @@ void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, clock_t start = clock(); const double target_rhat = 1.05; const double target_ess = 50; - const int window_size = 100; sampler.set_cross_chain_adaptation_params(num_warmup, - window_size, num_chains, + cross_chain_window, num_chains, target_rhat, target_ess); stan::mcmc::mpi_var_adaptation - var_adapt(sampler.z().q.size(), num_warmup, window_size); + var_adapt(sampler.z().q.size(), num_warmup, cross_chain_window); sampler.set_cross_chain_var_adaptation(var_adapt); util::mpi_cross_chain_warmup(sampler, num_chains, num_warmup, 0, num_warmup + num_samples, num_thin, refresh, save_warmup, true, - window_size, target_rhat, target_ess, + cross_chain_window, target_rhat, target_ess, writer, s, model, rng, interrupt, logger); clock_t end = clock(); From 1410f2c306189d8f00f69d17fe96f15c7b3b75d4 Mon Sep 17 00:00:00 2001 From: yiz Date: Thu, 30 Jan 2020 09:27:33 -0800 Subject: [PATCH 44/73] add target_ess to arg list --- src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp | 11 +++++++---- src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp | 5 +++-- src/stan/services/util/mpi_cross_chain_warmup.hpp | 4 +--- src/stan/services/util/run_mpi_adaptive_sampler.hpp | 12 +++++------- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp index cda456a0605..634bac20541 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp @@ -62,7 +62,8 @@ int hmc_nuts_diag_e_adapt( Model& model, stan::io::var_context& init, stan::io::var_context& init_inv_metric, unsigned int random_seed, unsigned int chain, double init_radius, - int num_cross_chains, int cross_chain_window, int num_warmup, int num_samples, + int num_cross_chains, int cross_chain_window, int cross_chain_ess, + int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, double delta, double gamma, double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer, @@ -102,7 +103,8 @@ int hmc_nuts_diag_e_adapt( #ifdef MPI_ADAPTED_WARMUP util::run_mpi_adaptive_sampler(sampler, - model, cont_vector, num_cross_chains, cross_chain_window, num_warmup, num_samples, num_thin, refresh, + model, cont_vector, num_cross_chains, cross_chain_window, cross_chain_ess, + num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); #else util::run_adaptive_sampler( @@ -148,7 +150,8 @@ template int hmc_nuts_diag_e_adapt( Model& model, stan::io::var_context& init, unsigned int random_seed, unsigned int chain, double init_radius, - int num_cross_chains, int cross_chain_window, int num_warmup, int num_samples, + int num_cross_chains, int cross_chain_window, int cross_chain_ess, + int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, double delta, double gamma, double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer, @@ -161,7 +164,7 @@ int hmc_nuts_diag_e_adapt( return hmc_nuts_diag_e_adapt( model, init, unit_e_metric, random_seed, chain, init_radius, - num_cross_chains, cross_chain_window, num_warmup, + num_cross_chains, cross_chain_window, cross_chain_ess, num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, logger, init_writer, sample_writer, diagnostic_writer); diff --git a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp index bb442588476..02e40cc1325 100644 --- a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp @@ -49,7 +49,8 @@ template int hmc_nuts_unit_e_adapt( Model& model, stan::io::var_context& init, unsigned int random_seed, unsigned int chain, double init_radius, - int num_cross_chains, int cross_chain_window, int num_warmup, int num_samples, + int num_cross_chains, int cross_chain_window, int cross_chain_ess, + int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, double delta, double gamma, double kappa, double t0, callbacks::interrupt& interrupt, @@ -74,7 +75,7 @@ int hmc_nuts_unit_e_adapt( #ifdef MPI_ADAPTED_WARMUP util::run_mpi_adaptive_sampler( - sampler, model, cont_vector, num_cross_chains, cross_chain_window, + sampler, model, cont_vector, num_cross_chains, cross_chain_window, cross_chain_ess, num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); #else diff --git a/src/stan/services/util/mpi_cross_chain_warmup.hpp b/src/stan/services/util/mpi_cross_chain_warmup.hpp index b91e4b1e0b3..3537faff87d 100644 --- a/src/stan/services/util/mpi_cross_chain_warmup.hpp +++ b/src/stan/services/util/mpi_cross_chain_warmup.hpp @@ -43,11 +43,9 @@ namespace util { * @param[in,out] logger logger for messages */ template -void mpi_cross_chain_warmup(Sampler& sampler, int num_chains, - int num_iterations, +void mpi_cross_chain_warmup(Sampler& sampler, int num_iterations, int start, int finish, int num_thin, int refresh, bool save, bool warmup, - int window_size, double target_rhat, double target_ess, util::mcmc_writer& mcmc_writer, stan::mcmc::sample& init_s, Model& model, RNG& base_rng, callbacks::interrupt& callback, diff --git a/src/stan/services/util/run_mpi_adaptive_sampler.hpp b/src/stan/services/util/run_mpi_adaptive_sampler.hpp index 5feb0237742..19c7fc60758 100644 --- a/src/stan/services/util/run_mpi_adaptive_sampler.hpp +++ b/src/stan/services/util/run_mpi_adaptive_sampler.hpp @@ -42,8 +42,8 @@ namespace util { template void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, std::vector& cont_vector, - int num_chains, int cross_chain_window, int num_warmup, - int num_samples, int num_thin, int refresh, + int num_chains, int cross_chain_window, int cross_chain_ess, + int num_warmup,int num_samples, int num_thin, int refresh, bool save_warmup, RNG& rng, callbacks::interrupt& interrupt, callbacks::logger& logger, @@ -72,17 +72,15 @@ void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, // warmup clock_t start = clock(); const double target_rhat = 1.05; - const double target_ess = 50; sampler.set_cross_chain_adaptation_params(num_warmup, - cross_chain_window, num_chains, - target_rhat, target_ess); + cross_chain_window, num_chains, + target_rhat, cross_chain_ess); stan::mcmc::mpi_var_adaptation var_adapt(sampler.z().q.size(), num_warmup, cross_chain_window); sampler.set_cross_chain_var_adaptation(var_adapt); - util::mpi_cross_chain_warmup(sampler, num_chains, + util::mpi_cross_chain_warmup(sampler, num_warmup, 0, num_warmup + num_samples, num_thin, refresh, save_warmup, true, - cross_chain_window, target_rhat, target_ess, writer, s, model, rng, interrupt, logger); clock_t end = clock(); From c35985eabb4b2c530a03ae097c02e7bd8fe6f4a4 Mon Sep 17 00:00:00 2001 From: yiz Date: Thu, 30 Jan 2020 14:28:52 -0800 Subject: [PATCH 45/73] cross_chain_rhat argument --- src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp | 8 ++++---- src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp | 4 ++-- src/stan/services/util/run_mpi_adaptive_sampler.hpp | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp index 634bac20541..ad4c051dc47 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp @@ -62,7 +62,7 @@ int hmc_nuts_diag_e_adapt( Model& model, stan::io::var_context& init, stan::io::var_context& init_inv_metric, unsigned int random_seed, unsigned int chain, double init_radius, - int num_cross_chains, int cross_chain_window, int cross_chain_ess, + int num_cross_chains, int cross_chain_window, double cross_chain_rhat, int cross_chain_ess, int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, double delta, double gamma, @@ -103,7 +103,7 @@ int hmc_nuts_diag_e_adapt( #ifdef MPI_ADAPTED_WARMUP util::run_mpi_adaptive_sampler(sampler, - model, cont_vector, num_cross_chains, cross_chain_window, cross_chain_ess, + model, cont_vector, num_cross_chains, cross_chain_window, cross_chain_rhat, cross_chain_ess, num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); #else @@ -150,7 +150,7 @@ template int hmc_nuts_diag_e_adapt( Model& model, stan::io::var_context& init, unsigned int random_seed, unsigned int chain, double init_radius, - int num_cross_chains, int cross_chain_window, int cross_chain_ess, + int num_cross_chains, int cross_chain_window, double cross_chain_rhat, int cross_chain_ess, int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, double delta, double gamma, @@ -164,7 +164,7 @@ int hmc_nuts_diag_e_adapt( return hmc_nuts_diag_e_adapt( model, init, unit_e_metric, random_seed, chain, init_radius, - num_cross_chains, cross_chain_window, cross_chain_ess, num_warmup, + num_cross_chains, cross_chain_window, cross_chain_rhat, cross_chain_ess, num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, logger, init_writer, sample_writer, diagnostic_writer); diff --git a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp index 02e40cc1325..14bb154bf62 100644 --- a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp @@ -49,7 +49,7 @@ template int hmc_nuts_unit_e_adapt( Model& model, stan::io::var_context& init, unsigned int random_seed, unsigned int chain, double init_radius, - int num_cross_chains, int cross_chain_window, int cross_chain_ess, + int num_cross_chains, int cross_chain_window, double cross_chain_rhat, int cross_chain_ess, int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, double delta, double gamma, @@ -75,7 +75,7 @@ int hmc_nuts_unit_e_adapt( #ifdef MPI_ADAPTED_WARMUP util::run_mpi_adaptive_sampler( - sampler, model, cont_vector, num_cross_chains, cross_chain_window, cross_chain_ess, + sampler, model, cont_vector, num_cross_chains, cross_chain_window, cross_chain_rhat, cross_chain_ess, num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); #else diff --git a/src/stan/services/util/run_mpi_adaptive_sampler.hpp b/src/stan/services/util/run_mpi_adaptive_sampler.hpp index 19c7fc60758..8ca021661ae 100644 --- a/src/stan/services/util/run_mpi_adaptive_sampler.hpp +++ b/src/stan/services/util/run_mpi_adaptive_sampler.hpp @@ -42,7 +42,8 @@ namespace util { template void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, std::vector& cont_vector, - int num_chains, int cross_chain_window, int cross_chain_ess, + int num_chains, int cross_chain_window, + double cross_chain_rhat, int cross_chain_ess, int num_warmup,int num_samples, int num_thin, int refresh, bool save_warmup, RNG& rng, callbacks::interrupt& interrupt, @@ -71,10 +72,9 @@ void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, // warmup clock_t start = clock(); - const double target_rhat = 1.05; sampler.set_cross_chain_adaptation_params(num_warmup, cross_chain_window, num_chains, - target_rhat, cross_chain_ess); + cross_chain_rhat, cross_chain_ess); stan::mcmc::mpi_var_adaptation var_adapt(sampler.z().q.size(), num_warmup, cross_chain_window); sampler.set_cross_chain_var_adaptation(var_adapt); From efe33a5701021bd4dd3304bbc730dba4577247e6 Mon Sep 17 00:00:00 2001 From: yiz Date: Thu, 30 Jan 2020 14:43:19 -0800 Subject: [PATCH 46/73] output min ESS among the chains --- src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp index 8947e805600..9670e564b7a 100644 --- a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp @@ -254,7 +254,7 @@ namespace mcmc { message << std::setw(5) << std::setprecision(2); message << " Rhat: " << std::fixed << cross_chain_adapt_rhat()[win]; const Eigen::ArrayXd& ess(cross_chain_adapt_ess()); - message << " ESS: " << std::fixed << ess_.matrix().mean(); + message << " ESS: " << std::fixed << ess_.matrix().minCoeff(); logger.info(message); } From 57949e3cd7f3d125f7d140ed46f56691c0492608 Mon Sep 17 00:00:00 2001 From: yiz Date: Thu, 30 Jan 2020 23:30:32 -0800 Subject: [PATCH 47/73] mpi_stream_writer --- src/stan/callbacks/mpi_stream_writer.hpp | 138 ++++++++++++++ src/stan/callbacks/stream_writer.hpp | 228 ++++++++++------------- 2 files changed, 236 insertions(+), 130 deletions(-) create mode 100644 src/stan/callbacks/mpi_stream_writer.hpp diff --git a/src/stan/callbacks/mpi_stream_writer.hpp b/src/stan/callbacks/mpi_stream_writer.hpp new file mode 100644 index 00000000000..1329ef498d6 --- /dev/null +++ b/src/stan/callbacks/mpi_stream_writer.hpp @@ -0,0 +1,138 @@ +#ifndef STAN_CALLBACKS_MPI_STREAM_WRITER_HPP +#define STAN_CALLBACKS_MPI_STREAM_WRITER_HPP + +#ifdef MPI_ADAPTED_WARMUP + +#include +#include +#include +#include +#include + +namespace stan { + namespace callbacks { + /** + * mpi_stream_writer is an implementation + * of writer that writes to a stream. + */ + class mpi_stream_writer : public writer { + public: + /** + * Constructs a stream writer with an output stream + * and an optional prefix for comments. + * + * @param[in, out] output stream to write + * @param[in] comment_prefix string to stream before + * each comment line. Default is "". + */ + mpi_stream_writer(int num_chains, std::ostream& output, + const std::string& comment_prefix = "") + : num_chains_(num_chains), output_(output), + comment_prefix_(comment_prefix) + {} + + /** + * Virtual destructor + */ + virtual ~mpi_stream_writer() {} + + /** + * Set new value for @c num_chains_. + * + * @param[in] n new value of @c num_chains_ + */ + void set_num_chains(int n) { + num_chains_ = n; + } + + /** + * Writes a set of names on a single line in csv format followed + * by a newline. + * + * Note: the names are not escaped. + * + * @param[in] names Names in a std::vector + */ + void operator()(const std::vector& names) { + write_vector(names); + } + + /** + * Writes a set of values in csv format followed by a newline. + * + * Note: the precision of the output is determined by the settings + * of the stream on construction. + * + * @param[in] state Values in a std::vector + */ + void operator()(const std::vector& state) { + write_vector(state); + } + + /** + * Writes the comment_prefix to the stream followed by a newline. + */ + void operator()() { + if (stan::math::mpi::Session::is_in_inter_chain_comm(num_chains_)) { + output_ << comment_prefix_ << std::endl; + } + } + + /** + * Writes the comment_prefix then the message followed by a newline. + * + * @param[in] message A string + */ + void operator()(const std::string& message) { + if (stan::math::mpi::Session::is_in_inter_chain_comm(num_chains_)) { + output_ << comment_prefix_ << message << std::endl; + } + } + + private: + + /** + * nb. of chains that have its own output stream + */ + int num_chains_; + + /** + * Output stream + */ + std::ostream& output_; + + /** + * Comment prefix to use when printing comments: strings and blank lines + */ + std::string comment_prefix_; + + /** + * Writes a set of values in csv format followed by a newline. + * + * Note: the precision of the output is determined by the settings + * of the stream on construction. + * + * @param[in] v Values in a std::vector + */ + template + void write_vector(const std::vector& v) { + if (stan::math::mpi::Session::is_in_inter_chain_comm(num_chains_)) { + if (v.empty()) return; + + typename std::vector::const_iterator last = v.end(); + --last; + + for (typename std::vector::const_iterator it = v.begin(); + it != last; ++it) + output_ << *it << ","; + output_ << v.back() << std::endl; + } + } + }; + + } +} + +#endif + +#endif diff --git a/src/stan/callbacks/stream_writer.hpp b/src/stan/callbacks/stream_writer.hpp index 23483e27250..3191ced10bd 100644 --- a/src/stan/callbacks/stream_writer.hpp +++ b/src/stan/callbacks/stream_writer.hpp @@ -6,137 +6,105 @@ #include #include -#ifdef STAN_LANG_MPI -#include -#endif - namespace stan { -namespace callbacks { - -/** - * stream_writer is an implementation - * of writer that writes to a stream. - */ -class stream_writer : public writer { - public: - /** - * Constructs a stream writer with an output stream - * and an optional prefix for comments. - * - * @param[in, out] output stream to write - * @param[in] comment_prefix string to stream before - * each comment line. Default is "". - */ - explicit stream_writer(std::ostream& output, - const std::string& comment_prefix = "") - : output_(output), comment_prefix_(comment_prefix) {} - - /** - * Virtual destructor - */ - virtual ~stream_writer() {} - - /** - * Writes a set of names on a single line in csv format followed - * by a newline. - * - * Note: the names are not escaped. - * - * @param[in] names Names in a std::vector - */ - void operator()(const std::vector& names) { - write_vector(names); - } - - /** - * Writes a set of values in csv format followed by a newline. - * - * Note: the precision of the output is determined by the settings - * of the stream on construction. - * - * @param[in] state Values in a std::vector - */ - void operator()(const std::vector& state) { - write_vector(state); - } - - /** - * Writes the comment_prefix to the stream followed by a newline. - */ - void operator()() { -#ifdef MPI_ADAPTED_WARMUP - if (stan::math::mpi::Session::is_in_inter_chain_comm(4)) { - output_ << comment_prefix_ << std::endl; - } -#else - output_ << comment_prefix_ << std::endl; -#endif - } + namespace callbacks { + + /** + * stream_writer is an implementation + * of writer that writes to a stream. + */ + class stream_writer : public writer { + public: + /** + * Constructs a stream writer with an output stream + * and an optional prefix for comments. + * + * @param[in, out] output stream to write + * @param[in] comment_prefix string to stream before + * each comment line. Default is "". + */ + stream_writer(std::ostream& output, + const std::string& comment_prefix = ""): + output_(output), comment_prefix_(comment_prefix) {} + + /** + * Virtual destructor + */ + virtual ~stream_writer() {} + + /** + * Writes a set of names on a single line in csv format followed + * by a newline. + * + * Note: the names are not escaped. + * + * @param[in] names Names in a std::vector + */ + void operator()(const std::vector& names) { + write_vector(names); + } + + /** + * Writes a set of values in csv format followed by a newline. + * + * Note: the precision of the output is determined by the settings + * of the stream on construction. + * + * @param[in] state Values in a std::vector + */ + void operator()(const std::vector& state) { + write_vector(state); + } + + /** + * Writes the comment_prefix to the stream followed by a newline. + */ + void operator()() { + output_ << comment_prefix_ << std::endl; + } + + /** + * Writes the comment_prefix then the message followed by a newline. + * + * @param[in] message A string + */ + void operator()(const std::string& message) { + output_ << comment_prefix_ << message << std::endl; + } + + private: + /** + * Output stream + */ + std::ostream& output_; + + /** + * Comment prefix to use when printing comments: strings and blank lines + */ + std::string comment_prefix_; + + /** + * Writes a set of values in csv format followed by a newline. + * + * Note: the precision of the output is determined by the settings + * of the stream on construction. + * + * @param[in] v Values in a std::vector + */ + template + void write_vector(const std::vector& v) { + if (v.empty()) return; + + typename std::vector::const_iterator last = v.end(); + --last; + + for (typename std::vector::const_iterator it = v.begin(); + it != last; ++it) + output_ << *it << ","; + output_ << v.back() << std::endl; + } + }; - /** - * Writes the comment_prefix then the message followed by a newline. - * - * @param[in] message A string - */ - void operator()(const std::string& message) { -#ifdef MPI_ADAPTED_WARMUP - if (stan::math::mpi::Session::is_in_inter_chain_comm(4)) { - output_ << comment_prefix_ << message << std::endl; - } -#else - output_ << comment_prefix_ << message << std::endl; -#endif } - - private: - /** - * Output stream - */ - std::ostream& output_; - - /** - * Comment prefix to use when printing comments: strings and blank lines - */ - std::string comment_prefix_; - - /** - * Writes a set of values in csv format followed by a newline. - * - * Note: the precision of the output is determined by the settings - * of the stream on construction. - * - * @param[in] v Values in a std::vector - */ - template - void write_vector(const std::vector& v) { -#ifdef MPI_ADAPTED_WARMUP - if (stan::math::mpi::Session::is_in_inter_chain_comm(4)) { - if (v.empty()) - return; - - typename std::vector::const_iterator last = v.end(); - --last; - - for (typename std::vector::const_iterator it = v.begin(); it != last; - ++it) - output_ << *it << ","; - output_ << v.back() << std::endl; - } -#else - if (v.empty()) - return; - - typename std::vector::const_iterator last = v.end(); - --last; - - for (typename std::vector::const_iterator it = v.begin(); it != last; - ++it) - output_ << *it << ","; - output_ << v.back() << std::endl; -#endif - } -}; - -} // namespace callbacks -} // namespace stan +} #endif From f53fa46eb64495bd48dae62cf84919013cd8964f Mon Sep 17 00:00:00 2001 From: yiz Date: Tue, 4 Feb 2020 00:53:34 -0800 Subject: [PATCH 48/73] hack: replace num_warmup in output.csv with cross_chain num_warmup --- src/stan/callbacks/mpi_stream_writer.hpp | 27 +++++++++++++++++++ .../services/util/mpi_cross_chain_warmup.hpp | 11 ++++++-- .../util/run_mpi_adaptive_sampler.hpp | 18 +++++++++---- 3 files changed, 49 insertions(+), 7 deletions(-) diff --git a/src/stan/callbacks/mpi_stream_writer.hpp b/src/stan/callbacks/mpi_stream_writer.hpp index 1329ef498d6..4b7bf6dab8e 100644 --- a/src/stan/callbacks/mpi_stream_writer.hpp +++ b/src/stan/callbacks/mpi_stream_writer.hpp @@ -130,6 +130,33 @@ namespace stan { } }; + class mpi_fstream_writer : public mpi_stream_writer { + public: + /** + * hack: foutput stream so we can replace warmup num + */ + const std::string& file_name; + + /** + * Constructs a stream writer with an output stream + * and an optional prefix for comments. + * + * @param[in, out] output stream to write + * @param[in] comment_prefix string to stream before + * each comment line. Default is "". + */ + mpi_fstream_writer(int num_chains, std::ostream& output, + const std::string& fname, + const std::string& comment_prefix = "") + : mpi_stream_writer(num_chains, output, comment_prefix), + file_name(fname) + {} + + /** + * Virtual destructor + */ + virtual ~mpi_fstream_writer() {} + }; } } diff --git a/src/stan/services/util/mpi_cross_chain_warmup.hpp b/src/stan/services/util/mpi_cross_chain_warmup.hpp index 3537faff87d..f34d710a551 100644 --- a/src/stan/services/util/mpi_cross_chain_warmup.hpp +++ b/src/stan/services/util/mpi_cross_chain_warmup.hpp @@ -43,13 +43,14 @@ namespace util { * @param[in,out] logger logger for messages */ template -void mpi_cross_chain_warmup(Sampler& sampler, int num_iterations, +int mpi_cross_chain_warmup(Sampler& sampler, int num_iterations, int start, int finish, int num_thin, int refresh, bool save, bool warmup, util::mcmc_writer& mcmc_writer, stan::mcmc::sample& init_s, Model& model, RNG& base_rng, callbacks::interrupt& callback, callbacks::logger& logger) { + int n_cross_chain_warmup = num_iterations; for (int m = 0; m < num_iterations; ++m) { callback(); @@ -90,10 +91,16 @@ void mpi_cross_chain_warmup(Sampler& sampler, int num_iterations, } init_s = sampler.transition(init_s, logger); - } + + // always save post-adjustment draws + mcmc_writer.write_sample_params(base_rng, init_s, sampler, model); + mcmc_writer.write_diagnostic_params(init_s, sampler); + } + n_cross_chain_warmup = m + 50; break; } } + return n_cross_chain_warmup; } } // namespace util diff --git a/src/stan/services/util/run_mpi_adaptive_sampler.hpp b/src/stan/services/util/run_mpi_adaptive_sampler.hpp index 8ca021661ae..544156ced60 100644 --- a/src/stan/services/util/run_mpi_adaptive_sampler.hpp +++ b/src/stan/services/util/run_mpi_adaptive_sampler.hpp @@ -78,11 +78,11 @@ void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, stan::mcmc::mpi_var_adaptation var_adapt(sampler.z().q.size(), num_warmup, cross_chain_window); sampler.set_cross_chain_var_adaptation(var_adapt); - util::mpi_cross_chain_warmup(sampler, - num_warmup, 0, num_warmup + num_samples, - num_thin, refresh, save_warmup, true, - writer, s, - model, rng, interrupt, logger); + int n_cross_chain_warmup = util::mpi_cross_chain_warmup(sampler, + num_warmup, 0, num_warmup + num_samples, + num_thin, refresh, save_warmup, true, + writer, s, + model, rng, interrupt, logger); clock_t end = clock(); double warm_delta_t = static_cast(end - start) / CLOCKS_PER_SEC; @@ -98,6 +98,14 @@ void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, double sample_delta_t = static_cast(end - start) / CLOCKS_PER_SEC; writer.write_timing(warm_delta_t, sample_delta_t); + + // replace num_warmup with actual one + if (stan::math::mpi::Session::is_in_inter_chain_comm(num_chains)) { + const std::string& file_name = dynamic_cast(sample_writer).file_name; + int i = stan::math::mpi::Session::inter_chain_comm(num_chains).rank(); + std::string sys_call = "awk '{ gsub(\"num_warmup = [0-9]*\", \"num_warmup = " + std::to_string(n_cross_chain_warmup) + "\") ; print $0 }' " + file_name + " > " + std::to_string(i) + ".temp.csv && mv " + std::to_string(i) + ".temp.csv " + file_name; + system(sys_call.c_str()); + } } } // namespace util } // namespace services From 8f0f0b1d2ec4b1ef0409a59ca3335dac30f5bb78 Mon Sep 17 00:00:00 2001 From: yiz Date: Tue, 4 Feb 2020 11:03:33 -0800 Subject: [PATCH 49/73] Revert "hack: replace num_warmup in output.csv with cross_chain num_warmup" This reverts commit f53fa46eb64495bd48dae62cf84919013cd8964f. --- src/stan/callbacks/mpi_stream_writer.hpp | 27 ------------------- .../services/util/mpi_cross_chain_warmup.hpp | 11 ++------ .../util/run_mpi_adaptive_sampler.hpp | 18 ++++--------- 3 files changed, 7 insertions(+), 49 deletions(-) diff --git a/src/stan/callbacks/mpi_stream_writer.hpp b/src/stan/callbacks/mpi_stream_writer.hpp index 4b7bf6dab8e..1329ef498d6 100644 --- a/src/stan/callbacks/mpi_stream_writer.hpp +++ b/src/stan/callbacks/mpi_stream_writer.hpp @@ -130,33 +130,6 @@ namespace stan { } }; - class mpi_fstream_writer : public mpi_stream_writer { - public: - /** - * hack: foutput stream so we can replace warmup num - */ - const std::string& file_name; - - /** - * Constructs a stream writer with an output stream - * and an optional prefix for comments. - * - * @param[in, out] output stream to write - * @param[in] comment_prefix string to stream before - * each comment line. Default is "". - */ - mpi_fstream_writer(int num_chains, std::ostream& output, - const std::string& fname, - const std::string& comment_prefix = "") - : mpi_stream_writer(num_chains, output, comment_prefix), - file_name(fname) - {} - - /** - * Virtual destructor - */ - virtual ~mpi_fstream_writer() {} - }; } } diff --git a/src/stan/services/util/mpi_cross_chain_warmup.hpp b/src/stan/services/util/mpi_cross_chain_warmup.hpp index f34d710a551..3537faff87d 100644 --- a/src/stan/services/util/mpi_cross_chain_warmup.hpp +++ b/src/stan/services/util/mpi_cross_chain_warmup.hpp @@ -43,14 +43,13 @@ namespace util { * @param[in,out] logger logger for messages */ template -int mpi_cross_chain_warmup(Sampler& sampler, int num_iterations, +void mpi_cross_chain_warmup(Sampler& sampler, int num_iterations, int start, int finish, int num_thin, int refresh, bool save, bool warmup, util::mcmc_writer& mcmc_writer, stan::mcmc::sample& init_s, Model& model, RNG& base_rng, callbacks::interrupt& callback, callbacks::logger& logger) { - int n_cross_chain_warmup = num_iterations; for (int m = 0; m < num_iterations; ++m) { callback(); @@ -91,16 +90,10 @@ int mpi_cross_chain_warmup(Sampler& sampler, int num_iterations, } init_s = sampler.transition(init_s, logger); - - // always save post-adjustment draws - mcmc_writer.write_sample_params(base_rng, init_s, sampler, model); - mcmc_writer.write_diagnostic_params(init_s, sampler); - } - n_cross_chain_warmup = m + 50; + } break; } } - return n_cross_chain_warmup; } } // namespace util diff --git a/src/stan/services/util/run_mpi_adaptive_sampler.hpp b/src/stan/services/util/run_mpi_adaptive_sampler.hpp index 544156ced60..8ca021661ae 100644 --- a/src/stan/services/util/run_mpi_adaptive_sampler.hpp +++ b/src/stan/services/util/run_mpi_adaptive_sampler.hpp @@ -78,11 +78,11 @@ void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, stan::mcmc::mpi_var_adaptation var_adapt(sampler.z().q.size(), num_warmup, cross_chain_window); sampler.set_cross_chain_var_adaptation(var_adapt); - int n_cross_chain_warmup = util::mpi_cross_chain_warmup(sampler, - num_warmup, 0, num_warmup + num_samples, - num_thin, refresh, save_warmup, true, - writer, s, - model, rng, interrupt, logger); + util::mpi_cross_chain_warmup(sampler, + num_warmup, 0, num_warmup + num_samples, + num_thin, refresh, save_warmup, true, + writer, s, + model, rng, interrupt, logger); clock_t end = clock(); double warm_delta_t = static_cast(end - start) / CLOCKS_PER_SEC; @@ -98,14 +98,6 @@ void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, double sample_delta_t = static_cast(end - start) / CLOCKS_PER_SEC; writer.write_timing(warm_delta_t, sample_delta_t); - - // replace num_warmup with actual one - if (stan::math::mpi::Session::is_in_inter_chain_comm(num_chains)) { - const std::string& file_name = dynamic_cast(sample_writer).file_name; - int i = stan::math::mpi::Session::inter_chain_comm(num_chains).rank(); - std::string sys_call = "awk '{ gsub(\"num_warmup = [0-9]*\", \"num_warmup = " + std::to_string(n_cross_chain_warmup) + "\") ; print $0 }' " + file_name + " > " + std::to_string(i) + ".temp.csv && mv " + std::to_string(i) + ".temp.csv " + file_name; - system(sys_call.c_str()); - } } } // namespace util } // namespace services From ad706bcfb58f101a2f57a03099d88750fb0ebdd5 Mon Sep 17 00:00:00 2001 From: yiz Date: Tue, 4 Feb 2020 13:20:03 -0800 Subject: [PATCH 50/73] write `num_warmup` before `adaption terminated` in csv --- .../services/util/mpi_cross_chain_warmup.hpp | 19 +++++++++++++++++-- .../util/run_mpi_adaptive_sampler.hpp | 11 ++++++----- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/src/stan/services/util/mpi_cross_chain_warmup.hpp b/src/stan/services/util/mpi_cross_chain_warmup.hpp index 3537faff87d..d3fa786546d 100644 --- a/src/stan/services/util/mpi_cross_chain_warmup.hpp +++ b/src/stan/services/util/mpi_cross_chain_warmup.hpp @@ -43,13 +43,14 @@ namespace util { * @param[in,out] logger logger for messages */ template -void mpi_cross_chain_warmup(Sampler& sampler, int num_iterations, +int mpi_cross_chain_warmup(Sampler& sampler, int num_iterations, int start, int finish, int num_thin, int refresh, bool save, bool warmup, util::mcmc_writer& mcmc_writer, stan::mcmc::sample& init_s, Model& model, RNG& base_rng, callbacks::interrupt& callback, callbacks::logger& logger) { + int num_cross_chain_warmup = 0; for (int m = 0; m < num_iterations; ++m) { callback(); @@ -73,9 +74,13 @@ void mpi_cross_chain_warmup(Sampler& sampler, int num_iterations, mcmc_writer.write_diagnostic_params(init_s, sampler); } + if (m % num_thin == 0) { + num_cross_chain_warmup++; + } + // check cross-chain convergence if (sampler.is_cross_chain_adapted()) { - for (int j = m + 1; j < m + 50; ++j) { + for (int j = m + 1; j < m + 51; ++j) { if (refresh > 0 && (start + j + 1 == finish || j == 0 || (j + 1) % refresh == 0)) { int it_print_width = std::ceil(std::log10(static_cast(finish))); @@ -90,10 +95,20 @@ void mpi_cross_chain_warmup(Sampler& sampler, int num_iterations, } init_s = sampler.transition(init_s, logger); + + if (save && ((j % num_thin) == 0)) { + mcmc_writer.write_sample_params(base_rng, init_s, sampler, model); + mcmc_writer.write_diagnostic_params(init_s, sampler); + } + + if (m % num_thin == 0) { + num_cross_chain_warmup++; + } } break; } } + return num_cross_chain_warmup; } } // namespace util diff --git a/src/stan/services/util/run_mpi_adaptive_sampler.hpp b/src/stan/services/util/run_mpi_adaptive_sampler.hpp index 8ca021661ae..fe20b3b8cb3 100644 --- a/src/stan/services/util/run_mpi_adaptive_sampler.hpp +++ b/src/stan/services/util/run_mpi_adaptive_sampler.hpp @@ -78,13 +78,14 @@ void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, stan::mcmc::mpi_var_adaptation var_adapt(sampler.z().q.size(), num_warmup, cross_chain_window); sampler.set_cross_chain_var_adaptation(var_adapt); - util::mpi_cross_chain_warmup(sampler, - num_warmup, 0, num_warmup + num_samples, - num_thin, refresh, save_warmup, true, - writer, s, - model, rng, interrupt, logger); + int num_cross_chain_warmup = util::mpi_cross_chain_warmup(sampler, + num_warmup, 0, num_warmup + num_samples, + num_thin, refresh, save_warmup, true, + writer, s, + model, rng, interrupt, logger); clock_t end = clock(); double warm_delta_t = static_cast(end - start) / CLOCKS_PER_SEC; + sample_writer("num_warmup = " + std::to_string(num_cross_chain_warmup)); sampler.disengage_adaptation(); writer.write_adapt_finish(sampler); From 630159595cacafa5c28a253dc4ef9685ef7523c2 Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 5 Feb 2020 13:01:39 -0800 Subject: [PATCH 51/73] use Stan's ESS for cross-chain ESS test --- src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp | 212 ++++++------------ src/test/unit/mcmc/hmc/mpi_warmup_test.cpp | 114 +++++----- 2 files changed, 127 insertions(+), 199 deletions(-) diff --git a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp index 9670e564b7a..1d79ebf602c 100644 --- a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include @@ -29,10 +29,11 @@ namespace mcmc { int max_num_windows_; double target_rhat_; double target_ess_; - std::vector log_prob_draws_; + std::vector lp_draws_; + Eigen::MatrixXd all_lp_draws_; std::vector>> log_prob_accumulators_; // NOLINT + boost::accumulators::tag::variance>>> lp_acc_; // NOLINT boost::accumulators::accumulator_set > draw_counter_acc_; Eigen::ArrayXd rhat_; @@ -42,8 +43,25 @@ namespace mcmc { public: mpi_cross_chain_adapter() = default; - inline void set_cross_chain_var_adaptation(mpi_var_adaptation& adapt) - { + mpi_cross_chain_adapter(int num_iterations, int window_size, + int num_chains, + double target_rhat, double target_ess) : + is_adapted_(false), + window_size_(window_size), + num_chains_(num_chains), + max_num_windows_(num_iterations / window_size), + target_rhat_(target_rhat), + target_ess_(target_ess), + lp_draws_(window_size), + all_lp_draws_(window_size_ * max_num_windows_, num_chains_), + lp_acc_(max_num_windows_), + draw_counter_acc_(), + rhat_(Eigen::ArrayXd::Zero(max_num_windows_)), + ess_(Eigen::ArrayXd::Zero(max_num_windows_)) + {} + + + inline void set_cross_chain_var_adaptation(mpi_var_adaptation& adapt) { var_adapt = &adapt; } @@ -57,28 +75,35 @@ namespace mcmc { max_num_windows_ = num_iterations / window_size; target_rhat_ = target_rhat; target_ess_ = target_ess; - log_prob_draws_.clear(); - log_prob_draws_.reserve(num_iterations); - log_prob_accumulators_.clear(); - log_prob_accumulators_.resize(max_num_windows_); + lp_draws_.resize(window_size); + all_lp_draws_.resize(window_size_ * max_num_windows_, num_chains_); + lp_acc_.clear(); + lp_acc_.resize(max_num_windows_); draw_counter_acc_ = {}; rhat_ = Eigen::ArrayXd::Zero(max_num_windows_); - ess_ = Eigen::ArrayXd::Zero(num_chains_); + ess_ = Eigen::ArrayXd::Zero(max_num_windows_); } inline void reset_cross_chain_adaptation() { is_adapted_ = false; - log_prob_draws_.clear(); - log_prob_accumulators_.clear(); - log_prob_accumulators_.resize(max_num_windows_); + lp_draws_.clear(); + lp_acc_.clear(); + lp_acc_.resize(max_num_windows_); draw_counter_acc_ = {}; rhat_ = Eigen::ArrayXd::Zero(max_num_windows_); - ess_ = Eigen::ArrayXd::Zero(num_chains_); + ess_ = Eigen::ArrayXd::Zero(max_num_windows_); var_adapt -> restart(); } + inline int max_num_windows() {return max_num_windows_;} + + /* + * Calculate the number of active windows when NEXT + * sample is added. + */ inline int current_cross_chain_window_counter() { - return (log_prob_draws_.size() - 1) / window_size_ + 1; + size_t n = boost::accumulators::count(draw_counter_acc_) - 1; + return n / window_size_ + 1; } inline void add_cross_chain_sample(const Eigen::VectorXd& q, double s) { @@ -86,126 +111,25 @@ namespace mcmc { using stan::math::mpi::Communicator; if (!is_adapted_) { + + int i = boost::accumulators::count(draw_counter_acc_) % window_size_; + // all procs keep a counter draw_counter_acc_(0); + int n_win = current_cross_chain_window_counter(); // only add samples to inter-chain ranks bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); if (is_inter_rank) { - log_prob_draws_.push_back(s); - int n_win = current_cross_chain_window_counter(); + lp_draws_[i] = s; for (int win = 0; win < n_win; ++win) { - log_prob_accumulators_[win](s); + lp_acc_[win](s); var_adapt -> estimators[win].add_sample(q); } } } } - inline double compute_effective_sample_size(std::vector draws, - std::vector sizes) { - int num_chains = sizes.size(); - size_t num_draws = sizes[0]; - for (int chain = 1; chain < num_chains; ++chain) { - num_draws = std::min(num_draws, sizes[chain]); - } - - // check if chains are constant; all equal to first draw's value - bool are_all_const = false; - Eigen::VectorXd init_draw = Eigen::VectorXd::Zero(num_chains); - - for (int chain_idx = 0; chain_idx < num_chains; chain_idx++) { - Eigen::Map> draw( - draws[chain_idx], sizes[chain_idx]); - - for (int n = 0; n < num_draws; n++) { - if (!boost::math::isfinite(draw(n))) { - return std::numeric_limits::quiet_NaN(); - } - } - - init_draw(chain_idx) = draw(0); - - if (draw.isApproxToConstant(draw(0))) { - are_all_const |= true; - } - } - - if (are_all_const) { - // If all chains are constant then return NaN - // if they all equal the same constant value - if (init_draw.isApproxToConstant(init_draw(0))) { - return std::numeric_limits::quiet_NaN(); - } - } - - Eigen::Matrix acov(num_chains); - Eigen::VectorXd chain_mean(num_chains); - Eigen::VectorXd chain_var(num_chains); - for (int chain = 0; chain < num_chains; ++chain) { - Eigen::Map> draw( - draws[chain], sizes[chain]); - stan::analyze::autocovariance(draw, acov(chain)); - chain_mean(chain) = draw.mean(); - chain_var(chain) = acov(chain)(0) * num_draws / (num_draws - 1); - } - - double mean_var = chain_var.mean(); - double var_plus = mean_var * (num_draws - 1) / num_draws; - if (num_chains > 1) - var_plus += math::variance(chain_mean); - Eigen::VectorXd rho_hat_s(num_draws); - rho_hat_s.setZero(); - Eigen::VectorXd acov_s(num_chains); - for (int chain = 0; chain < num_chains; ++chain) - acov_s(chain) = acov(chain)(1); - double rho_hat_even = 1.0; - rho_hat_s(0) = rho_hat_even; - double rho_hat_odd = 1 - (mean_var - acov_s.mean()) / var_plus; - rho_hat_s(1) = rho_hat_odd; - - // Convert raw autocovariance estimators into Geyer's initial - // positive sequence. Loop only until num_draws - 4 to - // leave the last pair of autocorrelations as a bias term that - // reduces variance in the case of antithetical chains. - size_t s = 1; - while (s < (num_draws - 4) && (rho_hat_even + rho_hat_odd) > 0) { - for (int chain = 0; chain < num_chains; ++chain) - acov_s(chain) = acov(chain)(s + 1); - rho_hat_even = 1 - (mean_var - acov_s.mean()) / var_plus; - for (int chain = 0; chain < num_chains; ++chain) - acov_s(chain) = acov(chain)(s + 2); - rho_hat_odd = 1 - (mean_var - acov_s.mean()) / var_plus; - if ((rho_hat_even + rho_hat_odd) >= 0) { - rho_hat_s(s + 1) = rho_hat_even; - rho_hat_s(s + 2) = rho_hat_odd; - } - s += 2; - } - - int max_s = s; - // this is used in the improved estimate, which reduces variance - // in antithetic case -- see tau_hat below - if (rho_hat_even > 0) - rho_hat_s(max_s + 1) = rho_hat_even; - - // Convert Geyer's initial positive sequence into an initial - // monotone sequence - for (int s = 1; s <= max_s - 3; s += 2) { - if (rho_hat_s(s + 1) + rho_hat_s(s + 2) > rho_hat_s(s - 1) + rho_hat_s(s)) { - rho_hat_s(s + 1) = (rho_hat_s(s - 1) + rho_hat_s(s)) / 2; - rho_hat_s(s + 2) = rho_hat_s(s + 1); - } - } - - double num_total_draws = num_chains * num_draws; - // Geyer's truncated estimator for the asymptotic variance - // Improved estimate reduces variance in antithetic case - double tau_hat = -1 + 2 * rho_hat_s.head(max_s).sum() + rho_hat_s(max_s + 1); - return std::min(num_total_draws / tau_hat, - num_total_draws * std::log10(num_total_draws)); - } - /* * Computes the effective sample size (ESS) for the specified * parameter across all kept samples. The value returned is the @@ -218,10 +142,13 @@ namespace mcmc { * calculated(on the fly during adaptation) * */ - inline double compute_effective_sample_size(size_t i_begin, size_t i_size) { - std::vector draws{log_prob_draws_.data() + i_begin}; - std::vector sizes{i_size}; - return compute_effective_sample_size(draws, sizes); + inline double compute_effective_sample_size(int win, int win_count) { + std::vector draws(num_chains_); + size_t num_draws = (win_count - win) * window_size_; + for (int chain = 0; chain < num_chains_; ++chain) { + draws[chain] = &all_lp_draws_(win * window_size_, chain); + } + return stan::analyze::compute_effective_sample_size(draws, num_draws); } inline const Eigen::ArrayXd& cross_chain_adapt_rhat() { @@ -237,10 +164,6 @@ namespace mcmc { return n > 0 && (n % window_size_ == 0); } - inline bool is_cross_chain_adapt_window_begin() { - return (log_prob_draws_.size() - 1) % window_size_ == 0; - } - inline bool is_cross_chain_adapted() { return is_adapted_; } @@ -254,7 +177,7 @@ namespace mcmc { message << std::setw(5) << std::setprecision(2); message << " Rhat: " << std::fixed << cross_chain_adapt_rhat()[win]; const Eigen::ArrayXd& ess(cross_chain_adapt_ess()); - message << " ESS: " << std::fixed << ess_.matrix().minCoeff(); + message << " ESS: " << std::fixed << ess_[win]; logger.info(message); } @@ -289,23 +212,23 @@ namespace mcmc { if (is_inter_rank) { const Communicator& comm = Session::inter_chain_comm(num_chains_); - const int nd_win = 4; // mean, variance, chain_stepsize + const int nd_win = 3; // mean, variance, chain_stepsize const int win_count = current_cross_chain_window_counter(); - int n_gather = nd_win * win_count; + int n_gather = nd_win * win_count + window_size_; std::vector chain_gather(n_gather, 0.0); for (int win = 0; win < win_count; ++win) { int num_draws = (win_count - win) * window_size_; double unbiased_var_scale = num_draws / (num_draws - 1.0); - chain_gather[nd_win * win] = boost::accumulators::mean(log_prob_accumulators_[win]); - chain_gather[nd_win * win + 1] = boost::accumulators::variance(log_prob_accumulators_[win]) * + chain_gather[nd_win * win] = boost::accumulators::mean(lp_acc_[win]); + chain_gather[nd_win * win + 1] = boost::accumulators::variance(lp_acc_[win]) * unbiased_var_scale; chain_gather[nd_win * win + 2] = chain_stepsize; - chain_gather[nd_win * win + 3] = - compute_effective_sample_size(win * window_size_, num_draws); } + std::copy(lp_draws_.begin(), lp_draws_.end(), + chain_gather.begin() + nd_win * win_count); rhat_ = Eigen::ArrayXd::Zero(max_num_windows_); - ess_ = Eigen::ArrayXd::Zero(num_chains_); + ess_ = Eigen::ArrayXd::Zero(max_num_windows_); const int invalid_win = -999; int adapted_win = invalid_win; @@ -313,6 +236,14 @@ namespace mcmc { std::vector all_chain_gather(n_gather * num_chains_); MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, all_chain_gather.data(), n_gather, MPI_DOUBLE, 0, comm.comm()); + int begin_row = (win_count - 1) * window_size_; + for (int chain = 0; chain < num_chains_; ++chain) { + int j = n_gather * chain + nd_win * win_count; + for (int i = 0; i < window_size_; ++i) { + all_lp_draws_(begin_row + i, chain) = all_chain_gather[j + i]; + } + } + for (int win = 0; win < win_count; ++win) { accumulator_set> acc_chain_mean; accumulator_set> acc_chain_var; @@ -325,15 +256,14 @@ namespace mcmc { chain_var(chain) = all_chain_gather[chain * n_gather + nd_win * win + 1]; acc_chain_var(chain_var(chain)); acc_step(all_chain_gather[chain * n_gather + nd_win * win + 2]); - ess_(chain) = all_chain_gather[chain * n_gather + nd_win * win + 3]; } size_t num_draws = (win_count - win) * window_size_; double var_between = num_draws * boost::accumulators::variance(acc_chain_mean) * num_chains_ / (num_chains_ - 1); double var_within = boost::accumulators::mean(acc_chain_var); rhat_(win) = sqrt((var_between / var_within + num_draws - 1) / num_draws); - // double ess_hmean = ess_.size()/((1.0/ess_).sum()); // harmonic mean - is_adapted_ = rhat_(win) < target_rhat_ && (ess_ > target_ess_).all(); + ess_[win] = compute_effective_sample_size(win, win_count); + is_adapted_ = rhat_(win) < target_rhat_ && ess_[win] > target_ess_; msg_adaptation(win, logger); diff --git a/src/test/unit/mcmc/hmc/mpi_warmup_test.cpp b/src/test/unit/mcmc/hmc/mpi_warmup_test.cpp index f17f83e8ef5..54fba6b9a55 100644 --- a/src/test/unit/mcmc/hmc/mpi_warmup_test.cpp +++ b/src/test/unit/mcmc/hmc/mpi_warmup_test.cpp @@ -39,7 +39,6 @@ struct dummy_sampler { } }; -// 4 chains with 4 cores, each chain run on a core TEST(mpi_warmup_test, mpi_cross_chain_adapter) { stan::callbacks::stream_logger logger(std::cout, std::cout, std::cout, std::cerr, std::cerr); @@ -265,77 +264,76 @@ draw_vecs[3] << double chain_stepsize = 1.1 + 0.1 * comm.rank(); const int num_iterations = window_size * max_num_windows; - stan::mcmc::mpi_cross_chain_adapter cc_adapter; - cc_adapter.set_cross_chain_adaptation_params(num_iterations, - window_size, - num_chains, 1.1, 40); - stan::mcmc::mpi_var_adaptation var_adapt(0, max_num_windows); - cc_adapter.set_cross_chain_var_adaptation(var_adapt); Eigen::VectorXd dummy; // a large ESS target should make all windows fail to pass tests - for (int i = 0; i < num_iterations; ++i) { - cc_adapter.add_cross_chain_sample(dummy, draw_vecs[comm.rank()](i)); + { + stan::mcmc::mpi_cross_chain_adapter cc_adapter(num_iterations, window_size, num_chains, 1.1, 100); + stan::mcmc::mpi_var_adaptation var_adapt(0, cc_adapter.max_num_windows()); + cc_adapter.set_cross_chain_var_adaptation(var_adapt); + for (int i = 0; i < num_iterations; ++i) { + cc_adapter.add_cross_chain_sample(dummy, draw_vecs[comm.rank()](i)); - dummy_sampler sampler(chain_stepsize); - cc_adapter.cross_chain_adaptation(&sampler, dummy, logger); + dummy_sampler sampler(chain_stepsize); + cc_adapter.cross_chain_adaptation(&sampler, dummy, logger); - EXPECT_FALSE(cc_adapter.is_cross_chain_adapted()); + EXPECT_FALSE(cc_adapter.is_cross_chain_adapted()); - if (cc_adapter.is_cross_chain_adapt_window_end()) { - int curr_num_win = cc_adapter.current_cross_chain_window_counter(); - for (int win = 0; win < curr_num_win; ++win) { - const std::vector p{ - draws[0] + win * window_size, - draws[1] + win * window_size, - draws[2] + win * window_size, - draws[3] + win * window_size}; - double rhat = - stan::analyze::compute_potential_scale_reduction(p, (curr_num_win - win) * window_size); - if (comm.rank() == 0) { - EXPECT_FLOAT_EQ(rhat, cc_adapter.cross_chain_adapt_rhat()(win)); + if (cc_adapter.is_cross_chain_adapt_window_end()) { + int curr_win_count = cc_adapter.current_cross_chain_window_counter(); + for (int win = 0; win < curr_win_count; ++win) { + const std::vector p{ + draws[0] + win * window_size, + draws[1] + win * window_size, + draws[2] + win * window_size, + draws[3] + win * window_size}; + double rhat = + stan::analyze::compute_potential_scale_reduction(p, (curr_win_count - win) * window_size); + double ess = + stan::analyze::compute_effective_sample_size(p, (curr_win_count - win) * window_size); + if (comm.rank() == 0) { + EXPECT_FLOAT_EQ(rhat, cc_adapter.cross_chain_adapt_rhat()(win)); + EXPECT_FLOAT_EQ(ess, cc_adapter.cross_chain_adapt_ess()(win)); + } } - } + } } } // a target_ess that 4-window tests should pass - cc_adapter.set_cross_chain_adaptation_params(num_iterations, - window_size, - num_chains, 1.1, 15); - - { - int curr_num_win = 4; - double target_ess = 15.0; - for (int i = 0; i < num_iterations; ++i) { - dummy_sampler sampler(chain_stepsize); - cc_adapter.add_cross_chain_sample(dummy, draw_vecs[comm.rank()](i)); - - double step = chain_stepsize; - cc_adapter.cross_chain_adaptation(&sampler, dummy, logger); - if (cc_adapter.is_cross_chain_adapted()) break; - } - int win = 1; // win = 1 @c is_adapted - const std::vector p{ - draws[0] + win * window_size, - draws[1] + win * window_size, - draws[2] + win * window_size, - draws[3] + win * window_size}; - double rhat = - stan::analyze::compute_potential_scale_reduction(p, (curr_num_win - win) * window_size); - if (comm.rank() == 0) { - EXPECT_FLOAT_EQ(rhat, cc_adapter.cross_chain_adapt_rhat()(win)); - for (int i = win + 1; i < max_num_windows; ++i) { - EXPECT_FLOAT_EQ(0.0, cc_adapter.cross_chain_adapt_rhat()(i)); - } - } + { + int curr_win_count = 4; + stan::mcmc::mpi_cross_chain_adapter cc_adapter(num_iterations, window_size, num_chains, 1.1, 50); + stan::mcmc::mpi_var_adaptation var_adapt(0, cc_adapter.max_num_windows()); + cc_adapter.set_cross_chain_var_adaptation(var_adapt); + for (int i = 0; i < num_iterations; ++i) { + dummy_sampler sampler(chain_stepsize); + cc_adapter.add_cross_chain_sample(dummy, draw_vecs[comm.rank()](i)); + double step = chain_stepsize; + cc_adapter.cross_chain_adaptation(&sampler, dummy, logger); + if (cc_adapter.is_cross_chain_adapted()) break; + } + int win = 0; // win = 0 @c is_adapted + const std::vector p{ + draws[0] + win * window_size, + draws[1] + win * window_size, + draws[2] + win * window_size, + draws[3] + win * window_size}; + double rhat = + stan::analyze::compute_potential_scale_reduction(p, (curr_win_count - win) * window_size); + double ess = + stan::analyze::compute_effective_sample_size(p, (curr_win_count - win) * window_size); + if (comm.rank() == 0) { + EXPECT_FLOAT_EQ(rhat, cc_adapter.cross_chain_adapt_rhat()(win)); + EXPECT_FLOAT_EQ(ess, cc_adapter.cross_chain_adapt_ess()(win)); + for (int i = win + 1; i < max_num_windows; ++i) { + EXPECT_FLOAT_EQ(cc_adapter.cross_chain_adapt_rhat()(i) ,0.0); + EXPECT_FLOAT_EQ(cc_adapter.cross_chain_adapt_ess()(i) ,0.0); + } + } } } #endif - - - - From 0a1c7ed1b22c2d9af950009798c1f9e54e5edb63 Mon Sep 17 00:00:00 2001 From: yiz Date: Thu, 6 Feb 2020 10:28:42 -0800 Subject: [PATCH 52/73] no stepsize reset after cross-chain converges --- src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp index 7dfb7bcedf5..fc30725e2bd 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp @@ -39,17 +39,18 @@ class adapt_diag_e_nuts : public diag_e_nuts, bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_, this->z_.q); +#ifdef MPI_ADAPTED_WARMUP + this -> add_cross_chain_sample(this->z_.q, s.log_prob()); + this -> cross_chain_adaptation(this, this->z_.inv_e_metric_, logger); + if (this -> is_cross_chain_adapted()) update = false; +#endif + if (update) { this->init_stepsize(logger); this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); this->stepsize_adaptation_.restart(); } - -#ifdef MPI_ADAPTED_WARMUP - this -> add_cross_chain_sample(this->z_.q, s.log_prob()); - this -> cross_chain_adaptation(this, this->z_.inv_e_metric_, logger); -#endif } return s; } From 2c75a626d49421b80e9d6fe6554724a2148979fa Mon Sep 17 00:00:00 2001 From: yiz Date: Thu, 6 Feb 2020 21:35:44 -0800 Subject: [PATCH 53/73] init commit mpi_warmup_v2 --- .gitmodules | 2 +- lib/stan_math | 2 +- src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp | 96 ++++++++++++------- src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp | 6 +- src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp | 2 +- src/stan/mcmc/mpi_var_adaptation.hpp | 13 ++- .../services/sample/hmc_nuts_diag_e_adapt.hpp | 9 +- .../services/sample/hmc_nuts_unit_e_adapt.hpp | 9 +- .../services/util/generate_transitions.hpp | 12 ++- src/stan/services/util/mpi_cross_chain.hpp | 79 +++++++++++++++ .../services/util/mpi_cross_chain_warmup.hpp | 37 +------ .../util/run_mpi_adaptive_sampler.hpp | 21 ++-- 12 files changed, 197 insertions(+), 91 deletions(-) create mode 100644 src/stan/services/util/mpi_cross_chain.hpp diff --git a/.gitmodules b/.gitmodules index 8bc3af6a4a0..0162c3dc7cf 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,4 @@ [submodule "lib/stan_math"] path = lib/stan_math url = https://github.com/stan-dev/math.git - branch = mpi_warmup_framework + branch = mpi_warmup_v2 diff --git a/lib/stan_math b/lib/stan_math index ad1525ab441..6b8a0e8e7e2 160000 --- a/lib/stan_math +++ b/lib/stan_math @@ -1 +1 @@ -Subproject commit ad1525ab44178ff781a5130d93920a0e6d2923f1 +Subproject commit 6b8a0e8e7e209ff77c38fb1fcb4c38445a601f1e diff --git a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp index 1d79ebf602c..f5eea4fd7bc 100644 --- a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp @@ -21,9 +21,11 @@ namespace stan { namespace mcmc { + template class mpi_cross_chain_adapter { protected: bool is_adapted_; + bool is_post_warmup_; int window_size_; int num_chains_; int max_num_windows_; @@ -35,18 +37,21 @@ namespace mcmc { boost::accumulators::stats>> lp_acc_; // NOLINT boost::accumulators::accumulator_set > draw_counter_acc_; + boost::accumulators::features > draw_count_acc_; Eigen::ArrayXd rhat_; Eigen::ArrayXd ess_; - mpi_var_adaptation* var_adapt; + std::shared_ptr var_adapt; public: + const static int num_post_warmup = 50; + mpi_cross_chain_adapter() = default; mpi_cross_chain_adapter(int num_iterations, int window_size, int num_chains, double target_rhat, double target_ess) : is_adapted_(false), + is_post_warmup_(false), window_size_(window_size), num_chains_(num_chains), max_num_windows_(num_iterations / window_size), @@ -55,14 +60,13 @@ namespace mcmc { lp_draws_(window_size), all_lp_draws_(window_size_ * max_num_windows_, num_chains_), lp_acc_(max_num_windows_), - draw_counter_acc_(), + draw_count_acc_(), rhat_(Eigen::ArrayXd::Zero(max_num_windows_)), ess_(Eigen::ArrayXd::Zero(max_num_windows_)) {} - - inline void set_cross_chain_var_adaptation(mpi_var_adaptation& adapt) { - var_adapt = &adapt; + inline void set_cross_chain_var_adaptation(int num_params, int num_iterations, int window_size) { + var_adapt = std::shared_ptr(new mpi_var_adaptation(num_params, num_iterations, window_size)); } inline void set_cross_chain_adaptation_params(int num_iterations, @@ -70,6 +74,7 @@ namespace mcmc { int num_chains, double target_rhat, double target_ess) { is_adapted_ = false; + is_post_warmup_ = false, window_size_ = window_size; num_chains_ = num_chains; max_num_windows_ = num_iterations / window_size; @@ -79,17 +84,18 @@ namespace mcmc { all_lp_draws_.resize(window_size_ * max_num_windows_, num_chains_); lp_acc_.clear(); lp_acc_.resize(max_num_windows_); - draw_counter_acc_ = {}; + draw_count_acc_ = {}; rhat_ = Eigen::ArrayXd::Zero(max_num_windows_); ess_ = Eigen::ArrayXd::Zero(max_num_windows_); } inline void reset_cross_chain_adaptation() { is_adapted_ = false; + is_post_warmup_ = false, lp_draws_.clear(); lp_acc_.clear(); lp_acc_.resize(max_num_windows_); - draw_counter_acc_ = {}; + draw_count_acc_ = {}; rhat_ = Eigen::ArrayXd::Zero(max_num_windows_); ess_ = Eigen::ArrayXd::Zero(max_num_windows_); var_adapt -> restart(); @@ -97,34 +103,52 @@ namespace mcmc { inline int max_num_windows() {return max_num_windows_;} + /* + * Calculate the number of draws. + */ + inline int num_cross_chain_draws() { + return boost::accumulators::count(draw_count_acc_); + } + + inline void + write_num_cross_chain_warmup(callbacks::writer& sample_writer, + int num_thin) { + size_t n = num_cross_chain_draws(); + sample_writer("num_warmup = " + std::to_string(n / num_thin)); + } + /* * Calculate the number of active windows when NEXT * sample is added. */ inline int current_cross_chain_window_counter() { - size_t n = boost::accumulators::count(draw_counter_acc_) - 1; + size_t n = num_cross_chain_draws() - 1; return n / window_size_ + 1; } - inline void add_cross_chain_sample(const Eigen::VectorXd& q, double s) { + // only add samples to inter-chain ranks + // CRTP to sampler + inline void add_cross_chain_sample(double s) { using stan::math::mpi::Session; using stan::math::mpi::Communicator; - if (!is_adapted_) { + Sampler& sampler = static_cast(*this); - int i = boost::accumulators::count(draw_counter_acc_) % window_size_; + if (sampler.adapting()) { + int i = num_cross_chain_draws() % window_size_; + draw_count_acc_(0); - // all procs keep a counter - draw_counter_acc_(0); - int n_win = current_cross_chain_window_counter(); + if (!is_adapted_) { + int n_win = current_cross_chain_window_counter(); - // only add samples to inter-chain ranks - bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); - if (is_inter_rank) { - lp_draws_[i] = s; - for (int win = 0; win < n_win; ++win) { - lp_acc_[win](s); - var_adapt -> estimators[win].add_sample(q); + Sampler& sampler = static_cast(*this); + bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); + if (is_inter_rank) { + lp_draws_[i] = s; + for (int win = 0; win < n_win; ++win) { + lp_acc_[win](s); + var_adapt -> estimators[win].add_sample(sampler.z().q); + } } } } @@ -160,7 +184,7 @@ namespace mcmc { } inline bool is_cross_chain_adapt_window_end() { - size_t n = boost::accumulators::count(draw_counter_acc_); + size_t n = num_cross_chain_draws(); return n > 0 && (n % window_size_ == 0); } @@ -168,11 +192,19 @@ namespace mcmc { return is_adapted_; } + inline bool is_post_cross_chain() { + return is_post_warmup_; + } + + inline void set_post_cross_chain() { + is_post_warmup_ = !is_post_warmup_; + } + inline void msg_adaptation(int win, callbacks::logger& logger) { std::stringstream message; message << "iteration: "; message << std::setw(3); - message << boost::accumulators::count(draw_counter_acc_); + message << num_cross_chain_draws(); message << " window: " << win + 1 << " / " << current_cross_chain_window_counter(); message << std::setw(5) << std::setprecision(2); message << " Rhat: " << std::fixed << cross_chain_adapt_rhat()[win]; @@ -192,10 +224,7 @@ namespace mcmc { * maximum windows for all chains. # @return vector {stepsize, rhat(only in rank 0)} */ - template - inline void cross_chain_adaptation(Sampler* hmc_sampler, - Eigen::VectorXd& inv_e_metric, - callbacks::logger& logger) { + inline void cross_chain_adaptation(callbacks::logger& logger) { using boost::accumulators::accumulator_set; using boost::accumulators::stats; using boost::accumulators::tag::mean; @@ -204,8 +233,10 @@ namespace mcmc { using stan::math::mpi::Session; using stan::math::mpi::Communicator; + Sampler& sampler = static_cast(*this); + if ((!is_adapted_) && is_cross_chain_adapt_window_end()) { - double chain_stepsize = hmc_sampler -> get_nominal_stepsize(); + double chain_stepsize = sampler.get_nominal_stepsize(); bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); double invalid_stepsize = -999.0; double new_stepsize = invalid_stepsize; @@ -280,7 +311,7 @@ namespace mcmc { if (adapted_win >= 0) { MPI_Allreduce(&chain_stepsize, &new_stepsize, 1, MPI_DOUBLE, MPI_SUM, comm.comm()); new_stepsize /= num_chains_; - var_adapt -> learn_variance(inv_e_metric, adapted_win, comm); + var_adapt -> learn_variance(sampler.z().inv_e_metric_, adapted_win, comm); } } const Communicator& intra_comm = Session::intra_chain_comm(num_chains_); @@ -288,10 +319,11 @@ namespace mcmc { is_adapted_ = new_stepsize > 0.0; if (is_adapted_) { chain_stepsize = new_stepsize; - MPI_Bcast(inv_e_metric.data(), var_adapt -> estimators[0].num_params(), MPI_DOUBLE, 0, intra_comm.comm()); + MPI_Bcast(sampler.z().inv_e_metric_.data(), + var_adapt -> estimators[0].num_params(), MPI_DOUBLE, 0, intra_comm.comm()); } if (is_adapted_) { - hmc_sampler -> set_nominal_stepsize(chain_stepsize); + sampler.set_nominal_stepsize(chain_stepsize); } } } diff --git a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp index fc30725e2bd..4699ce0dbae 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp @@ -19,7 +19,7 @@ namespace mcmc { template class adapt_diag_e_nuts : public diag_e_nuts, #ifdef MPI_ADAPTED_WARMUP - public mpi_cross_chain_adapter, + public mpi_cross_chain_adapter>, #endif public stepsize_var_adapter { public: @@ -40,8 +40,8 @@ class adapt_diag_e_nuts : public diag_e_nuts, this->z_.q); #ifdef MPI_ADAPTED_WARMUP - this -> add_cross_chain_sample(this->z_.q, s.log_prob()); - this -> cross_chain_adaptation(this, this->z_.inv_e_metric_, logger); + this -> add_cross_chain_sample(s.log_prob()); + this -> cross_chain_adaptation(logger); if (this -> is_cross_chain_adapted()) update = false; #endif diff --git a/src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp index a0f3ed62ff9..b6cd7541278 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp @@ -19,7 +19,7 @@ namespace mcmc { template class adapt_unit_e_nuts : public unit_e_nuts, #ifdef MPI_ADAPTED_WARMUP - public mpi_cross_chain_adapter, + public mpi_cross_chain_adapter>, #endif public stepsize_adapter { public: diff --git a/src/stan/mcmc/mpi_var_adaptation.hpp b/src/stan/mcmc/mpi_var_adaptation.hpp index 82016e81346..68e0b3462d6 100644 --- a/src/stan/mcmc/mpi_var_adaptation.hpp +++ b/src/stan/mcmc/mpi_var_adaptation.hpp @@ -17,6 +17,8 @@ class mpi_var_adaptation { public: std::vector estimators; + mpi_var_adaptation() = default; + mpi_var_adaptation(int n_params, int max_num_windows) : estimators(max_num_windows, est_t(n_params)) {} @@ -34,10 +36,17 @@ class mpi_var_adaptation { } void restart() { - for (auto&& adapt : estimators) { - adapt.restart(); + for (auto&& e : estimators) { + e.restart(); } } + + // void restart(int n_params, int num_iterations, int window_size) { + // estimators.resize(num_iterations / window_size); + // for (auto&& e : estimators) { + // e.restart(n_params); + // } + // } }; } // namespace mcmc diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp index ad4c051dc47..c3c8c8c0f9c 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp @@ -101,9 +101,16 @@ int hmc_nuts_diag_e_adapt( sampler.set_window_params(num_warmup, init_buffer, term_buffer, window, logger); + // cross chain adaptation + sampler.set_cross_chain_adaptation_params(num_warmup, + cross_chain_window, num_cross_chains, + cross_chain_rhat, cross_chain_ess); + sampler.set_cross_chain_var_adaptation(model.num_params_r(), + num_warmup, cross_chain_window); + #ifdef MPI_ADAPTED_WARMUP util::run_mpi_adaptive_sampler(sampler, - model, cont_vector, num_cross_chains, cross_chain_window, cross_chain_rhat, cross_chain_ess, + model, cont_vector, num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); #else diff --git a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp index 14bb154bf62..83bf5199590 100644 --- a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp @@ -73,9 +73,16 @@ int hmc_nuts_unit_e_adapt( sampler.get_stepsize_adaptation().set_kappa(kappa); sampler.get_stepsize_adaptation().set_t0(t0); + // cross chain adaptation + sampler.set_cross_chain_adaptation_params(num_warmup, + cross_chain_window, num_cross_chains, + cross_chain_rhat, cross_chain_ess); + sampler.set_cross_chain_var_adaptation(model.num_params_r(), + num_warmup, cross_chain_window); + #ifdef MPI_ADAPTED_WARMUP util::run_mpi_adaptive_sampler( - sampler, model, cont_vector, num_cross_chains, cross_chain_window, cross_chain_rhat, cross_chain_ess, + sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); #else diff --git a/src/stan/services/util/generate_transitions.hpp b/src/stan/services/util/generate_transitions.hpp index 2c72f2e1138..50373e2ff0a 100644 --- a/src/stan/services/util/generate_transitions.hpp +++ b/src/stan/services/util/generate_transitions.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include namespace stan { @@ -37,8 +38,9 @@ namespace util { * @param[in,out] callback interrupt callback called once an iteration * @param[in,out] logger logger for messages */ -template -void generate_transitions(stan::mcmc::base_mcmc& sampler, int num_iterations, +template ::value>* = nullptr> +void generate_transitions(Sampler& sampler, int num_iterations, int start, int finish, int num_thin, int refresh, bool save, bool warmup, util::mcmc_writer& mcmc_writer, @@ -67,6 +69,12 @@ void generate_transitions(stan::mcmc::base_mcmc& sampler, int num_iterations, mcmc_writer.write_sample_params(base_rng, init_s, sampler, model); mcmc_writer.write_diagnostic_params(init_s, sampler); } + + // check cross-chain convergence + if (mpi_cross_chain::end_transitions(sampler)) { + mpi_cross_chain::set_post_iter(sampler); + break; + } } } diff --git a/src/stan/services/util/mpi_cross_chain.hpp b/src/stan/services/util/mpi_cross_chain.hpp new file mode 100644 index 00000000000..a0586d27d20 --- /dev/null +++ b/src/stan/services/util/mpi_cross_chain.hpp @@ -0,0 +1,79 @@ +#ifndef STAN_SERVICES_UTIL_MPI_CROSS_CHAIN_HPP +#define STAN_SERVICES_UTIL_MPI_CROSS_CHAIN_HPP + +#include +#include +#include +#include +#include +#include + +#ifdef STAN_LANG_MPI +#include +#endif + +namespace stan { +namespace services { +namespace util { + + struct mpi_cross_chain { + template + static bool end_transitions(Sampler& sampler) {return false;} + + template + static void set_post_iter(Sampler& sampler) {} + + static void set_seed(unsigned int& seed, int num_chains) { +#ifdef MPI_ADAPTED_WARMUP + using stan::math::mpi::Session; + using stan::math::mpi::Communicator; + + const Communicator& inter_comm = Session::inter_chain_comm(num_chains); + const Communicator& intra_comm = Session::intra_chain_comm(num_chains); + MPI_Bcast(&seed, 1, MPI_UNSIGNED, 0, MPI_COMM_STAN); + seed += inter_comm.rank(); + MPI_Bcast(&seed, 1, MPI_UNSIGNED, 0, intra_comm.comm()); +#endif + } + + static void set_file(std::string& file_name, int num_chains) { +#ifdef MPI_ADAPTED_WARMUP + using stan::math::mpi::Session; + using stan::math::mpi::Communicator; + + // hard-coded nb. of chains + if (Session::is_in_inter_chain_comm(num_chains)) { + const Communicator& comm = Session::inter_chain_comm(num_chains); + file_name = "mpi." + std::to_string(comm.rank()) + "." + file_name; + } +#endif + } + +// MPI versions +#ifdef MPI_ADAPTED_WARMUP + template + static bool end_transitions(mcmc::adapt_diag_e_nuts& sampler) { + return !sampler.is_post_cross_chain() && sampler.is_cross_chain_adapted(); + } + + template + static bool end_transitions(mcmc::adapt_unit_e_nuts& sampler) { + return !sampler.is_post_cross_chain() && sampler.is_cross_chain_adapted(); + } + + template + static void set_post_iter(mcmc::adapt_diag_e_nuts& sampler) { + sampler.set_post_cross_chain(); + } + + template + static void set_post_iter(mcmc::adapt_unit_e_nuts& sampler) { + sampler.set_post_cross_chain(); + } +#endif + }; +} +} +} + +#endif diff --git a/src/stan/services/util/mpi_cross_chain_warmup.hpp b/src/stan/services/util/mpi_cross_chain_warmup.hpp index d3fa786546d..aa032046476 100644 --- a/src/stan/services/util/mpi_cross_chain_warmup.hpp +++ b/src/stan/services/util/mpi_cross_chain_warmup.hpp @@ -43,14 +43,13 @@ namespace util { * @param[in,out] logger logger for messages */ template -int mpi_cross_chain_warmup(Sampler& sampler, int num_iterations, +void mpi_cross_chain_warmup(Sampler& sampler, int num_iterations, int start, int finish, int num_thin, int refresh, bool save, bool warmup, util::mcmc_writer& mcmc_writer, stan::mcmc::sample& init_s, Model& model, RNG& base_rng, callbacks::interrupt& callback, callbacks::logger& logger) { - int num_cross_chain_warmup = 0; for (int m = 0; m < num_iterations; ++m) { callback(); @@ -74,41 +73,13 @@ int mpi_cross_chain_warmup(Sampler& sampler, int num_iterations, mcmc_writer.write_diagnostic_params(init_s, sampler); } - if (m % num_thin == 0) { - num_cross_chain_warmup++; - } - // check cross-chain convergence - if (sampler.is_cross_chain_adapted()) { - for (int j = m + 1; j < m + 51; ++j) { - if (refresh > 0 - && (start + j + 1 == finish || j == 0 || (j + 1) % refresh == 0)) { - int it_print_width = std::ceil(std::log10(static_cast(finish))); - std::stringstream message; - message << "Iteration: "; - message << std::setw(it_print_width) << j + 1 + start << " / " << finish; - message << " [" << std::setw(3) - << static_cast((100.0 * (start + j + 1)) / finish) << "%] "; - message << (warmup ? " (Warmup)" : " (Sampling)"); - - logger.info(message); - } - - init_s = sampler.transition(init_s, logger); - - if (save && ((j % num_thin) == 0)) { - mcmc_writer.write_sample_params(base_rng, init_s, sampler, model); - mcmc_writer.write_diagnostic_params(init_s, sampler); - } - - if (m % num_thin == 0) { - num_cross_chain_warmup++; - } - } + if ((!sampler.is_post_cross_chain()) && sampler.is_cross_chain_adapted()) { + std::cout << "taki test: " << "break" << "\n"; + sampler.set_post_cross_chain(); break; } } - return num_cross_chain_warmup; } } // namespace util diff --git a/src/stan/services/util/run_mpi_adaptive_sampler.hpp b/src/stan/services/util/run_mpi_adaptive_sampler.hpp index fe20b3b8cb3..df3b57c5814 100644 --- a/src/stan/services/util/run_mpi_adaptive_sampler.hpp +++ b/src/stan/services/util/run_mpi_adaptive_sampler.hpp @@ -42,8 +42,6 @@ namespace util { template void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, std::vector& cont_vector, - int num_chains, int cross_chain_window, - double cross_chain_rhat, int cross_chain_ess, int num_warmup,int num_samples, int num_thin, int refresh, bool save_warmup, RNG& rng, callbacks::interrupt& interrupt, @@ -72,20 +70,15 @@ void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, // warmup clock_t start = clock(); - sampler.set_cross_chain_adaptation_params(num_warmup, - cross_chain_window, num_chains, - cross_chain_rhat, cross_chain_ess); - stan::mcmc::mpi_var_adaptation - var_adapt(sampler.z().q.size(), num_warmup, cross_chain_window); - sampler.set_cross_chain_var_adaptation(var_adapt); - int num_cross_chain_warmup = util::mpi_cross_chain_warmup(sampler, - num_warmup, 0, num_warmup + num_samples, - num_thin, refresh, save_warmup, true, - writer, s, - model, rng, interrupt, logger); + util::generate_transitions(sampler, num_warmup, 0, num_warmup + num_samples, + num_thin, refresh, save_warmup, true, writer, s, + model, rng, interrupt, logger); + util::generate_transitions(sampler, sampler.num_post_warmup, sampler.num_cross_chain_draws(), + num_warmup + num_samples, num_thin, refresh, save_warmup, + true, writer, s, model, rng, interrupt, logger); + sampler.write_num_cross_chain_warmup(sample_writer, num_thin); clock_t end = clock(); double warm_delta_t = static_cast(end - start) / CLOCKS_PER_SEC; - sample_writer("num_warmup = " + std::to_string(num_cross_chain_warmup)); sampler.disengage_adaptation(); writer.write_adapt_finish(sampler); From de0b7543adf74b6cbe3eb4c66a97da53450eddeb Mon Sep 17 00:00:00 2001 From: yiz Date: Fri, 7 Feb 2020 13:28:18 -0800 Subject: [PATCH 54/73] use mpi_cross_chain as single entry to cross-chain methods --- .../services/sample/hmc_nuts_diag_e_adapt.hpp | 21 +--- .../services/sample/hmc_nuts_unit_e_adapt.hpp | 19 ++-- src/stan/services/util/mpi_cross_chain.hpp | 91 ++++++++++++++++- .../services/util/mpi_cross_chain_warmup.hpp | 89 ----------------- .../services/util/run_adaptive_sampler.hpp | 8 ++ .../util/run_mpi_adaptive_sampler.hpp | 99 ------------------- 6 files changed, 109 insertions(+), 218 deletions(-) delete mode 100644 src/stan/services/util/mpi_cross_chain_warmup.hpp delete mode 100644 src/stan/services/util/run_mpi_adaptive_sampler.hpp diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp index c3c8c8c0f9c..7c6165e71fc 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp @@ -15,10 +15,6 @@ #include #include -#ifdef MPI_ADAPTED_WARMUP -#include -#endif - namespace stan { namespace services { namespace sample { @@ -102,22 +98,15 @@ int hmc_nuts_diag_e_adapt( logger); // cross chain adaptation - sampler.set_cross_chain_adaptation_params(num_warmup, - cross_chain_window, num_cross_chains, - cross_chain_rhat, cross_chain_ess); - sampler.set_cross_chain_var_adaptation(model.num_params_r(), - num_warmup, cross_chain_window); + util::mpi_cross_chain::set_params(sampler, num_warmup, + cross_chain_window, num_cross_chains, + cross_chain_rhat, cross_chain_ess); + util::mpi_cross_chain::set_var_adaptation(sampler, model.num_params_r(), + num_warmup, cross_chain_window); -#ifdef MPI_ADAPTED_WARMUP - util::run_mpi_adaptive_sampler(sampler, - model, cont_vector, - num_warmup, num_samples, num_thin, refresh, - save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); -#else util::run_adaptive_sampler( sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); -#endif return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp index 83bf5199590..b97a8ba8ab3 100644 --- a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp @@ -73,23 +73,16 @@ int hmc_nuts_unit_e_adapt( sampler.get_stepsize_adaptation().set_kappa(kappa); sampler.get_stepsize_adaptation().set_t0(t0); - // cross chain adaptation - sampler.set_cross_chain_adaptation_params(num_warmup, - cross_chain_window, num_cross_chains, - cross_chain_rhat, cross_chain_ess); - sampler.set_cross_chain_var_adaptation(model.num_params_r(), - num_warmup, cross_chain_window); + // cross chain adaptation setup + util::mpi_cross_chain::set_params(sampler, num_warmup, + cross_chain_window, num_cross_chains, + cross_chain_rhat, cross_chain_ess); + util::mpi_cross_chain::set_var_adaptation(sampler, model.num_params_r(), + num_warmup, cross_chain_window); -#ifdef MPI_ADAPTED_WARMUP - util::run_mpi_adaptive_sampler( - sampler, model, cont_vector, - num_warmup, num_samples, num_thin, refresh, - save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); -#else util::run_adaptive_sampler( sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); -#endif return error_codes::OK; } diff --git a/src/stan/services/util/mpi_cross_chain.hpp b/src/stan/services/util/mpi_cross_chain.hpp index a0586d27d20..2c17ebed4c8 100644 --- a/src/stan/services/util/mpi_cross_chain.hpp +++ b/src/stan/services/util/mpi_cross_chain.hpp @@ -23,6 +23,29 @@ namespace util { template static void set_post_iter(Sampler& sampler) {} + template + static int num_post_warmup(Sampler& sampler) { return 0;} + + template + static int num_draws(Sampler& sampler) { return 0;} + + template + static void write_num_warmup(Sampler& sampler, + callbacks::writer& sample_writer, + int num_thin) + {} + + template + static void set_params(Sampler& sampler, int num_iterations, + int window_size, int num_chains, + double target_rhat, double target_ess) + {} + + template + static void set_var_adaptation(Sampler& sampler, + int num_params, int num_iterations, int window_size) + {} + static void set_seed(unsigned int& seed, int num_chains) { #ifdef MPI_ADAPTED_WARMUP using stan::math::mpi::Session; @@ -41,7 +64,6 @@ namespace util { using stan::math::mpi::Session; using stan::math::mpi::Communicator; - // hard-coded nb. of chains if (Session::is_in_inter_chain_comm(num_chains)) { const Communicator& comm = Session::inter_chain_comm(num_chains); file_name = "mpi." + std::to_string(comm.rank()) + "." + file_name; @@ -61,6 +83,40 @@ namespace util { return !sampler.is_post_cross_chain() && sampler.is_cross_chain_adapted(); } + template + static int num_post_warmup(mcmc::adapt_diag_e_nuts& sampler) { + return sampler.num_post_warmup; + } + + template + static int num_post_warmup(mcmc::adapt_unit_e_nuts& sampler) { + return sampler.num_post_warmup; + } + + template + static int num_draws(mcmc::adapt_diag_e_nuts& sampler) { + return sampler.num_cross_chain_draws(); + } + + template + static int num_draws(mcmc::adapt_unit_e_nuts& sampler) { + return sampler.num_cross_chain_draws(); + } + + template + static void write_num_warmup(mcmc::adapt_diag_e_nuts& sampler, + callbacks::writer& sample_writer, + int num_thin) { + sampler.write_num_cross_chain_warmup(sample_writer, num_thin); + } + + template + static void write_num_warmup(mcmc::adapt_unit_e_nuts& sampler, + callbacks::writer& sample_writer, + int num_thin) { + sampler.write_num_cross_chain_warmup(sample_writer, num_thin); + } + template static void set_post_iter(mcmc::adapt_diag_e_nuts& sampler) { sampler.set_post_cross_chain(); @@ -70,6 +126,39 @@ namespace util { static void set_post_iter(mcmc::adapt_unit_e_nuts& sampler) { sampler.set_post_cross_chain(); } + + template + static void set_params(mcmc::adapt_diag_e_nuts& sampler, + int num_iterations, + int window_size, int num_chains, + double target_rhat, double target_ess) { + sampler.set_cross_chain_adaptation_params(num_iterations, + window_size, num_chains, + target_rhat, target_ess); + } + + template + static void set_params(mcmc::adapt_unit_e_nuts& sampler, + int num_iterations, + int window_size, int num_chains, + double target_rhat, double target_ess) { + sampler.set_cross_chain_adaptation_params(num_iterations, + window_size, num_chains, + target_rhat, target_ess); + } + + template + static void set_var_adaptation(mcmc::adapt_diag_e_nuts& sampler, + int num_params, int num_iterations, int window_size) { + sampler.set_cross_chain_var_adaptation(num_params, num_iterations, window_size); + } + + template + static void set_var_adaptation(mcmc::adapt_unit_e_nuts& sampler, + int num_params, int num_iterations, int window_size) { + sampler.set_cross_chain_var_adaptation(num_params, num_iterations, window_size); + } + #endif }; } diff --git a/src/stan/services/util/mpi_cross_chain_warmup.hpp b/src/stan/services/util/mpi_cross_chain_warmup.hpp deleted file mode 100644 index aa032046476..00000000000 --- a/src/stan/services/util/mpi_cross_chain_warmup.hpp +++ /dev/null @@ -1,89 +0,0 @@ -#ifndef STAN_SERVICES_UTIL_MPI_CROSS_CHAIN_WARMUP_HPP -#define STAN_SERVICES_UTIL_MPI_CROSS_CHAIN_WARMUP_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace stan { -namespace services { -namespace util { - -/** - * Generates MCMC transitions. - * - * @tparam Model model class - * @tparam RNG random number generator class - * @param[in,out] sampler MCMC sampler used to generate transitions - * @param[in] num_iterations number of MCMC transitions - * @param[in] start starting iteration number used for printing messages - * @param[in] finish end iteration number used for printing messages - * @param[in] num_thin when save is true, a draw will be written to the - * mcmc_writer every num_thin iterations - * @param[in] refresh number of iterations to print a message. If - * refresh is zero, iteration number messages will not be printed - * @param[in] save if save is true, the transitions will be written - * to the mcmc_writer. If false, transitions will not be written - * @param[in] warmup indicates whether these transitions are warmup. Used - * for printing iteration number messages - * @param[in,out] mcmc_writer writer to handle mcmc otuput - * @param[in,out] init_s starts as the initial unconstrained parameter - * values. When the function completes, this will have the final - * iteration's unconstrained parameter values - * @param[in] model model - * @param[in,out] base_rng random number generator - * @param[in,out] callback interrupt callback called once an iteration - * @param[in,out] logger logger for messages - */ -template -void mpi_cross_chain_warmup(Sampler& sampler, int num_iterations, - int start, int finish, int num_thin, int refresh, - bool save, bool warmup, - util::mcmc_writer& mcmc_writer, - stan::mcmc::sample& init_s, Model& model, - RNG& base_rng, callbacks::interrupt& callback, - callbacks::logger& logger) { - for (int m = 0; m < num_iterations; ++m) { - callback(); - - if (refresh > 0 - && (start + m + 1 == finish || m == 0 || (m + 1) % refresh == 0)) { - int it_print_width = std::ceil(std::log10(static_cast(finish))); - std::stringstream message; - message << "Iteration: "; - message << std::setw(it_print_width) << m + 1 + start << " / " << finish; - message << " [" << std::setw(3) - << static_cast((100.0 * (start + m + 1)) / finish) << "%] "; - message << (warmup ? " (Warmup)" : " (Sampling)"); - - logger.info(message); - } - - init_s = sampler.transition(init_s, logger); - - if (save && ((m % num_thin) == 0)) { - mcmc_writer.write_sample_params(base_rng, init_s, sampler, model); - mcmc_writer.write_diagnostic_params(init_s, sampler); - } - - // check cross-chain convergence - if ((!sampler.is_post_cross_chain()) && sampler.is_cross_chain_adapted()) { - std::cout << "taki test: " << "break" << "\n"; - sampler.set_post_cross_chain(); - break; - } - } -} - -} // namespace util -} // namespace services -} // namespace stan - -#endif diff --git a/src/stan/services/util/run_adaptive_sampler.hpp b/src/stan/services/util/run_adaptive_sampler.hpp index c4758eb06c7..7dfc34bace6 100644 --- a/src/stan/services/util/run_adaptive_sampler.hpp +++ b/src/stan/services/util/run_adaptive_sampler.hpp @@ -67,6 +67,14 @@ void run_adaptive_sampler(Sampler& sampler, Model& model, util::generate_transitions(sampler, num_warmup, 0, num_warmup + num_samples, num_thin, refresh, save_warmup, true, writer, s, model, rng, interrupt, logger); + + // cross-chain post convergence iterations + util::generate_transitions(sampler, mpi_cross_chain::num_post_warmup(sampler), + mpi_cross_chain::num_draws(sampler), + num_warmup + num_samples, num_thin, refresh, save_warmup, + true, writer, s, model, rng, interrupt, logger); + mpi_cross_chain::write_num_warmup(sampler, sample_writer, num_thin); + clock_t end = clock(); double warm_delta_t = static_cast(end - start) / CLOCKS_PER_SEC; diff --git a/src/stan/services/util/run_mpi_adaptive_sampler.hpp b/src/stan/services/util/run_mpi_adaptive_sampler.hpp deleted file mode 100644 index df3b57c5814..00000000000 --- a/src/stan/services/util/run_mpi_adaptive_sampler.hpp +++ /dev/null @@ -1,99 +0,0 @@ -#ifndef STAN_SERVICES_UTIL_RUN_MPI_ADAPTIVE_SAMPLER_HPP -#define STAN_SERVICES_UTIL_RUN_MPI_ADAPTIVE_SAMPLER_HPP - -#include -#include -#include -#include -#include -#include - -#ifdef MPI_ADAPTED_WARMUP -#include -#include -#endif - -namespace stan { -namespace services { -namespace util { - -/** - * Runs the sampler with adaptation. - * - * @tparam Sampler Type of adaptive sampler. - * @tparam Model Type of model - * @tparam RNG Type of random number generator - * @param[in,out] sampler the mcmc sampler to use on the model - * @param[in] model the model concept to use for computing log probability - * @param[in] cont_vector initial parameter values - * @param[in] num_warmup number of warmup draws - * @param[in] num_samples number of post warmup draws - * @param[in] num_thin number to thin the draws. Must be greater than - * or equal to 1. - * @param[in] refresh controls output to the logger - * @param[in] save_warmup indicates whether the warmup draws should be - * sent to the sample writer - * @param[in,out] rng random number generator - * @param[in,out] interrupt interrupt callback - * @param[in,out] logger logger for messages - * @param[in,out] sample_writer writer for draws - * @param[in,out] diagnostic_writer writer for diagnostic information - */ -template -void run_mpi_adaptive_sampler(Sampler& sampler, Model& model, - std::vector& cont_vector, - int num_warmup,int num_samples, int num_thin, int refresh, - bool save_warmup, RNG& rng, - callbacks::interrupt& interrupt, - callbacks::logger& logger, - callbacks::writer& sample_writer, - callbacks::writer& diagnostic_writer) { - Eigen::Map cont_params(cont_vector.data(), - cont_vector.size()); - - sampler.engage_adaptation(); - try { - sampler.z().q = cont_params; - sampler.init_stepsize(logger); - } catch (const std::exception& e) { - logger.info("Exception initializing step size."); - logger.info(e.what()); - return; - } - - services::util::mcmc_writer writer(sample_writer, diagnostic_writer, logger); - stan::mcmc::sample s(cont_params, 0, 0); - - // Headers - writer.write_sample_names(s, sampler, model); - writer.write_diagnostic_names(s, sampler, model); - - // warmup - clock_t start = clock(); - util::generate_transitions(sampler, num_warmup, 0, num_warmup + num_samples, - num_thin, refresh, save_warmup, true, writer, s, - model, rng, interrupt, logger); - util::generate_transitions(sampler, sampler.num_post_warmup, sampler.num_cross_chain_draws(), - num_warmup + num_samples, num_thin, refresh, save_warmup, - true, writer, s, model, rng, interrupt, logger); - sampler.write_num_cross_chain_warmup(sample_writer, num_thin); - clock_t end = clock(); - double warm_delta_t = static_cast(end - start) / CLOCKS_PER_SEC; - - sampler.disengage_adaptation(); - writer.write_adapt_finish(sampler); - sampler.write_sampler_state(sample_writer); - - start = clock(); - util::generate_transitions(sampler, num_samples, num_warmup, - num_warmup + num_samples, num_thin, refresh, true, - false, writer, s, model, rng, interrupt, logger); - end = clock(); - double sample_delta_t = static_cast(end - start) / CLOCKS_PER_SEC; - - writer.write_timing(warm_delta_t, sample_delta_t); -} -} // namespace util -} // namespace services -} // namespace stan -#endif From 88cfa38f95654fb5a16e794689fe1dd4d5e78d67 Mon Sep 17 00:00:00 2001 From: yiz Date: Sat, 8 Feb 2020 23:01:02 -0800 Subject: [PATCH 55/73] move MPI_ADAPTED_WARMUP macro into adapter --- src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp | 45 +++++++++++--- src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp | 15 ++--- src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp | 5 -- src/stan/mcmc/mpi_covar_adaptation.hpp | 60 +++++++++++++++++++ src/stan/mcmc/mpi_metric_adaptation.hpp | 35 +++++++++++ src/stan/mcmc/mpi_var_adaptation.hpp | 18 +++++- 6 files changed, 153 insertions(+), 25 deletions(-) create mode 100644 src/stan/mcmc/mpi_covar_adaptation.hpp create mode 100644 src/stan/mcmc/mpi_metric_adaptation.hpp diff --git a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp index f5eea4fd7bc..e8188c55c6d 100644 --- a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp @@ -1,14 +1,11 @@ #ifndef STAN_MCMC_HMC_MPI_CROSS_CHAIN_ADAPTER_HPP #define STAN_MCMC_HMC_MPI_CROSS_CHAIN_ADAPTER_HPP -#ifdef MPI_ADAPTED_WARMUP - #include #include #include -#include +#include #include -#include #include #include #include @@ -18,9 +15,28 @@ #include #include +#ifdef MPI_ADAPTED_WARMUP +#include +#include +#include +#endif + namespace stan { namespace mcmc { +#ifdef MPI_ADAPTED_WARMUP + // template ::value>* = nullptr> + // struct mpi_cross_chain_metric_adapt_t { + // using type = mpi_var_adaptation; + // }; + + // template ::value>* = nullptr> + // struct mpi_cross_chain_metric_adapt_t { + // using type = mpi_covar_adaptation; + // }; + template class mpi_cross_chain_adapter { protected: @@ -311,7 +327,7 @@ namespace mcmc { if (adapted_win >= 0) { MPI_Allreduce(&chain_stepsize, &new_stepsize, 1, MPI_DOUBLE, MPI_SUM, comm.comm()); new_stepsize /= num_chains_; - var_adapt -> learn_variance(sampler.z().inv_e_metric_, adapted_win, comm); + var_adapt -> learn_metric(sampler.z().inv_e_metric_, adapted_win, win_count, comm); } } const Communicator& intra_comm = Session::intra_chain_comm(num_chains_); @@ -320,7 +336,7 @@ namespace mcmc { if (is_adapted_) { chain_stepsize = new_stepsize; MPI_Bcast(sampler.z().inv_e_metric_.data(), - var_adapt -> estimators[0].num_params(), MPI_DOUBLE, 0, intra_comm.comm()); + sampler.z().inv_e_metric_.size(), MPI_DOUBLE, 0, intra_comm.comm()); } if (is_adapted_) { sampler.set_nominal_stepsize(chain_stepsize); @@ -328,8 +344,23 @@ namespace mcmc { } } }; + +#else // sequential version + + template + class mpi_cross_chain_adapter { + public: + inline void add_cross_chain_sample(double s) {} + + inline void cross_chain_adaptation(callbacks::logger& logger) {} + + inline bool is_cross_chain_adapted() { return false; } + }; + +#endif + } } #endif -#endif + diff --git a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp index 4699ce0dbae..908a6d2efb2 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp @@ -4,10 +4,7 @@ #include #include #include - -#ifdef MPI_ADAPTED_WARMUP #include -#endif namespace stan { namespace mcmc { @@ -18,9 +15,7 @@ namespace mcmc { */ template class adapt_diag_e_nuts : public diag_e_nuts, -#ifdef MPI_ADAPTED_WARMUP public mpi_cross_chain_adapter>, -#endif public stepsize_var_adapter { public: adapt_diag_e_nuts(const Model& model, BaseRNG& rng) @@ -39,11 +34,11 @@ class adapt_diag_e_nuts : public diag_e_nuts, bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_, this->z_.q); -#ifdef MPI_ADAPTED_WARMUP - this -> add_cross_chain_sample(s.log_prob()); - this -> cross_chain_adaptation(logger); - if (this -> is_cross_chain_adapted()) update = false; -#endif + // cross-chain adaptation + this -> add_cross_chain_sample(s.log_prob()); + this -> cross_chain_adaptation(logger); + if (this -> is_cross_chain_adapted()) update = false; + // cross-chain adaptation if (update) { this->init_stepsize(logger); diff --git a/src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp index b6cd7541278..3fe3b915a10 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp @@ -4,10 +4,7 @@ #include #include #include - -#ifdef MPI_ADAPTED_WARMUP #include -#endif namespace stan { namespace mcmc { @@ -18,9 +15,7 @@ namespace mcmc { */ template class adapt_unit_e_nuts : public unit_e_nuts, -#ifdef MPI_ADAPTED_WARMUP public mpi_cross_chain_adapter>, -#endif public stepsize_adapter { public: adapt_unit_e_nuts(const Model& model, BaseRNG& rng) diff --git a/src/stan/mcmc/mpi_covar_adaptation.hpp b/src/stan/mcmc/mpi_covar_adaptation.hpp new file mode 100644 index 00000000000..1f1f67126de --- /dev/null +++ b/src/stan/mcmc/mpi_covar_adaptation.hpp @@ -0,0 +1,60 @@ +#ifndef STAN_MCMC_MPI_COVAR_ADAPTATION_HPP +#define STAN_MCMC_MPI_COVAR_ADAPTATION_HPP + +#ifdef STAN_LANG_MPI + +#include +#include +#include +#include + +namespace stan { + +namespace mcmc { + + class mpi_covar_adaptation : public mpi_metric_adaptation { + using est_t = stan::math::mpi::mpi_covar_estimator; + + int window_size_; +public: + est_t estimator; + + mpi_covar_adaptation(int n_params, int num_iterations, int window_size) + : window_size_(window_size), + estimator(n_params, num_iterations) + {} + + virtual void add_sample(const Eigen::VectorXd& q, int curr_win_count) { + estimator.add_sample(q); + } + + virtual void learn_metric(Eigen::MatrixXd& covar, int win, int curr_win_count, + const stan::math::mpi::Communicator& comm) { + int col_begin = win * window_size_; + int num_draws = (curr_win_count - win) * window_size_; + learn_covariance(covar, col_begin, num_draws, comm); + } + + void learn_covariance(Eigen::MatrixXd& covar, + int col_begin, int n_samples, + const stan::math::mpi::Communicator& comm) { + estimator.sample_covariance(covar, col_begin, n_samples, comm); + double n = static_cast(estimator.num_samples(comm)); + covar = (n / (n + 5.0)) * covar + + 1e-3 * (5.0 / (n + 5.0)) + * Eigen::MatrixXd::Identity(covar.rows(), covar.cols()); + restart(); + } + + virtual void restart() { + estimator.restart(); + } +}; + +} // namespace mcmc + +} // namespace stan + +#endif + +#endif diff --git a/src/stan/mcmc/mpi_metric_adaptation.hpp b/src/stan/mcmc/mpi_metric_adaptation.hpp new file mode 100644 index 00000000000..87b9c4d1879 --- /dev/null +++ b/src/stan/mcmc/mpi_metric_adaptation.hpp @@ -0,0 +1,35 @@ +#ifndef STAN_MCMC_MPI_METRIC_ADAPTATION_HPP +#define STAN_MCMC_MPI_METRIC_ADAPTATION_HPP + +#ifdef STAN_LANG_MPI + +#include +#include + +namespace stan { + +namespace mcmc { + + class mpi_metric_adaptation { + public: + mpi_metric_adaptation() = default; + + virtual void add_sample(const Eigen::VectorXd& q, int curr_win_count) {} + + virtual void learn_metric(Eigen::VectorXd& var, int win, int curr_win_count, + const stan::math::mpi::Communicator& comm) + {} + + virtual void learn_metric(Eigen::MatrixXd& covar, int win, int curr_win_count, + const stan::math::mpi::Communicator& comm) + {} + virtual void restart() {} + }; + +} // namespace mcmc + +} // namespace stan + +#endif + +#endif diff --git a/src/stan/mcmc/mpi_var_adaptation.hpp b/src/stan/mcmc/mpi_var_adaptation.hpp index 68e0b3462d6..76a2e088fcf 100644 --- a/src/stan/mcmc/mpi_var_adaptation.hpp +++ b/src/stan/mcmc/mpi_var_adaptation.hpp @@ -4,6 +4,7 @@ #ifdef STAN_LANG_MPI #include +#include #include #include @@ -11,7 +12,7 @@ namespace stan { namespace mcmc { -class mpi_var_adaptation { + class mpi_var_adaptation : public mpi_metric_adaptation { using est_t = stan::math::mpi::mpi_var_estimator; public: @@ -27,7 +28,18 @@ class mpi_var_adaptation { : mpi_var_adaptation(n_params, num_iterations / window_size) {} - void learn_variance(Eigen::VectorXd& var, int win, + virtual void add_sample(Eigen::VectorXd& q, int curr_win_count) { + for (int win = 0; win < curr_win_count; ++win) { + estimators[win].add_sample(q); + } + } + + virtual void learn_metric(Eigen::VectorXd& var, int win, int curr_win_count, + const stan::math::mpi::Communicator& comm) { + learn_variance(var, win, curr_win_count, comm); + } + + void learn_variance(Eigen::VectorXd& var, int win, int curr_win_count, const stan::math::mpi::Communicator& comm) { double n = static_cast(estimators[win].sample_variance(var, comm)); var = (n / (n + 5.0)) * var @@ -35,7 +47,7 @@ class mpi_var_adaptation { restart(); } - void restart() { + virtual void restart() { for (auto&& e : estimators) { e.restart(); } From da2f7b05c02d4467a061f7c5799d7fc036f444eb Mon Sep 17 00:00:00 2001 From: yiz Date: Sun, 9 Feb 2020 21:35:07 -0800 Subject: [PATCH 56/73] simplify cross chain interface --- src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp | 30 +++++------- src/stan/mcmc/mpi_metric_adaptation.hpp | 12 ++--- src/stan/mcmc/mpi_var_adaptation.hpp | 23 ++++----- .../sample/hmc_nuts_dense_e_adapt.hpp | 12 +++-- .../services/sample/hmc_nuts_diag_e_adapt.hpp | 10 ++-- .../services/sample/hmc_nuts_unit_e_adapt.hpp | 11 +++-- src/stan/services/util/mpi_cross_chain.hpp | 48 ++----------------- 7 files changed, 52 insertions(+), 94 deletions(-) diff --git a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp index e8188c55c6d..fc1d4e3f264 100644 --- a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp @@ -5,6 +5,8 @@ #include #include #include +#include +#include #include #include #include @@ -16,8 +18,6 @@ #include #ifdef MPI_ADAPTED_WARMUP -#include -#include #include #endif @@ -25,18 +25,6 @@ namespace stan { namespace mcmc { #ifdef MPI_ADAPTED_WARMUP - // template ::value>* = nullptr> - // struct mpi_cross_chain_metric_adapt_t { - // using type = mpi_var_adaptation; - // }; - - // template ::value>* = nullptr> - // struct mpi_cross_chain_metric_adapt_t { - // using type = mpi_covar_adaptation; - // }; - template class mpi_cross_chain_adapter { protected: @@ -56,7 +44,7 @@ namespace mcmc { boost::accumulators::features > draw_count_acc_; Eigen::ArrayXd rhat_; Eigen::ArrayXd ess_; - std::shared_ptr var_adapt; + mpi_metric_adaptation* var_adapt; public: const static int num_post_warmup = 50; @@ -81,9 +69,7 @@ namespace mcmc { ess_(Eigen::ArrayXd::Zero(max_num_windows_)) {} - inline void set_cross_chain_var_adaptation(int num_params, int num_iterations, int window_size) { - var_adapt = std::shared_ptr(new mpi_var_adaptation(num_params, num_iterations, window_size)); - } + inline void set_cross_chain_metric_adaptation(mpi_metric_adaptation* ptr) {var_adapt = ptr;} inline void set_cross_chain_adaptation_params(int num_iterations, int window_size, @@ -163,8 +149,8 @@ namespace mcmc { lp_draws_[i] = s; for (int win = 0; win < n_win; ++win) { lp_acc_[win](s); - var_adapt -> estimators[win].add_sample(sampler.z().q); } + var_adapt -> add_sample(sampler.z().q, n_win); } } } @@ -350,6 +336,12 @@ namespace mcmc { template class mpi_cross_chain_adapter { public: + inline void set_cross_chain_metric_adaptation(mpi_metric_adaptation* ptr) {} + + inline void set_cross_chain_adaptation_params(int num_iterations, + int window_size, + int num_chains, + double target_rhat, double target_ess) {} inline void add_cross_chain_sample(double s) {} inline void cross_chain_adaptation(callbacks::logger& logger) {} diff --git a/src/stan/mcmc/mpi_metric_adaptation.hpp b/src/stan/mcmc/mpi_metric_adaptation.hpp index 87b9c4d1879..a71506e532a 100644 --- a/src/stan/mcmc/mpi_metric_adaptation.hpp +++ b/src/stan/mcmc/mpi_metric_adaptation.hpp @@ -1,8 +1,6 @@ #ifndef STAN_MCMC_MPI_METRIC_ADAPTATION_HPP #define STAN_MCMC_MPI_METRIC_ADAPTATION_HPP -#ifdef STAN_LANG_MPI - #include #include @@ -12,18 +10,20 @@ namespace mcmc { class mpi_metric_adaptation { public: - mpi_metric_adaptation() = default; + virtual void add_sample(const Eigen::VectorXd& q, int curr_win_count) {}; - virtual void add_sample(const Eigen::VectorXd& q, int curr_win_count) {} + virtual void restart() {} +#ifdef MPI_ADAPTED_WARMUP virtual void learn_metric(Eigen::VectorXd& var, int win, int curr_win_count, const stan::math::mpi::Communicator& comm) {} + virtual void learn_metric(Eigen::MatrixXd& covar, int win, int curr_win_count, const stan::math::mpi::Communicator& comm) {} - virtual void restart() {} +#endif }; } // namespace mcmc @@ -31,5 +31,3 @@ namespace mcmc { } // namespace stan #endif - -#endif diff --git a/src/stan/mcmc/mpi_var_adaptation.hpp b/src/stan/mcmc/mpi_var_adaptation.hpp index 76a2e088fcf..dd47fefba39 100644 --- a/src/stan/mcmc/mpi_var_adaptation.hpp +++ b/src/stan/mcmc/mpi_var_adaptation.hpp @@ -1,18 +1,20 @@ #ifndef STAN_MCMC_MPI_VAR_ADAPTATION_HPP #define STAN_MCMC_MPI_VAR_ADAPTATION_HPP -#ifdef STAN_LANG_MPI - #include #include -#include #include +#ifdef STAN_LANG_MPI +#include +#endif + namespace stan { namespace mcmc { class mpi_var_adaptation : public mpi_metric_adaptation { +#ifdef STAN_LANG_MPI using est_t = stan::math::mpi::mpi_var_estimator; public: @@ -28,7 +30,7 @@ namespace mcmc { : mpi_var_adaptation(n_params, num_iterations / window_size) {} - virtual void add_sample(Eigen::VectorXd& q, int curr_win_count) { + virtual void add_sample(const Eigen::VectorXd& q, int curr_win_count) { for (int win = 0; win < curr_win_count; ++win) { estimators[win].add_sample(q); } @@ -53,18 +55,17 @@ namespace mcmc { } } - // void restart(int n_params, int num_iterations, int window_size) { - // estimators.resize(num_iterations / window_size); - // for (auto&& e : estimators) { - // e.restart(n_params); - // } - // } +#else + public: + mpi_var_adaptation(int n_params, int num_iterations, int window_size) + {} +#endif }; } // namespace mcmc } // namespace stan -#endif + #endif diff --git a/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp index 80adbf0b549..9481c9cb621 100644 --- a/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp @@ -57,7 +57,9 @@ template int hmc_nuts_dense_e_adapt( Model& model, stan::io::var_context& init, stan::io::var_context& init_inv_metric, unsigned int random_seed, - unsigned int chain, double init_radius, int num_warmup, int num_samples, + unsigned int chain, double init_radius, + int num_cross_chains, int cross_chain_window, double cross_chain_rhat, int cross_chain_ess, + int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, double delta, double gamma, double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer, @@ -138,7 +140,9 @@ int hmc_nuts_dense_e_adapt( template int hmc_nuts_dense_e_adapt( Model& model, stan::io::var_context& init, unsigned int random_seed, - unsigned int chain, double init_radius, int num_warmup, int num_samples, + unsigned int chain, double init_radius, + int num_cross_chains, int cross_chain_window, double cross_chain_rhat, int cross_chain_ess, + int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, double delta, double gamma, double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer, @@ -150,7 +154,9 @@ int hmc_nuts_dense_e_adapt( stan::io::var_context& unit_e_metric = dmp; return hmc_nuts_dense_e_adapt( - model, init, unit_e_metric, random_seed, chain, init_radius, num_warmup, + model, init, unit_e_metric, random_seed, chain, init_radius, + num_cross_chains, cross_chain_window, cross_chain_rhat, cross_chain_ess, + num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, logger, init_writer, sample_writer, diagnostic_writer); diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp index 7c6165e71fc..78f9b473b6c 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp @@ -98,11 +98,11 @@ int hmc_nuts_diag_e_adapt( logger); // cross chain adaptation - util::mpi_cross_chain::set_params(sampler, num_warmup, - cross_chain_window, num_cross_chains, - cross_chain_rhat, cross_chain_ess); - util::mpi_cross_chain::set_var_adaptation(sampler, model.num_params_r(), - num_warmup, cross_chain_window); + sampler.set_cross_chain_adaptation_params(num_warmup, + cross_chain_window, num_cross_chains, + cross_chain_rhat, cross_chain_ess); + mcmc::mpi_var_adaptation var_adapt(model.num_params_r(), num_warmup, cross_chain_window); + sampler.set_cross_chain_metric_adaptation(&var_adapt); util::run_adaptive_sampler( sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, diff --git a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp index b97a8ba8ab3..6fcf9ec3db5 100644 --- a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp @@ -74,11 +74,12 @@ int hmc_nuts_unit_e_adapt( sampler.get_stepsize_adaptation().set_t0(t0); // cross chain adaptation setup - util::mpi_cross_chain::set_params(sampler, num_warmup, - cross_chain_window, num_cross_chains, - cross_chain_rhat, cross_chain_ess); - util::mpi_cross_chain::set_var_adaptation(sampler, model.num_params_r(), - num_warmup, cross_chain_window); + sampler.set_cross_chain_adaptation_params(num_warmup, + cross_chain_window, num_cross_chains, + cross_chain_rhat, cross_chain_ess); + mcmc::mpi_var_adaptation var_adapt(model.num_params_r(), num_warmup, cross_chain_window); + sampler.set_cross_chain_metric_adaptation(&var_adapt); + util::run_adaptive_sampler( sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, diff --git a/src/stan/services/util/mpi_cross_chain.hpp b/src/stan/services/util/mpi_cross_chain.hpp index 2c17ebed4c8..e3e90dc8863 100644 --- a/src/stan/services/util/mpi_cross_chain.hpp +++ b/src/stan/services/util/mpi_cross_chain.hpp @@ -16,6 +16,10 @@ namespace stan { namespace services { namespace util { + /* + * Helper functions for samplers with MPI WARMUP. Other + * samplers have dummy implmenentation. + */ struct mpi_cross_chain { template static bool end_transitions(Sampler& sampler) {return false;} @@ -35,17 +39,6 @@ namespace util { int num_thin) {} - template - static void set_params(Sampler& sampler, int num_iterations, - int window_size, int num_chains, - double target_rhat, double target_ess) - {} - - template - static void set_var_adaptation(Sampler& sampler, - int num_params, int num_iterations, int window_size) - {} - static void set_seed(unsigned int& seed, int num_chains) { #ifdef MPI_ADAPTED_WARMUP using stan::math::mpi::Session; @@ -126,39 +119,6 @@ namespace util { static void set_post_iter(mcmc::adapt_unit_e_nuts& sampler) { sampler.set_post_cross_chain(); } - - template - static void set_params(mcmc::adapt_diag_e_nuts& sampler, - int num_iterations, - int window_size, int num_chains, - double target_rhat, double target_ess) { - sampler.set_cross_chain_adaptation_params(num_iterations, - window_size, num_chains, - target_rhat, target_ess); - } - - template - static void set_params(mcmc::adapt_unit_e_nuts& sampler, - int num_iterations, - int window_size, int num_chains, - double target_rhat, double target_ess) { - sampler.set_cross_chain_adaptation_params(num_iterations, - window_size, num_chains, - target_rhat, target_ess); - } - - template - static void set_var_adaptation(mcmc::adapt_diag_e_nuts& sampler, - int num_params, int num_iterations, int window_size) { - sampler.set_cross_chain_var_adaptation(num_params, num_iterations, window_size); - } - - template - static void set_var_adaptation(mcmc::adapt_unit_e_nuts& sampler, - int num_params, int num_iterations, int window_size) { - sampler.set_cross_chain_var_adaptation(num_params, num_iterations, window_size); - } - #endif }; } From 203e2717b10fdd77d8e2b88c09492ea2acb5d9ce Mon Sep 17 00:00:00 2001 From: yiz Date: Sun, 9 Feb 2020 22:28:20 -0800 Subject: [PATCH 57/73] mpi warmup for dense_e nuts --- lib/stan_math | 2 +- src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp | 8 ++++++ src/stan/mcmc/mpi_covar_adaptation.hpp | 15 +++++++--- src/stan/mcmc/mpi_metric_adaptation.hpp | 4 +++ src/stan/mcmc/mpi_var_adaptation.hpp | 2 +- .../sample/hmc_nuts_dense_e_adapt.hpp | 7 +++++ src/stan/services/util/mpi_cross_chain.hpp | 28 +++++++++++++++++++ 7 files changed, 60 insertions(+), 6 deletions(-) diff --git a/lib/stan_math b/lib/stan_math index 6b8a0e8e7e2..fd218be2cf2 160000 --- a/lib/stan_math +++ b/lib/stan_math @@ -1 +1 @@ -Subproject commit 6b8a0e8e7e209ff77c38fb1fcb4c38445a601f1e +Subproject commit fd218be2cf261acd31d539bbf412c1a9b65530e3 diff --git a/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp index 05b6c80523f..17312a96720 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp @@ -4,6 +4,7 @@ #include #include #include +#include namespace stan { namespace mcmc { @@ -14,6 +15,7 @@ namespace mcmc { */ template class adapt_dense_e_nuts : public dense_e_nuts, + public mpi_cross_chain_adapter>, public stepsize_covar_adapter { public: adapt_dense_e_nuts(const Model& model, BaseRNG& rng) @@ -32,6 +34,12 @@ class adapt_dense_e_nuts : public dense_e_nuts, bool update = this->covar_adaptation_.learn_covariance( this->z_.inv_e_metric_, this->z_.q); + // cross-chain adaptation + this -> add_cross_chain_sample(s.log_prob()); + this -> cross_chain_adaptation(logger); + if (this -> is_cross_chain_adapted()) update = false; + // cross-chain adaptation + if (update) { this->init_stepsize(logger); diff --git a/src/stan/mcmc/mpi_covar_adaptation.hpp b/src/stan/mcmc/mpi_covar_adaptation.hpp index 1f1f67126de..35e68053687 100644 --- a/src/stan/mcmc/mpi_covar_adaptation.hpp +++ b/src/stan/mcmc/mpi_covar_adaptation.hpp @@ -1,18 +1,20 @@ #ifndef STAN_MCMC_MPI_COVAR_ADAPTATION_HPP #define STAN_MCMC_MPI_COVAR_ADAPTATION_HPP -#ifdef STAN_LANG_MPI - #include #include -#include #include +#ifdef STAN_LANG_MPI +#include +#endif + namespace stan { namespace mcmc { class mpi_covar_adaptation : public mpi_metric_adaptation { +#ifdef STAN_LANG_MPI using est_t = stan::math::mpi::mpi_covar_estimator; int window_size_; @@ -49,12 +51,17 @@ namespace mcmc { virtual void restart() { estimator.restart(); } +#else + public: + mpi_covar_adaptation(int n_params, int num_iterations, int window_size) + {} +#endif }; } // namespace mcmc } // namespace stan -#endif + #endif diff --git a/src/stan/mcmc/mpi_metric_adaptation.hpp b/src/stan/mcmc/mpi_metric_adaptation.hpp index a71506e532a..21c6df72c9f 100644 --- a/src/stan/mcmc/mpi_metric_adaptation.hpp +++ b/src/stan/mcmc/mpi_metric_adaptation.hpp @@ -4,6 +4,10 @@ #include #include +#ifdef STAN_LANG_MPI +#include +#endif + namespace stan { namespace mcmc { diff --git a/src/stan/mcmc/mpi_var_adaptation.hpp b/src/stan/mcmc/mpi_var_adaptation.hpp index dd47fefba39..2242cbbef91 100644 --- a/src/stan/mcmc/mpi_var_adaptation.hpp +++ b/src/stan/mcmc/mpi_var_adaptation.hpp @@ -58,7 +58,7 @@ namespace mcmc { #else public: mpi_var_adaptation(int n_params, int num_iterations, int window_size) - {} + {} #endif }; diff --git a/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp index 9481c9cb621..0c8ab83bcb5 100644 --- a/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp @@ -98,6 +98,13 @@ int hmc_nuts_dense_e_adapt( sampler.set_window_params(num_warmup, init_buffer, term_buffer, window, logger); + // cross chain adaptation + sampler.set_cross_chain_adaptation_params(num_warmup, + cross_chain_window, num_cross_chains, + cross_chain_rhat, cross_chain_ess); + mcmc::mpi_covar_adaptation var_adapt(model.num_params_r(), num_warmup, cross_chain_window); + sampler.set_cross_chain_metric_adaptation(&var_adapt); + util::run_adaptive_sampler( sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); diff --git a/src/stan/services/util/mpi_cross_chain.hpp b/src/stan/services/util/mpi_cross_chain.hpp index e3e90dc8863..92b76f29bd9 100644 --- a/src/stan/services/util/mpi_cross_chain.hpp +++ b/src/stan/services/util/mpi_cross_chain.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #ifdef STAN_LANG_MPI @@ -76,6 +77,11 @@ namespace util { return !sampler.is_post_cross_chain() && sampler.is_cross_chain_adapted(); } + template + static bool end_transitions(mcmc::adapt_dense_e_nuts& sampler) { + return !sampler.is_post_cross_chain() && sampler.is_cross_chain_adapted(); + } + template static int num_post_warmup(mcmc::adapt_diag_e_nuts& sampler) { return sampler.num_post_warmup; @@ -86,6 +92,11 @@ namespace util { return sampler.num_post_warmup; } + template + static int num_post_warmup(mcmc::adapt_dense_e_nuts& sampler) { + return sampler.num_post_warmup; + } + template static int num_draws(mcmc::adapt_diag_e_nuts& sampler) { return sampler.num_cross_chain_draws(); @@ -96,6 +107,11 @@ namespace util { return sampler.num_cross_chain_draws(); } + template + static int num_draws(mcmc::adapt_dense_e_nuts& sampler) { + return sampler.num_cross_chain_draws(); + } + template static void write_num_warmup(mcmc::adapt_diag_e_nuts& sampler, callbacks::writer& sample_writer, @@ -110,6 +126,13 @@ namespace util { sampler.write_num_cross_chain_warmup(sample_writer, num_thin); } + template + static void write_num_warmup(mcmc::adapt_dense_e_nuts& sampler, + callbacks::writer& sample_writer, + int num_thin) { + sampler.write_num_cross_chain_warmup(sample_writer, num_thin); + } + template static void set_post_iter(mcmc::adapt_diag_e_nuts& sampler) { sampler.set_post_cross_chain(); @@ -119,6 +142,11 @@ namespace util { static void set_post_iter(mcmc::adapt_unit_e_nuts& sampler) { sampler.set_post_cross_chain(); } + + template + static void set_post_iter(mcmc::adapt_dense_e_nuts& sampler) { + sampler.set_post_cross_chain(); + } #endif }; } From cce0945013697ad8a6e6ccf762b0ed993d1af5d8 Mon Sep 17 00:00:00 2001 From: yiz Date: Mon, 10 Feb 2020 10:29:46 -0800 Subject: [PATCH 58/73] tests for mpi_covar_adpt. type traits for cross-chain warmup sampler --- .../services/util/generate_transitions.hpp | 4 +- src/stan/services/util/mpi_cross_chain.hpp | 175 +++++++++--------- .../services/util/run_adaptive_sampler.hpp | 6 +- .../unit/mcmc/mpi_covar_adaptation_test.cpp | 101 ++++++++++ .../unit/mcmc/mpi_var_adaptation_test.cpp | 12 +- 5 files changed, 199 insertions(+), 99 deletions(-) create mode 100644 src/test/unit/mcmc/mpi_covar_adaptation_test.cpp diff --git a/src/stan/services/util/generate_transitions.hpp b/src/stan/services/util/generate_transitions.hpp index 50373e2ff0a..47bcb10b5ff 100644 --- a/src/stan/services/util/generate_transitions.hpp +++ b/src/stan/services/util/generate_transitions.hpp @@ -71,8 +71,8 @@ void generate_transitions(Sampler& sampler, int num_iterations, } // check cross-chain convergence - if (mpi_cross_chain::end_transitions(sampler)) { - mpi_cross_chain::set_post_iter(sampler); + if (mpi_cross_chain::end_transitions(sampler)) { + mpi_cross_chain::set_post_iter(sampler); break; } } diff --git a/src/stan/services/util/mpi_cross_chain.hpp b/src/stan/services/util/mpi_cross_chain.hpp index 92b76f29bd9..b4dc2408805 100644 --- a/src/stan/services/util/mpi_cross_chain.hpp +++ b/src/stan/services/util/mpi_cross_chain.hpp @@ -17,138 +17,137 @@ namespace stan { namespace services { namespace util { + template + struct has_cross_chain_warmup { + static const bool value = false; + }; + + + template + struct has_cross_chain_warmup> { + static const bool value = true; + }; + + template + struct has_cross_chain_warmup> { + static const bool value = true; + }; + + template + struct has_cross_chain_warmup> { + static const bool value = true; + }; + /* * Helper functions for samplers with MPI WARMUP. Other * samplers have dummy implmenentation. */ - struct mpi_cross_chain { - template + template + struct mpi_cross_chain_impl { static bool end_transitions(Sampler& sampler) {return false;} - template static void set_post_iter(Sampler& sampler) {} - template static int num_post_warmup(Sampler& sampler) { return 0;} - template static int num_draws(Sampler& sampler) { return 0;} - template static void write_num_warmup(Sampler& sampler, callbacks::writer& sample_writer, - int num_thin) - {} - - static void set_seed(unsigned int& seed, int num_chains) { -#ifdef MPI_ADAPTED_WARMUP - using stan::math::mpi::Session; - using stan::math::mpi::Communicator; - - const Communicator& inter_comm = Session::inter_chain_comm(num_chains); - const Communicator& intra_comm = Session::intra_chain_comm(num_chains); - MPI_Bcast(&seed, 1, MPI_UNSIGNED, 0, MPI_COMM_STAN); - seed += inter_comm.rank(); - MPI_Bcast(&seed, 1, MPI_UNSIGNED, 0, intra_comm.comm()); -#endif - } - - static void set_file(std::string& file_name, int num_chains) { -#ifdef MPI_ADAPTED_WARMUP - using stan::math::mpi::Session; - using stan::math::mpi::Communicator; - - if (Session::is_in_inter_chain_comm(num_chains)) { - const Communicator& comm = Session::inter_chain_comm(num_chains); - file_name = "mpi." + std::to_string(comm.rank()) + "." + file_name; - } -#endif - } + int num_thin) {} + }; -// MPI versions + /* + * Partial specialization that is only active for MPI warmups + */ #ifdef MPI_ADAPTED_WARMUP - template - static bool end_transitions(mcmc::adapt_diag_e_nuts& sampler) { - return !sampler.is_post_cross_chain() && sampler.is_cross_chain_adapted(); - } - - template - static bool end_transitions(mcmc::adapt_unit_e_nuts& sampler) { + template + struct mpi_cross_chain_impl { + static bool end_transitions(Sampler& sampler) { return !sampler.is_post_cross_chain() && sampler.is_cross_chain_adapted(); } - template - static bool end_transitions(mcmc::adapt_dense_e_nuts& sampler) { - return !sampler.is_post_cross_chain() && sampler.is_cross_chain_adapted(); + static void set_post_iter(Sampler& sampler) { + sampler.set_post_cross_chain(); } - template - static int num_post_warmup(mcmc::adapt_diag_e_nuts& sampler) { + static int num_post_warmup(Sampler& sampler) { return sampler.num_post_warmup; } - template - static int num_post_warmup(mcmc::adapt_unit_e_nuts& sampler) { - return sampler.num_post_warmup; + static int num_draws(Sampler& sampler) { + return sampler.num_cross_chain_draws(); } - template - static int num_post_warmup(mcmc::adapt_dense_e_nuts& sampler) { - return sampler.num_post_warmup; + static void write_num_warmup(Sampler& sampler, + callbacks::writer& sample_writer, + int num_thin) { + sampler.write_num_cross_chain_warmup(sample_writer, num_thin); } + }; +#endif - template - static int num_draws(mcmc::adapt_diag_e_nuts& sampler) { - return sampler.num_cross_chain_draws(); + template + struct mpi_cross_chain { + static bool end_transitions(Sampler& sampler) { + return mpi_cross_chain_impl::value>:: + end_transitions(sampler); } - template - static int num_draws(mcmc::adapt_unit_e_nuts& sampler) { - return sampler.num_cross_chain_draws(); - } - - template - static int num_draws(mcmc::adapt_dense_e_nuts& sampler) { - return sampler.num_cross_chain_draws(); + static void set_post_iter(Sampler& sampler) { + mpi_cross_chain_impl::value>:: + set_post_iter(sampler); } - template - static void write_num_warmup(mcmc::adapt_diag_e_nuts& sampler, - callbacks::writer& sample_writer, - int num_thin) { - sampler.write_num_cross_chain_warmup(sample_writer, num_thin); + static int num_post_warmup(Sampler& sampler) { + return mpi_cross_chain_impl::value>:: + num_post_warmup(sampler); } - template - static void write_num_warmup(mcmc::adapt_unit_e_nuts& sampler, - callbacks::writer& sample_writer, - int num_thin) { - sampler.write_num_cross_chain_warmup(sample_writer, num_thin); + static int num_draws(Sampler& sampler) { + return mpi_cross_chain_impl::value>:: + num_draws(sampler); } - template - static void write_num_warmup(mcmc::adapt_dense_e_nuts& sampler, + static void write_num_warmup(Sampler& sampler, callbacks::writer& sample_writer, int num_thin) { - sampler.write_num_cross_chain_warmup(sample_writer, num_thin); + mpi_cross_chain_impl::value>:: + write_num_warmup(sampler, sample_writer, num_thin); } + }; - template - static void set_post_iter(mcmc::adapt_diag_e_nuts& sampler) { - sampler.set_post_cross_chain(); - } - template - static void set_post_iter(mcmc::adapt_unit_e_nuts& sampler) { - sampler.set_post_cross_chain(); - } + /* + * modify cmdstan::command seed + */ + void set_cross_chain_seed(unsigned int& seed, int num_chains) { +#ifdef MPI_ADAPTED_WARMUP + using stan::math::mpi::Session; + using stan::math::mpi::Communicator; - template - static void set_post_iter(mcmc::adapt_dense_e_nuts& sampler) { - sampler.set_post_cross_chain(); + const Communicator& inter_comm = Session::inter_chain_comm(num_chains); + const Communicator& intra_comm = Session::intra_chain_comm(num_chains); + MPI_Bcast(&seed, 1, MPI_UNSIGNED, 0, MPI_COMM_STAN); + seed += inter_comm.rank(); + MPI_Bcast(&seed, 1, MPI_UNSIGNED, 0, intra_comm.comm()); +#endif + } + + /* + * modify cmdstan::command file + */ + void set_cross_chain_file(std::string& file_name, int num_chains) { +#ifdef MPI_ADAPTED_WARMUP + using stan::math::mpi::Session; + using stan::math::mpi::Communicator; + + if (Session::is_in_inter_chain_comm(num_chains)) { + const Communicator& comm = Session::inter_chain_comm(num_chains); + file_name = "mpi." + std::to_string(comm.rank()) + "." + file_name; } #endif - }; + } } } } diff --git a/src/stan/services/util/run_adaptive_sampler.hpp b/src/stan/services/util/run_adaptive_sampler.hpp index 7dfc34bace6..345e7ebde33 100644 --- a/src/stan/services/util/run_adaptive_sampler.hpp +++ b/src/stan/services/util/run_adaptive_sampler.hpp @@ -69,11 +69,11 @@ void run_adaptive_sampler(Sampler& sampler, Model& model, model, rng, interrupt, logger); // cross-chain post convergence iterations - util::generate_transitions(sampler, mpi_cross_chain::num_post_warmup(sampler), - mpi_cross_chain::num_draws(sampler), + util::generate_transitions(sampler, mpi_cross_chain::num_post_warmup(sampler), + mpi_cross_chain::num_draws(sampler), num_warmup + num_samples, num_thin, refresh, save_warmup, true, writer, s, model, rng, interrupt, logger); - mpi_cross_chain::write_num_warmup(sampler, sample_writer, num_thin); + mpi_cross_chain::write_num_warmup(sampler, sample_writer, num_thin); clock_t end = clock(); double warm_delta_t = static_cast(end - start) / CLOCKS_PER_SEC; diff --git a/src/test/unit/mcmc/mpi_covar_adaptation_test.cpp b/src/test/unit/mcmc/mpi_covar_adaptation_test.cpp new file mode 100644 index 00000000000..10480389b8b --- /dev/null +++ b/src/test/unit/mcmc/mpi_covar_adaptation_test.cpp @@ -0,0 +1,101 @@ +#ifdef STAN_LANG_MPI + +#include +#include +#include +#include +#include + +TEST(McmcVarAdaptation, mpi_learn_covariance) { + stan::test::unit::instrumented_logger logger; + + const int n = 10; + Eigen::VectorXd q = Eigen::VectorXd::Zero(n); + Eigen::MatrixXd covar(Eigen::MatrixXd::Zero(n, n)); + + const int n_learn = 12; + + Eigen::MatrixXd target_covar(Eigen::MatrixXd::Identity(n, n)); + target_covar *= 1e-3 * 5.0 / (n_learn + 5.0); + + stan::mcmc::covar_adaptation adapter(n); + adapter.set_window_params(50, 0, 0, n_learn, logger); + + for (int i = 0; i < n_learn; ++i) + adapter.learn_covariance(covar, q); + + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n; ++j) { + EXPECT_EQ(target_covar(i, j), covar(i, j)); + } + } + EXPECT_EQ(0, logger.call_count()); + + stan::math::mpi::Communicator comm(MPI_COMM_STAN); + const int num_chains = comm.size(); // must be <= 4 + if (n_learn % num_chains != 0) + throw std::domain_error("this test function was called with inconsistent MPI COMM size"); + + const int n_learn_chain = n_learn / num_chains; + stan::mcmc::mpi_covar_adaptation mpi_adapter(n, n_learn_chain, n_learn_chain); + Eigen::MatrixXd mpi_covar(Eigen::MatrixXd::Zero(n, n)); + for (int i = 0; i < n_learn_chain; ++i) + mpi_adapter.add_sample(q, 1); + + mpi_adapter.learn_metric(mpi_covar, 0, 1, comm); + + for (int i = 0; i < covar.size(); ++i) { + EXPECT_FLOAT_EQ(covar(i), mpi_covar(i)); + } +} + +TEST(McmcVarAdaptation, mpi_data_learn_covariance) { + stan::test::unit::instrumented_logger logger; + + const int n = 5; + const int n_learn = 12; + Eigen::VectorXd q_all(n * n_learn); + q_all << + -276.606, -277.168, -272.621, -271.142, -271.95 , + -269.749, -267.016, -273.508, -268.65 , -265.904, + -264.629, -260.797, -263.184, -263.892, -268.81 , + -272.563, -268.32 , -266.297, -265.787, -266.073, + -265.788, -262.26 , -265.073, -265.511, -264.318, + -264.318, -266.261, -265.633, -265.323, -265.633, + -265.426, -265.69 , -266.122, -264.876, -264.829, + -264.238, -265.822, -262.979, -264.012, -263.801, + -264.745, -263.94 , -263.586, -263.284, -262.566, + -261.816, -265.308, -266.467, -265.915, -266.122, + -266.122, -265.903, -265.903, -265.717, -271.78 , + -271.78 , -271.712, -271.712, -271.011, -273.137; + + stan::mcmc::covar_adaptation adapter(n); + adapter.set_window_params(50, 0, 0, n_learn, logger); + Eigen::MatrixXd covar(Eigen::MatrixXd::Zero(n, n)); + for (int i = 0; i < n_learn; ++i) { + Eigen::VectorXd q = Eigen::VectorXd::Map(&q_all(i * n), n); + adapter.learn_covariance(covar, q); + } + + EXPECT_EQ(0, logger.call_count()); + + stan::math::mpi::Communicator comm(MPI_COMM_STAN); + const int num_chains = comm.size(); + if (n_learn % num_chains != 0) + throw std::domain_error("this test function was called with inconsistent MPI COMM size"); + const int n_learn_chain = n_learn / num_chains; + stan::mcmc::mpi_covar_adaptation mpi_adapter(n, n_learn_chain, n_learn_chain); + Eigen::MatrixXd mpi_covar(Eigen::MatrixXd::Zero(n, n)); + for (int i = 0; i < n_learn_chain; ++i) { + Eigen::VectorXd q = + Eigen::VectorXd::Map(&q_all(i * n + comm.rank() * n * n_learn_chain), n); + mpi_adapter.add_sample(q, 1); + } + mpi_adapter.learn_metric(mpi_covar, 0, 1, comm); + + for (int i = 0; i < n; ++i) { + EXPECT_FLOAT_EQ(covar(i), mpi_covar(i)); + } +} + +#endif diff --git a/src/test/unit/mcmc/mpi_var_adaptation_test.cpp b/src/test/unit/mcmc/mpi_var_adaptation_test.cpp index b04b06bd84a..ccdf6de7a34 100644 --- a/src/test/unit/mcmc/mpi_var_adaptation_test.cpp +++ b/src/test/unit/mcmc/mpi_var_adaptation_test.cpp @@ -28,12 +28,12 @@ TEST(McmcVarAdaptation, mpi_learn_variance) { stan::math::mpi::Communicator comm(MPI_COMM_STAN); const int num_chains = comm.size(); const int n_learn_chain = n_learn / num_chains; - stan::mcmc::mpi_var_adaptation mpi_adapter(n, 1); + stan::mcmc::mpi_var_adaptation mpi_adapter(n, 1, 1); Eigen::VectorXd mpi_var(Eigen::VectorXd::Zero(n)); for (int i = 0; i < n_learn_chain; ++i) - mpi_adapter.estimators[0].add_sample(q); + mpi_adapter.add_sample(q, 1); - mpi_adapter.learn_variance(mpi_var, 0, comm); + mpi_adapter.learn_variance(mpi_var, 0, 1, comm); for (int i = 0; i < n; ++i) { EXPECT_FLOAT_EQ(var(i), mpi_var(i)); @@ -73,14 +73,14 @@ TEST(McmcVarAdaptation, mpi_data_learn_variance) { stan::math::mpi::Communicator comm(MPI_COMM_STAN); const int num_chains = comm.size(); const int n_learn_chain = n_learn / num_chains; - stan::mcmc::mpi_var_adaptation mpi_adapter(n, 1); + stan::mcmc::mpi_var_adaptation mpi_adapter(n, 1, 1); Eigen::VectorXd mpi_var(Eigen::VectorXd::Zero(n)); for (int i = 0; i < n_learn_chain; ++i) { Eigen::VectorXd q = Eigen::VectorXd::Map(&q_all(i * n + comm.rank() * n * n_learn_chain), n); - mpi_adapter.estimators[0].add_sample(q); + mpi_adapter.add_sample(q, 1); } - mpi_adapter.learn_variance(mpi_var, 0, comm); + mpi_adapter.learn_variance(mpi_var, 0, 1, comm); for (int i = 0; i < n; ++i) { EXPECT_FLOAT_EQ(var(i), mpi_var(i)); From 4acb8f7fd66b83860ea07fb9743af56659437851 Mon Sep 17 00:00:00 2001 From: yiz Date: Mon, 10 Feb 2020 16:58:04 -0800 Subject: [PATCH 59/73] change sed seed to set id --- src/stan/services/util/mpi_cross_chain.hpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/stan/services/util/mpi_cross_chain.hpp b/src/stan/services/util/mpi_cross_chain.hpp index b4dc2408805..e8e9a8bfab8 100644 --- a/src/stan/services/util/mpi_cross_chain.hpp +++ b/src/stan/services/util/mpi_cross_chain.hpp @@ -121,16 +121,15 @@ namespace util { /* * modify cmdstan::command seed */ - void set_cross_chain_seed(unsigned int& seed, int num_chains) { + void set_cross_chain_id(unsigned int& id, int num_chains) { #ifdef MPI_ADAPTED_WARMUP using stan::math::mpi::Session; using stan::math::mpi::Communicator; const Communicator& inter_comm = Session::inter_chain_comm(num_chains); const Communicator& intra_comm = Session::intra_chain_comm(num_chains); - MPI_Bcast(&seed, 1, MPI_UNSIGNED, 0, MPI_COMM_STAN); - seed += inter_comm.rank(); - MPI_Bcast(&seed, 1, MPI_UNSIGNED, 0, intra_comm.comm()); + id = inter_comm.rank(); + MPI_Bcast(&id, 1, MPI_UNSIGNED, 0, intra_comm.comm()); #endif } From 29826268698c5c6e4e0116188993b8e024fca9cc Mon Sep 17 00:00:00 2001 From: yiz Date: Tue, 11 Feb 2020 09:46:16 -0800 Subject: [PATCH 60/73] fix: unit nuts doesn't need var; no post warmup when max out --- src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp | 4 ++-- src/stan/services/util/mpi_cross_chain.hpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp index 6fcf9ec3db5..4260646cde2 100644 --- a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp @@ -77,8 +77,8 @@ int hmc_nuts_unit_e_adapt( sampler.set_cross_chain_adaptation_params(num_warmup, cross_chain_window, num_cross_chains, cross_chain_rhat, cross_chain_ess); - mcmc::mpi_var_adaptation var_adapt(model.num_params_r(), num_warmup, cross_chain_window); - sampler.set_cross_chain_metric_adaptation(&var_adapt); + mcmc::mpi_metric_adaptation dummy_adapt; + sampler.set_cross_chain_metric_adaptation(&dummy_adapt); util::run_adaptive_sampler( diff --git a/src/stan/services/util/mpi_cross_chain.hpp b/src/stan/services/util/mpi_cross_chain.hpp index e8e9a8bfab8..2f1e99f3b66 100644 --- a/src/stan/services/util/mpi_cross_chain.hpp +++ b/src/stan/services/util/mpi_cross_chain.hpp @@ -72,7 +72,7 @@ namespace util { } static int num_post_warmup(Sampler& sampler) { - return sampler.num_post_warmup; + return sampler.is_cross_chain_adapted()? sampler.num_post_warmup : 0; } static int num_draws(Sampler& sampler) { From 6a95b53e9f8f671da9ac1b731ab40e0d21dc3d1e Mon Sep 17 00:00:00 2001 From: yiz Date: Tue, 11 Feb 2020 11:41:16 -0800 Subject: [PATCH 61/73] win loop to decrease counter in croos-chain adapter We prefer the most recent window that passes convergence test --- src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp index fc1d4e3f264..486b417e229 100644 --- a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp @@ -277,7 +277,7 @@ namespace mcmc { } } - for (int win = 0; win < win_count; ++win) { + for (int win = win_count - 1; win >= 0; --win) { accumulator_set> acc_chain_mean; accumulator_set> acc_chain_var; accumulator_set> acc_step; From 9d5c99ac40097714d4a4e36682c97aa6911fe458 Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 12 Feb 2020 00:40:04 -0800 Subject: [PATCH 62/73] full campfire version where every window has aggregation --- src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp | 113 +++++++++++++++++- 1 file changed, 112 insertions(+), 1 deletion(-) diff --git a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp index 486b417e229..f0554180401 100644 --- a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp @@ -226,7 +226,7 @@ namespace mcmc { * maximum windows for all chains. # @return vector {stepsize, rhat(only in rank 0)} */ - inline void cross_chain_adaptation(callbacks::logger& logger) { + inline void cross_chain_adaptation_v1(callbacks::logger& logger) { using boost::accumulators::accumulator_set; using boost::accumulators::stats; using boost::accumulators::tag::mean; @@ -329,6 +329,117 @@ namespace mcmc { } } } + + inline void cross_chain_adaptation(callbacks::logger& logger) { + using boost::accumulators::accumulator_set; + using boost::accumulators::stats; + using boost::accumulators::tag::mean; + using boost::accumulators::tag::variance; + + using stan::math::mpi::Session; + using stan::math::mpi::Communicator; + + Sampler& sampler = static_cast(*this); + + if ((!is_adapted_) && is_cross_chain_adapt_window_end()) { + double chain_stepsize = sampler.get_nominal_stepsize(); + bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); + double invalid_stepsize = -999.0; + double new_stepsize = invalid_stepsize; + double max_ess = 0.0; + if (is_inter_rank) { + const Communicator& comm = Session::inter_chain_comm(num_chains_); + + const int nd_win = 3; // mean, variance, chain_stepsize + const int win_count = current_cross_chain_window_counter(); + int n_gather = nd_win * win_count + window_size_; + std::vector chain_gather(n_gather, 0.0); + for (int win = 0; win < win_count; ++win) { + int num_draws = (win_count - win) * window_size_; + double unbiased_var_scale = num_draws / (num_draws - 1.0); + chain_gather[nd_win * win] = boost::accumulators::mean(lp_acc_[win]); + chain_gather[nd_win * win + 1] = boost::accumulators::variance(lp_acc_[win]) * + unbiased_var_scale; + chain_gather[nd_win * win + 2] = chain_stepsize; + } + std::copy(lp_draws_.begin(), lp_draws_.end(), + chain_gather.begin() + nd_win * win_count); + + rhat_ = Eigen::ArrayXd::Zero(max_num_windows_); + ess_ = Eigen::ArrayXd::Zero(max_num_windows_); + const int invalid_win = -999; + int adapted_win = invalid_win; + + if (comm.rank() == 0) { + std::vector all_chain_gather(n_gather * num_chains_); + MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, + all_chain_gather.data(), n_gather, MPI_DOUBLE, 0, comm.comm()); + int begin_row = (win_count - 1) * window_size_; + for (int chain = 0; chain < num_chains_; ++chain) { + int j = n_gather * chain + nd_win * win_count; + for (int i = 0; i < window_size_; ++i) { + all_lp_draws_(begin_row + i, chain) = all_chain_gather[j + i]; + } + } + + for (int win = 0; win < win_count; ++win) { + bool win_adapted; + accumulator_set> acc_chain_mean; + accumulator_set> acc_chain_var; + accumulator_set> acc_step; + Eigen::VectorXd chain_mean(num_chains_); + Eigen::VectorXd chain_var(num_chains_); + for (int chain = 0; chain < num_chains_; ++chain) { + chain_mean(chain) = all_chain_gather[chain * n_gather + nd_win * win]; + acc_chain_mean(chain_mean(chain)); + chain_var(chain) = all_chain_gather[chain * n_gather + nd_win * win + 1]; + acc_chain_var(chain_var(chain)); + acc_step(all_chain_gather[chain * n_gather + nd_win * win + 2]); + } + size_t num_draws = (win_count - win) * window_size_; + double var_between = num_draws * boost::accumulators::variance(acc_chain_mean) + * num_chains_ / (num_chains_ - 1); + double var_within = boost::accumulators::mean(acc_chain_var); + rhat_(win) = sqrt((var_between / var_within + num_draws - 1) / num_draws); + ess_[win] = compute_effective_sample_size(win, win_count); + win_adapted = rhat_(win) < target_rhat_ && ess_[win] > target_ess_; + + msg_adaptation(win, logger); + + if(ess_[win] > max_ess) { + max_ess = ess_[win]; + adapted_win = -(win + 1); + if (win_adapted) { + adapted_win = std::abs(adapted_win) - 1; + } + } + } + } else { + MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, + NULL, 0, MPI_DOUBLE, 0, comm.comm()); + } + MPI_Bcast(&adapted_win, 1, MPI_INT, 0, comm.comm()); + if (adapted_win >= 0) { + MPI_Allreduce(&chain_stepsize, &new_stepsize, 1, MPI_DOUBLE, MPI_SUM, comm.comm()); + new_stepsize /= num_chains_; + var_adapt -> learn_metric(sampler.z().inv_e_metric_, adapted_win, win_count, comm); + std::cout << "cross chain win: " << adapted_win + 1 << "\n"; + } else { + MPI_Allreduce(&chain_stepsize, &new_stepsize, 1, MPI_DOUBLE, MPI_SUM, comm.comm()); + new_stepsize /= -num_chains_; + var_adapt -> learn_metric(sampler.z().inv_e_metric_, std::abs(adapted_win)-1, win_count, comm); + std::cout << "cross chain win: " << std::abs(adapted_win) << "\n"; + } + } + const Communicator& intra_comm = Session::intra_chain_comm(num_chains_); + MPI_Bcast(&new_stepsize, 1, MPI_DOUBLE, 0, intra_comm.comm()); + is_adapted_ = new_stepsize > 0.0; + chain_stepsize = std::abs(new_stepsize); + MPI_Bcast(sampler.z().inv_e_metric_.data(), + sampler.z().inv_e_metric_.size(), MPI_DOUBLE, 0, intra_comm.comm()); + sampler.set_nominal_stepsize(chain_stepsize); + } + } }; #else // sequential version From 6da0fa65a8fe3c7737f922ef4d28e028d9a2c1d4 Mon Sep 17 00:00:00 2001 From: yiz Date: Wed, 12 Feb 2020 14:49:52 -0800 Subject: [PATCH 63/73] harmonic mean for stepsize --- src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp index f0554180401..639f5ffd84a 100644 --- a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp @@ -419,14 +419,15 @@ namespace mcmc { NULL, 0, MPI_DOUBLE, 0, comm.comm()); } MPI_Bcast(&adapted_win, 1, MPI_INT, 0, comm.comm()); + chain_stepsize = 1.0/(chain_stepsize * chain_stepsize); // harmonic mean if (adapted_win >= 0) { MPI_Allreduce(&chain_stepsize, &new_stepsize, 1, MPI_DOUBLE, MPI_SUM, comm.comm()); - new_stepsize /= num_chains_; + new_stepsize = sqrt(num_chains_ / new_stepsize); var_adapt -> learn_metric(sampler.z().inv_e_metric_, adapted_win, win_count, comm); std::cout << "cross chain win: " << adapted_win + 1 << "\n"; } else { MPI_Allreduce(&chain_stepsize, &new_stepsize, 1, MPI_DOUBLE, MPI_SUM, comm.comm()); - new_stepsize /= -num_chains_; + new_stepsize = -sqrt(num_chains_ / new_stepsize); var_adapt -> learn_metric(sampler.z().inv_e_metric_, std::abs(adapted_win)-1, win_count, comm); std::cout << "cross chain win: " << std::abs(adapted_win) << "\n"; } @@ -458,6 +459,8 @@ namespace mcmc { inline void cross_chain_adaptation(callbacks::logger& logger) {} inline bool is_cross_chain_adapted() { return false; } + + inline bool is_cross_chain_adapt_window_end() { return false; } }; #endif From 7f6ec8619c6b983eb40912383044b2a05b448112 Mon Sep 17 00:00:00 2001 From: yiz Date: Thu, 13 Feb 2020 10:41:18 -0800 Subject: [PATCH 64/73] separate stepsize & metric aggregation --- src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp | 58 ++++++++++++------- src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp | 27 ++++++--- src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp | 23 +++++--- src/stan/mcmc/mpi_covar_adaptation.hpp | 2 +- 4 files changed, 71 insertions(+), 39 deletions(-) diff --git a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp index 639f5ffd84a..5518ecfa30f 100644 --- a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp @@ -45,7 +45,6 @@ namespace mcmc { Eigen::ArrayXd rhat_; Eigen::ArrayXd ess_; mpi_metric_adaptation* var_adapt; - public: const static int num_post_warmup = 50; @@ -143,7 +142,6 @@ namespace mcmc { if (!is_adapted_) { int n_win = current_cross_chain_window_counter(); - Sampler& sampler = static_cast(*this); bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); if (is_inter_rank) { lp_draws_[i] = s; @@ -330,7 +328,7 @@ namespace mcmc { } } - inline void cross_chain_adaptation(callbacks::logger& logger) { + inline bool cross_chain_adaptation(callbacks::logger& logger) { using boost::accumulators::accumulator_set; using boost::accumulators::stats; using boost::accumulators::tag::mean; @@ -419,28 +417,44 @@ namespace mcmc { NULL, 0, MPI_DOUBLE, 0, comm.comm()); } MPI_Bcast(&adapted_win, 1, MPI_INT, 0, comm.comm()); - chain_stepsize = 1.0/(chain_stepsize * chain_stepsize); // harmonic mean - if (adapted_win >= 0) { - MPI_Allreduce(&chain_stepsize, &new_stepsize, 1, MPI_DOUBLE, MPI_SUM, comm.comm()); - new_stepsize = sqrt(num_chains_ / new_stepsize); - var_adapt -> learn_metric(sampler.z().inv_e_metric_, adapted_win, win_count, comm); - std::cout << "cross chain win: " << adapted_win + 1 << "\n"; - } else { - MPI_Allreduce(&chain_stepsize, &new_stepsize, 1, MPI_DOUBLE, MPI_SUM, comm.comm()); - new_stepsize = -sqrt(num_chains_ / new_stepsize); - var_adapt -> learn_metric(sampler.z().inv_e_metric_, std::abs(adapted_win)-1, win_count, comm); - std::cout << "cross chain win: " << std::abs(adapted_win) << "\n"; - } + is_adapted_ = adapted_win >= 0; + int max_ess_win = is_adapted_ ? adapted_win : (-adapted_win - 1); + var_adapt -> learn_metric(sampler.z().inv_e_metric_, max_ess_win, win_count, comm); + std::cout << "cross chain win: " << max_ess_win + 1 << "\n"; } const Communicator& intra_comm = Session::intra_chain_comm(num_chains_); - MPI_Bcast(&new_stepsize, 1, MPI_DOUBLE, 0, intra_comm.comm()); - is_adapted_ = new_stepsize > 0.0; - chain_stepsize = std::abs(new_stepsize); + MPI_Bcast(&is_adapted_, 1, MPI_C_BOOL, 0, intra_comm.comm()); MPI_Bcast(sampler.z().inv_e_metric_.data(), sampler.z().inv_e_metric_.size(), MPI_DOUBLE, 0, intra_comm.comm()); - sampler.set_nominal_stepsize(chain_stepsize); + if (is_adapted_) { + set_cross_chain_stepsize(); + } + return true; + } else { + return false; } } + + inline void set_cross_chain_stepsize() { + using stan::math::mpi::Session; + using stan::math::mpi::Communicator; + + const Communicator& comm = Session::inter_chain_comm(num_chains_); + bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); + double new_stepsize; + Sampler& sampler = static_cast(*this); + if (is_inter_rank) { + double chain_stepsize = sampler.get_nominal_stepsize(); + chain_stepsize = 1.0/(chain_stepsize * chain_stepsize); // harmonic mean + MPI_Allreduce(&chain_stepsize, &new_stepsize, 1, MPI_DOUBLE, MPI_SUM, comm.comm()); + new_stepsize = sqrt(num_chains_ / new_stepsize); + } + const Communicator& intra_comm = Session::intra_chain_comm(num_chains_); + MPI_Bcast(&new_stepsize, 1, MPI_DOUBLE, 0, intra_comm.comm()); + sampler.set_nominal_stepsize(new_stepsize); + } + + inline bool use_cross_chain_adapt() { return true; } }; #else // sequential version @@ -456,11 +470,13 @@ namespace mcmc { double target_rhat, double target_ess) {} inline void add_cross_chain_sample(double s) {} - inline void cross_chain_adaptation(callbacks::logger& logger) {} + inline bool cross_chain_adaptation(callbacks::logger& logger) { return false; } inline bool is_cross_chain_adapted() { return false; } - inline bool is_cross_chain_adapt_window_end() { return false; } + inline void set_cross_chain_stepsize() {} + + inline bool use_cross_chain_adapt() { return false; } }; #endif diff --git a/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp index 17312a96720..8aa589d9689 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp @@ -31,20 +31,27 @@ class adapt_dense_e_nuts : public dense_e_nuts, this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, s.accept_stat()); - bool update = this->covar_adaptation_.learn_covariance( - this->z_.inv_e_metric_, this->z_.q); - - // cross-chain adaptation - this -> add_cross_chain_sample(s.log_prob()); - this -> cross_chain_adaptation(logger); - if (this -> is_cross_chain_adapted()) update = false; - // cross-chain adaptation + bool update; + if (this -> use_cross_chain_adapt()) { + this -> add_cross_chain_sample(s.log_prob()); + update = this -> cross_chain_adaptation(logger); + if (this -> is_cross_chain_adapted()) { + update = false; + } + } else { + update = this->covar_adaptation_.learn_covariance(this->z_.inv_e_metric_, + this->z_.q); + } if (update) { this->init_stepsize(logger); this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); this->stepsize_adaptation_.restart(); + + if (this -> use_cross_chain_adapt()) { + this->set_cross_chain_stepsize(); + } } } return s; @@ -52,7 +59,9 @@ class adapt_dense_e_nuts : public dense_e_nuts, void disengage_adaptation() { base_adapter::disengage_adaptation(); - this->stepsize_adaptation_.complete_adaptation(this->nom_epsilon_); + if (!this -> is_cross_chain_adapted()) { + this->stepsize_adaptation_.complete_adaptation(this->nom_epsilon_); + } } }; diff --git a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp index 908a6d2efb2..96c34969205 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp @@ -31,20 +31,27 @@ class adapt_diag_e_nuts : public diag_e_nuts, this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, s.accept_stat()); - bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_, - this->z_.q); - - // cross-chain adaptation - this -> add_cross_chain_sample(s.log_prob()); - this -> cross_chain_adaptation(logger); - if (this -> is_cross_chain_adapted()) update = false; - // cross-chain adaptation + bool update; + if (this -> use_cross_chain_adapt()) { + this -> add_cross_chain_sample(s.log_prob()); + update = this -> cross_chain_adaptation(logger); + if (this -> is_cross_chain_adapted()) { + update = false; + } + } else { + update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_, + this->z_.q); + } if (update) { this->init_stepsize(logger); this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); this->stepsize_adaptation_.restart(); + + if (this -> use_cross_chain_adapt()) { + this->set_cross_chain_stepsize(); + } } } return s; diff --git a/src/stan/mcmc/mpi_covar_adaptation.hpp b/src/stan/mcmc/mpi_covar_adaptation.hpp index 35e68053687..2c6cf421834 100644 --- a/src/stan/mcmc/mpi_covar_adaptation.hpp +++ b/src/stan/mcmc/mpi_covar_adaptation.hpp @@ -45,7 +45,7 @@ namespace mcmc { covar = (n / (n + 5.0)) * covar + 1e-3 * (5.0 / (n + 5.0)) * Eigen::MatrixXd::Identity(covar.rows(), covar.cols()); - restart(); + // restart(); } virtual void restart() { From 87356a123d521b960a82ae182cbdcb875c6b9d8b Mon Sep 17 00:00:00 2001 From: yiz Date: Thu, 13 Feb 2020 15:39:12 -0800 Subject: [PATCH 65/73] no restart for mpi_var_adapt --- src/stan/mcmc/mpi_var_adaptation.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/stan/mcmc/mpi_var_adaptation.hpp b/src/stan/mcmc/mpi_var_adaptation.hpp index 2242cbbef91..d7ef4e8031e 100644 --- a/src/stan/mcmc/mpi_var_adaptation.hpp +++ b/src/stan/mcmc/mpi_var_adaptation.hpp @@ -46,7 +46,6 @@ namespace mcmc { double n = static_cast(estimators[win].sample_variance(var, comm)); var = (n / (n + 5.0)) * var + 1e-3 * (5.0 / (n + 5.0)) * Eigen::VectorXd::Ones(var.size()); - restart(); } virtual void restart() { From 92b8f3fa8a735223982569c65038a6d990a87dc0 Mon Sep 17 00:00:00 2001 From: yiz Date: Thu, 13 Feb 2020 17:21:00 -0800 Subject: [PATCH 66/73] exclude initial buffer from covar calculation --- lib/stan_math | 2 +- src/stan/mcmc/mpi_covar_adaptation.hpp | 7 +++---- src/stan/mcmc/mpi_metric_adaptation.hpp | 2 ++ 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/lib/stan_math b/lib/stan_math index fd218be2cf2..91244285f84 160000 --- a/lib/stan_math +++ b/lib/stan_math @@ -1 +1 @@ -Subproject commit fd218be2cf261acd31d539bbf412c1a9b65530e3 +Subproject commit 91244285f842281a7c5371f5e941c327dbeff462 diff --git a/src/stan/mcmc/mpi_covar_adaptation.hpp b/src/stan/mcmc/mpi_covar_adaptation.hpp index 2c6cf421834..e467209183a 100644 --- a/src/stan/mcmc/mpi_covar_adaptation.hpp +++ b/src/stan/mcmc/mpi_covar_adaptation.hpp @@ -32,8 +32,8 @@ namespace mcmc { virtual void learn_metric(Eigen::MatrixXd& covar, int win, int curr_win_count, const stan::math::mpi::Communicator& comm) { - int col_begin = win * window_size_; - int num_draws = (curr_win_count - win) * window_size_; + int col_begin = win == 0 ? init_bufer_size : (win * window_size_); + int num_draws = win == 0 ? (curr_win_count * window_size_ - init_bufer_size) : ((curr_win_count - win) * window_size_); learn_covariance(covar, col_begin, num_draws, comm); } @@ -41,11 +41,10 @@ namespace mcmc { int col_begin, int n_samples, const stan::math::mpi::Communicator& comm) { estimator.sample_covariance(covar, col_begin, n_samples, comm); - double n = static_cast(estimator.num_samples(comm)); + double n = n_samples * comm.size(); covar = (n / (n + 5.0)) * covar + 1e-3 * (5.0 / (n + 5.0)) * Eigen::MatrixXd::Identity(covar.rows(), covar.cols()); - // restart(); } virtual void restart() { diff --git a/src/stan/mcmc/mpi_metric_adaptation.hpp b/src/stan/mcmc/mpi_metric_adaptation.hpp index 21c6df72c9f..8f37e4b696b 100644 --- a/src/stan/mcmc/mpi_metric_adaptation.hpp +++ b/src/stan/mcmc/mpi_metric_adaptation.hpp @@ -13,6 +13,8 @@ namespace stan { namespace mcmc { class mpi_metric_adaptation { + protected: + static const int init_bufer_size = 75; public: virtual void add_sample(const Eigen::VectorXd& q, int curr_win_count) {}; From 83a450abe1022b69f96ecc192e98ae41621fb6ba Mon Sep 17 00:00:00 2001 From: yiz Date: Thu, 13 Feb 2020 17:49:21 -0800 Subject: [PATCH 67/73] ignore init buffer draws in mpi var adaptation --- src/stan/mcmc/mpi_var_adaptation.hpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/stan/mcmc/mpi_var_adaptation.hpp b/src/stan/mcmc/mpi_var_adaptation.hpp index d7ef4e8031e..5a1fe18e139 100644 --- a/src/stan/mcmc/mpi_var_adaptation.hpp +++ b/src/stan/mcmc/mpi_var_adaptation.hpp @@ -17,13 +17,14 @@ namespace mcmc { #ifdef STAN_LANG_MPI using est_t = stan::math::mpi::mpi_var_estimator; + int init_draw_counter; public: std::vector estimators; mpi_var_adaptation() = default; mpi_var_adaptation(int n_params, int max_num_windows) - : estimators(max_num_windows, est_t(n_params)) + : init_draw_counter(0), estimators(max_num_windows, est_t(n_params)) {} mpi_var_adaptation(int n_params, int num_iterations, int window_size) @@ -31,8 +32,11 @@ namespace mcmc { {} virtual void add_sample(const Eigen::VectorXd& q, int curr_win_count) { - for (int win = 0; win < curr_win_count; ++win) { - estimators[win].add_sample(q); + init_draw_counter++; + if (init_draw_counter > init_bufer_size) { + for (int win = 0; win < curr_win_count; ++win) { + estimators[win].add_sample(q); + } } } From f2835ce826c20920f4a5e1f37afd640faee4dbda Mon Sep 17 00:00:00 2001 From: Ben Date: Sat, 15 Feb 2020 16:24:07 -0500 Subject: [PATCH 68/73] Updated mpi file name adjuster to work with diagnostic files --- src/stan/services/util/mpi_cross_chain.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stan/services/util/mpi_cross_chain.hpp b/src/stan/services/util/mpi_cross_chain.hpp index 2f1e99f3b66..7c10381656a 100644 --- a/src/stan/services/util/mpi_cross_chain.hpp +++ b/src/stan/services/util/mpi_cross_chain.hpp @@ -141,9 +141,9 @@ namespace util { using stan::math::mpi::Session; using stan::math::mpi::Communicator; - if (Session::is_in_inter_chain_comm(num_chains)) { + if (file_name.size() > 0 && num_chains > 1 && Session::is_in_inter_chain_comm(num_chains)) { const Communicator& comm = Session::inter_chain_comm(num_chains); - file_name = "mpi." + std::to_string(comm.rank()) + "." + file_name; + file_name = file_name + "." + "mpi." + std::to_string(comm.rank()); } #endif } From 97af9ac78a1639be465c32c1c14d297e0ad2c095 Mon Sep 17 00:00:00 2001 From: Ben Date: Mon, 17 Feb 2020 18:44:24 -0500 Subject: [PATCH 69/73] auto_e metric might be working --- .../mcmc/hmc/hamiltonians/auto_e_point.hpp | 12 - src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp | 117 +------ src/stan/mcmc/hmc/nuts/adapt_auto_e_nuts.hpp | 37 ++- src/stan/mcmc/mpi_auto_adaptation.hpp | 292 ++++++++++++++++++ .../services/sample/hmc_nuts_auto_e_adapt.hpp | 17 +- 5 files changed, 336 insertions(+), 139 deletions(-) create mode 100644 src/stan/mcmc/mpi_auto_adaptation.hpp diff --git a/src/stan/mcmc/hmc/hamiltonians/auto_e_point.hpp b/src/stan/mcmc/hmc/hamiltonians/auto_e_point.hpp index 69ba07fb35b..d882f1a3910 100644 --- a/src/stan/mcmc/hmc/hamiltonians/auto_e_point.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/auto_e_point.hpp @@ -33,18 +33,6 @@ namespace stan { inv_e_metric_.setIdentity(); } - /** - * Copy constructor which does fast copy of inverse mass matrix. - * - * @param z point to copy - */ - auto_e_point(const auto_e_point& z) - : ps_point(z), inv_e_metric_(z.inv_e_metric_.rows(), - z.inv_e_metric_.cols()) { - fast_matrix_copy_(inv_e_metric_, z.inv_e_metric_); - is_diagonal_ = z.is_diagonal_; - } - /** * Set elements of mass matrix * diff --git a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp index 5518ecfa30f..b4586aa55d0 100644 --- a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp +++ b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp @@ -44,8 +44,9 @@ namespace mcmc { boost::accumulators::features > draw_count_acc_; Eigen::ArrayXd rhat_; Eigen::ArrayXd ess_; - mpi_metric_adaptation* var_adapt; public: + mpi_metric_adaptation* var_adapt; + const static int num_post_warmup = 50; mpi_cross_chain_adapter() = default; @@ -214,120 +215,6 @@ namespace mcmc { logger.info(message); } - /* - * @tparam Sampler sampler class - * @param[in] m_win number of windows - * @param[in] window_size window size - * @param[in] num_chains number of chains - * @param[in,out] chain_gather gathered information from each chain, - * must have enough capacity to store up to - * maximum windows for all chains. - # @return vector {stepsize, rhat(only in rank 0)} - */ - inline void cross_chain_adaptation_v1(callbacks::logger& logger) { - using boost::accumulators::accumulator_set; - using boost::accumulators::stats; - using boost::accumulators::tag::mean; - using boost::accumulators::tag::variance; - - using stan::math::mpi::Session; - using stan::math::mpi::Communicator; - - Sampler& sampler = static_cast(*this); - - if ((!is_adapted_) && is_cross_chain_adapt_window_end()) { - double chain_stepsize = sampler.get_nominal_stepsize(); - bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); - double invalid_stepsize = -999.0; - double new_stepsize = invalid_stepsize; - if (is_inter_rank) { - const Communicator& comm = Session::inter_chain_comm(num_chains_); - - const int nd_win = 3; // mean, variance, chain_stepsize - const int win_count = current_cross_chain_window_counter(); - int n_gather = nd_win * win_count + window_size_; - std::vector chain_gather(n_gather, 0.0); - for (int win = 0; win < win_count; ++win) { - int num_draws = (win_count - win) * window_size_; - double unbiased_var_scale = num_draws / (num_draws - 1.0); - chain_gather[nd_win * win] = boost::accumulators::mean(lp_acc_[win]); - chain_gather[nd_win * win + 1] = boost::accumulators::variance(lp_acc_[win]) * - unbiased_var_scale; - chain_gather[nd_win * win + 2] = chain_stepsize; - } - std::copy(lp_draws_.begin(), lp_draws_.end(), - chain_gather.begin() + nd_win * win_count); - - rhat_ = Eigen::ArrayXd::Zero(max_num_windows_); - ess_ = Eigen::ArrayXd::Zero(max_num_windows_); - const int invalid_win = -999; - int adapted_win = invalid_win; - - if (comm.rank() == 0) { - std::vector all_chain_gather(n_gather * num_chains_); - MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, - all_chain_gather.data(), n_gather, MPI_DOUBLE, 0, comm.comm()); - int begin_row = (win_count - 1) * window_size_; - for (int chain = 0; chain < num_chains_; ++chain) { - int j = n_gather * chain + nd_win * win_count; - for (int i = 0; i < window_size_; ++i) { - all_lp_draws_(begin_row + i, chain) = all_chain_gather[j + i]; - } - } - - for (int win = win_count - 1; win >= 0; --win) { - accumulator_set> acc_chain_mean; - accumulator_set> acc_chain_var; - accumulator_set> acc_step; - Eigen::VectorXd chain_mean(num_chains_); - Eigen::VectorXd chain_var(num_chains_); - for (int chain = 0; chain < num_chains_; ++chain) { - chain_mean(chain) = all_chain_gather[chain * n_gather + nd_win * win]; - acc_chain_mean(chain_mean(chain)); - chain_var(chain) = all_chain_gather[chain * n_gather + nd_win * win + 1]; - acc_chain_var(chain_var(chain)); - acc_step(all_chain_gather[chain * n_gather + nd_win * win + 2]); - } - size_t num_draws = (win_count - win) * window_size_; - double var_between = num_draws * boost::accumulators::variance(acc_chain_mean) - * num_chains_ / (num_chains_ - 1); - double var_within = boost::accumulators::mean(acc_chain_var); - rhat_(win) = sqrt((var_between / var_within + num_draws - 1) / num_draws); - ess_[win] = compute_effective_sample_size(win, win_count); - is_adapted_ = rhat_(win) < target_rhat_ && ess_[win] > target_ess_; - - msg_adaptation(win, logger); - - if (is_adapted_) { - adapted_win = win; - break; - } - } - } else { - MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, - NULL, 0, MPI_DOUBLE, 0, comm.comm()); - } - MPI_Bcast(&adapted_win, 1, MPI_INT, 0, comm.comm()); - if (adapted_win >= 0) { - MPI_Allreduce(&chain_stepsize, &new_stepsize, 1, MPI_DOUBLE, MPI_SUM, comm.comm()); - new_stepsize /= num_chains_; - var_adapt -> learn_metric(sampler.z().inv_e_metric_, adapted_win, win_count, comm); - } - } - const Communicator& intra_comm = Session::intra_chain_comm(num_chains_); - MPI_Bcast(&new_stepsize, 1, MPI_DOUBLE, 0, intra_comm.comm()); - is_adapted_ = new_stepsize > 0.0; - if (is_adapted_) { - chain_stepsize = new_stepsize; - MPI_Bcast(sampler.z().inv_e_metric_.data(), - sampler.z().inv_e_metric_.size(), MPI_DOUBLE, 0, intra_comm.comm()); - } - if (is_adapted_) { - sampler.set_nominal_stepsize(chain_stepsize); - } - } - } - inline bool cross_chain_adaptation(callbacks::logger& logger) { using boost::accumulators::accumulator_set; using boost::accumulators::stats; diff --git a/src/stan/mcmc/hmc/nuts/adapt_auto_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_auto_e_nuts.hpp index 4335f61df9d..fc2523da9e5 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_auto_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_auto_e_nuts.hpp @@ -2,25 +2,27 @@ #define STAN_MCMC_HMC_NUTS_ADAPT_AUTO_E_NUTS_HPP #include -#include #include +#include +#include namespace stan { namespace mcmc { /** * The No-U-Turn sampler (NUTS) with multinomial sampling * with a Gaussian-Euclidean disintegration and adaptive - * dense metric and adaptive step size + * dense or diagonal metric and adaptive step size */ template class adapt_auto_e_nuts : public auto_e_nuts, - public stepsize_auto_adapter { + public mpi_cross_chain_adapter>, + public stepsize_covar_adapter { protected: const Model& model_; public: adapt_auto_e_nuts(const Model& model, BaseRNG& rng) : model_(model), auto_e_nuts(model, rng), - stepsize_auto_adapter(model.num_params_r()) {} + stepsize_covar_adapter(model.num_params_r()) {} ~adapt_auto_e_nuts() {} @@ -33,17 +35,30 @@ namespace stan { this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, s.accept_stat()); - bool update = this->auto_adaptation_.learn_covariance( - model_, - this->z_.inv_e_metric_, - this->z_.is_diagonal_, - this->z_.q); + bool update; + if (this -> use_cross_chain_adapt()) { + this -> add_cross_chain_sample(s.log_prob()); + update = this -> cross_chain_adaptation(logger); + if (this -> is_cross_chain_adapted()) { + update = false; + } + } else { + update = this->covar_adaptation_.learn_covariance(this->z_.inv_e_metric_, + this->z_.q); + } if (update) { + //std::cout << this->z_.inv_e_metric_ << std::endl; + this->z_.is_diagonal_ = reinterpret_cast *>(this->var_adapt)->is_diagonal_; + this->init_stepsize(logger); this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); this->stepsize_adaptation_.restart(); + + if (this -> use_cross_chain_adapt()) { + this->set_cross_chain_stepsize(); + } } } return s; @@ -51,7 +66,9 @@ namespace stan { void disengage_adaptation() { base_adapter::disengage_adaptation(); - this->stepsize_adaptation_.complete_adaptation(this->nom_epsilon_); + if (!this -> is_cross_chain_adapted()) { + this->stepsize_adaptation_.complete_adaptation(this->nom_epsilon_); + } } }; diff --git a/src/stan/mcmc/mpi_auto_adaptation.hpp b/src/stan/mcmc/mpi_auto_adaptation.hpp new file mode 100644 index 00000000000..b3b8ce3fafd --- /dev/null +++ b/src/stan/mcmc/mpi_auto_adaptation.hpp @@ -0,0 +1,292 @@ +#ifndef STAN_MCMC_MPI_AUTO_ADAPTATION_HPP +#define STAN_MCMC_MPI_AUTO_ADAPTATION_HPP + +#include +#include +#include + +#ifdef STAN_LANG_MPI +#include +#endif + +namespace stan { + +namespace mcmc { + template + struct log_prob_wrapper_covar { + const Model& model_; + log_prob_wrapper_covar(const Model& model) : model_(model) {} + + template + T operator()(const Eigen::Matrix& q) const { + return model_.template log_prob(const_cast& >(q), &std::cout); + } + }; + + namespace internal { + /** + * Compute the covariance of data in Y. + * + * Columns of Y are different variables. Rows are different samples. + * + * When there is only one row in Y, return a covariance matrix of the expected + * size filled with zeros. + * + * @param Y Data + * @return Covariance of Y + */ + Eigen::MatrixXd covariance(const Eigen::MatrixXd& Y) { + stan::math::check_nonzero_size("covariance", "Y", Y); + + Eigen::MatrixXd centered = Y.rowwise() - Y.colwise().mean(); + return centered.transpose() * centered / std::max(centered.rows() - 1.0, 1.0); + } + + /** + * Compute the largest magnitude eigenvalue of a symmetric matrix using the power method. The function f + * should return the product of that matrix with an abitrary vector. + * + * f should take one Eigen::VectorXd argument, x, and return the product of a matrix with x as + * an Eigen::VectorXd argument of the same size. + * + * The eigenvalue is estimated iteratively. If the kth estimate is e_k, then the function returns when + * either abs(e_{k + 1} - e_k) < tol * abs(e_k) or the maximum number of iterations have been performed + * + * This means the returned eigenvalue might not be computed to full precision + * + * @param initial_guess Initial guess of the eigenvector of the largest eigenvalue + * @param[in,out] max_iterations Maximum number of power iterations, on return number of iterations used + * @param[in,out] tol Relative tolerance, on return the relative error in the eigenvalue estimate + * @return Largest magnitude eigenvalue of operator f + */ + template + double power_method(F& f, const Eigen::VectorXd& initial_guess, int& max_iterations, double& tol) { + Eigen::VectorXd v = initial_guess; + double eval = 0.0; + Eigen::VectorXd Av = f(v); + stan::math::check_matching_sizes("power_method", "matrix vector product", Av, "vector", v); + + int i = 0; + for(; i < max_iterations; ++i) { + double v_norm = v.norm(); + double new_eval = v.dot(Av) / (v_norm * v_norm); + if(i == max_iterations - 1 || std::abs(new_eval - eval) <= tol * std::abs(eval)) { + tol = std::abs(new_eval - eval) / std::abs(eval); + eval = new_eval; + max_iterations = i + 1; + break; + } + + eval = new_eval; + v = Av / Av.norm(); + + Av = f(v); + } + + return eval; + } + + /** + * Compute the largest eigenvalue of the sample covariance rescaled by a metric, + * that is, the largest eigenvalue of L^{-1} \Sigma L^{-T} + * + * @param L Cholesky decomposition of Metric + * @param Sigma Sample covariance + * @return Largest eigenvalue + */ + double eigenvalue_scaled_covariance(const Eigen::MatrixXd& L, const Eigen::MatrixXd& Sigma) { + Eigen::MatrixXd S = L.template triangularView(). + solve(L.template triangularView().solve(Sigma).transpose()).transpose(); + + auto Sx = [&](Eigen::VectorXd x) -> Eigen::VectorXd { + return S * x; + }; + + int max_iterations = 100; + double tol = 1e-3; + + return internal::power_method(Sx, Eigen::VectorXd::Random(Sigma.cols()), max_iterations, tol); + } + + /** + * Compute the largest eigenvalue of the Hessian of the log density rescaled by a metric, + * that is, the largest eigenvalue of L^T \nabla^2_{qq} H(q) L + * + * @tparam Model Type of model + * @param model Defines the log density + * @param q Point around which to compute the Hessian + * @param L Cholesky decomposition of Metric + * @return Largest eigenvalue + */ + template + double eigenvalue_scaled_hessian(const Model& model, const Eigen::MatrixXd& L, const Eigen::VectorXd& q) { + Eigen::VectorXd eigenvalues; + Eigen::MatrixXd eigenvectors; + + auto hessian_vector = [&](const Eigen::VectorXd& x) -> Eigen::VectorXd { + double lp; + Eigen::VectorXd grad1; + Eigen::VectorXd grad2; + //stan::math::hessian_times_vector(log_prob_wrapper_covar(model), q, x, lp, Ax); + double dx = 1e-5; + Eigen::VectorXd dr = L * x * dx; + stan::math::gradient(log_prob_wrapper_covar(model), q + dr / 2.0, lp, grad1); + stan::math::gradient(log_prob_wrapper_covar(model), q - dr / 2.0, lp, grad2); + return L.transpose() * (grad1 - grad2) / dx; + }; + + int max_iterations = 100; + double tol = 1e-3; + + return internal::power_method(hessian_vector, Eigen::VectorXd::Random(q.size()), max_iterations, tol); + } + } + +template +class mpi_auto_adaptation : public mpi_metric_adaptation { +#ifdef STAN_LANG_MPI + using est_t = stan::math::mpi::mpi_covar_estimator; + + int window_size_; + int n_params_; + Model& model_; + std::deque last_qs_; +public: + est_t estimator; + bool is_diagonal_; + + mpi_auto_adaptation(Model& model, int n_params, int num_iterations, int window_size) + : window_size_(window_size), + n_params_(n_params), + model_(model), + estimator(n_params, num_iterations), + is_diagonal_(false) {} + + virtual void add_sample(const Eigen::VectorXd& q, int curr_win_count) { + estimator.add_sample(q); + last_qs_.push_back(q); + if(last_qs_.size() > 5) { + last_qs_.pop_front(); + } + } + + virtual void learn_metric(Eigen::MatrixXd& covar, int win, int curr_win_count, + const stan::math::mpi::Communicator& comm) { + if(curr_win_count > 1) { + win = std::max(1, win); + } + + int col_begin = win * window_size_; + int num_draws = (curr_win_count - win) * window_size_; + + int M = n_params_; + + try { + bool use_dense = false; + for(auto state : { "selection", "refinement" }) { + Eigen::MatrixXd cov_train = Eigen::MatrixXd::Zero(M, M); + Eigen::MatrixXd cov_test = Eigen::MatrixXd::Zero(M, M); + + int Ntest; + if(state == "selection") { + Ntest = int(0.2 * num_draws); + if(Ntest < 5) { + Ntest = 5; + } + + if(num_draws < 10) { + throw std::runtime_error("Each warmup stage must have at least 10 samples"); + } + + learn_covariance(cov_train, col_begin, num_draws - Ntest, comm); + learn_covariance(cov_test, col_begin + num_draws - Ntest, Ntest, comm); + //Ytrain = Y.block(0, 0, M, Y.cols() - Mtest); + //Ytest = Y.block(0, Ytrain.cols(), M, Mtest); + } else { + learn_covariance(cov_train, col_begin, num_draws, comm); + Ntest = 0; + //Ytrain = Y; + } + + Eigen::MatrixXd dense = ((num_draws - Ntest) / ((num_draws - Ntest) + 5.0)) * cov_train + + 1e-3 * (5.0 / ((num_draws - Ntest) + 5.0)) * Eigen::MatrixXd::Identity(cov_train.rows(), cov_train.cols()); + + Eigen::MatrixXd diag = dense.diagonal().asDiagonal(); + + covar = dense; + + if(state == "selection") { + Eigen::MatrixXd L_dense = dense.llt().matrixL(); + Eigen::MatrixXd L_diag = diag.diagonal().array().sqrt().matrix().asDiagonal(); + + double low_eigenvalue_dense = -1.0 / internal::eigenvalue_scaled_covariance(L_dense, cov_test); + double low_eigenvalue_diag = -1.0 / internal::eigenvalue_scaled_covariance(L_diag, cov_test); + + std::cout << "TRAIN low:" << low_eigenvalue_dense << std::endl; + std::cout << "TEST low :" << low_eigenvalue_diag << std::endl; + + double c_dense = 0.0; + double c_diag = 0.0; + for(int i = 0; i < last_qs_.size(); i++) { + double high_eigenvalue_dense = internal::eigenvalue_scaled_hessian(model_, L_dense, last_qs_[i]); + double high_eigenvalue_diag = internal::eigenvalue_scaled_hessian(model_, L_diag, last_qs_[i]); + + c_dense = std::max(c_dense, std::sqrt(high_eigenvalue_dense / low_eigenvalue_dense)); + c_diag = std::max(c_diag, std::sqrt(high_eigenvalue_diag / low_eigenvalue_diag)); + } + + std::cout << "adapt dense, max: " << c_dense << std::endl; + std::cout << "adapt diag, max: " << c_diag << std::endl; + + if(c_dense < c_diag) { + use_dense = true; + } else { + use_dense = false; + } + } else { + if(use_dense) { + covar = dense; + is_diagonal_ = false; + } else { + covar = diag; + is_diagonal_ = true; + } + } + } + } catch(const std::exception& e) { + std::cout << e.what() << std::endl; + std::cout << "Exception while using auto adaptation, falling back to diagonal" << std::endl; + Eigen::MatrixXd cov = Eigen::MatrixXd::Zero(M, M); + learn_covariance(cov, col_begin, num_draws, comm); + covar = ((num_draws / (num_draws + 5.0)) * cov.diagonal() + + 1e-3 * (5.0 / (num_draws + 5.0)) * Eigen::VectorXd::Ones(cov.cols())).asDiagonal(); + is_diagonal_ = true; + } + + std::cout << covar << std::endl; + } + + void learn_covariance(Eigen::MatrixXd& covar, + int col_begin, int n_samples, + const stan::math::mpi::Communicator& comm) { + estimator.sample_covariance(covar, col_begin, n_samples, comm); + //double n = static_cast(estimator.num_samples(comm)); + //covar = (n / (n + 5.0)) * covar + // + 1e-3 * (5.0 / (n + 5.0)) + // * Eigen::MatrixXd::Identity(covar.rows(), covar.cols()); + // restart(); + } + + virtual void restart() { + estimator.restart(); + } +#else +public: + mpi_auto_adaptation(int n_params, int num_iterations, int window_size) {} +#endif +}; + +} // namespace mcmc +} // namespace stan + +#endif diff --git a/src/stan/services/sample/hmc_nuts_auto_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_auto_e_adapt.hpp index 464880d3164..4bb4a7b996a 100644 --- a/src/stan/services/sample/hmc_nuts_auto_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_auto_e_adapt.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -57,7 +58,9 @@ namespace stan { int hmc_nuts_auto_e_adapt(Model& model, stan::io::var_context& init, stan::io::var_context& init_inv_metric, unsigned int random_seed, unsigned int chain, - double init_radius, int num_warmup, + double init_radius, + int num_cross_chains, int cross_chain_window, double cross_chain_rhat, int cross_chain_ess, + int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, @@ -104,6 +107,13 @@ namespace stan { sampler.set_window_params(num_warmup, init_buffer, term_buffer, window, logger); + // cross chain adaptation + sampler.set_cross_chain_adaptation_params(num_warmup, + cross_chain_window, num_cross_chains, + cross_chain_rhat, cross_chain_ess); + mcmc::mpi_auto_adaptation var_adapt(model, model.num_params_r(), num_warmup, cross_chain_window); + sampler.set_cross_chain_metric_adaptation(&var_adapt); + util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, @@ -147,7 +157,9 @@ namespace stan { template int hmc_nuts_auto_e_adapt(Model& model, stan::io::var_context& init, unsigned int random_seed, unsigned int chain, - double init_radius, int num_warmup, + double init_radius, + int num_cross_chains, int cross_chain_window, double cross_chain_rhat, int cross_chain_ess, + int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, @@ -165,6 +177,7 @@ namespace stan { return hmc_nuts_auto_e_adapt(model, init, unit_e_metric, random_seed, chain, init_radius, + num_cross_chains, cross_chain_window, cross_chain_rhat, cross_chain_ess, num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, From 227a0a674c8671406973075da73064cda0d7c710 Mon Sep 17 00:00:00 2001 From: Ben Date: Tue, 18 Feb 2020 11:15:37 -0500 Subject: [PATCH 70/73] Updating stan math Don't use first samples for warmup --- lib/stan_math | 2 +- src/stan/mcmc/mpi_auto_adaptation.hpp | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/lib/stan_math b/lib/stan_math index 91244285f84..fd218be2cf2 160000 --- a/lib/stan_math +++ b/lib/stan_math @@ -1 +1 @@ -Subproject commit 91244285f842281a7c5371f5e941c327dbeff462 +Subproject commit fd218be2cf261acd31d539bbf412c1a9b65530e3 diff --git a/src/stan/mcmc/mpi_auto_adaptation.hpp b/src/stan/mcmc/mpi_auto_adaptation.hpp index b3b8ce3fafd..bb98099d3ce 100644 --- a/src/stan/mcmc/mpi_auto_adaptation.hpp +++ b/src/stan/mcmc/mpi_auto_adaptation.hpp @@ -151,6 +151,7 @@ class mpi_auto_adaptation : public mpi_metric_adaptation { int n_params_; Model& model_; std::deque last_qs_; + int init_draw_counter_; public: est_t estimator; bool is_diagonal_; @@ -159,23 +160,23 @@ class mpi_auto_adaptation : public mpi_metric_adaptation { : window_size_(window_size), n_params_(n_params), model_(model), + init_draw_counter_(0), estimator(n_params, num_iterations), is_diagonal_(false) {} virtual void add_sample(const Eigen::VectorXd& q, int curr_win_count) { - estimator.add_sample(q); - last_qs_.push_back(q); - if(last_qs_.size() > 5) { - last_qs_.pop_front(); + init_draw_counter++; + if (init_draw_counter > init_bufer_size) { + estimator.add_sample(q); + last_qs_.push_back(q); + if(last_qs_.size() > 5) { + last_qs_.pop_front(); + } } } - + virtual void learn_metric(Eigen::MatrixXd& covar, int win, int curr_win_count, const stan::math::mpi::Communicator& comm) { - if(curr_win_count > 1) { - win = std::max(1, win); - } - int col_begin = win * window_size_; int num_draws = (curr_win_count - win) * window_size_; From 37a3fa1da43b82575d6ea6b1c4c95d7a9b81073e Mon Sep 17 00:00:00 2001 From: Ben Date: Tue, 18 Feb 2020 12:38:10 -0500 Subject: [PATCH 71/73] auto adaptation terminates properly now Changed auto adaptation to specific ignore init_buffer samples --- src/stan/mcmc/hmc/nuts/adapt_auto_e_nuts.hpp | 110 +++++++++---------- src/stan/mcmc/mpi_auto_adaptation.hpp | 24 ++-- src/stan/services/util/mpi_cross_chain.hpp | 6 + 3 files changed, 69 insertions(+), 71 deletions(-) diff --git a/src/stan/mcmc/hmc/nuts/adapt_auto_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_auto_e_nuts.hpp index fc2523da9e5..d84d7c7aca3 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_auto_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_auto_e_nuts.hpp @@ -7,71 +7,71 @@ #include namespace stan { - namespace mcmc { - /** - * The No-U-Turn sampler (NUTS) with multinomial sampling - * with a Gaussian-Euclidean disintegration and adaptive - * dense or diagonal metric and adaptive step size - */ - template - class adapt_auto_e_nuts : public auto_e_nuts, - public mpi_cross_chain_adapter>, - public stepsize_covar_adapter { - protected: - const Model& model_; - public: - adapt_auto_e_nuts(const Model& model, BaseRNG& rng) - : model_(model), auto_e_nuts(model, rng), - stepsize_covar_adapter(model.num_params_r()) {} +namespace mcmc { +/** + * The No-U-Turn sampler (NUTS) with multinomial sampling + * with a Gaussian-Euclidean disintegration and adaptive + * dense or diagonal metric and adaptive step size + */ +template +class adapt_auto_e_nuts : public auto_e_nuts, + public mpi_cross_chain_adapter>, + public stepsize_covar_adapter { +protected: + const Model& model_; +public: + adapt_auto_e_nuts(const Model& model, BaseRNG& rng) + : model_(model), auto_e_nuts(model, rng), + stepsize_covar_adapter(model.num_params_r()) {} - ~adapt_auto_e_nuts() {} + ~adapt_auto_e_nuts() {} - sample - transition(sample& init_sample, callbacks::logger& logger) { - sample s = auto_e_nuts::transition(init_sample, - logger); + sample + transition(sample& init_sample, callbacks::logger& logger) { + sample s = auto_e_nuts::transition(init_sample, + logger); - if (this->adapt_flag_) { - this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, - s.accept_stat()); + if (this->adapt_flag_) { + this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, + s.accept_stat()); - bool update; - if (this -> use_cross_chain_adapt()) { - this -> add_cross_chain_sample(s.log_prob()); - update = this -> cross_chain_adaptation(logger); - if (this -> is_cross_chain_adapted()) { - update = false; - } - } else { - update = this->covar_adaptation_.learn_covariance(this->z_.inv_e_metric_, - this->z_.q); - } + bool update; + if (this -> use_cross_chain_adapt()) { + this -> add_cross_chain_sample(s.log_prob()); + update = this -> cross_chain_adaptation(logger); + if (this -> is_cross_chain_adapted()) { + update = false; + } + } else { + update = this->covar_adaptation_.learn_covariance(this->z_.inv_e_metric_, + this->z_.q); + } - if (update) { - //std::cout << this->z_.inv_e_metric_ << std::endl; - this->z_.is_diagonal_ = reinterpret_cast *>(this->var_adapt)->is_diagonal_; + if (update) { + //std::cout << this->z_.inv_e_metric_ << std::endl; + this->z_.is_diagonal_ = reinterpret_cast *>(this->var_adapt)->is_diagonal_; - this->init_stepsize(logger); + this->init_stepsize(logger); - this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); - this->stepsize_adaptation_.restart(); + this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); + this->stepsize_adaptation_.restart(); - if (this -> use_cross_chain_adapt()) { - this->set_cross_chain_stepsize(); - } - } - } - return s; - } - - void disengage_adaptation() { - base_adapter::disengage_adaptation(); - if (!this -> is_cross_chain_adapted()) { - this->stepsize_adaptation_.complete_adaptation(this->nom_epsilon_); + if (this -> use_cross_chain_adapt()) { + this->set_cross_chain_stepsize(); } } - }; + } + return s; + } + + void disengage_adaptation() { + base_adapter::disengage_adaptation(); + if (!this -> is_cross_chain_adapted()) { + this->stepsize_adaptation_.complete_adaptation(this->nom_epsilon_); + } + } +}; - } // mcmc +} // mcmc } // stan #endif diff --git a/src/stan/mcmc/mpi_auto_adaptation.hpp b/src/stan/mcmc/mpi_auto_adaptation.hpp index bb98099d3ce..6377281b579 100644 --- a/src/stan/mcmc/mpi_auto_adaptation.hpp +++ b/src/stan/mcmc/mpi_auto_adaptation.hpp @@ -151,7 +151,6 @@ class mpi_auto_adaptation : public mpi_metric_adaptation { int n_params_; Model& model_; std::deque last_qs_; - int init_draw_counter_; public: est_t estimator; bool is_diagonal_; @@ -160,25 +159,21 @@ class mpi_auto_adaptation : public mpi_metric_adaptation { : window_size_(window_size), n_params_(n_params), model_(model), - init_draw_counter_(0), estimator(n_params, num_iterations), is_diagonal_(false) {} virtual void add_sample(const Eigen::VectorXd& q, int curr_win_count) { - init_draw_counter++; - if (init_draw_counter > init_bufer_size) { - estimator.add_sample(q); - last_qs_.push_back(q); - if(last_qs_.size() > 5) { - last_qs_.pop_front(); - } + estimator.add_sample(q); + last_qs_.push_back(q); + if(last_qs_.size() > 5) { + last_qs_.pop_front(); } } virtual void learn_metric(Eigen::MatrixXd& covar, int win, int curr_win_count, const stan::math::mpi::Communicator& comm) { - int col_begin = win * window_size_; - int num_draws = (curr_win_count - win) * window_size_; + int col_begin = std::max(win * window_size_, init_bufer_size); + int num_draws = std::max(curr_win_count * window_size_ - col_begin, 0); int M = n_params_; @@ -188,6 +183,8 @@ class mpi_auto_adaptation : public mpi_metric_adaptation { Eigen::MatrixXd cov_train = Eigen::MatrixXd::Zero(M, M); Eigen::MatrixXd cov_test = Eigen::MatrixXd::Zero(M, M); + //std::cout << "col_begin: " << col_begin << ", num_draws: " << num_draws << std::endl; + int Ntest; if(state == "selection") { Ntest = int(0.2 * num_draws); @@ -223,9 +220,6 @@ class mpi_auto_adaptation : public mpi_metric_adaptation { double low_eigenvalue_dense = -1.0 / internal::eigenvalue_scaled_covariance(L_dense, cov_test); double low_eigenvalue_diag = -1.0 / internal::eigenvalue_scaled_covariance(L_diag, cov_test); - std::cout << "TRAIN low:" << low_eigenvalue_dense << std::endl; - std::cout << "TEST low :" << low_eigenvalue_diag << std::endl; - double c_dense = 0.0; double c_diag = 0.0; for(int i = 0; i < last_qs_.size(); i++) { @@ -263,8 +257,6 @@ class mpi_auto_adaptation : public mpi_metric_adaptation { + 1e-3 * (5.0 / (num_draws + 5.0)) * Eigen::VectorXd::Ones(cov.cols())).asDiagonal(); is_diagonal_ = true; } - - std::cout << covar << std::endl; } void learn_covariance(Eigen::MatrixXd& covar, diff --git a/src/stan/services/util/mpi_cross_chain.hpp b/src/stan/services/util/mpi_cross_chain.hpp index 7c10381656a..aa9f3d46b20 100644 --- a/src/stan/services/util/mpi_cross_chain.hpp +++ b/src/stan/services/util/mpi_cross_chain.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #ifdef STAN_LANG_MPI @@ -38,6 +39,11 @@ namespace util { static const bool value = true; }; + template + struct has_cross_chain_warmup> { + static const bool value = true; + }; + /* * Helper functions for samplers with MPI WARMUP. Other * samplers have dummy implmenentation. From 02e93b4378935a93b239f68e441ba319b28a6e43 Mon Sep 17 00:00:00 2001 From: Ben Date: Tue, 18 Feb 2020 14:13:49 -0500 Subject: [PATCH 72/73] Removed some unused files --- src/stan/mcmc/auto_adaptation.hpp | 270 ------------------ src/stan/mcmc/stepsize_auto_adapter.hpp | 48 ---- ...ation_learn_covariance_pick_dense_test.cpp | 64 ----- ...tation_learn_covariance_pick_diag_test.cpp | 63 ---- src/test/unit/mcmc/auto_adaptation_test.cpp | 170 ----------- 5 files changed, 615 deletions(-) delete mode 100644 src/stan/mcmc/auto_adaptation.hpp delete mode 100644 src/stan/mcmc/stepsize_auto_adapter.hpp delete mode 100644 src/test/unit/mcmc/auto_adaptation_learn_covariance_pick_dense_test.cpp delete mode 100644 src/test/unit/mcmc/auto_adaptation_learn_covariance_pick_diag_test.cpp delete mode 100644 src/test/unit/mcmc/auto_adaptation_test.cpp diff --git a/src/stan/mcmc/auto_adaptation.hpp b/src/stan/mcmc/auto_adaptation.hpp deleted file mode 100644 index f7f85164ec2..00000000000 --- a/src/stan/mcmc/auto_adaptation.hpp +++ /dev/null @@ -1,270 +0,0 @@ -#ifndef STAN_MCMC_AUTO_ADAPTATION_HPP -#define STAN_MCMC_AUTO_ADAPTATION_HPP - -#include -#include -#include - -namespace stan { - - namespace mcmc { - template - struct log_prob_wrapper_covar { - const Model& model_; - log_prob_wrapper_covar(const Model& model) : model_(model) {} - - template - T operator()(const Eigen::Matrix& q) const { - return model_.template log_prob(const_cast& >(q), &std::cout); - } - }; - - namespace internal { - /** - * Compute the covariance of data in Y. - * - * Columns of Y are different variables. Rows are different samples. - * - * When there is only one row in Y, return a covariance matrix of the expected - * size filled with zeros. - * - * @param Y Data - * @return Covariance of Y - */ - Eigen::MatrixXd covariance(const Eigen::MatrixXd& Y) { - stan::math::check_nonzero_size("covariance", "Y", Y); - - Eigen::MatrixXd centered = Y.rowwise() - Y.colwise().mean(); - return centered.transpose() * centered / std::max(centered.rows() - 1.0, 1.0); - } - - /** - * Compute the largest magnitude eigenvalue of a symmetric matrix using the power method. The function f - * should return the product of that matrix with an abitrary vector. - * - * f should take one Eigen::VectorXd argument, x, and return the product of a matrix with x as - * an Eigen::VectorXd argument of the same size. - * - * The eigenvalue is estimated iteratively. If the kth estimate is e_k, then the function returns when - * either abs(e_{k + 1} - e_k) < tol * abs(e_k) or the maximum number of iterations have been performed - * - * This means the returned eigenvalue might not be computed to full precision - * - * @param initial_guess Initial guess of the eigenvector of the largest eigenvalue - * @param[in,out] max_iterations Maximum number of power iterations, on return number of iterations used - * @param[in,out] tol Relative tolerance, on return the relative error in the eigenvalue estimate - * @return Largest magnitude eigenvalue of operator f - */ - template - double power_method(F& f, const Eigen::VectorXd& initial_guess, int& max_iterations, double& tol) { - Eigen::VectorXd v = initial_guess; - double eval = 0.0; - Eigen::VectorXd Av = f(v); - stan::math::check_matching_sizes("power_method", "matrix vector product", Av, "vector", v); - - int i = 0; - for(; i < max_iterations; ++i) { - double v_norm = v.norm(); - double new_eval = v.dot(Av) / (v_norm * v_norm); - if(i == max_iterations - 1 || std::abs(new_eval - eval) <= tol * std::abs(eval)) { - tol = std::abs(new_eval - eval) / std::abs(eval); - eval = new_eval; - max_iterations = i + 1; - break; - } - - eval = new_eval; - v = Av / Av.norm(); - - Av = f(v); - } - - return eval; - } - - /** - * Compute the largest eigenvalue of the sample covariance rescaled by a metric, - * that is, the largest eigenvalue of L^{-1} \Sigma L^{-T} - * - * @param L Cholesky decomposition of Metric - * @param Sigma Sample covariance - * @return Largest eigenvalue - */ - double eigenvalue_scaled_covariance(const Eigen::MatrixXd& L, const Eigen::MatrixXd& Sigma) { - Eigen::MatrixXd S = L.template triangularView(). - solve(L.template triangularView().solve(Sigma).transpose()).transpose(); - - auto Sx = [&](Eigen::VectorXd x) -> Eigen::VectorXd { - return S * x; - }; - - int max_iterations = 100; - double tol = 1e-3; - - return internal::power_method(Sx, Eigen::VectorXd::Random(Sigma.cols()), max_iterations, tol); - } - - /** - * Compute the largest eigenvalue of the Hessian of the log density rescaled by a metric, - * that is, the largest eigenvalue of L^T \nabla^2_{qq} H(q) L - * - * @tparam Model Type of model - * @param model Defines the log density - * @param q Point around which to compute the Hessian - * @param L Cholesky decomposition of Metric - * @return Largest eigenvalue - */ - template - double eigenvalue_scaled_hessian(const Model& model, const Eigen::MatrixXd& L, const Eigen::VectorXd& q) { - Eigen::VectorXd eigenvalues; - Eigen::MatrixXd eigenvectors; - - auto hessian_vector = [&](const Eigen::VectorXd& x) -> Eigen::VectorXd { - double lp; - Eigen::VectorXd grad1; - Eigen::VectorXd grad2; - //stan::math::hessian_times_vector(log_prob_wrapper_covar(model), q, x, lp, Ax); - double dx = 1e-5; - Eigen::VectorXd dr = L * x * dx; - stan::math::gradient(log_prob_wrapper_covar(model), q + dr / 2.0, lp, grad1); - stan::math::gradient(log_prob_wrapper_covar(model), q - dr / 2.0, lp, grad2); - return L.transpose() * (grad1 - grad2) / dx; - }; - - int max_iterations = 100; - double tol = 1e-3; - - return internal::power_method(hessian_vector, Eigen::VectorXd::Random(q.size()), max_iterations, tol); - } - } - - class auto_adaptation: public windowed_adaptation { - public: - explicit auto_adaptation(int n) - : windowed_adaptation("covariance") {} - /** - * Update the metric if at the end of an adaptation window. - * - * @tparam Model Type of model - * @param model Defines the log density - * @param covar[out] New metric - * @param covar_is_diagonal[out] Set to true if metric is diagonal, false otherwise - * @param q New MCMC draw - * @return True if this was the end of an adaptation window, false otherwise - */ - template - bool learn_covariance(const Model& model, Eigen::MatrixXd& covar, bool& covar_is_diagonal, const Eigen::VectorXd& q) { - if (adaptation_window()) - qs_.push_back(q); - - if (end_adaptation_window()) { - compute_next_window(); - - int M = q.size(); - int N = qs_.size(); - - Eigen::MatrixXd Y = Eigen::MatrixXd::Zero(M, N); - std::vector idxs(N); - for(int i = 0; i < qs_.size(); i++) - idxs[i] = i; - - std::random_shuffle(idxs.begin(), idxs.end()); - for(int i = 0; i < qs_.size(); i++) - Y.block(0, i, M, 1) = qs_[idxs[i]]; - - try { - bool use_dense = false; - for(auto state : { "selection", "refinement" }) { - Eigen::MatrixXd Ytrain; - Eigen::MatrixXd Ytest; - - if(state == "selection") { - int Mtest; - Mtest = int(0.2 * Y.cols()); - if(Mtest < 5) { - Mtest = 5; - } - - if(Y.cols() < 10) { - throw std::runtime_error("Each warmup stage must have at least 10 samples"); - } - - Ytrain = Y.block(0, 0, M, Y.cols() - Mtest); - Ytest = Y.block(0, Ytrain.cols(), M, Mtest); - } else { - Ytrain = Y; - } - - Eigen::MatrixXd cov_train = (Ytrain.cols() > 0) ? internal::covariance(Ytrain.transpose()) : Eigen::MatrixXd::Zero(M, M); - Eigen::MatrixXd cov_test = (Ytest.cols() > 0) ? internal::covariance(Ytest.transpose()) : Eigen::MatrixXd::Zero(M, M); - - Eigen::MatrixXd dense = (N / (N + 5.0)) * cov_train + - 1e-3 * (5.0 / (N + 5.0)) * Eigen::MatrixXd::Identity(cov_train.rows(), cov_train.cols()); - - Eigen::MatrixXd diag = dense.diagonal().asDiagonal(); - - covar = dense; - - if(state == "selection") { - Eigen::MatrixXd L_dense = dense.llt().matrixL(); - Eigen::MatrixXd L_diag = diag.diagonal().array().sqrt().matrix().asDiagonal(); - - double low_eigenvalue_dense = -1.0 / internal::eigenvalue_scaled_covariance(L_dense, cov_test); - double low_eigenvalue_diag = -1.0 / internal::eigenvalue_scaled_covariance(L_diag, cov_test); - - double c_dense = 0.0; - double c_diag = 0.0; - for(int i = 0; i < 5; i++) { - double high_eigenvalue_dense = internal::eigenvalue_scaled_hessian(model, L_dense, Ytest.block(0, i, M, 1)); - double high_eigenvalue_diag = internal::eigenvalue_scaled_hessian(model, L_diag, Ytest.block(0, i, M, 1)); - - c_dense = std::max(c_dense, std::sqrt(high_eigenvalue_dense / low_eigenvalue_dense)); - c_diag = std::max(c_diag, std::sqrt(high_eigenvalue_diag / low_eigenvalue_diag)); - } - - std::cout << "adapt: " << adapt_window_counter_ << ", which: dense, max: " << c_dense << std::endl; - std::cout << "adapt: " << adapt_window_counter_ << ", which: diag, max: " << c_diag << std::endl; - - if(c_dense < c_diag) { - use_dense = true; - } else { - use_dense = false; - } - } else { - if(use_dense) { - covar = dense; - covar_is_diagonal = false; - } else { - covar = diag; - covar_is_diagonal = true; - } - } - } - } catch(const std::exception& e) { - std::cout << e.what() << std::endl; - std::cout << "Exception while using auto adaptation, falling back to diagonal" << std::endl; - Eigen::MatrixXd cov = internal::covariance(Y.transpose()); - covar = ((N / (N + 5.0)) * cov.diagonal() - + 1e-3 * (5.0 / (N + 5.0)) * Eigen::VectorXd::Ones(cov.cols())).asDiagonal(); - covar_is_diagonal = true; - } - - ++adapt_window_counter_; - qs_.clear(); - - return true; - } - - ++adapt_window_counter_; - return false; - } - - protected: - std::vector< Eigen::VectorXd > qs_; - }; - - } // mcmc - -} // stan - -#endif diff --git a/src/stan/mcmc/stepsize_auto_adapter.hpp b/src/stan/mcmc/stepsize_auto_adapter.hpp deleted file mode 100644 index 0db72870d38..00000000000 --- a/src/stan/mcmc/stepsize_auto_adapter.hpp +++ /dev/null @@ -1,48 +0,0 @@ -#ifndef STAN_MCMC_STEPSIZE_AUTO_ADAPTER_HPP -#define STAN_MCMC_STEPSIZE_AUTO_ADAPTER_HPP - -#include -#include -#include -#include - -namespace stan { - - namespace mcmc { - - class stepsize_auto_adapter: public base_adapter { - public: - explicit stepsize_auto_adapter(int n) - : auto_adaptation_(n) { - } - - stepsize_adaptation& get_stepsize_adaptation() { - return stepsize_adaptation_; - } - - auto_adaptation& get_auto_adaptation() { - return auto_adaptation_; - } - - void set_window_params(unsigned int num_warmup, - unsigned int init_buffer, - unsigned int term_buffer, - unsigned int base_window, - callbacks::logger& logger) { - auto_adaptation_.set_window_params(num_warmup, - init_buffer, - term_buffer, - base_window, - logger); - } - - protected: - stepsize_adaptation stepsize_adaptation_; - auto_adaptation auto_adaptation_; - }; - - } // mcmc - -} // stan - -#endif diff --git a/src/test/unit/mcmc/auto_adaptation_learn_covariance_pick_dense_test.cpp b/src/test/unit/mcmc/auto_adaptation_learn_covariance_pick_dense_test.cpp deleted file mode 100644 index 0a8bd02139f..00000000000 --- a/src/test/unit/mcmc/auto_adaptation_learn_covariance_pick_dense_test.cpp +++ /dev/null @@ -1,64 +0,0 @@ -#include -#include -#include -#include - -TEST(McmcVarAdaptation, learn_covariance_pick_dense) { - std::fstream data_stream(std::string("").c_str(), std::fstream::in); - stan::io::dump data_var_context(data_stream); - data_stream.close(); - - std::stringstream output; - correlated_gaussian_model_namespace::correlated_gaussian_model - correlated_gaussian_model(data_var_context, &output); - - stan::test::unit::instrumented_logger logger; - - const int M = 2; - const int N = 20; - Eigen::MatrixXd qs(N, M); - qs << 0.256173753306128, -0.0238087093098673, - -1.63218152810157, -1.5309929638363, - -0.812451331685826, -1.15062373620068, - -1.49814775191801, -1.51310110681996, - 0.738630631536685, 1.03588205799336, - 0.472288580035284, 0.250286770328584, - -1.63634486169493, -1.6222798835089, - -0.400790615207103, -0.337669147200631, - -0.568779612417544, -0.424833495378187, - 0.103690913176746, 0.272885200284842, - -0.453017424229528, -0.504634004215693, - 3.34484533887237, 3.29418872328382, - -1.3376507113241, -1.32724775403694, - -0.137543235057544, -0.0290938109919368, - -1.58194496352741, -1.39338740677379, - 0.312166136194586, 0.336989933768233, - -0.628941448228566, -0.850758612234264, - -0.766816808981044, -0.645020468024267, - -0.75078110234827, -0.502544092120385, - -0.00694807494461906, -0.186748159558166; - - Eigen::MatrixXd covar(M, M); - bool covar_is_diagonal; - - Eigen::MatrixXd target_covar(M, M); - - target_covar << 1.0311414783609130, 1.0100577463968425, - 1.0100577463968425, 1.0148380697138280; - - stan::mcmc::auto_adaptation adapter(M); - adapter.set_window_params(50, 0, 0, N, logger); - - for (int i = 0; i < N; ++i) { - Eigen::VectorXd q = qs.block(i, 0, 1, M).transpose(); - adapter.learn_covariance(correlated_gaussian_model, covar, covar_is_diagonal, q); - } - - for (int i = 0; i < covar.size(); ++i) { - EXPECT_FLOAT_EQ(target_covar(i), covar(i)); - } - - EXPECT_EQ(covar_is_diagonal, false); - - EXPECT_EQ(0, logger.call_count()); -} diff --git a/src/test/unit/mcmc/auto_adaptation_learn_covariance_pick_diag_test.cpp b/src/test/unit/mcmc/auto_adaptation_learn_covariance_pick_diag_test.cpp deleted file mode 100644 index 26f3da56cef..00000000000 --- a/src/test/unit/mcmc/auto_adaptation_learn_covariance_pick_diag_test.cpp +++ /dev/null @@ -1,63 +0,0 @@ -#include -#include -#include -#include - -TEST(McmcVarAdaptation, learn_covariance_pick_diagonal) { - std::fstream data_stream(std::string("").c_str(), std::fstream::in); - stan::io::dump data_var_context(data_stream); - data_stream.close(); - - std::stringstream output; - independent_gaussian_model_namespace::independent_gaussian_model - independent_gaussian_model(data_var_context, &output); - - stan::test::unit::instrumented_logger logger; - - const int M = 2; - const int N = 20; - Eigen::MatrixXd qs(N, M); - qs << 0.607446257145326, 0.338465765807058, - 1.47389672467345, -1.0577986841911, - 1.02886652895522, 0.364277500948572, - 0.492316893603469, 2.19693408641558, - -0.931854393410476, 1.62634580968769, - -0.443145375724188, 0.902790875582656, - 0.517782110245233, -1.56724331755861, - -1.7556390097031, 0.310274990315213, - 0.0394975482340945, 0.366999438969482, - 1.29372950054929, 0.361369734821582, - -0.258301497542829, 0.166994731172984, - 0.492639248874412, -0.659502589885556, - 0.913729457222598, 1.99580706461809, - 0.669655370469707, -0.509028392475839, - -0.626041244059129, -0.771981104624195, - -0.842385483586737, 0.337166271031201, - 0.548177804329155, -0.0462961925005498, - 0.955748803092952, 1.3141117316189, - 0.335670079140694, 1.09112083087171, - 0.759245358940033, -1.11318882201676; - - Eigen::MatrixXd covar(M, M); - bool covar_is_diagonal; - - Eigen::MatrixXd target_covar(M, M); - - target_covar << 0.55350038163333048, 0.0, 0.0, 0.86122545968912112; - - stan::mcmc::auto_adaptation adapter(M); - adapter.set_window_params(50, 0, 0, N, logger); - - for (int i = 0; i < N; ++i) { - Eigen::VectorXd q = qs.block(i, 0, 1, M).transpose(); - adapter.learn_covariance(independent_gaussian_model, covar, covar_is_diagonal, q); - } - - for (int i = 0; i < covar.size(); ++i) { - EXPECT_FLOAT_EQ(target_covar(i), covar(i)); - } - - EXPECT_EQ(covar_is_diagonal, true); - - EXPECT_EQ(0, logger.call_count()); -} diff --git a/src/test/unit/mcmc/auto_adaptation_test.cpp b/src/test/unit/mcmc/auto_adaptation_test.cpp deleted file mode 100644 index 69890163251..00000000000 --- a/src/test/unit/mcmc/auto_adaptation_test.cpp +++ /dev/null @@ -1,170 +0,0 @@ -#include -#include -#include -#include - -TEST(McmcAutoAdaptation, test_covariance_zero_rows_zero_cols) { - Eigen::MatrixXd X1(0, 5); - - EXPECT_THROW(stan::mcmc::internal::covariance(X1), std::invalid_argument); - - Eigen::MatrixXd X2(1, 0); - - EXPECT_THROW(stan::mcmc::internal::covariance(X2), std::invalid_argument); -} - -TEST(McmcAutoAdaptation, test_covariance_one_row_one_col) { - Eigen::MatrixXd X1(1, 2); - Eigen::MatrixXd X2(3, 1); - - X1 << 1.0, 2.0; - X2 << 1.0, 2.0, 3.0; - - Eigen::MatrixXd cov1 = stan::mcmc::internal::covariance(X1); - Eigen::MatrixXd cov2 = stan::mcmc::internal::covariance(X2); - - ASSERT_EQ(cov1.rows(), 2); - ASSERT_EQ(cov1.cols(), 2); - - ASSERT_EQ(cov2.rows(), 1); - ASSERT_EQ(cov2.cols(), 1); - - for(int i = 0; i < cov1.size(); ++i) { - ASSERT_FLOAT_EQ(cov1(i), 0.0); - } - - ASSERT_FLOAT_EQ(cov2(0), 1.0); -} - -TEST(McmcAutoAdaptation, test_covariance) { - Eigen::MatrixXd X1(3, 2); - Eigen::MatrixXd X2(2, 3); - - X1 << 0.0, -1.0, 0.5, -2.7, 3.0, 5.0; - X2 << 0.0, 3, -2.7, 0.5, -1, 5.0; - - Eigen::MatrixXd cov1 = stan::mcmc::internal::covariance(X1); - Eigen::MatrixXd cov2 = stan::mcmc::internal::covariance(X2); - - Eigen::MatrixXd cov1_ref(2, 2); - Eigen::MatrixXd cov2_ref(3, 3); - - cov1_ref << 2.5833333333333335, 6.0666666666666664, - 6.0666666666666664, 16.3633333333333333; - - cov2_ref << 0.125, -1.0, 1.925, - -1.000, 8.0, -15.4, - 1.925, -15.4, 29.645; - - ASSERT_EQ(cov1.rows(), cov1_ref.rows()); - ASSERT_EQ(cov1.cols(), cov1_ref.cols()); - - ASSERT_EQ(cov2.rows(), cov2_ref.rows()); - ASSERT_EQ(cov2.cols(), cov2_ref.cols()); - - for(int i = 0; i < cov1_ref.size(); ++i) { - ASSERT_FLOAT_EQ(cov1(i), cov1_ref(i)); - } - - for(int i = 0; i < cov2_ref.size(); ++i) { - ASSERT_FLOAT_EQ(cov2(i), cov2_ref(i)); - } -} - -TEST(McmcAutoAdaptation, power_method) { - Eigen::MatrixXd X(2, 2); - Eigen::VectorXd x0(2); - - X << 2.0, 0.5, 0.5, 1.0; - x0 << 1.0, 0.0; - - const int max_iterations = 10; - const double tol = 1e-10; - - auto Av = [&](const Eigen::VectorXd& v) { return X * v; }; - - int max_iterations_1 = max_iterations; - double tol_1 = tol; - - double eval = stan::mcmc::internal::power_method(Av, x0, max_iterations_1, tol_1); - - EXPECT_FLOAT_EQ(eval, 2.20710678118654746); -} - -TEST(McmcAutoAdaptation, power_method_tol_check) { - Eigen::MatrixXd X(2, 2); - Eigen::VectorXd x0(2); - - X << 2.0, 0.5, 0.5, 1.0; - x0 << 1.0, 0.0; - - const int max_iterations = 1000; - const double tol = 1e-12; - - auto Av = [&](const Eigen::VectorXd& v) { return X * v; }; - - int max_iterations_1 = max_iterations; - double tol_1 = tol; - double eval = stan::mcmc::internal::power_method(Av, x0, max_iterations_1, tol_1); - - EXPECT_LT(tol_1, tol); -} - -TEST(McmcAutoAdaptation, power_method_iter_check) { - Eigen::MatrixXd X(2, 2); - Eigen::VectorXd x0(2); - - X << 2.0, 0.5, 0.5, 1.0; - x0 << 1.0, 0.0; - - const int max_iterations = 10; - const double tol = 1e-50; - - auto Av = [&](const Eigen::VectorXd& v) { return X * v; }; - - int max_iterations_1 = max_iterations; - double tol_1 = tol; - double eval = stan::mcmc::internal::power_method(Av, x0, max_iterations_1, tol_1); - - EXPECT_GT(tol_1, tol); - EXPECT_EQ(max_iterations_1, max_iterations); -} - -// The checks in here are very coarse because eigenvalue_scaled_covariance -// only estimates things with low precision -TEST(McmcAutoAdaptation, eigenvalue_scaled_covariance) { - Eigen::MatrixXd L(2, 2), Sigma(2, 2); - - L << 1.0, 0.0, 0.5, 1.0; - Sigma << 2.0, 0.7, 0.7, 1.3; - - double eval = stan::mcmc::internal::eigenvalue_scaled_covariance(L, Sigma); - - EXPECT_LT(std::abs(eval - 2.0908326913195983) / eval, 1e-2); - - L << 2.0, 0.0, 0.7, 1.3; - - eval = stan::mcmc::internal::eigenvalue_scaled_covariance(L, Sigma); - - EXPECT_LT(std::abs(eval - 0.62426035502958577) / eval, 1e-2); -} - -// The checks in here are very coarse because eigenvalue_scaled_hessian -// only estimates things with low precision -TEST(McmcAutoAdaptation, eigenvalue_scaled_hessian) { - std::fstream data_stream(std::string("").c_str(), std::fstream::in); - stan::io::dump data_var_context(data_stream); - data_stream.close(); - - std::stringstream output; - known_hessian_model_namespace::known_hessian_model known_hessian_model(data_var_context, &output); - - Eigen::MatrixXd L(3, 3); - Eigen::VectorXd q(3); - L << 2.0, 0.0, 0.0, 0.7, 1.3, 0.0, -1.5, 2.0, 4.0; - q << 0.0, 0.0, 0.0; - - double eval = stan::mcmc::internal::eigenvalue_scaled_hessian(known_hessian_model, L, q); - - EXPECT_LT(std::abs(eval - 22.8141075806892850) / eval, 1e-2); -} From 36af38db156654c8bb8bf2077e73041d105ee080 Mon Sep 17 00:00:00 2001 From: Ben Date: Fri, 28 Feb 2020 13:30:04 -0500 Subject: [PATCH 73/73] Updated math includes Got auto adaptation working with merged mpi_warmup_v2 --- src/stan/analyze/mcmc/autocovariance.hpp | 3 +- .../mcmc/compute_effective_sample_size.hpp | 2 +- .../compute_potential_scale_reduction.hpp | 2 +- src/stan/io/array_var_context.hpp | 2 +- src/stan/io/dump.hpp | 2 +- src/stan/io/reader.hpp | 2 +- src/stan/io/stan_csv_reader.hpp | 2 +- src/stan/io/writer.hpp | 2 +- src/stan/mcmc/chains.hpp | 2 +- src/stan/mcmc/covar_adaptation.hpp | 2 +- .../mcmc/hmc/hamiltonians/auto_e_metric.hpp | 2 +- .../hmc/hamiltonians/base_hamiltonian.hpp | 2 +- .../mcmc/hmc/hamiltonians/dense_e_metric.hpp | 2 +- src/stan/mcmc/hmc/hamiltonians/ps_point.hpp | 2 +- .../mcmc/hmc/integrators/expl_leapfrog.hpp | 2 +- src/stan/mcmc/hmc/nuts/adapt_auto_e_nuts.hpp | 40 ++++--- src/stan/mcmc/hmc/nuts/base_nuts.hpp | 2 +- src/stan/mcmc/mpi_auto_adaptation.hpp | 109 ++++++++++++------ src/stan/mcmc/mpi_covar_adaptation.hpp | 34 +++--- src/stan/mcmc/mpi_metric_adaptation.hpp | 2 +- src/stan/mcmc/mpi_var_adaptation.hpp | 2 +- src/stan/mcmc/sample.hpp | 2 +- src/stan/mcmc/var_adaptation.hpp | 2 +- src/stan/model/gradient.hpp | 2 +- src/stan/model/indexing/lvalue.hpp | 2 +- src/stan/model/indexing/rvalue.hpp | 2 +- src/stan/model/log_prob_grad.hpp | 2 +- src/stan/model/log_prob_propto.hpp | 2 +- src/stan/model/model_functional.hpp | 2 +- src/stan/optimization/bfgs.hpp | 2 +- src/stan/services/sample/fixed_param.hpp | 2 +- .../services/sample/hmc_nuts_auto_e_adapt.hpp | 11 +- src/stan/services/sample/hmc_nuts_dense_e.hpp | 3 +- .../sample/hmc_nuts_dense_e_adapt.hpp | 3 +- src/stan/services/sample/hmc_nuts_diag_e.hpp | 3 +- .../services/sample/hmc_nuts_diag_e_adapt.hpp | 3 +- src/stan/services/sample/hmc_nuts_unit_e.hpp | 2 +- .../services/sample/hmc_nuts_unit_e_adapt.hpp | 2 +- .../services/sample/hmc_static_dense_e.hpp | 3 +- .../sample/hmc_static_dense_e_adapt.hpp | 3 +- .../services/sample/hmc_static_diag_e.hpp | 3 +- .../sample/hmc_static_diag_e_adapt.hpp | 3 +- .../services/sample/hmc_static_unit_e.hpp | 2 +- .../sample/hmc_static_unit_e_adapt.hpp | 2 +- src/stan/services/util/initialize.hpp | 2 +- .../services/util/read_dense_inv_metric.hpp | 2 +- .../util/validate_dense_inv_metric.hpp | 2 +- .../util/validate_diag_inv_metric.hpp | 2 +- src/stan/variational/base_family.hpp | 2 +- .../variational/families/normal_fullrank.hpp | 2 +- .../variational/families/normal_meanfield.hpp | 2 +- src/stan/variational/print_progress.hpp | 4 +- 52 files changed, 170 insertions(+), 131 deletions(-) diff --git a/src/stan/analyze/mcmc/autocovariance.hpp b/src/stan/analyze/mcmc/autocovariance.hpp index 74c0634cfc2..36130732dc5 100644 --- a/src/stan/analyze/mcmc/autocovariance.hpp +++ b/src/stan/analyze/mcmc/autocovariance.hpp @@ -1,8 +1,7 @@ #ifndef STAN_ANALYZE_MCMC_AUTOCOVARIANCE_HPP #define STAN_ANALYZE_MCMC_AUTOCOVARIANCE_HPP -#include -#include +#include #include #include #include diff --git a/src/stan/analyze/mcmc/compute_effective_sample_size.hpp b/src/stan/analyze/mcmc/compute_effective_sample_size.hpp index 75502d12144..4d58330b202 100644 --- a/src/stan/analyze/mcmc/compute_effective_sample_size.hpp +++ b/src/stan/analyze/mcmc/compute_effective_sample_size.hpp @@ -1,7 +1,7 @@ #ifndef STAN_ANALYZE_MCMC_COMPUTE_EFFECTIVE_SAMPLE_SIZE_HPP #define STAN_ANALYZE_MCMC_COMPUTE_EFFECTIVE_SAMPLE_SIZE_HPP -#include +#include #include #include #include diff --git a/src/stan/analyze/mcmc/compute_potential_scale_reduction.hpp b/src/stan/analyze/mcmc/compute_potential_scale_reduction.hpp index 7189b17b590..22ff44c2eb1 100644 --- a/src/stan/analyze/mcmc/compute_potential_scale_reduction.hpp +++ b/src/stan/analyze/mcmc/compute_potential_scale_reduction.hpp @@ -1,7 +1,7 @@ #ifndef STAN_ANALYZE_MCMC_COMPUTE_POTENTIAL_SCALE_REDUCTION_HPP #define STAN_ANALYZE_MCMC_COMPUTE_POTENTIAL_SCALE_REDUCTION_HPP -#include +#include #include #include #include diff --git a/src/stan/io/array_var_context.hpp b/src/stan/io/array_var_context.hpp index af2bf20b87f..8b433ba179d 100644 --- a/src/stan/io/array_var_context.hpp +++ b/src/stan/io/array_var_context.hpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/stan/io/dump.hpp b/src/stan/io/dump.hpp index 05e28b9dcab..f2cd2c8ca93 100644 --- a/src/stan/io/dump.hpp +++ b/src/stan/io/dump.hpp @@ -3,7 +3,7 @@ #include #include -#include +#include #include #include #include diff --git a/src/stan/io/reader.hpp b/src/stan/io/reader.hpp index feaf8b4c51e..912f23e7478 100644 --- a/src/stan/io/reader.hpp +++ b/src/stan/io/reader.hpp @@ -2,7 +2,7 @@ #define STAN_IO_READER_HPP #include -#include +#include #include #include #include diff --git a/src/stan/io/stan_csv_reader.hpp b/src/stan/io/stan_csv_reader.hpp index e39d29538ea..c92acb84b8f 100644 --- a/src/stan/io/stan_csv_reader.hpp +++ b/src/stan/io/stan_csv_reader.hpp @@ -2,7 +2,7 @@ #define STAN_IO_STAN_CSV_READER_HPP #include -#include +#include #include #include #include diff --git a/src/stan/io/writer.hpp b/src/stan/io/writer.hpp index 609eee0bb28..d534716bf61 100644 --- a/src/stan/io/writer.hpp +++ b/src/stan/io/writer.hpp @@ -1,7 +1,7 @@ #ifndef STAN_IO_WRITER_HPP #define STAN_IO_WRITER_HPP -#include +#include #include #include diff --git a/src/stan/mcmc/chains.hpp b/src/stan/mcmc/chains.hpp index ff20b91a97e..620358c9fac 100644 --- a/src/stan/mcmc/chains.hpp +++ b/src/stan/mcmc/chains.hpp @@ -2,7 +2,7 @@ #define STAN_MCMC_CHAINS_HPP #include -#include +#include #include #include #include diff --git a/src/stan/mcmc/covar_adaptation.hpp b/src/stan/mcmc/covar_adaptation.hpp index 7a1a07b6220..ef01301ccc1 100644 --- a/src/stan/mcmc/covar_adaptation.hpp +++ b/src/stan/mcmc/covar_adaptation.hpp @@ -1,7 +1,7 @@ #ifndef STAN_MCMC_COVAR_ADAPTATION_HPP #define STAN_MCMC_COVAR_ADAPTATION_HPP -#include +#include #include #include diff --git a/src/stan/mcmc/hmc/hamiltonians/auto_e_metric.hpp b/src/stan/mcmc/hmc/hamiltonians/auto_e_metric.hpp index bdacb513e72..a81699d7ce1 100644 --- a/src/stan/mcmc/hmc/hamiltonians/auto_e_metric.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/auto_e_metric.hpp @@ -2,7 +2,7 @@ #define STAN_MCMC_HMC_HAMILTONIANS_AUTO_E_METRIC_HPP #include -#include +#include #include #include #include diff --git a/src/stan/mcmc/hmc/hamiltonians/base_hamiltonian.hpp b/src/stan/mcmc/hmc/hamiltonians/base_hamiltonian.hpp index a0405d539cc..2ce64932984 100644 --- a/src/stan/mcmc/hmc/hamiltonians/base_hamiltonian.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/base_hamiltonian.hpp @@ -2,7 +2,7 @@ #define STAN_MCMC_HMC_HAMILTONIANS_BASE_HAMILTONIAN_HPP #include -#include +#include #include #include #include diff --git a/src/stan/mcmc/hmc/hamiltonians/dense_e_metric.hpp b/src/stan/mcmc/hmc/hamiltonians/dense_e_metric.hpp index 9799c356a9e..09de4898153 100644 --- a/src/stan/mcmc/hmc/hamiltonians/dense_e_metric.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/dense_e_metric.hpp @@ -2,7 +2,7 @@ #define STAN_MCMC_HMC_HAMILTONIANS_DENSE_E_METRIC_HPP #include -#include +#include #include #include #include diff --git a/src/stan/mcmc/hmc/hamiltonians/ps_point.hpp b/src/stan/mcmc/hmc/hamiltonians/ps_point.hpp index 8a292c46c4a..d1811c0290f 100644 --- a/src/stan/mcmc/hmc/hamiltonians/ps_point.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/ps_point.hpp @@ -2,7 +2,7 @@ #define STAN_MCMC_HMC_HAMILTONIANS_PS_POINT_HPP #include -#include +#include #include #include #include diff --git a/src/stan/mcmc/hmc/integrators/expl_leapfrog.hpp b/src/stan/mcmc/hmc/integrators/expl_leapfrog.hpp index 5aae58c2c40..eae05136e6d 100644 --- a/src/stan/mcmc/hmc/integrators/expl_leapfrog.hpp +++ b/src/stan/mcmc/hmc/integrators/expl_leapfrog.hpp @@ -3,7 +3,7 @@ #include #include -#include +#include namespace stan { namespace mcmc { diff --git a/src/stan/mcmc/hmc/nuts/adapt_auto_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_auto_e_nuts.hpp index d84d7c7aca3..c6deef4f4d1 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_auto_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_auto_e_nuts.hpp @@ -35,30 +35,32 @@ class adapt_auto_e_nuts : public auto_e_nuts, this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, s.accept_stat()); - bool update; if (this -> use_cross_chain_adapt()) { - this -> add_cross_chain_sample(s.log_prob()); - update = this -> cross_chain_adaptation(logger); - if (this -> is_cross_chain_adapted()) { - update = false; - } - } else { - update = this->covar_adaptation_.learn_covariance(this->z_.inv_e_metric_, - this->z_.q); - } + this -> add_cross_chain_sample(s.log_prob()); + bool update = this -> cross_chain_adaptation(logger); + if (this -> is_cross_chain_adapted()) { + update = false; + } - if (update) { - //std::cout << this->z_.inv_e_metric_ << std::endl; - this->z_.is_diagonal_ = reinterpret_cast *>(this->var_adapt)->is_diagonal_; + if (update) { + this->z_.is_diagonal_ = reinterpret_cast *>(this->metric_adapt)->is_diagonal_; - this->init_stepsize(logger); + this->init_stepsize(logger); - this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); - this->stepsize_adaptation_.restart(); + this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); + this->stepsize_adaptation_.restart(); + + this->set_cross_chain_stepsize(); + } + } else { + bool update = this->covar_adaptation_.learn_covariance(this->z_.inv_e_metric_, + this->z_.q); + if (update) { + this->init_stepsize(logger); - if (this -> use_cross_chain_adapt()) { - this->set_cross_chain_stepsize(); - } + this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); + this->stepsize_adaptation_.restart(); + } } } return s; diff --git a/src/stan/mcmc/hmc/nuts/base_nuts.hpp b/src/stan/mcmc/hmc/nuts/base_nuts.hpp index ae57d970440..b8912e69f72 100644 --- a/src/stan/mcmc/hmc/nuts/base_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/base_nuts.hpp @@ -3,7 +3,7 @@ #include #include -#include +#include #include #include #include diff --git a/src/stan/mcmc/mpi_auto_adaptation.hpp b/src/stan/mcmc/mpi_auto_adaptation.hpp index 6377281b579..76d5b06624f 100644 --- a/src/stan/mcmc/mpi_auto_adaptation.hpp +++ b/src/stan/mcmc/mpi_auto_adaptation.hpp @@ -1,14 +1,10 @@ #ifndef STAN_MCMC_MPI_AUTO_ADAPTATION_HPP #define STAN_MCMC_MPI_AUTO_ADAPTATION_HPP -#include +#include #include #include -#ifdef STAN_LANG_MPI -#include -#endif - namespace stan { namespace mcmc { @@ -147,23 +143,60 @@ class mpi_auto_adaptation : public mpi_metric_adaptation { #ifdef STAN_LANG_MPI using est_t = stan::math::mpi::mpi_covar_estimator; - int window_size_; + int num_chains_; int n_params_; + int window_size_; + int init_buffer_; + int init_draw_counter_; + int num_iterations_; + int draw_req_counter_; + int draws_collected_counter_; Model& model_; std::deque last_qs_; public: - est_t estimator; + std::vector reqs; + std::vector draws; + std::vector num_draws; + Eigen::MatrixXd Y_; bool is_diagonal_; - mpi_auto_adaptation(Model& model, int n_params, int num_iterations, int window_size) - : window_size_(window_size), + mpi_auto_adaptation(Model& model, int n_params, int num_chains, + int num_iterations, int window_size, int init_buffer) + : num_chains_(num_chains), + window_size_(window_size), + init_buffer_(init_buffer), + init_draw_counter_(0), + draw_req_counter_(0), + draws_collected_counter_(0), n_params_(n_params), + reqs(window_size), + num_iterations_(num_iterations), model_(model), - estimator(n_params, num_iterations), - is_diagonal_(false) {} + draws(window_size, Eigen::MatrixXd(n_params, num_chains)), + Y_(num_chains * num_iterations, n_params), + is_diagonal_(false) { + std::cout << "numchains " << num_chains_ << + ", n_params " << n_params_ << + ", num_iter " << num_iterations_ << + ", window_size " << window_size_ << + ", init_buffer " << init_buffer_ << std::endl; + } + + void reset_req() { + draw_req_counter_ = 0; + reqs.clear(); + reqs.resize(window_size_); + } virtual void add_sample(const Eigen::VectorXd& q, int curr_win_count) { - estimator.add_sample(q); + const stan::math::mpi::Communicator& comm = + stan::math::mpi::Session::inter_chain_comm(num_chains_); + + MPI_Iallgather(q.data(), q.size(), MPI_DOUBLE, + draws[draw_req_counter_].data(), q.size(), MPI_DOUBLE, + comm.comm(), &reqs[draw_req_counter_]); + draw_req_counter_++; + last_qs_.push_back(q); if(last_qs_.size() > 5) { last_qs_.pop_front(); @@ -172,9 +205,13 @@ class mpi_auto_adaptation : public mpi_metric_adaptation { virtual void learn_metric(Eigen::MatrixXd& covar, int win, int curr_win_count, const stan::math::mpi::Communicator& comm) { - int col_begin = std::max(win * window_size_, init_bufer_size); - int num_draws = std::max(curr_win_count * window_size_ - col_begin, 0); + //std::cout << "win: " << win << ", current_win_count: " << curr_win_count << std::endl << std::flush; + collect_draws(win, comm); + int first_draw = num_chains_ * (std::max(win - 1, 0) * window_size_ + (win > 0) * (window_size_ - init_buffer_)); + int num_draws = std::max(num_chains_ * draws_collected_counter_ - first_draw, 0); + //std::cout << "first_draw: " << first_draw << ", num_draws: " << num_draws << std::endl << std::flush; + int M = n_params_; try { @@ -183,8 +220,6 @@ class mpi_auto_adaptation : public mpi_metric_adaptation { Eigen::MatrixXd cov_train = Eigen::MatrixXd::Zero(M, M); Eigen::MatrixXd cov_test = Eigen::MatrixXd::Zero(M, M); - //std::cout << "col_begin: " << col_begin << ", num_draws: " << num_draws << std::endl; - int Ntest; if(state == "selection") { Ntest = int(0.2 * num_draws); @@ -196,14 +231,14 @@ class mpi_auto_adaptation : public mpi_metric_adaptation { throw std::runtime_error("Each warmup stage must have at least 10 samples"); } - learn_covariance(cov_train, col_begin, num_draws - Ntest, comm); - learn_covariance(cov_test, col_begin + num_draws - Ntest, Ntest, comm); - //Ytrain = Y.block(0, 0, M, Y.cols() - Mtest); - //Ytest = Y.block(0, Ytrain.cols(), M, Mtest); + Eigen::MatrixXd Ytrain = Y_.block(first_draw, 0, num_draws - Ntest, n_params_); + Eigen::MatrixXd Ytest = Y_.block(first_draw + num_draws - Ntest, 0, Ntest, n_params_); + cov_train = internal::covariance(Ytrain); + cov_test = internal::covariance(Ytest); } else { - learn_covariance(cov_train, col_begin, num_draws, comm); Ntest = 0; - //Ytrain = Y; + Eigen::MatrixXd Ytrain = Y_.block(first_draw, 0, num_draws, n_params_); + cov_train = internal::covariance(Ytrain); } Eigen::MatrixXd dense = ((num_draws - Ntest) / ((num_draws - Ntest) + 5.0)) * cov_train + @@ -252,26 +287,34 @@ class mpi_auto_adaptation : public mpi_metric_adaptation { std::cout << e.what() << std::endl; std::cout << "Exception while using auto adaptation, falling back to diagonal" << std::endl; Eigen::MatrixXd cov = Eigen::MatrixXd::Zero(M, M); - learn_covariance(cov, col_begin, num_draws, comm); covar = ((num_draws / (num_draws + 5.0)) * cov.diagonal() + 1e-3 * (5.0 / (num_draws + 5.0)) * Eigen::VectorXd::Ones(cov.cols())).asDiagonal(); is_diagonal_ = true; } } - void learn_covariance(Eigen::MatrixXd& covar, - int col_begin, int n_samples, - const stan::math::mpi::Communicator& comm) { - estimator.sample_covariance(covar, col_begin, n_samples, comm); - //double n = static_cast(estimator.num_samples(comm)); - //covar = (n / (n + 5.0)) * covar - // + 1e-3 * (5.0 / (n + 5.0)) - // * Eigen::MatrixXd::Identity(covar.rows(), covar.cols()); - // restart(); + void collect_draws(int win, const stan::math::mpi::Communicator& comm) { + int finished = 0; + int index; + int flag = 0; + + while(finished < draw_req_counter_) { + MPI_Testany(draw_req_counter_, reqs.data(), &index, &flag, MPI_STATUS_IGNORE); + if (flag) { + finished++; + for (int chain = 0; chain < num_chains_; ++chain) { + Eigen::RowVectorXd draw = draws[index].col(chain).transpose(); + Y_.block((draws_collected_counter_ + index) * num_chains_ + chain, 0, 1, n_params_) = draw; + } + } + } + + draws_collected_counter_ += draw_req_counter_; + + reset_req(); } virtual void restart() { - estimator.restart(); } #else public: diff --git a/src/stan/mcmc/mpi_covar_adaptation.hpp b/src/stan/mcmc/mpi_covar_adaptation.hpp index 6c9e5d457a3..aa1e4a64652 100644 --- a/src/stan/mcmc/mpi_covar_adaptation.hpp +++ b/src/stan/mcmc/mpi_covar_adaptation.hpp @@ -1,7 +1,7 @@ #ifndef STAN_MCMC_MPI_COVAR_ADAPTATION_HPP #define STAN_MCMC_MPI_COVAR_ADAPTATION_HPP -#include +#include #include #include @@ -13,7 +13,7 @@ namespace stan { namespace mcmc { - class mpi_covar_adaptation : public mpi_metric_adaptation { +class mpi_covar_adaptation : public mpi_metric_adaptation { #ifdef STAN_LANG_MPI // using est_t = stan::math::mpi::mpi_covar_estimator; using est_t = stan::math::welford_covar_estimator; @@ -37,23 +37,23 @@ namespace mcmc { num_draws(num_iterations / window_size, 0) {} - void reset_req() { - draw_req_counter_ = 0; - reqs.clear(); - reqs.resize(window_size_); - } + void reset_req() { + draw_req_counter_ = 0; + reqs.clear(); + reqs.resize(window_size_); + } - virtual void add_sample(const Eigen::VectorXd& q, int curr_win_count) { - const stan::math::mpi::Communicator& comm = - stan::math::mpi::Session::inter_chain_comm(num_chains_); - MPI_Iallgather(q.data(), q.size(), MPI_DOUBLE, - draws[draw_req_counter_].data(), q.size(), MPI_DOUBLE, - comm.comm(), &reqs[draw_req_counter_]); - draw_req_counter_++; - for (int win = 0; win < curr_win_count; ++win) { - num_draws[win]++; - } + virtual void add_sample(const Eigen::VectorXd& q, int curr_win_count) { + const stan::math::mpi::Communicator& comm = + stan::math::mpi::Session::inter_chain_comm(num_chains_); + MPI_Iallgather(q.data(), q.size(), MPI_DOUBLE, + draws[draw_req_counter_].data(), q.size(), MPI_DOUBLE, + comm.comm(), &reqs[draw_req_counter_]); + draw_req_counter_++; + for (int win = 0; win < curr_win_count; ++win) { + num_draws[win]++; } + } virtual void learn_metric(Eigen::MatrixXd& covar, int win, int curr_win_count, const stan::math::mpi::Communicator& comm) { diff --git a/src/stan/mcmc/mpi_metric_adaptation.hpp b/src/stan/mcmc/mpi_metric_adaptation.hpp index 21c6df72c9f..f2b51603737 100644 --- a/src/stan/mcmc/mpi_metric_adaptation.hpp +++ b/src/stan/mcmc/mpi_metric_adaptation.hpp @@ -1,7 +1,7 @@ #ifndef STAN_MCMC_MPI_METRIC_ADAPTATION_HPP #define STAN_MCMC_MPI_METRIC_ADAPTATION_HPP -#include +#include #include #ifdef STAN_LANG_MPI diff --git a/src/stan/mcmc/mpi_var_adaptation.hpp b/src/stan/mcmc/mpi_var_adaptation.hpp index 6534cc44dbc..33a6f21436d 100644 --- a/src/stan/mcmc/mpi_var_adaptation.hpp +++ b/src/stan/mcmc/mpi_var_adaptation.hpp @@ -1,7 +1,7 @@ #ifndef STAN_MCMC_MPI_VAR_ADAPTATION_HPP #define STAN_MCMC_MPI_VAR_ADAPTATION_HPP -#include +#include #include #include diff --git a/src/stan/mcmc/sample.hpp b/src/stan/mcmc/sample.hpp index 4450f7fba75..9bf45c91201 100644 --- a/src/stan/mcmc/sample.hpp +++ b/src/stan/mcmc/sample.hpp @@ -1,7 +1,7 @@ #ifndef STAN_MCMC_SAMPLE_HPP #define STAN_MCMC_SAMPLE_HPP -#include +#include #include #include diff --git a/src/stan/mcmc/var_adaptation.hpp b/src/stan/mcmc/var_adaptation.hpp index a33165e34d5..c9ca255bad0 100644 --- a/src/stan/mcmc/var_adaptation.hpp +++ b/src/stan/mcmc/var_adaptation.hpp @@ -1,7 +1,7 @@ #ifndef STAN_MCMC_VAR_ADAPTATION_HPP #define STAN_MCMC_VAR_ADAPTATION_HPP -#include +#include #include #include diff --git a/src/stan/model/gradient.hpp b/src/stan/model/gradient.hpp index f89640dfd52..be7cf47f9b9 100644 --- a/src/stan/model/gradient.hpp +++ b/src/stan/model/gradient.hpp @@ -3,7 +3,7 @@ #include #include -#include +#include #include #include #include diff --git a/src/stan/model/indexing/lvalue.hpp b/src/stan/model/indexing/lvalue.hpp index 963c07d5931..69e607f9fba 100644 --- a/src/stan/model/indexing/lvalue.hpp +++ b/src/stan/model/indexing/lvalue.hpp @@ -3,7 +3,7 @@ #include #include -#include +#include #include #include #include diff --git a/src/stan/model/indexing/rvalue.hpp b/src/stan/model/indexing/rvalue.hpp index 6d2c833bf2f..2a714ee1a72 100644 --- a/src/stan/model/indexing/rvalue.hpp +++ b/src/stan/model/indexing/rvalue.hpp @@ -3,7 +3,7 @@ #include #include -#include +#include #include #include #include diff --git a/src/stan/model/log_prob_grad.hpp b/src/stan/model/log_prob_grad.hpp index c55927a229f..6e5eb11771e 100644 --- a/src/stan/model/log_prob_grad.hpp +++ b/src/stan/model/log_prob_grad.hpp @@ -1,7 +1,7 @@ #ifndef STAN_MODEL_LOG_PROB_GRAD_HPP #define STAN_MODEL_LOG_PROB_GRAD_HPP -#include +#include #include #include diff --git a/src/stan/model/log_prob_propto.hpp b/src/stan/model/log_prob_propto.hpp index 52667335131..bde538d15d9 100644 --- a/src/stan/model/log_prob_propto.hpp +++ b/src/stan/model/log_prob_propto.hpp @@ -1,7 +1,7 @@ #ifndef STAN_MODEL_LOG_PROB_PROPTO_HPP #define STAN_MODEL_LOG_PROB_PROPTO_HPP -#include +#include #include #include diff --git a/src/stan/model/model_functional.hpp b/src/stan/model/model_functional.hpp index 48735e46a03..0324804a5a1 100644 --- a/src/stan/model/model_functional.hpp +++ b/src/stan/model/model_functional.hpp @@ -1,7 +1,7 @@ #ifndef STAN_MODEL_MODEL_FUNCTIONAL_HPP #define STAN_MODEL_MODEL_FUNCTIONAL_HPP -#include +#include #include namespace stan { diff --git a/src/stan/optimization/bfgs.hpp b/src/stan/optimization/bfgs.hpp index 48936576481..beee8ecb649 100644 --- a/src/stan/optimization/bfgs.hpp +++ b/src/stan/optimization/bfgs.hpp @@ -1,7 +1,7 @@ #ifndef STAN_OPTIMIZATION_BFGS_HPP #define STAN_OPTIMIZATION_BFGS_HPP -#include +#include #include #include #include diff --git a/src/stan/services/sample/fixed_param.hpp b/src/stan/services/sample/fixed_param.hpp index 8954e904ac5..85d2c508f01 100644 --- a/src/stan/services/sample/fixed_param.hpp +++ b/src/stan/services/sample/fixed_param.hpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/stan/services/sample/hmc_nuts_auto_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_auto_e_adapt.hpp index 4bb4a7b996a..30d32f6dde9 100644 --- a/src/stan/services/sample/hmc_nuts_auto_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_auto_e_adapt.hpp @@ -1,8 +1,7 @@ #ifndef STAN_SERVICES_SAMPLE_HMC_NUTS_AUTO_E_ADAPT_HPP #define STAN_SERVICES_SAMPLE_HMC_NUTS_AUTO_E_ADAPT_HPP -#include -#include +#include #include #include #include @@ -108,10 +107,14 @@ namespace stan { window, logger); // cross chain adaptation - sampler.set_cross_chain_adaptation_params(num_warmup, + sampler.set_cross_chain_adaptation_params(num_warmup, init_buffer, term_buffer, cross_chain_window, num_cross_chains, cross_chain_rhat, cross_chain_ess); - mcmc::mpi_auto_adaptation var_adapt(model, model.num_params_r(), num_warmup, cross_chain_window); + std::cout << "num warmup: " << num_warmup << std::endl; + std::cout << "cross chain window: " << cross_chain_window << std::endl; + mcmc::mpi_auto_adaptation var_adapt(model, model.num_params_r(), + num_cross_chains, num_warmup, + cross_chain_window, init_buffer); sampler.set_cross_chain_metric_adaptation(&var_adapt); util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, diff --git a/src/stan/services/sample/hmc_nuts_dense_e.hpp b/src/stan/services/sample/hmc_nuts_dense_e.hpp index e4b080fb612..a9c9129ea2a 100644 --- a/src/stan/services/sample/hmc_nuts_dense_e.hpp +++ b/src/stan/services/sample/hmc_nuts_dense_e.hpp @@ -4,8 +4,7 @@ #include #include #include -#include -#include +#include #include #include #include diff --git a/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp index f1eb7ab58c5..c9a6a7d599f 100644 --- a/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp @@ -1,8 +1,7 @@ #ifndef STAN_SERVICES_SAMPLE_HMC_NUTS_DENSE_E_ADAPT_HPP #define STAN_SERVICES_SAMPLE_HMC_NUTS_DENSE_E_ADAPT_HPP -#include -#include +#include #include #include #include diff --git a/src/stan/services/sample/hmc_nuts_diag_e.hpp b/src/stan/services/sample/hmc_nuts_diag_e.hpp index 86bac7ee3f2..67f698bcdd5 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e.hpp @@ -1,8 +1,7 @@ #ifndef STAN_SERVICES_SAMPLE_HMC_NUTS_DIAG_E_HPP #define STAN_SERVICES_SAMPLE_HMC_NUTS_DIAG_E_HPP -#include -#include +#include #include #include #include diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp index 34fdd24059d..13a713e9030 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp @@ -1,8 +1,7 @@ #ifndef STAN_SERVICES_SAMPLE_HMC_NUTS_DIAG_E_ADAPT_HPP #define STAN_SERVICES_SAMPLE_HMC_NUTS_DIAG_E_ADAPT_HPP -#include -#include +#include #include #include #include diff --git a/src/stan/services/sample/hmc_nuts_unit_e.hpp b/src/stan/services/sample/hmc_nuts_unit_e.hpp index 17e51bf468f..cea0607eced 100644 --- a/src/stan/services/sample/hmc_nuts_unit_e.hpp +++ b/src/stan/services/sample/hmc_nuts_unit_e.hpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp index 2c42ca52625..993f99d9497 100644 --- a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/stan/services/sample/hmc_static_dense_e.hpp b/src/stan/services/sample/hmc_static_dense_e.hpp index 19cc4e0e7fe..7d021b39115 100644 --- a/src/stan/services/sample/hmc_static_dense_e.hpp +++ b/src/stan/services/sample/hmc_static_dense_e.hpp @@ -4,8 +4,7 @@ #include #include #include -#include -#include +#include #include #include #include diff --git a/src/stan/services/sample/hmc_static_dense_e_adapt.hpp b/src/stan/services/sample/hmc_static_dense_e_adapt.hpp index 11204f861f7..6d06be044e9 100644 --- a/src/stan/services/sample/hmc_static_dense_e_adapt.hpp +++ b/src/stan/services/sample/hmc_static_dense_e_adapt.hpp @@ -4,8 +4,7 @@ #include #include #include -#include -#include +#include #include #include #include diff --git a/src/stan/services/sample/hmc_static_diag_e.hpp b/src/stan/services/sample/hmc_static_diag_e.hpp index 72e77f1b08d..4851f2a7ab9 100644 --- a/src/stan/services/sample/hmc_static_diag_e.hpp +++ b/src/stan/services/sample/hmc_static_diag_e.hpp @@ -4,8 +4,7 @@ #include #include #include -#include -#include +#include #include #include #include diff --git a/src/stan/services/sample/hmc_static_diag_e_adapt.hpp b/src/stan/services/sample/hmc_static_diag_e_adapt.hpp index 7a40088e368..c4131bcf3c6 100644 --- a/src/stan/services/sample/hmc_static_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_static_diag_e_adapt.hpp @@ -4,8 +4,7 @@ #include #include #include -#include -#include +#include #include #include #include diff --git a/src/stan/services/sample/hmc_static_unit_e.hpp b/src/stan/services/sample/hmc_static_unit_e.hpp index 1b9d26d84bc..a55cfdb2eab 100644 --- a/src/stan/services/sample/hmc_static_unit_e.hpp +++ b/src/stan/services/sample/hmc_static_unit_e.hpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/stan/services/sample/hmc_static_unit_e_adapt.hpp b/src/stan/services/sample/hmc_static_unit_e_adapt.hpp index d12f0831a8c..d44edff83e3 100644 --- a/src/stan/services/sample/hmc_static_unit_e_adapt.hpp +++ b/src/stan/services/sample/hmc_static_unit_e_adapt.hpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/stan/services/util/initialize.hpp b/src/stan/services/util/initialize.hpp index cfd35415ddd..a7d37e5f7e3 100644 --- a/src/stan/services/util/initialize.hpp +++ b/src/stan/services/util/initialize.hpp @@ -7,7 +7,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/stan/services/util/read_dense_inv_metric.hpp b/src/stan/services/util/read_dense_inv_metric.hpp index e00a6d357a7..8b24b523c04 100644 --- a/src/stan/services/util/read_dense_inv_metric.hpp +++ b/src/stan/services/util/read_dense_inv_metric.hpp @@ -3,7 +3,7 @@ #include #include -#include +#include #include #include #include diff --git a/src/stan/services/util/validate_dense_inv_metric.hpp b/src/stan/services/util/validate_dense_inv_metric.hpp index 34aace9895d..c6d0188ff48 100644 --- a/src/stan/services/util/validate_dense_inv_metric.hpp +++ b/src/stan/services/util/validate_dense_inv_metric.hpp @@ -2,7 +2,7 @@ #define STAN_SERVICES_UTIL_VALIDATE_DENSE_INV_METRIC_HPP #include -#include +#include namespace stan { namespace services { diff --git a/src/stan/services/util/validate_diag_inv_metric.hpp b/src/stan/services/util/validate_diag_inv_metric.hpp index 8a4378a3ff2..4f88bff95d8 100644 --- a/src/stan/services/util/validate_diag_inv_metric.hpp +++ b/src/stan/services/util/validate_diag_inv_metric.hpp @@ -2,7 +2,7 @@ #define STAN_SERVICES_UTIL_VALIDATE_DIAG_INV_METRIC_HPP #include -#include +#include namespace stan { namespace services { diff --git a/src/stan/variational/base_family.hpp b/src/stan/variational/base_family.hpp index 025fa500a87..5e2d3d7a3d8 100644 --- a/src/stan/variational/base_family.hpp +++ b/src/stan/variational/base_family.hpp @@ -2,7 +2,7 @@ #define STAN_VARIATIONAL_BASE_FAMILY_HPP #include -#include +#include #include #include diff --git a/src/stan/variational/families/normal_fullrank.hpp b/src/stan/variational/families/normal_fullrank.hpp index d7999b51384..f62fdc14e84 100644 --- a/src/stan/variational/families/normal_fullrank.hpp +++ b/src/stan/variational/families/normal_fullrank.hpp @@ -2,7 +2,7 @@ #define STAN_VARIATIONAL_NORMAL_FULLRANK_HPP #include -#include +#include #include #include #include diff --git a/src/stan/variational/families/normal_meanfield.hpp b/src/stan/variational/families/normal_meanfield.hpp index d54e0fb8f51..f6d509f059f 100644 --- a/src/stan/variational/families/normal_meanfield.hpp +++ b/src/stan/variational/families/normal_meanfield.hpp @@ -2,7 +2,7 @@ #define STAN_VARIATIONAL_NORMAL_MEANFIELD_HPP #include -#include +#include #include #include #include diff --git a/src/stan/variational/print_progress.hpp b/src/stan/variational/print_progress.hpp index 968b80942f8..6558e672aa6 100644 --- a/src/stan/variational/print_progress.hpp +++ b/src/stan/variational/print_progress.hpp @@ -2,8 +2,8 @@ #define STAN_VARIATIONAL_PRINT_PROGRESS_HPP #include -#include -#include +#include +#include #include #include #include