Skip to content

Commit

Permalink
Merge branch 'feature/3181-json-hmc-tuning-params' of https://github.…
Browse files Browse the repository at this point in the history
…com/stan-dev/stan into feature/3181-json-hmc-tuning-params
  • Loading branch information
mitzimorris committed Sep 16, 2023
2 parents c8e616c + 2183f66 commit 56d725d
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 121 deletions.
2 changes: 1 addition & 1 deletion lib/stan_math
3 changes: 1 addition & 2 deletions src/stan/mcmc/hmc/base_hmc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ class base_hmc : public base_mcmc {
/**
* write stepsize and elements of mass matrix as a JSON object
*/
void write_sampler_state_struct(
callbacks::structured_writer& struct_writer) {
void write_sampler_state_struct(callbacks::structured_writer& struct_writer) {
struct_writer.begin_record();
struct_writer.write("stepsize", get_nominal_stepsize());
struct_writer.write("inv_metric", z_.inv_e_metric_);
Expand Down
7 changes: 3 additions & 4 deletions src/stan/services/sample/fixed_param.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ int fixed_param(Model& model, const stan::io::var_context& init,

stan::mcmc::fixed_param_sampler sampler;
callbacks::structured_writer dummy_metric_writer;
services::util::mcmc_writer writer(
sample_writer, diagnostic_writer, dummy_metric_writer, logger);
services::util::mcmc_writer writer(sample_writer, diagnostic_writer,
dummy_metric_writer, logger);
Eigen::VectorXd cont_params(cont_vector.size());
for (size_t i = 0; i < cont_vector.size(); i++)
cont_params[i] = cont_vector[i];
Expand Down Expand Up @@ -156,8 +156,7 @@ int fixed_param(Model& model, const std::size_t num_chains,
cont_vectors.push_back(
Eigen::Map<Eigen::VectorXd>(cont_vector.data(), cont_vector.size()));
samples.emplace_back(cont_vectors[i], 0, 0);
dummy_metric_writers.emplace_back(
stan::callbacks::structured_writer());
dummy_metric_writers.emplace_back(stan::callbacks::structured_writer());
writers.emplace_back(sample_writers[i], diagnostic_writers[i],
dummy_metric_writers[i], logger);
// Headers
Expand Down
9 changes: 4 additions & 5 deletions src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,7 @@ int hmc_nuts_dense_e_adapt(
std::vector<stan::callbacks::structured_writer> dummy_metric_writer;
dummy_metric_writer.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
dummy_metric_writer.emplace_back(
stan::callbacks::structured_writer());
dummy_metric_writer.emplace_back(stan::callbacks::structured_writer());
}
if (num_chains == 1) {
return hmc_nuts_dense_e_adapt(
Expand Down Expand Up @@ -532,7 +531,8 @@ int hmc_nuts_dense_e_adapt(
* @return error_codes::OK if successful
*/
template <class Model, typename InitContextPtr, typename InitWriter,
typename SampleWriter, typename DiagnosticWriter, typename MetricWriter>
typename SampleWriter, typename DiagnosticWriter,
typename MetricWriter>
int hmc_nuts_dense_e_adapt(
Model& model, size_t num_chains, const std::vector<InitContextPtr>& init,
unsigned int random_seed, unsigned int init_chain_id, double init_radius,
Expand Down Expand Up @@ -634,8 +634,7 @@ int hmc_nuts_dense_e_adapt(
std::vector<callbacks::structured_writer> dummy_metric_writer;
dummy_metric_writer.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
dummy_metric_writer.emplace_back(
stan::callbacks::structured_writer());
dummy_metric_writer.emplace_back(stan::callbacks::structured_writer());
}
if (num_chains == 1) {
return hmc_nuts_dense_e_adapt(
Expand Down
6 changes: 2 additions & 4 deletions src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,7 @@ int hmc_nuts_diag_e_adapt(
std::vector<callbacks::structured_writer> dummy_metric_writer;
dummy_metric_writer.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
dummy_metric_writer.emplace_back(
stan::callbacks::structured_writer());
dummy_metric_writer.emplace_back(stan::callbacks::structured_writer());
}
if (num_chains == 1) {
return hmc_nuts_diag_e_adapt(
Expand Down Expand Up @@ -644,8 +643,7 @@ int hmc_nuts_diag_e_adapt(
std::vector<callbacks::structured_writer> dummy_metric_writer;
dummy_metric_writer.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
dummy_metric_writer.emplace_back(
stan::callbacks::structured_writer());
dummy_metric_writer.emplace_back(stan::callbacks::structured_writer());
}
if (num_chains == 1) {
return hmc_nuts_diag_e_adapt(
Expand Down
3 changes: 1 addition & 2 deletions src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,7 @@ int hmc_nuts_unit_e_adapt(
std::vector<callbacks::structured_writer> dummy_metric_writer;
dummy_metric_writer.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
dummy_metric_writer.emplace_back(
stan::callbacks::structured_writer());
dummy_metric_writer.emplace_back(stan::callbacks::structured_writer());
}
if (num_chains == 1) {
return hmc_nuts_unit_e_adapt(
Expand Down
22 changes: 12 additions & 10 deletions src/stan/services/util/run_adaptive_sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,16 @@ namespace util {
* (optional, default == 1)
*/
template <typename Sampler, typename Model, typename RNG>
void run_adaptive_sampler(
Sampler& sampler, Model& model, std::vector<double>& cont_vector,
int num_warmup, int num_samples, int num_thin, int refresh,
bool save_warmup, RNG& rng, callbacks::interrupt& interrupt,
callbacks::logger& logger, callbacks::writer& sample_writer,
callbacks::writer& diagnostic_writer,
callbacks::structured_writer& metric_writer,
size_t chain_id = 1, size_t num_chains = 1) {
void run_adaptive_sampler(Sampler& sampler, Model& model,
std::vector<double>& cont_vector, int num_warmup,
int num_samples, int num_thin, int refresh,
bool save_warmup, RNG& rng,
callbacks::interrupt& interrupt,
callbacks::logger& logger,
callbacks::writer& sample_writer,
callbacks::writer& diagnostic_writer,
callbacks::structured_writer& metric_writer,
size_t chain_id = 1, size_t num_chains = 1) {
Eigen::Map<Eigen::VectorXd> cont_params(cont_vector.data(),
cont_vector.size());

Expand All @@ -63,8 +65,8 @@ void run_adaptive_sampler(
return;
}

services::util::mcmc_writer writer(
sample_writer, diagnostic_writer, metric_writer, logger);
services::util::mcmc_writer writer(sample_writer, diagnostic_writer,
metric_writer, logger);
stan::mcmc::sample s(cont_params, 0, 0);

// Headers
Expand Down
4 changes: 2 additions & 2 deletions src/stan/services/util/run_sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ void run_sampler(stan::mcmc::base_mcmc& sampler, Model& model,
Eigen::Map<Eigen::VectorXd> cont_params(cont_vector.data(),
cont_vector.size());
callbacks::structured_writer dummy_metric_writer;
services::util::mcmc_writer writer(
sample_writer, diagnostic_writer, dummy_metric_writer, logger);
services::util::mcmc_writer writer(sample_writer, diagnostic_writer,
dummy_metric_writer, logger);
stan::mcmc::sample s(cont_params, 0, 0);

// Headers
Expand Down
181 changes: 90 additions & 91 deletions src/test/unit/services/sample/hmc_nuts_unit_e_parallel_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,98 +95,97 @@ TEST_F(ServicesSampleHmcNutsUnitEPar, parameter_checks) {
std::vector<std::stringstream> ss_metric;
std::vector<stan::callbacks::structured_writer> metric;
for (int i = 0; i < num_chains; ++i) {
metric.emplace_back(
stan::callbacks::json_writer<std::stringstream, deleter_noop>(
std::unique_ptr<std::stringstream, deleter_noop>(&ss_metric[i])));


int return_code = stan::services::sample::hmc_nuts_unit_e(
model, num_chains, context, random_seed, chain, init_radius, num_warmup,
num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
max_depth, interrupt, logger, init, parameter, diagnostic, metric);

for (size_t i = 0; i < num_chains; ++i) {
std::vector<std::vector<std::string>> parameter_names;
parameter_names = parameter[i].vector_string_values();
std::vector<std::vector<double>> parameter_values;
parameter_values = parameter[i].vector_double_values();
std::vector<std::vector<std::string>> diagnostic_names;
diagnostic_names = diagnostic[i].vector_string_values();
std::vector<std::vector<double>> diagnostic_values;
diagnostic_values = diagnostic[i].vector_double_values();
std::vector<std::string> metrics;
metrics[i] = ss_metric[i].str();
// Adapted metric
rapidjson::Document document;
ASSERT_FALSE(document.Parse<0>(metrics[i].c_str()).HasParseError());
EXPECT_EQ(count_matches("stepsize", metrics[i]), 1);
EXPECT_EQ(count_matches("inv_metric", metrics[i]), 1);
EXPECT_EQ(count_matches("[", par_metrics[i]), 1); // single list
EXPECT_EQ(count_matches("[ 0, 0 ]", par_metrics[i]), 1); // unit diagonal

// Expectations of parameter parameter names.
ASSERT_EQ(9, parameter_names[0].size());
EXPECT_EQ("lp__", parameter_names[0][0]);
EXPECT_EQ("accept_stat__", parameter_names[0][1]);
EXPECT_EQ("stepsize__", parameter_names[0][2]);
EXPECT_EQ("treedepth__", parameter_names[0][3]);
EXPECT_EQ("n_leapfrog__", parameter_names[0][4]);
EXPECT_EQ("divergent__", parameter_names[0][5]);
EXPECT_EQ("energy__", parameter_names[0][6]);
EXPECT_EQ("x", parameter_names[0][7]);
EXPECT_EQ("y", parameter_names[0][8]);

// Expect one name per parameter value.
EXPECT_EQ(parameter_names[0].size(), parameter_values[0].size());
EXPECT_EQ(diagnostic_names[0].size(), diagnostic_values[0].size());

EXPECT_EQ((num_warmup + num_samples) / num_thin, parameter_values.size());

// Expect one call to set parameter names, and one set of output per
// iteration.
EXPECT_EQ("lp__", diagnostic_names[0][0]);
EXPECT_EQ("accept_stat__", diagnostic_names[0][1]);
metric.emplace_back(
stan::callbacks::json_writer<std::stringstream, deleter_noop>(
std::unique_ptr<std::stringstream, deleter_noop>(&ss_metric[i])));

int return_code = stan::services::sample::hmc_nuts_unit_e(
model, num_chains, context, random_seed, chain, init_radius, num_warmup,
num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
max_depth, interrupt, logger, init, parameter, diagnostic, metric);

for (size_t i = 0; i < num_chains; ++i) {
std::vector<std::vector<std::string>> parameter_names;
parameter_names = parameter[i].vector_string_values();
std::vector<std::vector<double>> parameter_values;
parameter_values = parameter[i].vector_double_values();
std::vector<std::vector<std::string>> diagnostic_names;
diagnostic_names = diagnostic[i].vector_string_values();
std::vector<std::vector<double>> diagnostic_values;
diagnostic_values = diagnostic[i].vector_double_values();
std::vector<std::string> metrics;
metrics[i] = ss_metric[i].str();
// Adapted metric
rapidjson::Document document;
ASSERT_FALSE(document.Parse<0>(metrics[i].c_str()).HasParseError());
EXPECT_EQ(count_matches("stepsize", metrics[i]), 1);
EXPECT_EQ(count_matches("inv_metric", metrics[i]), 1);
EXPECT_EQ(count_matches("[", par_metrics[i]), 1); // single list
EXPECT_EQ(count_matches("[ 0, 0 ]", par_metrics[i]), 1); // unit diagonal

// Expectations of parameter parameter names.
ASSERT_EQ(9, parameter_names[0].size());
EXPECT_EQ("lp__", parameter_names[0][0]);
EXPECT_EQ("accept_stat__", parameter_names[0][1]);
EXPECT_EQ("stepsize__", parameter_names[0][2]);
EXPECT_EQ("treedepth__", parameter_names[0][3]);
EXPECT_EQ("n_leapfrog__", parameter_names[0][4]);
EXPECT_EQ("divergent__", parameter_names[0][5]);
EXPECT_EQ("energy__", parameter_names[0][6]);
EXPECT_EQ("x", parameter_names[0][7]);
EXPECT_EQ("y", parameter_names[0][8]);

// Expect one name per parameter value.
EXPECT_EQ(parameter_names[0].size(), parameter_values[0].size());
EXPECT_EQ(diagnostic_names[0].size(), diagnostic_values[0].size());

EXPECT_EQ((num_warmup + num_samples) / num_thin, parameter_values.size());

// Expect one call to set parameter names, and one set of output per
// iteration.
EXPECT_EQ("lp__", diagnostic_names[0][0]);
EXPECT_EQ("accept_stat__", diagnostic_names[0][1]);
}
EXPECT_EQ(return_code, 0);
}
EXPECT_EQ(return_code, 0);
}

TEST_F(ServicesSampleHmcNutsUnitEPar, output_regression) {
unsigned int random_seed = 0;
unsigned int chain = 1;
double init_radius = 0;
int num_warmup = 200;
int num_samples = 400;
int num_thin = 5;
bool save_warmup = true;
int refresh = 0;
double stepsize = 0.1;
double stepsize_jitter = 0;
int max_depth = 8;
double delta = .1;
double gamma = .1;
double kappa = .1;
double t0 = .1;
unsigned int init_buffer = 50;
unsigned int term_buffer = 50;
unsigned int window = 100;
stan::test::unit::instrumented_interrupt interrupt;
EXPECT_EQ(interrupt.call_count(), 0);

stan::services::sample::hmc_nuts_unit_e(
model, num_chains, context, random_seed, chain, init_radius, num_warmup,
num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
max_depth, interrupt, logger, init, parameter, diagnostic);

for (auto&& init_it : init) {
std::vector<std::string> init_values;
init_values = init_it.string_values();
TEST_F(ServicesSampleHmcNutsUnitEPar, output_regression) {
unsigned int random_seed = 0;
unsigned int chain = 1;
double init_radius = 0;
int num_warmup = 200;
int num_samples = 400;
int num_thin = 5;
bool save_warmup = true;
int refresh = 0;
double stepsize = 0.1;
double stepsize_jitter = 0;
int max_depth = 8;
double delta = .1;
double gamma = .1;
double kappa = .1;
double t0 = .1;
unsigned int init_buffer = 50;
unsigned int term_buffer = 50;
unsigned int window = 100;
stan::test::unit::instrumented_interrupt interrupt;
EXPECT_EQ(interrupt.call_count(), 0);

stan::services::sample::hmc_nuts_unit_e(
model, num_chains, context, random_seed, chain, init_radius, num_warmup,
num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
max_depth, interrupt, logger, init, parameter, diagnostic);

for (auto&& init_it : init) {
std::vector<std::string> init_values;
init_values = init_it.string_values();

EXPECT_EQ(0, init_values.size());
}

EXPECT_EQ(0, init_values.size());
EXPECT_EQ(num_chains, logger.find_info("Elapsed Time:"));
EXPECT_EQ(num_chains, logger.find_info("seconds (Warm-up)"));
EXPECT_EQ(num_chains, logger.find_info("seconds (Sampling)"));
EXPECT_EQ(num_chains, logger.find_info("seconds (Total)"));
EXPECT_EQ(0, logger.call_count_error());
}

EXPECT_EQ(num_chains, logger.find_info("Elapsed Time:"));
EXPECT_EQ(num_chains, logger.find_info("seconds (Warm-up)"));
EXPECT_EQ(num_chains, logger.find_info("seconds (Sampling)"));
EXPECT_EQ(num_chains, logger.find_info("seconds (Total)"));
EXPECT_EQ(0, logger.call_count_error());
}

0 comments on commit 56d725d

Please sign in to comment.