Skip to content

Commit

Permalink
Merge branch 'feature/3181-json-hmc-tuning-params' of https://github.…
Browse files Browse the repository at this point in the history
…com/stan-dev/stan into feature/3181-json-hmc-tuning-params
  • Loading branch information
mitzimorris committed Sep 20, 2023
2 parents da64fdc + dd455a9 commit c9ed01d
Show file tree
Hide file tree
Showing 6 changed files with 406 additions and 16 deletions.
163 changes: 163 additions & 0 deletions src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,87 @@ int hmc_nuts_dense_e_adapt(
return error_codes::OK;
}

/**
* Runs multiple chains of NUTS with adaptation using dense Euclidean metric,
* with a pre-specified dense metric.
*
* @tparam Model Model class
* @tparam InitContextPtr A pointer with underlying type derived from
* `stan::io::var_context`
* @tparam InitInvContextPtr A pointer with underlying type derived from
* `stan::io::var_context`
* @tparam InitWriter A type derived from `stan::callbacks::writer`
* @tparam SamplerWriter A type derived from `stan::callbacks::writer`
* @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
* @param[in] model Input model (with data already instantiated)
* @param[in] num_chains The number of chains to run in parallel. `init`,
* `init_writer`, `sample_writer`, and `diagnostic_writer`
* must be the same length as this value.
* @param[in] init A std vector of init var contexts for initialization
* of each chain.
* @param[in] init_inv_metric var context exposing an initial dense
* inverse Euclidean metric (must be positive definite)
* @param[in] random_seed random seed for the random number generator
* @param[in] init_chain_id first chain id. The pseudo random number generator
* will advance by for each chain by an integer sequence from `init_chain_id` to
* `init_chain_id+num_chains-1`
* @param[in] init_radius radius to initialize
* @param[in] num_warmup Number of warmup samples
* @param[in] num_samples Number of samples
* @param[in] num_thin Number to thin the samples
* @param[in] save_warmup Indicates whether to save the warmup iterations
* @param[in] refresh Controls the output
* @param[in] stepsize initial stepsize for discrete evolution
* @param[in] stepsize_jitter uniform random jitter of stepsize
* @param[in] max_depth Maximum tree depth
* @param[in] delta adaptation target acceptance statistic
* @param[in] gamma adaptation regularization scale
* @param[in] kappa adaptation relaxation exponent
* @param[in] t0 adaptation iteration offset
* @param[in] init_buffer width of initial fast adaptation interval
* @param[in] term_buffer width of final fast adaptation interval
* @param[in] window initial width of slow adaptation interval
* @param[in,out] interrupt Callback for interrupts
* @param[in,out] logger Logger for messages
* @param[in,out] init_writer std vector of Writer callbacks for unconstrained
* inits of each chain.
* @param[in,out] sample_writer std vector of Writers for draws of each chain.
* @param[in,out] diagnostic_writer std vector of Writers for diagnostic
* information of each chain.
* @return error_codes::OK if successful
*/
template <class Model, typename InitContextPtr, typename InitInvContextPtr,
typename InitWriter, typename SampleWriter, typename DiagnosticWriter>
int hmc_nuts_dense_e_adapt(
Model& model, size_t num_chains, const std::vector<InitContextPtr>& init,
const std::vector<InitInvContextPtr>& init_inv_metric,
unsigned int random_seed, unsigned int init_chain_id, double init_radius,
int num_warmup, int num_samples, int num_thin, bool save_warmup,
int refresh, double stepsize, double stepsize_jitter, int max_depth,
double delta, double gamma, double kappa, double t0,
unsigned int init_buffer, unsigned int term_buffer, unsigned int window,
callbacks::interrupt& interrupt, callbacks::logger& logger,
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer) {
std::vector<stan::callbacks::structured_writer> dummy_metric_writer(
num_chains);
if (num_chains == 1) {
return hmc_nuts_dense_e_adapt(
model, *init[0], *init_inv_metric[0], random_seed, init_chain_id,
init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh,
stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0,
init_buffer, term_buffer, window, interrupt, logger, init_writer[0],
sample_writer[0], diagnostic_writer[0], dummy_metric_writer[0]);
}
return hmc_nuts_dense_e_adapt(
model, num_chains, init, init_inv_metric, random_seed, init_chain_id,
init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh,
stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0,
init_buffer, term_buffer, window, interrupt, logger, init_writer,
sample_writer, diagnostic_writer, dummy_metric_writer);
}

/**
* Runs multiple chains of NUTS with adaptation using dense Euclidean metric,
* with identity matrix as initial inv_metric and saves adapted tuning
Expand Down Expand Up @@ -484,6 +565,88 @@ int hmc_nuts_dense_e_adapt(
sample_writer, diagnostic_writer, metric_writer);
}

/**
* Runs multiple chains of NUTS with adaptation using dense Euclidean metric,
* with identity matrix as initial inv_metric.
*
* @tparam Model Model class
* @tparam InitContextPtr A pointer with underlying type derived from
* `stan::io::var_context`
* @tparam InitWriter A type derived from `stan::callbacks::writer`
* @tparam SamplerWriter A type derived from `stan::callbacks::writer`
* @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
* @param[in] model Input model (with data already instantiated)
* @param[in] num_chains The number of chains to run in parallel. `init`,
* `init_writer`, `sample_writer`, and `diagnostic_writer`
* must be the same length as this value.
* @param[in] init A std vector of init var contexts for initialization of each
* chain.
* @param[in] random_seed random seed for the random number generator
* @param[in] init_chain_id first chain id. The pseudo random number generator
* will advance by for each chain by an integer sequence from `init_chain_id` to
* `init_chain_id+num_chains-1`
* @param[in] init_radius radius to initialize
* @param[in] num_warmup Number of warmup samples
* @param[in] num_samples Number of samples
* @param[in] num_thin Number to thin the samples
* @param[in] save_warmup Indicates whether to save the warmup iterations
* @param[in] refresh Controls the output
* @param[in] stepsize initial stepsize for discrete evolution
* @param[in] stepsize_jitter uniform random jitter of stepsize
* @param[in] max_depth Maximum tree depth
* @param[in] delta adaptation target acceptance statistic
* @param[in] gamma adaptation regularization scale
* @param[in] kappa adaptation relaxation exponent
* @param[in] t0 adaptation iteration offset
* @param[in] init_buffer width of initial fast adaptation interval
* @param[in] term_buffer width of final fast adaptation interval
* @param[in] window initial width of slow adaptation interval
* @param[in,out] interrupt Callback for interrupts
* @param[in,out] logger Logger for messages
* @param[in,out] init_writer std vector of Writer callbacks for unconstrained
* inits of each chain.
* @param[in,out] sample_writer std vector of Writers for draws of each chain.
* @param[in,out] diagnostic_writer std vector of Writers for diagnostic
* information of each chain.
* @return error_codes::OK if successful
*/
template <class Model, typename InitContextPtr, typename InitWriter,
typename SampleWriter, typename DiagnosticWriter>
int hmc_nuts_dense_e_adapt(
Model& model, size_t num_chains, const std::vector<InitContextPtr>& init,
unsigned int random_seed, unsigned int init_chain_id, double init_radius,
int num_warmup, int num_samples, int num_thin, bool save_warmup,
int refresh, double stepsize, double stepsize_jitter, int max_depth,
double delta, double gamma, double kappa, double t0,
unsigned int init_buffer, unsigned int term_buffer, unsigned int window,
callbacks::interrupt& interrupt, callbacks::logger& logger,
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer) {
std::vector<std::unique_ptr<stan::io::dump>> unit_e_metric;
unit_e_metric.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
unit_e_metric.emplace_back(std::make_unique<stan::io::dump>(
util::create_unit_e_dense_inv_metric(model.num_params_r())));
}
std::vector<stan::callbacks::structured_writer> dummy_metric_writer(
num_chains);
if (num_chains == 1) {
return hmc_nuts_dense_e_adapt(
model, *init[0], *unit_e_metric[0], random_seed, init_chain_id,
init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh,
stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0,
init_buffer, term_buffer, window, interrupt, logger, init_writer[0],
sample_writer[0], diagnostic_writer[0], dummy_metric_writer[0]);
}
return hmc_nuts_dense_e_adapt(
model, num_chains, init, unit_e_metric, random_seed, init_chain_id,
init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh,
stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0,
init_buffer, term_buffer, window, interrupt, logger, init_writer,
sample_writer, diagnostic_writer, dummy_metric_writer);
}

} // namespace sample
} // namespace services
} // namespace stan
Expand Down
165 changes: 165 additions & 0 deletions src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,87 @@ int hmc_nuts_diag_e_adapt(
return error_codes::OK;
}

/**
* Runs multiple chains of HMC with NUTS with adaptation using diagonal
* Euclidean metric with a pre-specified diagonal metric.
*
* @tparam Model Model class
* @tparam InitContextPtr A pointer with underlying type derived from
* `stan::io::var_context`
* @tparam InitContextPtr A pointer with underlying type derived from
* `stan::io::var_context`
* @tparam InitWriter A type derived from `stan::callbacks::writer`
* @tparam SamplerWriter A type derived from `stan::callbacks::writer`
* @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
* @param[in] model Input model (with data already instantiated)
* @param[in] num_chains The number of chains to run in parallel. `init`,
* `init_writer`, `sample_writer`, and `diagnostic_writer` must be the same
* length as this value.
* @param[in] init A std vector of init var contexts for initialization
* of each chain.
* @param[in] init_inv_metric A std vector of var contexts exposing an initial
* diagonal inverse Euclidean metric for each chain (must be positive definite)
* @param[in] random_seed random seed for the random number generator
* @param[in] init_chain_id first chain id. The pseudo random number generator
* will advance by for each chain by an integer sequence from `init_chain_id` to
* `init_chain_id+num_chains-1`
* @param[in] init_radius radius to initialize
* @param[in] num_warmup Number of warmup samples
* @param[in] num_samples Number of samples
* @param[in] num_thin Number to thin the samples
* @param[in] save_warmup Indicates whether to save the warmup iterations
* @param[in] refresh Controls the output
* @param[in] stepsize initial stepsize for discrete evolution
* @param[in] stepsize_jitter uniform random jitter of stepsize
* @param[in] max_depth Maximum tree depth
* @param[in] delta adaptation target acceptance statistic
* @param[in] gamma adaptation regularization scale
* @param[in] kappa adaptation relaxation exponent
* @param[in] t0 adaptation iteration offset
* @param[in] init_buffer width of initial fast adaptation interval
* @param[in] term_buffer width of final fast adaptation interval
* @param[in] window initial width of slow adaptation interval
* @param[in,out] interrupt Callback for interrupts
* @param[in,out] logger Logger for messages
* @param[in,out] init_writer std vector of Writer callbacks for unconstrained
* inits of each chain.
* @param[in,out] sample_writer std vector of Writers for draws of each chain.
* @param[in,out] diagnostic_writer std vector of Writers for diagnostic
* information of each chain.
* @return error_codes::OK if successful
*/
template <class Model, typename InitContextPtr, typename InitInvContextPtr,
typename InitWriter, typename SampleWriter, typename DiagnosticWriter>
int hmc_nuts_diag_e_adapt(
Model& model, size_t num_chains, const std::vector<InitContextPtr>& init,
const std::vector<InitInvContextPtr>& init_inv_metric,
unsigned int random_seed, unsigned int init_chain_id, double init_radius,
int num_warmup, int num_samples, int num_thin, bool save_warmup,
int refresh, double stepsize, double stepsize_jitter, int max_depth,
double delta, double gamma, double kappa, double t0,
unsigned int init_buffer, unsigned int term_buffer, unsigned int window,
callbacks::interrupt& interrupt, callbacks::logger& logger,
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer) {
std::vector<stan::callbacks::structured_writer> dummy_metric_writer(
num_chains);
if (num_chains == 1) {
return hmc_nuts_diag_e_adapt(
model, *init[0], *init_inv_metric[0], random_seed, init_chain_id,
init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh,
stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0,
init_buffer, term_buffer, window, interrupt, logger, init_writer[0],
sample_writer[0], diagnostic_writer[0], dummy_metric_writer[0]);
}
return hmc_nuts_diag_e_adapt(
model, num_chains, init, init_inv_metric, random_seed, init_chain_id,
init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh,
stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0,
init_buffer, term_buffer, window, interrupt, logger, init_writer,
sample_writer, diagnostic_writer, dummy_metric_writer);
}

/**
* Runs multiple chains of HMC with NUTS with adaptation using diagonal
* with identity matrix as initial inv_metric and saves adapted tuning
Expand Down Expand Up @@ -484,6 +565,90 @@ int hmc_nuts_diag_e_adapt(
sample_writer, diagnostic_writer, metric_writer);
}

/**
* Runs multiple chains of HMC with NUTS with adaptation using diagonal
* with identity matrix as initial inv_metric.
*
* @tparam Model Model class
* @tparam InitContextPtr A pointer with underlying type derived from
* `stan::io::var_context`
* @tparam InitWriter A type derived from `stan::callbacks::writer`
* @tparam SamplerWriter A type derived from `stan::callbacks::writer`
* @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
* @param[in] model Input model (with data already instantiated)
* @param[in] num_chains The number of chains to run in parallel. `init`,
* `init_writer`, `sample_writer`, and `diagnostic_writer` must be the same
* length as this value.
* @param[in] init A std vector of init var contexts for initialization of each
* chain.
* @param[in] init_inv_metric var context exposing an initial diagonal
* inverse Euclidean metric (must be positive definite)
* @param[in] random_seed random seed for the random number generator
* @param[in] init_chain_id first chain id. The pseudo random number generator
* will advance by for each chain by an integer sequence from `init_chain_id` to
* `init_chain_id+num_chains-1`
* @param[in] init_radius radius to initialize
* @param[in] num_warmup Number of warmup samples
* @param[in] num_samples Number of samples
* @param[in] num_thin Number to thin the samples
* @param[in] save_warmup Indicates whether to save the warmup iterations
* @param[in] refresh Controls the output
* @param[in] stepsize initial stepsize for discrete evolution
* @param[in] stepsize_jitter uniform random jitter of stepsize
* @param[in] max_depth Maximum tree depth
* @param[in] delta adaptation target acceptance statistic
* @param[in] gamma adaptation regularization scale
* @param[in] kappa adaptation relaxation exponent
* @param[in] t0 adaptation iteration offset
* @param[in] init_buffer width of initial fast adaptation interval
* @param[in] term_buffer width of final fast adaptation interval
* @param[in] window initial width of slow adaptation interval
* @param[in,out] interrupt Callback for interrupts
* @param[in,out] logger Logger for messages
* @param[in,out] init_writer std vector of Writer callbacks for unconstrained
* inits of each chain.
* @param[in,out] sample_writer std vector of Writers for draws of each chain.
* @param[in,out] diagnostic_writer std vector of Writers for diagnostic
* information of each chain.
* @return error_codes::OK if successful
*/
template <class Model, typename InitContextPtr, typename InitWriter,
typename SampleWriter, typename DiagnosticWriter>
int hmc_nuts_diag_e_adapt(
Model& model, size_t num_chains, const std::vector<InitContextPtr>& init,
unsigned int random_seed, unsigned int init_chain_id, double init_radius,
int num_warmup, int num_samples, int num_thin, bool save_warmup,
int refresh, double stepsize, double stepsize_jitter, int max_depth,
double delta, double gamma, double kappa, double t0,
unsigned int init_buffer, unsigned int term_buffer, unsigned int window,
callbacks::interrupt& interrupt, callbacks::logger& logger,
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer) {
std::vector<std::unique_ptr<stan::io::dump>> unit_e_metric;
unit_e_metric.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
unit_e_metric.emplace_back(std::make_unique<stan::io::dump>(
util::create_unit_e_diag_inv_metric(model.num_params_r())));
}
std::vector<stan::callbacks::structured_writer> dummy_metric_writer(
num_chains);
if (num_chains == 1) {
return hmc_nuts_diag_e_adapt(
model, *init[0], *unit_e_metric[0], random_seed, init_chain_id,
init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh,
stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0,
init_buffer, term_buffer, window, interrupt, logger, init_writer[0],
sample_writer[0], diagnostic_writer[0], dummy_metric_writer[0]);
}
return hmc_nuts_diag_e_adapt(
model, num_chains, init, unit_e_metric, random_seed, init_chain_id,
init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh,
stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0,
init_buffer, term_buffer, window, interrupt, logger, init_writer,
sample_writer, diagnostic_writer, dummy_metric_writer);
}

} // namespace sample
} // namespace services
} // namespace stan
Expand Down
Loading

0 comments on commit c9ed01d

Please sign in to comment.