diff --git a/src/stan/io/stan_csv_reader.hpp b/src/stan/io/stan_csv_reader.hpp index 09e378eb54..6718ef053f 100644 --- a/src/stan/io/stan_csv_reader.hpp +++ b/src/stan/io/stan_csv_reader.hpp @@ -101,13 +101,12 @@ class stan_csv_reader { stan_csv_reader() {} ~stan_csv_reader() {} - static bool read_metadata(std::istream& in, stan_csv_metadata& metadata, - std::ostream* out) { + static void read_metadata(std::istream& in, stan_csv_metadata& metadata) { std::stringstream ss; std::string line; if (in.peek() != '#') - return false; + return; while (in.peek() == '#') { std::getline(in, line); ss << line << '\n'; @@ -193,10 +192,6 @@ class stan_csv_reader { std::stringstream(value) >> metadata.max_depth; } } - if (ss.good() == true) - return false; - - return true; } // read_metadata static bool read_header(std::istream& in, std::vector& header, @@ -325,11 +320,10 @@ class stan_csv_reader { if (cols == -1) { cols = current_cols; } else if (cols != current_cols) { - if (out) - *out << "Error: expected " << cols << " columns, but found " - << current_cols << " instead for row " << rows + 1 - << std::endl; - return false; + std::stringstream msg; + msg << "Error: expected " << cols << " columns, but found " + << current_cols << " instead for row " << rows + 1; + throw std::invalid_argument(msg.str()); } rows++; } @@ -357,10 +351,10 @@ class stan_csv_reader { /** * Parses the file. * - * Warns if missing metatdata, inconsistencies between metadata config - * and parsed data rows. + * Throws exception if contents can't be parsed into header + data rows + * or if sample size doesn't match metadata config. * - * Throws exception if no header row found. + * Emits warning message if can't parse sampler adaptation. * * @param[in] in input stream to parse * @param[out] out output stream to send messages @@ -369,13 +363,8 @@ class stan_csv_reader { stan_csv data; std::string line; - if (!read_metadata(in, data.metadata, out)) { - if (out) - *out << "Warning: non-fatal error reading metadata" << std::endl; - } + read_metadata(in, data.metadata); if (!read_header(in, data.header, out)) { - if (out) - *out << "Error: error reading header" << std::endl; throw std::invalid_argument("Error with header of input file in parse"); } @@ -408,10 +397,9 @@ class stan_csv_reader { int expected_samples = data.metadata.num_samples / data.metadata.thin; if (expected_samples != data.samples.rows()) { std::stringstream msg; - msg << ", expecting " << expected_samples << " samples, found " + msg << "Error reading samples, expecting " << expected_samples << " samples, found " << data.samples.rows(); - if (out) - *out << "Warning: error reading samples" << msg.str() << std::endl; + throw std::invalid_argument(msg.str()); } } return data; diff --git a/src/stan/mcmc/chainset.hpp b/src/stan/mcmc/chainset.hpp index fbc8851174..ccfc03058c 100644 --- a/src/stan/mcmc/chainset.hpp +++ b/src/stan/mcmc/chainset.hpp @@ -28,11 +28,43 @@ namespace stan { namespace mcmc { using Eigen::Dynamic; +/** + * Checks that a Stan CSV file contains both a header row + * and a set of draws from the posterior. + * Throws exception if either are missing. + * + * @param stan_csv parsed csv file object + */ +void validate_sample(const stan::io::stan_csv& stan_csv) { + if (stan_csv.header.empty()) { + throw std::invalid_argument("Error: Stan CSV file missing header row"); + } + if (stan_csv.samples.size() == 0) { + throw std::invalid_argument("Error: no sample found in Stan CSV file"); + } +} + +/** + * Reports the expected number of post-warmup draws in the CSV output file. + * + * @param stan_csv parsed csv file object + * @return expected number of draws + */ +size_t thinned_samples(const stan::io::stan_csv& stan_csv) { + size_t thinned_samples = stan_csv.metadata.num_samples; + if (stan_csv.metadata.thin > 1) { + thinned_samples = thinned_samples / stan_csv.metadata.thin; + } + return thinned_samples; +} + /** * An mcmc::chainset object manages the post-warmup draws - * across a set of MCMC chains, which all have the same number or samples. + * across a set of MCMC chains, which all have the same number of samples. * - *

Storage Order: Storage is column/last-index major. + * @note samples are stored in column major, i.e., each column corresponds to + * an output variable (element). + * */ class chainset { private: @@ -40,35 +72,12 @@ class chainset { std::vector param_names_; std::vector chains_; - static size_t thinned_samples(const stan::io::stan_csv& stan_csv) { - size_t thinned_samples = stan_csv.metadata.num_samples; - if (stan_csv.metadata.thin > 1) { - thinned_samples = thinned_samples / stan_csv.metadata.thin; - } - return thinned_samples; - } - - static bool is_valid(const stan::io::stan_csv& stan_csv) { - if (stan_csv.header.empty()) { - return false; - } - if (stan_csv.samples.size() == 0) { - return false; - } - if (stan_csv.samples.rows() != thinned_samples(stan_csv)) { - return false; - } - return true; - } - /** * Process first chain: record header, thinned samples, * add samples to vector chains. */ void init_from_stan_csv(const stan::io::stan_csv& stan_csv) { - if (!is_valid(stan_csv)) { - throw std::invalid_argument("Invalid sample"); - } + validate_sample(stan_csv); if (chains_.size() > 0) { throw std::invalid_argument("Cannot re-initialize chains object"); } @@ -82,9 +91,7 @@ class chainset { * append to vector chains. */ void add(const stan::io::stan_csv& stan_csv) { - if (!is_valid(stan_csv)) { - throw std::invalid_argument("Invalid sample"); - } + validate_sample(stan_csv); if (stan_csv.header.size() != num_params()) { throw std::invalid_argument( "Error add(stan_csv): number of columns in" diff --git a/src/test/unit/mcmc/test_csv_files/bernoulli_corrupt.csv b/src/test/unit/io/bernoulli_corrupt.csv similarity index 100% rename from src/test/unit/mcmc/test_csv_files/bernoulli_corrupt.csv rename to src/test/unit/io/bernoulli_corrupt.csv diff --git a/src/test/unit/io/stan_csv_reader_test.cpp b/src/test/unit/io/stan_csv_reader_test.cpp index 98744b147d..ffe75aae32 100644 --- a/src/test/unit/io/stan_csv_reader_test.cpp +++ b/src/test/unit/io/stan_csv_reader_test.cpp @@ -26,6 +26,8 @@ class StanIoStanCsvReader : public testing::Test { bernoulli_thin_stream.open( "src/test/unit/io/test_csv_files/bernoulli_thin.csv"); + bernoulli_trunc_stream.open( + "src/test/unit/io/test_csv_files/bernoulli_corrupt.csv"); bernoulli_warmup_stream.open( "src/test/unit/io/test_csv_files/bernoulli_warmup.csv"); missing_draws_stream.open( @@ -46,6 +48,7 @@ class StanIoStanCsvReader : public testing::Test { epil0_stream.close(); blocker_nondiag0_stream.close(); bernoulli_thin_stream.close(); + bernoulli_trunc_stream.close(); bernoulli_warmup_stream.close(); missing_draws_stream.close(); fixed_param_stream.close(); @@ -59,6 +62,7 @@ class StanIoStanCsvReader : public testing::Test { std::ifstream eight_schools_stream; std::ifstream header3_stream; std::ifstream bernoulli_thin_stream; + std::ifstream bernoulli_trunc_stream; std::ifstream bernoulli_warmup_stream; std::ifstream missing_draws_stream; std::ifstream fixed_param_stream; @@ -66,8 +70,7 @@ class StanIoStanCsvReader : public testing::Test { TEST_F(StanIoStanCsvReader, read_metadata1) { stan::io::stan_csv_metadata metadata; - EXPECT_TRUE( - stan::io::stan_csv_reader::read_metadata(metadata1_stream, metadata, 0)); + stan::io::stan_csv_reader::read_metadata(metadata1_stream, metadata); EXPECT_EQ(2, metadata.stan_version_major); EXPECT_EQ(9, metadata.stan_version_minor); @@ -93,9 +96,7 @@ TEST_F(StanIoStanCsvReader, read_metadata1) { TEST_F(StanIoStanCsvReader, read_metadata3) { stan::io::stan_csv_metadata metadata; - - EXPECT_TRUE( - stan::io::stan_csv_reader::read_metadata(metadata3_stream, metadata, 0)); + stan::io::stan_csv_reader::read_metadata(metadata3_stream, metadata); EXPECT_EQ(2, metadata.stan_version_major); EXPECT_EQ(9, metadata.stan_version_minor); @@ -569,8 +570,11 @@ TEST_F(StanIoStanCsvReader, skip_warmup) { TEST_F(StanIoStanCsvReader, missing_data) { stan::io::stan_csv missing_draws; std::stringstream out; - missing_draws = stan::io::stan_csv_reader::parse(missing_draws_stream, &out); - ASSERT_TRUE(boost::algorithm::starts_with(out.str(), "Warning:")); + EXPECT_THROW(stan::io::stan_csv_reader::parse(missing_draws_stream, &out), + std::invalid_argument); + stan::io::stan_csv bernoulli_trunc; + EXPECT_THROW(stan::io::stan_csv_reader::parse(bernoulli_trunc_stream, &out), + std::invalid_argument); } TEST_F(StanIoStanCsvReader, thinned_data) { diff --git a/src/test/unit/mcmc/chainset_test.cpp b/src/test/unit/mcmc/chainset_test.cpp index e30175d6e4..47ba2c16f5 100644 --- a/src/test/unit/mcmc/chainset_test.cpp +++ b/src/test/unit/mcmc/chainset_test.cpp @@ -15,9 +15,6 @@ class McmcChains : public testing::Test { bernoulli_500_stream.open( "src/test/unit/mcmc/test_csv_files/bernoulli_500.csv", std::ifstream::in); - bernoulli_corrupt_stream.open( - "src/test/unit/mcmc/test_csv_files/bernoulli_corrupt.csv", - std::ifstream::in); bernoulli_default_stream.open( "src/test/unit/mcmc/test_csv_files/bernoulli_default.csv", std::ifstream::in); @@ -43,7 +40,7 @@ class McmcChains : public testing::Test { "src/test/unit/mcmc/test_csv_files/eight_schools_5iters_2.csv", std::ifstream::in); - if (!bernoulli_500_stream || !bernoulli_corrupt_stream + if (!bernoulli_500_stream || !bernoulli_default_stream || !bernoulli_thin_stream || !bernoulli_warmup_stream || !bernoulli_zeta_stream || !eight_schools_1_stream || !eight_schools_2_stream @@ -51,7 +48,6 @@ class McmcChains : public testing::Test { FAIL() << "Failed to open one or more test files"; } bernoulli_500_stream.seekg(0, std::ios::beg); - bernoulli_corrupt_stream.seekg(0, std::ios::beg); bernoulli_default_stream.seekg(0, std::ios::beg); bernoulli_thin_stream.seekg(0, std::ios::beg); bernoulli_warmup_stream.seekg(0, std::ios::beg); @@ -63,8 +59,6 @@ class McmcChains : public testing::Test { bernoulli_500 = stan::io::stan_csv_reader::parse(bernoulli_500_stream, &out); - bernoulli_corrupt - = stan::io::stan_csv_reader::parse(bernoulli_corrupt_stream, &out); bernoulli_default = stan::io::stan_csv_reader::parse(bernoulli_default_stream, &out); bernoulli_thin @@ -85,7 +79,6 @@ class McmcChains : public testing::Test { void TearDown() override { bernoulli_500_stream.close(); - bernoulli_corrupt_stream.close(); bernoulli_default_stream.close(); bernoulli_thin_stream.close(); bernoulli_warmup_stream.close(); @@ -98,14 +91,14 @@ class McmcChains : public testing::Test { std::stringstream out; - std::ifstream bernoulli_500_stream, bernoulli_corrupt_stream, - bernoulli_default_stream, bernoulli_thin_stream, bernoulli_warmup_stream, - bernoulli_zeta_stream, eight_schools_1_stream, eight_schools_2_stream, - eight_schools_5iters_1_stream, eight_schools_5iters_2_stream; + std::ifstream bernoulli_500_stream, + bernoulli_default_stream, bernoulli_thin_stream, bernoulli_warmup_stream, + bernoulli_zeta_stream, eight_schools_1_stream, eight_schools_2_stream, + eight_schools_5iters_1_stream, eight_schools_5iters_2_stream; - stan::io::stan_csv bernoulli_500, bernoulli_corrupt, bernoulli_default, - bernoulli_thin, bernoulli_warmup, bernoulli_zeta, eight_schools_1, - eight_schools_2, eight_schools_5iters_1, eight_schools_5iters_2; + stan::io::stan_csv bernoulli_500, bernoulli_default, + bernoulli_thin, bernoulli_warmup, bernoulli_zeta, eight_schools_1, + eight_schools_2, eight_schools_5iters_1, eight_schools_5iters_2; }; TEST_F(McmcChains, constructor) { @@ -139,11 +132,6 @@ TEST_F(McmcChains, constructor) { TEST_F(McmcChains, addFail) { std::vector bad; bad.push_back(bernoulli_default); - bad.push_back(bernoulli_corrupt); - EXPECT_THROW(stan::mcmc::chainset fail(bad), std::invalid_argument); - - bad.clear(); - bad.push_back(bernoulli_default); bad.push_back(bernoulli_500); EXPECT_THROW(stan::mcmc::chainset fail(bad), std::invalid_argument);