diff --git a/lib/stan_math b/lib/stan_math index 5091bf9aa1..37df3571c7 160000 --- a/lib/stan_math +++ b/lib/stan_math @@ -1 +1 @@ -Subproject commit 5091bf9aa1cce0fe69b1cc1a7edef81ac9ebe1e9 +Subproject commit 37df3571c7853dfa1939cc7c001bdc7b0e07ca1a diff --git a/src/stan/mcmc/hmc/base_hmc.hpp b/src/stan/mcmc/hmc/base_hmc.hpp index ecffb9aa6c..d170ff4f6b 100644 --- a/src/stan/mcmc/hmc/base_hmc.hpp +++ b/src/stan/mcmc/hmc/base_hmc.hpp @@ -67,8 +67,7 @@ class base_hmc : public base_mcmc { /** * write stepsize and elements of mass matrix as a JSON object */ - void write_sampler_state_struct( - callbacks::structured_writer& struct_writer) { + void write_sampler_state_struct(callbacks::structured_writer& struct_writer) { struct_writer.begin_record(); struct_writer.write("stepsize", get_nominal_stepsize()); struct_writer.write("inv_metric", z_.inv_e_metric_); diff --git a/src/stan/services/sample/fixed_param.hpp b/src/stan/services/sample/fixed_param.hpp index 23454edace..a7b9816866 100644 --- a/src/stan/services/sample/fixed_param.hpp +++ b/src/stan/services/sample/fixed_param.hpp @@ -66,8 +66,8 @@ int fixed_param(Model& model, const stan::io::var_context& init, stan::mcmc::fixed_param_sampler sampler; callbacks::structured_writer dummy_metric_writer; - services::util::mcmc_writer writer( - sample_writer, diagnostic_writer, dummy_metric_writer, logger); + services::util::mcmc_writer writer(sample_writer, diagnostic_writer, + dummy_metric_writer, logger); Eigen::VectorXd cont_params(cont_vector.size()); for (size_t i = 0; i < cont_vector.size(); i++) cont_params[i] = cont_vector[i]; @@ -156,8 +156,7 @@ int fixed_param(Model& model, const std::size_t num_chains, cont_vectors.push_back( Eigen::Map(cont_vector.data(), cont_vector.size())); samples.emplace_back(cont_vectors[i], 0, 0); - dummy_metric_writers.emplace_back( - stan::callbacks::structured_writer()); + dummy_metric_writers.emplace_back(stan::callbacks::structured_writer()); writers.emplace_back(sample_writers[i], diagnostic_writers[i], dummy_metric_writers[i], logger); // Headers diff --git a/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp index 6373ab8382..596a95824a 100644 --- a/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp @@ -464,8 +464,7 @@ int hmc_nuts_dense_e_adapt( std::vector dummy_metric_writer; dummy_metric_writer.reserve(num_chains); for (size_t i = 0; i < num_chains; ++i) { - dummy_metric_writer.emplace_back( - stan::callbacks::structured_writer()); + dummy_metric_writer.emplace_back(stan::callbacks::structured_writer()); } if (num_chains == 1) { return hmc_nuts_dense_e_adapt( @@ -532,7 +531,8 @@ int hmc_nuts_dense_e_adapt( * @return error_codes::OK if successful */ template + typename SampleWriter, typename DiagnosticWriter, + typename MetricWriter> int hmc_nuts_dense_e_adapt( Model& model, size_t num_chains, const std::vector& init, unsigned int random_seed, unsigned int init_chain_id, double init_radius, @@ -634,8 +634,7 @@ int hmc_nuts_dense_e_adapt( std::vector dummy_metric_writer; dummy_metric_writer.reserve(num_chains); for (size_t i = 0; i < num_chains; ++i) { - dummy_metric_writer.emplace_back( - stan::callbacks::structured_writer()); + dummy_metric_writer.emplace_back(stan::callbacks::structured_writer()); } if (num_chains == 1) { return hmc_nuts_dense_e_adapt( diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp index ecbf9e36be..8ffc1f8656 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp @@ -473,8 +473,7 @@ int hmc_nuts_diag_e_adapt( std::vector dummy_metric_writer; dummy_metric_writer.reserve(num_chains); for (size_t i = 0; i < num_chains; ++i) { - dummy_metric_writer.emplace_back( - stan::callbacks::structured_writer()); + dummy_metric_writer.emplace_back(stan::callbacks::structured_writer()); } if (num_chains == 1) { return hmc_nuts_diag_e_adapt( @@ -644,8 +643,7 @@ int hmc_nuts_diag_e_adapt( std::vector dummy_metric_writer; dummy_metric_writer.reserve(num_chains); for (size_t i = 0; i < num_chains; ++i) { - dummy_metric_writer.emplace_back( - stan::callbacks::structured_writer()); + dummy_metric_writer.emplace_back(stan::callbacks::structured_writer()); } if (num_chains == 1) { return hmc_nuts_diag_e_adapt( diff --git a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp index dc921889de..ed7d3b540b 100644 --- a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp @@ -301,8 +301,7 @@ int hmc_nuts_unit_e_adapt( std::vector dummy_metric_writer; dummy_metric_writer.reserve(num_chains); for (size_t i = 0; i < num_chains; ++i) { - dummy_metric_writer.emplace_back( - stan::callbacks::structured_writer()); + dummy_metric_writer.emplace_back(stan::callbacks::structured_writer()); } if (num_chains == 1) { return hmc_nuts_unit_e_adapt( diff --git a/src/stan/services/util/run_adaptive_sampler.hpp b/src/stan/services/util/run_adaptive_sampler.hpp index 0e82d164e7..fb0218314a 100644 --- a/src/stan/services/util/run_adaptive_sampler.hpp +++ b/src/stan/services/util/run_adaptive_sampler.hpp @@ -42,14 +42,16 @@ namespace util { * (optional, default == 1) */ template -void run_adaptive_sampler( - Sampler& sampler, Model& model, std::vector& cont_vector, - int num_warmup, int num_samples, int num_thin, int refresh, - bool save_warmup, RNG& rng, callbacks::interrupt& interrupt, - callbacks::logger& logger, callbacks::writer& sample_writer, - callbacks::writer& diagnostic_writer, - callbacks::structured_writer& metric_writer, - size_t chain_id = 1, size_t num_chains = 1) { +void run_adaptive_sampler(Sampler& sampler, Model& model, + std::vector& cont_vector, int num_warmup, + int num_samples, int num_thin, int refresh, + bool save_warmup, RNG& rng, + callbacks::interrupt& interrupt, + callbacks::logger& logger, + callbacks::writer& sample_writer, + callbacks::writer& diagnostic_writer, + callbacks::structured_writer& metric_writer, + size_t chain_id = 1, size_t num_chains = 1) { Eigen::Map cont_params(cont_vector.data(), cont_vector.size()); @@ -63,8 +65,8 @@ void run_adaptive_sampler( return; } - services::util::mcmc_writer writer( - sample_writer, diagnostic_writer, metric_writer, logger); + services::util::mcmc_writer writer(sample_writer, diagnostic_writer, + metric_writer, logger); stan::mcmc::sample s(cont_params, 0, 0); // Headers diff --git a/src/stan/services/util/run_sampler.hpp b/src/stan/services/util/run_sampler.hpp index 341cbb8d41..8ac7c16f0e 100644 --- a/src/stan/services/util/run_sampler.hpp +++ b/src/stan/services/util/run_sampler.hpp @@ -49,8 +49,8 @@ void run_sampler(stan::mcmc::base_mcmc& sampler, Model& model, Eigen::Map cont_params(cont_vector.data(), cont_vector.size()); callbacks::structured_writer dummy_metric_writer; - services::util::mcmc_writer writer( - sample_writer, diagnostic_writer, dummy_metric_writer, logger); + services::util::mcmc_writer writer(sample_writer, diagnostic_writer, + dummy_metric_writer, logger); stan::mcmc::sample s(cont_params, 0, 0); // Headers diff --git a/src/test/unit/services/sample/hmc_nuts_unit_e_parallel_test.cpp b/src/test/unit/services/sample/hmc_nuts_unit_e_parallel_test.cpp index 6c4489b7d9..23038d304a 100644 --- a/src/test/unit/services/sample/hmc_nuts_unit_e_parallel_test.cpp +++ b/src/test/unit/services/sample/hmc_nuts_unit_e_parallel_test.cpp @@ -95,98 +95,97 @@ TEST_F(ServicesSampleHmcNutsUnitEPar, parameter_checks) { std::vector ss_metric; std::vector metric; for (int i = 0; i < num_chains; ++i) { - metric.emplace_back( - stan::callbacks::json_writer( - std::unique_ptr(&ss_metric[i]))); - - - int return_code = stan::services::sample::hmc_nuts_unit_e( - model, num_chains, context, random_seed, chain, init_radius, num_warmup, - num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, - max_depth, interrupt, logger, init, parameter, diagnostic, metric); - - for (size_t i = 0; i < num_chains; ++i) { - std::vector> parameter_names; - parameter_names = parameter[i].vector_string_values(); - std::vector> parameter_values; - parameter_values = parameter[i].vector_double_values(); - std::vector> diagnostic_names; - diagnostic_names = diagnostic[i].vector_string_values(); - std::vector> diagnostic_values; - diagnostic_values = diagnostic[i].vector_double_values(); - std::vector metrics; - metrics[i] = ss_metric[i].str(); - // Adapted metric - rapidjson::Document document; - ASSERT_FALSE(document.Parse<0>(metrics[i].c_str()).HasParseError()); - EXPECT_EQ(count_matches("stepsize", metrics[i]), 1); - EXPECT_EQ(count_matches("inv_metric", metrics[i]), 1); - EXPECT_EQ(count_matches("[", par_metrics[i]), 1); // single list - EXPECT_EQ(count_matches("[ 0, 0 ]", par_metrics[i]), 1); // unit diagonal - - // Expectations of parameter parameter names. - ASSERT_EQ(9, parameter_names[0].size()); - EXPECT_EQ("lp__", parameter_names[0][0]); - EXPECT_EQ("accept_stat__", parameter_names[0][1]); - EXPECT_EQ("stepsize__", parameter_names[0][2]); - EXPECT_EQ("treedepth__", parameter_names[0][3]); - EXPECT_EQ("n_leapfrog__", parameter_names[0][4]); - EXPECT_EQ("divergent__", parameter_names[0][5]); - EXPECT_EQ("energy__", parameter_names[0][6]); - EXPECT_EQ("x", parameter_names[0][7]); - EXPECT_EQ("y", parameter_names[0][8]); - - // Expect one name per parameter value. - EXPECT_EQ(parameter_names[0].size(), parameter_values[0].size()); - EXPECT_EQ(diagnostic_names[0].size(), diagnostic_values[0].size()); - - EXPECT_EQ((num_warmup + num_samples) / num_thin, parameter_values.size()); - - // Expect one call to set parameter names, and one set of output per - // iteration. - EXPECT_EQ("lp__", diagnostic_names[0][0]); - EXPECT_EQ("accept_stat__", diagnostic_names[0][1]); + metric.emplace_back( + stan::callbacks::json_writer( + std::unique_ptr(&ss_metric[i]))); + + int return_code = stan::services::sample::hmc_nuts_unit_e( + model, num_chains, context, random_seed, chain, init_radius, num_warmup, + num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, + max_depth, interrupt, logger, init, parameter, diagnostic, metric); + + for (size_t i = 0; i < num_chains; ++i) { + std::vector> parameter_names; + parameter_names = parameter[i].vector_string_values(); + std::vector> parameter_values; + parameter_values = parameter[i].vector_double_values(); + std::vector> diagnostic_names; + diagnostic_names = diagnostic[i].vector_string_values(); + std::vector> diagnostic_values; + diagnostic_values = diagnostic[i].vector_double_values(); + std::vector metrics; + metrics[i] = ss_metric[i].str(); + // Adapted metric + rapidjson::Document document; + ASSERT_FALSE(document.Parse<0>(metrics[i].c_str()).HasParseError()); + EXPECT_EQ(count_matches("stepsize", metrics[i]), 1); + EXPECT_EQ(count_matches("inv_metric", metrics[i]), 1); + EXPECT_EQ(count_matches("[", par_metrics[i]), 1); // single list + EXPECT_EQ(count_matches("[ 0, 0 ]", par_metrics[i]), 1); // unit diagonal + + // Expectations of parameter parameter names. + ASSERT_EQ(9, parameter_names[0].size()); + EXPECT_EQ("lp__", parameter_names[0][0]); + EXPECT_EQ("accept_stat__", parameter_names[0][1]); + EXPECT_EQ("stepsize__", parameter_names[0][2]); + EXPECT_EQ("treedepth__", parameter_names[0][3]); + EXPECT_EQ("n_leapfrog__", parameter_names[0][4]); + EXPECT_EQ("divergent__", parameter_names[0][5]); + EXPECT_EQ("energy__", parameter_names[0][6]); + EXPECT_EQ("x", parameter_names[0][7]); + EXPECT_EQ("y", parameter_names[0][8]); + + // Expect one name per parameter value. + EXPECT_EQ(parameter_names[0].size(), parameter_values[0].size()); + EXPECT_EQ(diagnostic_names[0].size(), diagnostic_values[0].size()); + + EXPECT_EQ((num_warmup + num_samples) / num_thin, parameter_values.size()); + + // Expect one call to set parameter names, and one set of output per + // iteration. + EXPECT_EQ("lp__", diagnostic_names[0][0]); + EXPECT_EQ("accept_stat__", diagnostic_names[0][1]); + } + EXPECT_EQ(return_code, 0); } - EXPECT_EQ(return_code, 0); -} -TEST_F(ServicesSampleHmcNutsUnitEPar, output_regression) { - unsigned int random_seed = 0; - unsigned int chain = 1; - double init_radius = 0; - int num_warmup = 200; - int num_samples = 400; - int num_thin = 5; - bool save_warmup = true; - int refresh = 0; - double stepsize = 0.1; - double stepsize_jitter = 0; - int max_depth = 8; - double delta = .1; - double gamma = .1; - double kappa = .1; - double t0 = .1; - unsigned int init_buffer = 50; - unsigned int term_buffer = 50; - unsigned int window = 100; - stan::test::unit::instrumented_interrupt interrupt; - EXPECT_EQ(interrupt.call_count(), 0); - - stan::services::sample::hmc_nuts_unit_e( - model, num_chains, context, random_seed, chain, init_radius, num_warmup, - num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, - max_depth, interrupt, logger, init, parameter, diagnostic); - - for (auto&& init_it : init) { - std::vector init_values; - init_values = init_it.string_values(); + TEST_F(ServicesSampleHmcNutsUnitEPar, output_regression) { + unsigned int random_seed = 0; + unsigned int chain = 1; + double init_radius = 0; + int num_warmup = 200; + int num_samples = 400; + int num_thin = 5; + bool save_warmup = true; + int refresh = 0; + double stepsize = 0.1; + double stepsize_jitter = 0; + int max_depth = 8; + double delta = .1; + double gamma = .1; + double kappa = .1; + double t0 = .1; + unsigned int init_buffer = 50; + unsigned int term_buffer = 50; + unsigned int window = 100; + stan::test::unit::instrumented_interrupt interrupt; + EXPECT_EQ(interrupt.call_count(), 0); + + stan::services::sample::hmc_nuts_unit_e( + model, num_chains, context, random_seed, chain, init_radius, num_warmup, + num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, + max_depth, interrupt, logger, init, parameter, diagnostic); + + for (auto&& init_it : init) { + std::vector init_values; + init_values = init_it.string_values(); + + EXPECT_EQ(0, init_values.size()); + } - EXPECT_EQ(0, init_values.size()); + EXPECT_EQ(num_chains, logger.find_info("Elapsed Time:")); + EXPECT_EQ(num_chains, logger.find_info("seconds (Warm-up)")); + EXPECT_EQ(num_chains, logger.find_info("seconds (Sampling)")); + EXPECT_EQ(num_chains, logger.find_info("seconds (Total)")); + EXPECT_EQ(0, logger.call_count_error()); } - - EXPECT_EQ(num_chains, logger.find_info("Elapsed Time:")); - EXPECT_EQ(num_chains, logger.find_info("seconds (Warm-up)")); - EXPECT_EQ(num_chains, logger.find_info("seconds (Sampling)")); - EXPECT_EQ(num_chains, logger.find_info("seconds (Total)")); - EXPECT_EQ(0, logger.call_count_error()); -}