Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARKODE_MRI interface #3013

Draft
wants to merge 5 commits into
base: next
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,8 @@ set(BOUT_SOURCES
./src/solver/impls/adams_bashforth/adams_bashforth.hxx
./src/solver/impls/arkode/arkode.cxx
./src/solver/impls/arkode/arkode.hxx
./src/solver/impls/arkode/arkode_mri.cxx
./src/solver/impls/arkode/arkode_mri.hxx
./src/solver/impls/cvode/cvode.cxx
./src/solver/impls/cvode/cvode.hxx
./src/solver/impls/euler/euler.cxx
Expand Down
34 changes: 22 additions & 12 deletions include/bout/mesh.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -43,33 +43,43 @@ class Mesh;
#ifndef BOUT_MESH_H
#define BOUT_MESH_H

#include "bout/bout_enum_class.hxx"
#include "mpi.h"

#include <bout/deriv_store.hxx>
#include <bout/index_derivs_interface.hxx>
#include <bout/mpi_wrapper.hxx>

#include "bout/bout_types.hxx"
#include "bout/coordinates.hxx" // Coordinates class
#include "bout/field2d.hxx"
#include "bout/field3d.hxx"
#include "bout/field_data.hxx"
#include "bout/fieldgroup.hxx"
#include "bout/generic_factory.hxx"
#include "bout/index_derivs_interface.hxx"
#include "bout/mpi_wrapper.hxx"
#include "bout/options.hxx"
#include "bout/region.hxx"

#include "bout/fieldgroup.hxx"

class BoundaryRegion;
class BoundaryRegionPar;

#include "bout/sys/range.hxx" // RangeIterator

#include <bout/griddata.hxx>

#include "bout/coordinates.hxx" // Coordinates class

#include "bout/unused.hxx"

#include "mpi.h"
#include "bout/generic_factory.hxx"
#include <bout/region.hxx>

#include <bout/bout_enum_class.hxx>

#include <list>
#include <map>
#include <memory>
#include <optional>
#include <set>
#include <string>

class BoundaryRegion;
class BoundaryRegionPar;
class GridDataSource;

class MeshFactory : public Factory<Mesh, MeshFactory, GridDataSource*, Options*> {
public:
static constexpr auto type_name = "Mesh";
Expand Down
14 changes: 11 additions & 3 deletions include/bout/monitor.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,18 @@ public:

/// number of RHS calls
int ncalls = 0;
/// number of RHS calls for fast timescale
/// number of RHS calls for explicit timescale
int ncalls_e = 0;
/// number of RHS calls for slow timescale
/// number of RHS calls for implicit timescale
int ncalls_i = 0;
/// number of RHS calls for slow explicit timescale
int ncalls_se = 0;
/// number of RHS calls for slow implicit timescale
int ncalls_si = 0;
/// number of RHS calls for fast explicit timescale
int ncalls_fe = 0;
/// number of RHS calls for fast implicit timescale
int ncalls_fi = 0;

/// wall time spent calculating RHS
BoutReal wtime_rhs = 0;
Expand Down Expand Up @@ -122,7 +130,7 @@ public:
/*!
* Write job progress to screen
*/
void writeProgress(BoutReal simtime, bool output_split);
void writeProgress(BoutReal simtime, bool output_split, bool output_splitmri);
};

#endif // BOUT_MONITOR_H
35 changes: 35 additions & 0 deletions include/bout/physicsmodel.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,24 @@ public:
*
* Returns a flag: 0 indicates success, non-zero an error flag
*/
int runRHS_se(BoutReal time, bool linear = false);
int runRHS_si(BoutReal time, bool linear = false);
int runRHS_fe(BoutReal time, bool linear = false);
int runRHS_fi(BoutReal time, bool linear = false);
int runRHS_s(BoutReal time, bool linear = false);
int runRHS_f(BoutReal time, bool linear = false);
int runRHS(BoutReal time, bool linear = false);

/*!
* True if this model uses split operators
*/
bool splitOperator();

/*!
* True if this model uses split operators
*/
bool splitOperatorMRI();

/*!
* Run the convective (usually explicit) part of the model
*
Expand Down Expand Up @@ -267,6 +278,24 @@ protected:
* which is set to true when the rhs() function can be
* linearised. This is used in e.g. linear iterative solves.
*/
virtual int rhs_se(BoutReal UNUSED(t)) { return 1; }
virtual int rhs_se(BoutReal t, bool UNUSED(linear)) { return rhs_se(t); }

virtual int rhs_si(BoutReal UNUSED(t)) { return 1; }
virtual int rhs_si(BoutReal t, bool UNUSED(linear)) { return rhs_si(t); }

virtual int rhs_fe(BoutReal UNUSED(t)) { return 1; }
virtual int rhs_fe(BoutReal t, bool UNUSED(linear)) { return rhs_fe(t); }

virtual int rhs_fi(BoutReal UNUSED(t)) { return 1; }
virtual int rhs_fi(BoutReal t, bool UNUSED(linear)) { return rhs_fi(t); }

virtual int rhs_s(BoutReal UNUSED(t)) { return 1; }
virtual int rhs_s(BoutReal t, bool UNUSED(linear)) { return rhs_s(t); }

virtual int rhs_f(BoutReal UNUSED(t)) { return 1; }
virtual int rhs_f(BoutReal t, bool UNUSED(linear)) { return rhs_f(t); }

virtual int rhs(BoutReal UNUSED(t)) { return 1; }
virtual int rhs(BoutReal t, bool UNUSED(linear)) { return rhs(t); }

Expand Down Expand Up @@ -309,6 +338,10 @@ protected:
/// Specify that this model is split into a convective and diffusive part
void setSplitOperator(bool split = true) { splitop = split; }

/// Specify that this model is split into a convective and diffusive part
void setSplitOperatorMRI(bool split = true) { splitopmri = split; }


/// Specify a preconditioner function
void setPrecon(preconfunc pset) { userprecon = pset; }
template <class Model>
Expand Down Expand Up @@ -391,6 +424,8 @@ private:
bool restart_enabled{true};
/// Split operator model?
bool splitop{false};
/// Split operator model?
bool splitopmri{false};
/// Pointer to user-supplied preconditioner function
preconfunc userprecon{nullptr};
/// Pointer to user-supplied Jacobian-vector multiply function
Expand Down
28 changes: 28 additions & 0 deletions include/bout/solver.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ constexpr auto SOLVEREULER = "euler";
constexpr auto SOLVERRK3SSP = "rk3ssp";
constexpr auto SOLVERPOWER = "power";
constexpr auto SOLVERARKODE = "arkode";
constexpr auto SOLVERARKODEMRI = "arkodemri";
constexpr auto SOLVERIMEXBDF2 = "imexbdf2";
constexpr auto SOLVERSNES = "snes";
constexpr auto SOLVERRKGENERIC = "rkgeneric";
Expand Down Expand Up @@ -310,9 +311,21 @@ public:
/// Same but fur implicit timestep counter - for IMEX
int resetRHSCounter_i();

/// Same but for slow explicit timestep counter - for MRI IMEX
int resetRHSCounter_se();
/// Same but for slow implicit timestep counter - for MRI IMEX
int resetRHSCounter_si();
/// Same but for fast explicit timestep counter - for MRI IMEX
int resetRHSCounter_fe();
/// Same but for fast implicit timestep counter - for MRI IMEX
int resetRHSCounter_fi();

/// Test if this solver supports split operators (e.g. implicit/explicit)
bool splitOperator();

/// Test if this solver supports split operators (e.g. implicit/explicit)
bool splitOperatorMRI();

bool canReset{false};

/// Add evolving variables to output (dump) file or restart file
Expand Down Expand Up @@ -438,6 +451,12 @@ protected:
BoutReal simtime{0.0};

/// Run the user's RHS function
int run_rhs_se(BoutReal t, bool linear = false);
int run_rhs_si(BoutReal t, bool linear = false);
int run_rhs_fe(BoutReal t, bool linear = false);
int run_rhs_fi(BoutReal t, bool linear = false);
int run_rhs_s(BoutReal t, bool linear = false);
int run_rhs_f(BoutReal t, bool linear = false);
int run_rhs(BoutReal t, bool linear = false);
/// Calculate only the convective parts
int run_convective(BoutReal t, bool linear = false);
Expand Down Expand Up @@ -542,6 +561,15 @@ private:
int rhs_ncalls_e{0};
/// Number of calls to the implicit (diffusive) RHS function
int rhs_ncalls_i{0};
/// number of RHS calls for slow explicit timescale
int rhs_ncalls_se = 0;
/// number of RHS calls for slow implicit timescale
int rhs_ncalls_si = 0;
/// number of RHS calls for fast explicit timescale
int rhs_ncalls_fe = 0;
/// number of RHS calls for fast implicit timescale
int rhs_ncalls_fi = 0;

/// Default sampling rate at which to call monitors - same as output to screen
int default_monitor_period{1};
/// timestep - shouldn't be changed after init is called.
Expand Down
62 changes: 45 additions & 17 deletions src/bout++.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -849,8 +849,12 @@ int BoutMonitor::call(Solver* solver, BoutReal t, [[maybe_unused]] int iter, int
run_data.ncalls = solver->resetRHSCounter();
run_data.ncalls_e = solver->resetRHSCounter_e();
run_data.ncalls_i = solver->resetRHSCounter_i();

run_data.ncalls_se = solver->resetRHSCounter_se();
run_data.ncalls_si = solver->resetRHSCounter_si();
run_data.ncalls_fe = solver->resetRHSCounter_fe();
run_data.ncalls_fi = solver->resetRHSCounter_fi();

const bool output_split = solver->splitOperator();
run_data.wtime_rhs = Timer::resetTime("rhs");
run_data.wtime_invert = Timer::resetTime("invert");
// Time spent communicating (part of RHS)
Expand All @@ -872,16 +876,23 @@ int BoutMonitor::call(Solver* solver, BoutReal t, [[maybe_unused]] int iter, int
first_time = false;

// Print the column header for timing info
if (!output_split) {
output_progress.write(_("Sim Time | RHS evals | Wall Time | Calc Inv Comm "
" I/O SOLVER\n\n"));
} else {
if (solver->splitOperator()) {
output_progress.write(_("Sim Time | RHS_e evals | RHS_I evals | Wall Time | "
"Calc Inv Comm I/O SOLVER\n\n"));
}
else if (solver->splitOperatorMRI()) {
output_progress.write(_("Sim Time | RHS_se evals | RHS_si evals | RHS_fe evals |"
"RHS_fi evals | Wall Time | "
"Calc Inv Comm I/O SOLVER\n\n"));
}
else
{
output_progress.write(_("Sim Time | RHS evals | Wall Time | Calc Inv Comm "
" I/O SOLVER\n\n"));
}
}

run_data.writeProgress(simtime, output_split);
run_data.writeProgress(simtime, solver->splitOperator(), solver->splitOperatorMRI());

// This bit only to screen, not log file

Expand Down Expand Up @@ -1011,6 +1022,10 @@ void RunMetrics::outputVars(Options& output_options) const {
output_options["wtime"].assignRepeat(wtime, "t", true, "Output");
output_options["ncalls"].assignRepeat(ncalls, "t", true, "Output");
output_options["ncalls_e"].assignRepeat(ncalls_e, "t", true, "Output");
output_options["ncalls_se"].assignRepeat(ncalls_se, "t", true, "Output");
output_options["ncalls_si"].assignRepeat(ncalls_si, "t", true, "Output");
output_options["ncalls_fe"].assignRepeat(ncalls_fe, "t", true, "Output");
output_options["ncalls_fi"].assignRepeat(ncalls_fi, "t", true, "Output");
output_options["ncalls_i"].assignRepeat(ncalls_i, "t", true, "Output");
output_options["wtime_rhs"].assignRepeat(wtime_rhs, "t", true, "Output");
output_options["wtime_invert"].assignRepeat(wtime_invert, "t", true, "Output");
Expand All @@ -1036,17 +1051,8 @@ void RunMetrics::calculateDerivedMetrics() {
wtime_per_rhs_i = wtime / ncalls_i;
}

void RunMetrics::writeProgress(BoutReal simtime, bool output_split) {
if (!output_split) {
output_progress.write(
"{:.3e} {:5d} {:.2e} {:5.1f} {:5.1f} {:5.1f} {:5.1f} {:5.1f}\n",
simtime, ncalls, wtime, 100. * (wtime_rhs - wtime_comms - wtime_invert) / wtime,
100. * wtime_invert / wtime, // Inversions
100. * wtime_comms / wtime, // Communications
100. * wtime_io / wtime, // I/O
100. * (wtime - wtime_io - wtime_rhs) / wtime); // Everything else

} else {
void RunMetrics::writeProgress(BoutReal simtime, bool output_split, bool output_splitmri) {
if (output_split) {
output_progress.write("{:.3e} {:5d} {:5d} {:.2e} {:5.1f} "
"{:5.1f} {:5.1f} {:5.1f} {:5.1f}\n",
simtime, ncalls_e, ncalls_i, wtime,
Expand All @@ -1057,4 +1063,26 @@ void RunMetrics::writeProgress(BoutReal simtime, bool output_split) {
100. * (wtime - wtime_io - wtime_rhs)
/ wtime); // Everything else
}
else if (output_splitmri) {
output_progress.write("{:.3e} {:8d} {:8d} {:8d} {:8d} {:.2e} {:5.1f} "
"{:5.1f} {:5.1f} {:5.1f} {:5.1f}\n",
simtime, ncalls_se, ncalls_si, ncalls_fe, ncalls_fi, wtime,
100. * (wtime_rhs - wtime_comms - wtime_invert) / wtime,
100. * wtime_invert / wtime, // Inversions
100. * wtime_comms / wtime, // Communications
100. * wtime_io / wtime, // I/O
100. * (wtime - wtime_io - wtime_rhs)
/ wtime); // Everything else
}
else
{
output_progress.write(
"{:.3e} {:5d} {:.2e} {:5.1f} {:5.1f} {:5.1f} {:5.1f} {:5.1f}\n",
simtime, ncalls, wtime, 100. * (wtime_rhs - wtime_comms - wtime_invert) / wtime,
100. * wtime_invert / wtime, // Inversions
100. * wtime_comms / wtime, // Communications
100. * wtime_io / wtime, // I/O
100. * (wtime - wtime_io - wtime_rhs) / wtime); // Everything else

}
}
1 change: 0 additions & 1 deletion src/mesh/impls/bout/boutmesh.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
#include <bout/dcomplex.hxx>
#include <bout/derivs.hxx>
#include <bout/fft.hxx>
#include <bout/griddata.hxx>
#include <bout/msg_stack.hxx>
#include <bout/options.hxx>
#include <bout/output.hxx>
Expand Down
6 changes: 3 additions & 3 deletions src/mesh/mesh.cxx
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
#include <bout/boutcomm.hxx>
#include <bout/coordinates.hxx>
#include <bout/derivs.hxx>
#include <bout/globals.hxx>
#include <bout/griddata.hxx>
#include <bout/mesh.hxx>
#include <bout/msg_stack.hxx>
#include <bout/output.hxx>
#include <bout/utils.hxx>

#include <cmath>

#include <bout/boutcomm.hxx>
#include <bout/output.hxx>

#include "impls/bout/boutmesh.hxx"

MeshFactory::ReturnType MeshFactory::create(Options* options,
Expand Down
9 changes: 8 additions & 1 deletion src/physics/physicsmodel.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,16 @@ void PhysicsModel::initialise(Solver* s) {
}
}

int PhysicsModel::runRHS(BoutReal time, bool linear) { return rhs(time, linear); }
int PhysicsModel::runRHS_se(BoutReal time, bool linear) { return PhysicsModel::rhs_se(time, linear); }
int PhysicsModel::runRHS_si(BoutReal time, bool linear) { return PhysicsModel::rhs_si(time, linear); }
int PhysicsModel::runRHS_fe(BoutReal time, bool linear) { return PhysicsModel::rhs_fe(time, linear); }
int PhysicsModel::runRHS_fi(BoutReal time, bool linear) { return PhysicsModel::rhs_fi(time, linear); }
int PhysicsModel::runRHS_s(BoutReal time, bool linear) { return PhysicsModel::rhs_s(time, linear); }
int PhysicsModel::runRHS_f(BoutReal time, bool linear) { return PhysicsModel::rhs_f(time, linear); }
int PhysicsModel::runRHS(BoutReal time, bool linear) { return PhysicsModel::rhs(time, linear); }

bool PhysicsModel::splitOperator() { return splitop; }
bool PhysicsModel::splitOperatorMRI() { return splitopmri; }

int PhysicsModel::runConvective(BoutReal time, bool linear) {
return convective(time, linear);
Expand Down
Loading