Skip to content

Commit

Permalink
Simplify time stepping schemes and add type hints and documentation (#…
Browse files Browse the repository at this point in the history
…572)

* Improve naming in time stepping schemes.
* Add documentation and type hints. Remove support for monolithic problem to simplify code.
* Tidy up and improve inheritance and documentation.
* Do not instantiate objects from problemDefinition
  • Loading branch information
BenjaminRodenberg authored Oct 19, 2024
1 parent a67bad1 commit 9d99538
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 142 deletions.
7 changes: 7 additions & 0 deletions oscillator/solver-python/mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[mypy]

[mypy-scipy.*]
ignore_missing_imports = True

[mypy-precice.*]
ignore_missing_imports = True
64 changes: 30 additions & 34 deletions oscillator/solver-python/oscillator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from enum import Enum
import csv
import os
from typing import Type

import problemDefinition
import timeSteppers
from timeSteppers import TimeStepper, TimeSteppingSchemes, GeneralizedAlpha, RungeKutta4, RadauIIA


class Participant(Enum):
Expand All @@ -25,20 +26,24 @@ class Participant(Enum):
help="Time stepping scheme being used.",
type=str,
choices=[
s.value for s in timeSteppers.TimeSteppingSchemes],
default=timeSteppers.TimeSteppingSchemes.NEWMARK_BETA.value)
s.value for s in TimeSteppingSchemes],
default=TimeSteppingSchemes.NEWMARK_BETA.value)
args = parser.parse_args()

participant_name = args.participantName

this_mass: Type[problemDefinition.Mass]
other_mass: Type[problemDefinition.Mass]
this_spring: Type[problemDefinition.Spring]
connecting_spring = problemDefinition.SpringMiddle

if participant_name == Participant.MASS_LEFT.value:
write_data_name = 'Force-Left'
read_data_name = 'Force-Right'
mesh_name = 'Mass-Left-Mesh'

this_mass = problemDefinition.MassLeft
this_spring = problemDefinition.SpringLeft
connecting_spring = problemDefinition.SpringMiddle
other_mass = problemDefinition.MassRight

elif participant_name == Participant.MASS_RIGHT.value:
Expand All @@ -48,7 +53,6 @@ class Participant(Enum):

this_mass = problemDefinition.MassRight
this_spring = problemDefinition.SpringRight
connecting_spring = problemDefinition.SpringMiddle
other_mass = problemDefinition.MassLeft

else:
Expand Down Expand Up @@ -89,25 +93,19 @@ class Participant(Enum):
a = a0
t = 0

if args.time_stepping == timeSteppers.TimeSteppingSchemes.GENERALIZED_ALPHA.value:
time_stepper = timeSteppers.GeneralizedAlpha(stiffness=stiffness, mass=mass, alpha_f=0.4, alpha_m=0.2)
elif args.time_stepping == timeSteppers.TimeSteppingSchemes.NEWMARK_BETA.value:
time_stepper = timeSteppers.GeneralizedAlpha(stiffness=stiffness, mass=mass, alpha_f=0.0, alpha_m=0.0)
elif args.time_stepping == timeSteppers.TimeSteppingSchemes.RUNGE_KUTTA_4.value:
ode_system = np.array([
[0, 1], # du
[-stiffness / mass, 0], # dv
])
time_stepper = timeSteppers.RungeKutta4(ode_system=ode_system)
elif args.time_stepping == timeSteppers.TimeSteppingSchemes.Radau_IIA.value:
ode_system = np.array([
[0, 1], # du
[-stiffness / mass, 0], # dv
])
time_stepper = timeSteppers.RadauIIA(ode_system=ode_system)
time_stepper: TimeStepper

if args.time_stepping == TimeSteppingSchemes.GENERALIZED_ALPHA.value:
time_stepper = GeneralizedAlpha(stiffness=stiffness, mass=mass, alpha_f=0.4, alpha_m=0.2)
elif args.time_stepping == TimeSteppingSchemes.NEWMARK_BETA.value:
time_stepper = GeneralizedAlpha(stiffness=stiffness, mass=mass, alpha_f=0.0, alpha_m=0.0)
elif args.time_stepping == TimeSteppingSchemes.RUNGE_KUTTA_4.value:
time_stepper = RungeKutta4(stiffness=stiffness, mass=mass)
elif args.time_stepping == TimeSteppingSchemes.Radau_IIA.value:
time_stepper = RadauIIA(stiffness=stiffness, mass=mass)
else:
raise Exception(
f"Invalid time stepping scheme {args.time_stepping}. Please use one of {[ts.value for ts in timeSteppers.TimeSteppingSchemes]}")
f"Invalid time stepping scheme {args.time_stepping}. Please use one of {[ts.value for ts in TimeSteppingSchemes]}")


positions = []
Expand All @@ -133,34 +131,32 @@ class Participant(Enum):
precice_dt = participant.get_max_time_step_size()
dt = np.min([precice_dt, my_dt])

def f(t): return participant.read_data(mesh_name, read_data_name, vertex_ids, t)[0]
def f(t: float) -> float: return participant.read_data(mesh_name, read_data_name, vertex_ids, t)[0]

# do time step, write data, and advance
# performs adaptive time stepping with dense output; multiple write calls per time step
if args.time_stepping == timeSteppers.TimeSteppingSchemes.Radau_IIA.value:
u_new, v_new, a_new, sol = time_stepper.do_step(u, v, a, f, dt)
t_new = t + dt
u_new, v_new, a_new = time_stepper.do_step(u, v, a, f, dt)

t_new = t + dt

# RadauIIA time stepper provides dense output. Do multiple write calls per time step.
if isinstance(time_stepper, RadauIIA):
# create n samples_per_step of time stepping scheme. Degree of dense
# interpolating function is usually larger 1 and, therefore, we need
# multiple samples per step.
samples_per_step = 5
n_time_steps = len(sol.ts) # number of time steps performed by adaptive time stepper
n_time_steps = len(time_stepper.dense_output.ts) # number of time steps performed by adaptive time stepper
n_pseudo = samples_per_step * n_time_steps # samples_per_step times no. time steps per window.

t_pseudo = 0
for i in range(n_pseudo):
# perform n_pseudo pseudosteps
dt_pseudo = dt / n_pseudo
t_pseudo += dt_pseudo
write_data = [connecting_spring.k * sol(t_pseudo)[0]]
write_data = np.array([connecting_spring.k * time_stepper.dense_output(t_pseudo)[0]])
participant.write_data(mesh_name, write_data_name, vertex_ids, write_data)
participant.advance(dt_pseudo)

else: # simple time stepping without dense output; only a single write call per time step
u_new, v_new, a_new = time_stepper.do_step(u, v, a, f, dt)
t_new = t + dt

write_data = [connecting_spring.k * u_new]
write_data = np.array([connecting_spring.k * u_new])
participant.write_data(mesh_name, write_data_name, vertex_ids, write_data)
participant.advance(dt)

Expand Down
27 changes: 18 additions & 9 deletions oscillator/solver-python/problemDefinition.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,33 @@
import numpy as np
from numpy.linalg import eig
from typing import Callable


class SpringLeft:
class Spring:
k: float


class SpringLeft(Spring):
k = 4 * np.pi**2


class SpringMiddle:
class SpringMiddle(Spring):
k = 16 * (np.pi**2)


class SpringRight:
class SpringRight(Spring):
k = 4 * np.pi**2


class MassLeft:
class Mass:
m: float
u0: float
v0: float
u_analytical: Callable[[float | np.ndarray], float | np.ndarray]
v_analytical: Callable[[float | np.ndarray], float | np.ndarray]


class MassLeft(Mass):
# mass
m = 1

Expand All @@ -23,10 +36,8 @@ class MassLeft:
u0 = 1.0
v0 = 0.0

u_analytical, v_analytical = None, None # will be defined below


class MassRight:
class MassRight(Mass):
# mass
m = 1

Expand All @@ -35,8 +46,6 @@ class MassRight:
u0 = 0.0
v0 = 0.0

u_analytical, v_analytical = None, None # will be defined below


# Mass matrix
M = np.array([
Expand Down
Loading

0 comments on commit 9d99538

Please sign in to comment.