Skip to content

Commit

Permalink
bugfix in, all unit tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
mitzimorris committed Oct 2, 2024
1 parent dd24b5c commit 7704ba4
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 65 deletions.
107 changes: 54 additions & 53 deletions src/stan/io/stan_csv_reader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ inline void prettify_stan_csv_name(std::string& variable) {
}
}

// FIXME: should consolidate with the options from
// the command line in stan::lang
struct stan_csv_metadata {
int stan_version_major;
int stan_version_minor;
Expand All @@ -47,6 +45,7 @@ struct stan_csv_metadata {
bool save_warmup;
size_t thin;
bool append_samples;
std::string method;
std::string algorithm;
std::string engine;
int max_depth;
Expand All @@ -64,8 +63,9 @@ struct stan_csv_metadata {
num_samples(0),
num_warmup(0),
save_warmup(false),
thin(0),
thin(1),
append_samples(false),
method(""),
algorithm(""),
engine(""),
max_depth(10) {}
Expand Down Expand Up @@ -101,13 +101,12 @@ class stan_csv_reader {
stan_csv_reader() {}
~stan_csv_reader() {}

static bool read_metadata(std::istream& in, stan_csv_metadata& metadata,
std::ostream* out) {
static void read_metadata(std::istream& in, stan_csv_metadata& metadata) {
std::stringstream ss;
std::string line;

if (in.peek() != '#')
return false;
return;
while (in.peek() == '#') {
std::getline(in, line);
ss << line << '\n';
Expand Down Expand Up @@ -161,9 +160,15 @@ class stan_csv_reader {
metadata.model = value;
} else if (name.compare("num_samples") == 0) {
std::stringstream(value) >> metadata.num_samples;
} else if (name.compare("output_samples") == 0) { // ADVI config name
std::stringstream(value) >> metadata.num_samples;
} else if (name.compare("num_warmup") == 0) {
std::stringstream(value) >> metadata.num_warmup;
} else if (name.compare("save_warmup") == 0) {
// cmdstan args can be "true" and "false", was "1", "0"
if (value.compare("true") == 0) {
value = "1";
}
std::stringstream(value) >> metadata.save_warmup;
} else if (name.compare("thin") == 0) {
std::stringstream(value) >> metadata.thin;
Expand All @@ -177,6 +182,8 @@ class stan_csv_reader {
metadata.random_seed = false;
} else if (name.compare("append_samples") == 0) {
std::stringstream(value) >> metadata.append_samples;
} else if (name.compare("method") == 0) {
metadata.method = value;
} else if (name.compare("algorithm") == 0) {
metadata.algorithm = value;
} else if (name.compare("engine") == 0) {
Expand All @@ -185,14 +192,10 @@ class stan_csv_reader {
std::stringstream(value) >> metadata.max_depth;
}
}
if (ss.good() == true)
return false;

return true;
} // read_metadata

static bool read_header(std::istream& in, std::vector<std::string>& header,
std::ostream* out, bool prettify_name = true) {
bool prettify_name = true) {
std::string line;

if (!std::isalpha(in.peek()))
Expand All @@ -216,62 +219,53 @@ class stan_csv_reader {
return true;
}

static bool read_adaptation(std::istream& in, stan_csv_adaptation& adaptation,
std::ostream* out) {
static void read_adaptation(std::istream& in,
stan_csv_adaptation& adaptation) {
std::stringstream ss;
std::string line;
int lines = 0;

if (in.peek() != '#' || in.good() == false)
return false;

return;
while (in.peek() == '#') {
std::getline(in, line);
ss << line << std::endl;
lines++;
}
ss.seekg(std::ios_base::beg);
if (lines < 2)
return;

if (lines < 4)
return false;
std::getline(ss, line); // comment adaptation terminated

char comment; // Buffer for comment indicator, #

// Skip first two lines
std::getline(ss, line);

// Stepsize
std::getline(ss, line, '=');
// parse stepsize
std::getline(ss, line, '='); // stepsize
boost::trim(line);
ss >> adaptation.step_size;
if (lines == 2) // ADVI reports stepsize, no metric
return;

// Metric parameters
std::getline(ss, line);
std::getline(ss, line);
std::getline(ss, line);
std::getline(ss, line); // consume end of stepsize line
std::getline(ss, line); // comment elements of mass matrix
std::getline(ss, line); // diagonal metric or row 1 of dense metric

int rows = lines - 3;
int cols = std::count(line.begin(), line.end(), ',') + 1;
adaptation.metric.resize(rows, cols);
char comment; // Buffer for comment indicator, #

// parse metric, row by row, element by element
for (int row = 0; row < rows; row++) {
std::stringstream line_ss;
line_ss.str(line);
line_ss >> comment;

for (int col = 0; col < cols; col++) {
std::string token;
std::getline(line_ss, token, ',');
boost::trim(token);
std::stringstream(token) >> adaptation.metric(row, col);
}
std::getline(ss, line); // Read in next line
std::getline(ss, line);
}

if (ss.good())
return false;
else
return true;
}

static bool read_samples(std::istream& in, Eigen::MatrixXd& samples,
Expand All @@ -290,7 +284,6 @@ class stan_csv_reader {
bool empty_line = (in.peek() == '\n');

std::getline(in, line);

if (empty_line)
continue;
if (!line.length())
Expand All @@ -316,11 +309,10 @@ class stan_csv_reader {
if (cols == -1) {
cols = current_cols;
} else if (cols != current_cols) {
if (out)
*out << "Error: expected " << cols << " columns, but found "
<< current_cols << " instead for row " << rows + 1
<< std::endl;
return false;
std::stringstream msg;
msg << "Error: expected " << cols << " columns, but found "
<< current_cols << " instead for row " << rows + 1;
throw std::invalid_argument(msg.str());
}
rows++;
}
Expand Down Expand Up @@ -348,36 +340,45 @@ class stan_csv_reader {
/**
* Parses the file.
*
* Throws exception if contents can't be parsed into header + data rows.
*
* Emits warning message
*
* @param[in] in input stream to parse
* @param[out] out output stream to send messages
*/
static stan_csv parse(std::istream& in, std::ostream* out) {
stan_csv data;
std::string line;

if (!read_metadata(in, data.metadata, out)) {
if (out)
*out << "Warning: non-fatal error reading metadata" << std::endl;
read_metadata(in, data.metadata);
if (!read_header(in, data.header)) {
throw std::invalid_argument("Error: no column names found in csv file");
}

if (!read_header(in, data.header, out)) {
if (out)
*out << "Error: error reading header" << std::endl;
throw std::invalid_argument("Error with header of input file in parse");
// skip warmup draws, if any
if (data.metadata.algorithm != "fixed_param" && data.metadata.num_warmup > 0
&& data.metadata.save_warmup) {
while (in.peek() != '#') {
std::getline(in, line);
}
}

if (!read_adaptation(in, data.adaptation, out)) {
if (out)
*out << "Warning: non-fatal error reading adaptation data" << std::endl;
if (data.metadata.algorithm != "fixed_param") {
read_adaptation(in, data.adaptation);
}

data.timing.warmup = 0;
data.timing.sampling = 0;

if (data.metadata.method == "variational") {
std::getline(in, line); // discard variational estimate
}

if (!read_samples(in, data.samples, data.timing, out)) {
if (out)
*out << "Warning: non-fatal error reading samples" << std::endl;
}

return data;
}
};
Expand Down
21 changes: 9 additions & 12 deletions src/test/unit/io/stan_csv_reader_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ class StanIoStanCsvReader : public testing::Test {
header3_stream.open("src/test/unit/io/test_csv_files/header3.csv");
adaptation1_stream.open("src/test/unit/io/test_csv_files/adaptation1.csv");
samples1_stream.open("src/test/unit/io/test_csv_files/samples1.csv");

epil0_stream.open("src/test/unit/io/test_csv_files/epil.0.csv");

blocker_nondiag0_stream.open(
"src/test/unit/io/test_csv_files/blocker_nondiag.0.csv");
eight_schools_stream.open(
"src/test/unit/io/test_csv_files/eight_schools.csv");

bernoulli_thin_stream.open(
"src/test/unit/io/test_csv_files/bernoulli_thin.csv");
bernoulli_warmup_stream.open(
Expand Down Expand Up @@ -59,8 +62,7 @@ class StanIoStanCsvReader : public testing::Test {

TEST_F(StanIoStanCsvReader, read_metadata1) {
stan::io::stan_csv_metadata metadata;
EXPECT_TRUE(
stan::io::stan_csv_reader::read_metadata(metadata1_stream, metadata, 0));
stan::io::stan_csv_reader::read_metadata(metadata1_stream, metadata);

EXPECT_EQ(2, metadata.stan_version_major);
EXPECT_EQ(9, metadata.stan_version_minor);
Expand All @@ -86,8 +88,7 @@ TEST_F(StanIoStanCsvReader, read_metadata1) {

TEST_F(StanIoStanCsvReader, read_metadata3) {
stan::io::stan_csv_metadata metadata;
EXPECT_TRUE(
stan::io::stan_csv_reader::read_metadata(metadata3_stream, metadata, 0));
stan::io::stan_csv_reader::read_metadata(metadata3_stream, metadata);

EXPECT_EQ(2, metadata.stan_version_major);
EXPECT_EQ(9, metadata.stan_version_minor);
Expand All @@ -113,8 +114,7 @@ TEST_F(StanIoStanCsvReader, read_metadata3) {

TEST_F(StanIoStanCsvReader, read_header1) {
std::vector<std::string> header;
EXPECT_TRUE(
stan::io::stan_csv_reader::read_header(header1_stream, header, 0));
EXPECT_TRUE(stan::io::stan_csv_reader::read_header(header1_stream, header));

ASSERT_EQ(55, header.size());
EXPECT_EQ("lp__", header[0]);
Expand All @@ -141,8 +141,7 @@ TEST_F(StanIoStanCsvReader, read_header1) {

TEST_F(StanIoStanCsvReader, read_header2) {
std::vector<std::string> header;
EXPECT_TRUE(
stan::io::stan_csv_reader::read_header(header2_stream, header, 0));
EXPECT_TRUE(stan::io::stan_csv_reader::read_header(header2_stream, header));

ASSERT_EQ(5, header.size());
EXPECT_EQ("d", header[0]);
Expand All @@ -156,8 +155,7 @@ TEST_F(StanIoStanCsvReader, read_header2) {

TEST_F(StanIoStanCsvReader, read_header_tuples) {
std::vector<std::string> header;
EXPECT_TRUE(
stan::io::stan_csv_reader::read_header(header3_stream, header, 0));
EXPECT_TRUE(stan::io::stan_csv_reader::read_header(header3_stream, header));

ASSERT_EQ(46, header.size());

Expand Down Expand Up @@ -190,8 +188,7 @@ TEST_F(StanIoStanCsvReader, read_header_tuples) {

TEST_F(StanIoStanCsvReader, read_adaptation1) {
stan::io::stan_csv_adaptation adaptation;
EXPECT_TRUE(stan::io::stan_csv_reader::read_adaptation(adaptation1_stream,
adaptation, 0));
stan::io::stan_csv_reader::read_adaptation(adaptation1_stream, adaptation);

EXPECT_FLOAT_EQ(0.118745, adaptation.step_size);
ASSERT_EQ(47, adaptation.metric.size());
Expand Down

0 comments on commit 7704ba4

Please sign in to comment.