Skip to content

Commit

Permalink
Added cvode functionality to SUNDIALS integrator (AMReX-Codes#3436)
Browse files Browse the repository at this point in the history
## Summary

## Additional background

## Checklist

The proposed changes:
- [ ] fix a bug or incorrect behavior in AMReX
- [x] add new capabilities to AMReX
- [ ] changes answers in the test suite to more than roundoff level
- [ ] are likely to significantly affect the results of downstream AMReX
users
- [ ] include documentation in the code and/or rst files, if appropriate

---------

Co-authored-by: Nicholas Deak <[email protected]>
Co-authored-by: Ann Almgren <[email protected]>
  • Loading branch information
3 people authored Sep 18, 2023
1 parent ad2a89a commit e9f4435
Showing 1 changed file with 131 additions and 2 deletions.
133 changes: 131 additions & 2 deletions Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
#include <arkode/arkode_erkstep.h> /* prototypes for ERKStep fcts., consts */
#include <arkode/arkode_arkstep.h> /* prototypes for ARKStep fcts., consts */
#include <arkode/arkode_mristep.h> /* prototypes for MRIStep fcts., consts */
#include <cvode/cvode.h> /* access to CVODE solver */
#include <nvector/nvector_manyvector.h>/* manyvector N_Vector types, fcts. etc */
#include <AMReX_NVector_MultiFab.H> /* MultiFab N_Vector types, fcts., macros */
#include <AMReX_Sundials.H> /* MultiFab N_Vector types, fcts., macros */
#include <sunlinsol/sunlinsol_spgmr.h> /* access to SPGMR SUNLinearSolver */
#include <sunlinsol/sunlinsol_spfgmr.h> /* access to SPGMR SUNLinearSolver */
#include <sunnonlinsol/sunnonlinsol_fixedpoint.h> /* access to FixedPoint SUNNonlinearSolver */
#include <sundials/sundials_types.h> /* defs. of realtype, sunindextype, etc */

Expand Down Expand Up @@ -70,11 +72,13 @@ private:
bool use_erk_strategy;
bool use_mri_strategy;
bool use_mri_strategy_test;
bool use_cvode_strategy;
bool use_implicit_inner;

SUNNonlinearSolver NLS; /* empty nonlinear solver object */
SUNLinearSolver LS; /* empty linear solver object */
void *arkode_mem; /* empty ARKode memory structure */
void *cvode_mem; /* empty CVODE memory structure */
SUNNonlinearSolver NLSf; /* empty nonlinear solver object */
SUNLinearSolver LSf; /* empty linear solver object */
void *inner_mem; /* empty ARKode memory structure */
Expand All @@ -101,6 +105,7 @@ private:
use_erk_strategy=false;
use_mri_strategy=false;
use_mri_strategy_test=false;
use_cvode_strategy=false;

amrex::ParmParse pp("integration.sundials");

Expand All @@ -124,6 +129,10 @@ private:
use_mri_strategy=true;
use_mri_strategy_test=true;
}
else if (theStrategy == "CVODE")
{
use_cvode_strategy=true;
}
else
{
std::string msg("Unknown strategy: ");
Expand All @@ -146,6 +155,7 @@ private:
NLS = nullptr; /* empty nonlinear solver object */
LS = nullptr; /* empty linear solver object */
arkode_mem = nullptr; /* empty ARKode memory structure */
cvode_mem = nullptr; /* empty CVODE memory structure */
NLSf = nullptr; /* empty nonlinear solver object */
LSf = nullptr; /* empty linear solver object */
inner_mem = nullptr; /* empty ARKode memory structure */
Expand Down Expand Up @@ -188,8 +198,10 @@ public:
return advance_mri(S_old, S_new, time, time_step);
} else if (use_erk_strategy) {
return advance_erk(S_old, S_new, time, time_step);
} else {
Error("SUNDIALS integrator backend not specified (ERK or MRI).");
} else if (use_cvode_strategy) {
return advance_cvode(S_old, S_new, time, time_step);
}else {
Error("SUNDIALS integrator backend not specified (ERK, MRI, or CVODE).");
}

return 0;
Expand Down Expand Up @@ -643,6 +655,123 @@ public:
return timestep;
}

amrex::Real advance_cvode (T& S_old, T& S_new, amrex::Real time, const amrex::Real time_step)
{
t = time;
tout = time+time_step;
hfixed = time_step;
timestep = time_step;

// We use S_new as our working space, so first copy S_old to S_new
IntegratorOps<T>::Copy(S_new, S_old);

// Create an N_Vector wrapper for the solution MultiFab
auto get_length = [&](int index) -> sunindextype {
auto* p_mf = &S_new[index];
return p_mf->nComp() * (p_mf->boxArray()).numPts();
};

/* Create manyvector for solution using S_new */
NVar = S_new.size(); // NOTE: expects S_new to be a Vector<MultiFab>
nv_many_arr = new N_Vector[NVar]; // vector array composed of cons, xmom, ymom, zmom component vectors */

for (int i = 0; i < NVar; ++i) {
sunindextype length = get_length(i);
N_Vector nvi = amrex::sundials::N_VMake_MultiFab(length, &S_new[i]);
nv_many_arr[i] = nvi;
}

nv_S = N_VNew_ManyVector(NVar, nv_many_arr, sunctx);
nv_stage_data = N_VClone(nv_S);

/* Create a temporary storage space for MRI */
Vector<std::unique_ptr<T> > temp_storage;
IntegratorOps<T>::CreateLike(temp_storage, S_old);
T& state_store = *temp_storage.back();

SundialsUserData udata;

/* Begin Section: SUNDIALS FUNCTION HOOKS */
/* f routine to compute the ODE RHS function f(t,y). */
udata.f = [&](realtype rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int {
amrex::Vector<amrex::MultiFab> S_data;
amrex::Vector<amrex::MultiFab> S_rhs;

const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data);
S_data.resize(num_vecs);
S_rhs.resize(num_vecs);

for(int i=0; i<num_vecs; i++)
{
S_data.at(i)=amrex::MultiFab(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_data, i)),amrex::make_alias,0,amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_data, i))->nComp());
S_rhs.at(i)=amrex::MultiFab(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i)),amrex::make_alias,0,amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i))->nComp());
}

BaseT::post_update(S_data, rhs_time);
BaseT::rhs(S_rhs, S_data, rhs_time);

return 0;
};

udata.ProcessStage = [&](realtype rhs_time, N_Vector y_data, void * /* user_data */) -> int {
amrex::Vector<amrex::MultiFab > S_data;

const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data);
S_data.resize(num_vecs);

for (int i=0; i<num_vecs; i++)
{
S_data.at(i)=amrex::MultiFab(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_data, i)),amrex::make_alias,0,amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_data, i))->nComp());
}

BaseT::post_update(S_data, rhs_time);

return 0;
};
/* End Section: SUNDIALS FUNCTION HOOKS */

/* Set up CVODE BDF solver */
cvode_mem = CVodeCreate(CV_BDF, sunctx);
CVodeSetUserData(cvode_mem, &udata);
CVodeInit(cvode_mem, SundialsUserFun::f, time, nv_S);
CVodeSStolerances(cvode_mem, reltol, abstol);
CVodeSetMaxNumSteps(cvode_mem, 100000);

for(int i=0; i<N_VGetNumSubvectors_ManyVector(nv_S); i++)
{
MultiFab::Copy(state_store[i], *amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(nv_S, i)), 0, 0, state_store[i].nComp(), state_store[i].nGrow());
MultiFab::Copy(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(nv_stage_data, i)), *amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(nv_S, i)), 0, 0, state_store[i].nComp(), state_store[i].nGrow());
}

// Set up and assign the linear solver (GMRES)
LS = SUNLinSol_SPGMR(nv_S, SUN_PREC_NONE, 0, sunctx);
CVodeSetLinearSolver(cvode_mem, LS, nullptr);

// Use CVode to evolve state_old data (wrapped in nv_S) from t to tout=t+dt
auto flag = CVode(cvode_mem, tout, nv_S, &t, CV_NORMAL);
AMREX_ALWAYS_ASSERT(flag >= 0);

// Copy the result stored in nv_S to state_new
for(int i=0; i<N_VGetNumSubvectors_ManyVector(nv_S); i++)
{
MultiFab::Copy(S_new[i], *amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(nv_S, i)), 0, 0, S_new[i].nComp(), S_new[i].nGrow());
}

// Clean up allocated memory
for (int i = 0; i < Nvar; ++i) {
N_VDestroy(nv_many_arr[i]);
}
delete[] nv_many_arr;
N_VDestroy(nv_S);
N_VDestroy(nv_stage_data);

CVodeFree(&cvode_mem);
SUNLinSolFree(LS);

// Return timestep
return timestep;
}

void time_interpolate (const T& /* S_new */, const T& /* S_old */, amrex::Real /* timestep_fraction */, T& /* data */) override {}

void map_data (std::function<void(T&)> /* Map */) override {}
Expand Down

0 comments on commit e9f4435

Please sign in to comment.