From b69751e838e5212600cfaba799deef5acec2cf60 Mon Sep 17 00:00:00 2001 From: vlakir Date: Fri, 17 Sep 2021 13:30:33 +0300 Subject: [PATCH] + du_dt returning --- cleanode/ode_solvers.py | 37 ++++++++++++++++++++++++++++--------- scalar_ode_example.py | 6 +++--- system_ode_example.py | 7 +++---- 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/cleanode/ode_solvers.py b/cleanode/ode_solvers.py index 45621fd..8820ef9 100644 --- a/cleanode/ode_solvers.py +++ b/cleanode/ode_solvers.py @@ -1,6 +1,6 @@ import numpy import numpy as np -from typing import Callable, Tuple +from typing import Callable, Tuple, Optional from funnydeco import benchmark import quadpy from scipy import interpolate @@ -111,7 +111,7 @@ def solve(self, print_benchmark=False, benchmark_name='') -> Tuple[np.ndarray, n i += 1 if self.is_adaptive_step and self.is_interpolate: - self.u, self.t = _interpolate_result(self.u, self.t, self.t0, self.tmax, self.dt0) + self.u, __, self.t = _interpolate_result(self.u, None, self.t, self.t0, self.tmax, self.dt0) return self.u, self.t @@ -894,7 +894,7 @@ def __init__(self, order, # noinspection PyUnusedLocal @benchmark - def solve(self, print_benchmark=False, benchmark_name='') -> Tuple[np.ndarray, np.ndarray]: + def solve(self, print_benchmark=False, benchmark_name='') -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ ODE solution :param print_benchmark: output the execution time to the console @@ -927,9 +927,9 @@ def solve(self, print_benchmark=False, benchmark_name='') -> Tuple[np.ndarray, n i += 1 if self.is_adaptive_step and self.is_interpolate: - self.u, self.t = _interpolate_result(self.u, self.t, self.t0, self.tmax, self.dt0) + self.u, self.du_dt, self.t = _interpolate_result(self.u, self.du_dt, self.t, self.t0, self.tmax, self.dt0) - return self.u, self.t + return self.u, self.du_dt, self.t def _do_step(self, u, du_dt, t, f, dt, h, alfa) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.longdouble]: """ @@ -1294,12 +1294,14 @@ def __init__(self, f2: Callable, is_interpolate=is_interpolate, tolerance=tolerance) -def _interpolate_result(u: numpy.array, t: numpy.array, t0: numpy.longdouble, tmax: numpy.longdouble, - dt: numpy.longdouble) -> Tuple[numpy.array, numpy.array]: +def _interpolate_result(u: numpy.array, du_dt: Optional[numpy.array], t: numpy.array, t0: numpy.longdouble, + tmax: numpy.longdouble, dt: numpy.longdouble) -> Tuple[numpy.array, numpy.array, numpy.array]: """ Interpolate ODE solution to uniform dt0 step :param u: solution :type u: numpy.array + :param du_dt: solution's derivative + :type du_dt: Optional[numpy.array] :param t: time :type t: numpy.array :param t0: desired lower limit @@ -1309,15 +1311,32 @@ def _interpolate_result(u: numpy.array, t: numpy.array, t0: numpy.longdouble, tm :param dt: desired step size :type dt: numpy.longdouble :return: interpolated solution - :rtype: numpy.array + :rtype: Tuple[numpy.array, numpy.array, numpy.array] """ + points_number = int((tmax - t0) / dt) t_result = np.linspace(t0, t0 + dt * points_number, points_number + 1) u_result = np.zeros((len(u[0]), len(t_result)), dtype='longdouble') + + if du_dt is not None: + du_dt_result = np.zeros((len(du_dt[0]), len(t_result)), dtype='longdouble') + else: + du_dt_result = None + for i in range(len(u[0])): solution = u[:, -1 - i] fu = interpolate.interp1d(t, solution, kind='cubic', fill_value="extrapolate") solution_result = fu(t_result) u_result[i] = solution_result + + if du_dt is not None: + solution = du_dt[:, -1 - i] + fdu = interpolate.interp1d(t, solution, kind='cubic', fill_value="extrapolate") + solution_result = fdu(t_result) + du_dt_result[i] = solution_result + u_result = numpy.rot90(u_result, k=3) - return u_result, t_result + if du_dt is not None: + du_dt_result = numpy.rot90(du_dt_result, k=3) + + return u_result, du_dt_result, t_result diff --git a/scalar_ode_example.py b/scalar_ode_example.py index 9a30055..8cbd1d9 100644 --- a/scalar_ode_example.py +++ b/scalar_ode_example.py @@ -68,12 +68,12 @@ def f1(u: List[float], t: Union[np.ndarray, np.float64]) -> List: u0 = np.array([x0], dtype='longdouble') # начальное положение du_dt0 = np.array([v0], dtype='longdouble') # начальная скорость - solver = EverhartIIRadau7ODESolver(f2, u0, du_dt0, t0, tmax, dt0, is_adaptive_step=True, tolerance=1e-8) - u3, t3 = solver.solve(print_benchmark=True, benchmark_name=solver.name) + solver = EverhartIIRadau7ODESolver(f2, u0, du_dt0, t0, tmax, dt0, is_adaptive_step=False, tolerance=1e-8) + u3, du3, t3 = solver.solve(print_benchmark=True, benchmark_name=solver.name) plt.plot(t3, u3, label=solver.name) u0 = np.array([x0, v0], dtype='longdouble') - solver = RungeKutta4ODESolver(f1, u0, t0, tmax, dt0, is_adaptive_step=True, tolerance=1e-8) + solver = RungeKutta4ODESolver(f1, u0, t0, tmax, dt0, is_adaptive_step=False, tolerance=1e-8) u1, t1 = solver.solve(print_benchmark=True, benchmark_name=solver.name) solution_x = u1[:, 0] plt.plot(t1, solution_x, label=solver.name) diff --git a/system_ode_example.py b/system_ode_example.py index 501ebc7..149bafa 100644 --- a/system_ode_example.py +++ b/system_ode_example.py @@ -9,7 +9,6 @@ # Example of the system ODE solving: cannon firing if __name__ == '__main__': - # noinspection PyUnusedLocal def f(u: List[float], t: Union[np.ndarray, np.float64]) -> List: """ @@ -60,7 +59,7 @@ def f2(u: np.longdouble, du_dt: np.longdouble, t: Union[np.ndarray, np.longdoubl # Mathematically, the ODE system looks like this: # d(dx)/dt^2 = -x / sqrt(x^2 + y^2)^3 - # d(dy)/dt^2 = -x / sqrt(x^2 + y^2)^3 + # d(dy)/dt^2 = -y / sqrt(x^2 + y^2)^3 x = u[0] y = u[1] @@ -72,7 +71,7 @@ def f2(u: np.longdouble, du_dt: np.longdouble, t: Union[np.ndarray, np.longdoubl return right_sides - # noinspection PyUnusedLocal + def exact_f(t): x = np.sin(t) y = np.cos(t) @@ -99,7 +98,7 @@ def exact_f(t): u0 = np.array([x0, y0], dtype='longdouble') du_dt0 = np.array([vx0, vy0], dtype='longdouble') solver = EverhartIIRadau7ODESolver(f2, u0, du_dt0, t0, tmax, dt0, is_adaptive_step=True, tolerance=1e-8) - solution, time_points = solver.solve(print_benchmark=True, benchmark_name=solver.name) + solution, d_solution, time_points = solver.solve(print_benchmark=True, benchmark_name=solver.name) x_solution1 = solution[:, 0] y_solution1 = solution[:, 1] plt.plot(time_points, x_solution1, label=solver.name)