From 8d09106c8123e3574f69728ba6a83b7e74735599 Mon Sep 17 00:00:00 2001 From: yiz Date: Fri, 15 Nov 2019 07:57:44 -0800 Subject: [PATCH] unit test with mpi gathering of stepsize --- src/stan/services/util/mpi.hpp | 23 +--- .../unit/services/util/mpi_warmup_test.cpp | 117 +++++++++++++++++- 2 files changed, 118 insertions(+), 22 deletions(-) diff --git a/src/stan/services/util/mpi.hpp b/src/stan/services/util/mpi.hpp index 5e0520da2fe..2b6b93df1f3 100644 --- a/src/stan/services/util/mpi.hpp +++ b/src/stan/services/util/mpi.hpp @@ -244,7 +244,7 @@ namespace mpi { Eigen::MatrixXd& workspace_r; int interval; MPI_Request req; - bool is_inter_comm_node; + const bool is_inter_comm_node; //! construct loader given MPI communicator mpi_warmup(mpi_loader_base& l, int inter) : @@ -290,25 +290,8 @@ namespace mpi { * check if the MPI communication is finished. While * waiting, keep doing transitions. When communication * is done, generate updated adaptation information and - * update sampler. - * - * @tparam Sampler sampler used - * @tparam Model model struct - * @tparam S functor that update sampler with new adaptation . - * @tparam F functor that does transitions. - * @tparam Ts args of @c F. - */ - void finalize() { - if (is_inter_comm_node) { - MPI_Wait(&req, MPI_STATUS_IGNORE); - } - } - - /* - * check if the MPI communication is finished. While - * waiting, keep doing transitions. When communication - * is done, generate updated adaptation information and - * update sampler. + * update sampler. This function must be called before + * exiting the scope in which @c mpi_warmup obj is declared. * * @tparam Sampler sampler used * @tparam Model model struct diff --git a/src/test/unit/services/util/mpi_warmup_test.cpp b/src/test/unit/services/util/mpi_warmup_test.cpp index a41f9d040b6..70b29f63edf 100644 --- a/src/test/unit/services/util/mpi_warmup_test.cpp +++ b/src/test/unit/services/util/mpi_warmup_test.cpp @@ -3,6 +3,16 @@ #include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include + using Eigen::MatrixXd; using Eigen::Matrix; using std::vector; @@ -143,7 +153,7 @@ struct send_processor { template static int size(const Sampler& sampler, const Model& model, stan::mcmc::sample& sample) { - return 10; + return 1000000; } template @@ -198,8 +208,111 @@ TEST(mpi_warmup_test, mpi_warmup_loader) { f, dummy_sampler, dummy_model, sample); mpi_warmup_adapt.finalize(dummy_sampler, dummy_model, sample, fd, f); +} + + +struct send_adapt_processor { + const Communicator& comm; + + send_adapt_processor(const Communicator& comm_in) : + comm(comm_in) + {} + + template + static int size(const Sampler& sampler, const Model& model, + stan::mcmc::sample& sample) { + return 1; + } + + template + Eigen::VectorXd operator()(Sampler& sampler, Model& model, stan::mcmc::sample& sample) const { + Eigen::VectorXd x(Eigen::VectorXd::Zero(size(sampler, model, sample))); + x(0) = sampler.get_nominal_stepsize() + 0.01 * comm.rank; + return x; + } +}; + +struct warmup_processor { +template +void operator()(stan::mcmc::base_mcmc& sampler, int num_iterations, + int start, int finish, int num_thin, int refresh, bool save, + stan::services::util::mcmc_writer& mcmc_writer, + stan::mcmc::sample& s, Model& model, + RNG& base_rng, stan::callbacks::interrupt& callback, + stan::callbacks::logger& logger) { + stan::services::util::generate_transitions(sampler, num_iterations, start, finish, + num_thin, refresh, save, true, mcmc_writer, s, + model, base_rng, callback, logger); +} +}; + +struct collect_adapt_processor { + const Communicator& comm; + + collect_adapt_processor(const Communicator& comm_in) : + comm(comm_in) + {} + + template + void operator()(const Eigen::MatrixXd& workspace_r, Sampler& sampler, Model& model, stan::mcmc::sample& sample) const { + EXPECT_EQ(workspace_r.cols(), comm.size); + for (int i = 0; i < comm.size; ++i) { + EXPECT_FLOAT_EQ(workspace_r(0, i), 0.01 * i + workspace_r(0, 0)); + } + } +}; + +TEST(mpi_warmup_test, unit_e_nuts) { + using Model = gauss3D_model_namespace::gauss3D_model; + using Sampler = stan::mcmc::adapt_unit_e_nuts; + boost::ecuyer1988 rng(4839294); + + stan::mcmc::unit_e_point z_init(3); + z_init.q(0) = 1; + z_init.q(1) = -1; + z_init.q(2) = 1; + z_init.p(0) = -1; + z_init.p(1) = 1; + z_init.p(2) = -1; + + std::stringstream debug, info, warn, error, fatal; + stan::callbacks::stream_logger logger(debug, info, warn, error, fatal); + + std::fstream empty_stream("", std::fstream::in); + stan::io::dump data_var_context(empty_stream); + Model model(data_var_context); + + Sampler sampler(model, rng); + sampler.z() = z_init; + sampler.init_hamiltonian(logger); + sampler.set_nominal_stepsize(0.1); + sampler.set_stepsize_jitter(0); + sampler.sample_stepsize(); + + stan::mcmc::sample s(z_init.q, 0, 0); + + stan::callbacks::writer sample_writer; + stan::callbacks::writer diagnostic_writer; + stan::services::util::mcmc_writer writer(sample_writer, diagnostic_writer, logger); + stan::callbacks::interrupt interrupt; + + stan::services::util::generate_transitions(sampler, 10, 0, 20, + 1, 0, false, true, writer, s, + model, rng, interrupt, logger); + + const Communicator inter_comm(Session<3>::MPI_COMM_INTER_CHAIN); + mpi_loader_base loader(inter_comm); + mpi_warmup mpi_warmup_adapt(loader, 10); + warmup_processor f_warmup; + send_adapt_processor fs(inter_comm); + mpi_warmup_adapt(sampler, model, s, fs, + f_warmup, + sampler, 10, 0, 20, 1, 0, false, writer, s, model, rng, interrupt, logger); - mpi_warmup_adapt.finalize(); + collect_adapt_processor fd(inter_comm); + mpi_warmup_adapt.finalize(sampler, model, s, fd, + f_warmup, + sampler, 10, 0, 20, 1, 0, false, writer, s, model, rng, interrupt, logger); } #endif