Skip to content

Commit

Permalink
Merge pull request #52 from bjodah/additional-time-data
Browse files Browse the repository at this point in the history
Add time_rhs, time_jac, etc.
  • Loading branch information
bjodah authored Feb 28, 2018
2 parents 89ec3ad + 00d5461 commit 22c13c8
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 4 deletions.
39 changes: 37 additions & 2 deletions pycvodes/include/cvodes_anyode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,37 @@ namespace cvodes_anyode {

template<class OdeSys>
int rhs_cb(realtype t, N_Vector y, N_Vector ydot, void *user_data){
auto t_start = std::chrono::high_resolution_clock::now();
auto& odesys = *static_cast<OdeSys*>(user_data);
if (odesys.record_rhs_xvals)
odesys.last_integration_info_vecdbl["rhs_xvals"].push_back(t);
AnyODE::Status status = odesys.rhs(t, NV_DATA_S(y), NV_DATA_S(ydot));
static_cast<Integrator*>(odesys.integrator)->time_rhs += std::chrono::duration<double>(
std::chrono::high_resolution_clock::now() - t_start).count();
return handle_status_(status);
}

template<class OdeSys>
int roots_cb(realtype t, N_Vector y, realtype *gout, void *user_data){
auto t_start = std::chrono::high_resolution_clock::now();
auto& odesys = *static_cast<OdeSys*>(user_data);
AnyODE::Status status = odesys.roots(t, NV_DATA_S(y), gout);
if (status == AnyODE::Status::recoverable_error)
throw std::runtime_error("There are only unrecoverable errors for roots().");
static_cast<Integrator*>(odesys.integrator)->time_roots += std::chrono::duration<double>(
std::chrono::high_resolution_clock::now() - t_start).count();
return handle_status_(status);
}

template<class OdeSys>
int quads_cb(realtype t, N_Vector y, N_Vector yQdot, void *user_data){
auto t_start = std::chrono::high_resolution_clock::now();
auto& odesys = *static_cast<OdeSys*>(user_data);
AnyODE::Status status = odesys.quads(t, NV_DATA_S(y), NV_DATA_S(yQdot));
if (status == AnyODE::Status::recoverable_error)
throw std::runtime_error("There are only unrecoverable errors for quads().");
static_cast<Integrator*>(odesys.integrator)->time_quads += std::chrono::duration<double>(
std::chrono::high_resolution_clock::now() - t_start).count();
return handle_status_(status);
}

Expand All @@ -76,6 +85,7 @@ namespace cvodes_anyode {
AnyODE::ignore(N);
#endif
AnyODE::ignore(tmp1); AnyODE::ignore(tmp2); AnyODE::ignore(tmp3);
auto t_start = std::chrono::high_resolution_clock::now();
auto& odesys = *static_cast<OdeSys*>(user_data);
if (odesys.record_jac_xvals)
odesys.last_integration_info_vecdbl["jac_xvals"].push_back(t);
Expand All @@ -86,6 +96,9 @@ namespace cvodes_anyode {
SM_DATA_D(Jac), odesys.get_ny()
#endif
);

static_cast<Integrator*>(odesys.integrator)->time_jac += std::chrono::duration<double>(
std::chrono::high_resolution_clock::now() - t_start).count();
return handle_status_(status);
}

Expand All @@ -104,6 +117,7 @@ namespace cvodes_anyode {
void *user_data,
N_Vector tmp1, N_Vector tmp2, N_Vector tmp3){
AnyODE::ignore(tmp1); AnyODE::ignore(tmp2); AnyODE::ignore(tmp3);
auto t_start = std::chrono::high_resolution_clock::now();
auto& odesys = *static_cast<OdeSys*>(user_data);
#if SUNDIALS_VERSION_MAJOR < 3
if (odesys.get_mupper() != mupper)
Expand All @@ -118,6 +132,8 @@ namespace cvodes_anyode {
if (odesys.record_jac_xvals)
odesys.last_integration_info_vecdbl["jac_xvals"].push_back(t);
AnyODE::Status status = odesys.banded_jac_cmaj(t, NV_DATA_S(y), NV_DATA_S(fy), Jac_->data + Jac_->s_mu - Jac_->mu, Jac_->ldim);
static_cast<Integrator*>(odesys.integrator)->time_jac += std::chrono::duration<double>(
std::chrono::high_resolution_clock::now() - t_start).count();
return handle_status_(status);
}

Expand All @@ -127,10 +143,13 @@ namespace cvodes_anyode {
N_Vector fy, void *user_data, N_Vector tmp){
// callback of req. signature wrapping OdeSys method.
AnyODE::ignore(tmp);
auto t_start = std::chrono::high_resolution_clock::now();
auto& odesys = *static_cast<OdeSys*>(user_data);
AnyODE::Status status = odesys.jac_times_vec(NV_DATA_S(v), NV_DATA_S(Jv), t, NV_DATA_S(y), NV_DATA_S(fy));
if (status == AnyODE::Status::recoverable_error)
throw std::runtime_error("There are only unrecoverable errors for JacTimesVec().");
static_cast<Integrator*>(odesys.integrator)->time_jtimes += std::chrono::duration<double>(
std::chrono::high_resolution_clock::now() - t_start).count();
return handle_status_(status);
}

Expand All @@ -146,12 +165,15 @@ namespace cvodes_anyode {
#if SUNDIALS_VERSION_MAJOR < 3
AnyODE::ignore(tmp); // delta used for iterative methods
#endif
auto t_start = std::chrono::high_resolution_clock::now();
double * ewt {nullptr};
auto& odesys = *static_cast<OdeSys*>(user_data);
if (lr != 1)
throw std::runtime_error("Only left preconditioning implemented.");
AnyODE::Status status = odesys.prec_solve_left(t, NV_DATA_S(y), NV_DATA_S(fy), NV_DATA_S(r),
NV_DATA_S(z), gamma, delta, ewt);
static_cast<Integrator*>(odesys.integrator)->time_prec += std::chrono::duration<double>(
std::chrono::high_resolution_clock::now() - t_start).count();
return handle_status_(status);
}

Expand All @@ -166,10 +188,13 @@ namespace cvodes_anyode {
#if SUNDIALS_VERSION_MAJOR < 3
AnyODE::ignore(tmp1); AnyODE::ignore(tmp2); AnyODE::ignore(tmp3);
#endif
auto t_start = std::chrono::high_resolution_clock::now();
auto& odesys = *static_cast<OdeSys*>(user_data);
bool jac_recomputed = false;
AnyODE::Status status = odesys.prec_setup(t, NV_DATA_S(y), NV_DATA_S(fy), jok, jac_recomputed, gamma);
(*jcurPtr) = (jac_recomputed) ? SUNTRUE : SUNFALSE;
static_cast<Integrator*>(odesys.integrator)->time_prec += std::chrono::duration<double>(
std::chrono::high_resolution_clock::now() - t_start).count();
return handle_status_(status);
}

Expand Down Expand Up @@ -357,7 +382,12 @@ namespace cvodes_anyode {
odesys->last_integration_info_dbl["time_cpu"] = (std::clock() - cput0) / (double)CLOCKS_PER_SEC;
odesys->last_integration_info_dbl["time_wall"] = std::chrono::duration<double>(
std::chrono::high_resolution_clock::now() - t_start).count();

odesys->last_integration_info_dbl["time_rhs"] = integr->time_rhs;
odesys->last_integration_info_dbl["time_quads"] = integr->time_quads;
odesys->last_integration_info_dbl["time_roots"] = integr->time_roots;
odesys->last_integration_info_dbl["time_jac"] = integr->time_jac;
odesys->last_integration_info_dbl["time_jtimes"] = integr->time_jtimes;
odesys->last_integration_info_dbl["time_prec"] = integr->time_prec;
if (odesys->record_order)
odesys->last_integration_info_vecint["orders"] = integr->orders_seen;
if (odesys->record_fpe)
Expand Down Expand Up @@ -435,7 +465,12 @@ namespace cvodes_anyode {
odesys->last_integration_info_dbl["time_cpu"] = (std::clock() - cput0) / (double)CLOCKS_PER_SEC;
odesys->last_integration_info_dbl["time_wall"] = std::chrono::duration<double>(
std::chrono::high_resolution_clock::now() - t_start).count();

odesys->last_integration_info_dbl["time_rhs"] = integr->time_rhs;
odesys->last_integration_info_dbl["time_quads"] = integr->time_quads;
odesys->last_integration_info_dbl["time_roots"] = integr->time_roots;
odesys->last_integration_info_dbl["time_jac"] = integr->time_jac;
odesys->last_integration_info_dbl["time_jtimes"] = integr->time_jtimes;
odesys->last_integration_info_dbl["time_prec"] = integr->time_prec;
if (odesys->record_order)
odesys->last_integration_info_vecint["orders"] = integr->orders_seen;
if (odesys->record_fpe)
Expand Down
1 change: 1 addition & 0 deletions pycvodes/include/cvodes_cxx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ namespace cvodes_cxx {
int verbosity = 50; // "50%" -- plenty of room for future tuning.
bool autonomous_exprs = false;
bool record_order = false, record_fpe = false, record_steps = false;
double time_rhs {0}, time_jac {0}, time_roots {0}, time_quads {0}, time_prec {0}, time_jtimes {0};
std::vector<int> orders_seen, fpes_seen;
std::vector<double> steps_seen; // Conversion from float / long double not a problem.
Integrator(const LMM lmm, const IterType iter) {
Expand Down
8 changes: 6 additions & 2 deletions pycvodes/tests/test_cvodes_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,14 @@ def test_integrate_predefined(method, forgiveness, banded):
atol=forgiveness*atol)
assert nfo['atol'] == [1e-8, 3e-9, 2e-9] and nfo['rtol'] == 1e-8
assert nfo['nfev'] > 0
assert nfo['time_cpu'] > 1e-9
assert nfo['time_wall'] > 1e-9
if os.name == 'posix':
assert nfo['time_cpu'] > 1e-9
assert nfo['time_wall'] > 1e-9
assert nfo['time_rhs'] > 1e-9
if method in requires_jac and j is not None:
assert nfo['njev'] > 0
if os.name == 'posix':
assert nfo['time_jac'] > 1e-9


def test_integrate_adaptive_tstop0():
Expand Down

0 comments on commit 22c13c8

Please sign in to comment.