Skip to content

Commit

Permalink
unit tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
mitzimorris committed Sep 16, 2023
1 parent 56d725d commit bb1ff9d
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 132 deletions.
2 changes: 1 addition & 1 deletion src/stan/services/sample/fixed_param.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ int fixed_param(Model& model, const std::size_t num_chains,
}
std::vector<boost::ecuyer1988> rngs;
std::vector<Eigen::VectorXd> cont_vectors;
std::vector<callbacks::structured_writer> dummy_metric_writers;
std::vector<stan::callbacks::structured_writer> dummy_metric_writers;
std::vector<util::mcmc_writer> writers;
std::vector<stan::mcmc::sample> samples;
std::vector<stan::mcmc::fixed_param_sampler> samplers(num_chains);
Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ int hmc_nuts_dense_e_adapt(
unit_e_metric.emplace_back(std::make_unique<stan::io::dump>(
util::create_unit_e_dense_inv_metric(model.num_params_r())));
}
std::vector<callbacks::structured_writer> dummy_metric_writer;
std::vector<stan::callbacks::structured_writer> dummy_metric_writer;
dummy_metric_writer.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
dummy_metric_writer.emplace_back(stan::callbacks::structured_writer());
Expand Down
17 changes: 10 additions & 7 deletions src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ int hmc_nuts_diag_e_adapt(
* @return error_codes::OK if successful
*/
template <class Model, typename InitContextPtr, typename InitInvContextPtr,
typename InitWriter, typename SampleWriter, typename DiagnosticWriter>
typename InitWriter, typename SampleWriter, typename DiagnosticWriter,
typename MetricWriter>
int hmc_nuts_diag_e_adapt(
Model& model, size_t num_chains, const std::vector<InitContextPtr>& init,
const std::vector<InitInvContextPtr>& init_inv_metric,
Expand All @@ -349,7 +350,7 @@ int hmc_nuts_diag_e_adapt(
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer,
std::vector<callbacks::structured_writer>& metric_writer) {
std::vector<MetricWriter>& metric_writer) {
if (num_chains == 1) {
return hmc_nuts_diag_e_adapt(
model, *init[0], *init_inv_metric[0], random_seed, init_chain_id,
Expand Down Expand Up @@ -470,7 +471,7 @@ int hmc_nuts_diag_e_adapt(
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer) {
std::vector<callbacks::structured_writer> dummy_metric_writer;
std::vector<stan::callbacks::structured_writer> dummy_metric_writer;
dummy_metric_writer.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
dummy_metric_writer.emplace_back(stan::callbacks::structured_writer());
Expand Down Expand Up @@ -499,9 +500,10 @@ int hmc_nuts_diag_e_adapt(
* @tparam Model Model class
* @tparam InitContextPtr A pointer with underlying type derived from
* `stan::io::var_context`
* @tparam InitWriter A type derived from `stan::callbacks::writer`
* @tparam SamplerWriter A type derived from `stan::callbacks::writer`
* @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
* @tparam InitWriter A type derived from `stan::callbacks::writer`
* @tparam MetricWriter A type derived from `stan::callbacks::structured_writer`
* @param[in] model Input model (with data already instantiated)
* @param[in] num_chains The number of chains to run in parallel. `init`,
* `init_writer`, `sample_writer`, and `diagnostic_writer` must be the same
Expand Down Expand Up @@ -540,7 +542,8 @@ int hmc_nuts_diag_e_adapt(
* @return error_codes::OK if successful
*/
template <class Model, typename InitContextPtr, typename InitWriter,
typename SampleWriter, typename DiagnosticWriter>
typename SampleWriter, typename DiagnosticWriter,
typename MetricWriter>
int hmc_nuts_diag_e_adapt(
Model& model, size_t num_chains, const std::vector<InitContextPtr>& init,
unsigned int random_seed, unsigned int init_chain_id, double init_radius,
Expand All @@ -552,7 +555,7 @@ int hmc_nuts_diag_e_adapt(
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer,
std::vector<callbacks::structured_writer>& metric_writer) {
std::vector<MetricWriter>& metric_writer) {
std::vector<std::unique_ptr<stan::io::dump>> unit_e_metric;
unit_e_metric.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
Expand Down Expand Up @@ -640,7 +643,7 @@ int hmc_nuts_diag_e_adapt(
unit_e_metric.emplace_back(std::make_unique<stan::io::dump>(
util::create_unit_e_diag_inv_metric(model.num_params_r())));
}
std::vector<callbacks::structured_writer> dummy_metric_writer;
std::vector<stan::callbacks::structured_writer> dummy_metric_writer;
dummy_metric_writer.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
dummy_metric_writer.emplace_back(stan::callbacks::structured_writer());
Expand Down
9 changes: 4 additions & 5 deletions src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ int hmc_nuts_unit_e_adapt(
* @tparam InitWriter A type derived from `stan::callbacks::writer`
* @tparam SamplerWriter A type derived from `stan::callbacks::writer`
* @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
* @tparam Stream A type with with a valid `operator<<(std::string)`
* @tparam Deleter A class with a valid `operator()` method for deleting the
* @tparam MetricWriter A type derived from `stan::callbacks::structured_writer`
* @param[in] model Input model (with data already instantiated)
* @param[in] num_chains The number of chains to run in parallel. `init`,
* `init_inv_metric`, `init_writer`, `sample_writer`, and `diagnostic_writer`
Expand Down Expand Up @@ -181,7 +180,7 @@ int hmc_nuts_unit_e_adapt(
* @return error_codes::OK if successful
*/
template <class Model, typename InitContextPtr, typename InitWriter,
typename SampleWriter, typename DiagnosticWriter>
typename SampleWriter, typename DiagnosticWriter, typename MetricWriter>
int hmc_nuts_unit_e_adapt(
Model& model, size_t num_chains, const std::vector<InitContextPtr>& init,
unsigned int random_seed, unsigned int init_chain_id, double init_radius,
Expand All @@ -192,7 +191,7 @@ int hmc_nuts_unit_e_adapt(
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer,
std::vector<callbacks::structured_writer>& metric_writer) {
std::vector<MetricWriter>& metric_writer) {
if (num_chains == 1) {
return hmc_nuts_unit_e_adapt(
model, *init[0], random_seed, init_chain_id, init_radius, num_warmup,
Expand Down Expand Up @@ -298,7 +297,7 @@ int hmc_nuts_unit_e_adapt(
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer) {
std::vector<callbacks::structured_writer> dummy_metric_writer;
std::vector<stan::callbacks::structured_writer> dummy_metric_writer;
dummy_metric_writer.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
dummy_metric_writer.emplace_back(stan::callbacks::structured_writer());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <stan/services/sample/hmc_nuts_dense_e_adapt.hpp>
#include <stan/io/empty_var_context.hpp>
#include <stan/callbacks/json_writer.hpp>
#include <stan/callbacks/unique_stream_writer.hpp>
#include <stan/io/empty_var_context.hpp>
#include <test/unit/util.hpp>
#include <src/test/unit/services/util.hpp>
#include <test/test-models/good/optimization/rosenbrock.hpp>
Expand Down Expand Up @@ -44,8 +44,6 @@ class ServicesSampleHmcNutsDenseEAdaptParMatch : public testing::Test {
for (int i = 0; i < num_chains; ++i) {
ss_metric[i].str(std::string());
ss_metric[i].clear();
metrics[i].begin_record();
metrics[i].end_record();
}
}

Expand All @@ -60,7 +58,8 @@ class ServicesSampleHmcNutsDenseEAdaptParMatch : public testing::Test {
= stan::callbacks::unique_stream_writer<std::stringstream, deleter_noop>;
std::vector<str_writer> par_parameters;
std::vector<str_writer> seq_parameters;
std::vector<stan::callbacks::structured_writer> metrics;
std::vector<stan::callbacks::json_writer<std::stringstream, deleter_noop>>
metrics;
std::vector<stan::test::unit::instrumented_writer> diagnostics;
std::vector<std::shared_ptr<stan::io::empty_var_context>> context;
std::unique_ptr<rosenbrock_model_namespace::rosenbrock_model> model;
Expand Down Expand Up @@ -125,12 +124,11 @@ TEST_F(ServicesSampleHmcNutsDenseEAdaptParMatch, single_multi_match) {
par_res.push_back(par_mat);

par_metrics.push_back(ss_metric[i].str());
std::cout << "metric " << par_metrics[i] << std::endl << std::flush;
// rapidjson::Document document;
// ASSERT_FALSE(document.Parse<0>(par_metrics[i].c_str()).HasParseError());
// EXPECT_EQ(count_matches("stepsize", par_metrics[i]), 1);
// EXPECT_EQ(count_matches("inv_metric", par_metrics[i]), 1);
// EXPECT_EQ(count_matches("[", par_metrics[i]), 3); // list has 2 rows
rapidjson::Document document;
ASSERT_FALSE(document.Parse<0>(par_metrics[i].c_str()).HasParseError());
EXPECT_EQ(count_matches("stepsize", par_metrics[i]), 1);
EXPECT_EQ(count_matches("inv_metric", par_metrics[i]), 1);
EXPECT_EQ(count_matches("[", par_metrics[i]), 3); // list has 2 rows
}

std::vector<Eigen::MatrixXd> seq_res;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#include <stan/services/sample/hmc_nuts_diag_e_adapt.hpp>
#include <stan/io/empty_var_context.hpp>
#include <stan/callbacks/json_writer.hpp>
#include <stan/callbacks/unique_stream_writer.hpp>
#include <stan/callbacks/structured_writer.hpp>
#include <test/unit/util.hpp>
#include <stan/io/empty_var_context.hpp>
#include <src/test/unit/services/util.hpp>
#include <test/test-models/good/optimization/rosenbrock.hpp>
#include <test/unit/services/instrumented_callbacks.hpp>
#include <test/unit/util.hpp>
#include <rapidjson/document.h>
#include <gtest/gtest.h>
#include <iostream>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,35 +1,58 @@
#include <stan/services/sample/hmc_nuts_unit_e_adapt.hpp>
#include <stan/callbacks/structured_writer.hpp>
#include <stan/callbacks/json_writer.hpp>
#include <stan/callbacks/unique_stream_writer.hpp>
#include <stan/io/empty_var_context.hpp>
#include <src/test/unit/services/util.hpp>
#include <test/test-models/good/optimization/rosenbrock.hpp>
#include <test/unit/services/instrumented_callbacks.hpp>
#include <iostream>
#include <test/unit/util.hpp>
#include <rapidjson/document.h>
#include <gtest/gtest.h>
#include <iostream>

auto&& blah = stan::math::init_threadpool_tbb();

static constexpr size_t num_chains = 4;

struct deleter_noop {
template <typename T>
constexpr void operator()(T* arg) const {}
};

class ServicesSampleHmcNutsUnitEAdaptPar : public testing::Test {
public:
ServicesSampleHmcNutsUnitEAdaptPar() : model(data_context, 0, &model_log) {
ServicesSampleHmcNutsUnitEAdaptPar()
: ss_metric(num_chains),
model(data_context, 0, &model_log) {
for (int i = 0; i < num_chains; ++i) {
init.push_back(stan::test::unit::instrumented_writer{});
parameter.push_back(stan::test::unit::instrumented_writer{});
diagnostic.push_back(stan::test::unit::instrumented_writer{});
metric.push_back(stan::callbacks::structured_writer(
std::unique_ptr<std::ofstream>(nullptr)));
metric.push_back(
stan::callbacks::json_writer<std::stringstream, deleter_noop>(
std::unique_ptr<std::stringstream, deleter_noop>(&ss_metric[i])));
context.push_back(std::make_shared<stan::io::empty_var_context>());
}
}

void SetUp() {
for (int i = 0; i < num_chains; ++i) {
ss_metric[i].str(std::string());
ss_metric[i].clear();
}
}

stan::io::empty_var_context data_context;
std::stringstream model_log;
stan::test::unit::instrumented_logger logger;
std::vector<stan::test::unit::instrumented_writer> init;
std::vector<stan::test::unit::instrumented_writer> parameter;
std::vector<stan::test::unit::instrumented_writer> diagnostic;
std::vector<stan::callbacks::structured_writer> metric;
std::vector<std::shared_ptr<stan::io::empty_var_context>> context;
stan_model model;
std::vector<std::stringstream> ss_metric;
std::vector<stan::callbacks::json_writer<std::stringstream, deleter_noop>>
metric;
};

TEST_F(ServicesSampleHmcNutsUnitEAdaptPar, call_count) {
Expand Down Expand Up @@ -92,7 +115,7 @@ TEST_F(ServicesSampleHmcNutsUnitEAdaptPar, parameter_checks) {
model, num_chains, context, random_seed, chain, init_radius, num_warmup,
num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
max_depth, delta, gamma, kappa, t0, interrupt, logger, init, parameter,
diagnostic);
diagnostic, metric);

for (size_t i = 0; i < num_chains; ++i) {
std::vector<std::vector<std::string>> parameter_names;
Expand All @@ -103,6 +126,14 @@ TEST_F(ServicesSampleHmcNutsUnitEAdaptPar, parameter_checks) {
diagnostic_names = diagnostic[i].vector_string_values();
std::vector<std::vector<double>> diagnostic_values;
diagnostic_values = diagnostic[i].vector_double_values();
std::string metric = ss_metric[i].str();
// Adapted metric
rapidjson::Document document;
ASSERT_FALSE(document.Parse<0>(metric.c_str()).HasParseError());
EXPECT_EQ(count_matches("stepsize", metric), 1);
EXPECT_EQ(count_matches("inv_metric", metric), 1);
EXPECT_EQ(count_matches("[", metric), 1); // single list
EXPECT_EQ(count_matches("[ 1, 1 ]", metric), 1); // unit diagonal

// Expectations of parameter parameter names.
ASSERT_EQ(9, parameter_names[0].size());
Expand Down
Loading

0 comments on commit bb1ff9d

Please sign in to comment.