diff --git a/src/cmdstan/command.hpp b/src/cmdstan/command.hpp index 792040ff54..fa0686fc1a 100644 --- a/src/cmdstan/command.hpp +++ b/src/cmdstan/command.hpp @@ -210,7 +210,7 @@ int command(int argc, const char *argv[]) { } std::vector> init_contexts - = get_vec_var_context(init, num_chains); + = get_vec_var_context(init, num_chains, id); std::vector model_compile_info = model.model_compile_info(); for (int i = 0; i < num_chains; ++i) { @@ -510,7 +510,7 @@ int command(int argc, const char *argv[]) { dynamic_cast(algo->arg("hmc")->arg("metric_file")) ->value()); context_vector metric_contexts - = get_vec_var_context(metric_filename, num_chains); + = get_vec_var_context(metric_filename, num_chains, id); categorical_argument *adapt = dynamic_cast(sample_arg->arg("adapt")); categorical_argument *hmc diff --git a/src/cmdstan/command_helper.hpp b/src/cmdstan/command_helper.hpp index 3077db5de5..0f0de2265c 100644 --- a/src/cmdstan/command_helper.hpp +++ b/src/cmdstan/command_helper.hpp @@ -192,6 +192,29 @@ inline shared_context_ptr get_var_context(const std::string &file) { return std::make_shared(var_context); } +std::vector make_filenames(const std::string &filename, + const std::string &tag, + const std::string &type, + unsigned int num_chains, + unsigned int id) { + std::vector names(num_chains); + auto base_sfx = get_basename_suffix(filename); + if (base_sfx.second.empty()) { + base_sfx.second = type; + } + auto name_iterator = [num_chains, id](auto i) { + if (num_chains == 1) { + return std::string(""); + } else { + return std::string("_" + std::to_string(i + id)); + } + }; + for (int i = 0; i < num_chains; ++i) { + names[i] = base_sfx.first + tag + name_iterator(i) + base_sfx.second; + } + return names; +} + using context_vector = std::vector; /** * Make a vector of shared pointers to contexts. @@ -201,7 +224,8 @@ using context_vector = std::vector; * @param num_chains The number of chains to run * @return a std vector of shared pointers to var contexts */ -context_vector get_vec_var_context(const std::string &file, size_t num_chains) { +context_vector get_vec_var_context(const std::string &file, size_t num_chains, + unsigned int id) { using stan::io::var_context; if (num_chains == 1) { return context_vector(1, get_var_context(file)); @@ -249,8 +273,9 @@ context_vector get_vec_var_context(const std::string &file, size_t num_chains) { "\tConsider saving your data in JSON format instead." << std::endl; } - std::string file_1 - = std::string(file_name + "_" + std::to_string(1) + file_ending); + + auto filenames = make_filenames(file_name, "", file_ending, num_chains, id); + auto &file_1 = filenames[0]; std::fstream stream_1(file_1.c_str(), std::fstream::in); // if file_1 exists we'll assume num_chains of these files exist if (stream_1.rdstate() & std::ifstream::failbit) { @@ -274,9 +299,8 @@ context_vector get_vec_var_context(const std::string &file, size_t num_chains) { ret.reserve(num_chains); ret.push_back(make_context(file_1, stream_1, file_ending)); for (size_t i = 1; i < num_chains; ++i) { - std::string file_i - = std::string(file_name + "_" + std::to_string(i) + file_ending); - std::fstream stream_i(file_1.c_str(), std::fstream::in); + auto &file_i = filenames[i]; + std::fstream stream_i(file_i.c_str(), std::fstream::in); // If any stream fails here something went wrong with file names if (stream_i.rdstate() & std::ifstream::failbit) { std::string file_name_err = std::string( @@ -737,29 +761,6 @@ void check_file_config(argument_parser &parser) { } } -std::vector make_filenames(const std::string &filename, - const std::string &tag, - const std::string &type, - unsigned int num_chains, - unsigned int id) { - std::vector names(num_chains); - auto base_sfx = get_basename_suffix(filename); - if (base_sfx.second.empty()) { - base_sfx.second = type; - } - auto name_iterator = [num_chains, id](auto i) { - if (num_chains == 1) { - return std::string(""); - } else { - return std::string("_" + std::to_string(i + id)); - } - }; - for (int i = 0; i < num_chains; ++i) { - names[i] = base_sfx.first + tag + name_iterator(i) + base_sfx.second; - } - return names; -} - void init_callbacks( argument_parser &parser, std::vector> diff --git a/src/test/interface/multi_chain_init_test.cpp b/src/test/interface/multi_chain_init_test.cpp index 742b898427..5cb8d3a0f5 100644 --- a/src/test/interface/multi_chain_init_test.cpp +++ b/src/test/interface/multi_chain_init_test.cpp @@ -18,6 +18,7 @@ class CmdStan : public testing::Test { init_data = {"src", "test", "test-models", "bern_init.json"}; init2_data = {"src", "test", "test-models", "bern_init2.json"}; init3_data = {"src", "test", "test-models", "bern_init2.R"}; + init_bad_data = {"src", "test", "test-models", "bern_init_bad.json"}; dev_null_path = {"/dev", "null"}; } std::vector bern_model; @@ -26,6 +27,7 @@ class CmdStan : public testing::Test { std::vector init_data; std::vector init2_data; std::vector init3_data; + std::vector init_bad_data; }; TEST_F(CmdStan, multi_chain_single_init_file_good) { @@ -52,6 +54,44 @@ TEST_F(CmdStan, multi_chain_multi_init_file_good) { ASSERT_FALSE(out.hasError); } +TEST_F(CmdStan, multi_chain_multi_init_file_id_good) { + std::stringstream ss; + ss << convert_model_path(bern_model) + << " data file=" << convert_model_path(bern_data) + << " output file=" << convert_model_path(dev_null_path) + << " init=" << convert_model_path(init2_data) << " id=2" + << " method=sample num_chains=2"; + std::string cmd = ss.str(); + run_command_output out = run_command(cmd); + ASSERT_FALSE(out.hasError) << out.output; +} + +TEST_F(CmdStan, multi_chain_multi_init_file_id_bad) { + // this will start by requesting ..._4.json, which doesn't exist + std::stringstream ss; + ss << convert_model_path(bern_model) + << " data file=" << convert_model_path(bern_data) + << " output file=" << convert_model_path(dev_null_path) + << " init=" << convert_model_path(init2_data) << " id=4" + << " method=sample num_chains=3"; + std::string cmd = ss.str(); + run_command_output out = run_command(cmd); + ASSERT_TRUE(out.hasError); +} + +TEST_F(CmdStan, multi_chain_multi_init_file_actually_used) { + // the second chain has a bad init value + std::stringstream ss; + ss << convert_model_path(bern_model) + << " data file=" << convert_model_path(bern_data) + << " output file=" << convert_model_path(dev_null_path) + << " init=" << convert_model_path(init_bad_data) + << " method=sample num_chains=2"; + std::string cmd = ss.str(); + run_command_output out = run_command(cmd); + ASSERT_TRUE(out.hasError) << out.output; +} + TEST_F(CmdStan, multi_chain_multi_init_file_R) { std::stringstream ss; ss << convert_model_path(bern_model) diff --git a/src/test/test-models/bern_init_bad_1.json b/src/test/test-models/bern_init_bad_1.json new file mode 100644 index 0000000000..7de066920b --- /dev/null +++ b/src/test/test-models/bern_init_bad_1.json @@ -0,0 +1,3 @@ +{ + "theta" : 0.1 +} diff --git a/src/test/test-models/bern_init_bad_2.json b/src/test/test-models/bern_init_bad_2.json new file mode 100644 index 0000000000..c330dfacbf --- /dev/null +++ b/src/test/test-models/bern_init_bad_2.json @@ -0,0 +1,3 @@ +{ + "theta" : 3.0 +}