Skip to content

Commit

Permalink
Maintain exception safety in top-level services functions
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Feb 8, 2024
1 parent 5681a81 commit 2709df7
Show file tree
Hide file tree
Showing 21 changed files with 395 additions and 221 deletions.
9 changes: 7 additions & 2 deletions src/stan/services/experimental/advi/fullrank.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,13 @@ int fullrank(Model& model, const stan::io::var_context& init,
stan::rng_t>
cmd_advi(model, cont_params, rng, grad_samples, elbo_samples, eval_elbo,
output_samples);
cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj,
max_iterations, logger, parameter_writer, diagnostic_writer);
try {
cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj,
max_iterations, logger, parameter_writer, diagnostic_writer);
} catch (const std::exception& e) {
logger.error(e.what());
return error_codes::SOFTWARE;
}

return stan::services::error_codes::OK;
}
Expand Down
9 changes: 7 additions & 2 deletions src/stan/services/experimental/advi/meanfield.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,13 @@ int meanfield(Model& model, const stan::io::var_context& init,
stan::rng_t>
cmd_advi(model, cont_params, rng, grad_samples, elbo_samples, eval_elbo,
output_samples);
cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj,
max_iterations, logger, parameter_writer, diagnostic_writer);
try {
cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj,
max_iterations, logger, parameter_writer, diagnostic_writer);
} catch (const std::exception& e) {
logger.error(e.what());
return error_codes::SOFTWARE;
}

return stan::services::error_codes::OK;
}
Expand Down
42 changes: 37 additions & 5 deletions src/stan/services/optimize/bfgs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,16 @@ int bfgs(Model& model, const stan::io::var_context& init,
if (save_iterations) {
std::vector<double> values;
std::stringstream msg;
model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg);
try {
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);
} catch (const std::exception& e) {
if (msg.str().length() > 0) {
logger.info(msg);
}
logger.error(e.what());
return error_codes::SOFTWARE;
}
if (msg.str().length() > 0)
logger.info(msg);

Expand All @@ -119,7 +128,13 @@ int bfgs(Model& model, const stan::io::var_context& init,
" # evals"
" Notes ");

ret = bfgs.step();
try {
ret = bfgs.step();
} catch (const std::exception& e) {
logger.error(e.what());
return error_codes::SOFTWARE;
}

lp = bfgs.logp();
bfgs.params_r(cont_vector);

Expand Down Expand Up @@ -150,8 +165,16 @@ int bfgs(Model& model, const stan::io::var_context& init,
if (save_iterations) {
std::vector<double> values;
std::stringstream msg;
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);
try {
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);
} catch (const std::exception& e) {
if (msg.str().length() > 0) {
logger.info(msg);
}
logger.error(e.what());
return error_codes::SOFTWARE;
}
// This if is here to match the pre-refactor behavior
if (msg.str().length() > 0)
logger.info(msg);
Expand All @@ -164,7 +187,16 @@ int bfgs(Model& model, const stan::io::var_context& init,
if (!save_iterations) {
std::vector<double> values;
std::stringstream msg;
model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg);
try {
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);
} catch (const std::exception& e) {
if (msg.str().length() > 0) {
logger.info(msg);
}
logger.error(e.what());
return error_codes::SOFTWARE;
}
if (msg.str().length() > 0)
logger.info(msg);
values.insert(values.begin(), lp);
Expand Down
30 changes: 26 additions & 4 deletions src/stan/services/optimize/lbfgs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,12 @@ int lbfgs(Model& model, const stan::io::var_context& init,
" # evals"
" Notes ");

ret = lbfgs.step();
try {
ret = lbfgs.step();
} catch (const std::exception& e) {
logger.error(e.what());
return error_codes::SOFTWARE;
}
lp = lbfgs.logp();
lbfgs.params_r(cont_vector);

Expand Down Expand Up @@ -154,8 +159,16 @@ int lbfgs(Model& model, const stan::io::var_context& init,
if (save_iterations) {
std::vector<double> values;
std::stringstream msg;
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);
try {
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);
} catch (const std::exception& e) {
if (msg.str().length() > 0) {
logger.info(msg);
}
logger.error(e.what());
return error_codes::SOFTWARE;
}
if (msg.str().length() > 0)
logger.info(msg);

Expand All @@ -167,7 +180,16 @@ int lbfgs(Model& model, const stan::io::var_context& init,
if (!save_iterations) {
std::vector<double> values;
std::stringstream msg;
model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg);
try {
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);
} catch (const std::exception& e) {
if (msg.str().length() > 0) {
logger.info(msg);
}
logger.error(e.what());
return error_codes::SOFTWARE;
}
if (msg.str().length() > 0)
logger.info(msg);

Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/optimize/newton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ int newton(Model& model, const stan::io::var_context& init,
lp = model.template log_prob<false, jacobian>(cont_vector, disc_vector,
&message);
logger.info(message);
} catch (const std::exception& e) {
} catch (const std::domain_error& e) {
logger.info("");
logger.info(
"Informational Message: The current"
Expand Down
49 changes: 27 additions & 22 deletions src/stan/services/pathfinder/multi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,28 +117,33 @@ inline int pathfinder_lbfgs_multi(
individual_samples;
individual_samples.resize(num_paths);
std::atomic<size_t> lp_calls{0};
tbb::parallel_for(
tbb::blocked_range<int>(0, num_paths), [&](tbb::blocked_range<int> r) {
for (int iter = r.begin(); iter < r.end(); ++iter) {
auto pathfinder_ret
= stan::services::pathfinder::pathfinder_lbfgs_single<true>(
model, *(init[iter]), random_seed, stride_id + iter,
init_radius, history_size, init_alpha, tol_obj, tol_rel_obj,
tol_grad, tol_rel_grad, tol_param, num_iterations,
num_elbo_draws, num_draws, save_iterations, refresh,
interrupt, logger, init_writers[iter],
single_path_parameter_writer[iter],
single_path_diagnostic_writer[iter], calculate_lp);
if (unlikely(std::get<0>(pathfinder_ret) != error_codes::OK)) {
logger.error(std::string("Pathfinder iteration: ")
+ std::to_string(iter) + " failed.");
return;
try {
tbb::parallel_for(
tbb::blocked_range<int>(0, num_paths), [&](tbb::blocked_range<int> r) {
for (int iter = r.begin(); iter < r.end(); ++iter) {
auto pathfinder_ret
= stan::services::pathfinder::pathfinder_lbfgs_single<true>(
model, *(init[iter]), random_seed, stride_id + iter,
init_radius, history_size, init_alpha, tol_obj, tol_rel_obj,
tol_grad, tol_rel_grad, tol_param, num_iterations,
num_elbo_draws, num_draws, save_iterations, refresh,
interrupt, logger, init_writers[iter],
single_path_parameter_writer[iter],
single_path_diagnostic_writer[iter], calculate_lp);
if (unlikely(std::get<0>(pathfinder_ret) != error_codes::OK)) {
logger.error(std::string("Pathfinder iteration: ")
+ std::to_string(iter) + " failed.");
return;
}
individual_lp_ratios[iter] = std::move(std::get<1>(pathfinder_ret));
individual_samples[iter] = std::move(std::get<2>(pathfinder_ret));
lp_calls += std::get<3>(pathfinder_ret);
}
individual_lp_ratios[iter] = std::move(std::get<1>(pathfinder_ret));
individual_samples[iter] = std::move(std::get<2>(pathfinder_ret));
lp_calls += std::get<3>(pathfinder_ret);
}
});
});
} catch (const std::exception& e) {
logger.error(e.what());
return error_codes::SOFTWARE;
}

// if any pathfinders failed, we want to remove their empty results
individual_lp_ratios.erase(
Expand Down Expand Up @@ -231,7 +236,7 @@ inline int pathfinder_lbfgs_multi(
parameter_writer(total_time_str);
}
parameter_writer();
return 0;
return error_codes::OK;
}
} // namespace pathfinder
} // namespace services
Expand Down
7 changes: 6 additions & 1 deletion src/stan/services/pathfinder/single.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,12 @@ inline auto pathfinder_lbfgs_single(
logger.info(lbfgs_ss);
lbfgs_ss.str("");
}
throw;
if (ReturnLpSamples) {
throw;
} else {
logger.error(e.what());
return error_codes::SOFTWARE;
}
}
}
if (unlikely(save_iterations)) {
Expand Down
58 changes: 34 additions & 24 deletions src/stan/services/sample/fixed_param.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,14 @@ int fixed_param(Model& model, const stan::io::var_context& init,
writer.write_diagnostic_names(s, sampler, model);

auto start = std::chrono::steady_clock::now();
util::generate_transitions(sampler, num_samples, 0, num_samples, num_thin,
refresh, true, false, writer, s, model, rng,
interrupt, logger);
try {
util::generate_transitions(sampler, num_samples, 0, num_samples, num_thin,
refresh, true, false, writer, s, model, rng,
interrupt, logger);
} catch (const std::exception& e) {
logger.error(e.what());
return error_codes::SOFTWARE;
}
auto end = std::chrono::steady_clock::now();
double sample_delta_t
= std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
Expand Down Expand Up @@ -156,27 +161,32 @@ int fixed_param(Model& model, const std::size_t num_chains,
writers[i].write_diagnostic_names(samples[i], samplers[i], model);
}

tbb::parallel_for(
tbb::blocked_range<size_t>(0, num_chains, 1),
[&samplers, &writers, &samples, &model, &rngs, &interrupt, &logger,
num_samples, num_thin, refresh, chain,
num_chains](const tbb::blocked_range<size_t>& r) {
for (size_t i = r.begin(); i != r.end(); ++i) {
auto start = std::chrono::steady_clock::now();
util::generate_transitions(samplers[i], num_samples, 0, num_samples,
num_thin, refresh, true, false, writers[i],
samples[i], model, rngs[i], interrupt,
logger, chain + i, num_chains);
auto end = std::chrono::steady_clock::now();
double sample_delta_t
= std::chrono::duration_cast<std::chrono::milliseconds>(end
- start)
.count()
/ 1000.0;
writers[i].write_timing(0.0, sample_delta_t);
}
},
tbb::simple_partitioner());
try {
tbb::parallel_for(
tbb::blocked_range<size_t>(0, num_chains, 1),
[&samplers, &writers, &samples, &model, &rngs, &interrupt, &logger,
num_samples, num_thin, refresh, chain,
num_chains](const tbb::blocked_range<size_t>& r) {
for (size_t i = r.begin(); i != r.end(); ++i) {
auto start = std::chrono::steady_clock::now();
util::generate_transitions(
samplers[i], num_samples, 0, num_samples, num_thin, refresh,
true, false, writers[i], samples[i], model, rngs[i], interrupt,
logger, chain + i, num_chains);
auto end = std::chrono::steady_clock::now();
double sample_delta_t
= std::chrono::duration_cast<std::chrono::milliseconds>(end
- start)
.count()
/ 1000.0;
writers[i].write_timing(0.0, sample_delta_t);
}
},
tbb::simple_partitioner());
} catch (const std::exception& e) {
logger.error(e.what());
return error_codes::SOFTWARE;
}
return error_codes::OK;
}

Expand Down
44 changes: 27 additions & 17 deletions src/stan/services/sample/hmc_nuts_dense_e.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,14 @@ int hmc_nuts_dense_e(Model& model, const stan::io::var_context& init,
sampler.set_stepsize_jitter(stepsize_jitter);
sampler.set_max_depth(max_depth);

util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples,
num_thin, refresh, save_warmup, rng, interrupt, logger,
sample_writer, diagnostic_writer);
try {
util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples,
num_thin, refresh, save_warmup, rng, interrupt, logger,
sample_writer, diagnostic_writer);
} catch (const std::exception& e) {
logger.error(e.what());
return error_codes::SOFTWARE;
}
return error_codes::OK;
}

Expand Down Expand Up @@ -221,20 +226,25 @@ int hmc_nuts_dense_e(Model& model, size_t num_chains,
logger.error(e.what());
return error_codes::CONFIG;
}
tbb::parallel_for(
tbb::blocked_range<size_t>(0, num_chains, 1),
[num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains,
init_chain_id, &samplers, &model, &rngs, &interrupt, &logger,
&sample_writer, &cont_vectors,
&diagnostic_writer](const tbb::blocked_range<size_t>& r) {
for (size_t i = r.begin(); i != r.end(); ++i) {
util::run_sampler(samplers[i], model, cont_vectors[i], num_warmup,
num_samples, num_thin, refresh, save_warmup,
rngs[i], interrupt, logger, sample_writer[i],
diagnostic_writer[i], init_chain_id + i);
}
},
tbb::simple_partitioner());
try {
tbb::parallel_for(
tbb::blocked_range<size_t>(0, num_chains, 1),
[num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains,
init_chain_id, &samplers, &model, &rngs, &interrupt, &logger,
&sample_writer, &cont_vectors,
&diagnostic_writer](const tbb::blocked_range<size_t>& r) {
for (size_t i = r.begin(); i != r.end(); ++i) {
util::run_sampler(samplers[i], model, cont_vectors[i], num_warmup,
num_samples, num_thin, refresh, save_warmup,
rngs[i], interrupt, logger, sample_writer[i],
diagnostic_writer[i], init_chain_id + i);
}
},
tbb::simple_partitioner());
} catch (const std::exception& e) {
logger.error(e.what());
return error_codes::SOFTWARE;
}
return error_codes::OK;
}

Expand Down
Loading

0 comments on commit 2709df7

Please sign in to comment.