From cabdcd2aefb64ea412cef70a2361e0027bf197c2 Mon Sep 17 00:00:00 2001 From: Gui-FernandesBR Date: Thu, 18 Apr 2024 18:35:14 -0400 Subject: [PATCH] MNT: Refactor Flight class root finding algorithm for rail exit and impact time calculations --- rocketpy/simulation/flight.py | 34 ++++++++++++---------------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/rocketpy/simulation/flight.py b/rocketpy/simulation/flight.py index 13aa40659..a3ab6fda3 100644 --- a/rocketpy/simulation/flight.py +++ b/rocketpy/simulation/flight.py @@ -818,18 +818,13 @@ def __simulate__(self, verbose): b = float((3 * y1 - yp1 * D - 2 * c * D - 3 * d) / (D**2)) a = float(-(2 * y1 - yp1 * D - c * D - 2 * d) / (D**3)) + 1e-5 # Find roots - d0 = b**2 - 3 * a * c - d1 = 2 * b**3 - 9 * a * b * c + 27 * d * a**2 - c1 = ((d1 + (d1**2 - 4 * d0**3) ** (0.5)) / 2) ** (1 / 3) - t_roots = [] - for k in [0, 1, 2]: - c2 = c1 * (-1 / 2 + 1j * (3**0.5) / 2) ** k - t_roots.append(-(1 / (3 * a)) * (b + c2 + d0 / c2)) + t_roots = Function.cardanos_root_finding(a, b, c, d) # Find correct root - valid_t_root = [] - for t_root in t_roots: - if 0 < t_root.real < t1 and abs(t_root.imag) < 0.001: - valid_t_root.append(t_root.real) + valid_t_root = [ + t_root.real + for t_root in t_roots + if 0 < t_root.real < t1 and abs(t_root.imag) < 0.001 + ] if len(valid_t_root) > 1: raise ValueError( "Multiple roots found when solving for rail exit time." @@ -914,18 +909,13 @@ def __simulate__(self, verbose): b = float((3 * y1 - yp1 * D - 2 * c * D - 3 * d) / (D**2)) a = float(-(2 * y1 - yp1 * D - c * D - 2 * d) / (D**3)) # Find roots - d0 = b**2 - 3 * a * c - d1 = 2 * b**3 - 9 * a * b * c + 27 * d * a**2 - c1 = ((d1 + (d1**2 - 4 * d0**3) ** (0.5)) / 2) ** (1 / 3) - t_roots = [] - for k in [0, 1, 2]: - c2 = c1 * (-1 / 2 + 1j * (3**0.5) / 2) ** k - t_roots.append(-(1 / (3 * a)) * (b + c2 + d0 / c2)) + t_roots = Function.cardanos_root_finding(a, b, c, d) # Find correct root - valid_t_root = [] - for t_root in t_roots: - if 0 < t_root.real < t1 and abs(t_root.imag) < 0.001: - valid_t_root.append(t_root.real) + valid_t_root = [ + t_root.real + for t_root in t_roots + if abs(t_root.imag) < 0.001 and 0 < t_root.real < t1 + ] if len(valid_t_root) > 1: raise ValueError( "Multiple roots found when solving for impact time."