Skip to content

Commit

Permalink
more opts
Browse files Browse the repository at this point in the history
  • Loading branch information
thorek1 committed Dec 22, 2024
1 parent 669e6be commit 32d2de0
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 371 deletions.
108 changes: 39 additions & 69 deletions src/MacroModelling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4085,7 +4085,7 @@ function calculate_second_order_stochastic_steady_state(parameters::Vector{M},
# tol::AbstractFloat = 1e-12)::Tuple{Vector{M}, Bool, Vector{M}, M, AbstractMatrix{M}, SparseMatrixCSC{M}, AbstractMatrix{M}, SparseMatrixCSC{M}} where M
# @timeit_debug timer "Calculate NSSS" begin

SS_and_pars, (solution_error, iters) = get_NSSS_and_parameters(𝓂, parameters, verbose = opts.verbose) # , timer = timer)
SS_and_pars, (solution_error, iters) = get_NSSS_and_parameters(𝓂, parameters, opts = opts) # , timer = timer)

# end # timeit_debug

Expand Down Expand Up @@ -4413,7 +4413,7 @@ function calculate_third_order_stochastic_steady_state( parameters::Vector{M},
pruning::Bool = false)::Tuple{Vector{M}, Bool, Vector{M}, M, AbstractMatrix{M}, SparseMatrixCSC{M}, SparseMatrixCSC{M}, AbstractMatrix{M}, SparseMatrixCSC{M}, SparseMatrixCSC{M}} where M
# timer::TimerOutput = TimerOutput(),
# tol::AbstractFloat = 1e-12)::Tuple{Vector{M}, Bool, Vector{M}, M, AbstractMatrix{M}, SparseMatrixCSC{M}, SparseMatrixCSC{M}, AbstractMatrix{M}, SparseMatrixCSC{M}, SparseMatrixCSC{M}} where M
SS_and_pars, (solution_error, iters) = get_NSSS_and_parameters(𝓂, parameters, verbose = opts.verbose) # , timer = timer)
SS_and_pars, (solution_error, iters) = get_NSSS_and_parameters(𝓂, parameters, opts = opts) # , timer = timer)

if solution_error > opts.tol || isnan(solution_error)
if verbose println("NSSS not found") end
Expand Down Expand Up @@ -4778,7 +4778,7 @@ function solve!(𝓂::β„³;

# @timeit_debug timer "Solve for NSSS (if necessary)" begin

SS_and_pars, (solution_error, iters) = 𝓂.solution.outdated_NSSS ? get_NSSS_and_parameters(𝓂, 𝓂.parameter_values, verbose = opts.verbose) : (𝓂.solution.non_stochastic_steady_state, (eps(), 0))
SS_and_pars, (solution_error, iters) = 𝓂.solution.outdated_NSSS ? get_NSSS_and_parameters(𝓂, 𝓂.parameter_values, opts = opts) : (𝓂.solution.non_stochastic_steady_state, (eps(), 0))

# end # timeit_debug

Expand Down Expand Up @@ -6985,14 +6985,13 @@ end

function get_NSSS_and_parameters(𝓂::β„³,
parameter_values::Vector{S};
verbose::Bool = false,
opts::CalculationOptions = merge_calculation_options()) where S <: Float64
# timer::TimerOutput = TimerOutput(),
tol::AbstractFloat = 1e-12) where S <: Float64
# @timeit_debug timer "Calculate NSSS" begin
SS_and_pars, (solution_error, iters) = 𝓂.SS_solve_func(parameter_values, 𝓂, verbose, false, 𝓂.solver_parameters)
SS_and_pars, (solution_error, iters) = 𝓂.SS_solve_func(parameter_values, 𝓂, opts.verbose, false, 𝓂.solver_parameters)

if solution_error > tol || isnan(solution_error)
if verbose
if solution_error > opts.tol || isnan(solution_error)
if opts.verbose
println("Failed to find NSSS")
end

Expand All @@ -7007,16 +7006,15 @@ end
function rrule(::typeof(get_NSSS_and_parameters),
𝓂,
parameter_values;
verbose = false,
opts::CalculationOptions = merge_calculation_options())
# timer::TimerOutput = TimerOutput(),
tol::AbstractFloat = 1e-12)
# @timeit_debug timer "Calculate NSSS - forward" begin

SS_and_pars, (solution_error, iters) = 𝓂.SS_solve_func(parameter_values, 𝓂, verbose, false, 𝓂.solver_parameters)
SS_and_pars, (solution_error, iters) = 𝓂.SS_solve_func(parameter_values, 𝓂, opts.verbose, false, 𝓂.solver_parameters)

# end # timeit_debug

if solution_error > tol || isnan(solution_error)
if solution_error > opts.tol || isnan(solution_error)
return (SS_and_pars, (solution_error, iters)), x -> (NoTangent(), NoTangent(), NoTangent(), NoTangent())
end

Expand Down Expand Up @@ -7138,15 +7136,14 @@ end

function get_NSSS_and_parameters(𝓂::β„³,
parameter_values_dual::Vector{β„±.Dual{Z,S,N}};
verbose::Bool = false,
opts::CalculationOptions = merge_calculation_options()) where {Z,S,N}
# timer::TimerOutput = TimerOutput(),
tol::AbstractFloat = 1e-12) where {Z,S,N}
parameter_values = β„±.value.(parameter_values_dual)

SS_and_pars, (solution_error, iters) = 𝓂.SS_solve_func(parameter_values, 𝓂, verbose, false, 𝓂.solver_parameters)
SS_and_pars, (solution_error, iters) = 𝓂.SS_solve_func(parameter_values, 𝓂, opts.verbose, false, 𝓂.solver_parameters)

if solution_error > tol || isnan(solution_error)
if verbose println("Failed to find NSSS") end
if solution_error > opts.tol || isnan(solution_error)
if opts.verbose println("Failed to find NSSS") end
return (SS_and_pars, (10, iters))#, x -> (NoTangent(), NoTangent(), NoTangent(), NoTangent())
end

Expand Down Expand Up @@ -7224,7 +7221,7 @@ function get_NSSS_and_parameters(𝓂::β„³,
βˆ‚SS_equations_βˆ‚SS_and_pars_lu = RF.lu!(βˆ‚SS_equations_βˆ‚SS_and_pars, check = false)

if !β„’.issuccess(βˆ‚SS_equations_βˆ‚SS_and_pars_lu)
if verbose println("Failed to calculate implicit derivative of NSSS") end
if opts.verbose println("Failed to calculate implicit derivative of NSSS") end
return (SS_and_pars, (10, iters))#, x -> (NoTangent(), NoTangent(), NoTangent(), NoTangent())
end

Expand Down Expand Up @@ -7272,21 +7269,15 @@ end

function get_relevant_steady_state_and_state_update(::Val{:second_order},
parameter_values::Vector{S},
𝓂::β„³,
tol::AbstractFloat;
quadratic_matrix_equation_algorithm::Symbol = :schur,
sylvester_algorithm::Symbol = :doubling,
𝓂::β„³;
opts::CalculationOptions = merge_calculation_options()) where S <: Real
# timer::TimerOutput = TimerOutput(),
verbose::Bool = false) where S <: Real
sss, converged, SS_and_pars, solution_error, βˆ‡β‚, βˆ‡β‚‚, 𝐒₁, 𝐒₂ = calculate_second_order_stochastic_steady_state(parameter_values,
𝓂,
# timer = timer,
opts = opts)

sss, converged, SS_and_pars, solution_error, βˆ‡β‚, βˆ‡β‚‚, 𝐒₁, 𝐒₂ = calculate_second_order_stochastic_steady_state(parameter_values, 𝓂, opts = opts) # timer = timer,

TT = 𝓂.timings

if !converged || solution_error > 1e-12
if verbose println("Could not find 2nd order stochastic steady state") end
if !converged || solution_error > opts.tol
if opts.verbose println("Could not find 2nd order stochastic steady state") end
return TT, SS_and_pars, [𝐒₁, 𝐒₂], collect(sss), converged
end

Expand All @@ -7301,49 +7292,38 @@ end

function get_relevant_steady_state_and_state_update(::Val{:pruned_second_order},
parameter_values::Vector{S},
𝓂::β„³,
tol::AbstractFloat;
quadratic_matrix_equation_algorithm::Symbol = :schur,
sylvester_algorithm::Symbol = :doubling,
𝓂::β„³;
opts::CalculationOptions = merge_calculation_options())::Tuple{timings, Vector{S}, Union{Matrix{S},Vector{AbstractMatrix{S}}}, Vector{Vector{S}}, Bool} where S <: Real
# timer::TimerOutput = TimerOutput(),
verbose::Bool = false)::Tuple{timings, Vector{S}, Union{Matrix{S},Vector{AbstractMatrix{S}}}, Vector{Vector{S}}, Bool} where S <: Real
sss, converged, SS_and_pars, solution_error, βˆ‡β‚, βˆ‡β‚‚, 𝐒₁, 𝐒₂ = calculate_second_order_stochastic_steady_state(parameter_values,
𝓂,
pruning = true,
# timer = timer,
opts = opts)
sss, converged, SS_and_pars, solution_error, βˆ‡β‚, βˆ‡β‚‚, 𝐒₁, 𝐒₂ = calculate_second_order_stochastic_steady_state(parameter_values, 𝓂, pruning = true, opts = opts) # timer = timer,

TT = 𝓂.timings

if !converged || solution_error > 1e-12
if verbose println("Could not find 2nd order stochastic steady state") end
if !converged || solution_error > opts.tol
if opts.verbose println("Could not find 2nd order stochastic steady state") end
return TT, SS_and_pars, [𝐒₁, 𝐒₂], [zeros(𝓂.timings.nVars), zeros(𝓂.timings.nVars)], converged
end

all_SS = expand_steady_state(SS_and_pars,𝓂)

state = [zeros(𝓂.timings.nVars), collect(sss) - all_SS]


return TT, SS_and_pars, [𝐒₁, 𝐒₂], state, converged
end



function get_relevant_steady_state_and_state_update(::Val{:third_order},
parameter_values::Vector{S},
𝓂::β„³,
tol::AbstractFloat;
quadratic_matrix_equation_algorithm::Symbol = :schur,
sylvester_algorithm::Symbol = :bicgstab,
𝓂::β„³;
opts::CalculationOptions = merge_calculation_options()) where S <: Real
# timer::TimerOutput = TimerOutput(),
verbose::Bool = false) where S <: Real
sss, converged, SS_and_pars, solution_error, βˆ‡β‚, βˆ‡β‚‚, βˆ‡β‚ƒ, 𝐒₁, 𝐒₂, 𝐒₃ = calculate_third_order_stochastic_steady_state(parameter_values, 𝓂, opts = opts) # timer = timer,

TT = 𝓂.timings

if !converged || solution_error > 1e-12
if verbose println("Could not find 3rd order stochastic steady state") end
if !converged || solution_error > opts.tol
if opts.verbose println("Could not find 3rd order stochastic steady state") end
return TT, SS_and_pars, [𝐒₁, 𝐒₂, 𝐒₃], collect(sss), converged
end

Expand All @@ -7358,18 +7338,15 @@ end

function get_relevant_steady_state_and_state_update(::Val{:pruned_third_order},
parameter_values::Vector{S},
𝓂::β„³,
tol::AbstractFloat;
quadratic_matrix_equation_algorithm::Symbol = :schur,
sylvester_algorithm::Symbol = :bicgstab,
𝓂::β„³;
opts::CalculationOptions = merge_calculation_options())::Tuple{timings, Vector{S}, Union{Matrix{S},Vector{AbstractMatrix{S}}}, Vector{Vector{S}}, Bool} where S <: Real
# timer::TimerOutput = TimerOutput(),
verbose::Bool = false)::Tuple{timings, Vector{S}, Union{Matrix{S},Vector{AbstractMatrix{S}}}, Vector{Vector{S}}, Bool} where S <: Real
sss, converged, SS_and_pars, solution_error, βˆ‡β‚, βˆ‡β‚‚, βˆ‡β‚ƒ, 𝐒₁, 𝐒₂, 𝐒₃ = calculate_third_order_stochastic_steady_state(parameter_values, 𝓂, pruning = true, opts = opts) # timer = timer,

TT = 𝓂.timings

if !converged || solution_error > 1e-12
if verbose println("Could not find 3rd order stochastic steady state") end
if !converged || solution_error > opts.tol
if opts.verbose println("Could not find 3rd order stochastic steady state") end
return TT, SS_and_pars, [𝐒₁, 𝐒₂, 𝐒₃], [zeros(𝓂.timings.nVars), zeros(𝓂.timings.nVars), zeros(𝓂.timings.nVars)], converged
end

Expand All @@ -7383,22 +7360,16 @@ end

function get_relevant_steady_state_and_state_update(::Val{:first_order},
parameter_values::Vector{S},
𝓂::β„³,
tol::AbstractFloat;
quadratic_matrix_equation_algorithm::Symbol = :schur,
sylvester_algorithm::Symbol = :bicgstab,
𝓂::β„³;
opts::CalculationOptions = merge_calculation_options())::Tuple{timings, Vector{S}, Union{Matrix{S},Vector{AbstractMatrix{S}}}, Vector{Vector{Float64}}, Bool} where S <: Real
# timer::TimerOutput = TimerOutput(),
verbose::Bool = false)::Tuple{timings, Vector{S}, Union{Matrix{S},Vector{AbstractMatrix{S}}}, Vector{Vector{Float64}}, Bool} where S <: Real
SS_and_pars, (solution_error, iters) = get_NSSS_and_parameters(𝓂, parameter_values,
tol = tol,
# timer = timer,
verbose = verbose)
SS_and_pars, (solution_error, iters) = get_NSSS_and_parameters(𝓂, parameter_values, opts = opts) # timer = timer,

state = zeros(𝓂.timings.nVars)

TT = 𝓂.timings

if solution_error > 1e-12 # || isnan(solution_error) if it's NaN the fisrt condition is false anyway
if solution_error > opts.tol # || isnan(solution_error) if it's NaN the fisrt condition is false anyway
# println("NSSS not found")
return TT, SS_and_pars, zeros(S, 0, 0), [state], solution_error < tol
end
Expand All @@ -7407,10 +7378,9 @@ function get_relevant_steady_state_and_state_update(::Val{:first_order},

𝐒₁, qme_sol, solved = calculate_first_order_solution(βˆ‡β‚;
T = TT,
quadratic_matrix_equation_algorithm = quadratic_matrix_equation_algorithm,
# timer = timer,
initial_guess = 𝓂.solution.perturbation.qme_solution,
verbose = verbose)
opts = opts)

if solved 𝓂.solution.perturbation.qme_solution = qme_sol end

Expand Down
Loading

0 comments on commit 32d2de0

Please sign in to comment.