From 2709df7f11af708477088f5609f6f70163ee5135 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 25 Jan 2024 10:40:49 -0500 Subject: [PATCH] Maintain exception safety in top-level services functions --- .../services/experimental/advi/fullrank.hpp | 9 +- .../services/experimental/advi/meanfield.hpp | 9 +- src/stan/services/optimize/bfgs.hpp | 42 +++++++-- src/stan/services/optimize/lbfgs.hpp | 30 ++++++- src/stan/services/optimize/newton.hpp | 2 +- src/stan/services/pathfinder/multi.hpp | 49 ++++++----- src/stan/services/pathfinder/single.hpp | 7 +- src/stan/services/sample/fixed_param.hpp | 58 +++++++------ src/stan/services/sample/hmc_nuts_dense_e.hpp | 44 ++++++---- .../sample/hmc_nuts_dense_e_adapt.hpp | 49 ++++++----- src/stan/services/sample/hmc_nuts_diag_e.hpp | 12 ++- .../services/sample/hmc_nuts_diag_e_adapt.hpp | 49 ++++++----- src/stan/services/sample/hmc_nuts_unit_e.hpp | 47 ++++++---- .../services/sample/hmc_nuts_unit_e_adapt.hpp | 49 ++++++----- .../services/sample/hmc_static_dense_e.hpp | 11 ++- .../sample/hmc_static_dense_e_adapt.hpp | 13 ++- .../services/sample/hmc_static_diag_e.hpp | 12 ++- .../sample/hmc_static_diag_e_adapt.hpp | 14 ++- .../services/sample/hmc_static_unit_e.hpp | 11 ++- .../sample/hmc_static_unit_e_adapt.hpp | 13 ++- src/stan/services/sample/standalone_gqs.hpp | 86 +++++++++++-------- 21 files changed, 395 insertions(+), 221 deletions(-) diff --git a/src/stan/services/experimental/advi/fullrank.hpp b/src/stan/services/experimental/advi/fullrank.hpp index 96cf5d05c0..5fba2e4e02 100644 --- a/src/stan/services/experimental/advi/fullrank.hpp +++ b/src/stan/services/experimental/advi/fullrank.hpp @@ -86,8 +86,13 @@ int fullrank(Model& model, const stan::io::var_context& init, stan::rng_t> cmd_advi(model, cont_params, rng, grad_samples, elbo_samples, eval_elbo, output_samples); - cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj, - max_iterations, logger, parameter_writer, diagnostic_writer); + try { + cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj, + max_iterations, logger, parameter_writer, diagnostic_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return stan::services::error_codes::OK; } diff --git a/src/stan/services/experimental/advi/meanfield.hpp b/src/stan/services/experimental/advi/meanfield.hpp index 6cffe548ac..49bee28505 100644 --- a/src/stan/services/experimental/advi/meanfield.hpp +++ b/src/stan/services/experimental/advi/meanfield.hpp @@ -85,8 +85,13 @@ int meanfield(Model& model, const stan::io::var_context& init, stan::rng_t> cmd_advi(model, cont_params, rng, grad_samples, elbo_samples, eval_elbo, output_samples); - cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj, - max_iterations, logger, parameter_writer, diagnostic_writer); + try { + cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj, + max_iterations, logger, parameter_writer, diagnostic_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return stan::services::error_codes::OK; } diff --git a/src/stan/services/optimize/bfgs.hpp b/src/stan/services/optimize/bfgs.hpp index 2819b853a6..bcb0e49f31 100644 --- a/src/stan/services/optimize/bfgs.hpp +++ b/src/stan/services/optimize/bfgs.hpp @@ -96,7 +96,16 @@ int bfgs(Model& model, const stan::io::var_context& init, if (save_iterations) { std::vector values; std::stringstream msg; - model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg); + try { + model.write_array(rng, cont_vector, disc_vector, values, true, true, + &msg); + } catch (const std::exception& e) { + if (msg.str().length() > 0) { + logger.info(msg); + } + logger.error(e.what()); + return error_codes::SOFTWARE; + } if (msg.str().length() > 0) logger.info(msg); @@ -119,7 +128,13 @@ int bfgs(Model& model, const stan::io::var_context& init, " # evals" " Notes "); - ret = bfgs.step(); + try { + ret = bfgs.step(); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } + lp = bfgs.logp(); bfgs.params_r(cont_vector); @@ -150,8 +165,16 @@ int bfgs(Model& model, const stan::io::var_context& init, if (save_iterations) { std::vector values; std::stringstream msg; - model.write_array(rng, cont_vector, disc_vector, values, true, true, - &msg); + try { + model.write_array(rng, cont_vector, disc_vector, values, true, true, + &msg); + } catch (const std::exception& e) { + if (msg.str().length() > 0) { + logger.info(msg); + } + logger.error(e.what()); + return error_codes::SOFTWARE; + } // This if is here to match the pre-refactor behavior if (msg.str().length() > 0) logger.info(msg); @@ -164,7 +187,16 @@ int bfgs(Model& model, const stan::io::var_context& init, if (!save_iterations) { std::vector values; std::stringstream msg; - model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg); + try { + model.write_array(rng, cont_vector, disc_vector, values, true, true, + &msg); + } catch (const std::exception& e) { + if (msg.str().length() > 0) { + logger.info(msg); + } + logger.error(e.what()); + return error_codes::SOFTWARE; + } if (msg.str().length() > 0) logger.info(msg); values.insert(values.begin(), lp); diff --git a/src/stan/services/optimize/lbfgs.hpp b/src/stan/services/optimize/lbfgs.hpp index 083e37ffed..9045b5470e 100644 --- a/src/stan/services/optimize/lbfgs.hpp +++ b/src/stan/services/optimize/lbfgs.hpp @@ -123,7 +123,12 @@ int lbfgs(Model& model, const stan::io::var_context& init, " # evals" " Notes "); - ret = lbfgs.step(); + try { + ret = lbfgs.step(); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } lp = lbfgs.logp(); lbfgs.params_r(cont_vector); @@ -154,8 +159,16 @@ int lbfgs(Model& model, const stan::io::var_context& init, if (save_iterations) { std::vector values; std::stringstream msg; - model.write_array(rng, cont_vector, disc_vector, values, true, true, - &msg); + try { + model.write_array(rng, cont_vector, disc_vector, values, true, true, + &msg); + } catch (const std::exception& e) { + if (msg.str().length() > 0) { + logger.info(msg); + } + logger.error(e.what()); + return error_codes::SOFTWARE; + } if (msg.str().length() > 0) logger.info(msg); @@ -167,7 +180,16 @@ int lbfgs(Model& model, const stan::io::var_context& init, if (!save_iterations) { std::vector values; std::stringstream msg; - model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg); + try { + model.write_array(rng, cont_vector, disc_vector, values, true, true, + &msg); + } catch (const std::exception& e) { + if (msg.str().length() > 0) { + logger.info(msg); + } + logger.error(e.what()); + return error_codes::SOFTWARE; + } if (msg.str().length() > 0) logger.info(msg); diff --git a/src/stan/services/optimize/newton.hpp b/src/stan/services/optimize/newton.hpp index 081365f0a9..db64f6e46c 100644 --- a/src/stan/services/optimize/newton.hpp +++ b/src/stan/services/optimize/newton.hpp @@ -62,7 +62,7 @@ int newton(Model& model, const stan::io::var_context& init, lp = model.template log_prob(cont_vector, disc_vector, &message); logger.info(message); - } catch (const std::exception& e) { + } catch (const std::domain_error& e) { logger.info(""); logger.info( "Informational Message: The current" diff --git a/src/stan/services/pathfinder/multi.hpp b/src/stan/services/pathfinder/multi.hpp index e87eaa63e3..924f0806b3 100644 --- a/src/stan/services/pathfinder/multi.hpp +++ b/src/stan/services/pathfinder/multi.hpp @@ -117,28 +117,33 @@ inline int pathfinder_lbfgs_multi( individual_samples; individual_samples.resize(num_paths); std::atomic lp_calls{0}; - tbb::parallel_for( - tbb::blocked_range(0, num_paths), [&](tbb::blocked_range r) { - for (int iter = r.begin(); iter < r.end(); ++iter) { - auto pathfinder_ret - = stan::services::pathfinder::pathfinder_lbfgs_single( - model, *(init[iter]), random_seed, stride_id + iter, - init_radius, history_size, init_alpha, tol_obj, tol_rel_obj, - tol_grad, tol_rel_grad, tol_param, num_iterations, - num_elbo_draws, num_draws, save_iterations, refresh, - interrupt, logger, init_writers[iter], - single_path_parameter_writer[iter], - single_path_diagnostic_writer[iter], calculate_lp); - if (unlikely(std::get<0>(pathfinder_ret) != error_codes::OK)) { - logger.error(std::string("Pathfinder iteration: ") - + std::to_string(iter) + " failed."); - return; + try { + tbb::parallel_for( + tbb::blocked_range(0, num_paths), [&](tbb::blocked_range r) { + for (int iter = r.begin(); iter < r.end(); ++iter) { + auto pathfinder_ret + = stan::services::pathfinder::pathfinder_lbfgs_single( + model, *(init[iter]), random_seed, stride_id + iter, + init_radius, history_size, init_alpha, tol_obj, tol_rel_obj, + tol_grad, tol_rel_grad, tol_param, num_iterations, + num_elbo_draws, num_draws, save_iterations, refresh, + interrupt, logger, init_writers[iter], + single_path_parameter_writer[iter], + single_path_diagnostic_writer[iter], calculate_lp); + if (unlikely(std::get<0>(pathfinder_ret) != error_codes::OK)) { + logger.error(std::string("Pathfinder iteration: ") + + std::to_string(iter) + " failed."); + return; + } + individual_lp_ratios[iter] = std::move(std::get<1>(pathfinder_ret)); + individual_samples[iter] = std::move(std::get<2>(pathfinder_ret)); + lp_calls += std::get<3>(pathfinder_ret); } - individual_lp_ratios[iter] = std::move(std::get<1>(pathfinder_ret)); - individual_samples[iter] = std::move(std::get<2>(pathfinder_ret)); - lp_calls += std::get<3>(pathfinder_ret); - } - }); + }); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } // if any pathfinders failed, we want to remove their empty results individual_lp_ratios.erase( @@ -231,7 +236,7 @@ inline int pathfinder_lbfgs_multi( parameter_writer(total_time_str); } parameter_writer(); - return 0; + return error_codes::OK; } } // namespace pathfinder } // namespace services diff --git a/src/stan/services/pathfinder/single.hpp b/src/stan/services/pathfinder/single.hpp index 0f0c7457ba..4719a40e1e 100644 --- a/src/stan/services/pathfinder/single.hpp +++ b/src/stan/services/pathfinder/single.hpp @@ -820,7 +820,12 @@ inline auto pathfinder_lbfgs_single( logger.info(lbfgs_ss); lbfgs_ss.str(""); } - throw; + if (ReturnLpSamples) { + throw; + } else { + logger.error(e.what()); + return error_codes::SOFTWARE; + } } } if (unlikely(save_iterations)) { diff --git a/src/stan/services/sample/fixed_param.hpp b/src/stan/services/sample/fixed_param.hpp index f407b14f57..17e04a621e 100644 --- a/src/stan/services/sample/fixed_param.hpp +++ b/src/stan/services/sample/fixed_param.hpp @@ -74,9 +74,14 @@ int fixed_param(Model& model, const stan::io::var_context& init, writer.write_diagnostic_names(s, sampler, model); auto start = std::chrono::steady_clock::now(); - util::generate_transitions(sampler, num_samples, 0, num_samples, num_thin, - refresh, true, false, writer, s, model, rng, - interrupt, logger); + try { + util::generate_transitions(sampler, num_samples, 0, num_samples, num_thin, + refresh, true, false, writer, s, model, rng, + interrupt, logger); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } auto end = std::chrono::steady_clock::now(); double sample_delta_t = std::chrono::duration_cast(end - start) @@ -156,27 +161,32 @@ int fixed_param(Model& model, const std::size_t num_chains, writers[i].write_diagnostic_names(samples[i], samplers[i], model); } - tbb::parallel_for( - tbb::blocked_range(0, num_chains, 1), - [&samplers, &writers, &samples, &model, &rngs, &interrupt, &logger, - num_samples, num_thin, refresh, chain, - num_chains](const tbb::blocked_range& r) { - for (size_t i = r.begin(); i != r.end(); ++i) { - auto start = std::chrono::steady_clock::now(); - util::generate_transitions(samplers[i], num_samples, 0, num_samples, - num_thin, refresh, true, false, writers[i], - samples[i], model, rngs[i], interrupt, - logger, chain + i, num_chains); - auto end = std::chrono::steady_clock::now(); - double sample_delta_t - = std::chrono::duration_cast(end - - start) - .count() - / 1000.0; - writers[i].write_timing(0.0, sample_delta_t); - } - }, - tbb::simple_partitioner()); + try { + tbb::parallel_for( + tbb::blocked_range(0, num_chains, 1), + [&samplers, &writers, &samples, &model, &rngs, &interrupt, &logger, + num_samples, num_thin, refresh, chain, + num_chains](const tbb::blocked_range& r) { + for (size_t i = r.begin(); i != r.end(); ++i) { + auto start = std::chrono::steady_clock::now(); + util::generate_transitions( + samplers[i], num_samples, 0, num_samples, num_thin, refresh, + true, false, writers[i], samples[i], model, rngs[i], interrupt, + logger, chain + i, num_chains); + auto end = std::chrono::steady_clock::now(); + double sample_delta_t + = std::chrono::duration_cast(end + - start) + .count() + / 1000.0; + writers[i].write_timing(0.0, sample_delta_t); + } + }, + tbb::simple_partitioner()); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_nuts_dense_e.hpp b/src/stan/services/sample/hmc_nuts_dense_e.hpp index 73e0a2af1d..0fb818b19c 100644 --- a/src/stan/services/sample/hmc_nuts_dense_e.hpp +++ b/src/stan/services/sample/hmc_nuts_dense_e.hpp @@ -81,9 +81,14 @@ int hmc_nuts_dense_e(Model& model, const stan::io::var_context& init, sampler.set_stepsize_jitter(stepsize_jitter); sampler.set_max_depth(max_depth); - util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, - num_thin, refresh, save_warmup, rng, interrupt, logger, - sample_writer, diagnostic_writer); + try { + util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, + num_thin, refresh, save_warmup, rng, interrupt, logger, + sample_writer, diagnostic_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } @@ -221,20 +226,25 @@ int hmc_nuts_dense_e(Model& model, size_t num_chains, logger.error(e.what()); return error_codes::CONFIG; } - tbb::parallel_for( - tbb::blocked_range(0, num_chains, 1), - [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, - init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, - &sample_writer, &cont_vectors, - &diagnostic_writer](const tbb::blocked_range& r) { - for (size_t i = r.begin(); i != r.end(); ++i) { - util::run_sampler(samplers[i], model, cont_vectors[i], num_warmup, - num_samples, num_thin, refresh, save_warmup, - rngs[i], interrupt, logger, sample_writer[i], - diagnostic_writer[i], init_chain_id + i); - } - }, - tbb::simple_partitioner()); + try { + tbb::parallel_for( + tbb::blocked_range(0, num_chains, 1), + [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, + init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, + &sample_writer, &cont_vectors, + &diagnostic_writer](const tbb::blocked_range& r) { + for (size_t i = r.begin(); i != r.end(); ++i) { + util::run_sampler(samplers[i], model, cont_vectors[i], num_warmup, + num_samples, num_thin, refresh, save_warmup, + rngs[i], interrupt, logger, sample_writer[i], + diagnostic_writer[i], init_chain_id + i); + } + }, + tbb::simple_partitioner()); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp index 913c50152a..ce1befff87 100644 --- a/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp @@ -98,11 +98,15 @@ int hmc_nuts_dense_e_adapt( sampler.set_window_params(num_warmup, init_buffer, term_buffer, window, logger); - - util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, - num_samples, num_thin, refresh, save_warmup, rng, - interrupt, logger, sample_writer, - diagnostic_writer, metric_writer); + try { + util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, + num_samples, num_thin, refresh, save_warmup, rng, + interrupt, logger, sample_writer, + diagnostic_writer, metric_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } @@ -379,21 +383,26 @@ int hmc_nuts_dense_e_adapt( logger.error(e.what()); return error_codes::CONFIG; } - tbb::parallel_for( - tbb::blocked_range(0, num_chains, 1), - [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, - init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, - &sample_writer, &cont_vectors, &diagnostic_writer, - &metric_writer](const tbb::blocked_range& r) { - for (size_t i = r.begin(); i != r.end(); ++i) { - util::run_adaptive_sampler( - samplers[i], model, cont_vectors[i], num_warmup, num_samples, - num_thin, refresh, save_warmup, rngs[i], interrupt, logger, - sample_writer[i], diagnostic_writer[i], metric_writer[i], - init_chain_id + i, num_chains); - } - }, - tbb::simple_partitioner()); + try { + tbb::parallel_for( + tbb::blocked_range(0, num_chains, 1), + [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, + init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, + &sample_writer, &cont_vectors, &diagnostic_writer, + &metric_writer](const tbb::blocked_range& r) { + for (size_t i = r.begin(); i != r.end(); ++i) { + util::run_adaptive_sampler( + samplers[i], model, cont_vectors[i], num_warmup, num_samples, + num_thin, refresh, save_warmup, rngs[i], interrupt, logger, + sample_writer[i], diagnostic_writer[i], metric_writer[i], + init_chain_id + i, num_chains); + } + }, + tbb::simple_partitioner()); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_nuts_diag_e.hpp b/src/stan/services/sample/hmc_nuts_diag_e.hpp index e693ed0a83..bb789c0151 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e.hpp @@ -79,10 +79,14 @@ int hmc_nuts_diag_e(Model& model, const stan::io::var_context& init, sampler.set_stepsize_jitter(stepsize_jitter); sampler.set_max_depth(max_depth); - util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, - num_thin, refresh, save_warmup, rng, interrupt, logger, - sample_writer, diagnostic_writer); - + try { + util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, + num_thin, refresh, save_warmup, rng, interrupt, logger, + sample_writer, diagnostic_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp index 1044d9ed53..ec48ade0b5 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp @@ -99,11 +99,15 @@ int hmc_nuts_diag_e_adapt( sampler.set_window_params(num_warmup, init_buffer, term_buffer, window, logger); - util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, - num_samples, num_thin, refresh, save_warmup, rng, - interrupt, logger, sample_writer, - diagnostic_writer, metric_writer); - + try { + util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, + num_samples, num_thin, refresh, save_warmup, rng, + interrupt, logger, sample_writer, + diagnostic_writer, metric_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } @@ -379,21 +383,26 @@ int hmc_nuts_diag_e_adapt( logger.error(e.what()); return error_codes::CONFIG; } - tbb::parallel_for( - tbb::blocked_range(0, num_chains, 1), - [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, - init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, - &sample_writer, &cont_vectors, &diagnostic_writer, - &metric_writer](const tbb::blocked_range& r) { - for (size_t i = r.begin(); i != r.end(); ++i) { - util::run_adaptive_sampler( - samplers[i], model, cont_vectors[i], num_warmup, num_samples, - num_thin, refresh, save_warmup, rngs[i], interrupt, logger, - sample_writer[i], diagnostic_writer[i], metric_writer[i], - init_chain_id + i, num_chains); - } - }, - tbb::simple_partitioner()); + try { + tbb::parallel_for( + tbb::blocked_range(0, num_chains, 1), + [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, + init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, + &sample_writer, &cont_vectors, &diagnostic_writer, + &metric_writer](const tbb::blocked_range& r) { + for (size_t i = r.begin(); i != r.end(); ++i) { + util::run_adaptive_sampler( + samplers[i], model, cont_vectors[i], num_warmup, num_samples, + num_thin, refresh, save_warmup, rngs[i], interrupt, logger, + sample_writer[i], diagnostic_writer[i], metric_writer[i], + init_chain_id + i, num_chains); + } + }, + tbb::simple_partitioner()); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_nuts_unit_e.hpp b/src/stan/services/sample/hmc_nuts_unit_e.hpp index 01c9fe2e1b..c7a2ca7c1b 100644 --- a/src/stan/services/sample/hmc_nuts_unit_e.hpp +++ b/src/stan/services/sample/hmc_nuts_unit_e.hpp @@ -69,10 +69,14 @@ int hmc_nuts_unit_e(Model& model, const stan::io::var_context& init, sampler.set_stepsize_jitter(stepsize_jitter); sampler.set_max_depth(max_depth); - util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, - num_thin, refresh, save_warmup, rng, interrupt, logger, - sample_writer, diagnostic_writer); - + try { + util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, + num_thin, refresh, save_warmup, rng, interrupt, logger, + sample_writer, diagnostic_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } @@ -155,21 +159,26 @@ int hmc_nuts_unit_e(Model& model, size_t num_chains, logger.error(e.what()); return error_codes::CONFIG; } - tbb::parallel_for( - tbb::blocked_range(0, num_chains, 1), - [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, - init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, - &sample_writer, &cont_vectors, - &diagnostic_writer](const tbb::blocked_range& r) { - for (size_t i = r.begin(); i != r.end(); ++i) { - util::run_sampler(samplers[i], model, cont_vectors[i], num_warmup, - num_samples, num_thin, refresh, save_warmup, - rngs[i], interrupt, logger, sample_writer[i], - diagnostic_writer[i], init_chain_id + i, - num_chains); - } - }, - tbb::simple_partitioner()); + try { + tbb::parallel_for( + tbb::blocked_range(0, num_chains, 1), + [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, + init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, + &sample_writer, &cont_vectors, + &diagnostic_writer](const tbb::blocked_range& r) { + for (size_t i = r.begin(); i != r.end(); ++i) { + util::run_sampler(samplers[i], model, cont_vectors[i], num_warmup, + num_samples, num_thin, refresh, save_warmup, + rngs[i], interrupt, logger, sample_writer[i], + diagnostic_writer[i], init_chain_id + i, + num_chains); + } + }, + tbb::simple_partitioner()); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp index 889c6d8920..5d74dae176 100644 --- a/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_unit_e_adapt.hpp @@ -83,11 +83,15 @@ int hmc_nuts_unit_e_adapt( sampler.get_stepsize_adaptation().set_kappa(kappa); sampler.get_stepsize_adaptation().set_t0(t0); - util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, - num_samples, num_thin, refresh, save_warmup, rng, - interrupt, logger, sample_writer, - diagnostic_writer, metric_writer); - + try { + util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, + num_samples, num_thin, refresh, save_warmup, rng, + interrupt, logger, sample_writer, + diagnostic_writer, metric_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } @@ -228,21 +232,26 @@ int hmc_nuts_unit_e_adapt( logger.error(e.what()); return error_codes::CONFIG; } - tbb::parallel_for( - tbb::blocked_range(0, num_chains, 1), - [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, - init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, - &sample_writer, &cont_vectors, &diagnostic_writer, - &metric_writer](const tbb::blocked_range& r) { - for (size_t i = r.begin(); i != r.end(); ++i) { - util::run_adaptive_sampler( - samplers[i], model, cont_vectors[i], num_warmup, num_samples, - num_thin, refresh, save_warmup, rngs[i], interrupt, logger, - sample_writer[i], diagnostic_writer[i], metric_writer[i], - init_chain_id + i, num_chains); - } - }, - tbb::simple_partitioner()); + try { + tbb::parallel_for( + tbb::blocked_range(0, num_chains, 1), + [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, + init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, + &sample_writer, &cont_vectors, &diagnostic_writer, + &metric_writer](const tbb::blocked_range& r) { + for (size_t i = r.begin(); i != r.end(); ++i) { + util::run_adaptive_sampler( + samplers[i], model, cont_vectors[i], num_warmup, num_samples, + num_thin, refresh, save_warmup, rngs[i], interrupt, logger, + sample_writer[i], diagnostic_writer[i], metric_writer[i], + init_chain_id + i, num_chains); + } + }, + tbb::simple_partitioner()); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_static_dense_e.hpp b/src/stan/services/sample/hmc_static_dense_e.hpp index c337636f9c..c093161738 100644 --- a/src/stan/services/sample/hmc_static_dense_e.hpp +++ b/src/stan/services/sample/hmc_static_dense_e.hpp @@ -76,9 +76,14 @@ int hmc_static_dense_e( sampler.set_nominal_stepsize_and_T(stepsize, int_time); sampler.set_stepsize_jitter(stepsize_jitter); - util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, - num_thin, refresh, save_warmup, rng, interrupt, logger, - sample_writer, diagnostic_writer); + try { + util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, + num_thin, refresh, save_warmup, rng, interrupt, logger, + sample_writer, diagnostic_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_static_dense_e_adapt.hpp b/src/stan/services/sample/hmc_static_dense_e_adapt.hpp index 21bd6d711d..b56082620a 100644 --- a/src/stan/services/sample/hmc_static_dense_e_adapt.hpp +++ b/src/stan/services/sample/hmc_static_dense_e_adapt.hpp @@ -98,10 +98,15 @@ int hmc_static_dense_e_adapt( logger); callbacks::structured_writer dummy_metric_writer; - util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, - num_samples, num_thin, refresh, save_warmup, rng, - interrupt, logger, sample_writer, - diagnostic_writer, dummy_metric_writer); + try { + util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, + num_samples, num_thin, refresh, save_warmup, rng, + interrupt, logger, sample_writer, + diagnostic_writer, dummy_metric_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_static_diag_e.hpp b/src/stan/services/sample/hmc_static_diag_e.hpp index 87ea955f84..b19c211047 100644 --- a/src/stan/services/sample/hmc_static_diag_e.hpp +++ b/src/stan/services/sample/hmc_static_diag_e.hpp @@ -78,10 +78,14 @@ int hmc_static_diag_e(Model& model, const stan::io::var_context& init, sampler.set_metric(inv_metric); sampler.set_nominal_stepsize_and_T(stepsize, int_time); sampler.set_stepsize_jitter(stepsize_jitter); - - util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, - num_thin, refresh, save_warmup, rng, interrupt, logger, - sample_writer, diagnostic_writer); + try { + util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, + num_thin, refresh, save_warmup, rng, interrupt, logger, + sample_writer, diagnostic_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_static_diag_e_adapt.hpp b/src/stan/services/sample/hmc_static_diag_e_adapt.hpp index 88979dc341..e4041b59b3 100644 --- a/src/stan/services/sample/hmc_static_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_static_diag_e_adapt.hpp @@ -96,10 +96,16 @@ int hmc_static_diag_e_adapt( logger); callbacks::structured_writer dummy_metric_writer; - util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, - num_samples, num_thin, refresh, save_warmup, rng, - interrupt, logger, sample_writer, - diagnostic_writer, dummy_metric_writer); + + try { + util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, + num_samples, num_thin, refresh, save_warmup, rng, + interrupt, logger, sample_writer, + diagnostic_writer, dummy_metric_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_static_unit_e.hpp b/src/stan/services/sample/hmc_static_unit_e.hpp index d50c902479..8e2b8428fb 100644 --- a/src/stan/services/sample/hmc_static_unit_e.hpp +++ b/src/stan/services/sample/hmc_static_unit_e.hpp @@ -68,9 +68,14 @@ int hmc_static_unit_e(Model& model, const stan::io::var_context& init, sampler.set_nominal_stepsize_and_T(stepsize, int_time); sampler.set_stepsize_jitter(stepsize_jitter); - util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, - num_thin, refresh, save_warmup, rng, interrupt, logger, - sample_writer, diagnostic_writer); + try { + util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples, + num_thin, refresh, save_warmup, rng, interrupt, logger, + sample_writer, diagnostic_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/hmc_static_unit_e_adapt.hpp b/src/stan/services/sample/hmc_static_unit_e_adapt.hpp index fb0da9aff5..bf4bf5c17e 100644 --- a/src/stan/services/sample/hmc_static_unit_e_adapt.hpp +++ b/src/stan/services/sample/hmc_static_unit_e_adapt.hpp @@ -80,10 +80,15 @@ int hmc_static_unit_e_adapt( sampler.get_stepsize_adaptation().set_t0(t0); callbacks::structured_writer dummy_metric_writer; - util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, - num_samples, num_thin, refresh, save_warmup, rng, - interrupt, logger, sample_writer, - diagnostic_writer, dummy_metric_writer); + try { + util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, + num_samples, num_thin, refresh, save_warmup, rng, + interrupt, logger, sample_writer, + diagnostic_writer, dummy_metric_writer); + } catch (const std::exception& e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_codes::OK; } diff --git a/src/stan/services/sample/standalone_gqs.hpp b/src/stan/services/sample/standalone_gqs.hpp index 378be61106..c792acd08f 100644 --- a/src/stan/services/sample/standalone_gqs.hpp +++ b/src/stan/services/sample/standalone_gqs.hpp @@ -67,19 +67,23 @@ int standalone_generate(const Model &model, const Eigen::MatrixXd &draws, std::vector unconstrained_params_r; std::vector row(draws.cols()); - - for (size_t i = 0; i < draws.rows(); ++i) { - Eigen::Map(&row[0], draws.cols()) = draws.row(i); - try { - model.unconstrain_array(row, unconstrained_params_r, &msg); - } catch (const std::exception &e) { - if (msg.str().length() > 0) - logger.error(msg); - logger.error(e.what()); - return error_codes::DATAERR; + try { + for (size_t i = 0; i < draws.rows(); ++i) { + Eigen::Map(&row[0], draws.cols()) = draws.row(i); + try { + model.unconstrain_array(row, unconstrained_params_r, &msg); + } catch (const std::exception &e) { + if (msg.str().length() > 0) + logger.error(msg); + logger.error(e.what()); + return error_codes::DATAERR; + } + interrupt(); // call out to interrupt and fail + writer.write_gq_values(model, rng, unconstrained_params_r); } - interrupt(); // call out to interrupt and fail - writer.write_gq_values(model, rng, unconstrained_params_r); + } catch (const std::exception &e) { + logger.error(e.what()); + return error_codes::SOFTWARE; } return error_codes::OK; } @@ -147,34 +151,40 @@ int standalone_generate(const Model &model, const int num_chains, rngs.emplace_back(util::create_rng(seed, i + 1)); } bool error_any = false; - tbb::parallel_for( - tbb::blocked_range(0, num_chains, 1), - [&draws, &model, &logger, &interrupt, &writers, &rngs, - &error_any](const tbb::blocked_range &r) { - Eigen::VectorXd unconstrained_params_r(draws[0].cols()); - Eigen::VectorXd row(draws[0].cols()); - std::stringstream msg; - for (size_t slice_idx = r.begin(); slice_idx != r.end(); ++slice_idx) { - for (size_t i = 0; i < draws[slice_idx].rows(); ++i) { - if (error_any) - return; - try { - row = draws[slice_idx].row(i); - model.unconstrain_array(row, unconstrained_params_r, &msg); - } catch (const std::exception &e) { - if (msg.str().length() > 0) - logger.error(msg); - logger.error(e.what()); - error_any = true; - return; + try { + tbb::parallel_for( + tbb::blocked_range(0, num_chains, 1), + [&draws, &model, &logger, &interrupt, &writers, &rngs, + &error_any](const tbb::blocked_range &r) { + Eigen::VectorXd unconstrained_params_r(draws[0].cols()); + Eigen::VectorXd row(draws[0].cols()); + std::stringstream msg; + for (size_t slice_idx = r.begin(); slice_idx != r.end(); + ++slice_idx) { + for (size_t i = 0; i < draws[slice_idx].rows(); ++i) { + if (error_any) + return; + try { + row = draws[slice_idx].row(i); + model.unconstrain_array(row, unconstrained_params_r, &msg); + } catch (const std::domain_error &e) { + if (msg.str().length() > 0) + logger.error(msg); + logger.error(e.what()); + error_any = true; + return; + } + interrupt(); // call out to interrupt and fail + writers[slice_idx].write_gq_values(model, rngs[slice_idx], + unconstrained_params_r); } - interrupt(); // call out to interrupt and fail - writers[slice_idx].write_gq_values(model, rngs[slice_idx], - unconstrained_params_r); } - } - }, - tbb::simple_partitioner()); + }, + tbb::simple_partitioner()); + } catch (const std::exception &e) { + logger.error(e.what()); + return error_codes::SOFTWARE; + } return error_any ? error_codes::DATAERR : error_codes::OK; }