Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/3301 stan csv reader #3311

Merged
merged 9 commits into from
Oct 3, 2024
115 changes: 58 additions & 57 deletions src/stan/io/stan_csv_reader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ inline void prettify_stan_csv_name(std::string& variable) {
}
}

// FIXME: should consolidate with the options from
// the command line in stan::lang
struct stan_csv_metadata {
int stan_version_major;
int stan_version_minor;
Expand All @@ -47,6 +45,7 @@ struct stan_csv_metadata {
bool save_warmup;
size_t thin;
bool append_samples;
std::string method;
std::string algorithm;
std::string engine;
int max_depth;
Expand All @@ -64,8 +63,9 @@ struct stan_csv_metadata {
num_samples(0),
num_warmup(0),
save_warmup(false),
thin(0),
thin(1),
append_samples(false),
method(""),
algorithm(""),
engine(""),
max_depth(10) {}
Expand Down Expand Up @@ -101,13 +101,12 @@ class stan_csv_reader {
stan_csv_reader() {}
~stan_csv_reader() {}

static bool read_metadata(std::istream& in, stan_csv_metadata& metadata,
std::ostream* out) {
static void read_metadata(std::istream& in, stan_csv_metadata& metadata) {
std::stringstream ss;
std::string line;

if (in.peek() != '#')
return false;
return;
while (in.peek() == '#') {
std::getline(in, line);
ss << line << '\n';
Expand Down Expand Up @@ -161,9 +160,15 @@ class stan_csv_reader {
metadata.model = value;
} else if (name.compare("num_samples") == 0) {
std::stringstream(value) >> metadata.num_samples;
} else if (name.compare("output_samples") == 0) { // ADVI config name
std::stringstream(value) >> metadata.num_samples;
} else if (name.compare("num_warmup") == 0) {
std::stringstream(value) >> metadata.num_warmup;
} else if (name.compare("save_warmup") == 0) {
// cmdstan args can be "true" and "false", was "1", "0"
if (value.compare("true") == 0) {
value = "1";
}
std::stringstream(value) >> metadata.save_warmup;
} else if (name.compare("thin") == 0) {
std::stringstream(value) >> metadata.thin;
Expand All @@ -177,6 +182,8 @@ class stan_csv_reader {
metadata.random_seed = false;
} else if (name.compare("append_samples") == 0) {
std::stringstream(value) >> metadata.append_samples;
} else if (name.compare("method") == 0) {
metadata.method = value;
} else if (name.compare("algorithm") == 0) {
metadata.algorithm = value;
} else if (name.compare("engine") == 0) {
Expand All @@ -185,14 +192,10 @@ class stan_csv_reader {
std::stringstream(value) >> metadata.max_depth;
}
}
if (ss.good() == true)
return false;

return true;
} // read_metadata

static bool read_header(std::istream& in, std::vector<std::string>& header,
std::ostream* out, bool prettify_name = true) {
bool prettify_name = true) {
std::string line;

if (!std::isalpha(in.peek()))
Expand All @@ -216,81 +219,71 @@ class stan_csv_reader {
return true;
}

static bool read_adaptation(std::istream& in, stan_csv_adaptation& adaptation,
std::ostream* out) {
static void read_adaptation(std::istream& in,
stan_csv_adaptation& adaptation) {
std::stringstream ss;
std::string line;
int lines = 0;

if (in.peek() != '#' || in.good() == false)
return false;

return;
while (in.peek() == '#') {
std::getline(in, line);
ss << line << std::endl;
lines++;
}
ss.seekg(std::ios_base::beg);
if (lines < 2)
return;

if (lines < 4)
return false;

char comment; // Buffer for comment indicator, #
std::getline(ss, line); // comment adaptation terminated

// Skip first two lines
std::getline(ss, line);

// Stepsize
std::getline(ss, line, '=');
// parse stepsize
std::getline(ss, line, '='); // stepsize
boost::trim(line);
ss >> adaptation.step_size;
if (lines == 2) // ADVI reports stepsize, no metric
return;

// Metric parameters
std::getline(ss, line);
std::getline(ss, line);
std::getline(ss, line);
std::getline(ss, line); // consume end of stepsize line
std::getline(ss, line); // comment elements of mass matrix
std::getline(ss, line); // diagonal metric or row 1 of dense metric

int rows = lines - 3;
int cols = std::count(line.begin(), line.end(), ',') + 1;
adaptation.metric.resize(rows, cols);
char comment; // Buffer for comment indicator, #

// parse metric, row by row, element by element
for (int row = 0; row < rows; row++) {
std::stringstream line_ss;
line_ss.str(line);
line_ss >> comment;

for (int col = 0; col < cols; col++) {
std::string token;
std::getline(line_ss, token, ',');
boost::trim(token);
std::stringstream(token) >> adaptation.metric(row, col);
}
std::getline(ss, line); // Read in next line
std::getline(ss, line);
}

if (ss.good())
return false;
else
return true;
}

static bool read_samples(std::istream& in, Eigen::MatrixXd& samples,
stan_csv_timing& timing, std::ostream* out) {
stan_csv_timing& timing) {
std::stringstream ss;
std::string line;

int rows = 0;
int cols = -1;

if (in.peek() == '#' || in.good() == false)
return false;
return false; // need at least one data row

while (in.good()) {
bool comment_line = (in.peek() == '#');
bool empty_line = (in.peek() == '\n');

std::getline(in, line);

if (empty_line)
continue;
if (!line.length())
Expand All @@ -316,11 +309,10 @@ class stan_csv_reader {
if (cols == -1) {
cols = current_cols;
} else if (cols != current_cols) {
if (out)
*out << "Error: expected " << cols << " columns, but found "
<< current_cols << " instead for row " << rows + 1
<< std::endl;
return false;
std::stringstream msg;
msg << "Error: expected " << cols << " columns, but found "
<< current_cols << " instead for row " << rows + 1;
throw std::invalid_argument(msg.str());
}
rows++;
}
Expand Down Expand Up @@ -348,36 +340,45 @@ class stan_csv_reader {
/**
* Parses the file.
*
* Throws exception if contents can't be parsed into header + data rows.
*
* Emits warning message
*
* @param[in] in input stream to parse
* @param[out] out output stream to send messages
*/
static stan_csv parse(std::istream& in, std::ostream* out) {
stan_csv data;
std::string line;

if (!read_metadata(in, data.metadata, out)) {
if (out)
*out << "Warning: non-fatal error reading metadata" << std::endl;
read_metadata(in, data.metadata);
if (!read_header(in, data.header)) {
throw std::invalid_argument("Error: no column names found in csv file");
}

if (!read_header(in, data.header, out)) {
if (out)
*out << "Error: error reading header" << std::endl;
throw std::invalid_argument("Error with header of input file in parse");
// skip warmup draws, if any
if (data.metadata.algorithm != "fixed_param" && data.metadata.num_warmup > 0
&& data.metadata.save_warmup) {
while (in.peek() != '#') {
std::getline(in, line);
}
}

if (!read_adaptation(in, data.adaptation, out)) {
if (out)
*out << "Warning: non-fatal error reading adaptation data" << std::endl;
if (data.metadata.algorithm != "fixed_param") {
read_adaptation(in, data.adaptation);
}

data.timing.warmup = 0;
data.timing.sampling = 0;

if (!read_samples(in, data.samples, data.timing, out)) {
if (out)
*out << "Warning: non-fatal error reading samples" << std::endl;
if (data.metadata.method == "variational") {
std::getline(in, line); // discard variational estimate
}

if (!read_samples(in, data.samples, data.timing)) {
if (out)
*out << "Unable to parse sample" << std::endl;
}
return data;
}
};
Expand Down
Loading
Loading