Skip to content

Commit

Permalink
changes per code review
Browse files Browse the repository at this point in the history
  • Loading branch information
mitzimorris committed Sep 19, 2023
1 parent 7ccd590 commit d99cea2
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 35 deletions.
12 changes: 2 additions & 10 deletions src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,11 +462,7 @@ int hmc_nuts_dense_e_adapt(
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer) {
std::vector<stan::callbacks::structured_writer> dummy_metric_writer;
dummy_metric_writer.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
dummy_metric_writer.emplace_back(stan::callbacks::structured_writer());
}
std::vector<stan::callbacks::structured_writer> dummy_metric_writer(num_chains);
if (num_chains == 1) {
return hmc_nuts_dense_e_adapt(
model, *init[0], *init_inv_metric[0], random_seed, init_chain_id,
Expand Down Expand Up @@ -632,11 +628,7 @@ int hmc_nuts_dense_e_adapt(
unit_e_metric.emplace_back(std::make_unique<stan::io::dump>(
util::create_unit_e_dense_inv_metric(model.num_params_r())));
}
std::vector<stan::callbacks::structured_writer> dummy_metric_writer;
dummy_metric_writer.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
dummy_metric_writer.emplace_back(stan::callbacks::structured_writer());
}
std::vector<stan::callbacks::structured_writer> dummy_metric_writer(num_chains);
if (num_chains == 1) {
return hmc_nuts_dense_e_adapt(
model, *init[0], *unit_e_metric[0], random_seed, init_chain_id,
Expand Down
12 changes: 2 additions & 10 deletions src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,11 +462,7 @@ int hmc_nuts_diag_e_adapt(
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer) {
std::vector<stan::callbacks::structured_writer> dummy_metric_writer;
dummy_metric_writer.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
dummy_metric_writer.emplace_back(stan::callbacks::structured_writer());
}
std::vector<stan::callbacks::structured_writer> dummy_metric_writer(num_chains);
if (num_chains == 1) {
return hmc_nuts_diag_e_adapt(
model, *init[0], *init_inv_metric[0], random_seed, init_chain_id,
Expand Down Expand Up @@ -634,11 +630,7 @@ int hmc_nuts_diag_e_adapt(
unit_e_metric.emplace_back(std::make_unique<stan::io::dump>(
util::create_unit_e_diag_inv_metric(model.num_params_r())));
}
std::vector<stan::callbacks::structured_writer> dummy_metric_writer;
dummy_metric_writer.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
dummy_metric_writer.emplace_back(stan::callbacks::structured_writer());
}
std::vector<stan::callbacks::structured_writer> dummy_metric_writer(num_chains);
if (num_chains == 1) {
return hmc_nuts_diag_e_adapt(
model, *init[0], *unit_e_metric[0], random_seed, init_chain_id,
Expand Down
6 changes: 1 addition & 5 deletions src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,7 @@ int hmc_nuts_unit_e_adapt(
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer) {
std::vector<stan::callbacks::structured_writer> dummy_metric_writer;
dummy_metric_writer.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
dummy_metric_writer.emplace_back(stan::callbacks::structured_writer());
}
std::vector<stan::callbacks::structured_writer> dummy_metric_writer(num_chains);
if (num_chains == 1) {
return hmc_nuts_unit_e_adapt(
model, *init[0], random_seed, init_chain_id, init_radius, num_warmup,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <src/test/unit/services/util.hpp>
#include <test/test-models/good/optimization/rosenbrock.hpp>
#include <test/unit/services/instrumented_callbacks.hpp>
#include <rapidjson/document.h>
#include <gtest/gtest.h>
#include <iostream>

Expand Down Expand Up @@ -103,7 +102,7 @@ TEST_F(ServicesSampleHmcNutsDenseEAdaptParMatch, single_multi_match) {
int num_output_lines = (num_warmup + num_samples) / num_thin;
EXPECT_EQ((num_warmup + num_samples) * num_chains, interrupt.call_count());
for (int i = 0; i < num_chains; ++i) {
stan::test::unit::instrumented_writer seq_init; // MM: what does this do?
stan::test::unit::instrumented_writer seq_init;
stan::test::unit::instrumented_writer seq_diagnostic;
return_code = stan::services::sample::hmc_nuts_dense_e_adapt(
*model, *(context[i]), random_seed, i, init_radius, num_warmup,
Expand All @@ -124,8 +123,7 @@ TEST_F(ServicesSampleHmcNutsDenseEAdaptParMatch, single_multi_match) {
par_res.push_back(par_mat);

par_metrics.push_back(ss_metric[i].str());
rapidjson::Document document;
ASSERT_FALSE(document.Parse<0>(par_metrics[i].c_str()).HasParseError());
ASSERT_TRUE(stan::test::is_valid_JSON(par_metrics[i]));
EXPECT_EQ(count_matches("stepsize", par_metrics[i]), 1);
EXPECT_EQ(count_matches("inv_metric", par_metrics[i]), 1);
EXPECT_EQ(count_matches("[", par_metrics[i]), 3); // list has 2 rows
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <test/test-models/good/optimization/rosenbrock.hpp>
#include <test/unit/services/instrumented_callbacks.hpp>
#include <test/unit/util.hpp>
#include <rapidjson/document.h>
#include <gtest/gtest.h>
#include <iostream>

Expand Down Expand Up @@ -123,8 +122,7 @@ TEST_F(ServicesSampleHmcNutsDiagEAdaptParMatch, single_multi_match) {
= stan::test::read_stan_sample_csv(sub_par_stream, 80, 9);
par_res.push_back(par_mat);
par_metrics.push_back(ss_metric[i].str());
rapidjson::Document document;
ASSERT_FALSE(document.Parse<0>(par_metrics[i].c_str()).HasParseError());
ASSERT_TRUE(stan::test::is_valid_JSON(par_metrics[i]));
EXPECT_EQ(count_matches("stepsize", par_metrics[i]), 1);
EXPECT_EQ(count_matches("inv_metric", par_metrics[i]), 1);
EXPECT_EQ(count_matches("[", par_metrics[i]), 1); // single list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <test/test-models/good/optimization/rosenbrock.hpp>
#include <test/unit/services/instrumented_callbacks.hpp>
#include <test/unit/util.hpp>
#include <rapidjson/document.h>
#include <gtest/gtest.h>
#include <iostream>

Expand Down Expand Up @@ -127,8 +126,7 @@ TEST_F(ServicesSampleHmcNutsUnitEAdaptPar, parameter_checks) {
diagnostic_values = diagnostic[i].vector_double_values();
std::string metric = ss_metric[i].str();
// Adapted metric
rapidjson::Document document;
ASSERT_FALSE(document.Parse<0>(metric.c_str()).HasParseError());
ASSERT_TRUE(stan::test::is_valid_JSON(metric));
EXPECT_EQ(count_matches("stepsize", metric), 1);
EXPECT_EQ(count_matches("inv_metric", metric), 1);
EXPECT_EQ(count_matches("[", metric), 1); // single list
Expand Down
1 change: 1 addition & 0 deletions src/test/unit/services/util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Eigen::MatrixXd read_stan_sample_csv(std::istringstream& in, int rows,
}
return res;
}

} // namespace test
} // namespace stan
#endif
10 changes: 10 additions & 0 deletions src/test/unit/util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <boost/algorithm/string.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <gtest/gtest.h>
#include <rapidjson/document.h>
#include <iostream>
#include <string>
#include <vector>
Expand Down Expand Up @@ -166,6 +167,15 @@ void reset_std_streams() {
cerr_buf = 0;
}

/**
* Validate JSON using rapidjson parser.
* @param text String of JSON
*/
bool is_valid_JSON(std::string& text) {
rapidjson::Document document;
return !document.Parse<0>(text.c_str()).HasParseError();
}

} // namespace test
} // namespace stan

Expand Down

0 comments on commit d99cea2

Please sign in to comment.