Skip to content

Commit

Permalink
carry through for all 3 mat eqs. qme back to indvdl calls
Browse files Browse the repository at this point in the history
  • Loading branch information
thorek1 committed Dec 22, 2024
1 parent 32d2de0 commit 612a243
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 116 deletions.
22 changes: 11 additions & 11 deletions src/MacroModelling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4114,7 +4114,7 @@ function calculate_second_order_stochastic_steady_state(parameters::Vector{M},
# end # timeit_debug

if !solved
if verbose println("1st order solution not found") end
if opts.verbose println("1st order solution not found") end
return all_SS, false, SS_and_pars, solution_error, zeros(0,0), spzeros(0,0), zeros(0,0), spzeros(0,0)
end

Expand All @@ -4141,7 +4141,7 @@ function calculate_second_order_stochastic_steady_state(parameters::Vector{M},
# end # timeit_debug

if !solved2
if verbose println("2nd order solution not found") end
if opts.verbose println("2nd order solution not found") end
return all_SS, false, SS_and_pars, solution_error, zeros(0,0), spzeros(0,0), zeros(0,0), spzeros(0,0)
end

Expand All @@ -4156,7 +4156,7 @@ function calculate_second_order_stochastic_steady_state(parameters::Vector{M},
tmp̄ = @ignore_derivatives.lu(tmp, check = false)

if !.issuccess(tmp̄)
if verbose println("SSS not found") end
if opts.verbose println("SSS not found") end
return all_SS, false, SS_and_pars, solution_error, zeros(0,0), spzeros(0,0), zeros(0,0), spzeros(0,0)
end

Expand All @@ -4178,7 +4178,7 @@ function calculate_second_order_stochastic_steady_state(parameters::Vector{M},
SSSstates, converged = calculate_second_order_stochastic_steady_state(Val(:newton), 𝐒₁, 𝐒₂, SSSstates, 𝓂) # , timer = timer)

if !converged
if verbose println("SSS not found") end
if opts.verbose println("SSS not found") end
return all_SS, false, SS_and_pars, solution_error, zeros(0,0), spzeros(0,0), zeros(0,0), spzeros(0,0)
end

Expand Down Expand Up @@ -4416,7 +4416,7 @@ function calculate_third_order_stochastic_steady_state( parameters::Vector{M},
SS_and_pars, (solution_error, iters) = get_NSSS_and_parameters(𝓂, parameters, opts = opts) # , timer = timer)

if solution_error > opts.tol || isnan(solution_error)
if verbose println("NSSS not found") end
if opts.verbose println("NSSS not found") end
return zeros(𝓂.timings.nVars), false, SS_and_pars, solution_error, zeros(0,0), spzeros(0,0), spzeros(0,0), zeros(0,0), spzeros(0,0), spzeros(0,0)
end

Expand All @@ -4432,7 +4432,7 @@ function calculate_third_order_stochastic_steady_state( parameters::Vector{M},
if solved 𝓂.solution.perturbation.qme_solution = qme_sol end

if !solved
if verbose println("1st order solution not found") end
if opts.verbose println("1st order solution not found") end
return all_SS, false, SS_and_pars, solution_error, zeros(0,0), spzeros(0,0), spzeros(0,0), zeros(0,0), spzeros(0,0), spzeros(0,0)
end

Expand All @@ -4446,7 +4446,7 @@ function calculate_third_order_stochastic_steady_state( parameters::Vector{M},
opts = opts)

if !solved2
if verbose println("2nd order solution not found") end
if opts.verbose println("2nd order solution not found") end
return all_SS, false, SS_and_pars, solution_error, zeros(0,0), spzeros(0,0), spzeros(0,0), zeros(0,0), spzeros(0,0), spzeros(0,0)
end

Expand All @@ -4467,7 +4467,7 @@ function calculate_third_order_stochastic_steady_state( parameters::Vector{M},
opts = opts)

if !solved3
if verbose println("3rd order solution not found") end
if opts.verbose println("3rd order solution not found") end
return all_SS, false, SS_and_pars, solution_error, zeros(0,0), spzeros(0,0), spzeros(0,0), zeros(0,0), spzeros(0,0), spzeros(0,0)
end

Expand All @@ -4486,7 +4486,7 @@ function calculate_third_order_stochastic_steady_state( parameters::Vector{M},
tmp̄ = @ignore_derivatives.lu(tmp, check = false)

if !.issuccess(tmp̄)
if verbose println("SSS not found") end
if opts.verbose println("SSS not found") end
return all_SS, false, SS_and_pars, solution_error, zeros(0,0), spzeros(0,0), spzeros(0,0), zeros(0,0), spzeros(0,0), spzeros(0,0)
end

Expand All @@ -4511,7 +4511,7 @@ function calculate_third_order_stochastic_steady_state( parameters::Vector{M},
SSSstates, converged = calculate_third_order_stochastic_steady_state(Val(:newton), 𝐒₁, 𝐒₂, 𝐒₃, SSSstates, 𝓂)

if !converged
if verbose println("SSS not found") end
if opts.verbose println("SSS not found") end
return all_SS, false, SS_and_pars, solution_error, zeros(0,0), spzeros(0,0), spzeros(0,0), zeros(0,0), spzeros(0,0), spzeros(0,0)
end

Expand Down Expand Up @@ -7371,7 +7371,7 @@ function get_relevant_steady_state_and_state_update(::Val{:first_order},

if solution_error > opts.tol # || isnan(solution_error) if it's NaN the fisrt condition is false anyway
# println("NSSS not found")
return TT, SS_and_pars, zeros(S, 0, 0), [state], solution_error < tol
return TT, SS_and_pars, zeros(S, 0, 0), [state], solution_error < opts.tol
end

∇₁ = calculate_jacobian(parameter_values, SS_and_pars, 𝓂) # , timer = timer)# |> Matrix
Expand Down
56 changes: 21 additions & 35 deletions src/algorithms/lyapunov.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
function solve_lyapunov_equation(A::AbstractMatrix{Float64},
C::AbstractMatrix{Float64};
lyapunov_algorithm::Symbol = :doubling,
tol::AbstractFloat = 1e-12,
# timer::TimerOutput = TimerOutput(),
tol::AbstractFloat = 1e-14,
acceptance_tol::AbstractFloat = 1e-12,
verbose::Bool = false)
# timer::TimerOutput = TimerOutput(),
# @timeit_debug timer "Solve lyapunov equation" begin
# @timeit_debug timer "Choose matrix formats" begin

Expand All @@ -28,39 +29,30 @@ function solve_lyapunov_equation(A::AbstractMatrix{Float64},
# end # timeit_debug
# @timeit_debug timer "Solve" begin

X, i, reached_tol = solve_lyapunov_equation(A, C,
Val(lyapunov_algorithm))#,
# tol = tol,
# timer = timer)
X, i, reached_tol = solve_lyapunov_equation(A, C, Val(lyapunov_algorithm), tol = tol) # timer = timer)

if verbose
println("Lyapunov equation - converged to tol $tol: $(reached_tol < tol); iterations: $i; reached tol: $reached_tol; algorithm: $lyapunov_algorithm")
println("Lyapunov equation - converged to tol $acceptance_tol: $(reached_tol < acceptance_tol); iterations: $i; reached tol: $reached_tol; algorithm: $lyapunov_algorithm")
end

if reached_tol > tol
if (reached_tol < sqrt(tol) || A isa AbstractSparseMatrix) && lyapunov_algorithm :bicgstab
if reached_tol > acceptance_tol
if (reached_tol < sqrt(acceptance_tol) || A isa AbstractSparseMatrix) && lyapunov_algorithm :bicgstab
C = collect(C)

X, i, reached_tol = solve_lyapunov_equation(A, C,
Val(:bicgstab))#,
# tol = tol,
# timer = timer)
X, i, reached_tol = solve_lyapunov_equation(A, C, Val(:bicgstab), tol = tol) # timer = timer)

if verbose
println("Lyapunov equation - converged to tol $tol: $(reached_tol < tol); iterations: $i; reached tol: $reached_tol; algorithm: gmres")
println("Lyapunov equation - converged to tol $acceptance_tol: $(reached_tol < acceptance_tol); iterations: $i; reached tol: $reached_tol; algorithm: gmres")
end
else
A = collect(A)

C = collect(C)

X, i, reached_tol = solve_lyapunov_equation(A, C,
Val(:bartels_stewart))#,
# tol = tol,
# timer = timer)
X, i, reached_tol = solve_lyapunov_equation(A, C, Val(:bartels_stewart), tol = tol) # timer = timer)

if verbose
println("Lyapunov equation - converged to tol $tol: $(reached_tol < tol); iterations: $i; reached tol: $reached_tol; algorithm: lyapunov")
println("Lyapunov equation - converged to tol $acceptance_tol: $(reached_tol < acceptance_tol); iterations: $i; reached tol: $reached_tol; algorithm: lyapunov")
end
end
end
Expand All @@ -70,27 +62,24 @@ function solve_lyapunov_equation(A::AbstractMatrix{Float64},

# if (reached_tol > tol) println("Lyapunov failed: $reached_tol") end

return X, reached_tol < tol
return X, reached_tol < acceptance_tol
end

function rrule(::typeof(solve_lyapunov_equation),
A::AbstractMatrix{Float64},
C::AbstractMatrix{Float64};
lyapunov_algorithm::Symbol = :doubling,
tol::AbstractFloat = 1e-12,
tol::AbstractFloat = 1e-14,
acceptance_tol::AbstractFloat = 1e-12,
# timer::TimerOutput = TimerOutput(),
verbose::Bool = false)

P, solved = solve_lyapunov_equation(A, C, lyapunov_algorithm = lyapunov_algorithm,
# tol = tol,
verbose = verbose)
P, solved = solve_lyapunov_equation(A, C, lyapunov_algorithm = lyapunov_algorithm, tol = tol, verbose = verbose)

# pullback
# https://arxiv.org/abs/2011.11430
function solve_lyapunov_equation_pullback(∂P)
∂C, solved = solve_lyapunov_equation(A', ∂P[1], lyapunov_algorithm = lyapunov_algorithm,
# tol = tol,
verbose = verbose)
∂C, solved = solve_lyapunov_equation(A', ∂P[1], lyapunov_algorithm = lyapunov_algorithm, tol = tol, verbose = verbose)

∂A = ∂C * A * P' + ∂C' * A * P

Expand All @@ -105,16 +94,15 @@ end
function solve_lyapunov_equation( A::AbstractMatrix{ℱ.Dual{Z,S,N}},
C::AbstractMatrix{ℱ.Dual{Z,S,N}};
lyapunov_algorithm::Symbol = :doubling,
tol::AbstractFloat = 1e-12,
tol::AbstractFloat = 1e-14,
acceptance_tol::AbstractFloat = 1e-12,
# timer::TimerOutput = TimerOutput(),
verbose::Bool = false) where {Z,S,N}
# unpack: AoS -> SoA
=.value.(A)
=.value.(C)

P̂, solved = solve_lyapunov_equation(Â, Ĉ, lyapunov_algorithm = lyapunov_algorithm,
# tol = tol,
verbose = verbose)
P̂, solved = solve_lyapunov_equation(Â, Ĉ, lyapunov_algorithm = lyapunov_algorithm, tol = tol, verbose = verbose)

= copy(Â)
= copy(Ĉ)
Expand All @@ -130,9 +118,7 @@ function solve_lyapunov_equation( A::AbstractMatrix{ℱ.Dual{Z,S,N}},

if.norm(X) < eps() continue end

P, solved = solve_lyapunov_equation(Â, X, lyapunov_algorithm = lyapunov_algorithm,
# tol = tol,
verbose = verbose)
P, solved = solve_lyapunov_equation(Â, X, lyapunov_algorithm = lyapunov_algorithm, tol = tol, verbose = verbose)

P̃[:,i] = vec(P)
end
Expand All @@ -148,7 +134,7 @@ function solve_lyapunov_equation(A::Union{ℒ.Adjoint{Float64,Matrix{Float64}},D
C::Union{ℒ.Adjoint{Float64,Matrix{Float64}},DenseMatrix{Float64}},
::Val{:bartels_stewart};
# timer::TimerOutput = TimerOutput(),
tol::AbstractFloat = 1e-12)
tol::AbstractFloat = 1e-14)
𝐂 = try
MatrixEquations.lyapd(A, C)
catch
Expand Down
13 changes: 4 additions & 9 deletions src/algorithms/quadratic_matrix_equation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,10 @@ function solve_quadratic_matrix_equation(A::AbstractMatrix{R},
C::AbstractMatrix{R},
T::timings;
initial_guess::AbstractMatrix{R} = zeros(0,0),
opts::CalculationOptions = merge_calculation_options()) where R <: Real

quadratic_matrix_equation_algorithm = opts.quadratic_matrix_equation_algorithm

tol = opts.qme_tol

acceptance_tol = opts.qme_acceptance_tol

verbose = opts.verbose
quadratic_matrix_equation_algorithm::Symbol = :schur,
tol::AbstractFloat = 1e-14,
acceptance_tol::AbstractFloat = 1e-8,
verbose::Bool = false) where R <: Real

if length(initial_guess) > 0
X = initial_guess
Expand Down
4 changes: 2 additions & 2 deletions src/algorithms/sylvester.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ function solve_sylvester_equation(A::M,
sylvester_algorithm::Symbol = :doubling,
acceptance_tol::AbstractFloat = 1e-10,
tol::AbstractFloat = 1e-14,
# timer::TimerOutput = TimerOutput(),
verbose::Bool = false) where {M <: AbstractMatrix{Float64}, N <: AbstractMatrix{Float64}, O <: AbstractMatrix{Float64}}
# timer::TimerOutput = TimerOutput(),
# @timeit_debug timer "Choose matrix formats" begin

if sylvester_algorithm == :bartels_stewart
Expand Down Expand Up @@ -992,7 +992,7 @@ function solve_sylvester_equation(A::DenseMatrix{Float64},
initial_guess::AbstractMatrix{<:AbstractFloat} = zeros(0,0),
# timer::TimerOutput = TimerOutput(),
verbose::Bool = false,
tol::AbstractFloat = 1e-12)
tol::AbstractFloat = 1e-14)
# guess_provided = true

if length(initial_guess) == 0
Expand Down
2 changes: 1 addition & 1 deletion src/filter/inversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3370,7 +3370,7 @@ function filter_data_with_model(𝓂::ℳ,

SS_and_pars, (solution_error, iters) = get_NSSS_and_parameters(𝓂, 𝓂.parameter_values, opts = opts)

if solution_error > 1e-12 || isnan(solution_error)
if solution_error > opts.tol || isnan(solution_error)
@error "No solution for these parameters."
return variables, shocks, [], decomposition
end
Expand Down
6 changes: 5 additions & 1 deletion src/filter/kalman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@ function get_initial_covariance(::Val{:theoretical},
B::AbstractMatrix{S};
opts::CalculationOptions = merge_calculation_options())::Matrix{S} where S <: Real
# timer::TimerOutput = TimerOutput(),
P, _ = solve_lyapunov_equation(A, B, lyapunov_algorithm = opts.lyapunov_algorithm, verbose = opts.verbose) # timer = timer,
P, _ = solve_lyapunov_equation(A, B,
lyapunov_algorithm = opts.lyapunov_algorithm,
tol = opts.lyapunov_tol,
acceptance_tol = opts.lyapunov_acceptance_tol,
verbose = opts.verbose) # timer = timer,

return P
end
Expand Down
Loading

0 comments on commit 612a243

Please sign in to comment.