Skip to content

Commit

Permalink
ready for release
Browse files Browse the repository at this point in the history
  • Loading branch information
kapoorlab committed Oct 6, 2023
1 parent ea098cb commit 87aa733
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
8 changes: 6 additions & 2 deletions src/quartic_solver/_tests/test_solve_quartic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
from quartic_solver import Solvers


# Coefficients are in the order of constant, x^1, x^2, x^3, x^4
@pytest.mark.parametrize(
"input_coeffs, expected_roots",
[
([6, -5, 1, 0, 0], [2, 3]),
([2, 0, -4, 3, 0], [-0.588911]),
([2, 0, -4, 0, 1], [-0.76537, 0.76537, 1.84776, -1.84776]),
([6, 5, -6, -3, 2], [1.5, 2, -1]),
],
)
def test_solve_quartic(input_coeffs, expected_roots):
Expand All @@ -15,10 +19,10 @@ def test_solve_quartic(input_coeffs, expected_roots):
for _, root in roots:
print(root)
assert any(
pytest.approx(root, rel=1e-6) == expected_root
pytest.approx(root, rel=1e-4) == expected_root
for expected_root in expected_roots
)


if __name__ == "__main__":
test_solve_quartic([6, -5, 1, 0, 0], [2, 3])
pytest.main()
22 changes: 11 additions & 11 deletions src/quartic_solver/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,9 @@ def solve_depressed_cubic(c0: float, c1: float) -> List[Tuple[int, float]]:

if c1 == zero:
if c0 > zero:
root0 = -(c0**one_third)
root0 = -math.pow(c0, one_third)
else:
root0 = (-c0) ** one_third
root0 = math.pow(-c0, one_third)

root_map.append((1, root0))
return root_map
Expand All @@ -209,16 +209,16 @@ def solve_depressed_cubic(c0: float, c1: float) -> List[Tuple[int, float]]:
if delta > zero:
delta_div_108 = delta / rat108
beta_re = -c0 / rat2
beta_im = (delta_div_108) ** 0.5
beta_im = math.sqrt(delta_div_108)
theta = math.atan2(beta_im, beta_re)
theta_div_3 = theta / rat3
angle = theta_div_3
cs = math.cos(angle)
sn = math.sin(angle)
rho_sqr = beta_re * beta_re + beta_im * beta_im
rho_pow_third = rho_sqr ** (1.0 / 6.0)
rho_pow_third = math.pow(rho_sqr, 1.0 / 6.0)
temp0 = rho_pow_third * cs
temp1 = rho_pow_third * sn * (3**0.5)
temp1 = rho_pow_third * sn * math.sqrt(3)
root0 = rat2 * temp0
root1 = -temp0 - temp1
root2 = -temp0 + temp1
Expand All @@ -228,19 +228,19 @@ def solve_depressed_cubic(c0: float, c1: float) -> List[Tuple[int, float]]:
elif delta < zero:
delta_div_108 = delta / rat108
temp0 = -c0 / rat2
temp1 = (-delta_div_108) ** 0.5
temp1 = math.sqrt(-delta_div_108)
temp2 = temp0 - temp1
temp3 = temp0 + temp1

if temp2 >= zero:
temp22 = temp2**one_third
temp22 = math.pow(temp2, one_third)
else:
temp22 = (-temp2) ** one_third
temp22 = -math.pow(-temp2, one_third)

if temp3 >= zero:
temp33 = temp3**one_third
temp33 = math.pow(temp3, one_third)
else:
temp33 = (-temp3) ** one_third
temp33 = -math.pow(-temp3, one_third)

root0 = temp22 + temp33
root_map.append((1, root0))
Expand Down Expand Up @@ -303,7 +303,7 @@ def solve_cubic(p: List[float]) -> List[Tuple[int, float]]:

q2third = q2 / rat3
c0 = q0 - q2third * (q1 - rat2 * q2third * q2third)
c1 = q1 - rat2 * q2third * q2third
c1 = q1 - q2 * q2third
RootmapLocal = Solvers.solve_depressed_cubic(c0, c1)

for rm in RootmapLocal:
Expand Down

0 comments on commit 87aa733

Please sign in to comment.