diff --git a/.vscode/settings.json b/.vscode/settings.json index e36a259e0..eac107253 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -245,6 +245,7 @@ "pytest", "pytz", "quantile", + "Radau", "Rdot", "referece", "relativetoground", diff --git a/CHANGELOG.md b/CHANGELOG.md index 3aa74969b..be9ae3bbd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,7 +32,7 @@ Attention: The newest changes should be on top --> ### Added -- +- ENH: Allow for Alternative and Custom ODE Solvers. [#748](https://github.com/RocketPy-Team/RocketPy/pull/748) ### Changed diff --git a/rocketpy/simulation/flight.py b/rocketpy/simulation/flight.py index a43ec15de..e4be64f5a 100644 --- a/rocketpy/simulation/flight.py +++ b/rocketpy/simulation/flight.py @@ -7,7 +7,7 @@ import numpy as np import simplekml -from scipy import integrate +from scipy.integrate import BDF, DOP853, LSODA, RK23, RK45, OdeSolver, Radau from ..mathutils.function import Function, funcify_method from ..mathutils.vector_matrix import Matrix, Vector @@ -24,8 +24,19 @@ quaternions_to_spin, ) +ODE_SOLVER_MAP = { + 'RK23': RK23, + 'RK45': RK45, + 'DOP853': DOP853, + 'Radau': Radau, + 'BDF': BDF, + 'LSODA': LSODA, +} -class Flight: # pylint: disable=too-many-public-methods + +# pylint: disable=too-many-public-methods +# pylint: disable=too-many-instance-attributes +class Flight: """Keeps all flight information and has a method to simulate flight. Attributes @@ -506,6 +517,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements verbose=False, name="Flight", equations_of_motion="standard", + ode_solver="LSODA", ): """Run a trajectory simulation. @@ -581,10 +593,23 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements more restricted set of equations of motion that only works for solid propulsion rockets. Such equations were used in RocketPy v0 and are kept here for backwards compatibility. + ode_solver : str, ``scipy.integrate.OdeSolver``, optional + Integration method to use to solve the equations of motion ODE. + Available options are: 'RK23', 'RK45', 'DOP853', 'Radau', 'BDF', + 'LSODA' from ``scipy.integrate.solve_ivp``. + Default is 'LSODA', which is recommended for most flights. + A custom ``scipy.integrate.OdeSolver`` can be passed as well. + For more information on the integration methods, see the scipy + documentation [1]_. + Returns ------- None + + References + ---------- + .. [1] https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html """ # Save arguments self.env = environment @@ -605,6 +630,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements self.terminate_on_apogee = terminate_on_apogee self.name = name self.equations_of_motion = equations_of_motion + self.ode_solver = ode_solver # Controller initialization self.__init_controllers() @@ -651,15 +677,16 @@ def __simulate(self, verbose): # Create solver for this flight phase # TODO: allow different integrators self.function_evaluations.append(0) - phase.solver = integrate.LSODA( + + phase.solver = self._solver( phase.derivative, t0=phase.t, y0=self.y_sol, t_bound=phase.time_bound, - min_step=self.min_time_step, - max_step=self.max_time_step, rtol=self.rtol, atol=self.atol, + max_step=self.max_time_step, + min_step=self.min_time_step, ) # Initialize phase time nodes @@ -691,13 +718,14 @@ def __simulate(self, verbose): for node_index, node in self.time_iterator(phase.time_nodes): # Determine time bound for this time node node.time_bound = phase.time_nodes[node_index + 1].t - # NOTE: Setting the time bound and status for the phase solver, - # and updating its internal state for the next integration step. phase.solver.t_bound = node.time_bound - phase.solver._lsoda_solver._integrator.rwork[0] = phase.solver.t_bound - phase.solver._lsoda_solver._integrator.call_args[4] = ( - phase.solver._lsoda_solver._integrator.rwork - ) + if self.__is_lsoda: + phase.solver._lsoda_solver._integrator.rwork[0] = ( + phase.solver.t_bound + ) + phase.solver._lsoda_solver._integrator.call_args[4] = ( + phase.solver._lsoda_solver._integrator.rwork + ) phase.solver.status = "running" # Feed required parachute and discrete controller triggers @@ -1185,6 +1213,8 @@ def __init_solver_monitors(self): self.t = self.solution[-1][0] self.y_sol = self.solution[-1][1:] + self.__set_ode_solver(self.ode_solver) + def __init_equations_of_motion(self): """Initialize equations of motion.""" if self.equations_of_motion == "solid_propulsion": @@ -1222,6 +1252,28 @@ def __cache_sensor_data(self): sensor_data[sensor] = sensor.measured_data[:] self.sensor_data = sensor_data + def __set_ode_solver(self, solver): + """Sets the ODE solver to be used in the simulation. + + Parameters + ---------- + solver : str, ``scipy.integrate.OdeSolver`` + Integration method to use to solve the equations of motion ODE, + or a custom ``scipy.integrate.OdeSolver``. + """ + if isinstance(solver, OdeSolver): + self._solver = solver + else: + try: + self._solver = ODE_SOLVER_MAP[solver] + except KeyError as e: + raise ValueError( + f"Invalid ``ode_solver`` input: {solver}. " + f"Available options are: {', '.join(ODE_SOLVER_MAP.keys())}" + ) from e + + self.__is_lsoda = hasattr(self._solver, "_lsoda_solver") + @cached_property def effective_1rl(self): """Original rail length minus the distance measured from nozzle exit diff --git a/tests/integration/test_flight.py b/tests/integration/test_flight.py index fc1dd1956..5fdeb9c3d 100644 --- a/tests/integration/test_flight.py +++ b/tests/integration/test_flight.py @@ -11,7 +11,8 @@ @patch("matplotlib.pyplot.show") -def test_all_info(mock_show, flight_calisto_robust): # pylint: disable=unused-argument +# pylint: disable=unused-argument +def test_all_info(mock_show, flight_calisto_robust): """Test that the flight class is working as intended. This basically calls the all_info() method and checks if it returns None. It is not testing if the values are correct, but whether the method is working without errors. @@ -27,6 +28,42 @@ def test_all_info(mock_show, flight_calisto_robust): # pylint: disable=unused-a assert flight_calisto_robust.all_info() is None +@pytest.mark.slow +@patch("matplotlib.pyplot.show") +@pytest.mark.parametrize("solver_method", ["RK45", "DOP853", "Radau", "BDF"]) +# RK23 is unstable and requires a very low tolerance to work +# pylint: disable=unused-argument +def test_all_info_different_solvers( + mock_show, calisto_robust, example_spaceport_env, solver_method +): + """Test that the flight class is working as intended with different solver + methods. This basically calls the all_info() method and checks if it returns + None. It is not testing if the values are correct, but whether the method is + working without errors. + + Parameters + ---------- + mock_show : unittest.mock.MagicMock + Mock object to replace matplotlib.pyplot.show + calisto_robust : rocketpy.Rocket + Rocket to be simulated. See the conftest.py file for more info. + example_spaceport_env : rocketpy.Environment + Environment to be simulated. See the conftest.py file for more info. + solver_method : str + The solver method to be used in the simulation. + """ + test_flight = Flight( + environment=example_spaceport_env, + rocket=calisto_robust, + rail_length=5.2, + inclination=85, + heading=0, + terminate_on_apogee=False, + ode_solver=solver_method, + ) + assert test_flight.all_info() is None + + class TestExportData: """Tests the export_data method of the Flight class."""