Skip to content

Commit

Permalink
cleanup stan_csv_reader, chainset
Browse files Browse the repository at this point in the history
  • Loading branch information
mitzimorris committed Sep 26, 2024
1 parent ce31e48 commit ed7f0a0
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 80 deletions.
36 changes: 12 additions & 24 deletions src/stan/io/stan_csv_reader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,12 @@ class stan_csv_reader {
stan_csv_reader() {}
~stan_csv_reader() {}

static bool read_metadata(std::istream& in, stan_csv_metadata& metadata,
std::ostream* out) {
static void read_metadata(std::istream& in, stan_csv_metadata& metadata) {
std::stringstream ss;
std::string line;

if (in.peek() != '#')
return false;
return;
while (in.peek() == '#') {
std::getline(in, line);
ss << line << '\n';
Expand Down Expand Up @@ -193,10 +192,6 @@ class stan_csv_reader {
std::stringstream(value) >> metadata.max_depth;
}
}
if (ss.good() == true)
return false;

return true;
} // read_metadata

static bool read_header(std::istream& in, std::vector<std::string>& header,
Expand Down Expand Up @@ -325,11 +320,10 @@ class stan_csv_reader {
if (cols == -1) {
cols = current_cols;
} else if (cols != current_cols) {
if (out)
*out << "Error: expected " << cols << " columns, but found "
<< current_cols << " instead for row " << rows + 1
<< std::endl;
return false;
std::stringstream msg;
msg << "Error: expected " << cols << " columns, but found "
<< current_cols << " instead for row " << rows + 1;
throw std::invalid_argument(msg.str());
}
rows++;
}
Expand Down Expand Up @@ -357,10 +351,10 @@ class stan_csv_reader {
/**
* Parses the file.
*
* Warns if missing metatdata, inconsistencies between metadata config
* and parsed data rows.
* Throws exception if contents can't be parsed into header + data rows
* or if sample size doesn't match metadata config.
*
* Throws exception if no header row found.
* Emits warning message if can't parse sampler adaptation.
*
* @param[in] in input stream to parse
* @param[out] out output stream to send messages
Expand All @@ -369,13 +363,8 @@ class stan_csv_reader {
stan_csv data;
std::string line;

if (!read_metadata(in, data.metadata, out)) {
if (out)
*out << "Warning: non-fatal error reading metadata" << std::endl;
}
read_metadata(in, data.metadata);
if (!read_header(in, data.header, out)) {
if (out)
*out << "Error: error reading header" << std::endl;
throw std::invalid_argument("Error with header of input file in parse");
}

Expand Down Expand Up @@ -408,10 +397,9 @@ class stan_csv_reader {
int expected_samples = data.metadata.num_samples / data.metadata.thin;
if (expected_samples != data.samples.rows()) {
std::stringstream msg;
msg << ", expecting " << expected_samples << " samples, found "
msg << "Error reading samples, expecting " << expected_samples << " samples, found "
<< data.samples.rows();
if (out)
*out << "Warning: error reading samples" << msg.str() << std::endl;
throw std::invalid_argument(msg.str());
}
}
return data;
Expand Down
65 changes: 36 additions & 29 deletions src/stan/mcmc/chainset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,47 +28,56 @@ namespace stan {
namespace mcmc {
using Eigen::Dynamic;

/**
* Checks that a Stan CSV file contains both a header row
* and a set of draws from the posterior.
* Throws exception if either are missing.
*
* @param stan_csv parsed csv file object
*/
void validate_sample(const stan::io::stan_csv& stan_csv) {
if (stan_csv.header.empty()) {
throw std::invalid_argument("Error: Stan CSV file missing header row");
}
if (stan_csv.samples.size() == 0) {
throw std::invalid_argument("Error: no sample found in Stan CSV file");
}
}

/**
* Reports the expected number of post-warmup draws in the CSV output file.
*
* @param stan_csv parsed csv file object
* @return expected number of draws
*/
size_t thinned_samples(const stan::io::stan_csv& stan_csv) {
size_t thinned_samples = stan_csv.metadata.num_samples;
if (stan_csv.metadata.thin > 1) {
thinned_samples = thinned_samples / stan_csv.metadata.thin;
}
return thinned_samples;
}

/**
* An <code>mcmc::chainset</code> object manages the post-warmup draws
* across a set of MCMC chains, which all have the same number or samples.
* across a set of MCMC chains, which all have the same number of samples.
*
* <p><b>Storage Order</b>: Storage is column/last-index major.
* @note samples are stored in column major, i.e., each column corresponds to
* an output variable (element).
*
*/
class chainset {
private:
size_t num_samples_;
std::vector<std::string> param_names_;
std::vector<Eigen::MatrixXd> chains_;

static size_t thinned_samples(const stan::io::stan_csv& stan_csv) {
size_t thinned_samples = stan_csv.metadata.num_samples;
if (stan_csv.metadata.thin > 1) {
thinned_samples = thinned_samples / stan_csv.metadata.thin;
}
return thinned_samples;
}

static bool is_valid(const stan::io::stan_csv& stan_csv) {
if (stan_csv.header.empty()) {
return false;
}
if (stan_csv.samples.size() == 0) {
return false;
}
if (stan_csv.samples.rows() != thinned_samples(stan_csv)) {
return false;
}
return true;
}

/**
* Process first chain: record header, thinned samples,
* add samples to vector chains.
*/
void init_from_stan_csv(const stan::io::stan_csv& stan_csv) {
if (!is_valid(stan_csv)) {
throw std::invalid_argument("Invalid sample");
}
validate_sample(stan_csv);
if (chains_.size() > 0) {
throw std::invalid_argument("Cannot re-initialize chains object");
}
Expand All @@ -82,9 +91,7 @@ class chainset {
* append to vector chains.
*/
void add(const stan::io::stan_csv& stan_csv) {
if (!is_valid(stan_csv)) {
throw std::invalid_argument("Invalid sample");
}
validate_sample(stan_csv);
if (stan_csv.header.size() != num_params()) {
throw std::invalid_argument(
"Error add(stan_csv): number of columns in"
Expand Down
File renamed without changes.
18 changes: 11 additions & 7 deletions src/test/unit/io/stan_csv_reader_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class StanIoStanCsvReader : public testing::Test {

bernoulli_thin_stream.open(
"src/test/unit/io/test_csv_files/bernoulli_thin.csv");
bernoulli_trunc_stream.open(
"src/test/unit/io/test_csv_files/bernoulli_corrupt.csv");
bernoulli_warmup_stream.open(
"src/test/unit/io/test_csv_files/bernoulli_warmup.csv");
missing_draws_stream.open(
Expand All @@ -46,6 +48,7 @@ class StanIoStanCsvReader : public testing::Test {
epil0_stream.close();
blocker_nondiag0_stream.close();
bernoulli_thin_stream.close();
bernoulli_trunc_stream.close();
bernoulli_warmup_stream.close();
missing_draws_stream.close();
fixed_param_stream.close();
Expand All @@ -59,15 +62,15 @@ class StanIoStanCsvReader : public testing::Test {
std::ifstream eight_schools_stream;
std::ifstream header3_stream;
std::ifstream bernoulli_thin_stream;
std::ifstream bernoulli_trunc_stream;
std::ifstream bernoulli_warmup_stream;
std::ifstream missing_draws_stream;
std::ifstream fixed_param_stream;
};

TEST_F(StanIoStanCsvReader, read_metadata1) {
stan::io::stan_csv_metadata metadata;
EXPECT_TRUE(
stan::io::stan_csv_reader::read_metadata(metadata1_stream, metadata, 0));
stan::io::stan_csv_reader::read_metadata(metadata1_stream, metadata);

EXPECT_EQ(2, metadata.stan_version_major);
EXPECT_EQ(9, metadata.stan_version_minor);
Expand All @@ -93,9 +96,7 @@ TEST_F(StanIoStanCsvReader, read_metadata1) {

TEST_F(StanIoStanCsvReader, read_metadata3) {
stan::io::stan_csv_metadata metadata;

EXPECT_TRUE(
stan::io::stan_csv_reader::read_metadata(metadata3_stream, metadata, 0));
stan::io::stan_csv_reader::read_metadata(metadata3_stream, metadata);

EXPECT_EQ(2, metadata.stan_version_major);
EXPECT_EQ(9, metadata.stan_version_minor);
Expand Down Expand Up @@ -569,8 +570,11 @@ TEST_F(StanIoStanCsvReader, skip_warmup) {
TEST_F(StanIoStanCsvReader, missing_data) {
stan::io::stan_csv missing_draws;
std::stringstream out;
missing_draws = stan::io::stan_csv_reader::parse(missing_draws_stream, &out);
ASSERT_TRUE(boost::algorithm::starts_with(out.str(), "Warning:"));
EXPECT_THROW(stan::io::stan_csv_reader::parse(missing_draws_stream, &out),
std::invalid_argument);
stan::io::stan_csv bernoulli_trunc;
EXPECT_THROW(stan::io::stan_csv_reader::parse(bernoulli_trunc_stream, &out),
std::invalid_argument);
}

TEST_F(StanIoStanCsvReader, thinned_data) {
Expand Down
28 changes: 8 additions & 20 deletions src/test/unit/mcmc/chainset_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ class McmcChains : public testing::Test {
bernoulli_500_stream.open(
"src/test/unit/mcmc/test_csv_files/bernoulli_500.csv",
std::ifstream::in);
bernoulli_corrupt_stream.open(
"src/test/unit/mcmc/test_csv_files/bernoulli_corrupt.csv",
std::ifstream::in);
bernoulli_default_stream.open(
"src/test/unit/mcmc/test_csv_files/bernoulli_default.csv",
std::ifstream::in);
Expand All @@ -43,15 +40,14 @@ class McmcChains : public testing::Test {
"src/test/unit/mcmc/test_csv_files/eight_schools_5iters_2.csv",
std::ifstream::in);

if (!bernoulli_500_stream || !bernoulli_corrupt_stream
if (!bernoulli_500_stream
|| !bernoulli_default_stream || !bernoulli_thin_stream
|| !bernoulli_warmup_stream || !bernoulli_zeta_stream
|| !eight_schools_1_stream || !eight_schools_2_stream
|| !eight_schools_5iters_1_stream || !eight_schools_5iters_2_stream) {
FAIL() << "Failed to open one or more test files";
}
bernoulli_500_stream.seekg(0, std::ios::beg);
bernoulli_corrupt_stream.seekg(0, std::ios::beg);
bernoulli_default_stream.seekg(0, std::ios::beg);
bernoulli_thin_stream.seekg(0, std::ios::beg);
bernoulli_warmup_stream.seekg(0, std::ios::beg);
Expand All @@ -63,8 +59,6 @@ class McmcChains : public testing::Test {

bernoulli_500
= stan::io::stan_csv_reader::parse(bernoulli_500_stream, &out);
bernoulli_corrupt
= stan::io::stan_csv_reader::parse(bernoulli_corrupt_stream, &out);
bernoulli_default
= stan::io::stan_csv_reader::parse(bernoulli_default_stream, &out);
bernoulli_thin
Expand All @@ -85,7 +79,6 @@ class McmcChains : public testing::Test {

void TearDown() override {
bernoulli_500_stream.close();
bernoulli_corrupt_stream.close();
bernoulli_default_stream.close();
bernoulli_thin_stream.close();
bernoulli_warmup_stream.close();
Expand All @@ -98,14 +91,14 @@ class McmcChains : public testing::Test {

std::stringstream out;

std::ifstream bernoulli_500_stream, bernoulli_corrupt_stream,
bernoulli_default_stream, bernoulli_thin_stream, bernoulli_warmup_stream,
bernoulli_zeta_stream, eight_schools_1_stream, eight_schools_2_stream,
eight_schools_5iters_1_stream, eight_schools_5iters_2_stream;
std::ifstream bernoulli_500_stream,
bernoulli_default_stream, bernoulli_thin_stream, bernoulli_warmup_stream,
bernoulli_zeta_stream, eight_schools_1_stream, eight_schools_2_stream,
eight_schools_5iters_1_stream, eight_schools_5iters_2_stream;

stan::io::stan_csv bernoulli_500, bernoulli_corrupt, bernoulli_default,
bernoulli_thin, bernoulli_warmup, bernoulli_zeta, eight_schools_1,
eight_schools_2, eight_schools_5iters_1, eight_schools_5iters_2;
stan::io::stan_csv bernoulli_500, bernoulli_default,
bernoulli_thin, bernoulli_warmup, bernoulli_zeta, eight_schools_1,
eight_schools_2, eight_schools_5iters_1, eight_schools_5iters_2;
};

TEST_F(McmcChains, constructor) {
Expand Down Expand Up @@ -139,11 +132,6 @@ TEST_F(McmcChains, constructor) {
TEST_F(McmcChains, addFail) {
std::vector<stan::io::stan_csv> bad;
bad.push_back(bernoulli_default);
bad.push_back(bernoulli_corrupt);
EXPECT_THROW(stan::mcmc::chainset fail(bad), std::invalid_argument);

bad.clear();
bad.push_back(bernoulli_default);
bad.push_back(bernoulli_500);
EXPECT_THROW(stan::mcmc::chainset fail(bad), std::invalid_argument);

Expand Down

0 comments on commit ed7f0a0

Please sign in to comment.