Skip to content

Commit

Permalink
Merge pull request #3263 from stan-dev/refactor/switch-rngs
Browse files Browse the repository at this point in the history
Refactor: Avoid hardcoding boost::ecuyer1988
  • Loading branch information
WardBrian authored Jan 31, 2024
2 parents a49908b + 35444e1 commit 476a356
Show file tree
Hide file tree
Showing 77 changed files with 444 additions and 529 deletions.
4 changes: 1 addition & 3 deletions src/stan/mcmc/chains.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
#include <boost/accumulators/statistics/variance.hpp>
#include <boost/accumulators/statistics/covariance.hpp>
#include <boost/accumulators/statistics/variates/covariate.hpp>
#include <boost/random/uniform_int_distribution.hpp>
#include <boost/random/additive_combine.hpp>
#include <algorithm>
#include <cmath>
#include <iostream>
Expand Down Expand Up @@ -44,7 +42,7 @@ using Eigen::Dynamic;
*
* <p><b>Storage Order</b>: Storage is column/last-index major.
*/
template <class RNG = boost::random::ecuyer1988>
template <typename Unused = void*>
class chains {
private:
std::vector<std::string> param_names_;
Expand Down
8 changes: 3 additions & 5 deletions src/stan/model/model_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <stan/math/mix.hpp>
#endif
#include <stan/model/prob_grad.hpp>
#include <boost/random/additive_combine.hpp>
#include <stan/services/util/create_rng.hpp>
#include <ostream>
#include <string>
#include <utility>
Expand Down Expand Up @@ -367,8 +367,7 @@ class model_base : public prob_grad {
* in output
* @param[in,out] msgs msgs stream to which messages are written
*/
virtual void write_array(boost::ecuyer1988& base_rng,
Eigen::VectorXd& params_r,
virtual void write_array(stan::rng_t& base_rng, Eigen::VectorXd& params_r,
Eigen::VectorXd& params_constrained_r,
bool include_tparams = true, bool include_gqs = true,
std::ostream* msgs = 0) const = 0;
Expand Down Expand Up @@ -618,8 +617,7 @@ class model_base : public prob_grad {
* in output
* @param[in,out] msgs msgs stream to which messages are written
*/
virtual void write_array(boost::ecuyer1988& base_rng,
std::vector<double>& params_r,
virtual void write_array(stan::rng_t& base_rng, std::vector<double>& params_r,
std::vector<int>& params_i,
std::vector<double>& params_r_constrained,
bool include_tparams = true, bool include_gqs = true,
Expand Down
4 changes: 2 additions & 2 deletions src/stan/model/model_base_crtp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class model_base_crtp : public stan::model::model_base {
msgs);
}

void write_array(boost::ecuyer1988& rng, Eigen::VectorXd& theta,
void write_array(stan::rng_t& rng, Eigen::VectorXd& theta,
Eigen::VectorXd& vars, bool include_tparams = true,
bool include_gqs = true,
std::ostream* msgs = 0) const override {
Expand Down Expand Up @@ -202,7 +202,7 @@ class model_base_crtp : public stan::model::model_base {
theta, theta_i, msgs);
}

void write_array(boost::ecuyer1988& rng, std::vector<double>& theta,
void write_array(stan::rng_t& rng, std::vector<double>& theta,
std::vector<int>& theta_i, std::vector<double>& vars,
bool include_tparams = true, bool include_gqs = true,
std::ostream* msgs = 0) const override {
Expand Down
3 changes: 0 additions & 3 deletions src/stan/model/model_header.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
#include <stan/model/indexing.hpp>
#include <stan/services/util/create_rng.hpp>

#include <boost/random/additive_combine.hpp>
#include <boost/random/linear_congruential.hpp>

#include <cmath>
#include <cstddef>
#include <fstream>
Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/diagnose/diagnose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ int diagnose(Model& model, const stan::io::var_context& init,
double epsilon, double error, callbacks::interrupt& interrupt,
callbacks::logger& logger, callbacks::writer& init_writer,
callbacks::writer& parameter_writer) {
boost::ecuyer1988 rng = util::create_rng(random_seed, chain);
stan::rng_t rng = util::create_rng(random_seed, chain);

std::vector<int> disc_vector;
std::vector<double> cont_vector = util::initialize(
Expand Down
5 changes: 2 additions & 3 deletions src/stan/services/experimental/advi/fullrank.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <stan/services/error_codes.hpp>
#include <stan/io/var_context.hpp>
#include <stan/variational/advi.hpp>
#include <boost/random/additive_combine.hpp>
#include <string>
#include <vector>

Expand Down Expand Up @@ -60,7 +59,7 @@ int fullrank(Model& model, const stan::io::var_context& init,
callbacks::writer& diagnostic_writer) {
util::experimental_message(logger);

boost::ecuyer1988 rng = util::create_rng(random_seed, chain);
stan::rng_t rng = util::create_rng(random_seed, chain);

std::vector<int> disc_vector;
std::vector<double> cont_vector;
Expand All @@ -84,7 +83,7 @@ int fullrank(Model& model, const stan::io::var_context& init,
= Eigen::Map<Eigen::VectorXd>(&cont_vector[0], cont_vector.size(), 1);

stan::variational::advi<Model, stan::variational::normal_fullrank,
boost::ecuyer1988>
stan::rng_t>
cmd_advi(model, cont_params, rng, grad_samples, elbo_samples, eval_elbo,
output_samples);
cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj,
Expand Down
5 changes: 2 additions & 3 deletions src/stan/services/experimental/advi/meanfield.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <stan/services/error_codes.hpp>
#include <stan/io/var_context.hpp>
#include <stan/variational/advi.hpp>
#include <boost/random/additive_combine.hpp>
#include <string>
#include <vector>

Expand Down Expand Up @@ -60,7 +59,7 @@ int meanfield(Model& model, const stan::io::var_context& init,
callbacks::writer& diagnostic_writer) {
util::experimental_message(logger);

boost::ecuyer1988 rng = util::create_rng(random_seed, chain);
stan::rng_t rng = util::create_rng(random_seed, chain);

std::vector<int> disc_vector;
std::vector<double> cont_vector;
Expand All @@ -83,7 +82,7 @@ int meanfield(Model& model, const stan::io::var_context& init,
= Eigen::Map<Eigen::VectorXd>(&cont_vector[0], cont_vector.size(), 1);

stan::variational::advi<Model, stan::variational::normal_meanfield,
boost::ecuyer1988>
stan::rng_t>
cmd_advi(model, cont_params, rng, grad_samples, elbo_samples, eval_elbo,
output_samples);
cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj,
Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/optimize/bfgs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ int bfgs(Model& model, const stan::io::var_context& init,
bool save_iterations, int refresh, callbacks::interrupt& interrupt,
callbacks::logger& logger, callbacks::writer& init_writer,
callbacks::writer& parameter_writer) {
boost::ecuyer1988 rng = util::create_rng(random_seed, chain);
stan::rng_t rng = util::create_rng(random_seed, chain);

std::vector<int> disc_vector;
std::vector<double> cont_vector;
Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/optimize/laplace_sample.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void laplace_sample(const Model& model, const Eigen::VectorXd& theta_hat,
}
// generate draws
std::stringstream refresh_msg;
boost::ecuyer1988 rng = util::create_rng(random_seed, 0);
stan::rng_t rng = util::create_rng(random_seed, 0);
Eigen::VectorXd draw_vec; // declare draw_vec, msgs here to avoid re-alloc
for (int m = 0; m < draws; ++m) {
interrupt(); // allow interpution each iteration
Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/optimize/lbfgs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ int lbfgs(Model& model, const stan::io::var_context& init,
int refresh, callbacks::interrupt& interrupt,
callbacks::logger& logger, callbacks::writer& init_writer,
callbacks::writer& parameter_writer) {
boost::ecuyer1988 rng = util::create_rng(random_seed, chain);
stan::rng_t rng = util::create_rng(random_seed, chain);

std::vector<int> disc_vector;
std::vector<double> cont_vector;
Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/optimize/newton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ int newton(Model& model, const stan::io::var_context& init,
callbacks::interrupt& interrupt, callbacks::logger& logger,
callbacks::writer& init_writer,
callbacks::writer& parameter_writer) {
boost::ecuyer1988 rng = util::create_rng(random_seed, chain);
stan::rng_t rng = util::create_rng(random_seed, chain);

std::vector<int> disc_vector;
std::vector<double> cont_vector;
Expand Down
8 changes: 3 additions & 5 deletions src/stan/services/pathfinder/multi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,9 @@ inline int pathfinder_lbfgs_multi(
3 * std::sqrt(num_returned_samples));
Eigen::Array<double, Eigen::Dynamic, 1> weight_vals
= stan::services::psis::psis_weights(lp_ratios, tail_len, logger);
boost::ecuyer1988 rng
= util::create_rng<boost::ecuyer1988>(random_seed, stride_id);
boost::variate_generator<
boost::ecuyer1988&,
boost::random::discrete_distribution<Eigen::Index, double>>
stan::rng_t rng = util::create_rng(random_seed, stride_id);
boost::variate_generator<stan::rng_t&, boost::random::discrete_distribution<
Eigen::Index, double>>
rand_psis_idx(
rng, boost::random::discrete_distribution<Eigen::Index, double>(
boost::iterator_range<double*>(
Expand Down
5 changes: 2 additions & 3 deletions src/stan/services/pathfinder/single.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ inline elbo_est_t est_approx_draws(LPF&& lp_fun, ConstrainF&& constrain_fun,
size_t num_samples, const EigVec& alpha,
const std::string& iter_msg, Logger&& logger,
bool calculate_lp = true) {
boost::variate_generator<boost::ecuyer1988&, boost::normal_distribution<>>
boost::variate_generator<stan::rng_t&, boost::normal_distribution<>>
rand_unit_gaus(rng, boost::normal_distribution<>());
const auto num_params = taylor_approx.x_center.size();
size_t lp_fun_calls = 0;
Expand Down Expand Up @@ -607,8 +607,7 @@ inline auto pathfinder_lbfgs_single(
callbacks::writer& init_writer, ParamWriter& parameter_writer,
DiagnosticWriter& diagnostic_writer, bool calculate_lp = true) {
const auto start_pathfinder_time = std::chrono::steady_clock::now();
boost::ecuyer1988 rng
= util::create_rng<boost::ecuyer1988>(random_seed, stride_id);
stan::rng_t rng = util::create_rng(random_seed, stride_id);
std::vector<int> disc_vector;
std::vector<double> cont_vector;

Expand Down
4 changes: 2 additions & 2 deletions src/stan/services/sample/fixed_param.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ int fixed_param(Model& model, const stan::io::var_context& init,
callbacks::writer& init_writer,
callbacks::writer& sample_writer,
callbacks::writer& diagnostic_writer) {
boost::ecuyer1988 rng = util::create_rng(random_seed, chain);
stan::rng_t rng = util::create_rng(random_seed, chain);

std::vector<int> disc_vector;
std::vector<double> cont_vector;
Expand Down Expand Up @@ -134,7 +134,7 @@ int fixed_param(Model& model, const std::size_t num_chains,
init_writer[0], sample_writers[0],
diagnostic_writers[0]);
}
std::vector<boost::ecuyer1988> rngs;
std::vector<stan::rng_t> rngs;
std::vector<Eigen::VectorXd> cont_vectors;
std::vector<util::mcmc_writer> writers;
std::vector<stan::mcmc::sample> samples;
Expand Down
8 changes: 4 additions & 4 deletions src/stan/services/sample/hmc_nuts_dense_e.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ int hmc_nuts_dense_e(Model& model, const stan::io::var_context& init,
callbacks::writer& init_writer,
callbacks::writer& sample_writer,
callbacks::writer& diagnostic_writer) {
boost::ecuyer1988 rng = util::create_rng(random_seed, chain);
stan::rng_t rng = util::create_rng(random_seed, chain);

std::vector<int> disc_vector;
std::vector<double> cont_vector;
Expand All @@ -73,7 +73,7 @@ int hmc_nuts_dense_e(Model& model, const stan::io::var_context& init,
return error_codes::CONFIG;
}

stan::mcmc::dense_e_nuts<Model, boost::ecuyer1988> sampler(model, rng);
stan::mcmc::dense_e_nuts<Model, stan::rng_t> sampler(model, rng);

sampler.set_metric(inv_metric);

Expand Down Expand Up @@ -197,11 +197,11 @@ int hmc_nuts_dense_e(Model& model, size_t num_chains,
stepsize, stepsize_jitter, max_depth, interrupt, logger, init_writer[0],
sample_writer[0], diagnostic_writer[0]);
}
std::vector<boost::ecuyer1988> rngs;
std::vector<stan::rng_t> rngs;
rngs.reserve(num_chains);
std::vector<std::vector<double>> cont_vectors;
cont_vectors.reserve(num_chains);
using sample_t = stan::mcmc::dense_e_nuts<Model, boost::ecuyer1988>;
using sample_t = stan::mcmc::dense_e_nuts<Model, stan::rng_t>;
std::vector<sample_t> samplers;
samplers.reserve(num_chains);
try {
Expand Down
8 changes: 4 additions & 4 deletions src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ int hmc_nuts_dense_e_adapt(
callbacks::logger& logger, callbacks::writer& init_writer,
callbacks::writer& sample_writer, callbacks::writer& diagnostic_writer,
callbacks::structured_writer& metric_writer) {
boost::ecuyer1988 rng = util::create_rng(random_seed, chain);
stan::rng_t rng = util::create_rng(random_seed, chain);

std::vector<double> cont_vector;

Expand All @@ -82,7 +82,7 @@ int hmc_nuts_dense_e_adapt(
return error_codes::CONFIG;
}

stan::mcmc::adapt_dense_e_nuts<Model, boost::ecuyer1988> sampler(model, rng);
stan::mcmc::adapt_dense_e_nuts<Model, stan::rng_t> sampler(model, rng);

sampler.set_metric(inv_metric);

Expand Down Expand Up @@ -347,8 +347,8 @@ int hmc_nuts_dense_e_adapt(
init_buffer, term_buffer, window, interrupt, logger, init_writer[0],
sample_writer[0], diagnostic_writer[0], metric_writer[0]);
}
using sample_t = stan::mcmc::adapt_dense_e_nuts<Model, boost::ecuyer1988>;
std::vector<boost::ecuyer1988> rngs;
using sample_t = stan::mcmc::adapt_dense_e_nuts<Model, stan::rng_t>;
std::vector<stan::rng_t> rngs;
rngs.reserve(num_chains);
std::vector<std::vector<double>> cont_vectors;
cont_vectors.reserve(num_chains);
Expand Down
8 changes: 4 additions & 4 deletions src/stan/services/sample/hmc_nuts_diag_e.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ int hmc_nuts_diag_e(Model& model, const stan::io::var_context& init,
callbacks::writer& init_writer,
callbacks::writer& sample_writer,
callbacks::writer& diagnostic_writer) {
boost::ecuyer1988 rng = util::create_rng(random_seed, chain);
stan::rng_t rng = util::create_rng(random_seed, chain);
std::vector<int> disc_vector;
std::vector<double> cont_vector;

Expand All @@ -72,7 +72,7 @@ int hmc_nuts_diag_e(Model& model, const stan::io::var_context& init,
return error_codes::CONFIG;
}

stan::mcmc::diag_e_nuts<Model, boost::ecuyer1988> sampler(model, rng);
stan::mcmc::diag_e_nuts<Model, stan::rng_t> sampler(model, rng);

sampler.set_metric(inv_metric);
sampler.set_nominal_stepsize(stepsize);
Expand Down Expand Up @@ -194,11 +194,11 @@ int hmc_nuts_diag_e(Model& model, size_t num_chains,
stepsize, stepsize_jitter, max_depth, interrupt, logger, init_writer[0],
sample_writer[0], diagnostic_writer[0]);
}
std::vector<boost::ecuyer1988> rngs;
std::vector<stan::rng_t> rngs;
rngs.reserve(num_chains);
std::vector<std::vector<double>> cont_vectors;
cont_vectors.reserve(num_chains);
using sample_t = stan::mcmc::diag_e_nuts<Model, boost::ecuyer1988>;
using sample_t = stan::mcmc::diag_e_nuts<Model, stan::rng_t>;
std::vector<sample_t> samplers;
samplers.reserve(num_chains);
try {
Expand Down
8 changes: 4 additions & 4 deletions src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ int hmc_nuts_diag_e_adapt(
callbacks::logger& logger, callbacks::writer& init_writer,
callbacks::writer& sample_writer, callbacks::writer& diagnostic_writer,
callbacks::structured_writer& metric_writer) {
boost::ecuyer1988 rng = util::create_rng(random_seed, chain);
stan::rng_t rng = util::create_rng(random_seed, chain);

std::vector<double> cont_vector;

Expand All @@ -83,7 +83,7 @@ int hmc_nuts_diag_e_adapt(
return error_codes::CONFIG;
}

stan::mcmc::adapt_diag_e_nuts<Model, boost::ecuyer1988> sampler(model, rng);
stan::mcmc::adapt_diag_e_nuts<Model, stan::rng_t> sampler(model, rng);

sampler.set_metric(inv_metric);
sampler.set_nominal_stepsize(stepsize);
Expand Down Expand Up @@ -347,8 +347,8 @@ int hmc_nuts_diag_e_adapt(
init_buffer, term_buffer, window, interrupt, logger, init_writer[0],
sample_writer[0], diagnostic_writer[0], metric_writer[0]);
}
using sample_t = stan::mcmc::adapt_diag_e_nuts<Model, boost::ecuyer1988>;
std::vector<boost::ecuyer1988> rngs;
using sample_t = stan::mcmc::adapt_diag_e_nuts<Model, stan::rng_t>;
std::vector<stan::rng_t> rngs;
rngs.reserve(num_chains);
std::vector<std::vector<double>> cont_vectors;
cont_vectors.reserve(num_chains);
Expand Down
8 changes: 4 additions & 4 deletions src/stan/services/sample/hmc_nuts_unit_e.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ int hmc_nuts_unit_e(Model& model, const stan::io::var_context& init,
callbacks::writer& init_writer,
callbacks::writer& sample_writer,
callbacks::writer& diagnostic_writer) {
boost::ecuyer1988 rng = util::create_rng(random_seed, chain);
stan::rng_t rng = util::create_rng(random_seed, chain);

std::vector<int> disc_vector;
std::vector<double> cont_vector;
Expand All @@ -64,7 +64,7 @@ int hmc_nuts_unit_e(Model& model, const stan::io::var_context& init,
logger.error(e.what());
return error_codes::CONFIG;
}
stan::mcmc::unit_e_nuts<Model, boost::ecuyer1988> sampler(model, rng);
stan::mcmc::unit_e_nuts<Model, stan::rng_t> sampler(model, rng);
sampler.set_nominal_stepsize(stepsize);
sampler.set_stepsize_jitter(stepsize_jitter);
sampler.set_max_depth(max_depth);
Expand Down Expand Up @@ -133,8 +133,8 @@ int hmc_nuts_unit_e(Model& model, size_t num_chains,
max_depth, interrupt, logger, init_writer[0],
sample_writer[0], diagnostic_writer[0]);
}
using sample_t = stan::mcmc::unit_e_nuts<Model, boost::ecuyer1988>;
std::vector<boost::ecuyer1988> rngs;
using sample_t = stan::mcmc::unit_e_nuts<Model, stan::rng_t>;
std::vector<stan::rng_t> rngs;
rngs.reserve(num_chains);
std::vector<std::vector<double>> cont_vectors;
cont_vectors.reserve(num_chains);
Expand Down
Loading

0 comments on commit 476a356

Please sign in to comment.