Skip to content

Commit

Permalink
Check that atol is of correct length.
Browse files Browse the repository at this point in the history
Sundials' VCopy_Serial uses the input as a template for length.
  • Loading branch information
bjodah committed Feb 27, 2018
1 parent 704cbad commit 7601d03
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
10 changes: 7 additions & 3 deletions pycvodes/include/cvodes_anyode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,14 @@ namespace cvodes_anyode {
}
if (atol.size() == 1){
integr.set_tol(rtol, atol[0]);
} else if (atol.size() != (size_t)ny and atol.size() != (size_t)(ny+nq)) {
throw std::runtime_error("atol of incorrect length");
} else {
} else if (atol.size() == (size_t)ny) {
integr.set_tol(rtol, atol);
} else if (atol.size() == (size_t)(ny+nq)) {
sundials_cxx::nvector_serial::VectorView atol_(ny, atol.data());
integr.set_tol(rtol, atol_.n_vec);
} else {
throw std::runtime_error("atol of incorrect length");

}
integr.set_init_step(dx0);
if (dx_min != 0.0)
Expand Down
4 changes: 4 additions & 0 deletions pycvodes/include/cvodes_cxx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ namespace cvodes_cxx {
throw std::runtime_error("CVodeSStolerances failed.");
}
void set_tol(realtype rtol, N_Vector atol){
if (NV_LENGTH_S(atol) != ny)
throw std::runtime_error("atol of incorrect length");
int status = CVodeSVtolerances(this->mem, rtol, atol);
if (status < 0)
throw std::runtime_error("CVodeSVtolerances failed.");
Expand All @@ -316,6 +318,8 @@ namespace cvodes_cxx {
}
}
void set_quad_tol(realtype reltolQ, const N_Vector abstolQ){
if (NV_LENGTH_S(abstolQ) != nq)
throw std::runtime_error("abstolQ of incorrect length");
int flag = CVodeQuadSVtolerances(this->mem, reltolQ, abstolQ);
switch(flag){
case CV_ILL_INPUT:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cvodes_anyode_quad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ TEST_CASE( "quadrature_adaptive", "[simple_adaptive]" ) {
int autorestart=2;

auto nout = cvodes_anyode::simple_adaptive(
&xyqout, &td, &odesys, {1e-10}, 1e-10, cvodes_cxx::LMM::BDF, tend, root_indices,
&xyqout, &td, &odesys, {1e-10, 1e-11, 1e-10}, 1e-10, cvodes_cxx::LMM::BDF, tend, root_indices,
mxsteps, dx0, dx_min, dx_max, with_jacobian, iter_type, linear_solver,
maxl, eps_lin, nderiv, return_on_root, autorestart);
REQUIRE((nout + 1) == td);
Expand Down Expand Up @@ -111,7 +111,7 @@ TEST_CASE( "quadrature_predefined", "[simple_predefined]" ) {
std::vector<double> root_out;

auto nout = cvodes_anyode::simple_predefined(
&odesys, {1e-10}, 1e-10, cvodes_cxx::LMM::BDF, yqout.data(), nt, tout.data(),
&odesys, {1e-10, 1e-11, 1e-10}, 1e-10, cvodes_cxx::LMM::BDF, yqout.data(), nt, tout.data(),
yqout.data(), root_indices, root_out);
REQUIRE(nout == nt);

Expand Down

0 comments on commit 7601d03

Please sign in to comment.