Skip to content

Commit

Permalink
merge mpi_warmup_framework
Browse files Browse the repository at this point in the history
  • Loading branch information
yiz committed Feb 19, 2020
2 parents 83a450a + 6052421 commit 117664b
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 32 deletions.
14 changes: 10 additions & 4 deletions src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,13 @@ namespace mcmc {

inline void
write_num_cross_chain_warmup(callbacks::writer& sample_writer,
int num_thin) {
size_t n = num_cross_chain_draws();
sample_writer("num_warmup = " + std::to_string(n / num_thin));
int num_thin, int num_warmup) {
if (use_cross_chain_adapt()) {
size_t n = num_cross_chain_draws();
sample_writer("num_warmup = " + std::to_string(n / num_thin));
} else {
sample_writer("num_warmup = " + std::to_string(num_warmup));
}
}

/*
Expand Down Expand Up @@ -454,7 +458,9 @@ namespace mcmc {
sampler.set_nominal_stepsize(new_stepsize);
}

inline bool use_cross_chain_adapt() { return true; }
inline bool use_cross_chain_adapt() {
return num_chains_ > 1;
}
};

#else // sequential version
Expand Down
25 changes: 14 additions & 11 deletions src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,30 @@ class adapt_dense_e_nuts : public dense_e_nuts<Model, BaseRNG>,
this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_,
s.accept_stat());

bool update;
if (this -> use_cross_chain_adapt()) {
this -> add_cross_chain_sample(s.log_prob());
update = this -> cross_chain_adaptation(logger);
bool update = this -> cross_chain_adaptation(logger);
if (this -> is_cross_chain_adapted()) {
update = false;
}
} else {
update = this->covar_adaptation_.learn_covariance(this->z_.inv_e_metric_,
this->z_.q);
}

if (update) {
this->init_stepsize(logger);
if (update) {
this->init_stepsize(logger);

this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
this->stepsize_adaptation_.restart();
this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
this->stepsize_adaptation_.restart();

if (this -> use_cross_chain_adapt()) {
this->set_cross_chain_stepsize();
}
} else {
bool update = this->covar_adaptation_.learn_covariance(this->z_.inv_e_metric_,
this->z_.q);
if (update) {
this->init_stepsize(logger);

this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
this->stepsize_adaptation_.restart();
}
}
}
return s;
Expand Down
26 changes: 15 additions & 11 deletions src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,31 @@ class adapt_diag_e_nuts : public diag_e_nuts<Model, BaseRNG>,
this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_,
s.accept_stat());

bool update;

if (this -> use_cross_chain_adapt()) {
this -> add_cross_chain_sample(s.log_prob());
update = this -> cross_chain_adaptation(logger);
bool update = this -> cross_chain_adaptation(logger);
if (this -> is_cross_chain_adapted()) {
update = false;
}
} else {
update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_,
this->z_.q);
}

if (update) {
this->init_stepsize(logger);
if (update) {
this->init_stepsize(logger);

this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
this->stepsize_adaptation_.restart();
this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
this->stepsize_adaptation_.restart();

if (this -> use_cross_chain_adapt()) {
this->set_cross_chain_stepsize();
}
} else {
bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_,
this->z_.q);
if (update) {
this->init_stepsize(logger);

this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
this->stepsize_adaptation_.restart();
}
}
}
return s;
Expand Down
10 changes: 5 additions & 5 deletions src/stan/services/util/mpi_cross_chain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ namespace util {

static void write_num_warmup(Sampler& sampler,
callbacks::writer& sample_writer,
int num_thin) {}
int num_thin, int num_warmup) {}
};

/*
Expand All @@ -81,8 +81,8 @@ namespace util {

static void write_num_warmup(Sampler& sampler,
callbacks::writer& sample_writer,
int num_thin) {
sampler.write_num_cross_chain_warmup(sample_writer, num_thin);
int num_thin, int num_warmup) {
sampler.write_num_cross_chain_warmup(sample_writer, num_thin, num_warmup);
}
};
#endif
Expand Down Expand Up @@ -111,9 +111,9 @@ namespace util {

static void write_num_warmup(Sampler& sampler,
callbacks::writer& sample_writer,
int num_thin) {
int num_thin, int num_warmup) {
mpi_cross_chain_impl<Sampler, has_cross_chain_warmup<Sampler>::value>::
write_num_warmup(sampler, sample_writer, num_thin);
write_num_warmup(sampler, sample_writer, num_thin, num_warmup);
}
};

Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/util/run_adaptive_sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void run_adaptive_sampler(Sampler& sampler, Model& model,
mpi_cross_chain<Sampler>::num_draws(sampler),
num_warmup + num_samples, num_thin, refresh, save_warmup,
true, writer, s, model, rng, interrupt, logger);
mpi_cross_chain<Sampler>::write_num_warmup(sampler, sample_writer, num_thin);
mpi_cross_chain<Sampler>::write_num_warmup(sampler, sample_writer, num_thin, num_warmup);

clock_t end = clock();
double warm_delta_t = static_cast<double>(end - start) / CLOCKS_PER_SEC;
Expand Down

0 comments on commit 117664b

Please sign in to comment.