Skip to content

Commit

Permalink
structured writer - compiler error
Browse files Browse the repository at this point in the history
  • Loading branch information
mitzimorris committed Sep 15, 2023
1 parent 502b9ac commit c8e616c
Show file tree
Hide file tree
Showing 20 changed files with 126 additions and 147 deletions.
15 changes: 7 additions & 8 deletions src/stan/mcmc/hmc/base_hmc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include <stan/callbacks/logger.hpp>
#include <stan/callbacks/writer.hpp>
#include <stan/callbacks/json_writer.hpp>
#include <stan/callbacks/structured_writer.hpp>
#include <stan/mcmc/base_mcmc.hpp>
#include <stan/mcmc/hmc/hamiltonians/ps_point.hpp>
#include <boost/random/uniform_01.hpp>
Expand Down Expand Up @@ -67,13 +67,12 @@ class base_hmc : public base_mcmc {
/**
* write stepsize and elements of mass matrix as a JSON object
*/
template <typename Stream, typename Deleter = std::default_delete<Stream>>
void write_sampler_state_json(
callbacks::json_writer<Stream, Deleter>& json_writer) {
json_writer.begin_record();
json_writer.write("stepsize", get_nominal_stepsize());
json_writer.write("inv_metric", z_.inv_e_metric_);
json_writer.end_record();
void write_sampler_state_struct(
callbacks::structured_writer& struct_writer) {
struct_writer.begin_record();
struct_writer.write("stepsize", get_nominal_stepsize());
struct_writer.write("inv_metric", z_.inv_e_metric_);
struct_writer.end_record();
}

void get_sampler_diagnostic_names(std::vector<std::string>& model_names,
Expand Down
16 changes: 8 additions & 8 deletions src/stan/services/sample/fixed_param.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
#define STAN_SERVICES_SAMPLE_FIXED_PARAM_HPP

#include <stan/callbacks/interrupt.hpp>
#include <stan/callbacks/json_writer.hpp>
#include <stan/callbacks/logger.hpp>
#include <stan/callbacks/structured_writer.hpp>
#include <stan/callbacks/writer.hpp>
#include <stan/math/prim.hpp>
#include <stan/mcmc/fixed_param_sampler.hpp>
#include <stan/services/error_codes.hpp>
#include <stan/services/util/mcmc_writer.hpp>
#include <stan/services/util/generate_transitions.hpp>
#include <stan/services/util/create_rng.hpp>
#include <stan/services/util/generate_transitions.hpp>
#include <stan/services/util/initialize.hpp>
#include <stan/services/util/mcmc_writer.hpp>
#include <chrono>
#include <vector>
#include <iostream>
Expand Down Expand Up @@ -65,8 +65,8 @@ int fixed_param(Model& model, const stan::io::var_context& init,
}

stan::mcmc::fixed_param_sampler sampler;
callbacks::json_writer<std::ofstream> dummy_metric_writer;
services::util::mcmc_writer<std::ofstream> writer(
callbacks::structured_writer dummy_metric_writer;
services::util::mcmc_writer writer(
sample_writer, diagnostic_writer, dummy_metric_writer, logger);
Eigen::VectorXd cont_params(cont_vector.size());
for (size_t i = 0; i < cont_vector.size(); i++)
Expand Down Expand Up @@ -140,8 +140,8 @@ int fixed_param(Model& model, const std::size_t num_chains,
}
std::vector<boost::ecuyer1988> rngs;
std::vector<Eigen::VectorXd> cont_vectors;
std::vector<callbacks::json_writer<std::ofstream>> dummy_metric_writers;
std::vector<util::mcmc_writer<std::ofstream>> writers;
std::vector<callbacks::structured_writer> dummy_metric_writers;
std::vector<util::mcmc_writer> writers;
std::vector<stan::mcmc::sample> samples;
std::vector<stan::mcmc::fixed_param_sampler> samplers(num_chains);
rngs.reserve(num_chains);
Expand All @@ -157,7 +157,7 @@ int fixed_param(Model& model, const std::size_t num_chains,
Eigen::Map<Eigen::VectorXd>(cont_vector.data(), cont_vector.size()));
samples.emplace_back(cont_vectors[i], 0, 0);
dummy_metric_writers.emplace_back(
stan::callbacks::json_writer<std::ofstream>());
stan::callbacks::structured_writer());
writers.emplace_back(sample_writers[i], diagnostic_writers[i],
dummy_metric_writers[i], logger);
// Headers
Expand Down
48 changes: 19 additions & 29 deletions src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@

#include <stan/callbacks/interrupt.hpp>
#include <stan/callbacks/logger.hpp>
#include <stan/callbacks/structured_writer.hpp>
#include <stan/callbacks/writer.hpp>
#include <stan/callbacks/json_writer.hpp>
#include <stan/io/var_context.hpp>
#include <stan/math/prim.hpp>
#include <stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp>
#include <stan/services/error_codes.hpp>
#include <stan/services/util/run_adaptive_sampler.hpp>
#include <stan/services/util/create_rng.hpp>
#include <stan/services/util/initialize.hpp>
#include <stan/services/util/inv_metric.hpp>
#include <iostream>
#include <stan/services/util/run_adaptive_sampler.hpp>
#include <vector>

namespace stan {
Expand All @@ -26,8 +25,6 @@ namespace sample {
* stepsize and inverse metric.
*
* @tparam Model Model class
* @tparam Stream A type with with a valid `operator<<(std::string)`
* @tparam Deleter A class with a valid `operator()` method for deleting the
* @param[in] model Input model (with data already instantiated)
* @param[in] init var context for initialization
* @param[in] init_inv_metric var context exposing an initial dense
Expand Down Expand Up @@ -58,8 +55,7 @@ namespace sample {
* @param[in,out] metric_writer Writer for tuning params
* @return error_codes::OK if successful
*/
template <class Model, typename Stream,
typename Deleter = std::default_delete<Stream>>
template <class Model>
int hmc_nuts_dense_e_adapt(
Model& model, const stan::io::var_context& init,
const stan::io::var_context& init_inv_metric, unsigned int random_seed,
Expand All @@ -70,7 +66,7 @@ int hmc_nuts_dense_e_adapt(
unsigned int window, callbacks::interrupt& interrupt,
callbacks::logger& logger, callbacks::writer& init_writer,
callbacks::writer& sample_writer, callbacks::writer& diagnostic_writer,
callbacks::json_writer<Stream, Deleter>& metric_writer) {
callbacks::structured_writer& metric_writer) {
boost::ecuyer1988 rng = util::create_rng(random_seed, chain);

std::vector<double> cont_vector;
Expand Down Expand Up @@ -117,8 +113,6 @@ int hmc_nuts_dense_e_adapt(
* with a pre-specified dense metric.
*
* @tparam Model Model class
* @tparam Stream A type with with a valid `operator<<(std::string)`
* @tparam Deleter A class with a valid `operator()` method for deleting the
* @param[in] model Input model (with data already instantiated)
* @param[in] init var context for initialization
* @param[in] init_inv_metric var context exposing an initial dense
Expand Down Expand Up @@ -159,7 +153,7 @@ int hmc_nuts_dense_e_adapt(
unsigned int window, callbacks::interrupt& interrupt,
callbacks::logger& logger, callbacks::writer& init_writer,
callbacks::writer& sample_writer, callbacks::writer& diagnostic_writer) {
callbacks::json_writer<std::ofstream> dummy_metric_writer;
callbacks::structured_writer dummy_metric_writer;
return hmc_nuts_dense_e_adapt(
model, init, init_inv_metric, random_seed, chain, init_radius, num_warmup,
num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
Expand All @@ -174,8 +168,6 @@ int hmc_nuts_dense_e_adapt(
* parameters stepsize and inverse metric.
*
* @tparam Model Model class
* @tparam Stream A type with with a valid `operator<<(std::string)`
* @tparam Deleter A class with a valid `operator()` method for deleting the
* @param[in] model Input model (with data already instantiated)
* @param[in] init var context for initialization
* @param[in] random_seed random seed for the random number generator
Expand Down Expand Up @@ -204,8 +196,7 @@ int hmc_nuts_dense_e_adapt(
* @param[in,out] metric_writer Writer for tuning params
* @return error_codes::OK if successful
*/
template <class Model, typename Stream,
typename Deleter = std::default_delete<Stream>>
template <class Model>
int hmc_nuts_dense_e_adapt(
Model& model, const stan::io::var_context& init, unsigned int random_seed,
unsigned int chain, double init_radius, int num_warmup, int num_samples,
Expand All @@ -215,7 +206,7 @@ int hmc_nuts_dense_e_adapt(
unsigned int window, callbacks::interrupt& interrupt,
callbacks::logger& logger, callbacks::writer& init_writer,
callbacks::writer& sample_writer, callbacks::writer& diagnostic_writer,
callbacks::json_writer<Stream, Deleter>& metric_writer) {
callbacks::structured_writer& metric_writer) {
stan::io::dump dmp
= util::create_unit_e_dense_inv_metric(model.num_params_r());
stan::io::var_context& unit_e_metric = dmp;
Expand Down Expand Up @@ -272,7 +263,7 @@ int hmc_nuts_dense_e_adapt(
stan::io::dump dmp
= util::create_unit_e_dense_inv_metric(model.num_params_r());
stan::io::var_context& unit_e_metric = dmp;
callbacks::json_writer<std::ofstream> dummy_metric_writer;
callbacks::structured_writer dummy_metric_writer;
return hmc_nuts_dense_e_adapt(
model, init, unit_e_metric, random_seed, chain, init_radius, num_warmup,
num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
Expand All @@ -287,15 +278,14 @@ int hmc_nuts_dense_e_adapt(
* stepsize and inverse metric.
*
* @tparam Model Model class
* @tparam Stream A type with with a valid `operator<<(std::string)`
* @tparam Deleter A class with a valid `operator()` method for deleting the
* @tparam InitContextPtr A pointer with underlying type derived from
`stan::io::var_context`
* @tparam InitInvContextPtr A pointer with underlying type derived from
`stan::io::var_context`
* @tparam InitWriter A type derived from `stan::callbacks::writer`
* @tparam SamplerWriter A type derived from `stan::callbacks::writer`
* @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
* @tparam InitWriter A type derived from `stan::callbacks::writer`
* @tparam MetricWriter A type derived from `stan::callbacks::structured_writer`
* @param[in] model Input model (with data already instantiated)
* @param[in] num_chains The number of chains to run in parallel. `init`,
* `init_inv_metric`, `init_writer`, `sample_writer`, and `diagnostic_writer`
Expand Down Expand Up @@ -337,7 +327,7 @@ int hmc_nuts_dense_e_adapt(
*/
template <class Model, typename InitContextPtr, typename InitInvContextPtr,
typename InitWriter, typename SampleWriter, typename DiagnosticWriter,
typename Stream, typename Deleter = std::default_delete<Stream>>
typename MetricWriter>
int hmc_nuts_dense_e_adapt(
Model& model, size_t num_chains, const std::vector<InitContextPtr>& init,
const std::vector<InitInvContextPtr>& init_inv_metric,
Expand All @@ -350,7 +340,7 @@ int hmc_nuts_dense_e_adapt(
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer,
std::vector<callbacks::json_writer<Stream, Deleter>>& metric_writer) {
std::vector<MetricWriter>& metric_writer) {
if (num_chains == 1) {
return hmc_nuts_dense_e_adapt(
model, *init[0], *init_inv_metric[0], random_seed, init_chain_id,
Expand Down Expand Up @@ -471,11 +461,11 @@ int hmc_nuts_dense_e_adapt(
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer) {
std::vector<callbacks::json_writer<std::ofstream>> dummy_metric_writer;
std::vector<stan::callbacks::structured_writer> dummy_metric_writer;
dummy_metric_writer.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
dummy_metric_writer.emplace_back(
stan::callbacks::json_writer<std::ofstream>());
stan::callbacks::structured_writer());
}
if (num_chains == 1) {
return hmc_nuts_dense_e_adapt(
Expand Down Expand Up @@ -504,6 +494,7 @@ int hmc_nuts_dense_e_adapt(
* @tparam InitWriter A type derived from `stan::callbacks::writer`
* @tparam SamplerWriter A type derived from `stan::callbacks::writer`
* @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
* @tparam MetricWriter A type derived from `stan::callbacks::structured_writer`
* @param[in] model Input model (with data already instantiated)
* @param[in] num_chains The number of chains to run in parallel. `init`,
* `init_writer`, `sample_writer`, and `diagnostic_writer` must be the same
Expand Down Expand Up @@ -541,8 +532,7 @@ int hmc_nuts_dense_e_adapt(
* @return error_codes::OK if successful
*/
template <class Model, typename InitContextPtr, typename InitWriter,
typename SampleWriter, typename DiagnosticWriter, typename Stream,
typename Deleter = std::default_delete<Stream>>
typename SampleWriter, typename DiagnosticWriter, typename MetricWriter>
int hmc_nuts_dense_e_adapt(
Model& model, size_t num_chains, const std::vector<InitContextPtr>& init,
unsigned int random_seed, unsigned int init_chain_id, double init_radius,
Expand All @@ -554,7 +544,7 @@ int hmc_nuts_dense_e_adapt(
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer,
std::vector<callbacks::json_writer<Stream, Deleter>>& metric_writer) {
std::vector<MetricWriter>& metric_writer) {
std::vector<std::unique_ptr<stan::io::dump>> unit_e_metric;
unit_e_metric.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
Expand Down Expand Up @@ -641,11 +631,11 @@ int hmc_nuts_dense_e_adapt(
unit_e_metric.emplace_back(std::make_unique<stan::io::dump>(
util::create_unit_e_dense_inv_metric(model.num_params_r())));
}
std::vector<callbacks::json_writer<std::ofstream>> dummy_metric_writer;
std::vector<callbacks::structured_writer> dummy_metric_writer;
dummy_metric_writer.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
dummy_metric_writer.emplace_back(
stan::callbacks::json_writer<std::ofstream>());
stan::callbacks::structured_writer());
}
if (num_chains == 1) {
return hmc_nuts_dense_e_adapt(
Expand Down
Loading

0 comments on commit c8e616c

Please sign in to comment.