-
-
Notifications
You must be signed in to change notification settings - Fork 371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MPI warmup auto #2886
Closed
Closed
MPI warmup auto #2886
Changes from all commits
Commits
Show all changes
78 commits
Select commit
Hold shift + click to select a range
1952caf
Added in switching adaptation
bbbales2 4dc0fff
Simplified code and added some comments
bbbales2 8000691
Changed 'switching' to 'auto'
bbbales2 b1205f3
Added try/catch so that if auto adaptation fails it falls back to dia…
bbbales2 a6ae242
Merge branch 'develop' of https://github.com/stan-dev/stan into exper…
bbbales2 baacc34
Fixed bug with regularization of automatically picked metric. Added t…
bbbales2 76bfd27
first draft of master-slave communication setup for warmup
7bfe044
unit test for mpi warmup communication
d57e3d9
inter chain and intra chain MPI communicators
a98744d
warmup loader MPI test
8d09106
unit test with mpi gathering of stepsize
b69492b
update math submodule to 7f4e3a4af5
1e11650
rhat as convergence critierior for mpi warmup
bae89e1
campfire warmup with rhat and synced stepsize
6e1fce8
pass num_chains through run_mpi_adaptive_sampler
45ed1d4
update submodule
54f5fa3
cross chain rhat calculation unit test
ef31851
unit test for cross-chain adapted warmup
d54c26a
use cross chain warmup for nuts adapted
1a0f743
stream writer only writes for rank == 0
84d9807
use degenerated version of compute ess for warmup
e4c6b50
fix branch of gitmodules
731b310
rename campfire warmup to cross chain warmup
b393841
rename campfire warmup to cross chain
524f750
add mpi cross chain adapter
c67a5e0
rename mpi adapter; add unit tests
4547a62
adapt unit/diag nuts inherit cross chain adapter
c6d98f7
check cross-chain convergence in during sampler transition
ab8ab7e
1st draft of adapting diag metric
536d789
use harmonic mean for ESS test
fc8c3ce
fix broadcasting metrics issue
a58e2a9
rm old mpi adapt implementation
7befdde
stream writer only writes to inter chain ranks
99cc236
rng seed should be from cmdstan
32f552e
pass `num_cross_chains` to diag_e_adapt function
16eac20
fix sequential compile failure when no MPI flags issued
26ca4f3
turn off metric aggregation. update submoduel
a8da936
unit test for mpi var adaptation
a703c4d
adapter with logging function and new mpi_var_adaptation
1a5be75
track all windows for end-of-warmup metric update
36bd480
use minimum instead of harmonic mean for ESS test
a34a661
mpi_var_adaptatin under #ifdef STAN_LANG_MPI
a299c50
clean up cross-chain adapter calls in sampler
77065b3
pass cross chain window size as function call arg
1410f2c
add target_ess to arg list
c35985e
cross_chain_rhat argument
efe33a5
output min ESS among the chains
57949e3
mpi_stream_writer
f53fa46
hack: replace num_warmup in output.csv with cross_chain num_warmup
8f0f0b1
Revert "hack: replace num_warmup in output.csv with cross_chain num_w…
ad706bc
write `num_warmup` before `adaption terminated` in csv
6301595
use Stan's ESS for cross-chain ESS test
0a1c7ed
no stepsize reset after cross-chain converges
2c75a62
init commit mpi_warmup_v2
de0b754
use mpi_cross_chain as single entry to cross-chain methods
88cfa38
move MPI_ADAPTED_WARMUP macro into adapter
da2f7b0
simplify cross chain interface
203e271
mpi warmup for dense_e nuts
cce0945
tests for mpi_covar_adpt. type traits for cross-chain warmup sampler
4acb8f7
change sed seed to set id
2982626
fix: unit nuts doesn't need var; no post warmup when max out
6a95b53
win loop to decrease counter in croos-chain adapter
9d5c99a
full campfire version where every window has aggregation
6da0fa6
harmonic mean for stepsize
7f6ec86
separate stepsize & metric aggregation
87356a1
no restart for mpi_var_adapt
92b8f3f
exclude initial buffer from covar calculation
83a450a
ignore init buffer draws in mpi var adaptation
f2835ce
Updated mpi file name adjuster to work with diagnostic files
bbbales2 9791601
Merge remote-tracking branch 'bbbales2/experimental/warmup-auto' into…
bbbales2 97af9ac
auto_e metric might be working
bbbales2 f5c903a
Merge branch 'mpi_warmup_v2' into mpi_warmup_auto
bbbales2 227a0a6
Updating stan math
bbbales2 37a3fa1
auto adaptation terminates properly now
bbbales2 02e93b4
Removed some unused files
bbbales2 278b1f3
Merge remote-tracking branch 'origin/mpi_warmup_v2' into mpi_warmup_auto
bbbales2 36af38d
Updated math includes
bbbales2 a93d90d
Merge remote-tracking branch 'origin/mpi_warmup_v2' into mpi_warmup_auto
bbbales2 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
#ifndef STAN_MCMC_HMC_HAMILTONIANS_AUTO_E_METRIC_HPP | ||
#define STAN_MCMC_HMC_HAMILTONIANS_AUTO_E_METRIC_HPP | ||
|
||
#include <stan/callbacks/logger.hpp> | ||
#include <stan/math/prim/fun.hpp> | ||
#include <stan/mcmc/hmc/hamiltonians/base_hamiltonian.hpp> | ||
#include <stan/mcmc/hmc/hamiltonians/auto_e_point.hpp> | ||
#include <boost/random/variate_generator.hpp> | ||
#include <boost/random/normal_distribution.hpp> | ||
#include <Eigen/Cholesky> | ||
|
||
namespace stan { | ||
namespace mcmc { | ||
|
||
// Euclidean manifold with dense metric | ||
template <class Model, class BaseRNG> | ||
class auto_e_metric | ||
: public base_hamiltonian<Model, auto_e_point, BaseRNG> { | ||
public: | ||
explicit auto_e_metric(const Model& model) | ||
: base_hamiltonian<Model, auto_e_point, BaseRNG>(model) {} | ||
|
||
double T(auto_e_point& z) { | ||
return 0.5 * z.p.transpose() * z.inv_e_metric_ * z.p; | ||
} | ||
|
||
double tau(auto_e_point& z) { | ||
return T(z); | ||
} | ||
|
||
double phi(auto_e_point& z) { | ||
return this->V(z); | ||
} | ||
|
||
double dG_dt(auto_e_point& z, callbacks::logger& logger) { | ||
return 2 * T(z) - z.q.dot(z.g); | ||
} | ||
|
||
Eigen::VectorXd dtau_dq(auto_e_point& z, callbacks::logger& logger) { | ||
return Eigen::VectorXd::Zero(this->model_.num_params_r()); | ||
} | ||
|
||
Eigen::VectorXd dtau_dp(auto_e_point& z) { | ||
if(z.is_diagonal_) { | ||
return z.inv_e_metric_.diagonal().cwiseProduct(z.p); | ||
} else { | ||
return z.inv_e_metric_ * z.p; | ||
} | ||
} | ||
|
||
Eigen::VectorXd dphi_dq(auto_e_point& z, callbacks::logger& logger) { | ||
return z.g; | ||
} | ||
|
||
void sample_p(auto_e_point& z, BaseRNG& rng) { | ||
typedef typename stan::math::index_type<Eigen::VectorXd>::type idx_t; | ||
boost::variate_generator<BaseRNG&, boost::normal_distribution<> > | ||
rand_gaus(rng, boost::normal_distribution<>()); | ||
|
||
if(z.is_diagonal_) { | ||
for (int i = 0; i < z.p.size(); ++i) | ||
z.p(i) = rand_gaus() / sqrt(z.inv_e_metric_(i, i)); | ||
} else { | ||
Eigen::VectorXd u(z.p.size()); | ||
|
||
for (idx_t i = 0; i < u.size(); ++i) | ||
u(i) = rand_gaus(); | ||
|
||
z.p = z.inv_e_metric_.llt().matrixU().solve(u); | ||
} | ||
} | ||
}; | ||
|
||
} // mcmc | ||
} // stan | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#ifndef STAN_MCMC_HMC_HAMILTONIANS_AUTO_E_POINT_HPP | ||
#define STAN_MCMC_HMC_HAMILTONIANS_AUTO_E_POINT_HPP | ||
|
||
#include <stan/callbacks/writer.hpp> | ||
#include <stan/mcmc/hmc/hamiltonians/ps_point.hpp> | ||
|
||
namespace stan { | ||
namespace mcmc { | ||
/** | ||
* Point in a phase space with a base | ||
* Euclidean manifold with auto metric | ||
*/ | ||
class auto_e_point: public ps_point { | ||
public: | ||
/** | ||
* Inverse mass matrix. | ||
*/ | ||
Eigen::MatrixXd inv_e_metric_; | ||
|
||
/** | ||
* Is inv_e_metric_ diagonal or not | ||
*/ | ||
bool is_diagonal_; | ||
|
||
/** | ||
* Construct a auto point in n-dimensional phase space | ||
* with identity matrix as inverse mass matrix. | ||
* | ||
* @param n number of dimensions | ||
*/ | ||
explicit auto_e_point(int n) | ||
: ps_point(n), inv_e_metric_(n, n), is_diagonal_(true) { | ||
inv_e_metric_.setIdentity(); | ||
} | ||
|
||
/** | ||
* Set elements of mass matrix | ||
* | ||
* @param inv_e_metric initial mass matrix | ||
*/ | ||
void | ||
set_metric(const Eigen::MatrixXd& inv_e_metric) { | ||
inv_e_metric_ = inv_e_metric; | ||
is_diagonal_ = false; | ||
} | ||
|
||
/** | ||
* Write elements of mass matrix to string and handoff to writer. | ||
* | ||
* @param writer Stan writer callback | ||
*/ | ||
inline | ||
void | ||
write_metric(stan::callbacks::writer& writer) { | ||
writer("Elements of inverse mass matrix:"); | ||
for (int i = 0; i < inv_e_metric_.rows(); ++i) { | ||
std::stringstream inv_e_metric_ss; | ||
inv_e_metric_ss << inv_e_metric_(i, 0); | ||
for (int j = 1; j < inv_e_metric_.cols(); ++j) | ||
inv_e_metric_ss << ", " << inv_e_metric_(i, j); | ||
writer(inv_e_metric_ss.str()); | ||
} | ||
} | ||
}; | ||
|
||
} // mcmc | ||
} // stan | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
#ifndef STAN_MCMC_HMC_NUTS_ADAPT_AUTO_E_NUTS_HPP | ||
#define STAN_MCMC_HMC_NUTS_ADAPT_AUTO_E_NUTS_HPP | ||
|
||
#include <stan/callbacks/logger.hpp> | ||
#include <stan/mcmc/hmc/nuts/auto_e_nuts.hpp> | ||
#include <stan/mcmc/hmc/mpi_cross_chain_adapter.hpp> | ||
#include <stan/mcmc/mpi_auto_adaptation.hpp> | ||
|
||
namespace stan { | ||
namespace mcmc { | ||
/** | ||
* The No-U-Turn sampler (NUTS) with multinomial sampling | ||
* with a Gaussian-Euclidean disintegration and adaptive | ||
* dense or diagonal metric and adaptive step size | ||
*/ | ||
template <class Model, class BaseRNG> | ||
class adapt_auto_e_nuts : public auto_e_nuts<Model, BaseRNG>, | ||
public mpi_cross_chain_adapter<adapt_auto_e_nuts<Model, BaseRNG>>, | ||
public stepsize_covar_adapter { | ||
protected: | ||
const Model& model_; | ||
public: | ||
adapt_auto_e_nuts(const Model& model, BaseRNG& rng) | ||
: model_(model), auto_e_nuts<Model, BaseRNG>(model, rng), | ||
stepsize_covar_adapter(model.num_params_r()) {} | ||
|
||
~adapt_auto_e_nuts() {} | ||
|
||
sample | ||
transition(sample& init_sample, callbacks::logger& logger) { | ||
sample s = auto_e_nuts<Model, BaseRNG>::transition(init_sample, | ||
logger); | ||
|
||
if (this->adapt_flag_) { | ||
this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, | ||
s.accept_stat()); | ||
|
||
if (this -> use_cross_chain_adapt()) { | ||
this -> add_cross_chain_sample(s.log_prob()); | ||
bool update = this -> cross_chain_adaptation(logger); | ||
if (this -> is_cross_chain_adapted()) { | ||
update = false; | ||
} | ||
|
||
if (update) { | ||
this->z_.is_diagonal_ = reinterpret_cast<mpi_auto_adaptation<Model> *>(this->metric_adapt)->is_diagonal_; | ||
|
||
this->init_stepsize(logger); | ||
|
||
this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); | ||
this->stepsize_adaptation_.restart(); | ||
|
||
this->set_cross_chain_stepsize(); | ||
} | ||
} else { | ||
bool update = this->covar_adaptation_.learn_covariance(this->z_.inv_e_metric_, | ||
this->z_.q); | ||
if (update) { | ||
this->init_stepsize(logger); | ||
|
||
this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); | ||
this->stepsize_adaptation_.restart(); | ||
} | ||
} | ||
} | ||
return s; | ||
} | ||
|
||
void disengage_adaptation() { | ||
base_adapter::disengage_adaptation(); | ||
if (!this -> is_cross_chain_adapted()) { | ||
this->stepsize_adaptation_.complete_adaptation(this->nom_epsilon_); | ||
} | ||
} | ||
}; | ||
|
||
} // mcmc | ||
} // stan | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#ifndef STAN_MCMC_HMC_NUTS_AUTO_E_NUTS_HPP | ||
#define STAN_MCMC_HMC_NUTS_AUTO_E_NUTS_HPP | ||
|
||
#include <stan/mcmc/hmc/nuts/base_nuts.hpp> | ||
#include <stan/mcmc/hmc/hamiltonians/dense_e_point.hpp> | ||
#include <stan/mcmc/hmc/hamiltonians/auto_e_metric.hpp> | ||
#include <stan/mcmc/hmc/integrators/expl_leapfrog.hpp> | ||
|
||
namespace stan { | ||
namespace mcmc { | ||
/** | ||
* The No-U-Turn sampler (NUTS) with multinomial sampling | ||
* with a Gaussian-Euclidean disintegration and dense metric | ||
*/ | ||
template <class Model, class BaseRNG> | ||
class auto_e_nuts : public base_nuts<Model, auto_e_metric, | ||
expl_leapfrog, BaseRNG> { | ||
public: | ||
auto_e_nuts(const Model& model, BaseRNG& rng) | ||
: base_nuts<Model, auto_e_metric, expl_leapfrog, | ||
BaseRNG>(model, rng) { } | ||
}; | ||
|
||
} // mcmc | ||
} // stan | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated the call to mpi adapter, see
stan/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp
Line 34 in 117664b