diff --git a/.gitmodules b/.gitmodules index 235095d569a..0162c3dc7cf 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,4 @@ [submodule "lib/stan_math"] path = lib/stan_math url = https://github.com/stan-dev/math.git + branch = mpi_warmup_v2 diff --git a/lib/stan_math b/lib/stan_math index 025a142ec01..ae73534895e 160000 --- a/lib/stan_math +++ b/lib/stan_math @@ -1 +1 @@ -Subproject commit 025a142ec01b68e91adf339a9b86d67e6d0e20ee +Subproject commit ae73534895e5783139e6610c35b5c913c047e0a5 diff --git a/make/mpi_warmup.mk b/make/mpi_warmup.mk new file mode 100644 index 00000000000..7b047a025ad --- /dev/null +++ b/make/mpi_warmup.mk @@ -0,0 +1,5 @@ +ifdef MPI_ADAPTED_WARMUP + CXXFLAGS += -DSTAN_LANG_MPI -DMPI_ADAPTED_WARMUP + CC=mpicxx + CXX=mpicxx +endif diff --git a/makefile b/makefile index da551f3baaa..0e6d98ab23f 100644 --- a/makefile +++ b/makefile @@ -13,6 +13,7 @@ help: -include $(HOME)/.config/stan/make.local # user-defined variables -include make/local # user-defined variables +-include make/mpi_warmup.mk MATH ?= lib/stan_math/ ifeq ($(OS),Windows_NT) diff --git a/src/stan/callbacks/mpi_stream_writer.hpp b/src/stan/callbacks/mpi_stream_writer.hpp new file mode 100644 index 00000000000..1329ef498d6 --- /dev/null +++ b/src/stan/callbacks/mpi_stream_writer.hpp @@ -0,0 +1,138 @@ +#ifndef STAN_CALLBACKS_MPI_STREAM_WRITER_HPP +#define STAN_CALLBACKS_MPI_STREAM_WRITER_HPP + +#ifdef MPI_ADAPTED_WARMUP + +#include +#include +#include +#include +#include + +namespace stan { + namespace callbacks { + /** + * mpi_stream_writer is an implementation + * of writer that writes to a stream. + */ + class mpi_stream_writer : public writer { + public: + /** + * Constructs a stream writer with an output stream + * and an optional prefix for comments. + * + * @param[in, out] output stream to write + * @param[in] comment_prefix string to stream before + * each comment line. Default is "". + */ + mpi_stream_writer(int num_chains, std::ostream& output, + const std::string& comment_prefix = "") + : num_chains_(num_chains), output_(output), + comment_prefix_(comment_prefix) + {} + + /** + * Virtual destructor + */ + virtual ~mpi_stream_writer() {} + + /** + * Set new value for @c num_chains_. + * + * @param[in] n new value of @c num_chains_ + */ + void set_num_chains(int n) { + num_chains_ = n; + } + + /** + * Writes a set of names on a single line in csv format followed + * by a newline. + * + * Note: the names are not escaped. + * + * @param[in] names Names in a std::vector + */ + void operator()(const std::vector& names) { + write_vector(names); + } + + /** + * Writes a set of values in csv format followed by a newline. + * + * Note: the precision of the output is determined by the settings + * of the stream on construction. + * + * @param[in] state Values in a std::vector + */ + void operator()(const std::vector& state) { + write_vector(state); + } + + /** + * Writes the comment_prefix to the stream followed by a newline. + */ + void operator()() { + if (stan::math::mpi::Session::is_in_inter_chain_comm(num_chains_)) { + output_ << comment_prefix_ << std::endl; + } + } + + /** + * Writes the comment_prefix then the message followed by a newline. + * + * @param[in] message A string + */ + void operator()(const std::string& message) { + if (stan::math::mpi::Session::is_in_inter_chain_comm(num_chains_)) { + output_ << comment_prefix_ << message << std::endl; + } + } + + private: + + /** + * nb. of chains that have its own output stream + */ + int num_chains_; + + /** + * Output stream + */ + std::ostream& output_; + + /** + * Comment prefix to use when printing comments: strings and blank lines + */ + std::string comment_prefix_; + + /** + * Writes a set of values in csv format followed by a newline. + * + * Note: the precision of the output is determined by the settings + * of the stream on construction. + * + * @param[in] v Values in a std::vector + */ + template + void write_vector(const std::vector& v) { + if (stan::math::mpi::Session::is_in_inter_chain_comm(num_chains_)) { + if (v.empty()) return; + + typename std::vector::const_iterator last = v.end(); + --last; + + for (typename std::vector::const_iterator it = v.begin(); + it != last; ++it) + output_ << *it << ","; + output_ << v.back() << std::endl; + } + } + }; + + } +} + +#endif + +#endif diff --git a/src/stan/callbacks/stream_writer.hpp b/src/stan/callbacks/stream_writer.hpp index 6519531c0ac..3191ced10bd 100644 --- a/src/stan/callbacks/stream_writer.hpp +++ b/src/stan/callbacks/stream_writer.hpp @@ -7,101 +7,104 @@ #include namespace stan { -namespace callbacks { + namespace callbacks { -/** - * stream_writer is an implementation - * of writer that writes to a stream. - */ -class stream_writer : public writer { - public: - /** - * Constructs a stream writer with an output stream - * and an optional prefix for comments. - * - * @param[in, out] output stream to write - * @param[in] comment_prefix string to stream before - * each comment line. Default is "". - */ - explicit stream_writer(std::ostream& output, - const std::string& comment_prefix = "") - : output_(output), comment_prefix_(comment_prefix) {} + /** + * stream_writer is an implementation + * of writer that writes to a stream. + */ + class stream_writer : public writer { + public: + /** + * Constructs a stream writer with an output stream + * and an optional prefix for comments. + * + * @param[in, out] output stream to write + * @param[in] comment_prefix string to stream before + * each comment line. Default is "". + */ + stream_writer(std::ostream& output, + const std::string& comment_prefix = ""): + output_(output), comment_prefix_(comment_prefix) {} - /** - * Virtual destructor - */ - virtual ~stream_writer() {} + /** + * Virtual destructor + */ + virtual ~stream_writer() {} - /** - * Writes a set of names on a single line in csv format followed - * by a newline. - * - * Note: the names are not escaped. - * - * @param[in] names Names in a std::vector - */ - void operator()(const std::vector& names) { - write_vector(names); - } + /** + * Writes a set of names on a single line in csv format followed + * by a newline. + * + * Note: the names are not escaped. + * + * @param[in] names Names in a std::vector + */ + void operator()(const std::vector& names) { + write_vector(names); + } - /** - * Writes a set of values in csv format followed by a newline. - * - * Note: the precision of the output is determined by the settings - * of the stream on construction. - * - * @param[in] state Values in a std::vector - */ - void operator()(const std::vector& state) { write_vector(state); } + /** + * Writes a set of values in csv format followed by a newline. + * + * Note: the precision of the output is determined by the settings + * of the stream on construction. + * + * @param[in] state Values in a std::vector + */ + void operator()(const std::vector& state) { + write_vector(state); + } - /** - * Writes the comment_prefix to the stream followed by a newline. - */ - void operator()() { output_ << comment_prefix_ << std::endl; } + /** + * Writes the comment_prefix to the stream followed by a newline. + */ + void operator()() { + output_ << comment_prefix_ << std::endl; + } - /** - * Writes the comment_prefix then the message followed by a newline. - * - * @param[in] message A string - */ - void operator()(const std::string& message) { - output_ << comment_prefix_ << message << std::endl; - } + /** + * Writes the comment_prefix then the message followed by a newline. + * + * @param[in] message A string + */ + void operator()(const std::string& message) { + output_ << comment_prefix_ << message << std::endl; + } - private: - /** - * Output stream - */ - std::ostream& output_; + private: + /** + * Output stream + */ + std::ostream& output_; - /** - * Comment prefix to use when printing comments: strings and blank lines - */ - std::string comment_prefix_; + /** + * Comment prefix to use when printing comments: strings and blank lines + */ + std::string comment_prefix_; - /** - * Writes a set of values in csv format followed by a newline. - * - * Note: the precision of the output is determined by the settings - * of the stream on construction. - * - * @param[in] v Values in a std::vector - */ - template - void write_vector(const std::vector& v) { - if (v.empty()) - return; + /** + * Writes a set of values in csv format followed by a newline. + * + * Note: the precision of the output is determined by the settings + * of the stream on construction. + * + * @param[in] v Values in a std::vector + */ + template + void write_vector(const std::vector& v) { + if (v.empty()) return; - typename std::vector::const_iterator last = v.end(); - --last; + typename std::vector::const_iterator last = v.end(); + --last; - for (typename std::vector::const_iterator it = v.begin(); it != last; - ++it) - output_ << *it << ","; - output_ << v.back() << std::endl; - } -}; + for (typename std::vector::const_iterator it = v.begin(); + it != last; ++it) + output_ << *it << ","; + output_ << v.back() << std::endl; + } + }; -} // namespace callbacks -} // namespace stan + } +} #endif diff --git a/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp new file mode 100644 index 00000000000..4a557f1b1bb --- /dev/null +++ b/src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp @@ -0,0 +1,494 @@ +#ifndef STAN_MCMC_HMC_MPI_CROSS_CHAIN_ADAPTER_HPP +#define STAN_MCMC_HMC_MPI_CROSS_CHAIN_ADAPTER_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef MPI_ADAPTED_WARMUP +#include +#endif + +namespace stan { +namespace mcmc { + +#ifdef MPI_ADAPTED_WARMUP + template + class mpi_cross_chain_adapter { + protected: + bool is_adapted_; + bool is_post_warmup_; + int window_size_; + int num_chains_; + int max_num_windows_; + double target_rhat_; + double target_ess_; + std::vector lp_draws_; + Eigen::MatrixXd all_lp_draws_; + std::vector>> lp_acc_; // NOLINT + boost::accumulators::accumulator_set > draw_count_acc_; + Eigen::ArrayXd rhat_; + Eigen::ArrayXd ess_; + mpi_metric_adaptation* var_adapt; + public: + const static int num_post_warmup = 50; + + mpi_cross_chain_adapter() = default; + + mpi_cross_chain_adapter(int num_iterations, int window_size, + int num_chains, + double target_rhat, double target_ess) : + is_adapted_(false), + is_post_warmup_(false), + window_size_(window_size), + num_chains_(num_chains), + max_num_windows_(num_iterations / window_size), + target_rhat_(target_rhat), + target_ess_(target_ess), + lp_draws_(window_size), + all_lp_draws_(window_size_ * max_num_windows_, num_chains_), + lp_acc_(max_num_windows_), + draw_count_acc_(), + rhat_(Eigen::ArrayXd::Zero(max_num_windows_)), + ess_(Eigen::ArrayXd::Zero(max_num_windows_)) + {} + + inline void set_cross_chain_metric_adaptation(mpi_metric_adaptation* ptr) {var_adapt = ptr;} + + inline void set_cross_chain_adaptation_params(int num_iterations, + int window_size, + int num_chains, + double target_rhat, double target_ess) { + is_adapted_ = false; + is_post_warmup_ = false, + window_size_ = window_size; + num_chains_ = num_chains; + max_num_windows_ = num_iterations / window_size; + target_rhat_ = target_rhat; + target_ess_ = target_ess; + lp_draws_.resize(window_size); + all_lp_draws_.resize(window_size_ * max_num_windows_, num_chains_); + lp_acc_.clear(); + lp_acc_.resize(max_num_windows_); + draw_count_acc_ = {}; + rhat_ = Eigen::ArrayXd::Zero(max_num_windows_); + ess_ = Eigen::ArrayXd::Zero(max_num_windows_); + } + + inline void reset_cross_chain_adaptation() { + is_adapted_ = false; + is_post_warmup_ = false, + lp_draws_.clear(); + lp_acc_.clear(); + lp_acc_.resize(max_num_windows_); + draw_count_acc_ = {}; + rhat_ = Eigen::ArrayXd::Zero(max_num_windows_); + ess_ = Eigen::ArrayXd::Zero(max_num_windows_); + var_adapt -> restart(); + } + + inline int max_num_windows() {return max_num_windows_;} + + /* + * Calculate the number of draws. + */ + inline int num_cross_chain_draws() { + return boost::accumulators::count(draw_count_acc_); + } + + inline void + write_num_cross_chain_warmup(callbacks::writer& sample_writer, + int num_thin, int num_warmup) { + if (use_cross_chain_adapt()) { + size_t n = num_cross_chain_draws(); + sample_writer("num_warmup = " + std::to_string(n / num_thin)); + } else { + sample_writer("num_warmup = " + std::to_string(num_warmup)); + } + } + + /* + * Calculate the number of active windows when NEXT + * sample is added. + */ + inline int current_cross_chain_window_counter() { + size_t n = num_cross_chain_draws() - 1; + return n / window_size_ + 1; + } + + // only add samples to inter-chain ranks + // CRTP to sampler + inline void add_cross_chain_sample(double s) { + using stan::math::mpi::Session; + using stan::math::mpi::Communicator; + + Sampler& sampler = static_cast(*this); + + if (sampler.adapting()) { + int i = num_cross_chain_draws() % window_size_; + draw_count_acc_(0); + + if (!is_adapted_) { + int n_win = current_cross_chain_window_counter(); + + bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); + if (is_inter_rank) { + lp_draws_[i] = s; + for (int win = 0; win < n_win; ++win) { + lp_acc_[win](s); + } + var_adapt -> add_sample(sampler.z().q, n_win); + } + } + } + } + + /* + * Computes the effective sample size (ESS) for the specified + * parameter across all kept samples. The value returned is the + * minimum of ESS and the number_total_draws * + * log10(number_total_draws). + * + * This version is based on the one at + * stan/analyze/mcmc/compute_effective_sample_size.hpp + * but assuming the chain_mean and chain_var has been + * calculated(on the fly during adaptation) + * + */ + inline double compute_effective_sample_size(int win, int win_count) { + std::vector draws(num_chains_); + size_t num_draws = (win_count - win) * window_size_; + for (int chain = 0; chain < num_chains_; ++chain) { + draws[chain] = &all_lp_draws_(win * window_size_, chain); + } + return stan::analyze::compute_effective_sample_size(draws, num_draws); + } + + inline const Eigen::ArrayXd& cross_chain_adapt_rhat() { + return rhat_; + } + + inline const Eigen::ArrayXd& cross_chain_adapt_ess() { + return ess_; + } + + inline bool is_cross_chain_adapt_window_end() { + size_t n = num_cross_chain_draws(); + return n > 0 && (n % window_size_ == 0); + } + + inline bool is_cross_chain_adapted() { + return is_adapted_; + } + + inline bool is_post_cross_chain() { + return is_post_warmup_; + } + + inline void set_post_cross_chain() { + is_post_warmup_ = !is_post_warmup_; + } + + inline void msg_adaptation(int win, callbacks::logger& logger) { + std::stringstream message; + message << "iteration: "; + message << std::setw(3); + message << num_cross_chain_draws(); + message << " window: " << win + 1 << " / " << current_cross_chain_window_counter(); + message << std::setw(5) << std::setprecision(2); + message << " Rhat: " << std::fixed << cross_chain_adapt_rhat()[win]; + const Eigen::ArrayXd& ess(cross_chain_adapt_ess()); + message << " ESS: " << std::fixed << ess_[win]; + + logger.info(message); + } + + /* + * @tparam Sampler sampler class + * @param[in] m_win number of windows + * @param[in] window_size window size + * @param[in] num_chains number of chains + * @param[in,out] chain_gather gathered information from each chain, + * must have enough capacity to store up to + * maximum windows for all chains. + # @return vector {stepsize, rhat(only in rank 0)} + */ + inline void cross_chain_adaptation_v1(callbacks::logger& logger) { + using boost::accumulators::accumulator_set; + using boost::accumulators::stats; + using boost::accumulators::tag::mean; + using boost::accumulators::tag::variance; + + using stan::math::mpi::Session; + using stan::math::mpi::Communicator; + + Sampler& sampler = static_cast(*this); + + if ((!is_adapted_) && is_cross_chain_adapt_window_end()) { + double chain_stepsize = sampler.get_nominal_stepsize(); + bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); + double invalid_stepsize = -999.0; + double new_stepsize = invalid_stepsize; + if (is_inter_rank) { + const Communicator& comm = Session::inter_chain_comm(num_chains_); + + const int nd_win = 3; // mean, variance, chain_stepsize + const int win_count = current_cross_chain_window_counter(); + int n_gather = nd_win * win_count + window_size_; + std::vector chain_gather(n_gather, 0.0); + for (int win = 0; win < win_count; ++win) { + int num_draws = (win_count - win) * window_size_; + double unbiased_var_scale = num_draws / (num_draws - 1.0); + chain_gather[nd_win * win] = boost::accumulators::mean(lp_acc_[win]); + chain_gather[nd_win * win + 1] = boost::accumulators::variance(lp_acc_[win]) * + unbiased_var_scale; + chain_gather[nd_win * win + 2] = chain_stepsize; + } + std::copy(lp_draws_.begin(), lp_draws_.end(), + chain_gather.begin() + nd_win * win_count); + + rhat_ = Eigen::ArrayXd::Zero(max_num_windows_); + ess_ = Eigen::ArrayXd::Zero(max_num_windows_); + const int invalid_win = -999; + int adapted_win = invalid_win; + + if (comm.rank() == 0) { + std::vector all_chain_gather(n_gather * num_chains_); + MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, + all_chain_gather.data(), n_gather, MPI_DOUBLE, 0, comm.comm()); + int begin_row = (win_count - 1) * window_size_; + for (int chain = 0; chain < num_chains_; ++chain) { + int j = n_gather * chain + nd_win * win_count; + for (int i = 0; i < window_size_; ++i) { + all_lp_draws_(begin_row + i, chain) = all_chain_gather[j + i]; + } + } + + for (int win = win_count - 1; win >= 0; --win) { + accumulator_set> acc_chain_mean; + accumulator_set> acc_chain_var; + accumulator_set> acc_step; + Eigen::VectorXd chain_mean(num_chains_); + Eigen::VectorXd chain_var(num_chains_); + for (int chain = 0; chain < num_chains_; ++chain) { + chain_mean(chain) = all_chain_gather[chain * n_gather + nd_win * win]; + acc_chain_mean(chain_mean(chain)); + chain_var(chain) = all_chain_gather[chain * n_gather + nd_win * win + 1]; + acc_chain_var(chain_var(chain)); + acc_step(all_chain_gather[chain * n_gather + nd_win * win + 2]); + } + size_t num_draws = (win_count - win) * window_size_; + double var_between = num_draws * boost::accumulators::variance(acc_chain_mean) + * num_chains_ / (num_chains_ - 1); + double var_within = boost::accumulators::mean(acc_chain_var); + rhat_(win) = sqrt((var_between / var_within + num_draws - 1) / num_draws); + ess_[win] = compute_effective_sample_size(win, win_count); + is_adapted_ = rhat_(win) < target_rhat_ && ess_[win] > target_ess_; + + msg_adaptation(win, logger); + + if (is_adapted_) { + adapted_win = win; + break; + } + } + } else { + MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, + NULL, 0, MPI_DOUBLE, 0, comm.comm()); + } + MPI_Bcast(&adapted_win, 1, MPI_INT, 0, comm.comm()); + if (adapted_win >= 0) { + MPI_Allreduce(&chain_stepsize, &new_stepsize, 1, MPI_DOUBLE, MPI_SUM, comm.comm()); + new_stepsize /= num_chains_; + var_adapt -> learn_metric(sampler.z().inv_e_metric_, adapted_win, win_count, comm); + } + } + const Communicator& intra_comm = Session::intra_chain_comm(num_chains_); + MPI_Bcast(&new_stepsize, 1, MPI_DOUBLE, 0, intra_comm.comm()); + is_adapted_ = new_stepsize > 0.0; + if (is_adapted_) { + chain_stepsize = new_stepsize; + MPI_Bcast(sampler.z().inv_e_metric_.data(), + sampler.z().inv_e_metric_.size(), MPI_DOUBLE, 0, intra_comm.comm()); + } + if (is_adapted_) { + sampler.set_nominal_stepsize(chain_stepsize); + } + } + } + + inline bool cross_chain_adaptation(callbacks::logger& logger) { + using boost::accumulators::accumulator_set; + using boost::accumulators::stats; + using boost::accumulators::tag::mean; + using boost::accumulators::tag::variance; + + using stan::math::mpi::Session; + using stan::math::mpi::Communicator; + + Sampler& sampler = static_cast(*this); + + if ((!is_adapted_) && is_cross_chain_adapt_window_end()) { + double chain_stepsize = sampler.get_nominal_stepsize(); + bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); + double invalid_stepsize = -999.0; + double new_stepsize = invalid_stepsize; + double max_ess = 0.0; + if (is_inter_rank) { + const Communicator& comm = Session::inter_chain_comm(num_chains_); + + const int nd_win = 3; // mean, variance, chain_stepsize + const int win_count = current_cross_chain_window_counter(); + int n_gather = nd_win * win_count + window_size_; + std::vector chain_gather(n_gather, 0.0); + for (int win = 0; win < win_count; ++win) { + int num_draws = (win_count - win) * window_size_; + double unbiased_var_scale = num_draws / (num_draws - 1.0); + chain_gather[nd_win * win] = boost::accumulators::mean(lp_acc_[win]); + chain_gather[nd_win * win + 1] = boost::accumulators::variance(lp_acc_[win]) * + unbiased_var_scale; + chain_gather[nd_win * win + 2] = chain_stepsize; + } + std::copy(lp_draws_.begin(), lp_draws_.end(), + chain_gather.begin() + nd_win * win_count); + + rhat_ = Eigen::ArrayXd::Zero(max_num_windows_); + ess_ = Eigen::ArrayXd::Zero(max_num_windows_); + const int invalid_win = -999; + int adapted_win = invalid_win; + + if (comm.rank() == 0) { + std::vector all_chain_gather(n_gather * num_chains_); + MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, + all_chain_gather.data(), n_gather, MPI_DOUBLE, 0, comm.comm()); + int begin_row = (win_count - 1) * window_size_; + for (int chain = 0; chain < num_chains_; ++chain) { + int j = n_gather * chain + nd_win * win_count; + for (int i = 0; i < window_size_; ++i) { + all_lp_draws_(begin_row + i, chain) = all_chain_gather[j + i]; + } + } + + for (int win = 0; win < win_count; ++win) { + bool win_adapted; + accumulator_set> acc_chain_mean; + accumulator_set> acc_chain_var; + accumulator_set> acc_step; + Eigen::VectorXd chain_mean(num_chains_); + Eigen::VectorXd chain_var(num_chains_); + for (int chain = 0; chain < num_chains_; ++chain) { + chain_mean(chain) = all_chain_gather[chain * n_gather + nd_win * win]; + acc_chain_mean(chain_mean(chain)); + chain_var(chain) = all_chain_gather[chain * n_gather + nd_win * win + 1]; + acc_chain_var(chain_var(chain)); + acc_step(all_chain_gather[chain * n_gather + nd_win * win + 2]); + } + size_t num_draws = (win_count - win) * window_size_; + double var_between = num_draws * boost::accumulators::variance(acc_chain_mean) + * num_chains_ / (num_chains_ - 1); + double var_within = boost::accumulators::mean(acc_chain_var); + rhat_(win) = sqrt((var_between / var_within + num_draws - 1) / num_draws); + ess_[win] = compute_effective_sample_size(win, win_count); + win_adapted = rhat_(win) < target_rhat_ && ess_[win] > target_ess_; + + msg_adaptation(win, logger); + + if(ess_[win] > max_ess) { + max_ess = ess_[win]; + adapted_win = -(win + 1); + if (win_adapted) { + adapted_win = std::abs(adapted_win) - 1; + } + } + } + } else { + MPI_Gather(chain_gather.data(), n_gather, MPI_DOUBLE, + NULL, 0, MPI_DOUBLE, 0, comm.comm()); + } + MPI_Bcast(&adapted_win, 1, MPI_INT, 0, comm.comm()); + is_adapted_ = adapted_win >= 0; + int max_ess_win = is_adapted_ ? adapted_win : (-adapted_win - 1); + var_adapt -> learn_metric(sampler.z().inv_e_metric_, max_ess_win, win_count, comm); + std::cout << "cross chain win: " << max_ess_win + 1 << "\n"; + } + const Communicator& intra_comm = Session::intra_chain_comm(num_chains_); + MPI_Bcast(&is_adapted_, 1, MPI_C_BOOL, 0, intra_comm.comm()); + MPI_Bcast(sampler.z().inv_e_metric_.data(), + sampler.z().inv_e_metric_.size(), MPI_DOUBLE, 0, intra_comm.comm()); + if (is_adapted_) { + set_cross_chain_stepsize(); + } + return true; + } else { + return false; + } + } + + inline void set_cross_chain_stepsize() { + using stan::math::mpi::Session; + using stan::math::mpi::Communicator; + + const Communicator& comm = Session::inter_chain_comm(num_chains_); + bool is_inter_rank = Session::is_in_inter_chain_comm(num_chains_); + double new_stepsize; + Sampler& sampler = static_cast(*this); + if (is_inter_rank) { + double chain_stepsize = sampler.get_nominal_stepsize(); + chain_stepsize = 1.0/(chain_stepsize * chain_stepsize); // harmonic mean + MPI_Allreduce(&chain_stepsize, &new_stepsize, 1, MPI_DOUBLE, MPI_SUM, comm.comm()); + new_stepsize = sqrt(num_chains_ / new_stepsize); + } + const Communicator& intra_comm = Session::intra_chain_comm(num_chains_); + MPI_Bcast(&new_stepsize, 1, MPI_DOUBLE, 0, intra_comm.comm()); + sampler.set_nominal_stepsize(new_stepsize); + } + + inline bool use_cross_chain_adapt() { + return num_chains_ > 1; + } + }; + +#else // sequential version + + template + class mpi_cross_chain_adapter { + public: + inline void set_cross_chain_metric_adaptation(mpi_metric_adaptation* ptr) {} + + inline void set_cross_chain_adaptation_params(int num_iterations, + int window_size, + int num_chains, + double target_rhat, double target_ess) {} + inline void add_cross_chain_sample(double s) {} + + inline bool cross_chain_adaptation(callbacks::logger& logger) { return false; } + + inline bool is_cross_chain_adapted() { return false; } + + inline void set_cross_chain_stepsize() {} + + inline bool use_cross_chain_adapt() { return false; } + }; + +#endif + +} +} +#endif + + diff --git a/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp index 05b6c80523f..09cb7b8299d 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp @@ -4,6 +4,7 @@ #include #include #include +#include namespace stan { namespace mcmc { @@ -14,6 +15,7 @@ namespace mcmc { */ template class adapt_dense_e_nuts : public dense_e_nuts, + public mpi_cross_chain_adapter>, public stepsize_covar_adapter { public: adapt_dense_e_nuts(const Model& model, BaseRNG& rng) @@ -29,14 +31,30 @@ class adapt_dense_e_nuts : public dense_e_nuts, this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, s.accept_stat()); - bool update = this->covar_adaptation_.learn_covariance( - this->z_.inv_e_metric_, this->z_.q); + if (this -> use_cross_chain_adapt()) { + this -> add_cross_chain_sample(s.log_prob()); + bool update = this -> cross_chain_adaptation(logger); + if (this -> is_cross_chain_adapted()) { + update = false; + } - if (update) { - this->init_stepsize(logger); + if (update) { + this->init_stepsize(logger); - this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); - this->stepsize_adaptation_.restart(); + this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); + this->stepsize_adaptation_.restart(); + + this->set_cross_chain_stepsize(); + } + } else { + bool update = this->covar_adaptation_.learn_covariance(this->z_.inv_e_metric_, + this->z_.q); + if (update) { + this->init_stepsize(logger); + + this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); + this->stepsize_adaptation_.restart(); + } } } return s; @@ -44,7 +62,9 @@ class adapt_dense_e_nuts : public dense_e_nuts, void disengage_adaptation() { base_adapter::disengage_adaptation(); - this->stepsize_adaptation_.complete_adaptation(this->nom_epsilon_); + if (!this -> is_cross_chain_adapted()) { + this->stepsize_adaptation_.complete_adaptation(this->nom_epsilon_); + } } }; diff --git a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp index 45e92380f57..e3ae02e905a 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp @@ -4,6 +4,7 @@ #include #include #include +#include namespace stan { namespace mcmc { @@ -14,6 +15,7 @@ namespace mcmc { */ template class adapt_diag_e_nuts : public diag_e_nuts, + public mpi_cross_chain_adapter>, public stepsize_var_adapter { public: adapt_diag_e_nuts(const Model& model, BaseRNG& rng) @@ -29,14 +31,31 @@ class adapt_diag_e_nuts : public diag_e_nuts, this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, s.accept_stat()); - bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_, - this->z_.q); + + if (this -> use_cross_chain_adapt()) { + this -> add_cross_chain_sample(s.log_prob()); + bool update = this -> cross_chain_adaptation(logger); + if (this -> is_cross_chain_adapted()) { + update = false; + } - if (update) { - this->init_stepsize(logger); + if (update) { + this->init_stepsize(logger); - this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); - this->stepsize_adaptation_.restart(); + this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); + this->stepsize_adaptation_.restart(); + + this->set_cross_chain_stepsize(); + } + } else { + bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_, + this->z_.q); + if (update) { + this->init_stepsize(logger); + + this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); + this->stepsize_adaptation_.restart(); + } } } return s; diff --git a/src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp index 3929fc7ed12..3fe3b915a10 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_unit_e_nuts.hpp @@ -4,6 +4,7 @@ #include #include #include +#include namespace stan { namespace mcmc { @@ -14,6 +15,7 @@ namespace mcmc { */ template class adapt_unit_e_nuts : public unit_e_nuts, + public mpi_cross_chain_adapter>, public stepsize_adapter { public: adapt_unit_e_nuts(const Model& model, BaseRNG& rng) diff --git a/src/stan/mcmc/mpi_covar_adaptation.hpp b/src/stan/mcmc/mpi_covar_adaptation.hpp new file mode 100644 index 00000000000..89c6a76a90c --- /dev/null +++ b/src/stan/mcmc/mpi_covar_adaptation.hpp @@ -0,0 +1,106 @@ +#ifndef STAN_MCMC_MPI_COVAR_ADAPTATION_HPP +#define STAN_MCMC_MPI_COVAR_ADAPTATION_HPP + +#include +#include +#include + +#ifdef STAN_LANG_MPI +#include +#endif + +namespace stan { + +namespace mcmc { + + class mpi_covar_adaptation : public mpi_metric_adaptation { +#ifdef STAN_LANG_MPI + // using est_t = stan::math::mpi::mpi_covar_estimator; + using est_t = stan::math::welford_covar_estimator; + int num_chains_; + int window_size_; + int init_draw_counter_; + int draw_req_counter_; +public: + std::vector estimators; + std::vector reqs; + std::vector draws; + std::vector num_draws; + + mpi_covar_adaptation(int n_params, int num_chains, int num_iterations, int window_size) + : num_chains_(num_chains), + window_size_(window_size), + init_draw_counter_(0), draw_req_counter_(0), + estimators(num_iterations / window_size, est_t(n_params)), + reqs(window_size), + draws(window_size, Eigen::MatrixXd(n_params, num_chains)), + num_draws(num_iterations / window_size, 0) + {} + + void reset_req() { + draw_req_counter_ = 0; + reqs.clear(); + reqs.resize(window_size_); + } + + virtual void add_sample(const Eigen::VectorXd& q, int curr_win_count) { + const stan::math::mpi::Communicator& comm = + stan::math::mpi::Session::inter_chain_comm(num_chains_); + init_draw_counter_++; + if (init_draw_counter_ > init_bufer_size) { + MPI_Iallgather(q.data(), q.size(), MPI_DOUBLE, + draws[draw_req_counter_].data(), q.size(), MPI_DOUBLE, + comm.comm(), &reqs[draw_req_counter_]); + draw_req_counter_++; + for (int win = 0; win < curr_win_count; ++win) { + num_draws[win]++; + } + } + } + + virtual void learn_metric(Eigen::MatrixXd& covar, int win, int curr_win_count, + const stan::math::mpi::Communicator& comm) { + learn_covariance(covar, win, curr_win_count); + } + + void learn_covariance(Eigen::MatrixXd& covar, int win, int curr_win_count) { + int finished = 0; + int index; + int flag = 0; + while(finished < draw_req_counter_) { + MPI_Testany(draw_req_counter_, reqs.data(), &index, &flag, MPI_STATUS_IGNORE); + if (flag) { + finished++; + for (int i = 0; i < curr_win_count; ++i) { + for (int chain = 0; chain < num_chains_; ++chain) { + estimators[i].add_sample(draws[index].col(chain)); + } + } + } + } + estimators[win].sample_covariance(covar); + double n = num_draws[win] * num_chains_; + covar = (n / (n + 5.0)) * covar + + 1e-3 * (5.0 / (n + 5.0)) + * Eigen::MatrixXd::Identity(covar.rows(), covar.cols()); + + reset_req(); + } + + virtual void restart() { + // estimator.restart(); + } +#else + public: + mpi_covar_adaptation(int n_params, int num_chains, int num_iterations, int window_size) + {} +#endif +}; + +} // namespace mcmc + +} // namespace stan + + + +#endif diff --git a/src/stan/mcmc/mpi_metric_adaptation.hpp b/src/stan/mcmc/mpi_metric_adaptation.hpp new file mode 100644 index 00000000000..8f37e4b696b --- /dev/null +++ b/src/stan/mcmc/mpi_metric_adaptation.hpp @@ -0,0 +1,39 @@ +#ifndef STAN_MCMC_MPI_METRIC_ADAPTATION_HPP +#define STAN_MCMC_MPI_METRIC_ADAPTATION_HPP + +#include +#include + +#ifdef STAN_LANG_MPI +#include +#endif + +namespace stan { + +namespace mcmc { + + class mpi_metric_adaptation { + protected: + static const int init_bufer_size = 75; + public: + virtual void add_sample(const Eigen::VectorXd& q, int curr_win_count) {}; + + virtual void restart() {} + +#ifdef MPI_ADAPTED_WARMUP + virtual void learn_metric(Eigen::VectorXd& var, int win, int curr_win_count, + const stan::math::mpi::Communicator& comm) + {} + + + virtual void learn_metric(Eigen::MatrixXd& covar, int win, int curr_win_count, + const stan::math::mpi::Communicator& comm) + {} +#endif + }; + +} // namespace mcmc + +} // namespace stan + +#endif diff --git a/src/stan/mcmc/mpi_var_adaptation.hpp b/src/stan/mcmc/mpi_var_adaptation.hpp new file mode 100644 index 00000000000..5a1fe18e139 --- /dev/null +++ b/src/stan/mcmc/mpi_var_adaptation.hpp @@ -0,0 +1,74 @@ +#ifndef STAN_MCMC_MPI_VAR_ADAPTATION_HPP +#define STAN_MCMC_MPI_VAR_ADAPTATION_HPP + +#include +#include +#include + +#ifdef STAN_LANG_MPI +#include +#endif + +namespace stan { + +namespace mcmc { + + class mpi_var_adaptation : public mpi_metric_adaptation { +#ifdef STAN_LANG_MPI + using est_t = stan::math::mpi::mpi_var_estimator; + + int init_draw_counter; +public: + std::vector estimators; + + mpi_var_adaptation() = default; + + mpi_var_adaptation(int n_params, int max_num_windows) + : init_draw_counter(0), estimators(max_num_windows, est_t(n_params)) + {} + + mpi_var_adaptation(int n_params, int num_iterations, int window_size) + : mpi_var_adaptation(n_params, num_iterations / window_size) + {} + + virtual void add_sample(const Eigen::VectorXd& q, int curr_win_count) { + init_draw_counter++; + if (init_draw_counter > init_bufer_size) { + for (int win = 0; win < curr_win_count; ++win) { + estimators[win].add_sample(q); + } + } + } + + virtual void learn_metric(Eigen::VectorXd& var, int win, int curr_win_count, + const stan::math::mpi::Communicator& comm) { + learn_variance(var, win, curr_win_count, comm); + } + + void learn_variance(Eigen::VectorXd& var, int win, int curr_win_count, + const stan::math::mpi::Communicator& comm) { + double n = static_cast(estimators[win].sample_variance(var, comm)); + var = (n / (n + 5.0)) * var + + 1e-3 * (5.0 / (n + 5.0)) * Eigen::VectorXd::Ones(var.size()); + } + + virtual void restart() { + for (auto&& e : estimators) { + e.restart(); + } + } + +#else + public: + mpi_var_adaptation(int n_params, int num_iterations, int window_size) + {} +#endif +}; + +} // namespace mcmc + +} // namespace stan + + + +#endif diff --git a/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp index 80adbf0b549..032985d45ad 100644 --- a/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp @@ -57,7 +57,9 @@ template int hmc_nuts_dense_e_adapt( Model& model, stan::io::var_context& init, stan::io::var_context& init_inv_metric, unsigned int random_seed, - unsigned int chain, double init_radius, int num_warmup, int num_samples, + unsigned int chain, double init_radius, + int num_cross_chains, int cross_chain_window, double cross_chain_rhat, int cross_chain_ess, + int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, double delta, double gamma, double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer, @@ -96,6 +98,14 @@ int hmc_nuts_dense_e_adapt( sampler.set_window_params(num_warmup, init_buffer, term_buffer, window, logger); + // cross chain adaptation + sampler.set_cross_chain_adaptation_params(num_warmup, + cross_chain_window, num_cross_chains, + cross_chain_rhat, cross_chain_ess); + mcmc::mpi_covar_adaptation var_adapt(model.num_params_r(), + num_cross_chains, num_warmup, cross_chain_window); + sampler.set_cross_chain_metric_adaptation(&var_adapt); + util::run_adaptive_sampler( sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); @@ -138,7 +148,9 @@ int hmc_nuts_dense_e_adapt( template int hmc_nuts_dense_e_adapt( Model& model, stan::io::var_context& init, unsigned int random_seed, - unsigned int chain, double init_radius, int num_warmup, int num_samples, + unsigned int chain, double init_radius, + int num_cross_chains, int cross_chain_window, double cross_chain_rhat, int cross_chain_ess, + int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, double delta, double gamma, double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer, @@ -150,7 +162,9 @@ int hmc_nuts_dense_e_adapt( stan::io::var_context& unit_e_metric = dmp; return hmc_nuts_dense_e_adapt( - model, init, unit_e_metric, random_seed, chain, init_radius, num_warmup, + model, init, unit_e_metric, random_seed, chain, init_radius, + num_cross_chains, cross_chain_window, cross_chain_rhat, cross_chain_ess, + num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, logger, init_writer, sample_writer, diagnostic_writer); diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp index e349c25a681..78f9b473b6c 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp @@ -57,7 +57,9 @@ template int hmc_nuts_diag_e_adapt( Model& model, stan::io::var_context& init, stan::io::var_context& init_inv_metric, unsigned int random_seed, - unsigned int chain, double init_radius, int num_warmup, int num_samples, + unsigned int chain, double init_radius, + int num_cross_chains, int cross_chain_window, double cross_chain_rhat, int cross_chain_ess, + int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, double delta, double gamma, double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer, @@ -95,6 +97,13 @@ int hmc_nuts_diag_e_adapt( sampler.set_window_params(num_warmup, init_buffer, term_buffer, window, logger); + // cross chain adaptation + sampler.set_cross_chain_adaptation_params(num_warmup, + cross_chain_window, num_cross_chains, + cross_chain_rhat, cross_chain_ess); + mcmc::mpi_var_adaptation var_adapt(model.num_params_r(), num_warmup, cross_chain_window); + sampler.set_cross_chain_metric_adaptation(&var_adapt); + util::run_adaptive_sampler( sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); @@ -136,7 +145,9 @@ int hmc_nuts_diag_e_adapt( template int hmc_nuts_diag_e_adapt( Model& model, stan::io::var_context& init, unsigned int random_seed, - unsigned int chain, double init_radius, int num_warmup, int num_samples, + unsigned int chain, double init_radius, + int num_cross_chains, int cross_chain_window, double cross_chain_rhat, int cross_chain_ess, + int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, double delta, double gamma, double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer, @@ -148,7 +159,8 @@ int hmc_nuts_diag_e_adapt( stan::io::var_context& unit_e_metric = dmp; return hmc_nuts_diag_e_adapt( - model, init, unit_e_metric, random_seed, chain, init_radius, num_warmup, + model, init, unit_e_metric, random_seed, chain, init_radius, + num_cross_chains, cross_chain_window, cross_chain_rhat, cross_chain_ess, num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, logger, init_writer, sample_writer, diagnostic_writer); diff --git a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp index adc84353d61..4260646cde2 100644 --- a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp @@ -48,7 +48,9 @@ namespace sample { template int hmc_nuts_unit_e_adapt( Model& model, stan::io::var_context& init, unsigned int random_seed, - unsigned int chain, double init_radius, int num_warmup, int num_samples, + unsigned int chain, double init_radius, + int num_cross_chains, int cross_chain_window, double cross_chain_rhat, int cross_chain_ess, + int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, double delta, double gamma, double kappa, double t0, callbacks::interrupt& interrupt, @@ -71,6 +73,14 @@ int hmc_nuts_unit_e_adapt( sampler.get_stepsize_adaptation().set_kappa(kappa); sampler.get_stepsize_adaptation().set_t0(t0); + // cross chain adaptation setup + sampler.set_cross_chain_adaptation_params(num_warmup, + cross_chain_window, num_cross_chains, + cross_chain_rhat, cross_chain_ess); + mcmc::mpi_metric_adaptation dummy_adapt; + sampler.set_cross_chain_metric_adaptation(&dummy_adapt); + + util::run_adaptive_sampler( sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer); diff --git a/src/stan/services/util/generate_transitions.hpp b/src/stan/services/util/generate_transitions.hpp index 2c72f2e1138..47bcb10b5ff 100644 --- a/src/stan/services/util/generate_transitions.hpp +++ b/src/stan/services/util/generate_transitions.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include namespace stan { @@ -37,8 +38,9 @@ namespace util { * @param[in,out] callback interrupt callback called once an iteration * @param[in,out] logger logger for messages */ -template -void generate_transitions(stan::mcmc::base_mcmc& sampler, int num_iterations, +template ::value>* = nullptr> +void generate_transitions(Sampler& sampler, int num_iterations, int start, int finish, int num_thin, int refresh, bool save, bool warmup, util::mcmc_writer& mcmc_writer, @@ -67,6 +69,12 @@ void generate_transitions(stan::mcmc::base_mcmc& sampler, int num_iterations, mcmc_writer.write_sample_params(base_rng, init_s, sampler, model); mcmc_writer.write_diagnostic_params(init_s, sampler); } + + // check cross-chain convergence + if (mpi_cross_chain::end_transitions(sampler)) { + mpi_cross_chain::set_post_iter(sampler); + break; + } } } diff --git a/src/stan/services/util/mpi.hpp b/src/stan/services/util/mpi.hpp new file mode 100644 index 00000000000..2b6b93df1f3 --- /dev/null +++ b/src/stan/services/util/mpi.hpp @@ -0,0 +1,327 @@ +#ifndef STAN_SERVICES_UTIL_MPI_WARMUP_HPP +#define STAN_SERVICES_UTIL_MPI_WARMUP_HPP + +#ifdef STAN_LANG_MPI + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// default comm to world comm, in case stan needs to be +// called as library. +#define MPI_COMM_STAN MPI_COMM_WORLD + +// by default there are no warmup-pulling chains. +#ifndef NUM_MPI_CHAINS +#define NUM_MPI_CHAINS 1 +#endif + +namespace stan { +namespace services { +namespace util { +namespace mpi { + + /* + * MPI Evionment that initializes and finalizes the MPI + */ + struct Envionment { + struct Envionment_ { + Envionment_() { + init(); + } + ~Envionment_() { + finalize(); + } + + static void init() { +#ifdef STAN_LANG_MPI + int flag; + MPI_Initialized(&flag); + if(!flag) { + int provided; + MPI_Init_thread(NULL, NULL, MPI_THREAD_SINGLE, &provided); + // print provided when needed + } +#endif + } + + static void finalize() { +#ifdef STAN_LANG_MPI + int flag; + MPI_Finalized(&flag); + if(!flag) MPI_Finalize(); +#endif + } + }; + + static const Envionment_ env; + }; + + // out-of-line initilization + const Envionment::Envionment_ Envionment::env; + + /* + * MPI Communicators. With default constructor disabled, + * a communicator can only be created through duplication. + */ + struct Communicator { + private: + Communicator(); + + public: + MPI_Comm comm; + int size; + int rank; + + /* + * communicator constructor using @c Envionment and @c MPI_Comm + */ + explicit Communicator(MPI_Comm other) : + comm(MPI_COMM_NULL), size(0), rank(-1) { + if (other != MPI_COMM_NULL) { + MPI_Comm_dup(other, &comm); + MPI_Comm_size(comm, &size); + MPI_Comm_rank(comm, &rank); + } + } + + /* + * copy constructor is deep + */ + explicit Communicator(const Communicator& other) : + Communicator(other.comm) + {} + + /* + * type-cast to MPI_Comm object + */ + operator MPI_Comm() { + return this -> comm; + } + + /* + * destructor needs to free MPI_Comm + */ + ~Communicator() { + if (comm != MPI_COMM_NULL) { + MPI_Comm_free(&comm); + } + } + }; + + MPI_Comm inter_chain_comm(int num_mpi_chains) { + Envionment::env.init(); + + int world_size; + MPI_Comm_size(MPI_COMM_STAN, &world_size); + stan::math::check_greater_or_equal("MPI inter-chain session", + "number of procs", world_size, + num_mpi_chains); + + MPI_Group stan_group, new_group; + MPI_Comm_group(MPI_COMM_STAN, &stan_group); + int num_chain_with_extra_proc = world_size % num_mpi_chains; + int num_proc_per_chain = world_size / num_mpi_chains; + std::vector ranks(num_mpi_chains); + if (num_chain_with_extra_proc == 0) { + for (int i = 0, j = 0; i < world_size; i += num_proc_per_chain, ++j) { + ranks[j] = i; + } + } else { + num_proc_per_chain++; + int i = 0; + for (int j = 0; j < num_chain_with_extra_proc; ++j) { + ranks[j] = i; + i += num_proc_per_chain; + } + num_proc_per_chain--; + for (int j = num_chain_with_extra_proc; j < num_mpi_chains; ++j) { + ranks[j] = i; + i += num_proc_per_chain; + } + } + + MPI_Group_incl(stan_group, num_mpi_chains, ranks.data(), &new_group); + MPI_Comm new_inter_comm, new_intra_comm; + MPI_Comm_create_group(MPI_COMM_STAN, new_group, 99, &new_inter_comm); + MPI_Group_free(&new_group); + MPI_Group_free(&stan_group); + return new_inter_comm; + } + + MPI_Comm intra_chain_comm(int num_mpi_chains) { + Envionment::env.init(); + + int world_size, world_rank, color; + MPI_Comm_size(MPI_COMM_STAN, &world_size); + MPI_Comm_rank(MPI_COMM_STAN, &world_rank); + + int num_chain_with_extra_proc = world_size % num_mpi_chains; + const int n_proc = world_size / num_mpi_chains; + if (num_chain_with_extra_proc == 0) { + color = world_rank / n_proc; + } else { + int i = 0; + for (int j = 0; j < num_mpi_chains; ++j) { + const int n = j < num_chain_with_extra_proc ? (n_proc + 1) : n_proc; + if (world_rank >= i && world_rank < i + n) { + color = i; + break; + } + i += n; + } + } + + MPI_Comm new_intra_comm; + MPI_Comm_split(MPI_COMM_STAN, color, world_rank, &new_intra_comm); + return new_intra_comm; + } + + // * MPI communicator wrapper for RAII. Note that no + // * MPI's predfined comm such as @c MPI_COMM_WOLRD are allowed. + template + struct Session { + static const Communicator stan_comm; + static const MPI_Comm MPI_COMM_INTER_CHAIN; + static const MPI_Comm MPI_COMM_INTRA_CHAIN; + }; + + template + const Communicator Session::stan_comm(MPI_COMM_WORLD); + + template + const MPI_Comm Session:: + MPI_COMM_INTER_CHAIN(inter_chain_comm(num_mpi_chains)); + + template + const MPI_Comm Session:: + MPI_COMM_INTRA_CHAIN(intra_chain_comm(num_mpi_chains)); + + /** + * Dynamic loader that manages master & slave + * communication and data assembly. + */ + struct mpi_loader_base { + static const int work_tag = 1; + static const int err_tag = 2; + static const int done_tag = 3; + + //! communicator wrapper for warmup + const Communicator& comm; + //! MPI communicator + const MPI_Comm mpi_comm; + //! double workspace + Eigen::MatrixXd workspace_r; + + //! construct loader given MPI communicator + mpi_loader_base(const Communicator& comm_in) : + comm(comm_in), mpi_comm(comm.comm) + { + // make sure there are slave chains. + // comm.rank == -1 indicates non inter-comm node + if (comm.rank >= 0) { + static const char* caller = "MPI load balance initialization"; + stan::math::check_greater(caller, "MPI comm size", comm.size, 1); + } + } + }; + + /** + * master receives adaptation info from slave chains and + * process that info through an external functor @c ensemble_func + * in order to improve the quality of adaptation, before + * sending the improved adapt info to slave chains. + */ + struct mpi_warmup { + mpi_loader_base& loader; + Eigen::MatrixXd& workspace_r; + int interval; + MPI_Request req; + const bool is_inter_comm_node; + + //! construct loader given MPI communicator + mpi_warmup(mpi_loader_base& l, int inter) : + loader(l), workspace_r(l.workspace_r), interval(inter), + is_inter_comm_node(loader.comm.size > 0) + {} + + ~mpi_warmup() {} + + /* + * run transitions and process each chain's adaptation + * information before sending it to others. + * + * @tparam Sampler sampler used + * @tparam Model model struct + * @tparam S functor that process the adaptation + * information and return it as a vector. + * @tparam F functor that does transitions. + * @tparam Ts args of @c F. + */ + template + void operator()(Sampler& sampler, Model& model, + stan::mcmc::sample& sample, + const S& fs, F& f, Ts... pars) { + if (is_inter_comm_node) { + f(pars...); + + const int rank = loader.comm.rank; + const int mpi_size = loader.comm.size; + const int size = S::size(sampler, model, sample); + workspace_r.resize(size, mpi_size); + Eigen::VectorXd work(fs(sampler, model, sample)); + MPI_Iallgather(work.data(), size, MPI_DOUBLE, + workspace_r.data(), size, MPI_DOUBLE, + loader.mpi_comm, &req); + + } + } + + /* + * check if the MPI communication is finished. While + * waiting, keep doing transitions. When communication + * is done, generate updated adaptation information and + * update sampler. This function must be called before + * exiting the scope in which @c mpi_warmup obj is declared. + * + * @tparam Sampler sampler used + * @tparam Model model struct + * @tparam S functor that update sampler with new adaptation . + * @tparam F functor that does transitions. + * @tparam Ts args of @c F. + */ + template + void finalize(Sampler& sampler, Model& model, + stan::mcmc::sample& sample, + const S& fs, F& f, Ts... pars) { + if (is_inter_comm_node) { + int flag = 0; + MPI_Test(&req, &flag, MPI_STATUS_IGNORE); + while (flag == 0) { + f(pars...); + MPI_Test(&req, &flag, MPI_STATUS_IGNORE); + } + fs(workspace_r, sampler, model, sample); + } + } + }; + +} // mpi +} // namespace util +} // namespace services +} // namespace stan +#endif + + +#endif diff --git a/src/stan/services/util/mpi_cross_chain.hpp b/src/stan/services/util/mpi_cross_chain.hpp new file mode 100644 index 00000000000..15c03a28afa --- /dev/null +++ b/src/stan/services/util/mpi_cross_chain.hpp @@ -0,0 +1,154 @@ +#ifndef STAN_SERVICES_UTIL_MPI_CROSS_CHAIN_HPP +#define STAN_SERVICES_UTIL_MPI_CROSS_CHAIN_HPP + +#include +#include +#include +#include +#include +#include +#include + +#ifdef STAN_LANG_MPI +#include +#endif + +namespace stan { +namespace services { +namespace util { + + template + struct has_cross_chain_warmup { + static const bool value = false; + }; + + + template + struct has_cross_chain_warmup> { + static const bool value = true; + }; + + template + struct has_cross_chain_warmup> { + static const bool value = true; + }; + + template + struct has_cross_chain_warmup> { + static const bool value = true; + }; + + /* + * Helper functions for samplers with MPI WARMUP. Other + * samplers have dummy implmenentation. + */ + template + struct mpi_cross_chain_impl { + static bool end_transitions(Sampler& sampler) {return false;} + + static void set_post_iter(Sampler& sampler) {} + + static int num_post_warmup(Sampler& sampler) { return 0;} + + static int num_draws(Sampler& sampler) { return 0;} + + static void write_num_warmup(Sampler& sampler, + callbacks::writer& sample_writer, + int num_thin, int num_warmup) {} + }; + + /* + * Partial specialization that is only active for MPI warmups + */ +#ifdef MPI_ADAPTED_WARMUP + template + struct mpi_cross_chain_impl { + static bool end_transitions(Sampler& sampler) { + return !sampler.is_post_cross_chain() && sampler.is_cross_chain_adapted(); + } + + static void set_post_iter(Sampler& sampler) { + sampler.set_post_cross_chain(); + } + + static int num_post_warmup(Sampler& sampler) { + return sampler.is_cross_chain_adapted()? sampler.num_post_warmup : 0; + } + + static int num_draws(Sampler& sampler) { + return sampler.num_cross_chain_draws(); + } + + static void write_num_warmup(Sampler& sampler, + callbacks::writer& sample_writer, + int num_thin, int num_warmup) { + sampler.write_num_cross_chain_warmup(sample_writer, num_thin, num_warmup); + } + }; +#endif + + template + struct mpi_cross_chain { + static bool end_transitions(Sampler& sampler) { + return mpi_cross_chain_impl::value>:: + end_transitions(sampler); + } + + static void set_post_iter(Sampler& sampler) { + mpi_cross_chain_impl::value>:: + set_post_iter(sampler); + } + + static int num_post_warmup(Sampler& sampler) { + return mpi_cross_chain_impl::value>:: + num_post_warmup(sampler); + } + + static int num_draws(Sampler& sampler) { + return mpi_cross_chain_impl::value>:: + num_draws(sampler); + } + + static void write_num_warmup(Sampler& sampler, + callbacks::writer& sample_writer, + int num_thin, int num_warmup) { + mpi_cross_chain_impl::value>:: + write_num_warmup(sampler, sample_writer, num_thin, num_warmup); + } + }; + + + /* + * modify cmdstan::command seed + */ + void set_cross_chain_id(unsigned int& id, int num_chains) { +#ifdef MPI_ADAPTED_WARMUP + using stan::math::mpi::Session; + using stan::math::mpi::Communicator; + + const Communicator& inter_comm = Session::inter_chain_comm(num_chains); + const Communicator& intra_comm = Session::intra_chain_comm(num_chains); + id = inter_comm.rank(); + MPI_Bcast(&id, 1, MPI_UNSIGNED, 0, intra_comm.comm()); +#endif + } + + /* + * modify cmdstan::command file + */ + void set_cross_chain_file(std::string& file_name, int num_chains) { +#ifdef MPI_ADAPTED_WARMUP + using stan::math::mpi::Session; + using stan::math::mpi::Communicator; + + if (Session::is_in_inter_chain_comm(num_chains)) { + const Communicator& comm = Session::inter_chain_comm(num_chains); + file_name = "mpi." + std::to_string(comm.rank()) + "." + file_name; + } +#endif + } +} +} +} + +#endif diff --git a/src/stan/services/util/run_adaptive_sampler.hpp b/src/stan/services/util/run_adaptive_sampler.hpp index c4758eb06c7..a3048394599 100644 --- a/src/stan/services/util/run_adaptive_sampler.hpp +++ b/src/stan/services/util/run_adaptive_sampler.hpp @@ -67,6 +67,14 @@ void run_adaptive_sampler(Sampler& sampler, Model& model, util::generate_transitions(sampler, num_warmup, 0, num_warmup + num_samples, num_thin, refresh, save_warmup, true, writer, s, model, rng, interrupt, logger); + + // cross-chain post convergence iterations + util::generate_transitions(sampler, mpi_cross_chain::num_post_warmup(sampler), + mpi_cross_chain::num_draws(sampler), + num_warmup + num_samples, num_thin, refresh, save_warmup, + true, writer, s, model, rng, interrupt, logger); + mpi_cross_chain::write_num_warmup(sampler, sample_writer, num_thin, num_warmup); + clock_t end = clock(); double warm_delta_t = static_cast(end - start) / CLOCKS_PER_SEC; diff --git a/src/test/unit/mcmc/hmc/mpi_warmup_test.cpp b/src/test/unit/mcmc/hmc/mpi_warmup_test.cpp new file mode 100644 index 00000000000..54fba6b9a55 --- /dev/null +++ b/src/test/unit/mcmc/hmc/mpi_warmup_test.cpp @@ -0,0 +1,339 @@ +#ifdef STAN_LANG_MPI + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using Eigen::MatrixXd; +using Eigen::Matrix; +using std::vector; +using stan::math::mpi::Session; +using stan::math::mpi::Communicator; +using Eigen::MatrixXd; +using Eigen::VectorXd; +using boost::accumulators::accumulator_set; +using boost::accumulators::stats; +using boost::accumulators::tag::mean; +using boost::accumulators::tag::variance; + +struct dummy_sampler { + double stepsize; + + dummy_sampler(double step) : stepsize(step) {} + + double get_nominal_stepsize() { + return stepsize; + } + + void set_nominal_stepsize(double new_stepsize) { + stepsize = new_stepsize; + } +}; + +TEST(mpi_warmup_test, mpi_cross_chain_adapter) { + stan::callbacks::stream_logger logger(std::cout, std::cout, std::cout, + std::cerr, std::cerr); + + const int num_chains = 4; + const int max_num_windows = 5; + const int window_size = 50; + std::vector + draw_vecs(num_chains, Eigen::VectorXd(window_size * max_num_windows)); +draw_vecs[0] << +-276.606, -277.168, -272.621, -271.142, -271.95 , +-269.749, -267.016, -273.508, -268.65 , -265.904, +-264.629, -260.797, -263.184, -263.892, -268.81 , +-272.563, -268.32 , -266.297, -265.787, -266.073, +-265.788, -262.26 , -265.073, -265.511, -264.318, +-264.318, -266.261, -265.633, -265.323, -265.633, +-265.426, -265.69 , -266.122, -264.876, -264.829, +-264.238, -265.822, -262.979, -264.012, -263.801, +-264.745, -263.94 , -263.586, -263.284, -262.566, +-261.816, -265.308, -266.467, -265.915, -266.122, +-266.122, -265.903, -265.903, -265.717, -271.78 , +-271.78 , -271.712, -271.712, -271.011, -273.137, +-272.125, -265.535, -265.168, -267.824, -262.983, +-262.985, -261.967, -265.455, -265.9 , -265.623, +-262.111, -262.111, -262.111, -266.586, -266.545, +-266.545, -263.267, -268.256, -270.425, -268.454, +-268.807, -269.154, -269.154, -269.528, -268.206, +-271.774, -269.453, -267.725, -266.435, -269.434, +-267.838, -267.676, -267.925, -268.343, -267.824, +-267.824, -267.05 , -268.138, -268.072, -267.321, +-267.529, -267.481, -267.118, -267.872, -269.605, +-269.974, -269.347, -269.806, -273.444, -272.257, +-269.983, -271.206, -271.453, -268.328, -268.185, +-268.817, -266.788, -264.052, -270.256, -269.739, +-271.512, -266.883, -266.736, -266.872, -267.525, +-266.845, -267.412, -267.754, -267.754, -267.625, +-266.819, -266.978, -267.949, -266.816, -267.641, +-268.377, -267.13 , -266.892, -269.544, -270.316, +-270.461, -270.989, -273.724, -273.155, -272.725, +-272.082, -264.071, -265.269, -263.945, -261.799, +-261.854, -264.487, -267.127, -265.134, -264.052, +-269.239, -263.838, -264.494, -261.844, -264.41 , +-261.969, -264.178, -265.37 , -266.054, -264.703, +-266.988, -267.21 , -265.177, -263.338, -266.309, +-272.157, -269.383, -266.892, -266.822, -268.786, +-271.036, -266.955, -267.356, -270.616, -265.706, +-264.444, -263.224, -263.313, -265.252, -263.874, +-265.89 , -260.837, -262.717, -262.073, -264.779, +-264.05 , -265.203, -262.597, -261.822, -264.143, +-268.655, -269.055, -270.736, -265.17 , -265.217, +-269.879, -270.83 , -271.194, -269.754, -263.825, +-263.737, -265.485, -264.626, -264.713, -265.561, +-266.183, -262.944, -263.938, -263.534, -263.802, +-262.138, -262.138, -261.331, -261.777, -261.62 , +-263.027, -263.062, -262.453, -263.18 , -264.445, +-266.134, -265.103, -264.626, -264.427, -265.528, +-263.938, -263.587, -263.358, -264.897, -265.179, +-264.573, -270.805, -270.824, -268.878, -268.878, +-269.638, -269.536, -276.973, -274.614, -277.589, +-273.321, -273.301, -271.049, -273.554, -269.292; + +draw_vecs[1] << +-270.284, -266.874, -263.361, -260.089, -262.139, +-262.139, -265.862, -265.862, -264.475, -264.475, +-263.834, -262.765, -263.039, -265.855, -267.63 , +-266.903, -267.004, -262.547, -262.196, -259.266, +-259.185, -258.549, -259.665, -258.64 , -258.7 , +-260.475, -260.475, -261.463, -261.483, -261.603, +-261.016, -262.461, -262.359, -262.543, -261.563, +-261.563, -262.017, -262.425, -262.895, -263.366, +-267.275, -265.694, -266.102, -265.527, -261.296, +-261.296, -263.983, -262.662, -263.794, -268.656, +-267.987, -268.543, -267.519, -265.96 , -267.899, +-268.445, -269.063, -270.79 , -264.644, -265.781, +-268.941, -269.489, -269.419, -272.76 , -267.807, +-270.202, -267.557, -265.109, -265.497, -268.019, +-266.981, -268.117, -265.153, -265.451, -271.16 , +-266.011, -265.764, -267.551, -267.334, -264.686, +-266.051, -267.103, -268.63 , -269.366, -269.251, +-267.918, -267.476, -265.557, -266.437, -264.879, +-265.035, -266.154, -268.055, -265.552, -265.48 , +-264.146, -264.952, -264.283, -264.768, -262.73 , +-263.659, -263.659, -270.215, -269.461, -271.369, +-276.308, -270.751, -267.617, -268.485, -266.52 , +-266.03 , -263.784, -263.786, -264.258, -264.176, +-265.869, -264.22 , -265.61 , -264.198, -263.931, +-264.291, -266.761, -267.04 , -268.286, -266.905, +-268.179, -266.984, -268.538, -267.756, -267.756, +-269.435, -269.391, -264.409, -264.465, -266.222, +-270.104, -269.579, -267.769, -263.89 , -264.642, +-264.289, -262.59 , -260.718, -266.693, -272.436, +-272.034, -269.769, -262.51 , -265.406, -272.021, +-270.796, -267.703, -267.549, -265.936, -265.205, +-268.592, -265.528, -261.628, -262.462, -262.253, +-262.93 , -262.932, -262.872, -263.627, -264.512, +-263.074, -263.795, -263.434, -265.15 , -264.709, +-262.096, -263.163, -259.09 , -259.09 , -261.602, +-261.602, -264.06 , -263.836, -260.945, -261.985, +-262.039, -261.927, -268.013, -272.047, -273.161, +-268.47 , -269.855, -269.855, -267.957, -267.957, +-261.807, -261.807, -261.807, -261.807, -262.876, +-262.905, -262.086, -262.461, -262.948, -261.041, +-258.394, -258.675, -259.269, -262.313, -261.776, -259 +, -257.015, -258.733, -259.681, -261.881, -263.303, +-263.303, -264.105, -263.857, -264.845, -265.121, +-265.121, -265.121, -265.121, -264.857, -263.811, +-263.796, -263.796, -265.803, -263.442, -263.442, +-262.237, -262.237, -262.237, -261.879, -262.177, +-264.667, -265.174, -265.174, -264.707, -264.707, +-264.535, -265.637, -261.316, -261.456, -262.575, +-265.12 , -263.7 , -263.7 , -263.7 , -263.7 , +-264.145, -268.846, -261.643, -261.561; + +draw_vecs[2] << +-262.744, -261.998, -261.994, -262.239, -264.747, +-263.467, -266.498, -266.158, -266.158, -266.158, +-266.334, -268.21 , -266.863, -265.772, -267.149, +-266.097, -266.097, -265.845, -265.976, -267.216, +-269.566, -269.566, -270.05 , -269.622, -269.981, +-270.698, -270.698, -270.698, -268.899, -268.504, +-268.814, -267.439, -268.08 , -267.438, -268.135, +-268.135, -268.135, -268.135, -267.767, -267.448, +-268.76 , -268.76 , -267.301, -268.337, -267.902, +-269.79 , -267.688, -265.888, -266.014, -266.122, +-266.953, -266.722, -267.119, -267.119, -267.047, +-267.047, -266.797, -266.797, -266.269, -265.727, +-266.522, -266.522, -267.202, -265.66 , -265.66 , +-268.183, -266.952, -267.373, -264.304, -264.59 , +-263.903, -263.988, -264.204, -264.204, -265.643, +-265.643, -264.296, -264.457, -265.484, -265.378, +-265.128, -265.128, -265.128, -265.128, -264.844, +-266.096, -265.205, -265.205, -265.205, -265.198, +-265.198, -265.741, -265.314, -265.903, -265.903, +-266.001, -266.001, -265.504, -265.313, -265.949, +-265.194, -264.247, -264.247, -264.455, -264.455, +-264.455, -264.303, -264.303, -264.303, -264.303, +-265.261, -265.885, -265.097, -264.3 , -264.3 , +-264.555, -264.798, -264.404, -265.854, -265.858, +-265.858, -265.858, -265.858, -265.122, -264.49 , +-264.49 , -264.066, -264.129, -265.102, -264.296, +-264.324, -264.273, -264.078, -264.07 , -263.923, +-263.923, -264.11 , -264.3 , -264.3 , -264.051, +-264.051, -264.864, -265.61 , -265.61 , -264.518, +-264.555, -264.711, -264.711, -265.346, -265.346, +-264.946, -265.135, -265.102, -265.625, -265.482, +-265.482, -265.482, -265.194, -264.499, -264.178, +-264.848, -264.848, -264.155, -264.155, -264.117, +-264.45 , -264.45 , -265.476, -265.476, -267.104, +-264.804, -264.496, -264.565, -264.637, -264.426, +-264.574, -264.659, -265.509, -265.509, -264.669, +-264.669, -264.669, -264.669, -265.014, -265.014, +-264.797, -266.052, -266.052, -267.349, -266.748, +-266.266, -267.778, -266.736, -266.736, -268.388, +-265.949, -265.949, -266.144, -267.147, -267.147, +-265.965, -265.329, -265.411, -267.016, -265.516, +-265.516, -265.516, -265.516, -265.516, -267.111, +-266.987, -266.662, -265.979, -265.517, -265.495, +-265.898, -266.085, -266.085, -265.282, -265.337, +-265.337, -265.337, -265.873, -265.044, -265.044, +-265.044, -267.565, -265.853, -266.693, -265.688, +-265.92 , -266.021, -266.147, -266.802, -266.84 , +-266.84 , -266.84 , -266.84 , -266.84 , -267.173, +-266.361, -266.361, -266.825, -266.444, -266.444, +-266.444, -266.444, -267.152, -266.515, -266.533; + +draw_vecs[3] << + -266.014,-265.791, -265.791, -266.053, -266.196, +-265.953, -265.787, -265.787, -266.336, -266.658, +-267.189, -267.232, -270.645, -270.645, -270.645, +-270.645, -272.083, -269 , -266.46 , -266.786, +-267.246, -266.353, -266.782, -266.782, -266.823, +-266.781, -266.781, -266.781, -266.781, -266.718, +-266.562, -266.837, -268.308, -267.161, -267.081, +-267.889, -267.103, -267.103, -265.488, -265.73 , +-265.73 , -266.46 , -266.46 , -267.288, -265.69 , +-265.69 , -265.69 , -266.302, -266.107, -266.107, +-266.107, -266.107, -266.107, -264.795, -264.659, +-265.365, -266.233, -265.995, -265.995, -266.013, +-266.025, -265.512, -265.512, -265.512, -265.512, +-265.605, -265.605, -265.605, -265.571, -265.916, +-265.325, -266.295, -265.598, -266.856, -266.856, +-266.19 , -266.19 , -265.077, -265.249, -265.43 , +-265.43 , -265.429, -265.481, -265.408, -265.408, +-265.991, -265.595, -266.051, -266.051, -266.525, +-267.047, -265.283, -265.167, -265.223, -265.223, +-265.526, -265.158, -265.11 , -265.11 , -265.24 , +-265.293, -265.45 , -265.45 , -265.109, -265.863, +-265.112, -265.112, -265.112, -265.112, -265.112, +-264.967, -264.967, -266.176, -265.038, -265.238, +-265.238, -265.238, -265.531, -265.531, -265.461, +-265.882, -265.882, -265.301, -265.301, -266.118, +-266.254, -264.316, -264.316, -264.316, -265.241, +-264.463, -264.658, -265.323, -264.331, -264.331, +-266.603, -264.131, -264.131, -264.289, -264.289, +-265.96 , -264.685, -264.731, -265.294, -264.663, +-264.831, -264.288, -265.753, -265.753, -265.925, +-265.925, -268.329, -266.288, -266.288, -266.288, +-266.288, -264.796, -264.573, -265.464, -264.95 , +-264.966, -264.602, -264.602, -265.338, -265.549, +-265.575, -267.306, -266.802, -266.268, -265.888, +-265.746, -265.746, -265.746, -266.17 , -266.134, +-265.365, -265.365, -265.484, -266.118, -265.285, +-265.285, -265.285, -265.285, -265.285, -266.041, +-266.041, -267.61 , -267.557, -267.557, -266.593, +-266.132, -265.76 , -265.757, -265.793, -265.793, +-265.793, -265.598, -265.354, -267.131, -265.039, +-265.039, -265.039, -265.039, -265.039, -265.869, +-265.869, -266.309, -265.897, -265.727, -265.958, +-267.231, -266.862, -267.255, -267.545, -267.009, +-265.8 , -265.551, -266.254, -265.394, -264.825, +-264.825, -265.245, -265.245, -264.312, -264.365, +-264.253, -264.514, -264.413, -264.413, -264.413, +-264.413, -264.413, -264.589, -265.277, -265.378, +-265.69 , -265.69 , -264.972, -264.972, -264.972, +-264.972, -265.283, -265.237, -264.671, -264.88 , +-265.099, -266.919, -265.878, -264.653, -264.653; + + const Communicator& comm = Session::inter_chain_comm(num_chains); + + const std::vector draws{draw_vecs[0].data(), + draw_vecs[1].data(), draw_vecs[2].data(), draw_vecs[3].data()}; + + double chain_stepsize = 1.1 + 0.1 * comm.rank(); + + const int num_iterations = window_size * max_num_windows; + + Eigen::VectorXd dummy; + + // a large ESS target should make all windows fail to pass tests + { + stan::mcmc::mpi_cross_chain_adapter cc_adapter(num_iterations, window_size, num_chains, 1.1, 100); + stan::mcmc::mpi_var_adaptation var_adapt(0, cc_adapter.max_num_windows()); + cc_adapter.set_cross_chain_var_adaptation(var_adapt); + for (int i = 0; i < num_iterations; ++i) { + cc_adapter.add_cross_chain_sample(dummy, draw_vecs[comm.rank()](i)); + + dummy_sampler sampler(chain_stepsize); + cc_adapter.cross_chain_adaptation(&sampler, dummy, logger); + + EXPECT_FALSE(cc_adapter.is_cross_chain_adapted()); + + if (cc_adapter.is_cross_chain_adapt_window_end()) { + int curr_win_count = cc_adapter.current_cross_chain_window_counter(); + for (int win = 0; win < curr_win_count; ++win) { + const std::vector p{ + draws[0] + win * window_size, + draws[1] + win * window_size, + draws[2] + win * window_size, + draws[3] + win * window_size}; + double rhat = + stan::analyze::compute_potential_scale_reduction(p, (curr_win_count - win) * window_size); + double ess = + stan::analyze::compute_effective_sample_size(p, (curr_win_count - win) * window_size); + if (comm.rank() == 0) { + EXPECT_FLOAT_EQ(rhat, cc_adapter.cross_chain_adapt_rhat()(win)); + EXPECT_FLOAT_EQ(ess, cc_adapter.cross_chain_adapt_ess()(win)); + } + } + } + } + } + + // a target_ess that 4-window tests should pass + { + int curr_win_count = 4; + stan::mcmc::mpi_cross_chain_adapter cc_adapter(num_iterations, window_size, num_chains, 1.1, 50); + stan::mcmc::mpi_var_adaptation var_adapt(0, cc_adapter.max_num_windows()); + cc_adapter.set_cross_chain_var_adaptation(var_adapt); + for (int i = 0; i < num_iterations; ++i) { + dummy_sampler sampler(chain_stepsize); + cc_adapter.add_cross_chain_sample(dummy, draw_vecs[comm.rank()](i)); + + double step = chain_stepsize; + cc_adapter.cross_chain_adaptation(&sampler, dummy, logger); + if (cc_adapter.is_cross_chain_adapted()) break; + } + int win = 0; // win = 0 @c is_adapted + const std::vector p{ + draws[0] + win * window_size, + draws[1] + win * window_size, + draws[2] + win * window_size, + draws[3] + win * window_size}; + double rhat = + stan::analyze::compute_potential_scale_reduction(p, (curr_win_count - win) * window_size); + double ess = + stan::analyze::compute_effective_sample_size(p, (curr_win_count - win) * window_size); + if (comm.rank() == 0) { + EXPECT_FLOAT_EQ(rhat, cc_adapter.cross_chain_adapt_rhat()(win)); + EXPECT_FLOAT_EQ(ess, cc_adapter.cross_chain_adapt_ess()(win)); + for (int i = win + 1; i < max_num_windows; ++i) { + EXPECT_FLOAT_EQ(cc_adapter.cross_chain_adapt_rhat()(i) ,0.0); + EXPECT_FLOAT_EQ(cc_adapter.cross_chain_adapt_ess()(i) ,0.0); + } + } + } +} + +#endif diff --git a/src/test/unit/mcmc/mpi_covar_adaptation_test.cpp b/src/test/unit/mcmc/mpi_covar_adaptation_test.cpp new file mode 100644 index 00000000000..10480389b8b --- /dev/null +++ b/src/test/unit/mcmc/mpi_covar_adaptation_test.cpp @@ -0,0 +1,101 @@ +#ifdef STAN_LANG_MPI + +#include +#include +#include +#include +#include + +TEST(McmcVarAdaptation, mpi_learn_covariance) { + stan::test::unit::instrumented_logger logger; + + const int n = 10; + Eigen::VectorXd q = Eigen::VectorXd::Zero(n); + Eigen::MatrixXd covar(Eigen::MatrixXd::Zero(n, n)); + + const int n_learn = 12; + + Eigen::MatrixXd target_covar(Eigen::MatrixXd::Identity(n, n)); + target_covar *= 1e-3 * 5.0 / (n_learn + 5.0); + + stan::mcmc::covar_adaptation adapter(n); + adapter.set_window_params(50, 0, 0, n_learn, logger); + + for (int i = 0; i < n_learn; ++i) + adapter.learn_covariance(covar, q); + + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n; ++j) { + EXPECT_EQ(target_covar(i, j), covar(i, j)); + } + } + EXPECT_EQ(0, logger.call_count()); + + stan::math::mpi::Communicator comm(MPI_COMM_STAN); + const int num_chains = comm.size(); // must be <= 4 + if (n_learn % num_chains != 0) + throw std::domain_error("this test function was called with inconsistent MPI COMM size"); + + const int n_learn_chain = n_learn / num_chains; + stan::mcmc::mpi_covar_adaptation mpi_adapter(n, n_learn_chain, n_learn_chain); + Eigen::MatrixXd mpi_covar(Eigen::MatrixXd::Zero(n, n)); + for (int i = 0; i < n_learn_chain; ++i) + mpi_adapter.add_sample(q, 1); + + mpi_adapter.learn_metric(mpi_covar, 0, 1, comm); + + for (int i = 0; i < covar.size(); ++i) { + EXPECT_FLOAT_EQ(covar(i), mpi_covar(i)); + } +} + +TEST(McmcVarAdaptation, mpi_data_learn_covariance) { + stan::test::unit::instrumented_logger logger; + + const int n = 5; + const int n_learn = 12; + Eigen::VectorXd q_all(n * n_learn); + q_all << + -276.606, -277.168, -272.621, -271.142, -271.95 , + -269.749, -267.016, -273.508, -268.65 , -265.904, + -264.629, -260.797, -263.184, -263.892, -268.81 , + -272.563, -268.32 , -266.297, -265.787, -266.073, + -265.788, -262.26 , -265.073, -265.511, -264.318, + -264.318, -266.261, -265.633, -265.323, -265.633, + -265.426, -265.69 , -266.122, -264.876, -264.829, + -264.238, -265.822, -262.979, -264.012, -263.801, + -264.745, -263.94 , -263.586, -263.284, -262.566, + -261.816, -265.308, -266.467, -265.915, -266.122, + -266.122, -265.903, -265.903, -265.717, -271.78 , + -271.78 , -271.712, -271.712, -271.011, -273.137; + + stan::mcmc::covar_adaptation adapter(n); + adapter.set_window_params(50, 0, 0, n_learn, logger); + Eigen::MatrixXd covar(Eigen::MatrixXd::Zero(n, n)); + for (int i = 0; i < n_learn; ++i) { + Eigen::VectorXd q = Eigen::VectorXd::Map(&q_all(i * n), n); + adapter.learn_covariance(covar, q); + } + + EXPECT_EQ(0, logger.call_count()); + + stan::math::mpi::Communicator comm(MPI_COMM_STAN); + const int num_chains = comm.size(); + if (n_learn % num_chains != 0) + throw std::domain_error("this test function was called with inconsistent MPI COMM size"); + const int n_learn_chain = n_learn / num_chains; + stan::mcmc::mpi_covar_adaptation mpi_adapter(n, n_learn_chain, n_learn_chain); + Eigen::MatrixXd mpi_covar(Eigen::MatrixXd::Zero(n, n)); + for (int i = 0; i < n_learn_chain; ++i) { + Eigen::VectorXd q = + Eigen::VectorXd::Map(&q_all(i * n + comm.rank() * n * n_learn_chain), n); + mpi_adapter.add_sample(q, 1); + } + mpi_adapter.learn_metric(mpi_covar, 0, 1, comm); + + for (int i = 0; i < n; ++i) { + EXPECT_FLOAT_EQ(covar(i), mpi_covar(i)); + } +} + +#endif diff --git a/src/test/unit/mcmc/mpi_var_adaptation_test.cpp b/src/test/unit/mcmc/mpi_var_adaptation_test.cpp new file mode 100644 index 00000000000..ccdf6de7a34 --- /dev/null +++ b/src/test/unit/mcmc/mpi_var_adaptation_test.cpp @@ -0,0 +1,90 @@ +#ifdef STAN_LANG_MPI + +#include +#include +#include +#include +#include + +TEST(McmcVarAdaptation, mpi_learn_variance) { + stan::test::unit::instrumented_logger logger; + + const int n = 10; + Eigen::VectorXd q(Eigen::VectorXd::Zero(n)), var(q); + const int n_learn = 12; + Eigen::VectorXd target_var(Eigen::VectorXd::Ones(n)); + target_var *= 1e-3 * 5.0 / (n_learn + 5.0); + + stan::mcmc::var_adaptation adapter(n); + adapter.set_window_params(50, 0, 0, n_learn, logger); + for (int i = 0; i < n_learn; ++i) + adapter.learn_variance(var, q); + + for (int i = 0; i < n; ++i) + EXPECT_FLOAT_EQ(target_var(i), var(i)); + + EXPECT_EQ(0, logger.call_count()); + + stan::math::mpi::Communicator comm(MPI_COMM_STAN); + const int num_chains = comm.size(); + const int n_learn_chain = n_learn / num_chains; + stan::mcmc::mpi_var_adaptation mpi_adapter(n, 1, 1); + Eigen::VectorXd mpi_var(Eigen::VectorXd::Zero(n)); + for (int i = 0; i < n_learn_chain; ++i) + mpi_adapter.add_sample(q, 1); + + mpi_adapter.learn_variance(mpi_var, 0, 1, comm); + + for (int i = 0; i < n; ++i) { + EXPECT_FLOAT_EQ(var(i), mpi_var(i)); + } +} + +TEST(McmcVarAdaptation, mpi_data_learn_variance) { + stan::test::unit::instrumented_logger logger; + + const int n = 5; + const int n_learn = 12; + Eigen::VectorXd q_all(n * n_learn); + q_all << + -276.606, -277.168, -272.621, -271.142, -271.95 , + -269.749, -267.016, -273.508, -268.65 , -265.904, + -264.629, -260.797, -263.184, -263.892, -268.81 , + -272.563, -268.32 , -266.297, -265.787, -266.073, + -265.788, -262.26 , -265.073, -265.511, -264.318, + -264.318, -266.261, -265.633, -265.323, -265.633, + -265.426, -265.69 , -266.122, -264.876, -264.829, + -264.238, -265.822, -262.979, -264.012, -263.801, + -264.745, -263.94 , -263.586, -263.284, -262.566, + -261.816, -265.308, -266.467, -265.915, -266.122, + -266.122, -265.903, -265.903, -265.717, -271.78 , + -271.78 , -271.712, -271.712, -271.011, -273.137; + + stan::mcmc::var_adaptation adapter(n); + adapter.set_window_params(50, 0, 0, n_learn, logger); + Eigen::VectorXd var(Eigen::VectorXd::Zero(n)); + for (int i = 0; i < n_learn; ++i) { + Eigen::VectorXd q = Eigen::VectorXd::Map(&q_all(i * n), n); + adapter.learn_variance(var, q); + } + + EXPECT_EQ(0, logger.call_count()); + + stan::math::mpi::Communicator comm(MPI_COMM_STAN); + const int num_chains = comm.size(); + const int n_learn_chain = n_learn / num_chains; + stan::mcmc::mpi_var_adaptation mpi_adapter(n, 1, 1); + Eigen::VectorXd mpi_var(Eigen::VectorXd::Zero(n)); + for (int i = 0; i < n_learn_chain; ++i) { + Eigen::VectorXd q = + Eigen::VectorXd::Map(&q_all(i * n + comm.rank() * n * n_learn_chain), n); + mpi_adapter.add_sample(q, 1); + } + mpi_adapter.learn_variance(mpi_var, 0, 1, comm); + + for (int i = 0; i < n; ++i) { + EXPECT_FLOAT_EQ(var(i), mpi_var(i)); + } +} + +#endif