diff --git a/.vscode/settings.json b/.vscode/settings.json index e36a259e0..eac107253 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -245,6 +245,7 @@ "pytest", "pytz", "quantile", + "Radau", "Rdot", "referece", "relativetoground", diff --git a/rocketpy/simulation/flight.py b/rocketpy/simulation/flight.py index a43ec15de..e274b7f63 100644 --- a/rocketpy/simulation/flight.py +++ b/rocketpy/simulation/flight.py @@ -7,7 +7,7 @@ import numpy as np import simplekml -from scipy import integrate +from scipy.integrate import BDF, DOP853, LSODA, RK23, RK45, OdeSolver, Radau from ..mathutils.function import Function, funcify_method from ..mathutils.vector_matrix import Matrix, Vector @@ -24,8 +24,19 @@ quaternions_to_spin, ) +ODE_SOLVER_MAP = { + 'RK23': RK23, + 'RK45': RK45, + 'DOP853': DOP853, + 'Radau': Radau, + 'BDF': BDF, + 'LSODA': LSODA, +} -class Flight: # pylint: disable=too-many-public-methods + +# pylint: disable=too-many-public-methods +# pylint: disable=too-many-instance-attributes +class Flight: """Keeps all flight information and has a method to simulate flight. Attributes @@ -506,6 +517,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements verbose=False, name="Flight", equations_of_motion="standard", + ode_solver="LSODA", ): """Run a trajectory simulation. @@ -581,10 +593,23 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements more restricted set of equations of motion that only works for solid propulsion rockets. Such equations were used in RocketPy v0 and are kept here for backwards compatibility. + ode_solver : str, ``scipy.integrate.OdeSolver``, optional + Integration method to use to solve the equations of motion ODE. + Available options are: 'RK23', 'RK45', 'DOP853', 'Radau', 'BDF', + 'LSODA' from ``scipy.integrate.solve_ivp``. + Default is 'LSODA', which is recommended for most flights. + A custom ``scipy.integrate.OdeSolver`` can be passed as well. + For more information on the integration methods, see the scipy + documentation [1]_. + Returns ------- None + + References + ---------- + .. [1] https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html """ # Save arguments self.env = environment @@ -605,6 +630,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements self.terminate_on_apogee = terminate_on_apogee self.name = name self.equations_of_motion = equations_of_motion + self.ode_solver = ode_solver # Controller initialization self.__init_controllers() @@ -651,15 +677,16 @@ def __simulate(self, verbose): # Create solver for this flight phase # TODO: allow different integrators self.function_evaluations.append(0) - phase.solver = integrate.LSODA( + + phase.solver = self._solver( phase.derivative, t0=phase.t, y0=self.y_sol, t_bound=phase.time_bound, - min_step=self.min_time_step, - max_step=self.max_time_step, rtol=self.rtol, atol=self.atol, + max_step=self.max_time_step, + min_step=self.min_time_step, ) # Initialize phase time nodes @@ -691,13 +718,14 @@ def __simulate(self, verbose): for node_index, node in self.time_iterator(phase.time_nodes): # Determine time bound for this time node node.time_bound = phase.time_nodes[node_index + 1].t - # NOTE: Setting the time bound and status for the phase solver, - # and updating its internal state for the next integration step. phase.solver.t_bound = node.time_bound - phase.solver._lsoda_solver._integrator.rwork[0] = phase.solver.t_bound - phase.solver._lsoda_solver._integrator.call_args[4] = ( - phase.solver._lsoda_solver._integrator.rwork - ) + if self.__is_lsoda: + phase.solver._lsoda_solver._integrator.rwork[0] = ( + phase.solver.t_bound + ) + phase.solver._lsoda_solver._integrator.call_args[4] = ( + phase.solver._lsoda_solver._integrator.rwork + ) phase.solver.status = "running" # Feed required parachute and discrete controller triggers @@ -1185,6 +1213,19 @@ def __init_solver_monitors(self): self.t = self.solution[-1][0] self.y_sol = self.solution[-1][1:] + if isinstance(self.ode_solver, OdeSolver): + self._solver = self.ode_solver + else: + try: + self._solver = ODE_SOLVER_MAP[self.ode_solver] + except KeyError as e: + raise ValueError( + f"Invalid ``ode_solver`` input: {self.ode_solver}. " + f"Available options are: {', '.join(ODE_SOLVER_MAP.keys())}" + ) from e + + self.__is_lsoda = hasattr(self._solver, "_lsoda_solver") + def __init_equations_of_motion(self): """Initialize equations of motion.""" if self.equations_of_motion == "solid_propulsion":