Skip to content

Commit

Permalink
+ du_dt returning
Browse files Browse the repository at this point in the history
  • Loading branch information
vlakir committed Sep 17, 2021
1 parent a288b77 commit b69751e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 16 deletions.
37 changes: 28 additions & 9 deletions cleanode/ode_solvers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy
import numpy as np
from typing import Callable, Tuple
from typing import Callable, Tuple, Optional
from funnydeco import benchmark
import quadpy
from scipy import interpolate
Expand Down Expand Up @@ -111,7 +111,7 @@ def solve(self, print_benchmark=False, benchmark_name='') -> Tuple[np.ndarray, n
i += 1

if self.is_adaptive_step and self.is_interpolate:
self.u, self.t = _interpolate_result(self.u, self.t, self.t0, self.tmax, self.dt0)
self.u, __, self.t = _interpolate_result(self.u, None, self.t, self.t0, self.tmax, self.dt0)

return self.u, self.t

Expand Down Expand Up @@ -894,7 +894,7 @@ def __init__(self, order,

# noinspection PyUnusedLocal
@benchmark
def solve(self, print_benchmark=False, benchmark_name='') -> Tuple[np.ndarray, np.ndarray]:
def solve(self, print_benchmark=False, benchmark_name='') -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
ODE solution
:param print_benchmark: output the execution time to the console
Expand Down Expand Up @@ -927,9 +927,9 @@ def solve(self, print_benchmark=False, benchmark_name='') -> Tuple[np.ndarray, n
i += 1

if self.is_adaptive_step and self.is_interpolate:
self.u, self.t = _interpolate_result(self.u, self.t, self.t0, self.tmax, self.dt0)
self.u, self.du_dt, self.t = _interpolate_result(self.u, self.du_dt, self.t, self.t0, self.tmax, self.dt0)

return self.u, self.t
return self.u, self.du_dt, self.t

def _do_step(self, u, du_dt, t, f, dt, h, alfa) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.longdouble]:
"""
Expand Down Expand Up @@ -1294,12 +1294,14 @@ def __init__(self, f2: Callable,
is_interpolate=is_interpolate, tolerance=tolerance)


def _interpolate_result(u: numpy.array, t: numpy.array, t0: numpy.longdouble, tmax: numpy.longdouble,
dt: numpy.longdouble) -> Tuple[numpy.array, numpy.array]:
def _interpolate_result(u: numpy.array, du_dt: Optional[numpy.array], t: numpy.array, t0: numpy.longdouble,
tmax: numpy.longdouble, dt: numpy.longdouble) -> Tuple[numpy.array, numpy.array, numpy.array]:
"""
Interpolate ODE solution to uniform dt0 step
:param u: solution
:type u: numpy.array
:param du_dt: solution's derivative
:type du_dt: Optional[numpy.array]
:param t: time
:type t: numpy.array
:param t0: desired lower limit
Expand All @@ -1309,15 +1311,32 @@ def _interpolate_result(u: numpy.array, t: numpy.array, t0: numpy.longdouble, tm
:param dt: desired step size
:type dt: numpy.longdouble
:return: interpolated solution
:rtype: numpy.array
:rtype: Tuple[numpy.array, numpy.array, numpy.array]
"""

points_number = int((tmax - t0) / dt)
t_result = np.linspace(t0, t0 + dt * points_number, points_number + 1)
u_result = np.zeros((len(u[0]), len(t_result)), dtype='longdouble')

if du_dt is not None:
du_dt_result = np.zeros((len(du_dt[0]), len(t_result)), dtype='longdouble')
else:
du_dt_result = None

for i in range(len(u[0])):
solution = u[:, -1 - i]
fu = interpolate.interp1d(t, solution, kind='cubic', fill_value="extrapolate")
solution_result = fu(t_result)
u_result[i] = solution_result

if du_dt is not None:
solution = du_dt[:, -1 - i]
fdu = interpolate.interp1d(t, solution, kind='cubic', fill_value="extrapolate")
solution_result = fdu(t_result)
du_dt_result[i] = solution_result

u_result = numpy.rot90(u_result, k=3)
return u_result, t_result
if du_dt is not None:
du_dt_result = numpy.rot90(du_dt_result, k=3)

return u_result, du_dt_result, t_result
6 changes: 3 additions & 3 deletions scalar_ode_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ def f1(u: List[float], t: Union[np.ndarray, np.float64]) -> List:

u0 = np.array([x0], dtype='longdouble') # начальное положение
du_dt0 = np.array([v0], dtype='longdouble') # начальная скорость
solver = EverhartIIRadau7ODESolver(f2, u0, du_dt0, t0, tmax, dt0, is_adaptive_step=True, tolerance=1e-8)
u3, t3 = solver.solve(print_benchmark=True, benchmark_name=solver.name)
solver = EverhartIIRadau7ODESolver(f2, u0, du_dt0, t0, tmax, dt0, is_adaptive_step=False, tolerance=1e-8)
u3, du3, t3 = solver.solve(print_benchmark=True, benchmark_name=solver.name)
plt.plot(t3, u3, label=solver.name)

u0 = np.array([x0, v0], dtype='longdouble')
solver = RungeKutta4ODESolver(f1, u0, t0, tmax, dt0, is_adaptive_step=True, tolerance=1e-8)
solver = RungeKutta4ODESolver(f1, u0, t0, tmax, dt0, is_adaptive_step=False, tolerance=1e-8)
u1, t1 = solver.solve(print_benchmark=True, benchmark_name=solver.name)
solution_x = u1[:, 0]
plt.plot(t1, solution_x, label=solver.name)
Expand Down
7 changes: 3 additions & 4 deletions system_ode_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

# Example of the system ODE solving: cannon firing
if __name__ == '__main__':

# noinspection PyUnusedLocal
def f(u: List[float], t: Union[np.ndarray, np.float64]) -> List:
"""
Expand Down Expand Up @@ -60,7 +59,7 @@ def f2(u: np.longdouble, du_dt: np.longdouble, t: Union[np.ndarray, np.longdoubl

# Mathematically, the ODE system looks like this:
# d(dx)/dt^2 = -x / sqrt(x^2 + y^2)^3
# d(dy)/dt^2 = -x / sqrt(x^2 + y^2)^3
# d(dy)/dt^2 = -y / sqrt(x^2 + y^2)^3

x = u[0]
y = u[1]
Expand All @@ -72,7 +71,7 @@ def f2(u: np.longdouble, du_dt: np.longdouble, t: Union[np.ndarray, np.longdoubl

return right_sides

# noinspection PyUnusedLocal

def exact_f(t):
x = np.sin(t)
y = np.cos(t)
Expand All @@ -99,7 +98,7 @@ def exact_f(t):
u0 = np.array([x0, y0], dtype='longdouble')
du_dt0 = np.array([vx0, vy0], dtype='longdouble')
solver = EverhartIIRadau7ODESolver(f2, u0, du_dt0, t0, tmax, dt0, is_adaptive_step=True, tolerance=1e-8)
solution, time_points = solver.solve(print_benchmark=True, benchmark_name=solver.name)
solution, d_solution, time_points = solver.solve(print_benchmark=True, benchmark_name=solver.name)
x_solution1 = solution[:, 0]
y_solution1 = solution[:, 1]
plt.plot(time_points, x_solution1, label=solver.name)
Expand Down

0 comments on commit b69751e

Please sign in to comment.