Skip to content

Commit

Permalink
Merge pull request #3259 from stan-dev/fix/3258-unconditional-excepti…
Browse files Browse the repository at this point in the history
…on-swallowing

Only swallow domain_errors in various algorithms
  • Loading branch information
WardBrian authored Feb 13, 2024
2 parents b6d010f + c95036b commit 6697a24
Show file tree
Hide file tree
Showing 26 changed files with 507 additions and 328 deletions.
4 changes: 2 additions & 2 deletions src/stan/mcmc/hmc/hamiltonians/base_hamiltonian.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class base_hamiltonian {
void update_potential(Point& z, callbacks::logger& logger) {
try {
z.V = -stan::model::log_prob_propto<true>(model_, z.q);
} catch (const std::exception& e) {
} catch (const std::domain_error& e) {
this->write_error_msg_(e, logger);
z.V = std::numeric_limits<double>::infinity();
}
Expand All @@ -62,7 +62,7 @@ class base_hamiltonian {
try {
stan::model::gradient(model_, z.q, z.V, z.g, logger);
z.V = -z.V;
} catch (const std::exception& e) {
} catch (const std::domain_error& e) {
this->write_error_msg_(e, logger);
z.V = std::numeric_limits<double>::infinity();
}
Expand Down
4 changes: 2 additions & 2 deletions src/stan/optimization/bfgs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ class ModelAdaptor {

try {
f = -log_prob_propto<jacobian>(_model, _x, _params_i, _msgs);
} catch (const std::exception &e) {
} catch (const std::domain_error &e) {
if (_msgs)
(*_msgs) << e.what() << std::endl;
return 1;
Expand Down Expand Up @@ -341,7 +341,7 @@ class ModelAdaptor {

try {
f = -log_prob_grad<true, jacobian>(_model, _x, _params_i, _g, _msgs);
} catch (const std::exception &e) {
} catch (const std::domain_error &e) {
if (_msgs)
(*_msgs) << e.what() << std::endl;
return 1;
Expand Down
2 changes: 1 addition & 1 deletion src/stan/optimization/newton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ double newton_step(M& model, std::vector<double>& params_r,
try {
f1 = stan::model::log_prob_grad<true, jacobian>(model, new_params_r,
params_i, gradient);
} catch (std::exception& e) {
} catch (std::domain_error& e) {
// FIXME: this is not a good way to handle a general exception
f1 = -1e100;
}
Expand Down
9 changes: 7 additions & 2 deletions src/stan/services/experimental/advi/fullrank.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,13 @@ int fullrank(Model& model, const stan::io::var_context& init,
stan::rng_t>
cmd_advi(model, cont_params, rng, grad_samples, elbo_samples, eval_elbo,
output_samples);
cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj,
max_iterations, logger, parameter_writer, diagnostic_writer);
try {
cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj,
max_iterations, logger, parameter_writer, diagnostic_writer);
} catch (const std::exception& e) {
logger.error(e.what());
return error_codes::SOFTWARE;
}

return stan::services::error_codes::OK;
}
Expand Down
9 changes: 7 additions & 2 deletions src/stan/services/experimental/advi/meanfield.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,13 @@ int meanfield(Model& model, const stan::io::var_context& init,
stan::rng_t>
cmd_advi(model, cont_params, rng, grad_samples, elbo_samples, eval_elbo,
output_samples);
cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj,
max_iterations, logger, parameter_writer, diagnostic_writer);
try {
cmd_advi.run(eta, adapt_engaged, adapt_iterations, tol_rel_obj,
max_iterations, logger, parameter_writer, diagnostic_writer);
} catch (const std::exception& e) {
logger.error(e.what());
return error_codes::SOFTWARE;
}

return stan::services::error_codes::OK;
}
Expand Down
133 changes: 79 additions & 54 deletions src/stan/services/optimize/bfgs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,16 @@ int bfgs(Model& model, const stan::io::var_context& init,
if (save_iterations) {
std::vector<double> values;
std::stringstream msg;
model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg);
try {
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);
} catch (const std::exception& e) {
if (msg.str().length() > 0) {
logger.info(msg);
}
logger.error(e.what());
return error_codes::SOFTWARE;
}
if (msg.str().length() > 0)
logger.info(msg);

Expand All @@ -105,66 +114,82 @@ int bfgs(Model& model, const stan::io::var_context& init,
}
int ret = 0;

while (ret == 0) {
interrupt();
if (refresh > 0
&& (bfgs.iter_num() == 0 || ((bfgs.iter_num() + 1) % refresh == 0)))
logger.info(
" Iter"
" log prob"
" ||dx||"
" ||grad||"
" alpha"
" alpha0"
" # evals"
" Notes ");

ret = bfgs.step();
lp = bfgs.logp();
bfgs.params_r(cont_vector);

if (refresh > 0
&& (ret != 0 || !bfgs.note().empty() || bfgs.iter_num() == 0
|| ((bfgs.iter_num() + 1) % refresh == 0))) {
std::stringstream msg;
msg << " " << std::setw(7) << bfgs.iter_num() << " ";
msg << " " << std::setw(12) << std::setprecision(6) << lp << " ";
msg << " " << std::setw(12) << std::setprecision(6)
<< bfgs.prev_step_size() << " ";
msg << " " << std::setw(12) << std::setprecision(6)
<< bfgs.curr_g().norm() << " ";
msg << " " << std::setw(10) << std::setprecision(4) << bfgs.alpha()
<< " ";
msg << " " << std::setw(10) << std::setprecision(4) << bfgs.alpha0()
<< " ";
msg << " " << std::setw(7) << bfgs.grad_evals() << " ";
msg << " " << bfgs.note() << " ";
logger.info(msg);
}

if (bfgs_ss.str().length() > 0) {
logger.info(bfgs_ss);
bfgs_ss.str("");
}

if (save_iterations) {
std::vector<double> values;
std::stringstream msg;
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);
// This if is here to match the pre-refactor behavior
if (msg.str().length() > 0)
try {
while (ret == 0) {
interrupt();
if (refresh > 0
&& (bfgs.iter_num() == 0 || ((bfgs.iter_num() + 1) % refresh == 0)))
logger.info(
" Iter"
" log prob"
" ||dx||"
" ||grad||"
" alpha"
" alpha0"
" # evals"
" Notes ");

ret = bfgs.step();

lp = bfgs.logp();
bfgs.params_r(cont_vector);

if (refresh > 0
&& (ret != 0 || !bfgs.note().empty() || bfgs.iter_num() == 0
|| ((bfgs.iter_num() + 1) % refresh == 0))) {
std::stringstream msg;
msg << " " << std::setw(7) << bfgs.iter_num() << " ";
msg << " " << std::setw(12) << std::setprecision(6) << lp << " ";
msg << " " << std::setw(12) << std::setprecision(6)
<< bfgs.prev_step_size() << " ";
msg << " " << std::setw(12) << std::setprecision(6)
<< bfgs.curr_g().norm() << " ";
msg << " " << std::setw(10) << std::setprecision(4) << bfgs.alpha()
<< " ";
msg << " " << std::setw(10) << std::setprecision(4) << bfgs.alpha0()
<< " ";
msg << " " << std::setw(7) << bfgs.grad_evals() << " ";
msg << " " << bfgs.note() << " ";
logger.info(msg);

values.insert(values.begin(), lp);
parameter_writer(values);
}

if (bfgs_ss.str().length() > 0) {
logger.info(bfgs_ss);
bfgs_ss.str("");
}

if (save_iterations) {
std::vector<double> values;
std::stringstream msg;
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);

// This if is here to match the pre-refactor behavior
if (msg.str().length() > 0)
logger.info(msg);

values.insert(values.begin(), lp);
parameter_writer(values);
}
}
} catch (const std::exception& e) {
logger.error(e.what());
return error_codes::SOFTWARE;
}

if (!save_iterations) {
std::vector<double> values;
std::stringstream msg;
model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg);
try {
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);
} catch (const std::exception& e) {
if (msg.str().length() > 0) {
logger.info(msg);
}
logger.error(e.what());
return error_codes::SOFTWARE;
}
if (msg.str().length() > 0)
logger.info(msg);
values.insert(values.begin(), lp);
Expand Down
119 changes: 67 additions & 52 deletions src/stan/services/optimize/lbfgs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,65 +109,80 @@ int lbfgs(Model& model, const stan::io::var_context& init,
}
int ret = 0;

while (ret == 0) {
interrupt();
if (refresh > 0
&& (lbfgs.iter_num() == 0 || ((lbfgs.iter_num() + 1) % refresh == 0)))
logger.info(
" Iter"
" log prob"
" ||dx||"
" ||grad||"
" alpha"
" alpha0"
" # evals"
" Notes ");

ret = lbfgs.step();
lp = lbfgs.logp();
lbfgs.params_r(cont_vector);

if (refresh > 0
&& (ret != 0 || !lbfgs.note().empty() || lbfgs.iter_num() == 0
|| ((lbfgs.iter_num() + 1) % refresh == 0))) {
std::stringstream msg;
msg << " " << std::setw(7) << lbfgs.iter_num() << " ";
msg << " " << std::setw(12) << std::setprecision(6) << lp << " ";
msg << " " << std::setw(12) << std::setprecision(6)
<< lbfgs.prev_step_size() << " ";
msg << " " << std::setw(12) << std::setprecision(6)
<< lbfgs.curr_g().norm() << " ";
msg << " " << std::setw(10) << std::setprecision(4) << lbfgs.alpha()
<< " ";
msg << " " << std::setw(10) << std::setprecision(4) << lbfgs.alpha0()
<< " ";
msg << " " << std::setw(7) << lbfgs.grad_evals() << " ";
msg << " " << lbfgs.note() << " ";
logger.info(msg);
}

if (lbfgs_ss.str().length() > 0) {
logger.info(lbfgs_ss);
lbfgs_ss.str("");
}

if (save_iterations) {
std::vector<double> values;
std::stringstream msg;
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);
if (msg.str().length() > 0)
try {
while (ret == 0) {
interrupt();
if (refresh > 0
&& (lbfgs.iter_num() == 0 || ((lbfgs.iter_num() + 1) % refresh == 0)))
logger.info(
" Iter"
" log prob"
" ||dx||"
" ||grad||"
" alpha"
" alpha0"
" # evals"
" Notes ");

ret = lbfgs.step();

lp = lbfgs.logp();
lbfgs.params_r(cont_vector);

if (refresh > 0
&& (ret != 0 || !lbfgs.note().empty() || lbfgs.iter_num() == 0
|| ((lbfgs.iter_num() + 1) % refresh == 0))) {
std::stringstream msg;
msg << " " << std::setw(7) << lbfgs.iter_num() << " ";
msg << " " << std::setw(12) << std::setprecision(6) << lp << " ";
msg << " " << std::setw(12) << std::setprecision(6)
<< lbfgs.prev_step_size() << " ";
msg << " " << std::setw(12) << std::setprecision(6)
<< lbfgs.curr_g().norm() << " ";
msg << " " << std::setw(10) << std::setprecision(4) << lbfgs.alpha()
<< " ";
msg << " " << std::setw(10) << std::setprecision(4) << lbfgs.alpha0()
<< " ";
msg << " " << std::setw(7) << lbfgs.grad_evals() << " ";
msg << " " << lbfgs.note() << " ";
logger.info(msg);

values.insert(values.begin(), lp);
parameter_writer(values);
}

if (lbfgs_ss.str().length() > 0) {
logger.info(lbfgs_ss);
lbfgs_ss.str("");
}

if (save_iterations) {
std::vector<double> values;
std::stringstream msg;
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);
if (msg.str().length() > 0)
logger.info(msg);

values.insert(values.begin(), lp);
parameter_writer(values);
}
}
} catch (const std::exception& e) {
logger.error(e.what());
return error_codes::SOFTWARE;
}

if (!save_iterations) {
std::vector<double> values;
std::stringstream msg;
model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg);
try {
model.write_array(rng, cont_vector, disc_vector, values, true, true,
&msg);
} catch (const std::exception& e) {
if (msg.str().length() > 0) {
logger.info(msg);
}
logger.error(e.what());
return error_codes::SOFTWARE;
}
if (msg.str().length() > 0)
logger.info(msg);

Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/optimize/newton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ int newton(Model& model, const stan::io::var_context& init,
lp = model.template log_prob<false, jacobian>(cont_vector, disc_vector,
&message);
logger.info(message);
} catch (const std::exception& e) {
} catch (const std::domain_error& e) {
logger.info("");
logger.info(
"Informational Message: The current"
Expand Down
Loading

0 comments on commit 6697a24

Please sign in to comment.