Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
vlakir committed Sep 17, 2021
1 parent b69751e commit de39277
Showing 1 changed file with 30 additions and 13 deletions.
43 changes: 30 additions & 13 deletions system_ode_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,28 @@ def f(u: List[float], t: Union[np.ndarray, np.float64]) -> List:

# Mathematically, the ODE system looks like this:
# dx/dt = Vx
# dVx/dt = -x / sqrt(x^2 + y^2)^3
# dVx/dt = -x / sqrt(x^2 + y^2 + z^2)^3
# dy/dt = Vy
# dVy/dt = -x / sqrt(x^2 + y^2)^3
# dVy/dt = -y / sqrt(x^2 + y^2 + z^2)^3
# dz/dt = Vz
# dVz/dt = -z / sqrt(x^2 + y^2 + z^2)^3

g = const.g

x = u[0]
vx = u[1]
y = u[2]
vy = u[3]
z = u[4]
vz = u[5]

right_sides = [
vx,
-x / math.sqrt(x**2 + y**2)**3,
-x / math.sqrt(x**2 + y**2 + z**2)**3,
vy,
-y / math.sqrt(x**2 + y**2)**3
-y / math.sqrt(x**2 + y**2 + z**2)**3,
vz,
-z / math.sqrt(x**2 + y**2 + z**2)**3
]

return right_sides
Expand All @@ -58,15 +64,18 @@ def f2(u: np.longdouble, du_dt: np.longdouble, t: Union[np.ndarray, np.longdoubl
"""

# Mathematically, the ODE system looks like this:
# d(dx)/dt^2 = -x / sqrt(x^2 + y^2)^3
# d(dy)/dt^2 = -y / sqrt(x^2 + y^2)^3
# d(dx)/dt^2 = -x / sqrt(x^2 + y^2 + z^2)^3
# d(dy)/dt^2 = -y / sqrt(x^2 + y^2 + z^2)^3
# d(dz)/dt^2 = -z / sqrt(x^2 + y^2 + z^2)^3

x = u[0]
y = u[1]
z = u[2]

right_sides = np.array([
-x / math.sqrt(x**2 + y**2)**3,
-y / math.sqrt(x**2 + y**2)**3,
-x / math.sqrt(x**2 + y**2 + z**2)**3,
-y / math.sqrt(x**2 + y**2 + z**2)**3,
-z / math.sqrt(x**2 + y**2 + z**2)**3
], dtype='longdouble')

return right_sides
Expand All @@ -82,25 +91,33 @@ def exact_f(t):
tmax = np.longdouble(2 * math.pi)
dt0 = np.longdouble(0.01)

is_adaptive_step = True
tolerance = 1e-8

# initial conditions:
x0 = np.longdouble(0)
y0 = np.longdouble(1)
z0 = np.longdouble(0)
vx0 = np.longdouble(1)
vy0 = np.longdouble(0)
vz0 = np.longdouble(0)

u0 = np.array([x0, vx0, y0, vy0], dtype='longdouble')
solver = RungeKutta4ODESolver(f, u0, t0, tmax, dt0, is_adaptive_step=True, tolerance=1e-8)
u0 = np.array([x0, vx0, y0, vy0, z0, vz0], dtype='longdouble')
solver = RungeKutta4ODESolver(f, u0, t0, tmax, dt0, is_adaptive_step=is_adaptive_step, tolerance=tolerance)
solution, time_points = solver.solve(print_benchmark=True, benchmark_name=solver.name)
x_solution = solution[:, 0]
y_solution = solution[:, 2]
z_solution = solution[:, 4]
plt.plot(time_points, x_solution, label=solver.name)

u0 = np.array([x0, y0], dtype='longdouble')
du_dt0 = np.array([vx0, vy0], dtype='longdouble')
solver = EverhartIIRadau7ODESolver(f2, u0, du_dt0, t0, tmax, dt0, is_adaptive_step=True, tolerance=1e-8)
u0 = np.array([x0, y0, z0], dtype='longdouble')
du_dt0 = np.array([vx0, vy0, vz0], dtype='longdouble')
solver = EverhartIIRadau7ODESolver(f2, u0, du_dt0, t0, tmax, dt0, is_adaptive_step=is_adaptive_step,
tolerance=tolerance)
solution, d_solution, time_points = solver.solve(print_benchmark=True, benchmark_name=solver.name)
x_solution1 = solution[:, 0]
y_solution1 = solution[:, 1]
z_solution1 = solution[:, 2]
plt.plot(time_points, x_solution1, label=solver.name)

points_number = int((tmax - t0) / dt0)
Expand Down

0 comments on commit de39277

Please sign in to comment.