Skip to content

Commit

Permalink
ENH: Allow for Alternative and Custom ODE Solvers.
Browse files Browse the repository at this point in the history
  • Loading branch information
phmbressan committed Dec 6, 2024
1 parent d93666a commit e778883
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 11 deletions.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@
"pytest",
"pytz",
"quantile",
"Radau",
"Rdot",
"referece",
"relativetoground",
Expand Down
63 changes: 52 additions & 11 deletions rocketpy/simulation/flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
import simplekml
from scipy import integrate
from scipy.integrate import BDF, DOP853, LSODA, RK23, RK45, OdeSolver, Radau

from ..mathutils.function import Function, funcify_method
from ..mathutils.vector_matrix import Matrix, Vector
Expand All @@ -24,8 +24,19 @@
quaternions_to_spin,
)

ODE_SOLVER_MAP = {
'RK23': RK23,
'RK45': RK45,
'DOP853': DOP853,
'Radau': Radau,
'BDF': BDF,
'LSODA': LSODA,
}

class Flight: # pylint: disable=too-many-public-methods

# pylint: disable=too-many-public-methods
# pylint: disable=too-many-instance-attributes
class Flight:
"""Keeps all flight information and has a method to simulate flight.
Attributes
Expand Down Expand Up @@ -506,6 +517,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
verbose=False,
name="Flight",
equations_of_motion="standard",
ode_solver="LSODA",
):
"""Run a trajectory simulation.
Expand Down Expand Up @@ -581,10 +593,23 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
more restricted set of equations of motion that only works for
solid propulsion rockets. Such equations were used in RocketPy v0
and are kept here for backwards compatibility.
ode_solver : str, ``scipy.integrate.OdeSolver``, optional
Integration method to use to solve the equations of motion ODE.
Available options are: 'RK23', 'RK45', 'DOP853', 'Radau', 'BDF',
'LSODA' from ``scipy.integrate.solve_ivp``.
Default is 'LSODA', which is recommended for most flights.
A custom ``scipy.integrate.OdeSolver`` can be passed as well.
For more information on the integration methods, see the scipy
documentation [1]_.
Returns
-------
None
References
----------
.. [1] https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html
"""
# Save arguments
self.env = environment
Expand All @@ -605,6 +630,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
self.terminate_on_apogee = terminate_on_apogee
self.name = name
self.equations_of_motion = equations_of_motion
self.ode_solver = ode_solver

# Controller initialization
self.__init_controllers()
Expand Down Expand Up @@ -651,15 +677,16 @@ def __simulate(self, verbose):

# Create solver for this flight phase # TODO: allow different integrators
self.function_evaluations.append(0)
phase.solver = integrate.LSODA(

phase.solver = self._solver(
phase.derivative,
t0=phase.t,
y0=self.y_sol,
t_bound=phase.time_bound,
min_step=self.min_time_step,
max_step=self.max_time_step,
rtol=self.rtol,
atol=self.atol,
max_step=self.max_time_step,
min_step=self.min_time_step,
)

# Initialize phase time nodes
Expand Down Expand Up @@ -691,13 +718,14 @@ def __simulate(self, verbose):
for node_index, node in self.time_iterator(phase.time_nodes):
# Determine time bound for this time node
node.time_bound = phase.time_nodes[node_index + 1].t
# NOTE: Setting the time bound and status for the phase solver,
# and updating its internal state for the next integration step.
phase.solver.t_bound = node.time_bound
phase.solver._lsoda_solver._integrator.rwork[0] = phase.solver.t_bound
phase.solver._lsoda_solver._integrator.call_args[4] = (
phase.solver._lsoda_solver._integrator.rwork
)
if self.__is_lsoda:
phase.solver._lsoda_solver._integrator.rwork[0] = (
phase.solver.t_bound
)
phase.solver._lsoda_solver._integrator.call_args[4] = (
phase.solver._lsoda_solver._integrator.rwork
)
phase.solver.status = "running"

# Feed required parachute and discrete controller triggers
Expand Down Expand Up @@ -1185,6 +1213,19 @@ def __init_solver_monitors(self):
self.t = self.solution[-1][0]
self.y_sol = self.solution[-1][1:]

if isinstance(self.ode_solver, OdeSolver):
self._solver = self.ode_solver
else:
try:
self._solver = ODE_SOLVER_MAP[self.ode_solver]
except KeyError as e:
raise ValueError(
f"Invalid ``ode_solver`` input: {self.ode_solver}. "
f"Available options are: {', '.join(ODE_SOLVER_MAP.keys())}"
) from e

self.__is_lsoda = hasattr(self._solver, "_lsoda_solver")

def __init_equations_of_motion(self):
"""Initialize equations of motion."""
if self.equations_of_motion == "solid_propulsion":
Expand Down

0 comments on commit e778883

Please sign in to comment.