Skip to content

Commit

Permalink
Merge branch 'TBendall/SubcyclingInvestigation' into TBendall/Predictors
Browse files Browse the repository at this point in the history
  • Loading branch information
tommbendall committed Sep 16, 2024
2 parents 6eebf46 + ad97838 commit b24b035
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 15 deletions.
18 changes: 13 additions & 5 deletions examples/shallow_water/williamson_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
Domain, IO, OutputParameters, SemiImplicitQuasiNewton, SSPRK3, DGUpwind,
TrapeziumRule, ShallowWaterParameters, ShallowWaterEquations, Sum,
lonlatr_from_xyz, GeneralIcosahedralSphereMesh, ZonalComponent,
MeridionalComponent, RelativeVorticity
MeridionalComponent, RelativeVorticity, MoistConvectiveSWSolver
)

williamson_5_defaults = {
Expand Down Expand Up @@ -71,24 +71,32 @@ def williamson_5(
rsq = min_value(R0**2, (lamda - lamda_c)**2 + (phi - phi_c)**2)
r = sqrt(rsq)
tpexpr = mountain_height * (1 - r/R0)
eqns = ShallowWaterEquations(domain, parameters, fexpr=fexpr, bexpr=tpexpr)
eqns = ShallowWaterEquations(domain, parameters, fexpr=fexpr, bexpr=tpexpr,
u_transport_option='vector_advection_form')

# I/O
output = OutputParameters(
dirname=dirname, dumplist_latlon=['D'], dumpfreq=dumpfreq,
dump_vtus=True, dump_nc=False, dumplist=['D', 'topography']
dump_vtus=False, dump_nc=True, dumplist=['D', 'topography']
)
diagnostic_fields = [Sum('D', 'topography'), RelativeVorticity(),
MeridionalComponent('u'), ZonalComponent('u')]
io = IO(domain, output, diagnostic_fields=diagnostic_fields)

# Transport schemes
transported_fields = [TrapeziumRule(domain, "u"), SSPRK3(domain, "D")]
transported_fields = [
SSPRK3(domain, "u", subcycle_by_courant=0.25),
SSPRK3(domain, "D", subcycle_by_courant=0.25)
]
transport_methods = [DGUpwind(eqns, "u"), DGUpwind(eqns, "D")]

linear_solver = MoistConvectiveSWSolver(eqns, tau_values={'D': 1.0})

# Time stepper
stepper = SemiImplicitQuasiNewton(
eqns, io, transported_fields, transport_methods
eqns, io, transported_fields, transport_methods,
linear_solver=linear_solver, num_outer=2, num_inner=2,
predictor='D', alpha=0.55, accelerator=True
)

# ------------------------------------------------------------------------ #
Expand Down
82 changes: 72 additions & 10 deletions gusto/timestepping/semi_implicit_quasi_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from firedrake import (Function, Constant, TrialFunctions, DirichletBC,
LinearVariationalProblem, LinearVariationalSolver)
LinearVariationalProblem, LinearVariationalSolver,
Interpolator, div)
from firedrake.fml import drop, replace_subject
from pyop2.profiling import timed_stage
from gusto.core import TimeLevelFields, StateFields
Expand Down Expand Up @@ -35,7 +36,8 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods,
diffusion_schemes=None, physics_schemes=None,
slow_physics_schemes=None, fast_physics_schemes=None,
alpha=Constant(0.5), off_centred_u=False,
num_outer=2, num_inner=2, accelerator=False):
num_outer=2, num_inner=2, accelerator=False,
reference_update_freq=None, predictor=None):

"""
Args:
Expand Down Expand Up @@ -84,13 +86,25 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods,
implicit forcing (pressure gradient and Coriolis) terms, and the
linear solve. Defaults to 2. Note that default used by the Met
Office's ENDGame and GungHo models is 2.
accelerator (bool, optional): Whether to zero non-wind implicit forcings
for transport terms in order to speed up solver convergence
accelerator (bool, optional): Whether to zero non-wind implicit
forcings for transport terms in order to speed up solver
convergence. Defaults to False.
reference_update_freq (float, optional): frequency with which to
update the reference profile with the n-th time level state
fields. This variable corresponds to time in seconds, and
setting this to zero will update the reference profiles every
time step. Setting it to None turns off the update, and
reference profiles will remain at their initial values.
Defaults to None.
"""

self.num_outer = num_outer
self.num_inner = num_inner
self.alpha = alpha
self.accelerator = accelerator
self.reference_update_freq = reference_update_freq
self.to_update_ref_profile = False
self.predictor = predictor

# default is to not offcentre transporting velocity but if it
# is offcentred then use the same value as alpha
Expand Down Expand Up @@ -188,7 +202,13 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods,
self.linear_solver = linear_solver
self.forcing = Forcing(equation_set, self.alpha)
self.bcs = equation_set.bcs
self.accelerator = accelerator

if self.predictor is not None:
V_DG = equation_set.domain.spaces('DG')
div_factor = Constant(1.0) - (Constant(1.0) - self.alpha)*self.dt*div(self.x.star('u'))
self.predictor_interpolator = Interpolator(
self.x.star(predictor)*div_factor, V_DG
)

def _apply_bcs(self):
"""
Expand Down Expand Up @@ -252,6 +272,24 @@ def copy_active_tracers(self, x_in, x_out):
for name in self.tracers_to_copy:
x_out(name).assign(x_in(name))

def update_reference_profiles(self):
"""
Updates the reference profiles and if required also updates them in the
linear solver.
"""

if self.reference_update_freq is not None:
if float(self.t) + self.reference_update_freq > self.last_ref_update_time:
self.equation.X_ref.assign(self.x.n(self.field_name))
self.last_ref_update_time = float(self.t)
if hasattr(self.linear_solver, 'update_reference_profiles'):
self.linear_solver.update_reference_profiles()

elif self.to_update_ref_profile:
if hasattr(self.linear_solver, 'update_reference_profiles'):
self.linear_solver.update_reference_profiles()
self.to_update_ref_profile = False

def timestep(self):
"""Defines the timestep"""
xn = self.x.n
Expand All @@ -264,13 +302,18 @@ def timestep(self):
xrhs_phys = self.xrhs_phys
dy = self.dy

# Update reference profiles --------------------------------------------
self.update_reference_profiles()

# Slow physics ---------------------------------------------------------
x_after_slow(self.field_name).assign(xn(self.field_name))
if len(self.slow_physics_schemes) > 0:
with timed_stage("Slow physics"):
logger.info('Semi-implicit Quasi Newton: Slow physics')
for _, scheme in self.slow_physics_schemes:
scheme.apply(x_after_slow(scheme.field_name), x_after_slow(scheme.field_name))

# Explict forcing ------------------------------------------------------
with timed_stage("Apply forcing terms"):
logger.info('Semi-implicit Quasi Newton: Explicit forcing')
# Put explicit forcing into xstar
Expand All @@ -280,16 +323,27 @@ def timestep(self):
# the correct values
xp(self.field_name).assign(xstar(self.field_name))

# OUTER ----------------------------------------------------------------
for outer in range(self.num_outer):

# Transport --------------------------------------------------------
with timed_stage("Transport"):
self.io.log_courant(self.fields, 'transporting_velocity',
message=f'transporting velocity, outer iteration {outer}')
for name, scheme in self.active_transport:
logger.info(f'Semi-implicit Quasi Newton: Transport {outer}: {name}')
# transports a field from xstar and puts result in xp
scheme.apply(xp(name), xstar(name))

if name == self.predictor:
V = xstar(name).function_space()
field_in = Function(V)
field_out = Function(V)
self.predictor_interpolator.interpolate()
scheme.apply(field_out, field_in)
xp(name).assign(xstar(name) + field_out - field_in)
else:
scheme.apply(xp(name), xstar(name))

# Fast physics -----------------------------------------------------
x_after_fast(self.field_name).assign(xp(self.field_name))
if len(self.fast_physics_schemes) > 0:
with timed_stage("Fast physics"):
Expand All @@ -302,8 +356,7 @@ def timestep(self):

for inner in range(self.num_inner):

# TODO: this is where to update the reference state

# Implicit forcing ---------------------------------------------
with timed_stage("Apply forcing terms"):
logger.info(f'Semi-implicit Quasi Newton: Implicit forcing {(outer, inner)}')
self.forcing.apply(xp, xnp1, xrhs, "implicit")
Expand All @@ -314,6 +367,7 @@ def timestep(self):
xrhs -= xnp1(self.field_name)
xrhs += xrhs_phys

# Linear solve -------------------------------------------------
with timed_stage("Implicit solve"):
logger.info(f'Semi-implicit Quasi Newton: Mixed solve {(outer, inner)}')
self.linear_solver.solve(xrhs, dy) # solves linear system and places result in dy
Expand Down Expand Up @@ -353,10 +407,18 @@ def run(self, t, tmax, pick_up=False):
pick_up: (bool): specify whether to pick_up from a previous run
"""

if not pick_up:
if not pick_up and self.reference_update_freq is None:
assert self.reference_profiles_initialised, \
'Reference profiles for must be initialised to use Semi-Implicit Timestepper'

if not pick_up and self.reference_update_freq is not None:
# Force reference profiles to be updated on first time step
self.last_ref_update_time = float(t) - float(self.dt)

elif not pick_up or (pick_up and self.reference_update_freq is None):
# Indicate that linear solver profile needs updating
self.to_update_ref_profile = True

super().run(t, tmax, pick_up=pick_up)


Expand Down

0 comments on commit b24b035

Please sign in to comment.