From 586989fc61d04cbaf66ee6da0883879f44e09186 Mon Sep 17 00:00:00 2001 From: Sebastian Weber Date: Fri, 11 Feb 2022 21:47:53 +0100 Subject: [PATCH 1/8] prototype speculative NUTS --- src/stan/mcmc/hmc/nuts/base_nuts.hpp | 1082 ++++++++++++++++++-------- 1 file changed, 745 insertions(+), 337 deletions(-) diff --git a/src/stan/mcmc/hmc/nuts/base_nuts.hpp b/src/stan/mcmc/hmc/nuts/base_nuts.hpp index 38964283c89..f851eb67200 100644 --- a/src/stan/mcmc/hmc/nuts/base_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/base_nuts.hpp @@ -11,355 +11,763 @@ #include #include +#include + +#include + +#include "tbb/task_scheduler_init.h" +#include "tbb/flow_graph.h" +#include "tbb/concurrent_vector.h" + +using namespace tbb::flow; + +// Prototype of speculative NUTS. +// Uses the Intel Flow Graph concept to turn NUTS into a parallel +// algorithm in that the forward and backward sweep run at the same +// time in parallel. + namespace stan { -namespace mcmc { -/** - * The No-U-Turn sampler (NUTS) with multinomial sampling - */ -template class Hamiltonian, - template class Integrator, class BaseRNG> -class base_nuts : public base_hmc { - public: - base_nuts(const Model& model, BaseRNG& rng) - : base_hmc(model, rng), - depth_(0), - max_depth_(5), - max_deltaH_(1000), - n_leapfrog_(0), - divergent_(false), - energy_(0) {} - - /** - * specialized constructor for specified diag mass matrix - */ - base_nuts(const Model& model, BaseRNG& rng, Eigen::VectorXd& inv_e_metric) - : base_hmc(model, rng, - inv_e_metric), - depth_(0), - max_depth_(5), - max_deltaH_(1000), - n_leapfrog_(0), - divergent_(false), - energy_(0) {} - - /** - * specialized constructor for specified dense mass matrix - */ - base_nuts(const Model& model, BaseRNG& rng, Eigen::MatrixXd& inv_e_metric) - : base_hmc(model, rng, - inv_e_metric), - depth_(0), - max_depth_(5), - max_deltaH_(1000), - n_leapfrog_(0), - divergent_(false), - energy_(0) {} - - ~base_nuts() {} - - void set_metric(const Eigen::MatrixXd& inv_e_metric) { - this->z_.set_metric(inv_e_metric); - } - - void set_metric(const Eigen::VectorXd& inv_e_metric) { - this->z_.set_metric(inv_e_metric); - } - - void set_max_depth(int d) { - if (d > 0) - max_depth_ = d; - } - - void set_max_delta(double d) { max_deltaH_ = d; } - - int get_max_depth() { return this->max_depth_; } - double get_max_delta() { return this->max_deltaH_; } - - sample transition(sample& init_sample, callbacks::logger& logger) { - // Initialize the algorithm - this->sample_stepsize(); - - this->seed(init_sample.cont_params()); - - this->hamiltonian_.sample_p(this->z_, this->rand_int_); - this->hamiltonian_.init(this->z_, logger); - - ps_point z_fwd(this->z_); // State at forward end of trajectory - ps_point z_bck(z_fwd); // State at backward end of trajectory - - ps_point z_sample(z_fwd); - ps_point z_propose(z_fwd); - - // Momentum and sharp momentum at forward end of forward subtree - Eigen::VectorXd p_fwd_fwd = this->z_.p; - Eigen::VectorXd p_sharp_fwd_fwd = this->hamiltonian_.dtau_dp(this->z_); - - // Momentum and sharp momentum at backward end of forward subtree - Eigen::VectorXd p_fwd_bck = this->z_.p; - Eigen::VectorXd p_sharp_fwd_bck = p_sharp_fwd_fwd; - - // Momentum and sharp momentum at forward end of backward subtree - Eigen::VectorXd p_bck_fwd = this->z_.p; - Eigen::VectorXd p_sharp_bck_fwd = p_sharp_fwd_fwd; - - // Momentum and sharp momentum at backward end of backward subtree - Eigen::VectorXd p_bck_bck = this->z_.p; - Eigen::VectorXd p_sharp_bck_bck = p_sharp_fwd_fwd; - - // Integrated momenta along trajectory - Eigen::VectorXd rho = this->z_.p.transpose(); - - // Log sum of state weights (offset by H0) along trajectory - double log_sum_weight = 0; // log(exp(H0 - H0)) - double H0 = this->hamiltonian_.H(this->z_); - int n_leapfrog = 0; - double sum_metro_prob = 0; - - // Build a trajectory until the no-u-turn - // criterion is no longer satisfied - this->depth_ = 0; - this->divergent_ = false; - - while (this->depth_ < this->max_depth_) { - // Build a new subtree in a random direction - Eigen::VectorXd rho_fwd = Eigen::VectorXd::Zero(rho.size()); - Eigen::VectorXd rho_bck = Eigen::VectorXd::Zero(rho.size()); - - bool valid_subtree = false; - double log_sum_weight_subtree = -std::numeric_limits::infinity(); - - if (this->rand_uniform_() > 0.5) { - // Extend the current trajectory forward - this->z_.ps_point::operator=(z_fwd); - rho_bck = rho; - p_bck_fwd = p_fwd_fwd; - p_sharp_bck_fwd = p_sharp_fwd_fwd; - - valid_subtree = build_tree( - this->depth_, z_propose, p_sharp_fwd_bck, p_sharp_fwd_fwd, rho_fwd, - p_fwd_bck, p_fwd_fwd, H0, 1, n_leapfrog, log_sum_weight_subtree, - sum_metro_prob, logger); - z_fwd.ps_point::operator=(this->z_); - } else { - // Extend the current trajectory backwards - this->z_.ps_point::operator=(z_bck); - rho_fwd = rho; - p_fwd_bck = p_bck_bck; - p_sharp_fwd_bck = p_sharp_bck_bck; - - valid_subtree = build_tree( - this->depth_, z_propose, p_sharp_bck_fwd, p_sharp_bck_bck, rho_bck, - p_bck_fwd, p_bck_bck, H0, -1, n_leapfrog, log_sum_weight_subtree, - sum_metro_prob, logger); - z_bck.ps_point::operator=(this->z_); + namespace mcmc { + /** + * The No-U-Turn sampler (NUTS) with multinomial sampling + */ + template class Hamiltonian, + template class Integrator, class BaseRNG> + class base_nuts : public base_hmc { + public: + typedef typename Hamiltonian::PointType state_t; + + base_nuts(const Model& model, BaseRNG& rng) + : base_hmc(model, rng), + depth_(0), max_depth_(5), max_deltaH_(1000), valid_trees_(true), + n_leapfrog_(0), divergent_(false), energy_(0) { } - if (!valid_subtree) - break; + /** + * specialized constructor for specified diag mass matrix + */ + base_nuts(const Model& model, BaseRNG& rng, + Eigen::VectorXd& inv_e_metric) + : base_hmc(model, rng, + inv_e_metric), + depth_(0), max_depth_(5), max_deltaH_(1000), valid_trees_(true), + n_leapfrog_(0), divergent_(false), energy_(0) { + } + + /** + * specialized constructor for specified dense mass matrix + */ + base_nuts(const Model& model, BaseRNG& rng, + Eigen::MatrixXd& inv_e_metric) + : base_hmc(model, rng, + inv_e_metric), + depth_(0), max_depth_(5), max_deltaH_(1000), valid_trees_(true), + n_leapfrog_(0), divergent_(false), energy_(0) { + } - // Sample from accepted subtree - ++(this->depth_); + ~base_nuts() {} - if (log_sum_weight_subtree > log_sum_weight) { - z_sample = z_propose; - } else { - double accept_prob = std::exp(log_sum_weight_subtree - log_sum_weight); - if (this->rand_uniform_() < accept_prob) - z_sample = z_propose; + void set_metric(const Eigen::MatrixXd& inv_e_metric) { + this->z_.set_metric(inv_e_metric); } - log_sum_weight + void set_metric(const Eigen::VectorXd& inv_e_metric) { + this->z_.set_metric(inv_e_metric); + } + + void set_max_depth(int d) { + if (d > 0) + max_depth_ = d; + } + + void set_max_delta(double d) { + max_deltaH_ = d; + } + + int get_max_depth() { return this->max_depth_; } + double get_max_delta() { return this->max_deltaH_; } + + // stores from left/right subtree entire information + struct subtree { + subtree(const double sign, + const ps_point& z_end, + const Eigen::VectorXd& p_sharp_end, + double H0) + : z_end_(z_end), z_propose_(z_end), + p_sharp_end_(p_sharp_end), + H0_(H0), + sign_(sign), + n_leapfrog_(0), + sum_metro_prob_(0) + {} + + ps_point z_end_; + ps_point z_propose_; + Eigen::VectorXd p_sharp_end_; + const double H0_; + const double sign_; + int n_leapfrog_; + double sum_metro_prob_; + }; + + + // extends the tree into the direction of the sign of the + // subtree + typedef std::tuple extend_tree_t; + + extend_tree_t + extend_tree(int depth, subtree& tree, state_t& z, + callbacks::logger& logger) { + // save the current ends needed for later criterion computations + //Eigen::VectorXd p_end = tree.p_end_; + //Eigen::VectorXd p_sharp_end = tree.p_sharp_end_; + Eigen::VectorXd p_sharp_dummy = Eigen::VectorXd::Zero(tree.p_sharp_end_.size()); + + Eigen::VectorXd rho_subtree = Eigen::VectorXd::Zero(tree.p_sharp_end_.size()); + double log_sum_weight_subtree = -std::numeric_limits::infinity(); + + tree.n_leapfrog_ = 0; + tree.sum_metro_prob_ = 0; + + z.ps_point::operator=(tree.z_end_); + + bool valid_subtree = build_tree(depth, + z, tree.z_propose_, + p_sharp_dummy, tree.p_sharp_end_, + rho_subtree, + tree.H0_, + tree.sign_, + tree.n_leapfrog_, + log_sum_weight_subtree, tree.sum_metro_prob_, + logger); + + tree.z_end_.ps_point::operator=(z); + + return std::make_tuple(valid_subtree, log_sum_weight_subtree, rho_subtree, tree.p_sharp_end_, tree.z_propose_, tree.n_leapfrog_, tree.sum_metro_prob_); + } + + + sample + transition(sample& init_sample, callbacks::logger& logger) { + return transition_parallel(init_sample, logger); + } + + // this implementation builds up the dependence graph every call + // to transition. Things which should be refactored: + // 1. build up the nodes only once + // 2. add a prepare method to each node which samples its + // direction and needed random numbers for multinomial sampling + // 3. only the edges are added dynamically. So the forward nodes + // are wired-up and the backward nodes are wired-up if run + // parallel. If run serially, then each grow node is alternated + // with a check node. + sample + transition_parallel(sample& init_sample, callbacks::logger& logger) { + // Initialize the algorithm + this->sample_stepsize(); + + this->seed(init_sample.cont_params()); + + this->hamiltonian_.sample_p(this->z_, this->rand_int_); + this->hamiltonian_.init(this->z_, logger); + + const ps_point z_init(this->z_); + + ps_point z_sample(z_init); + //ps_point z_propose(z_init); + + const Eigen::VectorXd p_sharp = this->hamiltonian_.dtau_dp(this->z_); + Eigen::VectorXd rho = this->z_.p; + + double log_sum_weight = 0; // log(exp(H0 - H0)) + double H0 = this->hamiltonian_.H(this->z_); + //int n_leapfrog = 0; + //double sum_metro_prob = 0; + + // forward tree + subtree tree_fwd(1, z_init, p_sharp, H0); + // backward tree + subtree tree_bck(-1, z_init, p_sharp, H0); + + // actual states which move... copy construct atm...revise?! + state_t z_fwd(this->z_); + state_t z_bck(this->z_); + + // Build a trajectory until the NUTS criterion is no longer satisfied + this->depth_ = 0; + this->divergent_ = false; + this->valid_trees_ = true; + + // the actual number of leapfrog steps in trajectory used + // excluding the ones executed speculative + int n_leapfrog = 0; + + // actually summed metropolis prob of used trajectory + double sum_metro_prob = 0; + + std::vector fwd_direction(this->max_depth_); + + for (std::size_t i = 0; i != this->max_depth_; ++i) + fwd_direction[i] = this->rand_uniform_() > 0.5; + + const std::size_t num_fwd = std::accumulate(fwd_direction.begin(), fwd_direction.end(), 0); + const std::size_t num_bck = this->max_depth_ - num_fwd; + + /* + std::cout << "sampled turns: "; + for (std::size_t i = 0; i != this->max_depth_; ++i) { + if(fwd_direction[i]) + std::cout << "+,"; + else + std::cout << "-,"; + } + std::cout << std::endl; + */ + + tbb::concurrent_vector ends(this->max_depth_, std::make_tuple(true, 0, Eigen::VectorXd(), Eigen::VectorXd(), z_sample, 0, 0.0)); + tbb::concurrent_vector valid_subtree_fwd(num_fwd, true); + tbb::concurrent_vector valid_subtree_bck(num_bck, true); + + // HACK!!! + callbacks::logger logger_fwd; + callbacks::logger logger_bck; + + // build TBB flow graph + graph g; + + // add nodes which advance the left/right tree + typedef continue_node tree_builder_t; + + tbb::concurrent_vector all_builder_idx(this->max_depth_); + tbb::concurrent_vector fwd_builder; + tbb::concurrent_vector bck_builder; + typedef tbb::concurrent_vector::iterator builder_iter_t; + + // now wire up the fwd and bck build of the trees which + // depends on single-core or multi-core run + const bool run_serial = stan::math::internal::get_num_threads() == 1; + + std::size_t fwd_idx = 0; + std::size_t bck_idx = 0; + // TODO: the extenders should also check for a global flag if + // we want to keep running + for (std::size_t depth=0; depth != this->max_depth_; ++depth) { + if (fwd_direction[depth]) { + builder_iter_t fwd_iter = + fwd_builder.emplace_back(g, [&,depth,fwd_idx](continue_msg) { + //std::cout << "fwd turn at depth " << depth; + bool valid_parent = fwd_idx == 0 ? true : valid_subtree_fwd[fwd_idx-1]; + if (valid_parent) { + //std::cout << " yes, here we go!" << std::endl; + ends[depth] = extend_tree(depth, tree_fwd, z_fwd, logger_fwd); + valid_subtree_fwd[fwd_idx] = std::get<0>(ends[depth]); + } else { + valid_subtree_fwd[fwd_idx] = false; + } + //std::cout << " nothing to do." << std::endl; + }); + if(!run_serial && fwd_idx != 0) { + // in this case this is not the starting node, we + // connect this with its predecessor + make_edge(*(fwd_iter-1), *fwd_iter); + } + all_builder_idx[depth] = fwd_idx; + ++fwd_idx; + } else { + builder_iter_t bck_iter = + bck_builder.emplace_back(g, [&,depth,bck_idx](continue_msg) { + //std::cout << "bck turn at depth " << depth; + bool valid_parent = bck_idx == 0 ? true : valid_subtree_bck[bck_idx-1]; + if (valid_parent) { + //std::cout << " yes, here we go!" << std::endl; + ends[depth] = extend_tree(depth, tree_bck, z_bck, logger_bck); + valid_subtree_bck[bck_idx] = std::get<0>(ends[depth]); + } else { + valid_subtree_bck[bck_idx] = false; + } + //std::cout << " nothing to do." << std::endl; + }); + if(!run_serial && bck_idx != 0) { + // in case this is not the starting node, we connect + // this with his predecessor + //make_edge(bck_builder[bck_idx-1], bck_builder[bck_idx]); + make_edge(*(bck_iter-1), *bck_iter); + } + all_builder_idx[depth] = bck_idx; + ++bck_idx; + } + } + + // finally wire in the checker which accepts or rejects the + // proposed states from the subtrees + //typedef function_node< tbb::flow::tuple, bool> checker_t; + //typedef join_node< tbb::flow::tuple > joiner_t; + typedef continue_node checker_t; + + tbb::concurrent_vector checks; + //std::vector joins; + + Eigen::VectorXd p_sharp_fwd(p_sharp); + Eigen::VectorXd p_sharp_bck(p_sharp); + + for (std::size_t depth=0; depth != this->max_depth_; ++depth) { + //joins.push_back(joiner_t(g)); + //std::cout << "creating check at depth " << depth << std::endl; + checks.emplace_back(g, [&,depth](continue_msg) { + bool is_fwd = fwd_direction[depth]; + + extend_tree_t& subtree_result = ends[depth]; + + // if we are still on the + // trajectories which are + // actually used update the + // running tree stats + if (this->valid_trees_) { + this->depth_ = depth + 1; + n_leapfrog += std::get<5>(subtree_result); + sum_metro_prob += std::get<6>(subtree_result); + } + + bool valid_subtree = is_fwd ? + valid_subtree_fwd[all_builder_idx[depth]] : + valid_subtree_bck[all_builder_idx[depth]]; + + bool is_valid = valid_subtree & this->valid_trees_; + + //std::cout << "CHECK at depth " << depth; + + if(!is_valid) { + //std::cout << " we are done (early)" << std::endl; + + // setting this globally here + // will terminate all ongoing work + this->valid_trees_ = false; + return; + } + + //std::cout << " checking" << std::endl; + + double log_sum_weight_subtree = std::get<1>(subtree_result); + const Eigen::VectorXd& rho_subtree = std::get<2>(subtree_result); + + // update correct side + if (is_fwd) { + p_sharp_fwd = std::get<3>(subtree_result); + } else { + p_sharp_bck = std::get<3>(subtree_result); + } + + const ps_point& z_propose = std::get<4>(subtree_result); + + // update running sums + if (log_sum_weight_subtree > log_sum_weight) { + z_sample = z_propose; + } else { + double accept_prob + = std::exp(log_sum_weight_subtree - log_sum_weight); + //if (this->rand_uniform_() < + //accept_prob) + // HACK + if (get_rand_uniform() < accept_prob) + z_sample = z_propose; + } + + log_sum_weight + = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); + + // Break when NUTS criterion is no longer satisfied + rho += rho_subtree; + if (!compute_criterion(p_sharp_bck, p_sharp_fwd, rho)) { + // setting this globally here + // will terminate all ongoing work + this->valid_trees_ = false; + //std::cout << " we are done (later)" << std::endl; + } + //std::cout << " continuing (later)" << std::endl; + }); + if(fwd_direction[depth]) { + //std::cout << "depth " << depth << ": joining fwd node " << all_builder_idx[depth] << " into join node." << std::endl; + make_edge(fwd_builder[all_builder_idx[depth]], checks.back()); + } else { + //std::cout << "depth " << depth << ": joining bck node " << all_builder_idx[depth] << " into join node." << std::endl; + make_edge(bck_builder[all_builder_idx[depth]], checks.back()); + } + if(!run_serial && depth != 0) { + make_edge(checks[depth-1], checks.back()); + } + } + + if(run_serial) { + for(std::size_t i = 1; i < this->max_depth_; ++i) { + make_edge(checks[i-1], fwd_direction[i] ? fwd_builder[all_builder_idx[i]] : bck_builder[all_builder_idx[i]]); + } + } + + // kick off work + if(fwd_direction[0]) { + fwd_builder[0].try_put(continue_msg()); + // the first turn is fwd, so kick off the bck walker if needed + if (!run_serial && num_bck != 0) + bck_builder[0].try_put(continue_msg()); + } else { + bck_builder[0].try_put(continue_msg()); + if (!run_serial && num_fwd != 0) + fwd_builder[0].try_put(continue_msg()); + } + + g.wait_for_all(); + + this->n_leapfrog_ = n_leapfrog; + //this->n_leapfrog_ = tree_fwd.n_leapfrog_ + tree_bck.n_leapfrog_; + + // this includes the speculative executed ones + //const double sum_metro_prob = tree_fwd.sum_metro_prob_ + tree_bck.sum_metro_prob_; + + // Compute average acceptance probabilty across entire trajectory, + // even over subtrees that may have been rejected + double accept_prob + = sum_metro_prob / static_cast(this->n_leapfrog_); + + this->z_.ps_point::operator=(z_sample); + this->energy_ = this->hamiltonian_.H(this->z_); + return sample(this->z_.q, -this->z_.V, accept_prob); + } + + sample + transition_refactored(sample& init_sample, callbacks::logger& logger) { + // Initialize the algorithm + this->sample_stepsize(); + + this->seed(init_sample.cont_params()); + + this->hamiltonian_.sample_p(this->z_, this->rand_int_); + this->hamiltonian_.init(this->z_, logger); + + const ps_point z_init(this->z_); + + ps_point z_sample(z_init); + ps_point z_propose(z_init); + + const Eigen::VectorXd p_sharp = this->hamiltonian_.dtau_dp(this->z_); + Eigen::VectorXd rho = this->z_.p; + + double log_sum_weight = 0; // log(exp(H0 - H0)) + double H0 = this->hamiltonian_.H(this->z_); + //int n_leapfrog = 0; + //double sum_metro_prob = 0; + + // forward tree + subtree tree_fwd(1, z_init, p_sharp, H0); + // backward tree + subtree tree_bck(-1, z_init, p_sharp, H0); + + // Build a trajectory until the NUTS criterion is no longer satisfied + this->depth_ = 0; + this->divergent_ = false; + + while (this->depth_ < this->max_depth_) { + bool valid_subtree; + double log_sum_weight_subtree; + Eigen::VectorXd rho_subtree; + + if (this->rand_uniform_() > 0.5) { + std::tie(valid_subtree, log_sum_weight_subtree, rho_subtree, z_propose) + = extend_tree(this->depth_, tree_fwd, this->z_, logger); + } else { + std::tie(valid_subtree, log_sum_weight_subtree, rho_subtree, z_propose) + = extend_tree(this->depth_, tree_bck, this->z_, logger); + } + + if (!valid_subtree) break; + + // Sample from an accepted subtree + ++(this->depth_); + + if (log_sum_weight_subtree > log_sum_weight) { + z_sample = z_propose; + } else { + double accept_prob + = std::exp(log_sum_weight_subtree - log_sum_weight); + if (this->rand_uniform_() < accept_prob) + z_sample = z_propose; + } + + log_sum_weight + = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); + + // Break when NUTS criterion is no longer satisfied + rho += rho_subtree; + if (!compute_criterion(tree_bck.p_sharp_end_, tree_fwd.p_sharp_end_, rho)) + break; + //if (!compute_criterion(p_sharp_minus, p_sharp_plus, rho)) + // break; + } + + //this->n_leapfrog_ = n_leapfrog; + this->n_leapfrog_ = tree_fwd.n_leapfrog_ + tree_bck.n_leapfrog_; + + const double sum_metro_prob = tree_fwd.sum_metro_prob_ + tree_bck.sum_metro_prob_; + + // Compute average acceptance probabilty across entire trajectory, + // even over subtrees that may have been rejected + double accept_prob + = sum_metro_prob / static_cast(this->n_leapfrog_); + + this->z_.ps_point::operator=(z_sample); + this->energy_ = this->hamiltonian_.H(this->z_); + return sample(this->z_.q, -this->z_.V, accept_prob); + } + + sample + transition_old(sample& init_sample, callbacks::logger& logger) { + // Initialize the algorithm + this->sample_stepsize(); + + this->seed(init_sample.cont_params()); + + this->hamiltonian_.sample_p(this->z_, this->rand_int_); + this->hamiltonian_.init(this->z_, logger); + + ps_point z_plus(this->z_); + ps_point z_minus(z_plus); + + ps_point z_sample(z_plus); + ps_point z_propose(z_plus); + + Eigen::VectorXd p_sharp_plus = this->hamiltonian_.dtau_dp(this->z_); + //Eigen::VectorXd p_sharp_dummy = p_sharp_plus; + Eigen::VectorXd p_sharp_minus = p_sharp_plus; + Eigen::VectorXd rho = this->z_.p; + + double log_sum_weight = 0; // log(exp(H0 - H0)) + double H0 = this->hamiltonian_.H(this->z_); + int n_leapfrog = 0; + double sum_metro_prob = 0; + + // Build a trajectory until the NUTS criterion is no longer satisfied + this->depth_ = 0; + this->divergent_ = false; + + while (this->depth_ < this->max_depth_) { + // Build a new subtree in a random direction + Eigen::VectorXd rho_subtree = Eigen::VectorXd::Zero(rho.size()); + bool valid_subtree = false; + double log_sum_weight_subtree + = -std::numeric_limits::infinity(); + + // this should be fine (modified from orig) + Eigen::VectorXd p_sharp_dummy = Eigen::VectorXd::Zero(this->z_.p.size()); + + if (this->rand_uniform_() > 0.5) { + this->z_.ps_point::operator=(z_plus); + valid_subtree + = build_tree(this->depth_, this->z_, z_propose, + p_sharp_dummy, p_sharp_plus, rho_subtree, + H0, 1, n_leapfrog, + log_sum_weight_subtree, sum_metro_prob, + logger); + z_plus.ps_point::operator=(this->z_); + } else { + this->z_.ps_point::operator=(z_minus); + valid_subtree + = build_tree(this->depth_, this->z_, z_propose, + p_sharp_dummy, p_sharp_minus, rho_subtree, + H0, -1, n_leapfrog, + log_sum_weight_subtree, sum_metro_prob, + logger); + z_minus.ps_point::operator=(this->z_); + } + + if (!valid_subtree) break; + + // Sample from an accepted subtree + ++(this->depth_); + + if (log_sum_weight_subtree > log_sum_weight) { + z_sample = z_propose; + } else { + double accept_prob + = std::exp(log_sum_weight_subtree - log_sum_weight); + if (this->rand_uniform_() < accept_prob) + z_sample = z_propose; + } + + log_sum_weight + = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); + + // Break when NUTS criterion is no longer satisfied + rho += rho_subtree; + if (!compute_criterion(p_sharp_minus, p_sharp_plus, rho)) + break; + } + + this->n_leapfrog_ = n_leapfrog; + + // Compute average acceptance probabilty across entire trajectory, + // even over subtrees that may have been rejected + double accept_prob + = sum_metro_prob / static_cast(n_leapfrog); + + this->z_.ps_point::operator=(z_sample); + this->energy_ = this->hamiltonian_.H(this->z_); + return sample(this->z_.q, -this->z_.V, accept_prob); + } + + void get_sampler_param_names(std::vector& names) { + names.push_back("stepsize__"); + names.push_back("treedepth__"); + names.push_back("n_leapfrog__"); + names.push_back("divergent__"); + names.push_back("energy__"); + } + + void get_sampler_params(std::vector& values) { + values.push_back(this->epsilon_); + values.push_back(this->depth_); + values.push_back(this->n_leapfrog_); + values.push_back(this->divergent_); + values.push_back(this->energy_); + } + + virtual bool compute_criterion(Eigen::VectorXd& p_sharp_minus, + Eigen::VectorXd& p_sharp_plus, + Eigen::VectorXd& rho) { + return p_sharp_plus.dot(rho) > 0 + && p_sharp_minus.dot(rho) > 0; + } + + /** + * Recursively build a new subtree to completion or until + * the subtree becomes invalid. Returns validity of the + * resulting subtree. + * + * @param depth Depth of the desired subtree + * @param z_beg State beginning from subtree + * @param z_propose State proposed from subtree + * @param p_sharp_left p_sharp from left boundary of returned tree + * @param p_sharp_right p_sharp from the right boundary of returned tree + * @param rho Summed momentum across trajectory + * @param H0 Hamiltonian of initial state + * @param sign Direction in time to built subtree + * @param n_leapfrog Summed number of leapfrog evaluations + * @param log_sum_weight Log of summed weights across trajectory + * @param sum_metro_prob Summed Metropolis probabilities across trajectory + * @param logger Logger for messages + */ + bool build_tree(int depth, state_t& z_beg, + ps_point& z_propose, + Eigen::VectorXd& p_sharp_left, + Eigen::VectorXd& p_sharp_right, + Eigen::VectorXd& rho, + double H0, double sign, int& n_leapfrog, + double& log_sum_weight, double& sum_metro_prob, + callbacks::logger& logger) { + // Base case + if (depth == 0) { + // check if trees are still valid or if we should terminate + if(!this->valid_trees_) + return false; + + this->integrator_.evolve(z_beg, this->hamiltonian_, + sign * this->epsilon_, + logger); + + ++n_leapfrog; + + double h = this->hamiltonian_.H(z_beg); + if (boost::math::isnan(h)) + h = std::numeric_limits::infinity(); + + // TODO: in parallel case we cannot use the global divergent + // flag since this could be a speculative tree!! + //if ((h - H0) > this->max_deltaH_) this->divergent_ = true; + bool is_divergent = (h - H0) > this->max_deltaH_; + //if ((h - H0) > this->max_deltaH_) this->divergent_ = true; + + log_sum_weight = math::log_sum_exp(log_sum_weight, H0 - h); + + if (H0 - h > 0) + sum_metro_prob += 1; + else + sum_metro_prob += std::exp(H0 - h); + + z_propose = z_beg; + rho += z_beg.p; + + p_sharp_left = this->hamiltonian_.dtau_dp(z_beg); + p_sharp_right = p_sharp_left; + + return !is_divergent; + } + // General recursion + Eigen::VectorXd p_sharp_dummy(z_beg.p.size()); + + // Build the left subtree + double log_sum_weight_left = -std::numeric_limits::infinity(); + Eigen::VectorXd rho_left = Eigen::VectorXd::Zero(rho.size()); + + bool valid_left + = build_tree(depth - 1, z_beg, z_propose, + p_sharp_left, p_sharp_dummy, rho_left, + H0, sign, n_leapfrog, + log_sum_weight_left, sum_metro_prob, + logger); + + if (!valid_left) return false; + + // Build the right subtree + ps_point z_propose_right(z_beg); + + double log_sum_weight_right = -std::numeric_limits::infinity(); + Eigen::VectorXd rho_right = Eigen::VectorXd::Zero(rho.size()); + + bool valid_right + = build_tree(depth - 1, z_beg, z_propose_right, + p_sharp_dummy, p_sharp_right, rho_right, + H0, sign, n_leapfrog, + log_sum_weight_right, sum_metro_prob, + logger); + + if (!valid_right) return false; + + // Multinomial sample from right subtree + double log_sum_weight_subtree + = math::log_sum_exp(log_sum_weight_left, log_sum_weight_right); + log_sum_weight = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); - // Break when no-u-turn criterion is no longer satisfied - rho = rho_bck + rho_fwd; - - // Demand satisfaction around merged subtrees - bool persist_criterion - = compute_criterion(p_sharp_bck_bck, p_sharp_fwd_fwd, rho); - - // Demand satisfaction between subtrees - Eigen::VectorXd rho_extended = rho_bck + p_fwd_bck; - - persist_criterion - &= compute_criterion(p_sharp_bck_bck, p_sharp_fwd_bck, rho_extended); - - rho_extended = rho_fwd + p_bck_fwd; - persist_criterion - &= compute_criterion(p_sharp_bck_fwd, p_sharp_fwd_fwd, rho_extended); - - if (!persist_criterion) - break; - } - - this->n_leapfrog_ = n_leapfrog; - - // Compute average acceptance probabilty across entire trajectory, - // even over subtrees that may have been rejected - double accept_prob = sum_metro_prob / static_cast(n_leapfrog); - - this->z_.ps_point::operator=(z_sample); - this->energy_ = this->hamiltonian_.H(this->z_); - return sample(this->z_.q, -this->z_.V, accept_prob); - } - - void get_sampler_param_names(std::vector& names) { - names.push_back("stepsize__"); - names.push_back("treedepth__"); - names.push_back("n_leapfrog__"); - names.push_back("divergent__"); - names.push_back("energy__"); - } - - void get_sampler_params(std::vector& values) { - values.push_back(this->epsilon_); - values.push_back(this->depth_); - values.push_back(this->n_leapfrog_); - values.push_back(this->divergent_); - values.push_back(this->energy_); - } - - virtual bool compute_criterion(Eigen::VectorXd& p_sharp_minus, - Eigen::VectorXd& p_sharp_plus, - Eigen::VectorXd& rho) { - return p_sharp_plus.dot(rho) > 0 && p_sharp_minus.dot(rho) > 0; - } - - /** - * Recursively build a new subtree to completion or until - * the subtree becomes invalid. Returns validity of the - * resulting subtree. - * - * @param depth Depth of the desired subtree - * @param z_propose State proposed from subtree - * @param p_sharp_beg Sharp momentum at beginning of new tree - * @param p_sharp_end Sharp momentum at end of new tree - * @param rho Summed momentum across trajectory - * @param p_beg Momentum at beginning of returned tree - * @param p_end Momentum at end of returned tree - * @param H0 Hamiltonian of initial state - * @param sign Direction in time to built subtree - * @param n_leapfrog Summed number of leapfrog evaluations - * @param log_sum_weight Log of summed weights across trajectory - * @param sum_metro_prob Summed Metropolis probabilities across trajectory - * @param logger Logger for messages - */ - bool build_tree(int depth, ps_point& z_propose, Eigen::VectorXd& p_sharp_beg, - Eigen::VectorXd& p_sharp_end, Eigen::VectorXd& rho, - Eigen::VectorXd& p_beg, Eigen::VectorXd& p_end, double H0, - double sign, int& n_leapfrog, double& log_sum_weight, - double& sum_metro_prob, callbacks::logger& logger) { - // Base case - if (depth == 0) { - this->integrator_.evolve(this->z_, this->hamiltonian_, - sign * this->epsilon_, logger); - ++n_leapfrog; - - double h = this->hamiltonian_.H(this->z_); - if (std::isnan(h)) - h = std::numeric_limits::infinity(); - - if ((h - H0) > this->max_deltaH_) - this->divergent_ = true; - - log_sum_weight = math::log_sum_exp(log_sum_weight, H0 - h); - - if (H0 - h > 0) - sum_metro_prob += 1; - else - sum_metro_prob += std::exp(H0 - h); - - z_propose = this->z_; - - p_sharp_beg = this->hamiltonian_.dtau_dp(this->z_); - p_sharp_end = p_sharp_beg; - - rho += this->z_.p; - p_beg = this->z_.p; - p_end = p_beg; - - return !this->divergent_; - } - // General recursion - - // Build the initial subtree - double log_sum_weight_init = -std::numeric_limits::infinity(); - - // Momentum and sharp momentum at end of the initial subtree - Eigen::VectorXd p_init_end(this->z_.p.size()); - Eigen::VectorXd p_sharp_init_end(this->z_.p.size()); - - Eigen::VectorXd rho_init = Eigen::VectorXd::Zero(rho.size()); - - bool valid_init - = build_tree(depth - 1, z_propose, p_sharp_beg, p_sharp_init_end, - rho_init, p_beg, p_init_end, H0, sign, n_leapfrog, - log_sum_weight_init, sum_metro_prob, logger); - - if (!valid_init) - return false; - - // Build the final subtree - ps_point z_propose_final(this->z_); - - double log_sum_weight_final = -std::numeric_limits::infinity(); - - // Momentum and sharp momentum at beginning of the final subtree - Eigen::VectorXd p_final_beg(this->z_.p.size()); - Eigen::VectorXd p_sharp_final_beg(this->z_.p.size()); - - Eigen::VectorXd rho_final = Eigen::VectorXd::Zero(rho.size()); - - bool valid_final - = build_tree(depth - 1, z_propose_final, p_sharp_final_beg, p_sharp_end, - rho_final, p_final_beg, p_end, H0, sign, n_leapfrog, - log_sum_weight_final, sum_metro_prob, logger); - - if (!valid_final) - return false; - - // Multinomial sample from right subtree - double log_sum_weight_subtree - = math::log_sum_exp(log_sum_weight_init, log_sum_weight_final); - log_sum_weight = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); - - if (log_sum_weight_final > log_sum_weight_subtree) { - z_propose = z_propose_final; - } else { - double accept_prob - = std::exp(log_sum_weight_final - log_sum_weight_subtree); - if (this->rand_uniform_() < accept_prob) - z_propose = z_propose_final; - } - - Eigen::VectorXd rho_subtree = rho_init + rho_final; - rho += rho_subtree; - - // Demand satisfaction around merged subtrees - bool persist_criterion - = compute_criterion(p_sharp_beg, p_sharp_end, rho_subtree); - - // Demand satisfaction between subtrees - rho_subtree = rho_init + p_final_beg; - persist_criterion - &= compute_criterion(p_sharp_beg, p_sharp_final_beg, rho_subtree); + if (log_sum_weight_right > log_sum_weight_subtree) { + z_propose = z_propose_right; + } else { + double accept_prob + = std::exp(log_sum_weight_right - log_sum_weight_subtree); + //if (this->rand_uniform_() < accept_prob) + if (get_rand_uniform() < accept_prob) + z_propose = z_propose_right; + } + + Eigen::VectorXd rho_subtree = rho_left + rho_right; + rho += rho_subtree; - rho_subtree = rho_final + p_init_end; - persist_criterion - &= compute_criterion(p_sharp_init_end, p_sharp_end, rho_subtree); + return compute_criterion(p_sharp_left, p_sharp_right, rho_subtree); + } - return persist_criterion; - } + inline double get_rand_uniform() { + static std::mutex rng_mutex; + std::lock_guard lock(rng_mutex); + return this->rand_uniform_(); + } - int depth_; - int max_depth_; - double max_deltaH_; + int depth_; + int max_depth_; + double max_deltaH_; + bool valid_trees_; - int n_leapfrog_; - bool divergent_; - double energy_; -}; + int n_leapfrog_; + bool divergent_; + double energy_; + }; -} // namespace mcmc -} // namespace stan + } // mcmc +} // stan #endif From 5d9a09acf3ca085f277f0c881afc13483ae7fc75 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Tue, 15 Feb 2022 18:15:12 -0500 Subject: [PATCH 2/8] get rid of mutex when sampling uniforms for parallel nuts --- src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp | 12 +- src/stan/mcmc/hmc/nuts/base_nuts.hpp | 50 +- src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp | 790 ++++++++++++++++++ src/stan/mcmc/hmc/nuts/diag_e_nuts.hpp | 12 +- .../sample/hmc_nuts_diag_e_adapt_parallel.hpp | 375 +++++++++ 5 files changed, 1206 insertions(+), 33 deletions(-) create mode 100644 src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp create mode 100644 src/stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp diff --git a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp index 45e92380f57..5e142ce8821 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp @@ -12,18 +12,20 @@ namespace mcmc { * with a Gaussian-Euclidean disintegration and adaptive * diagonal metric and adaptive step size */ -template -class adapt_diag_e_nuts : public diag_e_nuts, +template +class adapt_diag_e_nuts : public diag_e_nuts, public stepsize_var_adapter { public: adapt_diag_e_nuts(const Model& model, BaseRNG& rng) - : diag_e_nuts(model, rng), + : diag_e_nuts(model, rng), stepsize_var_adapter(model.num_params_r()) {} - ~adapt_diag_e_nuts() {} + diag_e_nuts(const Model& model, std::vector& thread_rngs) + : diag_e_nuts(model, thread_rngs), + stepsize_var_adapter(model.num_params_r()) {} sample transition(sample& init_sample, callbacks::logger& logger) { - sample s = diag_e_nuts::transition(init_sample, logger); + sample s = diag_e_nuts::transition(init_sample, logger); if (this->adapt_flag_) { this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, diff --git a/src/stan/mcmc/hmc/nuts/base_nuts.hpp b/src/stan/mcmc/hmc/nuts/base_nuts.hpp index f851eb67200..74aca4d105c 100644 --- a/src/stan/mcmc/hmc/nuts/base_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/base_nuts.hpp @@ -24,7 +24,7 @@ using namespace tbb::flow; // Prototype of speculative NUTS. // Uses the Intel Flow Graph concept to turn NUTS into a parallel // algorithm in that the forward and backward sweep run at the same -// time in parallel. +// time in parallel. namespace stan { namespace mcmc { @@ -36,7 +36,7 @@ namespace stan { class base_nuts : public base_hmc { public: typedef typename Hamiltonian::PointType state_t; - + base_nuts(const Model& model, BaseRNG& rng) : base_hmc(model, rng), depth_(0), max_depth_(5), max_deltaH_(1000), valid_trees_(true), @@ -114,7 +114,7 @@ namespace stan { // extends the tree into the direction of the sign of the // subtree typedef std::tuple extend_tree_t; - + extend_tree_t extend_tree(int depth, subtree& tree, state_t& z, callbacks::logger& logger) { @@ -122,15 +122,15 @@ namespace stan { //Eigen::VectorXd p_end = tree.p_end_; //Eigen::VectorXd p_sharp_end = tree.p_sharp_end_; Eigen::VectorXd p_sharp_dummy = Eigen::VectorXd::Zero(tree.p_sharp_end_.size()); - + Eigen::VectorXd rho_subtree = Eigen::VectorXd::Zero(tree.p_sharp_end_.size()); double log_sum_weight_subtree = -std::numeric_limits::infinity(); tree.n_leapfrog_ = 0; tree.sum_metro_prob_ = 0; - + z.ps_point::operator=(tree.z_end_); - + bool valid_subtree = build_tree(depth, z, tree.z_propose_, p_sharp_dummy, tree.p_sharp_end_, @@ -142,10 +142,10 @@ namespace stan { logger); tree.z_end_.ps_point::operator=(z); - + return std::make_tuple(valid_subtree, log_sum_weight_subtree, rho_subtree, tree.p_sharp_end_, tree.z_propose_, tree.n_leapfrog_, tree.sum_metro_prob_); } - + sample transition(sample& init_sample, callbacks::logger& logger) { @@ -172,13 +172,13 @@ namespace stan { this->hamiltonian_.init(this->z_, logger); const ps_point z_init(this->z_); - + ps_point z_sample(z_init); //ps_point z_propose(z_init); const Eigen::VectorXd p_sharp = this->hamiltonian_.dtau_dp(this->z_); Eigen::VectorXd rho = this->z_.p; - + double log_sum_weight = 0; // log(exp(H0 - H0)) double H0 = this->hamiltonian_.H(this->z_); //int n_leapfrog = 0; @@ -192,7 +192,7 @@ namespace stan { // actual states which move... copy construct atm...revise?! state_t z_fwd(this->z_); state_t z_bck(this->z_); - + // Build a trajectory until the NUTS criterion is no longer satisfied this->depth_ = 0; this->divergent_ = false; @@ -223,7 +223,7 @@ namespace stan { } std::cout << std::endl; */ - + tbb::concurrent_vector ends(this->max_depth_, std::make_tuple(true, 0, Eigen::VectorXd(), Eigen::VectorXd(), z_sample, 0, 0.0)); tbb::concurrent_vector valid_subtree_fwd(num_fwd, true); tbb::concurrent_vector valid_subtree_bck(num_bck, true); @@ -231,7 +231,7 @@ namespace stan { // HACK!!! callbacks::logger logger_fwd; callbacks::logger logger_bck; - + // build TBB flow graph graph g; @@ -309,7 +309,7 @@ namespace stan { Eigen::VectorXd p_sharp_fwd(p_sharp); Eigen::VectorXd p_sharp_bck(p_sharp); - + for (std::size_t depth=0; depth != this->max_depth_; ++depth) { //joins.push_back(joiner_t(g)); //std::cout << "creating check at depth " << depth << std::endl; @@ -317,7 +317,7 @@ namespace stan { bool is_fwd = fwd_direction[depth]; extend_tree_t& subtree_result = ends[depth]; - + // if we are still on the // trajectories which are // actually used update the @@ -327,7 +327,7 @@ namespace stan { n_leapfrog += std::get<5>(subtree_result); sum_metro_prob += std::get<6>(subtree_result); } - + bool valid_subtree = is_fwd ? valid_subtree_fwd[all_builder_idx[depth]] : valid_subtree_bck[all_builder_idx[depth]]; @@ -338,7 +338,7 @@ namespace stan { if(!is_valid) { //std::cout << " we are done (early)" << std::endl; - + // setting this globally here // will terminate all ongoing work this->valid_trees_ = false; @@ -349,7 +349,7 @@ namespace stan { double log_sum_weight_subtree = std::get<1>(subtree_result); const Eigen::VectorXd& rho_subtree = std::get<2>(subtree_result); - + // update correct side if (is_fwd) { p_sharp_fwd = std::get<3>(subtree_result); @@ -416,7 +416,7 @@ namespace stan { } g.wait_for_all(); - + this->n_leapfrog_ = n_leapfrog; //this->n_leapfrog_ = tree_fwd.n_leapfrog_ + tree_bck.n_leapfrog_; @@ -444,13 +444,13 @@ namespace stan { this->hamiltonian_.init(this->z_, logger); const ps_point z_init(this->z_); - + ps_point z_sample(z_init); ps_point z_propose(z_init); const Eigen::VectorXd p_sharp = this->hamiltonian_.dtau_dp(this->z_); Eigen::VectorXd rho = this->z_.p; - + double log_sum_weight = 0; // log(exp(H0 - H0)) double H0 = this->hamiltonian_.H(this->z_); //int n_leapfrog = 0; @@ -460,7 +460,7 @@ namespace stan { subtree tree_fwd(1, z_init, p_sharp, H0); // backward tree subtree tree_bck(-1, z_init, p_sharp, H0); - + // Build a trajectory until the NUTS criterion is no longer satisfied this->depth_ = 0; this->divergent_ = false; @@ -469,7 +469,7 @@ namespace stan { bool valid_subtree; double log_sum_weight_subtree; Eigen::VectorXd rho_subtree; - + if (this->rand_uniform_() > 0.5) { std::tie(valid_subtree, log_sum_weight_subtree, rho_subtree, z_propose) = extend_tree(this->depth_, tree_fwd, this->z_, logger); @@ -477,7 +477,7 @@ namespace stan { std::tie(valid_subtree, log_sum_weight_subtree, rho_subtree, z_propose) = extend_tree(this->depth_, tree_bck, this->z_, logger); } - + if (!valid_subtree) break; // Sample from an accepted subtree @@ -671,7 +671,7 @@ namespace stan { this->integrator_.evolve(z_beg, this->hamiltonian_, sign * this->epsilon_, logger); - + ++n_leapfrog; double h = this->hamiltonian_.H(z_beg); diff --git a/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp b/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp new file mode 100644 index 00000000000..8aed8d5cbec --- /dev/null +++ b/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp @@ -0,0 +1,790 @@ +#ifndef STAN_MCMC_HMC_NUTS_BASE_PARALLEL_NUTS_HPP +#define STAN_MCMC_HMC_NUTS_BASE_PARALLEL_NUTS_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include "tbb/task_scheduler_init.h" +#include "tbb/flow_graph.h" +#include "tbb/concurrent_vector.h" + +using namespace tbb::flow; + +template +inline auto make_uniform_vec(std::vector& thread_rngs) { + /* + std::vector> rand_uniform_vec; + const size_t num_thread_rngs = thread_rngs.size(); + rand_uniform_vec.reserve(num_thread_rngs); + for (size_t i = 0; i < rand_uniform_vec.size(); ++i) { + rand_uniform_vec.emplace_back(thread_rngs[i]); + } + */ + return std::vector>(thread_rngs.begin(), thread_rngs.end()); +} + +// Prototype of speculative NUTS. +// Uses the Intel Flow Graph concept to turn NUTS into a parallel +// algorithm in that the forward and backward sweep run at the same +// time in parallel. + +namespace stan { + namespace mcmc { + + /** + * The No-U-Turn sampler (NUTS) with multinomial sampling + */ + template class Hamiltonian, + template class Integrator, class BaseRNG> + class base_parallel_nuts : public base_hmc { + public: + using state_t = typename Hamiltonian::PointType; + + base_parallel_nuts(const Model& model, std::vector& thread_rngs) + : base_hmc(model, thread_rngs[tbb::this_task_arena::current_thread_index()]), + rand_uniform_vec_(make_uniform_vec(thread_rngs)) { + } + + base_parallel_nuts(const Model& model, BaseRNG& rng, std::vector& thread_rngs) + : base_hmc(model, rng), + rand_uniform_vec_(make_uniform_vec(thread_rngs)) { + } + + /** + * specialized constructor for specified diag mass matrix + */ + base_parallel_nuts(const Model& model, BaseRNG& rng, + Eigen::VectorXd& inv_e_metric, std::vector& thread_rngs) + : base_hmc(model, rng, inv_e_metric), + rand_uniform_vec_(make_uniform_vec(thread_rngs)) { + } + + /** + * specialized constructor for specified dense mass matrix + */ + base_parallel_nuts(const Model& model, BaseRNG& rng, + Eigen::MatrixXd& inv_e_metric, std::vector& thread_rngs) + : base_hmc(model, rng, inv_e_metric), + rand_uniform_vec_(make_uniform_vec(thread_rngs)) { + } + + ~base_parallel_nuts() {} + + inline void set_metric(const Eigen::MatrixXd& inv_e_metric) { + this->z_.set_metric(inv_e_metric); + } + + inline void set_metric(const Eigen::VectorXd& inv_e_metric) { + this->z_.set_metric(inv_e_metric); + } + + inline void set_max_depth(int d) noexcept { + if (d > 0) { + max_depth_ = d; + } + } + + inline void set_max_delta(double d) noexcept { + max_deltaH_ = d; + } + + inline int get_max_depth() noexcept { return this->max_depth_; } + inline double get_max_delta() noexcept { return this->max_deltaH_; } + + // stores from left/right subtree entire information + struct subtree { + subtree(const double sign, + const ps_point& z_end, + const Eigen::VectorXd& p_sharp_end, + double H0) + : z_end_(z_end), z_propose_(z_end), + p_sharp_end_(p_sharp_end), + H0_(H0), + sign_(sign), + n_leapfrog_(0), + sum_metro_prob_(0) + {} + + ps_point z_end_; + ps_point z_propose_; + Eigen::VectorXd p_sharp_end_; + double H0_; + double sign_; + int n_leapfrog_{0}; + double sum_metro_prob_{0}; + }; + + + // extends the tree into the direction of the sign of the + // subtree + using extend_tree_t = std::tuple; + + inline extend_tree_t extend_tree(int depth, subtree& tree, state_t& z, + callbacks::logger& logger) { + // save the current ends needed for later criterion computations + //Eigen::VectorXd p_end = tree.p_end_; + //Eigen::VectorXd p_sharp_end = tree.p_sharp_end_; + Eigen::VectorXd p_sharp_dummy = Eigen::VectorXd::Zero(tree.p_sharp_end_.size()); + + Eigen::VectorXd rho_subtree = Eigen::VectorXd::Zero(tree.p_sharp_end_.size()); + double log_sum_weight_subtree = -std::numeric_limits::infinity(); + + tree.n_leapfrog_ = 0; + tree.sum_metro_prob_ = 0; + + z.ps_point::operator=(tree.z_end_); + + bool valid_subtree = build_tree(depth, + z, tree.z_propose_, + p_sharp_dummy, tree.p_sharp_end_, + rho_subtree, + tree.H0_, + tree.sign_, + tree.n_leapfrog_, + log_sum_weight_subtree, tree.sum_metro_prob_, + logger); + + tree.z_end_.ps_point::operator=(z); + + return std::make_tuple(valid_subtree, log_sum_weight_subtree, rho_subtree, tree.p_sharp_end_, tree.z_propose_, tree.n_leapfrog_, tree.sum_metro_prob_); + } + + + inline sample transition(sample& init_sample, callbacks::logger& logger) { + return transition_parallel(init_sample, logger); + } + + // this implementation builds up the dependence graph every call + // to transition. Things which should be refactored: + // 1. build up the nodes only once + // 2. add a prepare method to each node which samples its + // direction and needed random numbers for multinomial sampling + // 3. only the edges are added dynamically. So the forward nodes + // are wired-up and the backward nodes are wired-up if run + // parallel. If run serially, then each grow node is alternated + // with a check node. + sample + transition_parallel(sample& init_sample, callbacks::logger& logger) { + // Initialize the algorithm + this->sample_stepsize(); + + this->seed(init_sample.cont_params()); + + this->hamiltonian_.sample_p(this->z_, this->rand_int_); + this->hamiltonian_.init(this->z_, logger); + + const ps_point z_init(this->z_); + + ps_point z_sample(z_init); + //ps_point z_propose(z_init); + + const Eigen::VectorXd p_sharp = this->hamiltonian_.dtau_dp(this->z_); + Eigen::VectorXd rho = this->z_.p; + + double log_sum_weight = 0; // log(exp(H0 - H0)) + double H0 = this->hamiltonian_.H(this->z_); + //int n_leapfrog = 0; + //double sum_metro_prob = 0; + + // forward tree + subtree tree_fwd(1, z_init, p_sharp, H0); + // backward tree + subtree tree_bck(-1, z_init, p_sharp, H0); + + // actual states which move... copy construct atm...revise?! + state_t z_fwd(this->z_); + state_t z_bck(this->z_); + + // Build a trajectory until the NUTS criterion is no longer satisfied + this->depth_ = 0; + this->divergent_ = false; + this->valid_trees_ = true; + + // the actual number of leapfrog steps in trajectory used + // excluding the ones executed speculative + int n_leapfrog = 0; + + // actually summed metropolis prob of used trajectory + double sum_metro_prob = 0; + + std::vector fwd_direction(this->max_depth_); + + for (std::size_t i = 0; i != this->max_depth_; ++i) + fwd_direction[i] = this->rand_uniform_() > 0.5; + + const std::size_t num_fwd = std::accumulate(fwd_direction.begin(), fwd_direction.end(), 0); + const std::size_t num_bck = this->max_depth_ - num_fwd; + + /* + std::cout << "sampled turns: "; + for (std::size_t i = 0; i != this->max_depth_; ++i) { + if(fwd_direction[i]) + std::cout << "+,"; + else + std::cout << "-,"; + } + std::cout << std::endl; + */ + + tbb::concurrent_vector ends(this->max_depth_, std::make_tuple(true, 0, Eigen::VectorXd(), Eigen::VectorXd(), z_sample, 0, 0.0)); + tbb::concurrent_vector valid_subtree_fwd(num_fwd, true); + tbb::concurrent_vector valid_subtree_bck(num_bck, true); + + // HACK!!! + callbacks::logger logger_fwd; + callbacks::logger logger_bck; + + // build TBB flow graph + graph g; + + // add nodes which advance the left/right tree + typedef continue_node tree_builder_t; + + tbb::concurrent_vector all_builder_idx(this->max_depth_); + tbb::concurrent_vector fwd_builder; + tbb::concurrent_vector bck_builder; + typedef tbb::concurrent_vector::iterator builder_iter_t; + + // now wire up the fwd and bck build of the trees which + // depends on single-core or multi-core run + const bool run_serial = stan::math::internal::get_num_threads() == 1; + + std::size_t fwd_idx = 0; + std::size_t bck_idx = 0; + // TODO: the extenders should also check for a global flag if + // we want to keep running + for (std::size_t depth=0; depth != this->max_depth_; ++depth) { + if (fwd_direction[depth]) { + builder_iter_t fwd_iter = + fwd_builder.emplace_back(g, [&,depth,fwd_idx](continue_msg) { + //std::cout << "fwd turn at depth " << depth; + bool valid_parent = fwd_idx == 0 ? true : valid_subtree_fwd[fwd_idx-1]; + if (valid_parent) { + //std::cout << " yes, here we go!" << std::endl; + ends[depth] = extend_tree(depth, tree_fwd, z_fwd, logger_fwd); + valid_subtree_fwd[fwd_idx] = std::get<0>(ends[depth]); + } else { + valid_subtree_fwd[fwd_idx] = false; + } + //std::cout << " nothing to do." << std::endl; + }); + if(!run_serial && fwd_idx != 0) { + // in this case this is not the starting node, we + // connect this with its predecessor + make_edge(*(fwd_iter-1), *fwd_iter); + } + all_builder_idx[depth] = fwd_idx; + ++fwd_idx; + } else { + builder_iter_t bck_iter = + bck_builder.emplace_back(g, [&,depth,bck_idx](continue_msg) { + //std::cout << "bck turn at depth " << depth; + bool valid_parent = bck_idx == 0 ? true : valid_subtree_bck[bck_idx-1]; + if (valid_parent) { + //std::cout << " yes, here we go!" << std::endl; + ends[depth] = extend_tree(depth, tree_bck, z_bck, logger_bck); + valid_subtree_bck[bck_idx] = std::get<0>(ends[depth]); + } else { + valid_subtree_bck[bck_idx] = false; + } + //std::cout << " nothing to do." << std::endl; + }); + if(!run_serial && bck_idx != 0) { + // in case this is not the starting node, we connect + // this with his predecessor + //make_edge(bck_builder[bck_idx-1], bck_builder[bck_idx]); + make_edge(*(bck_iter-1), *bck_iter); + } + all_builder_idx[depth] = bck_idx; + ++bck_idx; + } + } + + // finally wire in the checker which accepts or rejects the + // proposed states from the subtrees + //typedef function_node< tbb::flow::tuple, bool> checker_t; + //typedef join_node< tbb::flow::tuple > joiner_t; + typedef continue_node checker_t; + + tbb::concurrent_vector checks; + //std::vector joins; + + Eigen::VectorXd p_sharp_fwd(p_sharp); + Eigen::VectorXd p_sharp_bck(p_sharp); + + for (std::size_t depth=0; depth != this->max_depth_; ++depth) { + //joins.push_back(joiner_t(g)); + //std::cout << "creating check at depth " << depth << std::endl; + checks.emplace_back(g, [&,depth](continue_msg) { + bool is_fwd = fwd_direction[depth]; + + extend_tree_t& subtree_result = ends[depth]; + + // if we are still on the + // trajectories which are + // actually used update the + // running tree stats + if (this->valid_trees_) { + this->depth_ = depth + 1; + n_leapfrog += std::get<5>(subtree_result); + sum_metro_prob += std::get<6>(subtree_result); + } + + bool valid_subtree = is_fwd ? + valid_subtree_fwd[all_builder_idx[depth]] : + valid_subtree_bck[all_builder_idx[depth]]; + + bool is_valid = valid_subtree & this->valid_trees_; + + //std::cout << "CHECK at depth " << depth; + + if(!is_valid) { + //std::cout << " we are done (early)" << std::endl; + + // setting this globally here + // will terminate all ongoing work + this->valid_trees_ = false; + return; + } + + //std::cout << " checking" << std::endl; + + double log_sum_weight_subtree = std::get<1>(subtree_result); + const Eigen::VectorXd& rho_subtree = std::get<2>(subtree_result); + + // update correct side + if (is_fwd) { + p_sharp_fwd = std::get<3>(subtree_result); + } else { + p_sharp_bck = std::get<3>(subtree_result); + } + + const ps_point& z_propose = std::get<4>(subtree_result); + + // update running sums + if (log_sum_weight_subtree > log_sum_weight) { + z_sample = z_propose; + } else { + double accept_prob + = std::exp(log_sum_weight_subtree - log_sum_weight); + //if (this->rand_uniform_() < + //accept_prob) + // HACK + if (get_rand_uniform() < accept_prob) + z_sample = z_propose; + } + + log_sum_weight + = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); + + // Break when NUTS criterion is no longer satisfied + rho += rho_subtree; + if (!compute_criterion(p_sharp_bck, p_sharp_fwd, rho)) { + // setting this globally here + // will terminate all ongoing work + this->valid_trees_ = false; + //std::cout << " we are done (later)" << std::endl; + } + //std::cout << " continuing (later)" << std::endl; + }); + if(fwd_direction[depth]) { + //std::cout << "depth " << depth << ": joining fwd node " << all_builder_idx[depth] << " into join node." << std::endl; + make_edge(fwd_builder[all_builder_idx[depth]], checks.back()); + } else { + //std::cout << "depth " << depth << ": joining bck node " << all_builder_idx[depth] << " into join node." << std::endl; + make_edge(bck_builder[all_builder_idx[depth]], checks.back()); + } + if(!run_serial && depth != 0) { + make_edge(checks[depth-1], checks.back()); + } + } + + if(run_serial) { + for(std::size_t i = 1; i < this->max_depth_; ++i) { + make_edge(checks[i-1], fwd_direction[i] ? fwd_builder[all_builder_idx[i]] : bck_builder[all_builder_idx[i]]); + } + } + + // kick off work + if(fwd_direction[0]) { + fwd_builder[0].try_put(continue_msg()); + // the first turn is fwd, so kick off the bck walker if needed + if (!run_serial && num_bck != 0) + bck_builder[0].try_put(continue_msg()); + } else { + bck_builder[0].try_put(continue_msg()); + if (!run_serial && num_fwd != 0) + fwd_builder[0].try_put(continue_msg()); + } + + g.wait_for_all(); + + this->n_leapfrog_ = n_leapfrog; + //this->n_leapfrog_ = tree_fwd.n_leapfrog_ + tree_bck.n_leapfrog_; + + // this includes the speculative executed ones + //const double sum_metro_prob = tree_fwd.sum_metro_prob_ + tree_bck.sum_metro_prob_; + + // Compute average acceptance probabilty across entire trajectory, + // even over subtrees that may have been rejected + double accept_prob + = sum_metro_prob / static_cast(this->n_leapfrog_); + + this->z_.ps_point::operator=(z_sample); + this->energy_ = this->hamiltonian_.H(this->z_); + return sample(this->z_.q, -this->z_.V, accept_prob); + } + + sample + transition_refactored(sample& init_sample, callbacks::logger& logger) { + // Initialize the algorithm + this->sample_stepsize(); + + this->seed(init_sample.cont_params()); + + this->hamiltonian_.sample_p(this->z_, this->rand_int_); + this->hamiltonian_.init(this->z_, logger); + + const ps_point z_init(this->z_); + + ps_point z_sample(z_init); + ps_point z_propose(z_init); + + const Eigen::VectorXd p_sharp = this->hamiltonian_.dtau_dp(this->z_); + Eigen::VectorXd rho = this->z_.p; + + double log_sum_weight = 0; // log(exp(H0 - H0)) + double H0 = this->hamiltonian_.H(this->z_); + //int n_leapfrog = 0; + //double sum_metro_prob = 0; + + // forward tree + subtree tree_fwd(1, z_init, p_sharp, H0); + // backward tree + subtree tree_bck(-1, z_init, p_sharp, H0); + + // Build a trajectory until the NUTS criterion is no longer satisfied + this->depth_ = 0; + this->divergent_ = false; + + while (this->depth_ < this->max_depth_) { + bool valid_subtree; + double log_sum_weight_subtree; + Eigen::VectorXd rho_subtree; + + if (this->rand_uniform_() > 0.5) { + std::tie(valid_subtree, log_sum_weight_subtree, rho_subtree, z_propose) + = extend_tree(this->depth_, tree_fwd, this->z_, logger); + } else { + std::tie(valid_subtree, log_sum_weight_subtree, rho_subtree, z_propose) + = extend_tree(this->depth_, tree_bck, this->z_, logger); + } + + if (!valid_subtree) break; + + // Sample from an accepted subtree + ++(this->depth_); + + if (log_sum_weight_subtree > log_sum_weight) { + z_sample = z_propose; + } else { + double accept_prob + = std::exp(log_sum_weight_subtree - log_sum_weight); + if (this->rand_uniform_() < accept_prob) + z_sample = z_propose; + } + + log_sum_weight + = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); + + // Break when NUTS criterion is no longer satisfied + rho += rho_subtree; + if (!compute_criterion(tree_bck.p_sharp_end_, tree_fwd.p_sharp_end_, rho)) + break; + //if (!compute_criterion(p_sharp_minus, p_sharp_plus, rho)) + // break; + } + + //this->n_leapfrog_ = n_leapfrog; + this->n_leapfrog_ = tree_fwd.n_leapfrog_ + tree_bck.n_leapfrog_; + + const double sum_metro_prob = tree_fwd.sum_metro_prob_ + tree_bck.sum_metro_prob_; + + // Compute average acceptance probabilty across entire trajectory, + // even over subtrees that may have been rejected + double accept_prob + = sum_metro_prob / static_cast(this->n_leapfrog_); + + this->z_.ps_point::operator=(z_sample); + this->energy_ = this->hamiltonian_.H(this->z_); + return sample(this->z_.q, -this->z_.V, accept_prob); + } + + sample + transition_old(sample& init_sample, callbacks::logger& logger) { + // Initialize the algorithm + this->sample_stepsize(); + + this->seed(init_sample.cont_params()); + + this->hamiltonian_.sample_p(this->z_, this->rand_int_); + this->hamiltonian_.init(this->z_, logger); + + ps_point z_plus(this->z_); + ps_point z_minus(z_plus); + + ps_point z_sample(z_plus); + ps_point z_propose(z_plus); + + Eigen::VectorXd p_sharp_plus = this->hamiltonian_.dtau_dp(this->z_); + //Eigen::VectorXd p_sharp_dummy = p_sharp_plus; + Eigen::VectorXd p_sharp_minus = p_sharp_plus; + Eigen::VectorXd rho = this->z_.p; + + double log_sum_weight = 0; // log(exp(H0 - H0)) + double H0 = this->hamiltonian_.H(this->z_); + int n_leapfrog = 0; + double sum_metro_prob = 0; + + // Build a trajectory until the NUTS criterion is no longer satisfied + this->depth_ = 0; + this->divergent_ = false; + + while (this->depth_ < this->max_depth_) { + // Build a new subtree in a random direction + Eigen::VectorXd rho_subtree = Eigen::VectorXd::Zero(rho.size()); + bool valid_subtree = false; + double log_sum_weight_subtree + = -std::numeric_limits::infinity(); + + // this should be fine (modified from orig) + Eigen::VectorXd p_sharp_dummy = Eigen::VectorXd::Zero(this->z_.p.size()); + + if (this->rand_uniform_() > 0.5) { + this->z_.ps_point::operator=(z_plus); + valid_subtree + = build_tree(this->depth_, this->z_, z_propose, + p_sharp_dummy, p_sharp_plus, rho_subtree, + H0, 1, n_leapfrog, + log_sum_weight_subtree, sum_metro_prob, + logger); + z_plus.ps_point::operator=(this->z_); + } else { + this->z_.ps_point::operator=(z_minus); + valid_subtree + = build_tree(this->depth_, this->z_, z_propose, + p_sharp_dummy, p_sharp_minus, rho_subtree, + H0, -1, n_leapfrog, + log_sum_weight_subtree, sum_metro_prob, + logger); + z_minus.ps_point::operator=(this->z_); + } + + if (!valid_subtree) break; + + // Sample from an accepted subtree + ++(this->depth_); + + if (log_sum_weight_subtree > log_sum_weight) { + z_sample = z_propose; + } else { + double accept_prob + = std::exp(log_sum_weight_subtree - log_sum_weight); + if (this->rand_uniform_() < accept_prob) + z_sample = z_propose; + } + + log_sum_weight + = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); + + // Break when NUTS criterion is no longer satisfied + rho += rho_subtree; + if (!compute_criterion(p_sharp_minus, p_sharp_plus, rho)) + break; + } + + this->n_leapfrog_ = n_leapfrog; + + // Compute average acceptance probabilty across entire trajectory, + // even over subtrees that may have been rejected + double accept_prob + = sum_metro_prob / static_cast(n_leapfrog); + + this->z_.ps_point::operator=(z_sample); + this->energy_ = this->hamiltonian_.H(this->z_); + return sample(this->z_.q, -this->z_.V, accept_prob); + } + + void get_sampler_param_names(std::vector& names) { + names.push_back("stepsize__"); + names.push_back("treedepth__"); + names.push_back("n_leapfrog__"); + names.push_back("divergent__"); + names.push_back("energy__"); + } + + void get_sampler_params(std::vector& values) { + values.push_back(this->epsilon_); + values.push_back(this->depth_); + values.push_back(this->n_leapfrog_); + values.push_back(this->divergent_); + values.push_back(this->energy_); + } + + virtual bool compute_criterion(Eigen::VectorXd& p_sharp_minus, + Eigen::VectorXd& p_sharp_plus, + Eigen::VectorXd& rho) { + return p_sharp_plus.dot(rho) > 0 + && p_sharp_minus.dot(rho) > 0; + } + + /** + * Recursively build a new subtree to completion or until + * the subtree becomes invalid. Returns validity of the + * resulting subtree. + * + * @param depth Depth of the desired subtree + * @param z_beg State beginning from subtree + * @param z_propose State proposed from subtree + * @param p_sharp_left p_sharp from left boundary of returned tree + * @param p_sharp_right p_sharp from the right boundary of returned tree + * @param rho Summed momentum across trajectory + * @param H0 Hamiltonian of initial state + * @param sign Direction in time to built subtree + * @param n_leapfrog Summed number of leapfrog evaluations + * @param log_sum_weight Log of summed weights across trajectory + * @param sum_metro_prob Summed Metropolis probabilities across trajectory + * @param logger Logger for messages + */ + bool build_tree(int depth, state_t& z_beg, + ps_point& z_propose, + Eigen::VectorXd& p_sharp_left, + Eigen::VectorXd& p_sharp_right, + Eigen::VectorXd& rho, + double H0, double sign, int& n_leapfrog, + double& log_sum_weight, double& sum_metro_prob, + callbacks::logger& logger) { + // Base case + if (depth == 0) { + // check if trees are still valid or if we should terminate + if(!this->valid_trees_) + return false; + + this->integrator_.evolve(z_beg, this->hamiltonian_, + sign * this->epsilon_, + logger); + + ++n_leapfrog; + + double h = this->hamiltonian_.H(z_beg); + if (boost::math::isnan(h)) + h = std::numeric_limits::infinity(); + + // TODO: in parallel case we cannot use the global divergent + // flag since this could be a speculative tree!! + //if ((h - H0) > this->max_deltaH_) this->divergent_ = true; + bool is_divergent = (h - H0) > this->max_deltaH_; + //if ((h - H0) > this->max_deltaH_) this->divergent_ = true; + + log_sum_weight = math::log_sum_exp(log_sum_weight, H0 - h); + + if (H0 - h > 0) + sum_metro_prob += 1; + else + sum_metro_prob += std::exp(H0 - h); + + z_propose = z_beg; + rho += z_beg.p; + + p_sharp_left = this->hamiltonian_.dtau_dp(z_beg); + p_sharp_right = p_sharp_left; + + return !is_divergent; + } + // General recursion + Eigen::VectorXd p_sharp_dummy(z_beg.p.size()); + + // Build the left subtree + double log_sum_weight_left = -std::numeric_limits::infinity(); + Eigen::VectorXd rho_left = Eigen::VectorXd::Zero(rho.size()); + + bool valid_left + = build_tree(depth - 1, z_beg, z_propose, + p_sharp_left, p_sharp_dummy, rho_left, + H0, sign, n_leapfrog, + log_sum_weight_left, sum_metro_prob, + logger); + + if (!valid_left) return false; + + // Build the right subtree + ps_point z_propose_right(z_beg); + + double log_sum_weight_right = -std::numeric_limits::infinity(); + Eigen::VectorXd rho_right = Eigen::VectorXd::Zero(rho.size()); + + bool valid_right + = build_tree(depth - 1, z_beg, z_propose_right, + p_sharp_dummy, p_sharp_right, rho_right, + H0, sign, n_leapfrog, + log_sum_weight_right, sum_metro_prob, + logger); + + if (!valid_right) return false; + + // Multinomial sample from right subtree + double log_sum_weight_subtree + = math::log_sum_exp(log_sum_weight_left, log_sum_weight_right); + log_sum_weight + = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); + + if (log_sum_weight_right > log_sum_weight_subtree) { + z_propose = z_propose_right; + } else { + double accept_prob + = std::exp(log_sum_weight_right - log_sum_weight_subtree); + //if (this->rand_uniform_() < accept_prob) + if (get_rand_uniform() < accept_prob) + z_propose = z_propose_right; + } + + Eigen::VectorXd rho_subtree = rho_left + rho_right; + rho += rho_subtree; + + return compute_criterion(p_sharp_left, p_sharp_right, rho_subtree); + } + + inline double get_rand_uniform() { + return this->rand_uniform_vec_[tbb::this_task_arena::current_thread_index()](); + } + + int depth_{0}; + int max_depth_{5}; + double max_deltaH_{1000}; + int n_leapfrog_{0}; + double energy_{0}; + bool valid_trees_{true}; + bool divergent_{false}; + // Uniform(0, 1) RNG + std::vector> rand_uniform_vec_; + }; + template class Hamiltonian, + template class Integrator, class BaseRNG> + using base_parallel_nuts_ct = std::conditional_t, + base_parallel_nuts>; + } // mcmc +} // stan +#endif diff --git a/src/stan/mcmc/hmc/nuts/diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/diag_e_nuts.hpp index 5f830f85cc6..896d219e098 100644 --- a/src/stan/mcmc/hmc/nuts/diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/diag_e_nuts.hpp @@ -2,22 +2,28 @@ #define STAN_MCMC_HMC_NUTS_DIAG_E_NUTS_HPP #include +#include #include #include #include namespace stan { namespace mcmc { + /** * The No-U-Turn sampler (NUTS) with multinomial sampling * with a Gaussian-Euclidean disintegration and diagonal metric */ -template +template class diag_e_nuts - : public base_nuts { + : public base_nuts_ct { + using base_nuts_t = base_nuts_ct; public: diag_e_nuts(const Model& model, BaseRNG& rng) - : base_nuts(model, rng) {} + : base_nuts_t(model, rng) {} + diag_e_nuts(const Model& model, std::vector& thread_rngs) + : base_nuts_t(model, thread_rngs) {} + }; } // namespace mcmc diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp new file mode 100644 index 00000000000..74b3669cfe3 --- /dev/null +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp @@ -0,0 +1,375 @@ +#ifndef STAN_SERVICES_SAMPLE_HMC_NUTS_DIAG_E_ADAPT_PARALLEL_HPP +#define STAN_SERVICES_SAMPLE_HMC_NUTS_DIAG_E_ADAPT_PARALLEL_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace services { +namespace sample { + +/** + * Runs HMC with NUTS with adaptation using diagonal Euclidean metric + * with a pre-specified Euclidean metric. + * + * @tparam Model Model class + * @tparam InitContextPtr A type derived from `stan::io::var_context` + * @tparam InitMetricContext A type derived from `stan::io::var_context` + * @tparam SamplerWriter A type derived from `stan::callbacks::writer` + * @tparam DiagnosticWriter A type derived from `stan::callbacks::writer` + * @tparam InitWriter A type derived from `stan::callbacks::writer` + * @param[in] model Input model to test (with data already instantiated) + * @param[in] init var context for initialization + * @param[in] init_inv_metric var context exposing an initial diagonal + inverse Euclidean metric (must be positive definite) + * @param[in] random_seed random seed for the random number generator + * @param[in] chain chain id to advance the pseudo random number generator + * @param[in] init_radius radius to initialize + * @param[in] num_warmup Number of warmup samples + * @param[in] num_samples Number of samples + * @param[in] num_thin Number to thin the samples + * @param[in] save_warmup Indicates whether to save the warmup iterations + * @param[in] refresh Controls the output + * @param[in] stepsize initial stepsize for discrete evolution + * @param[in] stepsize_jitter uniform random jitter of stepsize + * @param[in] max_depth Maximum tree depth + * @param[in] delta adaptation target acceptance statistic + * @param[in] gamma adaptation regularization scale + * @param[in] kappa adaptation relaxation exponent + * @param[in] t0 adaptation iteration offset + * @param[in] init_buffer width of initial fast adaptation interval + * @param[in] term_buffer width of final fast adaptation interval + * @param[in] window initial width of slow adaptation interval + * @param[in,out] interrupt Callback for interrupts + * @param[in,out] logger Logger for messages + * @param[in,out] init_writer Writer callback for unconstrained inits + * @param[in,out] sample_writer Writer for draws + * @param[in,out] diagnostic_writer Writer for diagnostic information + * @return error_codes::OK if successful + */ +template +int hmc_nuts_diag_e_adapt_parallel( + Model& model, const stan::io::var_context& init, + const stan::io::var_context& init_inv_metric, unsigned int random_seed, + unsigned int chain, double init_radius, int num_warmup, int num_samples, + int num_thin, bool save_warmup, int refresh, double stepsize, + double stepsize_jitter, int max_depth, double delta, double gamma, + double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer, + unsigned int window, callbacks::interrupt& interrupt, + callbacks::logger& logger, callbacks::writer& init_writer, + callbacks::writer& sample_writer, callbacks::writer& diagnostic_writer) { + const int num_threads = stan::math::get_num_threads(); + std::vector rngs; + rngs.reserve(num_threads) + for (size_t i = 0; i < num_threads; ++i) { + rngs.emplace_back(util::create_rng(random_seed, chain + i)); + } + std::vector cont_vector = util::initialize( + model, init, rngs[0], init_radius, true, logger, init_writer); + + Eigen::VectorXd inv_metric; + try { + inv_metric = util::read_diag_inv_metric(init_inv_metric, + model.num_params_r(), logger); + util::validate_diag_inv_metric(inv_metric, logger); + } catch (const std::domain_error& e) { + return error_codes::CONFIG; + } + + stan::mcmc::adapt_diag_e_nuts sampler(model, rngs); + + sampler.set_metric(inv_metric); + sampler.set_nominal_stepsize(stepsize); + sampler.set_stepsize_jitter(stepsize_jitter); + sampler.set_max_depth(max_depth); + + sampler.get_stepsize_adaptation().set_mu(log(10 * stepsize)); + sampler.get_stepsize_adaptation().set_delta(delta); + sampler.get_stepsize_adaptation().set_gamma(gamma); + sampler.get_stepsize_adaptation().set_kappa(kappa); + sampler.get_stepsize_adaptation().set_t0(t0); + + sampler.set_window_params(num_warmup, init_buffer, term_buffer, window, + logger); + + util::run_adaptive_sampler( + sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, + save_warmup, rngs[0], interrupt, logger, sample_writer, diagnostic_writer); + + return error_codes::OK; +} + +/** + * Runs HMC with NUTS with adaptation using diagonal Euclidean metric. + * + * @tparam Model Model class + * @param[in] model Input model to test (with data already instantiated) + * @param[in] init var context for initialization + * @param[in] random_seed random seed for the random number generator + * @param[in] chain chain id to advance the pseudo random number generator + * @param[in] init_radius radius to initialize + * @param[in] num_warmup Number of warmup samples + * @param[in] num_samples Number of samples + * @param[in] num_thin Number to thin the samples + * @param[in] save_warmup Indicates whether to save the warmup iterations + * @param[in] refresh Controls the output + * @param[in] stepsize initial stepsize for discrete evolution + * @param[in] stepsize_jitter uniform random jitter of stepsize + * @param[in] max_depth Maximum tree depth + * @param[in] delta adaptation target acceptance statistic + * @param[in] gamma adaptation regularization scale + * @param[in] kappa adaptation relaxation exponent + * @param[in] t0 adaptation iteration offset + * @param[in] init_buffer width of initial fast adaptation interval + * @param[in] term_buffer width of final fast adaptation interval + * @param[in] window initial width of slow adaptation interval + * @param[in,out] interrupt Callback for interrupts + * @param[in,out] logger Logger for messages + * @param[in,out] init_writer Writer callback for unconstrained inits + * @param[in,out] sample_writer Writer for draws + * @param[in,out] diagnostic_writer Writer for diagnostic information + * @return error_codes::OK if successful + */ +template +int hmc_nuts_diag_e_adapt_parallel( + Model& model, const stan::io::var_context& init, unsigned int random_seed, + unsigned int chain, double init_radius, int num_warmup, int num_samples, + int num_thin, bool save_warmup, int refresh, double stepsize, + double stepsize_jitter, int max_depth, double delta, double gamma, + double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer, + unsigned int window, callbacks::interrupt& interrupt, + callbacks::logger& logger, callbacks::writer& init_writer, + callbacks::writer& sample_writer, callbacks::writer& diagnostic_writer) { + stan::io::dump unit_e_metric + = util::create_unit_e_diag_inv_metric(model.num_params_r()); + return hmc_nuts_diag_e_adapt_parallel( + model, init, unit_e_metric, random_seed, chain, init_radius, num_warmup, + num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, + max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, + interrupt, logger, init_writer, sample_writer, diagnostic_writer); +} + +/** + * Runs multiple chains of HMC with NUTS with adaptation using diagonal + * Euclidean metric with a pre-specified Euclidean metric. + * + * @tparam Model Model class + * @tparam InitContextPtr A pointer with underlying type derived from + `stan::io::var_context` + * @tparam InitInvContextPtr A pointer with underlying type derived from + `stan::io::var_context` + * @tparam SamplerWriter A type derived from `stan::callbacks::writer` + * @tparam DiagnosticWriter A type derived from `stan::callbacks::writer` + * @tparam InitWriter A type derived from `stan::callbacks::writer` + * @param[in] model Input model to test (with data already instantiated) + * @param[in] num_chains The number of chains to run in parallel. `init`, + * `init_inv_metric`, `init_writer`, `sample_writer`, and `diagnostic_writer` + must + * be the same length as this value. + * @param[in] init An std vector of init var contexts for initialization of each + * chain. + * @param[in] init_inv_metric An std vector of var contexts exposing an initial + diagonal inverse Euclidean metric for each chain (must be positive definite) + * @param[in] random_seed random seed for the random number generator + * @param[in] init_chain_id first chain id. The pseudo random number generator + * will advance for each chain by an integer sequence from `init_chain_id` to + * `init_chain_id + num_chains - 1` + * @param[in] init_radius radius to initialize + * @param[in] num_warmup Number of warmup samples + * @param[in] num_samples Number of samples + * @param[in] num_thin Number to thin the samples + * @param[in] save_warmup Indicates whether to save the warmup iterations + * @param[in] refresh Controls the output + * @param[in] stepsize initial stepsize for discrete evolution + * @param[in] stepsize_jitter uniform random jitter of stepsize + * @param[in] max_depth Maximum tree depth + * @param[in] delta adaptation target acceptance statistic + * @param[in] gamma adaptation regularization scale + * @param[in] kappa adaptation relaxation exponent + * @param[in] t0 adaptation iteration offset + * @param[in] init_buffer width of initial fast adaptation interval + * @param[in] term_buffer width of final fast adaptation interval + * @param[in] window initial width of slow adaptation interval + * @param[in,out] interrupt Callback for interrupts + * @param[in,out] logger Logger for messages + * @param[in,out] init_writer std vector of Writer callbacks for unconstrained + * inits of each chain. + * @param[in,out] sample_writer std vector of Writers for draws of each chain. + * @param[in,out] diagnostic_writer std vector of Writers for diagnostic + * information of each chain. + * @return error_codes::OK if successful + */ +template +int hmc_nuts_diag_e_adapt_parallel( + Model& model, size_t num_chains, const std::vector& init, + const std::vector& init_inv_metric, + unsigned int random_seed, unsigned int init_chain_id, double init_radius, + int num_warmup, int num_samples, int num_thin, bool save_warmup, + int refresh, double stepsize, double stepsize_jitter, int max_depth, + double delta, double gamma, double kappa, double t0, + unsigned int init_buffer, unsigned int term_buffer, unsigned int window, + callbacks::interrupt& interrupt, callbacks::logger& logger, + std::vector& init_writer, + std::vector& sample_writer, + std::vector& diagnostic_writer) { + if (num_chains == 1 || stan::math::get_num_threads() == 1) { + return hmc_nuts_diag_e_adapt_parallel( + model, *init[0], *init_inv_metric[0], random_seed, init_chain_id, + init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh, + stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0, + init_buffer, term_buffer, window, interrupt, logger, init_writer[0], + sample_writer[0], diagnostic_writer[0]); + } + const int num_threads = stan::math::get_num_threads(); + std::vector rngs; + rngs.reserve(num_threads); + try { + for (int i = 0; i < num_threads; ++i) { + rngs.emplace_back(util::create_rng(random_seed, init_chain_id + i)); + } + } catch (const std::domain_error& e) { + return error_codes::CONFIG; + } + error_codes ret_code; + tbb::parallel_for(tbb::blocked_range(0, num_chains, 1), + [num_warmup, num_samples, num_thin, refresh, save_warmup, + num_chains, init_chain_id, &ret_code, &model, &rngs, + &interrupt, &logger, &sample_writer, + &diagnostic_writer](const tbb::blocked_range& r) { + boost::ecuyer1988& thread_rng = rngs[tbb::this_task_arena::current_thread_index()] + using sample_t = stan::mcmc::adapt_diag_e_nuts; + Eigen::VectorXd inv_metric; + std::vector cont_vector; + for (size_t i = r.begin(); i != r.end(); ++i) { + sample_t sampler(model, rngs); + try { + cont_vector = util::initialize( + model, *init[i], thread_rng, init_radius, true, logger, init_writer[i]); + inv_metric = util::read_diag_inv_metric( + *init_inv_metric[i], model.num_params_r(), logger); + util::validate_diag_inv_metric(inv_metric, logger); + + sampler.set_metric(inv_metric); + sampler.set_nominal_stepsize(stepsize); + sampler.set_stepsize_jitter(stepsize_jitter); + sampler.set_max_depth(max_depth); + + sampler.get_stepsize_adaptation().set_mu(log(10 * stepsize)); + sampler.get_stepsize_adaptation().set_delta(delta); + sampler.get_stepsize_adaptation().set_gamma(gamma); + sampler.get_stepsize_adaptation().set_kappa(kappa); + sampler.get_stepsize_adaptation().set_t0(t0); + sampler.set_window_params(num_warmup, init_buffer, term_buffer, + window, logger); + } catch (const std::domain_error& e) { + ret_code = error_codes::CONFIG; + return; + } + util::run_adaptive_sampler( + sampler, model, cont_vector, num_warmup, + num_samples, num_thin, refresh, save_warmup, + rngs[i], interrupt, logger, sample_writer[i], + diagnostic_writer[i], init_chain_id + i, + num_chains); + } + }, + tbb::simple_partitioner()); + return ret_code == error_codes::CONFIG ? error_codes::CONFIG : error_codes::OK; +} + +/** + * Runs multiple chains of HMC with NUTS with adaptation using diagonal + * Euclidean metric. + * + * @tparam Model Model class + * @tparam InitContextPtr A pointer with underlying type derived from + * `stan::io::var_context` + * @tparam SamplerWriter A type derived from `stan::callbacks::writer` + * @tparam DiagnosticWriter A type derived from `stan::callbacks::writer` + * @tparam InitWriter A type derived from `stan::callbacks::writer` + * @param[in] model Input model to test (with data already instantiated) + * @param[in] num_chains The number of chains to run in parallel. `init`, + * `init_writer`, `sample_writer`, and `diagnostic_writer` must be the same + * length as this value. + * @param[in] init An std vector of init var contexts for initialization of each + * chain. + * @param[in] random_seed random seed for the random number generator + * @param[in] init_chain_id first chain id. The pseudo random number generator + * will advance by for each chain by an integer sequence from `init_chain_id` to + * `init_chain_id+num_chains-1` + * @param[in] init_radius radius to initialize + * @param[in] num_warmup Number of warmup samples + * @param[in] num_samples Number of samples + * @param[in] num_thin Number to thin the samples + * @param[in] save_warmup Indicates whether to save the warmup iterations + * @param[in] refresh Controls the output + * @param[in] stepsize initial stepsize for discrete evolution + * @param[in] stepsize_jitter uniform random jitter of stepsize + * @param[in] max_depth Maximum tree depth + * @param[in] delta adaptation target acceptance statistic + * @param[in] gamma adaptation regularization scale + * @param[in] kappa adaptation relaxation exponent + * @param[in] t0 adaptation iteration offset + * @param[in] init_buffer width of initial fast adaptation interval + * @param[in] term_buffer width of final fast adaptation interval + * @param[in] window initial width of slow adaptation interval + * @param[in,out] interrupt Callback for interrupts + * @param[in,out] logger Logger for messages + * @param[in,out] init_writer std vector of Writer callbacks for unconstrained + * inits of each chain. + * @param[in,out] sample_writer std vector of Writers for draws of each chain. + * @param[in,out] diagnostic_writer std vector of Writers for diagnostic + * information of each chain. + * @return error_codes::OK if successful + */ +template +int hmc_nuts_diag_e_adapt_parallel( + Model& model, size_t num_chains, const std::vector& init, + unsigned int random_seed, unsigned int init_chain_id, double init_radius, + int num_warmup, int num_samples, int num_thin, bool save_warmup, + int refresh, double stepsize, double stepsize_jitter, int max_depth, + double delta, double gamma, double kappa, double t0, + unsigned int init_buffer, unsigned int term_buffer, unsigned int window, + callbacks::interrupt& interrupt, callbacks::logger& logger, + std::vector& init_writer, + std::vector& sample_writer, + std::vector& diagnostic_writer) { + if (num_chains == 1 || stan::math::get_num_threads() == 1) { + return hmc_nuts_diag_e_adapt_parallel( + model, *init[0], random_seed, init_chain_id, init_radius, num_warmup, + num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, + max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, + interrupt, logger, init_writer[0], sample_writer[0], + diagnostic_writer[0]); + } + std::vector> unit_e_metrics; + unit_e_metrics.reserve(num_chains); + for (size_t i = 0; i < num_chains; ++i) { + unit_e_metrics.emplace_back(std::make_unique( + util::create_unit_e_diag_inv_metric(model.num_params_r()))); + } + return hmc_nuts_diag_e_adapt_parallel( + model, num_chains, init, unit_e_metrics, random_seed, init_chain_id, + init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh, + stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0, + init_buffer, term_buffer, window, interrupt, logger, init_writer, + sample_writer, diagnostic_writer); +} + +} // namespace sample +} // namespace services +} // namespace stan +#endif From 7669f6d43f44c99be839d39578f22ccc531ff3d2 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Tue, 15 Feb 2022 18:17:35 -0500 Subject: [PATCH 3/8] clang format --- src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp | 1340 ++++++++--------- 1 file changed, 666 insertions(+), 674 deletions(-) diff --git a/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp b/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp index 8aed8d5cbec..919e0c18aa6 100644 --- a/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp @@ -32,7 +32,8 @@ inline auto make_uniform_vec(std::vector& thread_rngs) { rand_uniform_vec.emplace_back(thread_rngs[i]); } */ - return std::vector>(thread_rngs.begin(), thread_rngs.end()); + return std::vector>(thread_rngs.begin(), + thread_rngs.end()); } // Prototype of speculative NUTS. @@ -41,750 +42,741 @@ inline auto make_uniform_vec(std::vector& thread_rngs) { // time in parallel. namespace stan { - namespace mcmc { - - /** - * The No-U-Turn sampler (NUTS) with multinomial sampling - */ - template class Hamiltonian, - template class Integrator, class BaseRNG> - class base_parallel_nuts : public base_hmc { - public: - using state_t = typename Hamiltonian::PointType; - - base_parallel_nuts(const Model& model, std::vector& thread_rngs) - : base_hmc(model, thread_rngs[tbb::this_task_arena::current_thread_index()]), - rand_uniform_vec_(make_uniform_vec(thread_rngs)) { - } +namespace mcmc { + +/** + * The No-U-Turn sampler (NUTS) with multinomial sampling + */ +template class Hamiltonian, + template class Integrator, class BaseRNG> +class base_parallel_nuts + : public base_hmc { + public: + using state_t = typename Hamiltonian::PointType; + + base_parallel_nuts(const Model& model, std::vector& thread_rngs) + : base_hmc( + model, thread_rngs[tbb::this_task_arena::current_thread_index()]), + rand_uniform_vec_(make_uniform_vec(thread_rngs)) {} + + base_parallel_nuts(const Model& model, BaseRNG& rng, + std::vector& thread_rngs) + : base_hmc(model, rng), + rand_uniform_vec_(make_uniform_vec(thread_rngs)) {} + + /** + * specialized constructor for specified diag mass matrix + */ + base_parallel_nuts(const Model& model, BaseRNG& rng, + Eigen::VectorXd& inv_e_metric, + std::vector& thread_rngs) + : base_hmc(model, rng, + inv_e_metric), + rand_uniform_vec_(make_uniform_vec(thread_rngs)) {} + + /** + * specialized constructor for specified dense mass matrix + */ + base_parallel_nuts(const Model& model, BaseRNG& rng, + Eigen::MatrixXd& inv_e_metric, + std::vector& thread_rngs) + : base_hmc(model, rng, + inv_e_metric), + rand_uniform_vec_(make_uniform_vec(thread_rngs)) {} + + ~base_parallel_nuts() {} + + inline void set_metric(const Eigen::MatrixXd& inv_e_metric) { + this->z_.set_metric(inv_e_metric); + } - base_parallel_nuts(const Model& model, BaseRNG& rng, std::vector& thread_rngs) - : base_hmc(model, rng), - rand_uniform_vec_(make_uniform_vec(thread_rngs)) { - } + inline void set_metric(const Eigen::VectorXd& inv_e_metric) { + this->z_.set_metric(inv_e_metric); + } - /** - * specialized constructor for specified diag mass matrix - */ - base_parallel_nuts(const Model& model, BaseRNG& rng, - Eigen::VectorXd& inv_e_metric, std::vector& thread_rngs) - : base_hmc(model, rng, inv_e_metric), - rand_uniform_vec_(make_uniform_vec(thread_rngs)) { - } + inline void set_max_depth(int d) noexcept { + if (d > 0) { + max_depth_ = d; + } + } - /** - * specialized constructor for specified dense mass matrix - */ - base_parallel_nuts(const Model& model, BaseRNG& rng, - Eigen::MatrixXd& inv_e_metric, std::vector& thread_rngs) - : base_hmc(model, rng, inv_e_metric), - rand_uniform_vec_(make_uniform_vec(thread_rngs)) { - } + inline void set_max_delta(double d) noexcept { max_deltaH_ = d; } + + inline int get_max_depth() noexcept { return this->max_depth_; } + inline double get_max_delta() noexcept { return this->max_deltaH_; } + + // stores from left/right subtree entire information + struct subtree { + subtree(const double sign, const ps_point& z_end, + const Eigen::VectorXd& p_sharp_end, double H0) + : z_end_(z_end), + z_propose_(z_end), + p_sharp_end_(p_sharp_end), + H0_(H0), + sign_(sign), + n_leapfrog_(0), + sum_metro_prob_(0) {} + + ps_point z_end_; + ps_point z_propose_; + Eigen::VectorXd p_sharp_end_; + double H0_; + double sign_; + int n_leapfrog_{0}; + double sum_metro_prob_{0}; + }; + + // extends the tree into the direction of the sign of the + // subtree + using extend_tree_t = std::tuple; + + inline extend_tree_t extend_tree(int depth, subtree& tree, state_t& z, + callbacks::logger& logger) { + // save the current ends needed for later criterion computations + // Eigen::VectorXd p_end = tree.p_end_; + // Eigen::VectorXd p_sharp_end = tree.p_sharp_end_; + Eigen::VectorXd p_sharp_dummy + = Eigen::VectorXd::Zero(tree.p_sharp_end_.size()); + + Eigen::VectorXd rho_subtree + = Eigen::VectorXd::Zero(tree.p_sharp_end_.size()); + double log_sum_weight_subtree = -std::numeric_limits::infinity(); + + tree.n_leapfrog_ = 0; + tree.sum_metro_prob_ = 0; + + z.ps_point::operator=(tree.z_end_); + + bool valid_subtree = build_tree( + depth, z, tree.z_propose_, p_sharp_dummy, tree.p_sharp_end_, + rho_subtree, tree.H0_, tree.sign_, tree.n_leapfrog_, + log_sum_weight_subtree, tree.sum_metro_prob_, logger); + + tree.z_end_.ps_point::operator=(z); + + return std::make_tuple(valid_subtree, log_sum_weight_subtree, rho_subtree, + tree.p_sharp_end_, tree.z_propose_, tree.n_leapfrog_, + tree.sum_metro_prob_); + } - ~base_parallel_nuts() {} + inline sample transition(sample& init_sample, callbacks::logger& logger) { + return transition_parallel(init_sample, logger); + } - inline void set_metric(const Eigen::MatrixXd& inv_e_metric) { - this->z_.set_metric(inv_e_metric); + // this implementation builds up the dependence graph every call + // to transition. Things which should be refactored: + // 1. build up the nodes only once + // 2. add a prepare method to each node which samples its + // direction and needed random numbers for multinomial sampling + // 3. only the edges are added dynamically. So the forward nodes + // are wired-up and the backward nodes are wired-up if run + // parallel. If run serially, then each grow node is alternated + // with a check node. + sample transition_parallel(sample& init_sample, callbacks::logger& logger) { + // Initialize the algorithm + this->sample_stepsize(); + + this->seed(init_sample.cont_params()); + + this->hamiltonian_.sample_p(this->z_, this->rand_int_); + this->hamiltonian_.init(this->z_, logger); + + const ps_point z_init(this->z_); + + ps_point z_sample(z_init); + // ps_point z_propose(z_init); + + const Eigen::VectorXd p_sharp = this->hamiltonian_.dtau_dp(this->z_); + Eigen::VectorXd rho = this->z_.p; + + double log_sum_weight = 0; // log(exp(H0 - H0)) + double H0 = this->hamiltonian_.H(this->z_); + // int n_leapfrog = 0; + // double sum_metro_prob = 0; + + // forward tree + subtree tree_fwd(1, z_init, p_sharp, H0); + // backward tree + subtree tree_bck(-1, z_init, p_sharp, H0); + + // actual states which move... copy construct atm...revise?! + state_t z_fwd(this->z_); + state_t z_bck(this->z_); + + // Build a trajectory until the NUTS criterion is no longer satisfied + this->depth_ = 0; + this->divergent_ = false; + this->valid_trees_ = true; + + // the actual number of leapfrog steps in trajectory used + // excluding the ones executed speculative + int n_leapfrog = 0; + + // actually summed metropolis prob of used trajectory + double sum_metro_prob = 0; + + std::vector fwd_direction(this->max_depth_); + + for (std::size_t i = 0; i != this->max_depth_; ++i) + fwd_direction[i] = this->rand_uniform_() > 0.5; + + const std::size_t num_fwd + = std::accumulate(fwd_direction.begin(), fwd_direction.end(), 0); + const std::size_t num_bck = this->max_depth_ - num_fwd; + + /* + std::cout << "sampled turns: "; + for (std::size_t i = 0; i != this->max_depth_; ++i) { + if(fwd_direction[i]) + std::cout << "+,"; + else + std::cout << "-,"; + } + std::cout << std::endl; + */ + + tbb::concurrent_vector ends( + this->max_depth_, std::make_tuple(true, 0, Eigen::VectorXd(), + Eigen::VectorXd(), z_sample, 0, 0.0)); + tbb::concurrent_vector valid_subtree_fwd(num_fwd, true); + tbb::concurrent_vector valid_subtree_bck(num_bck, true); + + // HACK!!! + callbacks::logger logger_fwd; + callbacks::logger logger_bck; + + // build TBB flow graph + graph g; + + // add nodes which advance the left/right tree + typedef continue_node tree_builder_t; + + tbb::concurrent_vector all_builder_idx(this->max_depth_); + tbb::concurrent_vector fwd_builder; + tbb::concurrent_vector bck_builder; + typedef tbb::concurrent_vector::iterator builder_iter_t; + + // now wire up the fwd and bck build of the trees which + // depends on single-core or multi-core run + const bool run_serial = stan::math::internal::get_num_threads() == 1; + + std::size_t fwd_idx = 0; + std::size_t bck_idx = 0; + // TODO: the extenders should also check for a global flag if + // we want to keep running + for (std::size_t depth = 0; depth != this->max_depth_; ++depth) { + if (fwd_direction[depth]) { + builder_iter_t fwd_iter + = fwd_builder.emplace_back(g, [&, depth, fwd_idx](continue_msg) { + // std::cout << "fwd turn at depth " << depth; + bool valid_parent + = fwd_idx == 0 ? true : valid_subtree_fwd[fwd_idx - 1]; + if (valid_parent) { + // std::cout << " yes, here we go!" << std::endl; + ends[depth] = extend_tree(depth, tree_fwd, z_fwd, logger_fwd); + valid_subtree_fwd[fwd_idx] = std::get<0>(ends[depth]); + } else { + valid_subtree_fwd[fwd_idx] = false; + } + // std::cout << " nothing to do." << std::endl; + }); + if (!run_serial && fwd_idx != 0) { + // in this case this is not the starting node, we + // connect this with its predecessor + make_edge(*(fwd_iter - 1), *fwd_iter); + } + all_builder_idx[depth] = fwd_idx; + ++fwd_idx; + } else { + builder_iter_t bck_iter + = bck_builder.emplace_back(g, [&, depth, bck_idx](continue_msg) { + // std::cout << "bck turn at depth " << depth; + bool valid_parent + = bck_idx == 0 ? true : valid_subtree_bck[bck_idx - 1]; + if (valid_parent) { + // std::cout << " yes, here we go!" << std::endl; + ends[depth] = extend_tree(depth, tree_bck, z_bck, logger_bck); + valid_subtree_bck[bck_idx] = std::get<0>(ends[depth]); + } else { + valid_subtree_bck[bck_idx] = false; + } + // std::cout << " nothing to do." << std::endl; + }); + if (!run_serial && bck_idx != 0) { + // in case this is not the starting node, we connect + // this with his predecessor + // make_edge(bck_builder[bck_idx-1], bck_builder[bck_idx]); + make_edge(*(bck_iter - 1), *bck_iter); + } + all_builder_idx[depth] = bck_idx; + ++bck_idx; } + } + + // finally wire in the checker which accepts or rejects the + // proposed states from the subtrees + // typedef function_node< tbb::flow::tuple, bool> checker_t; + // typedef join_node< tbb::flow::tuple > joiner_t; + typedef continue_node checker_t; + + tbb::concurrent_vector checks; + // std::vector joins; + + Eigen::VectorXd p_sharp_fwd(p_sharp); + Eigen::VectorXd p_sharp_bck(p_sharp); + + for (std::size_t depth = 0; depth != this->max_depth_; ++depth) { + // joins.push_back(joiner_t(g)); + // std::cout << "creating check at depth " << depth << std::endl; + checks.emplace_back(g, [&, depth](continue_msg) { + bool is_fwd = fwd_direction[depth]; + + extend_tree_t& subtree_result = ends[depth]; + + // if we are still on the + // trajectories which are + // actually used update the + // running tree stats + if (this->valid_trees_) { + this->depth_ = depth + 1; + n_leapfrog += std::get<5>(subtree_result); + sum_metro_prob += std::get<6>(subtree_result); + } - inline void set_metric(const Eigen::VectorXd& inv_e_metric) { - this->z_.set_metric(inv_e_metric); - } + bool valid_subtree = is_fwd ? valid_subtree_fwd[all_builder_idx[depth]] + : valid_subtree_bck[all_builder_idx[depth]]; + + bool is_valid = valid_subtree & this->valid_trees_; + + // std::cout << "CHECK at depth " << depth; - inline void set_max_depth(int d) noexcept { - if (d > 0) { - max_depth_ = d; + if (!is_valid) { + // std::cout << " we are done (early)" << std::endl; + + // setting this globally here + // will terminate all ongoing work + this->valid_trees_ = false; + return; } - } - inline void set_max_delta(double d) noexcept { - max_deltaH_ = d; - } + // std::cout << " checking" << std::endl; - inline int get_max_depth() noexcept { return this->max_depth_; } - inline double get_max_delta() noexcept { return this->max_deltaH_; } - - // stores from left/right subtree entire information - struct subtree { - subtree(const double sign, - const ps_point& z_end, - const Eigen::VectorXd& p_sharp_end, - double H0) - : z_end_(z_end), z_propose_(z_end), - p_sharp_end_(p_sharp_end), - H0_(H0), - sign_(sign), - n_leapfrog_(0), - sum_metro_prob_(0) - {} - - ps_point z_end_; - ps_point z_propose_; - Eigen::VectorXd p_sharp_end_; - double H0_; - double sign_; - int n_leapfrog_{0}; - double sum_metro_prob_{0}; - }; - - - // extends the tree into the direction of the sign of the - // subtree - using extend_tree_t = std::tuple; - - inline extend_tree_t extend_tree(int depth, subtree& tree, state_t& z, - callbacks::logger& logger) { - // save the current ends needed for later criterion computations - //Eigen::VectorXd p_end = tree.p_end_; - //Eigen::VectorXd p_sharp_end = tree.p_sharp_end_; - Eigen::VectorXd p_sharp_dummy = Eigen::VectorXd::Zero(tree.p_sharp_end_.size()); + double log_sum_weight_subtree = std::get<1>(subtree_result); + const Eigen::VectorXd& rho_subtree = std::get<2>(subtree_result); - Eigen::VectorXd rho_subtree = Eigen::VectorXd::Zero(tree.p_sharp_end_.size()); - double log_sum_weight_subtree = -std::numeric_limits::infinity(); + // update correct side + if (is_fwd) { + p_sharp_fwd = std::get<3>(subtree_result); + } else { + p_sharp_bck = std::get<3>(subtree_result); + } - tree.n_leapfrog_ = 0; - tree.sum_metro_prob_ = 0; + const ps_point& z_propose = std::get<4>(subtree_result); - z.ps_point::operator=(tree.z_end_); + // update running sums + if (log_sum_weight_subtree > log_sum_weight) { + z_sample = z_propose; + } else { + double accept_prob + = std::exp(log_sum_weight_subtree - log_sum_weight); + // if (this->rand_uniform_() < + // accept_prob) + // HACK + if (get_rand_uniform() < accept_prob) + z_sample = z_propose; + } - bool valid_subtree = build_tree(depth, - z, tree.z_propose_, - p_sharp_dummy, tree.p_sharp_end_, - rho_subtree, - tree.H0_, - tree.sign_, - tree.n_leapfrog_, - log_sum_weight_subtree, tree.sum_metro_prob_, - logger); + log_sum_weight + = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); - tree.z_end_.ps_point::operator=(z); + // Break when NUTS criterion is no longer satisfied + rho += rho_subtree; + if (!compute_criterion(p_sharp_bck, p_sharp_fwd, rho)) { + // setting this globally here + // will terminate all ongoing work + this->valid_trees_ = false; + // std::cout << " we are done (later)" << std::endl; + } + // std::cout << " continuing (later)" << std::endl; + }); + if (fwd_direction[depth]) { + // std::cout << "depth " << depth << ": joining fwd node " << + // all_builder_idx[depth] << " into join node." << std::endl; + make_edge(fwd_builder[all_builder_idx[depth]], checks.back()); + } else { + // std::cout << "depth " << depth << ": joining bck node " << + // all_builder_idx[depth] << " into join node." << std::endl; + make_edge(bck_builder[all_builder_idx[depth]], checks.back()); + } + if (!run_serial && depth != 0) { + make_edge(checks[depth - 1], checks.back()); + } + } - return std::make_tuple(valid_subtree, log_sum_weight_subtree, rho_subtree, tree.p_sharp_end_, tree.z_propose_, tree.n_leapfrog_, tree.sum_metro_prob_); + if (run_serial) { + for (std::size_t i = 1; i < this->max_depth_; ++i) { + make_edge(checks[i - 1], fwd_direction[i] + ? fwd_builder[all_builder_idx[i]] + : bck_builder[all_builder_idx[i]]); } + } + + // kick off work + if (fwd_direction[0]) { + fwd_builder[0].try_put(continue_msg()); + // the first turn is fwd, so kick off the bck walker if needed + if (!run_serial && num_bck != 0) + bck_builder[0].try_put(continue_msg()); + } else { + bck_builder[0].try_put(continue_msg()); + if (!run_serial && num_fwd != 0) + fwd_builder[0].try_put(continue_msg()); + } + + g.wait_for_all(); + + this->n_leapfrog_ = n_leapfrog; + // this->n_leapfrog_ = tree_fwd.n_leapfrog_ + tree_bck.n_leapfrog_; + + // this includes the speculative executed ones + // const double sum_metro_prob = tree_fwd.sum_metro_prob_ + + // tree_bck.sum_metro_prob_; + + // Compute average acceptance probabilty across entire trajectory, + // even over subtrees that may have been rejected + double accept_prob + = sum_metro_prob / static_cast(this->n_leapfrog_); + + this->z_.ps_point::operator=(z_sample); + this->energy_ = this->hamiltonian_.H(this->z_); + return sample(this->z_.q, -this->z_.V, accept_prob); + } + sample transition_refactored(sample& init_sample, callbacks::logger& logger) { + // Initialize the algorithm + this->sample_stepsize(); - inline sample transition(sample& init_sample, callbacks::logger& logger) { - return transition_parallel(init_sample, logger); - } + this->seed(init_sample.cont_params()); - // this implementation builds up the dependence graph every call - // to transition. Things which should be refactored: - // 1. build up the nodes only once - // 2. add a prepare method to each node which samples its - // direction and needed random numbers for multinomial sampling - // 3. only the edges are added dynamically. So the forward nodes - // are wired-up and the backward nodes are wired-up if run - // parallel. If run serially, then each grow node is alternated - // with a check node. - sample - transition_parallel(sample& init_sample, callbacks::logger& logger) { - // Initialize the algorithm - this->sample_stepsize(); - - this->seed(init_sample.cont_params()); - - this->hamiltonian_.sample_p(this->z_, this->rand_int_); - this->hamiltonian_.init(this->z_, logger); - - const ps_point z_init(this->z_); - - ps_point z_sample(z_init); - //ps_point z_propose(z_init); - - const Eigen::VectorXd p_sharp = this->hamiltonian_.dtau_dp(this->z_); - Eigen::VectorXd rho = this->z_.p; - - double log_sum_weight = 0; // log(exp(H0 - H0)) - double H0 = this->hamiltonian_.H(this->z_); - //int n_leapfrog = 0; - //double sum_metro_prob = 0; - - // forward tree - subtree tree_fwd(1, z_init, p_sharp, H0); - // backward tree - subtree tree_bck(-1, z_init, p_sharp, H0); - - // actual states which move... copy construct atm...revise?! - state_t z_fwd(this->z_); - state_t z_bck(this->z_); - - // Build a trajectory until the NUTS criterion is no longer satisfied - this->depth_ = 0; - this->divergent_ = false; - this->valid_trees_ = true; - - // the actual number of leapfrog steps in trajectory used - // excluding the ones executed speculative - int n_leapfrog = 0; - - // actually summed metropolis prob of used trajectory - double sum_metro_prob = 0; - - std::vector fwd_direction(this->max_depth_); - - for (std::size_t i = 0; i != this->max_depth_; ++i) - fwd_direction[i] = this->rand_uniform_() > 0.5; - - const std::size_t num_fwd = std::accumulate(fwd_direction.begin(), fwd_direction.end(), 0); - const std::size_t num_bck = this->max_depth_ - num_fwd; - - /* - std::cout << "sampled turns: "; - for (std::size_t i = 0; i != this->max_depth_; ++i) { - if(fwd_direction[i]) - std::cout << "+,"; - else - std::cout << "-,"; - } - std::cout << std::endl; - */ - - tbb::concurrent_vector ends(this->max_depth_, std::make_tuple(true, 0, Eigen::VectorXd(), Eigen::VectorXd(), z_sample, 0, 0.0)); - tbb::concurrent_vector valid_subtree_fwd(num_fwd, true); - tbb::concurrent_vector valid_subtree_bck(num_bck, true); - - // HACK!!! - callbacks::logger logger_fwd; - callbacks::logger logger_bck; - - // build TBB flow graph - graph g; - - // add nodes which advance the left/right tree - typedef continue_node tree_builder_t; - - tbb::concurrent_vector all_builder_idx(this->max_depth_); - tbb::concurrent_vector fwd_builder; - tbb::concurrent_vector bck_builder; - typedef tbb::concurrent_vector::iterator builder_iter_t; - - // now wire up the fwd and bck build of the trees which - // depends on single-core or multi-core run - const bool run_serial = stan::math::internal::get_num_threads() == 1; - - std::size_t fwd_idx = 0; - std::size_t bck_idx = 0; - // TODO: the extenders should also check for a global flag if - // we want to keep running - for (std::size_t depth=0; depth != this->max_depth_; ++depth) { - if (fwd_direction[depth]) { - builder_iter_t fwd_iter = - fwd_builder.emplace_back(g, [&,depth,fwd_idx](continue_msg) { - //std::cout << "fwd turn at depth " << depth; - bool valid_parent = fwd_idx == 0 ? true : valid_subtree_fwd[fwd_idx-1]; - if (valid_parent) { - //std::cout << " yes, here we go!" << std::endl; - ends[depth] = extend_tree(depth, tree_fwd, z_fwd, logger_fwd); - valid_subtree_fwd[fwd_idx] = std::get<0>(ends[depth]); - } else { - valid_subtree_fwd[fwd_idx] = false; - } - //std::cout << " nothing to do." << std::endl; - }); - if(!run_serial && fwd_idx != 0) { - // in this case this is not the starting node, we - // connect this with its predecessor - make_edge(*(fwd_iter-1), *fwd_iter); - } - all_builder_idx[depth] = fwd_idx; - ++fwd_idx; - } else { - builder_iter_t bck_iter = - bck_builder.emplace_back(g, [&,depth,bck_idx](continue_msg) { - //std::cout << "bck turn at depth " << depth; - bool valid_parent = bck_idx == 0 ? true : valid_subtree_bck[bck_idx-1]; - if (valid_parent) { - //std::cout << " yes, here we go!" << std::endl; - ends[depth] = extend_tree(depth, tree_bck, z_bck, logger_bck); - valid_subtree_bck[bck_idx] = std::get<0>(ends[depth]); - } else { - valid_subtree_bck[bck_idx] = false; - } - //std::cout << " nothing to do." << std::endl; - }); - if(!run_serial && bck_idx != 0) { - // in case this is not the starting node, we connect - // this with his predecessor - //make_edge(bck_builder[bck_idx-1], bck_builder[bck_idx]); - make_edge(*(bck_iter-1), *bck_iter); - } - all_builder_idx[depth] = bck_idx; - ++bck_idx; - } - } + this->hamiltonian_.sample_p(this->z_, this->rand_int_); + this->hamiltonian_.init(this->z_, logger); - // finally wire in the checker which accepts or rejects the - // proposed states from the subtrees - //typedef function_node< tbb::flow::tuple, bool> checker_t; - //typedef join_node< tbb::flow::tuple > joiner_t; - typedef continue_node checker_t; - - tbb::concurrent_vector checks; - //std::vector joins; - - Eigen::VectorXd p_sharp_fwd(p_sharp); - Eigen::VectorXd p_sharp_bck(p_sharp); - - for (std::size_t depth=0; depth != this->max_depth_; ++depth) { - //joins.push_back(joiner_t(g)); - //std::cout << "creating check at depth " << depth << std::endl; - checks.emplace_back(g, [&,depth](continue_msg) { - bool is_fwd = fwd_direction[depth]; - - extend_tree_t& subtree_result = ends[depth]; - - // if we are still on the - // trajectories which are - // actually used update the - // running tree stats - if (this->valid_trees_) { - this->depth_ = depth + 1; - n_leapfrog += std::get<5>(subtree_result); - sum_metro_prob += std::get<6>(subtree_result); - } - - bool valid_subtree = is_fwd ? - valid_subtree_fwd[all_builder_idx[depth]] : - valid_subtree_bck[all_builder_idx[depth]]; - - bool is_valid = valid_subtree & this->valid_trees_; - - //std::cout << "CHECK at depth " << depth; - - if(!is_valid) { - //std::cout << " we are done (early)" << std::endl; - - // setting this globally here - // will terminate all ongoing work - this->valid_trees_ = false; - return; - } - - //std::cout << " checking" << std::endl; - - double log_sum_weight_subtree = std::get<1>(subtree_result); - const Eigen::VectorXd& rho_subtree = std::get<2>(subtree_result); - - // update correct side - if (is_fwd) { - p_sharp_fwd = std::get<3>(subtree_result); - } else { - p_sharp_bck = std::get<3>(subtree_result); - } - - const ps_point& z_propose = std::get<4>(subtree_result); - - // update running sums - if (log_sum_weight_subtree > log_sum_weight) { - z_sample = z_propose; - } else { - double accept_prob - = std::exp(log_sum_weight_subtree - log_sum_weight); - //if (this->rand_uniform_() < - //accept_prob) - // HACK - if (get_rand_uniform() < accept_prob) - z_sample = z_propose; - } - - log_sum_weight - = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); - - // Break when NUTS criterion is no longer satisfied - rho += rho_subtree; - if (!compute_criterion(p_sharp_bck, p_sharp_fwd, rho)) { - // setting this globally here - // will terminate all ongoing work - this->valid_trees_ = false; - //std::cout << " we are done (later)" << std::endl; - } - //std::cout << " continuing (later)" << std::endl; - }); - if(fwd_direction[depth]) { - //std::cout << "depth " << depth << ": joining fwd node " << all_builder_idx[depth] << " into join node." << std::endl; - make_edge(fwd_builder[all_builder_idx[depth]], checks.back()); - } else { - //std::cout << "depth " << depth << ": joining bck node " << all_builder_idx[depth] << " into join node." << std::endl; - make_edge(bck_builder[all_builder_idx[depth]], checks.back()); - } - if(!run_serial && depth != 0) { - make_edge(checks[depth-1], checks.back()); - } - } + const ps_point z_init(this->z_); - if(run_serial) { - for(std::size_t i = 1; i < this->max_depth_; ++i) { - make_edge(checks[i-1], fwd_direction[i] ? fwd_builder[all_builder_idx[i]] : bck_builder[all_builder_idx[i]]); - } - } + ps_point z_sample(z_init); + ps_point z_propose(z_init); - // kick off work - if(fwd_direction[0]) { - fwd_builder[0].try_put(continue_msg()); - // the first turn is fwd, so kick off the bck walker if needed - if (!run_serial && num_bck != 0) - bck_builder[0].try_put(continue_msg()); - } else { - bck_builder[0].try_put(continue_msg()); - if (!run_serial && num_fwd != 0) - fwd_builder[0].try_put(continue_msg()); - } + const Eigen::VectorXd p_sharp = this->hamiltonian_.dtau_dp(this->z_); + Eigen::VectorXd rho = this->z_.p; - g.wait_for_all(); + double log_sum_weight = 0; // log(exp(H0 - H0)) + double H0 = this->hamiltonian_.H(this->z_); + // int n_leapfrog = 0; + // double sum_metro_prob = 0; - this->n_leapfrog_ = n_leapfrog; - //this->n_leapfrog_ = tree_fwd.n_leapfrog_ + tree_bck.n_leapfrog_; + // forward tree + subtree tree_fwd(1, z_init, p_sharp, H0); + // backward tree + subtree tree_bck(-1, z_init, p_sharp, H0); - // this includes the speculative executed ones - //const double sum_metro_prob = tree_fwd.sum_metro_prob_ + tree_bck.sum_metro_prob_; + // Build a trajectory until the NUTS criterion is no longer satisfied + this->depth_ = 0; + this->divergent_ = false; - // Compute average acceptance probabilty across entire trajectory, - // even over subtrees that may have been rejected - double accept_prob - = sum_metro_prob / static_cast(this->n_leapfrog_); + while (this->depth_ < this->max_depth_) { + bool valid_subtree; + double log_sum_weight_subtree; + Eigen::VectorXd rho_subtree; - this->z_.ps_point::operator=(z_sample); - this->energy_ = this->hamiltonian_.H(this->z_); - return sample(this->z_.q, -this->z_.V, accept_prob); + if (this->rand_uniform_() > 0.5) { + std::tie(valid_subtree, log_sum_weight_subtree, rho_subtree, z_propose) + = extend_tree(this->depth_, tree_fwd, this->z_, logger); + } else { + std::tie(valid_subtree, log_sum_weight_subtree, rho_subtree, z_propose) + = extend_tree(this->depth_, tree_bck, this->z_, logger); } - sample - transition_refactored(sample& init_sample, callbacks::logger& logger) { - // Initialize the algorithm - this->sample_stepsize(); + if (!valid_subtree) + break; - this->seed(init_sample.cont_params()); + // Sample from an accepted subtree + ++(this->depth_); - this->hamiltonian_.sample_p(this->z_, this->rand_int_); - this->hamiltonian_.init(this->z_, logger); + if (log_sum_weight_subtree > log_sum_weight) { + z_sample = z_propose; + } else { + double accept_prob = std::exp(log_sum_weight_subtree - log_sum_weight); + if (this->rand_uniform_() < accept_prob) + z_sample = z_propose; + } - const ps_point z_init(this->z_); + log_sum_weight + = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); - ps_point z_sample(z_init); - ps_point z_propose(z_init); + // Break when NUTS criterion is no longer satisfied + rho += rho_subtree; + if (!compute_criterion(tree_bck.p_sharp_end_, tree_fwd.p_sharp_end_, rho)) + break; + // if (!compute_criterion(p_sharp_minus, p_sharp_plus, rho)) + // break; + } - const Eigen::VectorXd p_sharp = this->hamiltonian_.dtau_dp(this->z_); - Eigen::VectorXd rho = this->z_.p; + // this->n_leapfrog_ = n_leapfrog; + this->n_leapfrog_ = tree_fwd.n_leapfrog_ + tree_bck.n_leapfrog_; - double log_sum_weight = 0; // log(exp(H0 - H0)) - double H0 = this->hamiltonian_.H(this->z_); - //int n_leapfrog = 0; - //double sum_metro_prob = 0; + const double sum_metro_prob + = tree_fwd.sum_metro_prob_ + tree_bck.sum_metro_prob_; - // forward tree - subtree tree_fwd(1, z_init, p_sharp, H0); - // backward tree - subtree tree_bck(-1, z_init, p_sharp, H0); + // Compute average acceptance probabilty across entire trajectory, + // even over subtrees that may have been rejected + double accept_prob + = sum_metro_prob / static_cast(this->n_leapfrog_); - // Build a trajectory until the NUTS criterion is no longer satisfied - this->depth_ = 0; - this->divergent_ = false; + this->z_.ps_point::operator=(z_sample); + this->energy_ = this->hamiltonian_.H(this->z_); + return sample(this->z_.q, -this->z_.V, accept_prob); + } - while (this->depth_ < this->max_depth_) { - bool valid_subtree; - double log_sum_weight_subtree; - Eigen::VectorXd rho_subtree; + sample transition_old(sample& init_sample, callbacks::logger& logger) { + // Initialize the algorithm + this->sample_stepsize(); + + this->seed(init_sample.cont_params()); + + this->hamiltonian_.sample_p(this->z_, this->rand_int_); + this->hamiltonian_.init(this->z_, logger); + + ps_point z_plus(this->z_); + ps_point z_minus(z_plus); + + ps_point z_sample(z_plus); + ps_point z_propose(z_plus); + + Eigen::VectorXd p_sharp_plus = this->hamiltonian_.dtau_dp(this->z_); + // Eigen::VectorXd p_sharp_dummy = p_sharp_plus; + Eigen::VectorXd p_sharp_minus = p_sharp_plus; + Eigen::VectorXd rho = this->z_.p; + + double log_sum_weight = 0; // log(exp(H0 - H0)) + double H0 = this->hamiltonian_.H(this->z_); + int n_leapfrog = 0; + double sum_metro_prob = 0; + + // Build a trajectory until the NUTS criterion is no longer satisfied + this->depth_ = 0; + this->divergent_ = false; + + while (this->depth_ < this->max_depth_) { + // Build a new subtree in a random direction + Eigen::VectorXd rho_subtree = Eigen::VectorXd::Zero(rho.size()); + bool valid_subtree = false; + double log_sum_weight_subtree = -std::numeric_limits::infinity(); + + // this should be fine (modified from orig) + Eigen::VectorXd p_sharp_dummy = Eigen::VectorXd::Zero(this->z_.p.size()); + + if (this->rand_uniform_() > 0.5) { + this->z_.ps_point::operator=(z_plus); + valid_subtree + = build_tree(this->depth_, this->z_, z_propose, p_sharp_dummy, + p_sharp_plus, rho_subtree, H0, 1, n_leapfrog, + log_sum_weight_subtree, sum_metro_prob, logger); + z_plus.ps_point::operator=(this->z_); + } else { + this->z_.ps_point::operator=(z_minus); + valid_subtree + = build_tree(this->depth_, this->z_, z_propose, p_sharp_dummy, + p_sharp_minus, rho_subtree, H0, -1, n_leapfrog, + log_sum_weight_subtree, sum_metro_prob, logger); + z_minus.ps_point::operator=(this->z_); + } - if (this->rand_uniform_() > 0.5) { - std::tie(valid_subtree, log_sum_weight_subtree, rho_subtree, z_propose) - = extend_tree(this->depth_, tree_fwd, this->z_, logger); - } else { - std::tie(valid_subtree, log_sum_weight_subtree, rho_subtree, z_propose) - = extend_tree(this->depth_, tree_bck, this->z_, logger); - } + if (!valid_subtree) + break; - if (!valid_subtree) break; + // Sample from an accepted subtree + ++(this->depth_); - // Sample from an accepted subtree - ++(this->depth_); + if (log_sum_weight_subtree > log_sum_weight) { + z_sample = z_propose; + } else { + double accept_prob = std::exp(log_sum_weight_subtree - log_sum_weight); + if (this->rand_uniform_() < accept_prob) + z_sample = z_propose; + } - if (log_sum_weight_subtree > log_sum_weight) { - z_sample = z_propose; - } else { - double accept_prob - = std::exp(log_sum_weight_subtree - log_sum_weight); - if (this->rand_uniform_() < accept_prob) - z_sample = z_propose; - } + log_sum_weight + = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); - log_sum_weight - = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); + // Break when NUTS criterion is no longer satisfied + rho += rho_subtree; + if (!compute_criterion(p_sharp_minus, p_sharp_plus, rho)) + break; + } - // Break when NUTS criterion is no longer satisfied - rho += rho_subtree; - if (!compute_criterion(tree_bck.p_sharp_end_, tree_fwd.p_sharp_end_, rho)) - break; - //if (!compute_criterion(p_sharp_minus, p_sharp_plus, rho)) - // break; - } + this->n_leapfrog_ = n_leapfrog; - //this->n_leapfrog_ = n_leapfrog; - this->n_leapfrog_ = tree_fwd.n_leapfrog_ + tree_bck.n_leapfrog_; + // Compute average acceptance probabilty across entire trajectory, + // even over subtrees that may have been rejected + double accept_prob = sum_metro_prob / static_cast(n_leapfrog); - const double sum_metro_prob = tree_fwd.sum_metro_prob_ + tree_bck.sum_metro_prob_; + this->z_.ps_point::operator=(z_sample); + this->energy_ = this->hamiltonian_.H(this->z_); + return sample(this->z_.q, -this->z_.V, accept_prob); + } - // Compute average acceptance probabilty across entire trajectory, - // even over subtrees that may have been rejected - double accept_prob - = sum_metro_prob / static_cast(this->n_leapfrog_); + void get_sampler_param_names(std::vector& names) { + names.push_back("stepsize__"); + names.push_back("treedepth__"); + names.push_back("n_leapfrog__"); + names.push_back("divergent__"); + names.push_back("energy__"); + } - this->z_.ps_point::operator=(z_sample); - this->energy_ = this->hamiltonian_.H(this->z_); - return sample(this->z_.q, -this->z_.V, accept_prob); - } + void get_sampler_params(std::vector& values) { + values.push_back(this->epsilon_); + values.push_back(this->depth_); + values.push_back(this->n_leapfrog_); + values.push_back(this->divergent_); + values.push_back(this->energy_); + } - sample - transition_old(sample& init_sample, callbacks::logger& logger) { - // Initialize the algorithm - this->sample_stepsize(); - - this->seed(init_sample.cont_params()); - - this->hamiltonian_.sample_p(this->z_, this->rand_int_); - this->hamiltonian_.init(this->z_, logger); - - ps_point z_plus(this->z_); - ps_point z_minus(z_plus); - - ps_point z_sample(z_plus); - ps_point z_propose(z_plus); - - Eigen::VectorXd p_sharp_plus = this->hamiltonian_.dtau_dp(this->z_); - //Eigen::VectorXd p_sharp_dummy = p_sharp_plus; - Eigen::VectorXd p_sharp_minus = p_sharp_plus; - Eigen::VectorXd rho = this->z_.p; - - double log_sum_weight = 0; // log(exp(H0 - H0)) - double H0 = this->hamiltonian_.H(this->z_); - int n_leapfrog = 0; - double sum_metro_prob = 0; - - // Build a trajectory until the NUTS criterion is no longer satisfied - this->depth_ = 0; - this->divergent_ = false; - - while (this->depth_ < this->max_depth_) { - // Build a new subtree in a random direction - Eigen::VectorXd rho_subtree = Eigen::VectorXd::Zero(rho.size()); - bool valid_subtree = false; - double log_sum_weight_subtree - = -std::numeric_limits::infinity(); - - // this should be fine (modified from orig) - Eigen::VectorXd p_sharp_dummy = Eigen::VectorXd::Zero(this->z_.p.size()); - - if (this->rand_uniform_() > 0.5) { - this->z_.ps_point::operator=(z_plus); - valid_subtree - = build_tree(this->depth_, this->z_, z_propose, - p_sharp_dummy, p_sharp_plus, rho_subtree, - H0, 1, n_leapfrog, - log_sum_weight_subtree, sum_metro_prob, - logger); - z_plus.ps_point::operator=(this->z_); - } else { - this->z_.ps_point::operator=(z_minus); - valid_subtree - = build_tree(this->depth_, this->z_, z_propose, - p_sharp_dummy, p_sharp_minus, rho_subtree, - H0, -1, n_leapfrog, - log_sum_weight_subtree, sum_metro_prob, - logger); - z_minus.ps_point::operator=(this->z_); - } - - if (!valid_subtree) break; - - // Sample from an accepted subtree - ++(this->depth_); - - if (log_sum_weight_subtree > log_sum_weight) { - z_sample = z_propose; - } else { - double accept_prob - = std::exp(log_sum_weight_subtree - log_sum_weight); - if (this->rand_uniform_() < accept_prob) - z_sample = z_propose; - } + virtual bool compute_criterion(Eigen::VectorXd& p_sharp_minus, + Eigen::VectorXd& p_sharp_plus, + Eigen::VectorXd& rho) { + return p_sharp_plus.dot(rho) > 0 && p_sharp_minus.dot(rho) > 0; + } - log_sum_weight - = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); + /** + * Recursively build a new subtree to completion or until + * the subtree becomes invalid. Returns validity of the + * resulting subtree. + * + * @param depth Depth of the desired subtree + * @param z_beg State beginning from subtree + * @param z_propose State proposed from subtree + * @param p_sharp_left p_sharp from left boundary of returned tree + * @param p_sharp_right p_sharp from the right boundary of returned tree + * @param rho Summed momentum across trajectory + * @param H0 Hamiltonian of initial state + * @param sign Direction in time to built subtree + * @param n_leapfrog Summed number of leapfrog evaluations + * @param log_sum_weight Log of summed weights across trajectory + * @param sum_metro_prob Summed Metropolis probabilities across trajectory + * @param logger Logger for messages + */ + bool build_tree(int depth, state_t& z_beg, ps_point& z_propose, + Eigen::VectorXd& p_sharp_left, Eigen::VectorXd& p_sharp_right, + Eigen::VectorXd& rho, double H0, double sign, int& n_leapfrog, + double& log_sum_weight, double& sum_metro_prob, + callbacks::logger& logger) { + // Base case + if (depth == 0) { + // check if trees are still valid or if we should terminate + if (!this->valid_trees_) + return false; - // Break when NUTS criterion is no longer satisfied - rho += rho_subtree; - if (!compute_criterion(p_sharp_minus, p_sharp_plus, rho)) - break; - } + this->integrator_.evolve(z_beg, this->hamiltonian_, sign * this->epsilon_, + logger); - this->n_leapfrog_ = n_leapfrog; + ++n_leapfrog; - // Compute average acceptance probabilty across entire trajectory, - // even over subtrees that may have been rejected - double accept_prob - = sum_metro_prob / static_cast(n_leapfrog); + double h = this->hamiltonian_.H(z_beg); + if (boost::math::isnan(h)) + h = std::numeric_limits::infinity(); - this->z_.ps_point::operator=(z_sample); - this->energy_ = this->hamiltonian_.H(this->z_); - return sample(this->z_.q, -this->z_.V, accept_prob); - } + // TODO: in parallel case we cannot use the global divergent + // flag since this could be a speculative tree!! + // if ((h - H0) > this->max_deltaH_) this->divergent_ = true; + bool is_divergent = (h - H0) > this->max_deltaH_; + // if ((h - H0) > this->max_deltaH_) this->divergent_ = true; - void get_sampler_param_names(std::vector& names) { - names.push_back("stepsize__"); - names.push_back("treedepth__"); - names.push_back("n_leapfrog__"); - names.push_back("divergent__"); - names.push_back("energy__"); - } + log_sum_weight = math::log_sum_exp(log_sum_weight, H0 - h); - void get_sampler_params(std::vector& values) { - values.push_back(this->epsilon_); - values.push_back(this->depth_); - values.push_back(this->n_leapfrog_); - values.push_back(this->divergent_); - values.push_back(this->energy_); - } + if (H0 - h > 0) + sum_metro_prob += 1; + else + sum_metro_prob += std::exp(H0 - h); - virtual bool compute_criterion(Eigen::VectorXd& p_sharp_minus, - Eigen::VectorXd& p_sharp_plus, - Eigen::VectorXd& rho) { - return p_sharp_plus.dot(rho) > 0 - && p_sharp_minus.dot(rho) > 0; - } + z_propose = z_beg; + rho += z_beg.p; - /** - * Recursively build a new subtree to completion or until - * the subtree becomes invalid. Returns validity of the - * resulting subtree. - * - * @param depth Depth of the desired subtree - * @param z_beg State beginning from subtree - * @param z_propose State proposed from subtree - * @param p_sharp_left p_sharp from left boundary of returned tree - * @param p_sharp_right p_sharp from the right boundary of returned tree - * @param rho Summed momentum across trajectory - * @param H0 Hamiltonian of initial state - * @param sign Direction in time to built subtree - * @param n_leapfrog Summed number of leapfrog evaluations - * @param log_sum_weight Log of summed weights across trajectory - * @param sum_metro_prob Summed Metropolis probabilities across trajectory - * @param logger Logger for messages - */ - bool build_tree(int depth, state_t& z_beg, - ps_point& z_propose, - Eigen::VectorXd& p_sharp_left, - Eigen::VectorXd& p_sharp_right, - Eigen::VectorXd& rho, - double H0, double sign, int& n_leapfrog, - double& log_sum_weight, double& sum_metro_prob, - callbacks::logger& logger) { - // Base case - if (depth == 0) { - // check if trees are still valid or if we should terminate - if(!this->valid_trees_) - return false; - - this->integrator_.evolve(z_beg, this->hamiltonian_, - sign * this->epsilon_, - logger); - - ++n_leapfrog; - - double h = this->hamiltonian_.H(z_beg); - if (boost::math::isnan(h)) - h = std::numeric_limits::infinity(); - - // TODO: in parallel case we cannot use the global divergent - // flag since this could be a speculative tree!! - //if ((h - H0) > this->max_deltaH_) this->divergent_ = true; - bool is_divergent = (h - H0) > this->max_deltaH_; - //if ((h - H0) > this->max_deltaH_) this->divergent_ = true; - - log_sum_weight = math::log_sum_exp(log_sum_weight, H0 - h); - - if (H0 - h > 0) - sum_metro_prob += 1; - else - sum_metro_prob += std::exp(H0 - h); - - z_propose = z_beg; - rho += z_beg.p; - - p_sharp_left = this->hamiltonian_.dtau_dp(z_beg); - p_sharp_right = p_sharp_left; - - return !is_divergent; - } - // General recursion - Eigen::VectorXd p_sharp_dummy(z_beg.p.size()); + p_sharp_left = this->hamiltonian_.dtau_dp(z_beg); + p_sharp_right = p_sharp_left; - // Build the left subtree - double log_sum_weight_left = -std::numeric_limits::infinity(); - Eigen::VectorXd rho_left = Eigen::VectorXd::Zero(rho.size()); + return !is_divergent; + } + // General recursion + Eigen::VectorXd p_sharp_dummy(z_beg.p.size()); - bool valid_left - = build_tree(depth - 1, z_beg, z_propose, - p_sharp_left, p_sharp_dummy, rho_left, - H0, sign, n_leapfrog, - log_sum_weight_left, sum_metro_prob, - logger); + // Build the left subtree + double log_sum_weight_left = -std::numeric_limits::infinity(); + Eigen::VectorXd rho_left = Eigen::VectorXd::Zero(rho.size()); - if (!valid_left) return false; + bool valid_left = build_tree(depth - 1, z_beg, z_propose, p_sharp_left, + p_sharp_dummy, rho_left, H0, sign, n_leapfrog, + log_sum_weight_left, sum_metro_prob, logger); - // Build the right subtree - ps_point z_propose_right(z_beg); + if (!valid_left) + return false; - double log_sum_weight_right = -std::numeric_limits::infinity(); - Eigen::VectorXd rho_right = Eigen::VectorXd::Zero(rho.size()); + // Build the right subtree + ps_point z_propose_right(z_beg); - bool valid_right - = build_tree(depth - 1, z_beg, z_propose_right, - p_sharp_dummy, p_sharp_right, rho_right, - H0, sign, n_leapfrog, - log_sum_weight_right, sum_metro_prob, - logger); + double log_sum_weight_right = -std::numeric_limits::infinity(); + Eigen::VectorXd rho_right = Eigen::VectorXd::Zero(rho.size()); - if (!valid_right) return false; + bool valid_right + = build_tree(depth - 1, z_beg, z_propose_right, p_sharp_dummy, + p_sharp_right, rho_right, H0, sign, n_leapfrog, + log_sum_weight_right, sum_metro_prob, logger); - // Multinomial sample from right subtree - double log_sum_weight_subtree - = math::log_sum_exp(log_sum_weight_left, log_sum_weight_right); - log_sum_weight - = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); + if (!valid_right) + return false; - if (log_sum_weight_right > log_sum_weight_subtree) { - z_propose = z_propose_right; - } else { - double accept_prob - = std::exp(log_sum_weight_right - log_sum_weight_subtree); - //if (this->rand_uniform_() < accept_prob) - if (get_rand_uniform() < accept_prob) - z_propose = z_propose_right; - } + // Multinomial sample from right subtree + double log_sum_weight_subtree + = math::log_sum_exp(log_sum_weight_left, log_sum_weight_right); + log_sum_weight = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); - Eigen::VectorXd rho_subtree = rho_left + rho_right; - rho += rho_subtree; + if (log_sum_weight_right > log_sum_weight_subtree) { + z_propose = z_propose_right; + } else { + double accept_prob + = std::exp(log_sum_weight_right - log_sum_weight_subtree); + // if (this->rand_uniform_() < accept_prob) + if (get_rand_uniform() < accept_prob) + z_propose = z_propose_right; + } - return compute_criterion(p_sharp_left, p_sharp_right, rho_subtree); - } + Eigen::VectorXd rho_subtree = rho_left + rho_right; + rho += rho_subtree; - inline double get_rand_uniform() { - return this->rand_uniform_vec_[tbb::this_task_arena::current_thread_index()](); - } + return compute_criterion(p_sharp_left, p_sharp_right, rho_subtree); + } + + inline double get_rand_uniform() { + return this + ->rand_uniform_vec_[tbb::this_task_arena::current_thread_index()](); + } - int depth_{0}; - int max_depth_{5}; - double max_deltaH_{1000}; - int n_leapfrog_{0}; - double energy_{0}; - bool valid_trees_{true}; - bool divergent_{false}; - // Uniform(0, 1) RNG - std::vector> rand_uniform_vec_; - }; - template class Hamiltonian, - template class Integrator, class BaseRNG> - using base_parallel_nuts_ct = std::conditional_t, - base_parallel_nuts>; - } // mcmc -} // stan + int depth_{0}; + int max_depth_{5}; + double max_deltaH_{1000}; + int n_leapfrog_{0}; + double energy_{0}; + bool valid_trees_{true}; + bool divergent_{false}; + // Uniform(0, 1) RNG + std::vector> rand_uniform_vec_; +}; +template class Hamiltonian, + template class Integrator, class BaseRNG> +using base_parallel_nuts_ct = std::conditional_t< + ParallelBase, base_parallel_nuts, + base_parallel_nuts>; +} // namespace mcmc +} // namespace stan #endif From 4cda135d29eb80d328a780efba0032ca64fb54ae Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Tue, 15 Feb 2022 18:23:14 -0500 Subject: [PATCH 4/8] update with some comments --- src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp b/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp index 919e0c18aa6..79c6a6cd5e1 100644 --- a/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp @@ -259,12 +259,16 @@ class base_parallel_nuts // now wire up the fwd and bck build of the trees which // depends on single-core or multi-core run + // TODO (Steve) We should only use this class if get_num_threads > 1 + // Else just use the non-parallel nuts. const bool run_serial = stan::math::internal::get_num_threads() == 1; std::size_t fwd_idx = 0; std::size_t bck_idx = 0; // TODO: the extenders should also check for a global flag if // we want to keep running + // TODO: We should also just run depth = 0 outside the loop to avoid the + // if statement here for (std::size_t depth = 0; depth != this->max_depth_; ++depth) { if (fwd_direction[depth]) { builder_iter_t fwd_iter @@ -330,7 +334,7 @@ class base_parallel_nuts // joins.push_back(joiner_t(g)); // std::cout << "creating check at depth " << depth << std::endl; checks.emplace_back(g, [&, depth](continue_msg) { - bool is_fwd = fwd_direction[depth]; + const bool is_fwd = fwd_direction[depth]; extend_tree_t& subtree_result = ends[depth]; @@ -344,10 +348,10 @@ class base_parallel_nuts sum_metro_prob += std::get<6>(subtree_result); } - bool valid_subtree = is_fwd ? valid_subtree_fwd[all_builder_idx[depth]] + const bool valid_subtree = is_fwd ? valid_subtree_fwd[all_builder_idx[depth]] : valid_subtree_bck[all_builder_idx[depth]]; - bool is_valid = valid_subtree & this->valid_trees_; + const bool is_valid = valid_subtree & this->valid_trees_; // std::cout << "CHECK at depth " << depth; From dfdc55a99761b0146df0eeea063d544239928a2b Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Wed, 16 Feb 2022 17:04:07 -0500 Subject: [PATCH 5/8] tests running for parallel adaptation --- src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp | 6 +- src/stan/mcmc/hmc/nuts/base_nuts.hpp | 1082 +++++------------ src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp | 64 +- src/stan/mcmc/hmc/nuts/diag_e_nuts.hpp | 2 + .../sample/hmc_nuts_diag_e_adapt_parallel.hpp | 258 +--- ...ts_diag_e_adapt_parallel_parallel_test.cpp | 175 +++ 6 files changed, 610 insertions(+), 977 deletions(-) create mode 100644 src/test/unit/services/sample/hmc_nuts_diag_e_adapt_parallel_parallel_test.cpp diff --git a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp index 5e142ce8821..d39cc5ef005 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp @@ -16,15 +16,17 @@ template class adapt_diag_e_nuts : public diag_e_nuts, public stepsize_var_adapter { public: + template * = nullptr> adapt_diag_e_nuts(const Model& model, BaseRNG& rng) : diag_e_nuts(model, rng), stepsize_var_adapter(model.num_params_r()) {} - diag_e_nuts(const Model& model, std::vector& thread_rngs) + template * = nullptr> + adapt_diag_e_nuts(const Model& model, std::vector& thread_rngs) : diag_e_nuts(model, thread_rngs), stepsize_var_adapter(model.num_params_r()) {} - sample transition(sample& init_sample, callbacks::logger& logger) { + inline sample transition(sample& init_sample, callbacks::logger& logger) { sample s = diag_e_nuts::transition(init_sample, logger); if (this->adapt_flag_) { diff --git a/src/stan/mcmc/hmc/nuts/base_nuts.hpp b/src/stan/mcmc/hmc/nuts/base_nuts.hpp index 74aca4d105c..38964283c89 100644 --- a/src/stan/mcmc/hmc/nuts/base_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/base_nuts.hpp @@ -11,763 +11,355 @@ #include #include -#include - -#include - -#include "tbb/task_scheduler_init.h" -#include "tbb/flow_graph.h" -#include "tbb/concurrent_vector.h" - -using namespace tbb::flow; - -// Prototype of speculative NUTS. -// Uses the Intel Flow Graph concept to turn NUTS into a parallel -// algorithm in that the forward and backward sweep run at the same -// time in parallel. - namespace stan { - namespace mcmc { - /** - * The No-U-Turn sampler (NUTS) with multinomial sampling - */ - template class Hamiltonian, - template class Integrator, class BaseRNG> - class base_nuts : public base_hmc { - public: - typedef typename Hamiltonian::PointType state_t; - - base_nuts(const Model& model, BaseRNG& rng) - : base_hmc(model, rng), - depth_(0), max_depth_(5), max_deltaH_(1000), valid_trees_(true), - n_leapfrog_(0), divergent_(false), energy_(0) { - } - - /** - * specialized constructor for specified diag mass matrix - */ - base_nuts(const Model& model, BaseRNG& rng, - Eigen::VectorXd& inv_e_metric) - : base_hmc(model, rng, - inv_e_metric), - depth_(0), max_depth_(5), max_deltaH_(1000), valid_trees_(true), - n_leapfrog_(0), divergent_(false), energy_(0) { - } - - /** - * specialized constructor for specified dense mass matrix - */ - base_nuts(const Model& model, BaseRNG& rng, - Eigen::MatrixXd& inv_e_metric) - : base_hmc(model, rng, - inv_e_metric), - depth_(0), max_depth_(5), max_deltaH_(1000), valid_trees_(true), - n_leapfrog_(0), divergent_(false), energy_(0) { - } - - ~base_nuts() {} - - void set_metric(const Eigen::MatrixXd& inv_e_metric) { - this->z_.set_metric(inv_e_metric); - } - - void set_metric(const Eigen::VectorXd& inv_e_metric) { - this->z_.set_metric(inv_e_metric); - } - - void set_max_depth(int d) { - if (d > 0) - max_depth_ = d; - } - - void set_max_delta(double d) { - max_deltaH_ = d; +namespace mcmc { +/** + * The No-U-Turn sampler (NUTS) with multinomial sampling + */ +template class Hamiltonian, + template class Integrator, class BaseRNG> +class base_nuts : public base_hmc { + public: + base_nuts(const Model& model, BaseRNG& rng) + : base_hmc(model, rng), + depth_(0), + max_depth_(5), + max_deltaH_(1000), + n_leapfrog_(0), + divergent_(false), + energy_(0) {} + + /** + * specialized constructor for specified diag mass matrix + */ + base_nuts(const Model& model, BaseRNG& rng, Eigen::VectorXd& inv_e_metric) + : base_hmc(model, rng, + inv_e_metric), + depth_(0), + max_depth_(5), + max_deltaH_(1000), + n_leapfrog_(0), + divergent_(false), + energy_(0) {} + + /** + * specialized constructor for specified dense mass matrix + */ + base_nuts(const Model& model, BaseRNG& rng, Eigen::MatrixXd& inv_e_metric) + : base_hmc(model, rng, + inv_e_metric), + depth_(0), + max_depth_(5), + max_deltaH_(1000), + n_leapfrog_(0), + divergent_(false), + energy_(0) {} + + ~base_nuts() {} + + void set_metric(const Eigen::MatrixXd& inv_e_metric) { + this->z_.set_metric(inv_e_metric); + } + + void set_metric(const Eigen::VectorXd& inv_e_metric) { + this->z_.set_metric(inv_e_metric); + } + + void set_max_depth(int d) { + if (d > 0) + max_depth_ = d; + } + + void set_max_delta(double d) { max_deltaH_ = d; } + + int get_max_depth() { return this->max_depth_; } + double get_max_delta() { return this->max_deltaH_; } + + sample transition(sample& init_sample, callbacks::logger& logger) { + // Initialize the algorithm + this->sample_stepsize(); + + this->seed(init_sample.cont_params()); + + this->hamiltonian_.sample_p(this->z_, this->rand_int_); + this->hamiltonian_.init(this->z_, logger); + + ps_point z_fwd(this->z_); // State at forward end of trajectory + ps_point z_bck(z_fwd); // State at backward end of trajectory + + ps_point z_sample(z_fwd); + ps_point z_propose(z_fwd); + + // Momentum and sharp momentum at forward end of forward subtree + Eigen::VectorXd p_fwd_fwd = this->z_.p; + Eigen::VectorXd p_sharp_fwd_fwd = this->hamiltonian_.dtau_dp(this->z_); + + // Momentum and sharp momentum at backward end of forward subtree + Eigen::VectorXd p_fwd_bck = this->z_.p; + Eigen::VectorXd p_sharp_fwd_bck = p_sharp_fwd_fwd; + + // Momentum and sharp momentum at forward end of backward subtree + Eigen::VectorXd p_bck_fwd = this->z_.p; + Eigen::VectorXd p_sharp_bck_fwd = p_sharp_fwd_fwd; + + // Momentum and sharp momentum at backward end of backward subtree + Eigen::VectorXd p_bck_bck = this->z_.p; + Eigen::VectorXd p_sharp_bck_bck = p_sharp_fwd_fwd; + + // Integrated momenta along trajectory + Eigen::VectorXd rho = this->z_.p.transpose(); + + // Log sum of state weights (offset by H0) along trajectory + double log_sum_weight = 0; // log(exp(H0 - H0)) + double H0 = this->hamiltonian_.H(this->z_); + int n_leapfrog = 0; + double sum_metro_prob = 0; + + // Build a trajectory until the no-u-turn + // criterion is no longer satisfied + this->depth_ = 0; + this->divergent_ = false; + + while (this->depth_ < this->max_depth_) { + // Build a new subtree in a random direction + Eigen::VectorXd rho_fwd = Eigen::VectorXd::Zero(rho.size()); + Eigen::VectorXd rho_bck = Eigen::VectorXd::Zero(rho.size()); + + bool valid_subtree = false; + double log_sum_weight_subtree = -std::numeric_limits::infinity(); + + if (this->rand_uniform_() > 0.5) { + // Extend the current trajectory forward + this->z_.ps_point::operator=(z_fwd); + rho_bck = rho; + p_bck_fwd = p_fwd_fwd; + p_sharp_bck_fwd = p_sharp_fwd_fwd; + + valid_subtree = build_tree( + this->depth_, z_propose, p_sharp_fwd_bck, p_sharp_fwd_fwd, rho_fwd, + p_fwd_bck, p_fwd_fwd, H0, 1, n_leapfrog, log_sum_weight_subtree, + sum_metro_prob, logger); + z_fwd.ps_point::operator=(this->z_); + } else { + // Extend the current trajectory backwards + this->z_.ps_point::operator=(z_bck); + rho_fwd = rho; + p_fwd_bck = p_bck_bck; + p_sharp_fwd_bck = p_sharp_bck_bck; + + valid_subtree = build_tree( + this->depth_, z_propose, p_sharp_bck_fwd, p_sharp_bck_bck, rho_bck, + p_bck_fwd, p_bck_bck, H0, -1, n_leapfrog, log_sum_weight_subtree, + sum_metro_prob, logger); + z_bck.ps_point::operator=(this->z_); } - int get_max_depth() { return this->max_depth_; } - double get_max_delta() { return this->max_deltaH_; } - - // stores from left/right subtree entire information - struct subtree { - subtree(const double sign, - const ps_point& z_end, - const Eigen::VectorXd& p_sharp_end, - double H0) - : z_end_(z_end), z_propose_(z_end), - p_sharp_end_(p_sharp_end), - H0_(H0), - sign_(sign), - n_leapfrog_(0), - sum_metro_prob_(0) - {} - - ps_point z_end_; - ps_point z_propose_; - Eigen::VectorXd p_sharp_end_; - const double H0_; - const double sign_; - int n_leapfrog_; - double sum_metro_prob_; - }; - - - // extends the tree into the direction of the sign of the - // subtree - typedef std::tuple extend_tree_t; - - extend_tree_t - extend_tree(int depth, subtree& tree, state_t& z, - callbacks::logger& logger) { - // save the current ends needed for later criterion computations - //Eigen::VectorXd p_end = tree.p_end_; - //Eigen::VectorXd p_sharp_end = tree.p_sharp_end_; - Eigen::VectorXd p_sharp_dummy = Eigen::VectorXd::Zero(tree.p_sharp_end_.size()); - - Eigen::VectorXd rho_subtree = Eigen::VectorXd::Zero(tree.p_sharp_end_.size()); - double log_sum_weight_subtree = -std::numeric_limits::infinity(); - - tree.n_leapfrog_ = 0; - tree.sum_metro_prob_ = 0; - - z.ps_point::operator=(tree.z_end_); - - bool valid_subtree = build_tree(depth, - z, tree.z_propose_, - p_sharp_dummy, tree.p_sharp_end_, - rho_subtree, - tree.H0_, - tree.sign_, - tree.n_leapfrog_, - log_sum_weight_subtree, tree.sum_metro_prob_, - logger); - - tree.z_end_.ps_point::operator=(z); - - return std::make_tuple(valid_subtree, log_sum_weight_subtree, rho_subtree, tree.p_sharp_end_, tree.z_propose_, tree.n_leapfrog_, tree.sum_metro_prob_); - } + if (!valid_subtree) + break; + // Sample from accepted subtree + ++(this->depth_); - sample - transition(sample& init_sample, callbacks::logger& logger) { - return transition_parallel(init_sample, logger); + if (log_sum_weight_subtree > log_sum_weight) { + z_sample = z_propose; + } else { + double accept_prob = std::exp(log_sum_weight_subtree - log_sum_weight); + if (this->rand_uniform_() < accept_prob) + z_sample = z_propose; } - // this implementation builds up the dependence graph every call - // to transition. Things which should be refactored: - // 1. build up the nodes only once - // 2. add a prepare method to each node which samples its - // direction and needed random numbers for multinomial sampling - // 3. only the edges are added dynamically. So the forward nodes - // are wired-up and the backward nodes are wired-up if run - // parallel. If run serially, then each grow node is alternated - // with a check node. - sample - transition_parallel(sample& init_sample, callbacks::logger& logger) { - // Initialize the algorithm - this->sample_stepsize(); - - this->seed(init_sample.cont_params()); - - this->hamiltonian_.sample_p(this->z_, this->rand_int_); - this->hamiltonian_.init(this->z_, logger); - - const ps_point z_init(this->z_); - - ps_point z_sample(z_init); - //ps_point z_propose(z_init); - - const Eigen::VectorXd p_sharp = this->hamiltonian_.dtau_dp(this->z_); - Eigen::VectorXd rho = this->z_.p; - - double log_sum_weight = 0; // log(exp(H0 - H0)) - double H0 = this->hamiltonian_.H(this->z_); - //int n_leapfrog = 0; - //double sum_metro_prob = 0; - - // forward tree - subtree tree_fwd(1, z_init, p_sharp, H0); - // backward tree - subtree tree_bck(-1, z_init, p_sharp, H0); - - // actual states which move... copy construct atm...revise?! - state_t z_fwd(this->z_); - state_t z_bck(this->z_); - - // Build a trajectory until the NUTS criterion is no longer satisfied - this->depth_ = 0; - this->divergent_ = false; - this->valid_trees_ = true; - - // the actual number of leapfrog steps in trajectory used - // excluding the ones executed speculative - int n_leapfrog = 0; - - // actually summed metropolis prob of used trajectory - double sum_metro_prob = 0; - - std::vector fwd_direction(this->max_depth_); - - for (std::size_t i = 0; i != this->max_depth_; ++i) - fwd_direction[i] = this->rand_uniform_() > 0.5; - - const std::size_t num_fwd = std::accumulate(fwd_direction.begin(), fwd_direction.end(), 0); - const std::size_t num_bck = this->max_depth_ - num_fwd; - - /* - std::cout << "sampled turns: "; - for (std::size_t i = 0; i != this->max_depth_; ++i) { - if(fwd_direction[i]) - std::cout << "+,"; - else - std::cout << "-,"; - } - std::cout << std::endl; - */ - - tbb::concurrent_vector ends(this->max_depth_, std::make_tuple(true, 0, Eigen::VectorXd(), Eigen::VectorXd(), z_sample, 0, 0.0)); - tbb::concurrent_vector valid_subtree_fwd(num_fwd, true); - tbb::concurrent_vector valid_subtree_bck(num_bck, true); - - // HACK!!! - callbacks::logger logger_fwd; - callbacks::logger logger_bck; - - // build TBB flow graph - graph g; - - // add nodes which advance the left/right tree - typedef continue_node tree_builder_t; - - tbb::concurrent_vector all_builder_idx(this->max_depth_); - tbb::concurrent_vector fwd_builder; - tbb::concurrent_vector bck_builder; - typedef tbb::concurrent_vector::iterator builder_iter_t; - - // now wire up the fwd and bck build of the trees which - // depends on single-core or multi-core run - const bool run_serial = stan::math::internal::get_num_threads() == 1; - - std::size_t fwd_idx = 0; - std::size_t bck_idx = 0; - // TODO: the extenders should also check for a global flag if - // we want to keep running - for (std::size_t depth=0; depth != this->max_depth_; ++depth) { - if (fwd_direction[depth]) { - builder_iter_t fwd_iter = - fwd_builder.emplace_back(g, [&,depth,fwd_idx](continue_msg) { - //std::cout << "fwd turn at depth " << depth; - bool valid_parent = fwd_idx == 0 ? true : valid_subtree_fwd[fwd_idx-1]; - if (valid_parent) { - //std::cout << " yes, here we go!" << std::endl; - ends[depth] = extend_tree(depth, tree_fwd, z_fwd, logger_fwd); - valid_subtree_fwd[fwd_idx] = std::get<0>(ends[depth]); - } else { - valid_subtree_fwd[fwd_idx] = false; - } - //std::cout << " nothing to do." << std::endl; - }); - if(!run_serial && fwd_idx != 0) { - // in this case this is not the starting node, we - // connect this with its predecessor - make_edge(*(fwd_iter-1), *fwd_iter); - } - all_builder_idx[depth] = fwd_idx; - ++fwd_idx; - } else { - builder_iter_t bck_iter = - bck_builder.emplace_back(g, [&,depth,bck_idx](continue_msg) { - //std::cout << "bck turn at depth " << depth; - bool valid_parent = bck_idx == 0 ? true : valid_subtree_bck[bck_idx-1]; - if (valid_parent) { - //std::cout << " yes, here we go!" << std::endl; - ends[depth] = extend_tree(depth, tree_bck, z_bck, logger_bck); - valid_subtree_bck[bck_idx] = std::get<0>(ends[depth]); - } else { - valid_subtree_bck[bck_idx] = false; - } - //std::cout << " nothing to do." << std::endl; - }); - if(!run_serial && bck_idx != 0) { - // in case this is not the starting node, we connect - // this with his predecessor - //make_edge(bck_builder[bck_idx-1], bck_builder[bck_idx]); - make_edge(*(bck_iter-1), *bck_iter); - } - all_builder_idx[depth] = bck_idx; - ++bck_idx; - } - } - - // finally wire in the checker which accepts or rejects the - // proposed states from the subtrees - //typedef function_node< tbb::flow::tuple, bool> checker_t; - //typedef join_node< tbb::flow::tuple > joiner_t; - typedef continue_node checker_t; - - tbb::concurrent_vector checks; - //std::vector joins; - - Eigen::VectorXd p_sharp_fwd(p_sharp); - Eigen::VectorXd p_sharp_bck(p_sharp); - - for (std::size_t depth=0; depth != this->max_depth_; ++depth) { - //joins.push_back(joiner_t(g)); - //std::cout << "creating check at depth " << depth << std::endl; - checks.emplace_back(g, [&,depth](continue_msg) { - bool is_fwd = fwd_direction[depth]; - - extend_tree_t& subtree_result = ends[depth]; - - // if we are still on the - // trajectories which are - // actually used update the - // running tree stats - if (this->valid_trees_) { - this->depth_ = depth + 1; - n_leapfrog += std::get<5>(subtree_result); - sum_metro_prob += std::get<6>(subtree_result); - } - - bool valid_subtree = is_fwd ? - valid_subtree_fwd[all_builder_idx[depth]] : - valid_subtree_bck[all_builder_idx[depth]]; - - bool is_valid = valid_subtree & this->valid_trees_; - - //std::cout << "CHECK at depth " << depth; - - if(!is_valid) { - //std::cout << " we are done (early)" << std::endl; - - // setting this globally here - // will terminate all ongoing work - this->valid_trees_ = false; - return; - } - - //std::cout << " checking" << std::endl; - - double log_sum_weight_subtree = std::get<1>(subtree_result); - const Eigen::VectorXd& rho_subtree = std::get<2>(subtree_result); - - // update correct side - if (is_fwd) { - p_sharp_fwd = std::get<3>(subtree_result); - } else { - p_sharp_bck = std::get<3>(subtree_result); - } - - const ps_point& z_propose = std::get<4>(subtree_result); - - // update running sums - if (log_sum_weight_subtree > log_sum_weight) { - z_sample = z_propose; - } else { - double accept_prob - = std::exp(log_sum_weight_subtree - log_sum_weight); - //if (this->rand_uniform_() < - //accept_prob) - // HACK - if (get_rand_uniform() < accept_prob) - z_sample = z_propose; - } - - log_sum_weight - = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); - - // Break when NUTS criterion is no longer satisfied - rho += rho_subtree; - if (!compute_criterion(p_sharp_bck, p_sharp_fwd, rho)) { - // setting this globally here - // will terminate all ongoing work - this->valid_trees_ = false; - //std::cout << " we are done (later)" << std::endl; - } - //std::cout << " continuing (later)" << std::endl; - }); - if(fwd_direction[depth]) { - //std::cout << "depth " << depth << ": joining fwd node " << all_builder_idx[depth] << " into join node." << std::endl; - make_edge(fwd_builder[all_builder_idx[depth]], checks.back()); - } else { - //std::cout << "depth " << depth << ": joining bck node " << all_builder_idx[depth] << " into join node." << std::endl; - make_edge(bck_builder[all_builder_idx[depth]], checks.back()); - } - if(!run_serial && depth != 0) { - make_edge(checks[depth-1], checks.back()); - } - } - - if(run_serial) { - for(std::size_t i = 1; i < this->max_depth_; ++i) { - make_edge(checks[i-1], fwd_direction[i] ? fwd_builder[all_builder_idx[i]] : bck_builder[all_builder_idx[i]]); - } - } - - // kick off work - if(fwd_direction[0]) { - fwd_builder[0].try_put(continue_msg()); - // the first turn is fwd, so kick off the bck walker if needed - if (!run_serial && num_bck != 0) - bck_builder[0].try_put(continue_msg()); - } else { - bck_builder[0].try_put(continue_msg()); - if (!run_serial && num_fwd != 0) - fwd_builder[0].try_put(continue_msg()); - } - - g.wait_for_all(); - - this->n_leapfrog_ = n_leapfrog; - //this->n_leapfrog_ = tree_fwd.n_leapfrog_ + tree_bck.n_leapfrog_; - - // this includes the speculative executed ones - //const double sum_metro_prob = tree_fwd.sum_metro_prob_ + tree_bck.sum_metro_prob_; - - // Compute average acceptance probabilty across entire trajectory, - // even over subtrees that may have been rejected - double accept_prob - = sum_metro_prob / static_cast(this->n_leapfrog_); - - this->z_.ps_point::operator=(z_sample); - this->energy_ = this->hamiltonian_.H(this->z_); - return sample(this->z_.q, -this->z_.V, accept_prob); - } - - sample - transition_refactored(sample& init_sample, callbacks::logger& logger) { - // Initialize the algorithm - this->sample_stepsize(); - - this->seed(init_sample.cont_params()); - - this->hamiltonian_.sample_p(this->z_, this->rand_int_); - this->hamiltonian_.init(this->z_, logger); - - const ps_point z_init(this->z_); - - ps_point z_sample(z_init); - ps_point z_propose(z_init); - - const Eigen::VectorXd p_sharp = this->hamiltonian_.dtau_dp(this->z_); - Eigen::VectorXd rho = this->z_.p; - - double log_sum_weight = 0; // log(exp(H0 - H0)) - double H0 = this->hamiltonian_.H(this->z_); - //int n_leapfrog = 0; - //double sum_metro_prob = 0; - - // forward tree - subtree tree_fwd(1, z_init, p_sharp, H0); - // backward tree - subtree tree_bck(-1, z_init, p_sharp, H0); - - // Build a trajectory until the NUTS criterion is no longer satisfied - this->depth_ = 0; - this->divergent_ = false; - - while (this->depth_ < this->max_depth_) { - bool valid_subtree; - double log_sum_weight_subtree; - Eigen::VectorXd rho_subtree; - - if (this->rand_uniform_() > 0.5) { - std::tie(valid_subtree, log_sum_weight_subtree, rho_subtree, z_propose) - = extend_tree(this->depth_, tree_fwd, this->z_, logger); - } else { - std::tie(valid_subtree, log_sum_weight_subtree, rho_subtree, z_propose) - = extend_tree(this->depth_, tree_bck, this->z_, logger); - } - - if (!valid_subtree) break; - - // Sample from an accepted subtree - ++(this->depth_); - - if (log_sum_weight_subtree > log_sum_weight) { - z_sample = z_propose; - } else { - double accept_prob - = std::exp(log_sum_weight_subtree - log_sum_weight); - if (this->rand_uniform_() < accept_prob) - z_sample = z_propose; - } - - log_sum_weight - = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); - - // Break when NUTS criterion is no longer satisfied - rho += rho_subtree; - if (!compute_criterion(tree_bck.p_sharp_end_, tree_fwd.p_sharp_end_, rho)) - break; - //if (!compute_criterion(p_sharp_minus, p_sharp_plus, rho)) - // break; - } - - //this->n_leapfrog_ = n_leapfrog; - this->n_leapfrog_ = tree_fwd.n_leapfrog_ + tree_bck.n_leapfrog_; - - const double sum_metro_prob = tree_fwd.sum_metro_prob_ + tree_bck.sum_metro_prob_; - - // Compute average acceptance probabilty across entire trajectory, - // even over subtrees that may have been rejected - double accept_prob - = sum_metro_prob / static_cast(this->n_leapfrog_); - - this->z_.ps_point::operator=(z_sample); - this->energy_ = this->hamiltonian_.H(this->z_); - return sample(this->z_.q, -this->z_.V, accept_prob); - } - - sample - transition_old(sample& init_sample, callbacks::logger& logger) { - // Initialize the algorithm - this->sample_stepsize(); - - this->seed(init_sample.cont_params()); - - this->hamiltonian_.sample_p(this->z_, this->rand_int_); - this->hamiltonian_.init(this->z_, logger); - - ps_point z_plus(this->z_); - ps_point z_minus(z_plus); - - ps_point z_sample(z_plus); - ps_point z_propose(z_plus); - - Eigen::VectorXd p_sharp_plus = this->hamiltonian_.dtau_dp(this->z_); - //Eigen::VectorXd p_sharp_dummy = p_sharp_plus; - Eigen::VectorXd p_sharp_minus = p_sharp_plus; - Eigen::VectorXd rho = this->z_.p; - - double log_sum_weight = 0; // log(exp(H0 - H0)) - double H0 = this->hamiltonian_.H(this->z_); - int n_leapfrog = 0; - double sum_metro_prob = 0; - - // Build a trajectory until the NUTS criterion is no longer satisfied - this->depth_ = 0; - this->divergent_ = false; - - while (this->depth_ < this->max_depth_) { - // Build a new subtree in a random direction - Eigen::VectorXd rho_subtree = Eigen::VectorXd::Zero(rho.size()); - bool valid_subtree = false; - double log_sum_weight_subtree - = -std::numeric_limits::infinity(); - - // this should be fine (modified from orig) - Eigen::VectorXd p_sharp_dummy = Eigen::VectorXd::Zero(this->z_.p.size()); - - if (this->rand_uniform_() > 0.5) { - this->z_.ps_point::operator=(z_plus); - valid_subtree - = build_tree(this->depth_, this->z_, z_propose, - p_sharp_dummy, p_sharp_plus, rho_subtree, - H0, 1, n_leapfrog, - log_sum_weight_subtree, sum_metro_prob, - logger); - z_plus.ps_point::operator=(this->z_); - } else { - this->z_.ps_point::operator=(z_minus); - valid_subtree - = build_tree(this->depth_, this->z_, z_propose, - p_sharp_dummy, p_sharp_minus, rho_subtree, - H0, -1, n_leapfrog, - log_sum_weight_subtree, sum_metro_prob, - logger); - z_minus.ps_point::operator=(this->z_); - } - - if (!valid_subtree) break; - - // Sample from an accepted subtree - ++(this->depth_); - - if (log_sum_weight_subtree > log_sum_weight) { - z_sample = z_propose; - } else { - double accept_prob - = std::exp(log_sum_weight_subtree - log_sum_weight); - if (this->rand_uniform_() < accept_prob) - z_sample = z_propose; - } - - log_sum_weight - = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); - - // Break when NUTS criterion is no longer satisfied - rho += rho_subtree; - if (!compute_criterion(p_sharp_minus, p_sharp_plus, rho)) - break; - } - - this->n_leapfrog_ = n_leapfrog; - - // Compute average acceptance probabilty across entire trajectory, - // even over subtrees that may have been rejected - double accept_prob - = sum_metro_prob / static_cast(n_leapfrog); - - this->z_.ps_point::operator=(z_sample); - this->energy_ = this->hamiltonian_.H(this->z_); - return sample(this->z_.q, -this->z_.V, accept_prob); - } - - void get_sampler_param_names(std::vector& names) { - names.push_back("stepsize__"); - names.push_back("treedepth__"); - names.push_back("n_leapfrog__"); - names.push_back("divergent__"); - names.push_back("energy__"); - } - - void get_sampler_params(std::vector& values) { - values.push_back(this->epsilon_); - values.push_back(this->depth_); - values.push_back(this->n_leapfrog_); - values.push_back(this->divergent_); - values.push_back(this->energy_); - } - - virtual bool compute_criterion(Eigen::VectorXd& p_sharp_minus, - Eigen::VectorXd& p_sharp_plus, - Eigen::VectorXd& rho) { - return p_sharp_plus.dot(rho) > 0 - && p_sharp_minus.dot(rho) > 0; - } - - /** - * Recursively build a new subtree to completion or until - * the subtree becomes invalid. Returns validity of the - * resulting subtree. - * - * @param depth Depth of the desired subtree - * @param z_beg State beginning from subtree - * @param z_propose State proposed from subtree - * @param p_sharp_left p_sharp from left boundary of returned tree - * @param p_sharp_right p_sharp from the right boundary of returned tree - * @param rho Summed momentum across trajectory - * @param H0 Hamiltonian of initial state - * @param sign Direction in time to built subtree - * @param n_leapfrog Summed number of leapfrog evaluations - * @param log_sum_weight Log of summed weights across trajectory - * @param sum_metro_prob Summed Metropolis probabilities across trajectory - * @param logger Logger for messages - */ - bool build_tree(int depth, state_t& z_beg, - ps_point& z_propose, - Eigen::VectorXd& p_sharp_left, - Eigen::VectorXd& p_sharp_right, - Eigen::VectorXd& rho, - double H0, double sign, int& n_leapfrog, - double& log_sum_weight, double& sum_metro_prob, - callbacks::logger& logger) { - // Base case - if (depth == 0) { - // check if trees are still valid or if we should terminate - if(!this->valid_trees_) - return false; - - this->integrator_.evolve(z_beg, this->hamiltonian_, - sign * this->epsilon_, - logger); - - ++n_leapfrog; - - double h = this->hamiltonian_.H(z_beg); - if (boost::math::isnan(h)) - h = std::numeric_limits::infinity(); - - // TODO: in parallel case we cannot use the global divergent - // flag since this could be a speculative tree!! - //if ((h - H0) > this->max_deltaH_) this->divergent_ = true; - bool is_divergent = (h - H0) > this->max_deltaH_; - //if ((h - H0) > this->max_deltaH_) this->divergent_ = true; - - log_sum_weight = math::log_sum_exp(log_sum_weight, H0 - h); - - if (H0 - h > 0) - sum_metro_prob += 1; - else - sum_metro_prob += std::exp(H0 - h); - - z_propose = z_beg; - rho += z_beg.p; - - p_sharp_left = this->hamiltonian_.dtau_dp(z_beg); - p_sharp_right = p_sharp_left; - - return !is_divergent; - } - // General recursion - Eigen::VectorXd p_sharp_dummy(z_beg.p.size()); - - // Build the left subtree - double log_sum_weight_left = -std::numeric_limits::infinity(); - Eigen::VectorXd rho_left = Eigen::VectorXd::Zero(rho.size()); - - bool valid_left - = build_tree(depth - 1, z_beg, z_propose, - p_sharp_left, p_sharp_dummy, rho_left, - H0, sign, n_leapfrog, - log_sum_weight_left, sum_metro_prob, - logger); - - if (!valid_left) return false; - - // Build the right subtree - ps_point z_propose_right(z_beg); - - double log_sum_weight_right = -std::numeric_limits::infinity(); - Eigen::VectorXd rho_right = Eigen::VectorXd::Zero(rho.size()); - - bool valid_right - = build_tree(depth - 1, z_beg, z_propose_right, - p_sharp_dummy, p_sharp_right, rho_right, - H0, sign, n_leapfrog, - log_sum_weight_right, sum_metro_prob, - logger); - - if (!valid_right) return false; - - // Multinomial sample from right subtree - double log_sum_weight_subtree - = math::log_sum_exp(log_sum_weight_left, log_sum_weight_right); - log_sum_weight + log_sum_weight = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); - if (log_sum_weight_right > log_sum_weight_subtree) { - z_propose = z_propose_right; - } else { - double accept_prob - = std::exp(log_sum_weight_right - log_sum_weight_subtree); - //if (this->rand_uniform_() < accept_prob) - if (get_rand_uniform() < accept_prob) - z_propose = z_propose_right; - } - - Eigen::VectorXd rho_subtree = rho_left + rho_right; - rho += rho_subtree; + // Break when no-u-turn criterion is no longer satisfied + rho = rho_bck + rho_fwd; + + // Demand satisfaction around merged subtrees + bool persist_criterion + = compute_criterion(p_sharp_bck_bck, p_sharp_fwd_fwd, rho); + + // Demand satisfaction between subtrees + Eigen::VectorXd rho_extended = rho_bck + p_fwd_bck; + + persist_criterion + &= compute_criterion(p_sharp_bck_bck, p_sharp_fwd_bck, rho_extended); + + rho_extended = rho_fwd + p_bck_fwd; + persist_criterion + &= compute_criterion(p_sharp_bck_fwd, p_sharp_fwd_fwd, rho_extended); + + if (!persist_criterion) + break; + } + + this->n_leapfrog_ = n_leapfrog; + + // Compute average acceptance probabilty across entire trajectory, + // even over subtrees that may have been rejected + double accept_prob = sum_metro_prob / static_cast(n_leapfrog); + + this->z_.ps_point::operator=(z_sample); + this->energy_ = this->hamiltonian_.H(this->z_); + return sample(this->z_.q, -this->z_.V, accept_prob); + } + + void get_sampler_param_names(std::vector& names) { + names.push_back("stepsize__"); + names.push_back("treedepth__"); + names.push_back("n_leapfrog__"); + names.push_back("divergent__"); + names.push_back("energy__"); + } + + void get_sampler_params(std::vector& values) { + values.push_back(this->epsilon_); + values.push_back(this->depth_); + values.push_back(this->n_leapfrog_); + values.push_back(this->divergent_); + values.push_back(this->energy_); + } + + virtual bool compute_criterion(Eigen::VectorXd& p_sharp_minus, + Eigen::VectorXd& p_sharp_plus, + Eigen::VectorXd& rho) { + return p_sharp_plus.dot(rho) > 0 && p_sharp_minus.dot(rho) > 0; + } + + /** + * Recursively build a new subtree to completion or until + * the subtree becomes invalid. Returns validity of the + * resulting subtree. + * + * @param depth Depth of the desired subtree + * @param z_propose State proposed from subtree + * @param p_sharp_beg Sharp momentum at beginning of new tree + * @param p_sharp_end Sharp momentum at end of new tree + * @param rho Summed momentum across trajectory + * @param p_beg Momentum at beginning of returned tree + * @param p_end Momentum at end of returned tree + * @param H0 Hamiltonian of initial state + * @param sign Direction in time to built subtree + * @param n_leapfrog Summed number of leapfrog evaluations + * @param log_sum_weight Log of summed weights across trajectory + * @param sum_metro_prob Summed Metropolis probabilities across trajectory + * @param logger Logger for messages + */ + bool build_tree(int depth, ps_point& z_propose, Eigen::VectorXd& p_sharp_beg, + Eigen::VectorXd& p_sharp_end, Eigen::VectorXd& rho, + Eigen::VectorXd& p_beg, Eigen::VectorXd& p_end, double H0, + double sign, int& n_leapfrog, double& log_sum_weight, + double& sum_metro_prob, callbacks::logger& logger) { + // Base case + if (depth == 0) { + this->integrator_.evolve(this->z_, this->hamiltonian_, + sign * this->epsilon_, logger); + ++n_leapfrog; + + double h = this->hamiltonian_.H(this->z_); + if (std::isnan(h)) + h = std::numeric_limits::infinity(); + + if ((h - H0) > this->max_deltaH_) + this->divergent_ = true; + + log_sum_weight = math::log_sum_exp(log_sum_weight, H0 - h); + + if (H0 - h > 0) + sum_metro_prob += 1; + else + sum_metro_prob += std::exp(H0 - h); + + z_propose = this->z_; + + p_sharp_beg = this->hamiltonian_.dtau_dp(this->z_); + p_sharp_end = p_sharp_beg; + + rho += this->z_.p; + p_beg = this->z_.p; + p_end = p_beg; + + return !this->divergent_; + } + // General recursion + + // Build the initial subtree + double log_sum_weight_init = -std::numeric_limits::infinity(); + + // Momentum and sharp momentum at end of the initial subtree + Eigen::VectorXd p_init_end(this->z_.p.size()); + Eigen::VectorXd p_sharp_init_end(this->z_.p.size()); + + Eigen::VectorXd rho_init = Eigen::VectorXd::Zero(rho.size()); + + bool valid_init + = build_tree(depth - 1, z_propose, p_sharp_beg, p_sharp_init_end, + rho_init, p_beg, p_init_end, H0, sign, n_leapfrog, + log_sum_weight_init, sum_metro_prob, logger); + + if (!valid_init) + return false; + + // Build the final subtree + ps_point z_propose_final(this->z_); + + double log_sum_weight_final = -std::numeric_limits::infinity(); + + // Momentum and sharp momentum at beginning of the final subtree + Eigen::VectorXd p_final_beg(this->z_.p.size()); + Eigen::VectorXd p_sharp_final_beg(this->z_.p.size()); + + Eigen::VectorXd rho_final = Eigen::VectorXd::Zero(rho.size()); + + bool valid_final + = build_tree(depth - 1, z_propose_final, p_sharp_final_beg, p_sharp_end, + rho_final, p_final_beg, p_end, H0, sign, n_leapfrog, + log_sum_weight_final, sum_metro_prob, logger); + + if (!valid_final) + return false; + + // Multinomial sample from right subtree + double log_sum_weight_subtree + = math::log_sum_exp(log_sum_weight_init, log_sum_weight_final); + log_sum_weight = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); + + if (log_sum_weight_final > log_sum_weight_subtree) { + z_propose = z_propose_final; + } else { + double accept_prob + = std::exp(log_sum_weight_final - log_sum_weight_subtree); + if (this->rand_uniform_() < accept_prob) + z_propose = z_propose_final; + } + + Eigen::VectorXd rho_subtree = rho_init + rho_final; + rho += rho_subtree; + + // Demand satisfaction around merged subtrees + bool persist_criterion + = compute_criterion(p_sharp_beg, p_sharp_end, rho_subtree); + + // Demand satisfaction between subtrees + rho_subtree = rho_init + p_final_beg; + persist_criterion + &= compute_criterion(p_sharp_beg, p_sharp_final_beg, rho_subtree); - return compute_criterion(p_sharp_left, p_sharp_right, rho_subtree); - } + rho_subtree = rho_final + p_init_end; + persist_criterion + &= compute_criterion(p_sharp_init_end, p_sharp_end, rho_subtree); - inline double get_rand_uniform() { - static std::mutex rng_mutex; - std::lock_guard lock(rng_mutex); - return this->rand_uniform_(); - } + return persist_criterion; + } - int depth_; - int max_depth_; - double max_deltaH_; - bool valid_trees_; + int depth_; + int max_depth_; + double max_deltaH_; - int n_leapfrog_; - bool divergent_; - double energy_; - }; + int n_leapfrog_; + bool divergent_; + double energy_; +}; - } // mcmc -} // stan +} // namespace mcmc +} // namespace stan #endif diff --git a/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp b/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp index 79c6a6cd5e1..e1ab28152a9 100644 --- a/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include #include @@ -16,11 +15,16 @@ #include -#include "tbb/task_scheduler_init.h" -#include "tbb/flow_graph.h" -#include "tbb/concurrent_vector.h" +#include +#include -using namespace tbb::flow; +// Prototype of speculative NUTS. +// Uses the Intel Flow Graph concept to turn NUTS into a parallel +// algorithm in that the forward and backward sweep run at the same +// time in parallel. + +namespace stan { +namespace mcmc { template inline auto make_uniform_vec(std::vector& thread_rngs) { @@ -36,14 +40,6 @@ inline auto make_uniform_vec(std::vector& thread_rngs) { thread_rngs.end()); } -// Prototype of speculative NUTS. -// Uses the Intel Flow Graph concept to turn NUTS into a parallel -// algorithm in that the forward and backward sweep run at the same -// time in parallel. - -namespace stan { -namespace mcmc { - /** * The No-U-Turn sampler (NUTS) with multinomial sampling */ @@ -247,15 +243,15 @@ class base_parallel_nuts callbacks::logger logger_bck; // build TBB flow graph - graph g; + tbb::flow::graph g; // add nodes which advance the left/right tree - typedef continue_node tree_builder_t; + using tree_builder_t = tbb::flow::continue_node; tbb::concurrent_vector all_builder_idx(this->max_depth_); tbb::concurrent_vector fwd_builder; tbb::concurrent_vector bck_builder; - typedef tbb::concurrent_vector::iterator builder_iter_t; + using builder_iter_t = tbb::concurrent_vector::iterator; // now wire up the fwd and bck build of the trees which // depends on single-core or multi-core run @@ -272,7 +268,7 @@ class base_parallel_nuts for (std::size_t depth = 0; depth != this->max_depth_; ++depth) { if (fwd_direction[depth]) { builder_iter_t fwd_iter - = fwd_builder.emplace_back(g, [&, depth, fwd_idx](continue_msg) { + = fwd_builder.emplace_back(g, [&, depth, fwd_idx](tbb::flow::continue_msg) { // std::cout << "fwd turn at depth " << depth; bool valid_parent = fwd_idx == 0 ? true : valid_subtree_fwd[fwd_idx - 1]; @@ -288,13 +284,13 @@ class base_parallel_nuts if (!run_serial && fwd_idx != 0) { // in this case this is not the starting node, we // connect this with its predecessor - make_edge(*(fwd_iter - 1), *fwd_iter); + tbb::flow::make_edge(*(fwd_iter - 1), *fwd_iter); } all_builder_idx[depth] = fwd_idx; ++fwd_idx; } else { builder_iter_t bck_iter - = bck_builder.emplace_back(g, [&, depth, bck_idx](continue_msg) { + = bck_builder.emplace_back(g, [&, depth, bck_idx](tbb::flow::continue_msg) { // std::cout << "bck turn at depth " << depth; bool valid_parent = bck_idx == 0 ? true : valid_subtree_bck[bck_idx - 1]; @@ -310,8 +306,8 @@ class base_parallel_nuts if (!run_serial && bck_idx != 0) { // in case this is not the starting node, we connect // this with his predecessor - // make_edge(bck_builder[bck_idx-1], bck_builder[bck_idx]); - make_edge(*(bck_iter - 1), *bck_iter); + // tbb::flow::make_edge(bck_builder[bck_idx-1], bck_builder[bck_idx]); + tbb::flow::make_edge(*(bck_iter - 1), *bck_iter); } all_builder_idx[depth] = bck_idx; ++bck_idx; @@ -322,7 +318,7 @@ class base_parallel_nuts // proposed states from the subtrees // typedef function_node< tbb::flow::tuple, bool> checker_t; // typedef join_node< tbb::flow::tuple > joiner_t; - typedef continue_node checker_t; + using checker_t = tbb::flow::continue_node; tbb::concurrent_vector checks; // std::vector joins; @@ -333,7 +329,7 @@ class base_parallel_nuts for (std::size_t depth = 0; depth != this->max_depth_; ++depth) { // joins.push_back(joiner_t(g)); // std::cout << "creating check at depth " << depth << std::endl; - checks.emplace_back(g, [&, depth](continue_msg) { + checks.emplace_back(g, [&, depth](tbb::flow::continue_msg) { const bool is_fwd = fwd_direction[depth]; extend_tree_t& subtree_result = ends[depth]; @@ -407,20 +403,20 @@ class base_parallel_nuts if (fwd_direction[depth]) { // std::cout << "depth " << depth << ": joining fwd node " << // all_builder_idx[depth] << " into join node." << std::endl; - make_edge(fwd_builder[all_builder_idx[depth]], checks.back()); + tbb::flow::make_edge(fwd_builder[all_builder_idx[depth]], checks.back()); } else { // std::cout << "depth " << depth << ": joining bck node " << // all_builder_idx[depth] << " into join node." << std::endl; - make_edge(bck_builder[all_builder_idx[depth]], checks.back()); + tbb::flow::make_edge(bck_builder[all_builder_idx[depth]], checks.back()); } if (!run_serial && depth != 0) { - make_edge(checks[depth - 1], checks.back()); + tbb::flow::make_edge(checks[depth - 1], checks.back()); } } if (run_serial) { for (std::size_t i = 1; i < this->max_depth_; ++i) { - make_edge(checks[i - 1], fwd_direction[i] + tbb::flow::make_edge(checks[i - 1], fwd_direction[i] ? fwd_builder[all_builder_idx[i]] : bck_builder[all_builder_idx[i]]); } @@ -428,14 +424,14 @@ class base_parallel_nuts // kick off work if (fwd_direction[0]) { - fwd_builder[0].try_put(continue_msg()); + fwd_builder[0].try_put(tbb::flow::continue_msg()); // the first turn is fwd, so kick off the bck walker if needed if (!run_serial && num_bck != 0) - bck_builder[0].try_put(continue_msg()); + bck_builder[0].try_put(tbb::flow::continue_msg()); } else { - bck_builder[0].try_put(continue_msg()); + bck_builder[0].try_put(tbb::flow::continue_msg()); if (!run_serial && num_fwd != 0) - fwd_builder[0].try_put(continue_msg()); + fwd_builder[0].try_put(tbb::flow::continue_msg()); } g.wait_for_all(); @@ -775,12 +771,14 @@ class base_parallel_nuts // Uniform(0, 1) RNG std::vector> rand_uniform_vec_; }; + template class Hamiltonian, template class Integrator, class BaseRNG> -using base_parallel_nuts_ct = std::conditional_t< +using base_nuts_ct = std::conditional_t< ParallelBase, base_parallel_nuts, - base_parallel_nuts>; + base_nuts>; + } // namespace mcmc } // namespace stan #endif diff --git a/src/stan/mcmc/hmc/nuts/diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/diag_e_nuts.hpp index 896d219e098..f3b8f95af3b 100644 --- a/src/stan/mcmc/hmc/nuts/diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/diag_e_nuts.hpp @@ -19,8 +19,10 @@ class diag_e_nuts : public base_nuts_ct { using base_nuts_t = base_nuts_ct; public: + template * = nullptr> diag_e_nuts(const Model& model, BaseRNG& rng) : base_nuts_t(model, rng) {} + template * = nullptr> diag_e_nuts(const Model& model, std::vector& thread_rngs) : base_nuts_t(model, thread_rngs) {} diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp index 74b3669cfe3..9be63df721d 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -19,147 +20,6 @@ namespace stan { namespace services { namespace sample { -/** - * Runs HMC with NUTS with adaptation using diagonal Euclidean metric - * with a pre-specified Euclidean metric. - * - * @tparam Model Model class - * @tparam InitContextPtr A type derived from `stan::io::var_context` - * @tparam InitMetricContext A type derived from `stan::io::var_context` - * @tparam SamplerWriter A type derived from `stan::callbacks::writer` - * @tparam DiagnosticWriter A type derived from `stan::callbacks::writer` - * @tparam InitWriter A type derived from `stan::callbacks::writer` - * @param[in] model Input model to test (with data already instantiated) - * @param[in] init var context for initialization - * @param[in] init_inv_metric var context exposing an initial diagonal - inverse Euclidean metric (must be positive definite) - * @param[in] random_seed random seed for the random number generator - * @param[in] chain chain id to advance the pseudo random number generator - * @param[in] init_radius radius to initialize - * @param[in] num_warmup Number of warmup samples - * @param[in] num_samples Number of samples - * @param[in] num_thin Number to thin the samples - * @param[in] save_warmup Indicates whether to save the warmup iterations - * @param[in] refresh Controls the output - * @param[in] stepsize initial stepsize for discrete evolution - * @param[in] stepsize_jitter uniform random jitter of stepsize - * @param[in] max_depth Maximum tree depth - * @param[in] delta adaptation target acceptance statistic - * @param[in] gamma adaptation regularization scale - * @param[in] kappa adaptation relaxation exponent - * @param[in] t0 adaptation iteration offset - * @param[in] init_buffer width of initial fast adaptation interval - * @param[in] term_buffer width of final fast adaptation interval - * @param[in] window initial width of slow adaptation interval - * @param[in,out] interrupt Callback for interrupts - * @param[in,out] logger Logger for messages - * @param[in,out] init_writer Writer callback for unconstrained inits - * @param[in,out] sample_writer Writer for draws - * @param[in,out] diagnostic_writer Writer for diagnostic information - * @return error_codes::OK if successful - */ -template -int hmc_nuts_diag_e_adapt_parallel( - Model& model, const stan::io::var_context& init, - const stan::io::var_context& init_inv_metric, unsigned int random_seed, - unsigned int chain, double init_radius, int num_warmup, int num_samples, - int num_thin, bool save_warmup, int refresh, double stepsize, - double stepsize_jitter, int max_depth, double delta, double gamma, - double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer, - unsigned int window, callbacks::interrupt& interrupt, - callbacks::logger& logger, callbacks::writer& init_writer, - callbacks::writer& sample_writer, callbacks::writer& diagnostic_writer) { - const int num_threads = stan::math::get_num_threads(); - std::vector rngs; - rngs.reserve(num_threads) - for (size_t i = 0; i < num_threads; ++i) { - rngs.emplace_back(util::create_rng(random_seed, chain + i)); - } - std::vector cont_vector = util::initialize( - model, init, rngs[0], init_radius, true, logger, init_writer); - - Eigen::VectorXd inv_metric; - try { - inv_metric = util::read_diag_inv_metric(init_inv_metric, - model.num_params_r(), logger); - util::validate_diag_inv_metric(inv_metric, logger); - } catch (const std::domain_error& e) { - return error_codes::CONFIG; - } - - stan::mcmc::adapt_diag_e_nuts sampler(model, rngs); - - sampler.set_metric(inv_metric); - sampler.set_nominal_stepsize(stepsize); - sampler.set_stepsize_jitter(stepsize_jitter); - sampler.set_max_depth(max_depth); - - sampler.get_stepsize_adaptation().set_mu(log(10 * stepsize)); - sampler.get_stepsize_adaptation().set_delta(delta); - sampler.get_stepsize_adaptation().set_gamma(gamma); - sampler.get_stepsize_adaptation().set_kappa(kappa); - sampler.get_stepsize_adaptation().set_t0(t0); - - sampler.set_window_params(num_warmup, init_buffer, term_buffer, window, - logger); - - util::run_adaptive_sampler( - sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, - save_warmup, rngs[0], interrupt, logger, sample_writer, diagnostic_writer); - - return error_codes::OK; -} - -/** - * Runs HMC with NUTS with adaptation using diagonal Euclidean metric. - * - * @tparam Model Model class - * @param[in] model Input model to test (with data already instantiated) - * @param[in] init var context for initialization - * @param[in] random_seed random seed for the random number generator - * @param[in] chain chain id to advance the pseudo random number generator - * @param[in] init_radius radius to initialize - * @param[in] num_warmup Number of warmup samples - * @param[in] num_samples Number of samples - * @param[in] num_thin Number to thin the samples - * @param[in] save_warmup Indicates whether to save the warmup iterations - * @param[in] refresh Controls the output - * @param[in] stepsize initial stepsize for discrete evolution - * @param[in] stepsize_jitter uniform random jitter of stepsize - * @param[in] max_depth Maximum tree depth - * @param[in] delta adaptation target acceptance statistic - * @param[in] gamma adaptation regularization scale - * @param[in] kappa adaptation relaxation exponent - * @param[in] t0 adaptation iteration offset - * @param[in] init_buffer width of initial fast adaptation interval - * @param[in] term_buffer width of final fast adaptation interval - * @param[in] window initial width of slow adaptation interval - * @param[in,out] interrupt Callback for interrupts - * @param[in,out] logger Logger for messages - * @param[in,out] init_writer Writer callback for unconstrained inits - * @param[in,out] sample_writer Writer for draws - * @param[in,out] diagnostic_writer Writer for diagnostic information - * @return error_codes::OK if successful - */ -template -int hmc_nuts_diag_e_adapt_parallel( - Model& model, const stan::io::var_context& init, unsigned int random_seed, - unsigned int chain, double init_radius, int num_warmup, int num_samples, - int num_thin, bool save_warmup, int refresh, double stepsize, - double stepsize_jitter, int max_depth, double delta, double gamma, - double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer, - unsigned int window, callbacks::interrupt& interrupt, - callbacks::logger& logger, callbacks::writer& init_writer, - callbacks::writer& sample_writer, callbacks::writer& diagnostic_writer) { - stan::io::dump unit_e_metric - = util::create_unit_e_diag_inv_metric(model.num_params_r()); - return hmc_nuts_diag_e_adapt_parallel( - model, init, unit_e_metric, random_seed, chain, init_radius, num_warmup, - num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, - max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, - interrupt, logger, init_writer, sample_writer, diagnostic_writer); -} - /** * Runs multiple chains of HMC with NUTS with adaptation using diagonal * Euclidean metric with a pre-specified Euclidean metric. @@ -224,15 +84,15 @@ int hmc_nuts_diag_e_adapt_parallel( std::vector& init_writer, std::vector& sample_writer, std::vector& diagnostic_writer) { - if (num_chains == 1 || stan::math::get_num_threads() == 1) { - return hmc_nuts_diag_e_adapt_parallel( - model, *init[0], *init_inv_metric[0], random_seed, init_chain_id, + if (stan::math::internal::get_num_threads() == 1) { + return hmc_nuts_diag_e_adapt( + model, num_chains, init, init_inv_metric, random_seed, init_chain_id, init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0, - init_buffer, term_buffer, window, interrupt, logger, init_writer[0], - sample_writer[0], diagnostic_writer[0]); + init_buffer, term_buffer, window, interrupt, logger, init_writer, + sample_writer, diagnostic_writer); } - const int num_threads = stan::math::get_num_threads(); + const int num_threads = stan::math::internal::get_num_threads(); std::vector rngs; rngs.reserve(num_threads); try { @@ -242,51 +102,55 @@ int hmc_nuts_diag_e_adapt_parallel( } catch (const std::domain_error& e) { return error_codes::CONFIG; } - error_codes ret_code; - tbb::parallel_for(tbb::blocked_range(0, num_chains, 1), - [num_warmup, num_samples, num_thin, refresh, save_warmup, - num_chains, init_chain_id, &ret_code, &model, &rngs, - &interrupt, &logger, &sample_writer, - &diagnostic_writer](const tbb::blocked_range& r) { - boost::ecuyer1988& thread_rng = rngs[tbb::this_task_arena::current_thread_index()] - using sample_t = stan::mcmc::adapt_diag_e_nuts; - Eigen::VectorXd inv_metric; - std::vector cont_vector; - for (size_t i = r.begin(); i != r.end(); ++i) { - sample_t sampler(model, rngs); - try { - cont_vector = util::initialize( - model, *init[i], thread_rng, init_radius, true, logger, init_writer[i]); - inv_metric = util::read_diag_inv_metric( - *init_inv_metric[i], model.num_params_r(), logger); - util::validate_diag_inv_metric(inv_metric, logger); - - sampler.set_metric(inv_metric); - sampler.set_nominal_stepsize(stepsize); - sampler.set_stepsize_jitter(stepsize_jitter); - sampler.set_max_depth(max_depth); - - sampler.get_stepsize_adaptation().set_mu(log(10 * stepsize)); - sampler.get_stepsize_adaptation().set_delta(delta); - sampler.get_stepsize_adaptation().set_gamma(gamma); - sampler.get_stepsize_adaptation().set_kappa(kappa); - sampler.get_stepsize_adaptation().set_t0(t0); - sampler.set_window_params(num_warmup, init_buffer, term_buffer, - window, logger); - } catch (const std::domain_error& e) { - ret_code = error_codes::CONFIG; - return; - } - util::run_adaptive_sampler( - sampler, model, cont_vector, num_warmup, - num_samples, num_thin, refresh, save_warmup, - rngs[i], interrupt, logger, sample_writer[i], - diagnostic_writer[i], init_chain_id + i, - num_chains); - } - }, - tbb::simple_partitioner()); - return ret_code == error_codes::CONFIG ? error_codes::CONFIG : error_codes::OK; + int ret_code = error_codes::OK; + tbb::parallel_for( + tbb::blocked_range(0, num_chains, 1), + [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, + init_chain_id, &ret_code, &model, &rngs, &interrupt, &logger, + &sample_writer, &init, &init_writer, &init_inv_metric, init_radius, delta, stepsize, max_depth, + stepsize_jitter, gamma, kappa, t0, init_buffer, term_buffer, window, + &diagnostic_writer](const tbb::blocked_range& r) { + boost::ecuyer1988& thread_rng + = rngs[tbb::this_task_arena::current_thread_index()]; + using sample_t + = stan::mcmc::adapt_diag_e_nuts; + Eigen::VectorXd inv_metric; + std::vector cont_vector; + for (size_t i = r.begin(); i != r.end(); ++i) { + sample_t sampler(model, rngs); + try { + cont_vector + = util::initialize(model, *init[i], thread_rng, init_radius, + true, logger, init_writer[i]); + inv_metric = util::read_diag_inv_metric( + *init_inv_metric[i], model.num_params_r(), logger); + util::validate_diag_inv_metric(inv_metric, logger); + + sampler.set_metric(inv_metric); + sampler.set_nominal_stepsize(stepsize); + sampler.set_stepsize_jitter(stepsize_jitter); + sampler.set_max_depth(max_depth); + + sampler.get_stepsize_adaptation().set_mu(log(10 * stepsize)); + sampler.get_stepsize_adaptation().set_delta(delta); + sampler.get_stepsize_adaptation().set_gamma(gamma); + sampler.get_stepsize_adaptation().set_kappa(kappa); + sampler.get_stepsize_adaptation().set_t0(t0); + sampler.set_window_params(num_warmup, init_buffer, term_buffer, + window, logger); + } catch (const std::domain_error& e) { + ret_code = error_codes::CONFIG; + return; + } + util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, + num_samples, num_thin, refresh, + save_warmup, rngs[i], interrupt, logger, + sample_writer[i], diagnostic_writer[i], + init_chain_id + i, num_chains); + } + }, + tbb::simple_partitioner()); + return ret_code; } /** @@ -347,13 +211,13 @@ int hmc_nuts_diag_e_adapt_parallel( std::vector& init_writer, std::vector& sample_writer, std::vector& diagnostic_writer) { - if (num_chains == 1 || stan::math::get_num_threads() == 1) { - return hmc_nuts_diag_e_adapt_parallel( - model, *init[0], random_seed, init_chain_id, init_radius, num_warmup, + if (stan::math::internal::get_num_threads() == 1) { + return hmc_nuts_diag_e_adapt( + model, num_chains, init, random_seed, init_chain_id, init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, - interrupt, logger, init_writer[0], sample_writer[0], - diagnostic_writer[0]); + interrupt, logger, init_writer, sample_writer, + diagnostic_writer); } std::vector> unit_e_metrics; unit_e_metrics.reserve(num_chains); diff --git a/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_parallel_parallel_test.cpp b/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_parallel_parallel_test.cpp new file mode 100644 index 00000000000..b11c2889bed --- /dev/null +++ b/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_parallel_parallel_test.cpp @@ -0,0 +1,175 @@ +#include +#include +#include +#include +#include +#include + +auto&& blah = stan::math::init_threadpool_tbb(); + +static constexpr size_t num_chains = 4; +class ServicesSampleHmcNutsDiagEAdaptPar : public testing::Test { + public: + ServicesSampleHmcNutsDiagEAdaptPar() : model(data_context, 0, &model_log) { + for (int i = 0; i < num_chains; ++i) { + init.push_back(stan::test::unit::instrumented_writer{}); + parameter.push_back(stan::test::unit::instrumented_writer{}); + diagnostic.push_back(stan::test::unit::instrumented_writer{}); + context.push_back(std::make_shared()); + } + } + stan::io::empty_var_context data_context; + std::stringstream model_log; + stan::test::unit::instrumented_logger logger; + std::vector init; + std::vector parameter; + std::vector diagnostic; + std::vector> context; + stan_model model; +}; + +TEST_F(ServicesSampleHmcNutsDiagEAdaptPar, call_count) { + unsigned int random_seed = 0; + unsigned int chain = 1; + double init_radius = 0; + int num_warmup = 200; + int num_samples = 400; + int num_thin = 5; + bool save_warmup = true; + int refresh = 0; + double stepsize = 0.1; + double stepsize_jitter = 0; + int max_depth = 8; + double delta = .1; + double gamma = .1; + double kappa = .1; + double t0 = .1; + unsigned int init_buffer = 50; + unsigned int term_buffer = 50; + unsigned int window = 100; + stan::test::unit::instrumented_interrupt interrupt; + EXPECT_EQ(interrupt.call_count(), 0); + + int return_code = stan::services::sample::hmc_nuts_diag_e_adapt_parallel( + model, num_chains, context, random_seed, chain, init_radius, num_warmup, + num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, + max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, + interrupt, logger, init, parameter, diagnostic); + + EXPECT_EQ(0, return_code); + + int num_output_lines = (num_warmup + num_samples) / num_thin; + EXPECT_EQ((num_warmup + num_samples) * num_chains, interrupt.call_count()); + for (int i = 0; i < num_chains; ++i) { + EXPECT_EQ(1, parameter[i].call_count("vector_string")); + EXPECT_EQ(num_output_lines, parameter[i].call_count("vector_double")); + EXPECT_EQ(1, diagnostic[i].call_count("vector_string")); + EXPECT_EQ(num_output_lines, diagnostic[i].call_count("vector_double")); + } +} + +TEST_F(ServicesSampleHmcNutsDiagEAdaptPar, parameter_checks) { + unsigned int random_seed = 0; + unsigned int chain = 1; + double init_radius = 0; + int num_warmup = 200; + int num_samples = 400; + int num_thin = 5; + bool save_warmup = true; + int refresh = 0; + double stepsize = 0.1; + double stepsize_jitter = 0; + int max_depth = 8; + double delta = .1; + double gamma = .1; + double kappa = .1; + double t0 = .1; + unsigned int init_buffer = 50; + unsigned int term_buffer = 50; + unsigned int window = 100; + stan::test::unit::instrumented_interrupt interrupt; + EXPECT_EQ(interrupt.call_count(), 0); + + int return_code = stan::services::sample::hmc_nuts_diag_e_adapt_parallel( + model, num_chains, context, random_seed, chain, init_radius, num_warmup, + num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, + max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, + interrupt, logger, init, parameter, diagnostic); + + for (size_t i = 0; i < num_chains; ++i) { + std::vector> parameter_names; + parameter_names = parameter[i].vector_string_values(); + std::vector> parameter_values; + parameter_values = parameter[i].vector_double_values(); + std::vector> diagnostic_names; + diagnostic_names = diagnostic[i].vector_string_values(); + std::vector> diagnostic_values; + diagnostic_values = diagnostic[i].vector_double_values(); + + // Expectations of parameter parameter names. + ASSERT_EQ(9, parameter_names[0].size()); + EXPECT_EQ("lp__", parameter_names[0][0]); + EXPECT_EQ("accept_stat__", parameter_names[0][1]); + EXPECT_EQ("stepsize__", parameter_names[0][2]); + EXPECT_EQ("treedepth__", parameter_names[0][3]); + EXPECT_EQ("n_leapfrog__", parameter_names[0][4]); + EXPECT_EQ("divergent__", parameter_names[0][5]); + EXPECT_EQ("energy__", parameter_names[0][6]); + EXPECT_EQ("x", parameter_names[0][7]); + EXPECT_EQ("y", parameter_names[0][8]); + + // Expect one name per parameter value. + EXPECT_EQ(parameter_names[0].size(), parameter_values[0].size()); + EXPECT_EQ(diagnostic_names[0].size(), diagnostic_values[0].size()); + + EXPECT_EQ((num_warmup + num_samples) / num_thin, parameter_values.size()); + + // Expect one call to set parameter names, and one set of output per + // iteration. + EXPECT_EQ("lp__", diagnostic_names[0][0]); + EXPECT_EQ("accept_stat__", diagnostic_names[0][1]); + } + EXPECT_EQ(return_code, 0); +} + +TEST_F(ServicesSampleHmcNutsDiagEAdaptPar, output_regression) { + unsigned int random_seed = 0; + unsigned int chain = 1; + double init_radius = 0; + int num_warmup = 200; + int num_samples = 400; + int num_thin = 5; + bool save_warmup = true; + int refresh = 0; + double stepsize = 0.1; + double stepsize_jitter = 0; + int max_depth = 8; + double delta = .1; + double gamma = .1; + double kappa = .1; + double t0 = .1; + unsigned int init_buffer = 50; + unsigned int term_buffer = 50; + unsigned int window = 100; + stan::test::unit::instrumented_interrupt interrupt; + EXPECT_EQ(interrupt.call_count(), 0); + + stan::services::sample::hmc_nuts_diag_e_adapt_parallel( + model, num_chains, context, random_seed, chain, init_radius, num_warmup, + num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, + max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, + interrupt, logger, init, parameter, diagnostic); + + for (auto&& init_it : init) { + std::vector init_values; + init_values = init_it.string_values(); + + EXPECT_EQ(0, init_values.size()); + } + + EXPECT_EQ(num_chains, logger.find_info("Elapsed Time:")); + EXPECT_EQ(num_chains, logger.find_info("seconds (Warm-up)")); + EXPECT_EQ(num_chains, logger.find_info("seconds (Sampling)")); + EXPECT_EQ(num_chains, logger.find_info("seconds (Total)")); + EXPECT_EQ(0, logger.call_count_error()); +} From f18ae74bb11850d1c73bdabb5134bd78a86b9627 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Wed, 16 Feb 2022 17:05:17 -0500 Subject: [PATCH 6/8] clang format --- src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp | 73 ++++++++++--------- .../sample/hmc_nuts_diag_e_adapt_parallel.hpp | 13 ++-- 2 files changed, 45 insertions(+), 41 deletions(-) diff --git a/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp b/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp index e1ab28152a9..adf4da9be68 100644 --- a/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp @@ -267,20 +267,20 @@ class base_parallel_nuts // if statement here for (std::size_t depth = 0; depth != this->max_depth_; ++depth) { if (fwd_direction[depth]) { - builder_iter_t fwd_iter - = fwd_builder.emplace_back(g, [&, depth, fwd_idx](tbb::flow::continue_msg) { - // std::cout << "fwd turn at depth " << depth; - bool valid_parent - = fwd_idx == 0 ? true : valid_subtree_fwd[fwd_idx - 1]; - if (valid_parent) { - // std::cout << " yes, here we go!" << std::endl; - ends[depth] = extend_tree(depth, tree_fwd, z_fwd, logger_fwd); - valid_subtree_fwd[fwd_idx] = std::get<0>(ends[depth]); - } else { - valid_subtree_fwd[fwd_idx] = false; - } - // std::cout << " nothing to do." << std::endl; - }); + builder_iter_t fwd_iter = fwd_builder.emplace_back( + g, [&, depth, fwd_idx](tbb::flow::continue_msg) { + // std::cout << "fwd turn at depth " << depth; + bool valid_parent + = fwd_idx == 0 ? true : valid_subtree_fwd[fwd_idx - 1]; + if (valid_parent) { + // std::cout << " yes, here we go!" << std::endl; + ends[depth] = extend_tree(depth, tree_fwd, z_fwd, logger_fwd); + valid_subtree_fwd[fwd_idx] = std::get<0>(ends[depth]); + } else { + valid_subtree_fwd[fwd_idx] = false; + } + // std::cout << " nothing to do." << std::endl; + }); if (!run_serial && fwd_idx != 0) { // in this case this is not the starting node, we // connect this with its predecessor @@ -289,20 +289,20 @@ class base_parallel_nuts all_builder_idx[depth] = fwd_idx; ++fwd_idx; } else { - builder_iter_t bck_iter - = bck_builder.emplace_back(g, [&, depth, bck_idx](tbb::flow::continue_msg) { - // std::cout << "bck turn at depth " << depth; - bool valid_parent - = bck_idx == 0 ? true : valid_subtree_bck[bck_idx - 1]; - if (valid_parent) { - // std::cout << " yes, here we go!" << std::endl; - ends[depth] = extend_tree(depth, tree_bck, z_bck, logger_bck); - valid_subtree_bck[bck_idx] = std::get<0>(ends[depth]); - } else { - valid_subtree_bck[bck_idx] = false; - } - // std::cout << " nothing to do." << std::endl; - }); + builder_iter_t bck_iter = bck_builder.emplace_back( + g, [&, depth, bck_idx](tbb::flow::continue_msg) { + // std::cout << "bck turn at depth " << depth; + bool valid_parent + = bck_idx == 0 ? true : valid_subtree_bck[bck_idx - 1]; + if (valid_parent) { + // std::cout << " yes, here we go!" << std::endl; + ends[depth] = extend_tree(depth, tree_bck, z_bck, logger_bck); + valid_subtree_bck[bck_idx] = std::get<0>(ends[depth]); + } else { + valid_subtree_bck[bck_idx] = false; + } + // std::cout << " nothing to do." << std::endl; + }); if (!run_serial && bck_idx != 0) { // in case this is not the starting node, we connect // this with his predecessor @@ -344,8 +344,9 @@ class base_parallel_nuts sum_metro_prob += std::get<6>(subtree_result); } - const bool valid_subtree = is_fwd ? valid_subtree_fwd[all_builder_idx[depth]] - : valid_subtree_bck[all_builder_idx[depth]]; + const bool valid_subtree + = is_fwd ? valid_subtree_fwd[all_builder_idx[depth]] + : valid_subtree_bck[all_builder_idx[depth]]; const bool is_valid = valid_subtree & this->valid_trees_; @@ -403,11 +404,13 @@ class base_parallel_nuts if (fwd_direction[depth]) { // std::cout << "depth " << depth << ": joining fwd node " << // all_builder_idx[depth] << " into join node." << std::endl; - tbb::flow::make_edge(fwd_builder[all_builder_idx[depth]], checks.back()); + tbb::flow::make_edge(fwd_builder[all_builder_idx[depth]], + checks.back()); } else { // std::cout << "depth " << depth << ": joining bck node " << // all_builder_idx[depth] << " into join node." << std::endl; - tbb::flow::make_edge(bck_builder[all_builder_idx[depth]], checks.back()); + tbb::flow::make_edge(bck_builder[all_builder_idx[depth]], + checks.back()); } if (!run_serial && depth != 0) { tbb::flow::make_edge(checks[depth - 1], checks.back()); @@ -416,9 +419,9 @@ class base_parallel_nuts if (run_serial) { for (std::size_t i = 1; i < this->max_depth_; ++i) { - tbb::flow::make_edge(checks[i - 1], fwd_direction[i] - ? fwd_builder[all_builder_idx[i]] - : bck_builder[all_builder_idx[i]]); + tbb::flow::make_edge( + checks[i - 1], fwd_direction[i] ? fwd_builder[all_builder_idx[i]] + : bck_builder[all_builder_idx[i]]); } } diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp index 9be63df721d..13f814f6bbf 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp @@ -107,8 +107,9 @@ int hmc_nuts_diag_e_adapt_parallel( tbb::blocked_range(0, num_chains, 1), [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, init_chain_id, &ret_code, &model, &rngs, &interrupt, &logger, - &sample_writer, &init, &init_writer, &init_inv_metric, init_radius, delta, stepsize, max_depth, - stepsize_jitter, gamma, kappa, t0, init_buffer, term_buffer, window, + &sample_writer, &init, &init_writer, &init_inv_metric, init_radius, + delta, stepsize, max_depth, stepsize_jitter, gamma, kappa, t0, + init_buffer, term_buffer, window, &diagnostic_writer](const tbb::blocked_range& r) { boost::ecuyer1988& thread_rng = rngs[tbb::this_task_arena::current_thread_index()]; @@ -213,10 +214,10 @@ int hmc_nuts_diag_e_adapt_parallel( std::vector& diagnostic_writer) { if (stan::math::internal::get_num_threads() == 1) { return hmc_nuts_diag_e_adapt( - model, num_chains, init, random_seed, init_chain_id, init_radius, num_warmup, - num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, - max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, - interrupt, logger, init_writer, sample_writer, + model, num_chains, init, random_seed, init_chain_id, init_radius, + num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, + stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer, + term_buffer, window, interrupt, logger, init_writer, sample_writer, diagnostic_writer); } std::vector> unit_e_metrics; From ff668c966032f781a6f570e85ad3ef25f69ba4ae Mon Sep 17 00:00:00 2001 From: stevebronder Date: Thu, 17 Feb 2022 14:45:32 -0500 Subject: [PATCH 7/8] update to not use get_num_theads() to detect thread level --- src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp | 38 ++++++++----------- .../sample/hmc_nuts_diag_e_adapt_parallel.hpp | 8 ++-- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp b/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp index adf4da9be68..87b5c31ce51 100644 --- a/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp @@ -239,9 +239,10 @@ class base_parallel_nuts tbb::concurrent_vector valid_subtree_bck(num_bck, true); // HACK!!! + /* callbacks::logger logger_fwd; callbacks::logger logger_bck; - + */ // build TBB flow graph tbb::flow::graph g; @@ -250,22 +251,19 @@ class base_parallel_nuts tbb::concurrent_vector all_builder_idx(this->max_depth_); tbb::concurrent_vector fwd_builder; + fwd_builder.reserve(this->max_depth_); tbb::concurrent_vector bck_builder; + bck_builder.reserve(this->max_depth_); using builder_iter_t = tbb::concurrent_vector::iterator; // now wire up the fwd and bck build of the trees which // depends on single-core or multi-core run - // TODO (Steve) We should only use this class if get_num_threads > 1 - // Else just use the non-parallel nuts. - const bool run_serial = stan::math::internal::get_num_threads() == 1; - std::size_t fwd_idx = 0; - std::size_t bck_idx = 0; // TODO: the extenders should also check for a global flag if // we want to keep running // TODO: We should also just run depth = 0 outside the loop to avoid the // if statement here - for (std::size_t depth = 0; depth != this->max_depth_; ++depth) { + for (std::size_t depth = 0, fwd_idx = 0, bck_idx = 0; depth != this->max_depth_; ++depth) { if (fwd_direction[depth]) { builder_iter_t fwd_iter = fwd_builder.emplace_back( g, [&, depth, fwd_idx](tbb::flow::continue_msg) { @@ -274,14 +272,14 @@ class base_parallel_nuts = fwd_idx == 0 ? true : valid_subtree_fwd[fwd_idx - 1]; if (valid_parent) { // std::cout << " yes, here we go!" << std::endl; - ends[depth] = extend_tree(depth, tree_fwd, z_fwd, logger_fwd); + ends[depth] = extend_tree(depth, tree_fwd, z_fwd, logger); valid_subtree_fwd[fwd_idx] = std::get<0>(ends[depth]); } else { valid_subtree_fwd[fwd_idx] = false; } // std::cout << " nothing to do." << std::endl; }); - if (!run_serial && fwd_idx != 0) { + if (fwd_idx != 0) { // in this case this is not the starting node, we // connect this with its predecessor tbb::flow::make_edge(*(fwd_iter - 1), *fwd_iter); @@ -296,14 +294,14 @@ class base_parallel_nuts = bck_idx == 0 ? true : valid_subtree_bck[bck_idx - 1]; if (valid_parent) { // std::cout << " yes, here we go!" << std::endl; - ends[depth] = extend_tree(depth, tree_bck, z_bck, logger_bck); + ends[depth] = extend_tree(depth, tree_bck, z_bck, logger); valid_subtree_bck[bck_idx] = std::get<0>(ends[depth]); } else { valid_subtree_bck[bck_idx] = false; } // std::cout << " nothing to do." << std::endl; }); - if (!run_serial && bck_idx != 0) { + if (bck_idx != 0) { // in case this is not the starting node, we connect // this with his predecessor // tbb::flow::make_edge(bck_builder[bck_idx-1], bck_builder[bck_idx]); @@ -348,7 +346,7 @@ class base_parallel_nuts = is_fwd ? valid_subtree_fwd[all_builder_idx[depth]] : valid_subtree_bck[all_builder_idx[depth]]; - const bool is_valid = valid_subtree & this->valid_trees_; + const bool is_valid = valid_subtree && this->valid_trees_; // std::cout << "CHECK at depth " << depth; @@ -412,29 +410,23 @@ class base_parallel_nuts tbb::flow::make_edge(bck_builder[all_builder_idx[depth]], checks.back()); } - if (!run_serial && depth != 0) { + if (depth != 0) { tbb::flow::make_edge(checks[depth - 1], checks.back()); } } - if (run_serial) { - for (std::size_t i = 1; i < this->max_depth_; ++i) { - tbb::flow::make_edge( - checks[i - 1], fwd_direction[i] ? fwd_builder[all_builder_idx[i]] - : bck_builder[all_builder_idx[i]]); - } - } - // kick off work if (fwd_direction[0]) { fwd_builder[0].try_put(tbb::flow::continue_msg()); // the first turn is fwd, so kick off the bck walker if needed - if (!run_serial && num_bck != 0) + if (num_bck != 0) { bck_builder[0].try_put(tbb::flow::continue_msg()); + } } else { bck_builder[0].try_put(tbb::flow::continue_msg()); - if (!run_serial && num_fwd != 0) + if (num_fwd != 0) { fwd_builder[0].try_put(tbb::flow::continue_msg()); + } } g.wait_for_all(); diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp index 13f814f6bbf..3ebda83bd14 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp @@ -84,7 +84,8 @@ int hmc_nuts_diag_e_adapt_parallel( std::vector& init_writer, std::vector& sample_writer, std::vector& diagnostic_writer) { - if (stan::math::internal::get_num_threads() == 1) { + if (tbb::this_task_arena::max_concurrency() == 1) { + std::cout << "Running serial" << std::endl; return hmc_nuts_diag_e_adapt( model, num_chains, init, init_inv_metric, random_seed, init_chain_id, init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh, @@ -92,7 +93,7 @@ int hmc_nuts_diag_e_adapt_parallel( init_buffer, term_buffer, window, interrupt, logger, init_writer, sample_writer, diagnostic_writer); } - const int num_threads = stan::math::internal::get_num_threads(); + const int num_threads = tbb::this_task_arena::max_concurrency(); std::vector rngs; rngs.reserve(num_threads); try { @@ -212,7 +213,8 @@ int hmc_nuts_diag_e_adapt_parallel( std::vector& init_writer, std::vector& sample_writer, std::vector& diagnostic_writer) { - if (stan::math::internal::get_num_threads() == 1) { + if (tbb::this_task_arena::max_concurrency() == 1) { + std::cout << "Running serial" << std::endl; return hmc_nuts_diag_e_adapt( model, num_chains, init, random_seed, init_chain_id, init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, From 334d242b63903587b0ddf43d81c0a6835ef5f68f Mon Sep 17 00:00:00 2001 From: stevebronder Date: Thu, 17 Feb 2022 15:04:56 -0500 Subject: [PATCH 8/8] clang format --- src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp | 13 ++++++++----- src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp | 3 ++- src/stan/mcmc/hmc/nuts/diag_e_nuts.hpp | 18 ++++++++++-------- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp index d39cc5ef005..c64f970e637 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp @@ -16,18 +16,21 @@ template class adapt_diag_e_nuts : public diag_e_nuts, public stepsize_var_adapter { public: - template * = nullptr> + template * = nullptr> adapt_diag_e_nuts(const Model& model, BaseRNG& rng) : diag_e_nuts(model, rng), stepsize_var_adapter(model.num_params_r()) {} - template * = nullptr> + template * = nullptr> adapt_diag_e_nuts(const Model& model, std::vector& thread_rngs) - : diag_e_nuts(model, thread_rngs), - stepsize_var_adapter(model.num_params_r()) {} + : diag_e_nuts(model, thread_rngs), + stepsize_var_adapter(model.num_params_r()) {} inline sample transition(sample& init_sample, callbacks::logger& logger) { - sample s = diag_e_nuts::transition(init_sample, logger); + sample s = diag_e_nuts::transition( + init_sample, logger); if (this->adapt_flag_) { this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, diff --git a/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp b/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp index 87b5c31ce51..2c3f486ad93 100644 --- a/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp @@ -263,7 +263,8 @@ class base_parallel_nuts // we want to keep running // TODO: We should also just run depth = 0 outside the loop to avoid the // if statement here - for (std::size_t depth = 0, fwd_idx = 0, bck_idx = 0; depth != this->max_depth_; ++depth) { + for (std::size_t depth = 0, fwd_idx = 0, bck_idx = 0; + depth != this->max_depth_; ++depth) { if (fwd_direction[depth]) { builder_iter_t fwd_iter = fwd_builder.emplace_back( g, [&, depth, fwd_idx](tbb::flow::continue_msg) { diff --git a/src/stan/mcmc/hmc/nuts/diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/diag_e_nuts.hpp index f3b8f95af3b..241283f1f8f 100644 --- a/src/stan/mcmc/hmc/nuts/diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/diag_e_nuts.hpp @@ -15,17 +15,19 @@ namespace mcmc { * with a Gaussian-Euclidean disintegration and diagonal metric */ template -class diag_e_nuts - : public base_nuts_ct { - using base_nuts_t = base_nuts_ct; +class diag_e_nuts : public base_nuts_ct { + using base_nuts_t = base_nuts_ct; + public: - template * = nullptr> - diag_e_nuts(const Model& model, BaseRNG& rng) - : base_nuts_t(model, rng) {} - template * = nullptr> + template * = nullptr> + diag_e_nuts(const Model& model, BaseRNG& rng) : base_nuts_t(model, rng) {} + template * = nullptr> diag_e_nuts(const Model& model, std::vector& thread_rngs) : base_nuts_t(model, thread_rngs) {} - }; } // namespace mcmc