diff --git a/src/algorithms/quadratic_matrix_equation.jl b/src/algorithms/quadratic_matrix_equation.jl index efb3eb68..3ed25dbe 100644 --- a/src/algorithms/quadratic_matrix_equation.jl +++ b/src/algorithms/quadratic_matrix_equation.jl @@ -181,7 +181,7 @@ function solve_quadratic_matrix_equation(A::AbstractMatrix{R}, reached_tol = โ„’.norm(AXX) / AXXnorm - return X, true, iter, reached_tol # schur is always successful + return X, reached_tol < tol, iter, reached_tol # schur can fail end diff --git a/src/algorithms/sylvester.jl b/src/algorithms/sylvester.jl index 59a2206a..fe3948ed 100644 --- a/src/algorithms/sylvester.jl +++ b/src/algorithms/sylvester.jl @@ -2,7 +2,7 @@ # :doubling - fast, expensive part: B^2 # :sylvester - fast, dense matrices only # :bicgstab - fastest for large problems, might not reach desired precision, warm start not always helpful -# :gmres - fastest for large problems, might not reach desired precision +# :dqgmres - fastest for large problems, might not reach desired precision # :iterative - slow # :speedmapping - slow @@ -14,26 +14,9 @@ function solve_sylvester_equation(A::M, C::O; initial_guess::AbstractMatrix{<:AbstractFloat} = zeros(0,0), sylvester_algorithm::Symbol = :doubling, - tol::AbstractFloat = 1e-14, + tol::AbstractFloat = 1e-12, timer::TimerOutput = TimerOutput(), verbose::Bool = false) where {M <: AbstractMatrix{Float64}, N <: AbstractMatrix{Float64}, O <: AbstractMatrix{Float64}} - @timeit_debug timer "Check if guess solves it already" begin - - if length(initial_guess) > 0 - ๐‚ = A * initial_guess * B + C - initial_guess - - reached_tol = โ„’.norm(๐‚) / โ„’.norm(initial_guess) - - if reached_tol < tol - if verbose println("Sylvester equation - previous solution achieves relative tol of $reached_tol") end - - # X = choose_matrix_format(initial_guess) - - return initial_guess, true - end - end - - end # timeit_debug @timeit_debug timer "Choose matrix formats" begin if sylvester_algorithm == :sylvester @@ -53,7 +36,23 @@ function solve_sylvester_equation(A::M, end end # timeit_debug + @timeit_debug timer "Check if guess solves it already" begin + + if length(initial_guess) > 0 + ๐‚ = a * initial_guess * b + c - initial_guess + + reached_tol = โ„’.norm(๐‚) / โ„’.norm(initial_guess) + + if reached_tol < tol + if verbose println("Sylvester equation - previous solution achieves relative tol of $reached_tol") end + + # X = choose_matrix_format(initial_guess) + + return initial_guess, true + end + end + end # timeit_debug @timeit_debug timer "Solve sylvester equation" begin x, solved, i, reached_tol = solve_sylvester_equation(a, b, c, Val(sylvester_algorithm), @@ -65,55 +64,122 @@ function solve_sylvester_equation(A::M, if verbose && i != 0 println("Sylvester equation - converged to tol $tol: $solved; iterations: $i; reached tol: $reached_tol; algorithm: $sylvester_algorithm") end - + if !solved - if reached_tol < sqrt(tol) && sylvester_algorithm โ‰  :bicgstab - a = collect(A) - - c = collect(C) + if sylvester_algorithm โ‰  :bicgstab + if reached_tol < sqrt(tol) + aa = collect(A) + + cc = collect(C) + + x, solved, i, reached_tol = solve_sylvester_equation(aa, b, cc, + Val(:bicgstab), + initial_guess = x, + tol = tol, + verbose = verbose, + timer = timer) + if verbose && i != 0 + println("Sylvester equation - converged to tol $tol: $solved; iterations: $i; reached tol: $reached_tol; algorithm: bicgstab (refinement of previous solution)") + end + else + aa = collect(A) + + cc = collect(C) + + x, solved, i, reached_tol = solve_sylvester_equation(aa, b, cc, + Val(:bicgstab), + initial_guess = zeros(0,0), + tol = tol, + verbose = verbose, + timer = timer) + + if verbose && i != 0 + println("Sylvester equation - converged to tol $tol: $solved; iterations: $i; reached tol: $reached_tol; algorithm: bicgstab") + end + end + end - x, solved, i, reached_tol = solve_sylvester_equation(a, b, c, - Val(:bicgstab), - initial_guess = x, - tol = tol, - verbose = verbose, - timer = timer) - if verbose && i != 0 - println("Sylvester equation - converged to tol $tol: $solved; iterations: $i; reached tol: $reached_tol; algorithm: bicgstab (refinement of previous solution)") + if sylvester_algorithm โ‰  :dqgmres + if reached_tol < sqrt(tol) + aa = collect(A) + + cc = collect(C) + + x, solved, i, reached_tol = solve_sylvester_equation(aa, b, cc, + Val(:dqgmres), + initial_guess = x, + tol = tol, + verbose = verbose, + timer = timer) + if verbose && i != 0 + println("Sylvester equation - converged to tol $tol: $solved; iterations: $i; reached tol: $reached_tol; algorithm: dqgmres (refinement of previous solution)") + end + else + aa = collect(A) + + cc = collect(C) + + x, solved, i, reached_tol = solve_sylvester_equation(aa, b, cc, + Val(:dqgmres), + initial_guess = zeros(0,0), + tol = tol, + verbose = verbose, + timer = timer) + + if verbose && i != 0 + println("Sylvester equation - converged to tol $tol: $solved; iterations: $i; reached tol: $reached_tol; algorithm: dqgmres") + end end - else - a = collect(A) + end + + if !solved && sylvester_algorithm โ‰  :sylvester && length(B) < 5e7 # try sylvester if previous one didn't solve it + aa = collect(A) + + bb = collect(B) - c = collect(C) + cc = collect(C) - x, solved, i, reached_tol = solve_sylvester_equation(a, b, c, - Val(:bicgstab), + x, solved, i, reached_tol = solve_sylvester_equation(aa, bb, cc, + Val(:sylvester), initial_guess = zeros(0,0), tol = tol, verbose = verbose, timer = timer) + if verbose && i != 0 - println("Sylvester equation - converged to tol $tol: $solved; iterations: $i; reached tol: $reached_tol; algorithm: bicgstab") + println("Sylvester equation - converged to tol $tol: $solved; iterations: $i; reached tol: $reached_tol; algorithm: sylvester") + end + + if !solved && reached_tol < sqrt(tol) + x, solved, i, reached_tol = solve_sylvester_equation(aa, b, cc, + Val(:bicgstab), + initial_guess = x, + tol = tol, + verbose = verbose, + timer = timer) + if verbose && i != 0 + println("Sylvester equation - converged to tol $tol: $solved; iterations: $i; reached tol: $reached_tol; algorithm: bicgstab (refinement of previous solution)") + end end end end - if !solved # && sylvester_algorithm != :sylvester # try schur if previous one didn't solve it - # a = collect(A) + # if !solved # && sylvester_algorithm != :sylvester # try schur if previous one didn't solve it + # # a = collect(A) - # c = collect(C) + # # c = collect(C) - x, solved, i, reached_tol = solve_sylvester_equation(a, b, c, - Val(sylvester_algorithm), - initial_guess = zeros(0,0), - tol = tol, - verbose = verbose, - timer = timer) + # x, solved, i, reached_tol = solve_sylvester_equation(a, b, c, + # Val(sylvester_algorithm), + # initial_guess = zeros(0,0), + # tol = tol, + # verbose = verbose, + # timer = timer) - if verbose && i != 0 - println("Sylvester equation - converged to tol $tol: $solved; iterations: $i; reached tol: $reached_tol; algorithm: $sylvester_algorithm (no initial guess)") - end - end + # if verbose && i != 0 + # println("Sylvester equation - converged to tol $tol: $solved; iterations: $i; reached tol: $reached_tol; algorithm: $sylvester_algorithm (no initial guess)") + # end + # end end # timeit_debug @@ -961,6 +1027,8 @@ function solve_sylvester_equation(A::DenseMatrix{Float64}, # if length(init) == 0 # ๐‚, info = Krylov.bicgstab(sylvester, C[idxs], rtol = tol / 10, atol = tol / 10)#, M = precond) ๐‚, info = Krylov.bicgstab(sylvester, [vec(๐‚ยน);], + itmax = 1000, + timemax = 10.0, rtol = tol / 100, atol = tol / 100)#, M = precond) # else @@ -987,32 +1055,6 @@ function solve_sylvester_equation(A::DenseMatrix{Float64}, end # timeit_debug - if reached_tol > tol || !isfinite(reached_tol) - @timeit_debug timer "GMRES refinement" begin - - ๐‚, info = Krylov.gmres(sylvester, [vec(C);], - [vec(๐‚);], # start value helps - rtol = tol / 100, atol = tol / 100)#, M = precond) - - # @inbounds ๐—[idxs] = ๐‚ - copyto!(๐—, ๐‚) - - โ„’.mul!(tmpฬ„, A, ๐— * B) - โ„’.axpy!(1, C, tmpฬ„) - - # denom = max(โ„’.norm(๐—), โ„’.norm(tmpฬ„)) - - โ„’.axpy!(-1, ๐—, tmpฬ„) - - # reached_tol = denom == 0 ? 0.0 : โ„’.norm(tmpฬ„) / denom - - ๐— += initial_guess - - reached_tol = โ„’.norm(A * ๐— * B + C - ๐—) / โ„’.norm(๐—) - - end # timeit_debug - end - if !(typeof(C) <: DenseMatrix) ๐— = choose_matrix_format(๐—, density_threshold = 1.0) end @@ -1024,7 +1066,7 @@ end function solve_sylvester_equation(A::DenseMatrix{Float64}, B::AbstractMatrix{Float64}, C::DenseMatrix{Float64}, - ::Val{:gmres}; + ::Val{:dqgmres}; initial_guess::AbstractMatrix{<:AbstractFloat} = zeros(0,0), timer::TimerOutput = TimerOutput(), verbose::Bool = false, @@ -1087,9 +1129,14 @@ function solve_sylvester_equation(A::DenseMatrix{Float64}, # precond = LinearOperators.LinearOperator(Float64, length(C), length(C), false, false, preconditioner!) - @timeit_debug timer "GMRES solve" begin + @timeit_debug timer "DQGMRES solve" begin # if length(init) == 0 - ๐‚, info = Krylov.gmres(sylvester, [vec(๐‚ยน);], rtol = tol / 100, atol = tol / 100)#, M = precond) + ๐‚, info = Krylov.dqgmres(sylvester, + [vec(๐‚ยน);], # start value helps + itmax = 1000, + timemax = 10.0, + rtol = tol / 100, + atol = tol / 100)#, M = precond) # else # ๐‚, info = Krylov.gmres(sylvester, [vec(C);], [vec(init);], rtol = tol / 10)#, restart = true, M = precond) # end diff --git a/src/filter/find_shocks.jl b/src/filter/find_shocks.jl index aecdb960..a3232a44 100644 --- a/src/filter/find_shocks.jl +++ b/src/filter/find_shocks.jl @@ -70,6 +70,12 @@ function find_shocks(::Val{:LagrangeNewton}, # fXฮปp = [reshape(2 * ๐’โฑยฒแต‰' * ฮป, size(๐’โฑ, 2), size(๐’โฑ, 2)) - 2*โ„’.I(size(๐’โฑ, 2)) (๐’โฑ + 2 * ๐’โฑยฒแต‰ * โ„’.kron(โ„’.I(length(x)), x))' # -(๐’โฑ + 2 * ๐’โฑยฒแต‰ * โ„’.kron(โ„’.I(length(x)), x)) zeros(size(๐’โฑ, 1),size(๐’โฑ, 1))] + # fฬ‚xฮปp = โ„’.lu(fxฮปp, check = false) + + # if !โ„’.issuccess(fฬ‚xฮปp) + # return x, false + # end + fฬ‚xฮปp = try โ„’.factorize(fxฮปp) catch