Skip to content

Commit

Permalink
Increment form for implicit RK added and tested
Browse files Browse the repository at this point in the history
  • Loading branch information
atb1995 committed Oct 24, 2024
1 parent 7fc82a1 commit 90d1f7c
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 42 deletions.
6 changes: 1 addition & 5 deletions gusto/time_discretisation/explicit_runge_kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,9 @@ def __init__(self, domain, butcher_matrix, field_name=None,
solver_parameters=solver_parameters,
limiter=limiter, options=options)
self.butcher_matrix = butcher_matrix
self.nbutcher = int(np.shape(self.butcher_matrix)[0])
self.nStages = int(np.shape(self.butcher_matrix)[0])
self.increment_form = increment_form

@property
def nStages(self):
return self.nbutcher

def setup(self, equation, apply_bcs=True, *active_labels):
"""
Set up the time discretisation based on the equation.
Expand Down
161 changes: 127 additions & 34 deletions gusto/time_discretisation/implicit_runge_kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from firedrake import (Function, split, NonlinearVariationalProblem,
NonlinearVariationalSolver)
NonlinearVariationalSolver, Constant)
from firedrake.fml import replace_subject, all_terms, drop
from firedrake.utils import cached_property

Expand Down Expand Up @@ -56,7 +56,8 @@ class ImplicitRungeKutta(TimeDiscretisation):
# ---------------------------------------------------------------------------

def __init__(self, domain, butcher_matrix, field_name=None,
solver_parameters=None, limiter=None, options=None,):
increment_form=True, solver_parameters=None,
limiter=None, options=None,):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
Expand All @@ -66,6 +67,9 @@ def __init__(self, domain, butcher_matrix, field_name=None,
discretisation.
field_name (str, optional): name of the field to be evolved.
Defaults to None.
increment_form (bool, optional): whether to write the RK scheme in
"increment form", solving for increments rather than updated
fields. Defaults to True.
solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying solver. Defaults to None.
limiter (:class:`Limiter` object, optional): a limiter to apply to
Expand All @@ -80,6 +84,7 @@ def __init__(self, domain, butcher_matrix, field_name=None,
limiter=limiter, options=options)
self.butcher_matrix = butcher_matrix
self.nStages = int(np.shape(self.butcher_matrix)[1])
self.increment_form = increment_form

def setup(self, equation, apply_bcs=True, *active_labels):
"""
Expand All @@ -94,30 +99,98 @@ def setup(self, equation, apply_bcs=True, *active_labels):
super().setup(equation, apply_bcs, *active_labels)

self.k = [Function(self.fs) for i in range(self.nStages)]
self.xs = [Function(self.fs) for i in range(self.nStages)]

def lhs(self):
return super().lhs

def rhs(self):
return super().rhs

def solver(self, stage):
residual = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=drop,
map_if_false=replace_subject(self.xnph, self.idx),
)
def res(self, stage):
"""Set up the discretisation's residual for a given stage."""
# Add time derivative terms y_s - y^n for stage s
mass_form = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_false=drop)
residual += mass_form.label_map(all_terms,
replace_subject(self.x_out, self.idx))
residual = mass_form.label_map(all_terms,
map_if_true=replace_subject(self.x_out, old_idx=self.idx))
residual -= mass_form.label_map(all_terms,
map_if_true=replace_subject(self.x1, old_idx=self.idx))
# Loop through stages up to s-1 and calcualte/sum
# dt*(a_s1*F(y_1) + a_s2*F(y_2)+ ... + a_{s,s-1}*F(y_{s-1}))
print(stage)
for i in range(stage):
r_imp = self.residual.label_map(
lambda t: not t.has_label(time_derivative),
map_if_true=replace_subject(self.xs[i], old_idx=self.idx),
map_if_false=drop)
r_imp = r_imp.label_map(
all_terms,
map_if_true=lambda t: Constant(self.butcher_matrix[stage, i])*self.dt*t)
residual += r_imp
# Calculate and add on dt*a_ss*F(y_s)
r_imp = self.residual.label_map(
lambda t: not t.has_label(time_derivative),
map_if_true=replace_subject(self.x_out, old_idx=self.idx),
map_if_false=drop)
r_imp = r_imp.label_map(
all_terms,
map_if_true=lambda t: Constant(self.butcher_matrix[stage, stage])*self.dt*t)
residual += r_imp
return residual.form

@property
def final_res(self):
"""Set up the discretisation's final residual."""
# Add time derivative terms y^{n+1} - y^n
mass_form = self.residual.label_map(lambda t: t.has_label(time_derivative),
map_if_false=drop)
residual = mass_form.label_map(all_terms,
map_if_true=replace_subject(self.x_out, old_idx=self.idx))
residual -= mass_form.label_map(all_terms,
map_if_true=replace_subject(self.x1, old_idx=self.idx))
# Loop through stages up to s-1 and calcualte/sum
# dt*(b_1*F(y_1) + b_2*F(y_2) + .... + b_s*F(y_s))
for i in range(self.nStages):
r_imp = self.residual.label_map(
lambda t: not t.has_label(time_derivative),
map_if_true=replace_subject(self.xs[i], old_idx=self.idx),
map_if_false=drop)
r_imp = r_imp.label_map(
all_terms,
map_if_true=lambda t: Constant(self.butcher_matrix[self.nStages, i])*self.dt*t)
residual += r_imp
return residual.form

problem = NonlinearVariationalProblem(residual.form, self.x_out, bcs=self.bcs)
def solver(self, stage):
if self.increment_form:
residual = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=drop,
map_if_false=replace_subject(self.xnph, self.idx),
)
mass_form = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_false=drop)
residual += mass_form.label_map(all_terms,
replace_subject(self.x_out, self.idx))

problem = NonlinearVariationalProblem(residual.form, self.x_out, bcs=self.bcs)

else:
problem = NonlinearVariationalProblem(self.res(stage), self.x_out, bcs=self.bcs)

solver_name = self.field_name+self.__class__.__name__ + "%s" % (stage)
return NonlinearVariationalSolver(problem, solver_parameters=self.solver_parameters,
options_prefix=solver_name)
return NonlinearVariationalSolver(problem, solver_parameters=self.solver_parameters, options_prefix=solver_name)

@cached_property
def final_solver(self):
"""Set up a solver for the final solve to evaluate time level n+1."""
# setup solver using lhs and rhs defined in derived class
problem = NonlinearVariationalProblem(self.final_res, self.x_out, bcs=self.bcs)
solver_name = self.field_name+self.__class__.__name__
return NonlinearVariationalSolver(problem, solver_parameters=self.solver_parameters, options_prefix=solver_name)

@cached_property
def solvers(self):
Expand All @@ -128,31 +201,43 @@ def solvers(self):

def solve_stage(self, x0, stage):
self.x1.assign(x0)
for i in range(stage):
self.x1.assign(self.x1 + self.butcher_matrix[stage, i]*self.dt*self.k[i])

if self.limiter is not None:
self.limiter.apply(self.x1)

if self.idx is None and len(self.fs) > 1:
self.xnph = tuple([self.dt*self.butcher_matrix[stage, stage]*a + b
for a, b in zip(split(self.x_out), split(self.x1))])
if self.increment_form:
for i in range(stage):
self.x1.assign(self.x1 + self.butcher_matrix[stage, i]*self.dt*self.k[i])

if self.limiter is not None:
self.limiter.apply(self.x1)

if self.idx is None and len(self.fs) > 1:
self.xnph = tuple([self.dt*self.butcher_matrix[stage, stage]*a + b
for a, b in zip(split(self.x_out), split(self.x1))])
else:
self.xnph = self.x1 + self.butcher_matrix[stage, stage]*self.dt*self.x_out
solver = self.solvers[stage]
solver.solve()

self.k[stage].assign(self.x_out)
else:
self.xnph = self.x1 + self.butcher_matrix[stage, stage]*self.dt*self.x_out
solver = self.solvers[stage]
solver.solve()
if (stage > 0):
self.x_out.assign(self.xs[stage-1])
solver = self.solvers[stage]
solver.solve()

self.k[stage].assign(self.x_out)
self.xs[stage].assign(self.x_out)

@wrapper_apply
def apply(self, x_out, x_in):

self.x_out.assign(x_in)
for i in range(self.nStages):
self.solve_stage(x_in, i)

x_out.assign(x_in)
for i in range(self.nStages):
x_out.assign(x_out + self.butcher_matrix[self.nStages, i]*self.dt*self.k[i])
if self.increment_form:
x_out.assign(x_in)
for i in range(self.nStages):
x_out.assign(x_out + self.butcher_matrix[self.nStages, i]*self.dt*self.k[i])
else:
self.final_solver.solve()
x_out.assign(self.x_out)

if self.limiter is not None:
self.limiter.apply(x_out)
Expand All @@ -168,14 +253,17 @@ class ImplicitMidpoint(ImplicitRungeKutta):
k0 = F[y^n + 0.5*dt*k0] \n
y^(n+1) = y^n + dt*k0 \n
"""
def __init__(self, domain, field_name=None, solver_parameters=None,
limiter=None, options=None):
def __init__(self, domain, field_name=None, increment_form=True,
solver_parameters=None, limiter=None, options=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
mesh and the compatible function spaces.
field_name (str, optional): name of the field to be evolved.
Defaults to None.
increment_form (bool, optional): whether to write the RK scheme in
"increment form", solving for increments rather than updated
fields. Defaults to True.
solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying solver. Defaults to None.
limiter (:class:`Limiter` object, optional): a limiter to apply to
Expand All @@ -187,6 +275,7 @@ def __init__(self, domain, field_name=None, solver_parameters=None,
"""
butcher_matrix = np.array([[0.5], [1.]])
super().__init__(domain, butcher_matrix, field_name,
increment_form=increment_form,
solver_parameters=solver_parameters,
limiter=limiter, options=options)

Expand All @@ -202,14 +291,17 @@ class QinZhang(ImplicitRungeKutta):
k1 = F[y^n + 0.5*dt*k0 + 0.25*dt*k1] \n
y^(n+1) = y^n + 0.5*dt*(k0 + k1) \n
"""
def __init__(self, domain, field_name=None, solver_parameters=None,
limiter=None, options=None):
def __init__(self, domain, field_name=None, increment_form=True,
solver_parameters=None, limiter=None, options=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
mesh and the compatible function spaces.
field_name (str, optional): name of the field to be evolved.
Defaults to None.
increment_form (bool, optional): whether to write the RK scheme in
"increment form", solving for increments rather than updated
fields. Defaults to True.
solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying solver. Defaults to None.
limiter (:class:`Limiter` object, optional): a limiter to apply to
Expand All @@ -221,5 +313,6 @@ def __init__(self, domain, field_name=None, solver_parameters=None,
"""
butcher_matrix = np.array([[0.25, 0], [0.5, 0.25], [0.5, 0.5]])
super().__init__(domain, butcher_matrix, field_name,
increment_form=increment_form,
solver_parameters=solver_parameters,
limiter=limiter, options=options)
9 changes: 6 additions & 3 deletions integration-tests/model/test_time_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ def run(timestepper, tmax, f_end):

@pytest.mark.parametrize(
"scheme", ["ssprk3_increment", "TrapeziumRule", "ImplicitMidpoint",
"QinZhang", "RK4", "Heun", "BDF2", "TR_BDF2", "AdamsBashforth",
"QinZhang_increment", "QinZhang_predictor",
"RK4", "Heun", "BDF2", "TR_BDF2", "AdamsBashforth",
"Leapfrog", "AdamsMoulton", "AdamsMoulton", "ssprk3_predictor"])
def test_time_discretisation(tmpdir, scheme, tracer_setup):
if (scheme == "AdamsBashforth"):
Expand All @@ -35,8 +36,10 @@ def test_time_discretisation(tmpdir, scheme, tracer_setup):
transport_scheme = TrapeziumRule(domain)
elif scheme == "ImplicitMidpoint":
transport_scheme = ImplicitMidpoint(domain)
elif scheme == "QinZhang":
transport_scheme = QinZhang(domain)
elif scheme == "QinZhang_increment":
transport_scheme = QinZhang(domain, increment_form=True)
elif scheme == "QinZhang_predictor":
transport_scheme = QinZhang(domain, increment_form=False)
elif scheme == "RK4":
transport_scheme = RK4(domain)
elif scheme == "Heun":
Expand Down

0 comments on commit 90d1f7c

Please sign in to comment.