diff --git a/src/stan/io/stan_csv_reader.hpp b/src/stan/io/stan_csv_reader.hpp index 5ded3d49c9..6d8fa9490d 100644 --- a/src/stan/io/stan_csv_reader.hpp +++ b/src/stan/io/stan_csv_reader.hpp @@ -29,8 +29,6 @@ inline void prettify_stan_csv_name(std::string& variable) { } } -// FIXME: should consolidate with the options from -// the command line in stan::lang struct stan_csv_metadata { int stan_version_major; int stan_version_minor; @@ -47,6 +45,7 @@ struct stan_csv_metadata { bool save_warmup; size_t thin; bool append_samples; + std::string method; std::string algorithm; std::string engine; int max_depth; @@ -64,8 +63,9 @@ struct stan_csv_metadata { num_samples(0), num_warmup(0), save_warmup(false), - thin(0), + thin(1), append_samples(false), + method(""), algorithm(""), engine(""), max_depth(10) {} @@ -101,13 +101,12 @@ class stan_csv_reader { stan_csv_reader() {} ~stan_csv_reader() {} - static bool read_metadata(std::istream& in, stan_csv_metadata& metadata, - std::ostream* out) { + static void read_metadata(std::istream& in, stan_csv_metadata& metadata) { std::stringstream ss; std::string line; if (in.peek() != '#') - return false; + return; while (in.peek() == '#') { std::getline(in, line); ss << line << '\n'; @@ -161,9 +160,15 @@ class stan_csv_reader { metadata.model = value; } else if (name.compare("num_samples") == 0) { std::stringstream(value) >> metadata.num_samples; + } else if (name.compare("output_samples") == 0) { // ADVI config name + std::stringstream(value) >> metadata.num_samples; } else if (name.compare("num_warmup") == 0) { std::stringstream(value) >> metadata.num_warmup; } else if (name.compare("save_warmup") == 0) { + // cmdstan args can be "true" and "false", was "1", "0" + if (value.compare("true") == 0) { + value = "1"; + } std::stringstream(value) >> metadata.save_warmup; } else if (name.compare("thin") == 0) { std::stringstream(value) >> metadata.thin; @@ -177,6 +182,8 @@ class stan_csv_reader { metadata.random_seed = false; } else if (name.compare("append_samples") == 0) { std::stringstream(value) >> metadata.append_samples; + } else if (name.compare("method") == 0) { + metadata.method = value; } else if (name.compare("algorithm") == 0) { metadata.algorithm = value; } else if (name.compare("engine") == 0) { @@ -185,14 +192,10 @@ class stan_csv_reader { std::stringstream(value) >> metadata.max_depth; } } - if (ss.good() == true) - return false; - - return true; } // read_metadata static bool read_header(std::istream& in, std::vector& header, - std::ostream* out, bool prettify_name = true) { + bool prettify_name = true) { std::string line; if (!std::isalpha(in.peek())) @@ -216,62 +219,53 @@ class stan_csv_reader { return true; } - static bool read_adaptation(std::istream& in, stan_csv_adaptation& adaptation, - std::ostream* out) { + static void read_adaptation(std::istream& in, + stan_csv_adaptation& adaptation) { std::stringstream ss; std::string line; int lines = 0; - if (in.peek() != '#' || in.good() == false) - return false; - + return; while (in.peek() == '#') { std::getline(in, line); ss << line << std::endl; lines++; } ss.seekg(std::ios_base::beg); + if (lines < 2) + return; - if (lines < 4) - return false; + std::getline(ss, line); // comment adaptation terminated - char comment; // Buffer for comment indicator, # - - // Skip first two lines - std::getline(ss, line); - - // Stepsize - std::getline(ss, line, '='); + // parse stepsize + std::getline(ss, line, '='); // stepsize boost::trim(line); ss >> adaptation.step_size; + if (lines == 2) // ADVI reports stepsize, no metric + return; - // Metric parameters - std::getline(ss, line); - std::getline(ss, line); - std::getline(ss, line); + std::getline(ss, line); // consume end of stepsize line + std::getline(ss, line); // comment elements of mass matrix + std::getline(ss, line); // diagonal metric or row 1 of dense metric int rows = lines - 3; int cols = std::count(line.begin(), line.end(), ',') + 1; adaptation.metric.resize(rows, cols); + char comment; // Buffer for comment indicator, # + // parse metric, row by row, element by element for (int row = 0; row < rows; row++) { std::stringstream line_ss; line_ss.str(line); line_ss >> comment; - for (int col = 0; col < cols; col++) { std::string token; std::getline(line_ss, token, ','); boost::trim(token); std::stringstream(token) >> adaptation.metric(row, col); } - std::getline(ss, line); // Read in next line + std::getline(ss, line); } - - if (ss.good()) - return false; - else - return true; } static bool read_samples(std::istream& in, Eigen::MatrixXd& samples, @@ -290,7 +284,6 @@ class stan_csv_reader { bool empty_line = (in.peek() == '\n'); std::getline(in, line); - if (empty_line) continue; if (!line.length()) @@ -316,11 +309,10 @@ class stan_csv_reader { if (cols == -1) { cols = current_cols; } else if (cols != current_cols) { - if (out) - *out << "Error: expected " << cols << " columns, but found " - << current_cols << " instead for row " << rows + 1 - << std::endl; - return false; + std::stringstream msg; + msg << "Error: expected " << cols << " columns, but found " + << current_cols << " instead for row " << rows + 1; + throw std::invalid_argument(msg.str()); } rows++; } @@ -348,36 +340,45 @@ class stan_csv_reader { /** * Parses the file. * + * Throws exception if contents can't be parsed into header + data rows. + * + * Emits warning message + * * @param[in] in input stream to parse * @param[out] out output stream to send messages */ static stan_csv parse(std::istream& in, std::ostream* out) { stan_csv data; + std::string line; - if (!read_metadata(in, data.metadata, out)) { - if (out) - *out << "Warning: non-fatal error reading metadata" << std::endl; + read_metadata(in, data.metadata); + if (!read_header(in, data.header)) { + throw std::invalid_argument("Error: no column names found in csv file"); } - if (!read_header(in, data.header, out)) { - if (out) - *out << "Error: error reading header" << std::endl; - throw std::invalid_argument("Error with header of input file in parse"); + // skip warmup draws, if any + if (data.metadata.algorithm != "fixed_param" && data.metadata.num_warmup > 0 + && data.metadata.save_warmup) { + while (in.peek() != '#') { + std::getline(in, line); + } } - if (!read_adaptation(in, data.adaptation, out)) { - if (out) - *out << "Warning: non-fatal error reading adaptation data" << std::endl; + if (data.metadata.algorithm != "fixed_param") { + read_adaptation(in, data.adaptation); } data.timing.warmup = 0; data.timing.sampling = 0; + if (data.metadata.method == "variational") { + std::getline(in, line); // discard variational estimate + } + if (!read_samples(in, data.samples, data.timing, out)) { if (out) *out << "Warning: non-fatal error reading samples" << std::endl; } - return data; } }; diff --git a/src/test/unit/io/stan_csv_reader_test.cpp b/src/test/unit/io/stan_csv_reader_test.cpp index 6de6a9d99d..43b8f68dfa 100644 --- a/src/test/unit/io/stan_csv_reader_test.cpp +++ b/src/test/unit/io/stan_csv_reader_test.cpp @@ -16,11 +16,14 @@ class StanIoStanCsvReader : public testing::Test { header3_stream.open("src/test/unit/io/test_csv_files/header3.csv"); adaptation1_stream.open("src/test/unit/io/test_csv_files/adaptation1.csv"); samples1_stream.open("src/test/unit/io/test_csv_files/samples1.csv"); + epil0_stream.open("src/test/unit/io/test_csv_files/epil.0.csv"); + blocker_nondiag0_stream.open( "src/test/unit/io/test_csv_files/blocker_nondiag.0.csv"); eight_schools_stream.open( "src/test/unit/io/test_csv_files/eight_schools.csv"); + bernoulli_thin_stream.open( "src/test/unit/io/test_csv_files/bernoulli_thin.csv"); bernoulli_warmup_stream.open( @@ -59,8 +62,7 @@ class StanIoStanCsvReader : public testing::Test { TEST_F(StanIoStanCsvReader, read_metadata1) { stan::io::stan_csv_metadata metadata; - EXPECT_TRUE( - stan::io::stan_csv_reader::read_metadata(metadata1_stream, metadata, 0)); + stan::io::stan_csv_reader::read_metadata(metadata1_stream, metadata); EXPECT_EQ(2, metadata.stan_version_major); EXPECT_EQ(9, metadata.stan_version_minor); @@ -86,8 +88,7 @@ TEST_F(StanIoStanCsvReader, read_metadata1) { TEST_F(StanIoStanCsvReader, read_metadata3) { stan::io::stan_csv_metadata metadata; - EXPECT_TRUE( - stan::io::stan_csv_reader::read_metadata(metadata3_stream, metadata, 0)); + stan::io::stan_csv_reader::read_metadata(metadata3_stream, metadata); EXPECT_EQ(2, metadata.stan_version_major); EXPECT_EQ(9, metadata.stan_version_minor); @@ -113,8 +114,7 @@ TEST_F(StanIoStanCsvReader, read_metadata3) { TEST_F(StanIoStanCsvReader, read_header1) { std::vector header; - EXPECT_TRUE( - stan::io::stan_csv_reader::read_header(header1_stream, header, 0)); + EXPECT_TRUE(stan::io::stan_csv_reader::read_header(header1_stream, header)); ASSERT_EQ(55, header.size()); EXPECT_EQ("lp__", header[0]); @@ -141,8 +141,7 @@ TEST_F(StanIoStanCsvReader, read_header1) { TEST_F(StanIoStanCsvReader, read_header2) { std::vector header; - EXPECT_TRUE( - stan::io::stan_csv_reader::read_header(header2_stream, header, 0)); + EXPECT_TRUE(stan::io::stan_csv_reader::read_header(header2_stream, header)); ASSERT_EQ(5, header.size()); EXPECT_EQ("d", header[0]); @@ -156,8 +155,7 @@ TEST_F(StanIoStanCsvReader, read_header2) { TEST_F(StanIoStanCsvReader, read_header_tuples) { std::vector header; - EXPECT_TRUE( - stan::io::stan_csv_reader::read_header(header3_stream, header, 0)); + EXPECT_TRUE(stan::io::stan_csv_reader::read_header(header3_stream, header)); ASSERT_EQ(46, header.size()); @@ -190,8 +188,7 @@ TEST_F(StanIoStanCsvReader, read_header_tuples) { TEST_F(StanIoStanCsvReader, read_adaptation1) { stan::io::stan_csv_adaptation adaptation; - EXPECT_TRUE(stan::io::stan_csv_reader::read_adaptation(adaptation1_stream, - adaptation, 0)); + stan::io::stan_csv_reader::read_adaptation(adaptation1_stream, adaptation); EXPECT_FLOAT_EQ(0.118745, adaptation.step_size); ASSERT_EQ(47, adaptation.metric.size());