diff --git a/src/algorithms/quadratic_matrix_equation.jl b/src/algorithms/quadratic_matrix_equation.jl index c393a2cf..721d9f8d 100644 --- a/src/algorithms/quadratic_matrix_equation.jl +++ b/src/algorithms/quadratic_matrix_equation.jl @@ -13,7 +13,7 @@ function solve_quadratic_matrix_equation(A::AbstractMatrix{R}, initial_guess::AbstractMatrix{R} = zeros(0,0), quadratic_matrix_equation_solver::Symbol = :doubling, timer::TimerOutput = TimerOutput(), - tol::AbstractFloat = 1e-14, + tol::AbstractFloat = 1e-12, verbose::Bool = false) where R <: Real sol, solved, iterations, reached_tol = solve_quadratic_matrix_equation(A, B, C, Val(quadratic_matrix_equation_solver), @@ -25,16 +25,24 @@ function solve_quadratic_matrix_equation(A::AbstractMatrix{R}, if verbose println("Quadratic matrix equation solver: $quadratic_matrix_equation_solver - converged: $solved in $iterations iterations to tolerance: $reached_tol") end - if !solved && quadratic_matrix_equation_solver != :schur # try schur if previous one didn't solve it + if !solved # try schur if previous one didn't solve it + if quadratic_matrix_equation_solver == :schur + initial_guess = reached_tol < sqrt(tol) ? sol : zeros(0,0) + + other_algo = :doubling + else + other_algo = :schur + end + sol, solved, iterations, reached_tol = solve_quadratic_matrix_equation(A, B, C, - Val(:schur), + Val(other_algo), T; initial_guess = initial_guess, tol = tol, timer = timer, verbose = verbose) - if verbose println("Quadratic matrix equation solver: schur - converged: $solved in $iterations iterations to tolerance: $reached_tol") end + if verbose println("Quadratic matrix equation solver: $other_algo - converged: $solved in $iterations iterations to tolerance: $reached_tol") end end return sol, solved @@ -149,11 +157,21 @@ function solve_quadratic_matrix_equation(A::AbstractMatrix{R}, X = sol[T.dynamic_order,:] * ℒ.I(length(comb))[past_not_future_and_mixed_in_comb,:] - reached_tol = ℒ.norm(A * X * X + B * X + C) / ℒ.norm(A * X * X) + iter = 0 + + AXX = A * X^2 + + AXXnorm = ℒ.norm(AXX) + + ℒ.mul!(AXX, B, X, 1, 1) + + ℒ.axpy!(1, C, AXX) + + reached_tol = ℒ.norm(AXX) / AXXnorm converged = reached_tol < tol - return X, converged, 0, reached_tol + return X, converged, iter, reached_tol end @@ -313,7 +331,7 @@ function solve_quadratic_matrix_equation(A::AbstractMatrix{R}, # println("Iter: $i; xtol: $Xtol; ytol: $Ytol; rel ytol: $relYtol; rel xtol: $relXtol") # Check for convergence - if Xtol < tol# && Yreltol < tol # i % 2 == 0 && + if Xtol < tol / 100# && Yreltol < tol # i % 2 == 0 && solved = true iter = i break @@ -347,7 +365,7 @@ function solve_quadratic_matrix_equation(A::AbstractMatrix{R}, converged = reached_tol < tol - return X_new, converged, iter, reached_tol + return X, converged, iter, reached_tol end