Skip to content

Commit

Permalink
1/3 back 1/2 forward when trying to find SS
Browse files Browse the repository at this point in the history
  • Loading branch information
Thore Kockerols authored and Thore Kockerols committed Nov 8, 2024
1 parent 5ab5bb7 commit 3ee8940
Showing 1 changed file with 52 additions and 22 deletions.
74 changes: 52 additions & 22 deletions src/MacroModelling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2368,11 +2368,11 @@ function write_block_solution!(𝓂, SS_solve_func, vars_to_solve, eqs_to_solve,

push!(SS_solve_func,:(iters += solution[2][2]))
push!(SS_solve_func,:(solution_error += solution[2][1]))
push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed after solving block with error $solution_error") end; scale = (scale + solved_scale) / 2; continue end))
push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed after solving block with error $solution_error") end; scale = (scale + 2 * solved_scale) / 3; continue end))

if length(ss_and_aux_equations_error) > 0
push!(SS_solve_func,:(solution_error += $(Expr(:call, :+, ss_and_aux_equations_error...))))
push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for aux variables with error $(solution_error)") end; scale = (scale + solved_scale) / 2; continue end))
push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for aux variables with error $(solution_error)") end; scale = (scale + 2 * solved_scale) / 3; continue end))
end

push!(SS_solve_func,:(sol = solution[1]))
Expand Down Expand Up @@ -2899,7 +2899,7 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo
if parsed_eq_to_solve_for != minmax_fixed_eqs
[push!(atoms_in_equations, a) for a in setdiff(get_symbols(parsed_eq_to_solve_for), get_symbols(minmax_fixed_eqs))]
push!(min_max_errors,:(solution_error += abs($parsed_eq_to_solve_for)))
push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for min max terms in equations with error $solution_error") end; scale = (scale + solved_scale) / 2; continue end))
push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for min max terms in equations with error $solution_error") end; scale = (scale + 2 * solved_scale) / 3; continue end))
eq_to_solve = eval(minmax_fixed_eqs)
end

Expand Down Expand Up @@ -2941,7 +2941,7 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo
if (𝓂.solved_vars[end] ∈ 𝓂.βž•_vars)
push!(SS_solve_func,:($(𝓂.solved_vars[end]) = min(max($(𝓂.bounds[𝓂.solved_vars[end]][1]), $(𝓂.solved_vals[end])), $(𝓂.bounds[𝓂.solved_vars[end]][2]))))
push!(SS_solve_func,:(solution_error += $(Expr(:call,:abs, Expr(:call, :-, 𝓂.solved_vars[end], 𝓂.solved_vals[end])))))
push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for analytical aux variables with error $solution_error") end; scale = (scale + solved_scale) / 2; continue end))
push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for analytical aux variables with error $solution_error") end; scale = (scale + 2 * solved_scale) / 3; continue end))

unique_βž•_eqs[𝓂.solved_vals[end]] = 𝓂.solved_vars[end]
else
Expand All @@ -2952,15 +2952,15 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo
if length(vcat(ss_and_aux_equations_error, ss_and_aux_equations_error_dep)) > 0
push!(SS_solve_func,vcat(ss_and_aux_equations, ss_and_aux_equations_dep)...)
push!(SS_solve_func,:(solution_error += $(Expr(:call, :+, vcat(ss_and_aux_equations_error, ss_and_aux_equations_error_dep)...))))
push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for analytical variables with error $solution_error") end; scale = (scale + solved_scale) / 2; continue end))
push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for analytical variables with error $solution_error") end; scale = (scale + 2 * solved_scale) / 3; continue end))
end

push!(SS_solve_func,:($(𝓂.solved_vars[end]) = $(rewritten_eqs[1])))
end

if haskey(𝓂.bounds, 𝓂.solved_vars[end]) && 𝓂.solved_vars[end] βˆ‰ 𝓂.βž•_vars
push!(SS_solve_func,:(solution_error += abs(min(max($(𝓂.bounds[𝓂.solved_vars[end]][1]), $(𝓂.solved_vars[end])), $(𝓂.bounds[𝓂.solved_vars[end]][2])) - $(𝓂.solved_vars[end]))))
push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for bounded variables with error $solution_error") end; scale = (scale + solved_scale) / 2; continue end))
push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for bounded variables with error $solution_error") end; scale = (scale + 2 * solved_scale) / 3; continue end))
end
end
else
Expand Down Expand Up @@ -3087,7 +3087,7 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo
# end))

push!(SS_solve_func,:(if (current_best > 1e-8) && (solution_error < 1e-12)
reverse_diff_friendly_push!(𝓂.NSSS_solver_cache, NSSS_solver_cache_tmp)
reverse_diff_friendly_push!(𝓂.NSSS_solver_cache, NSSS_solver_cache_tmp)
# solved_scale = scale
end))
# push!(SS_solve_func,:(if length(𝓂.NSSS_solver_cache) > 100 popfirst!(𝓂.NSSS_solver_cache) end))
Expand Down Expand Up @@ -3121,13 +3121,16 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo

current_best = sum(abs2,𝓂.NSSS_solver_cache[end][end] - initial_parameters)
closest_solution_init = 𝓂.NSSS_solver_cache[end]

for pars in 𝓂.NSSS_solver_cache
latest = sum(abs2,pars[end] - initial_parameters)
if latest <= current_best
current_best = latest
closest_solution_init = pars
end
end

# closest_solution = copy(closest_solution_init)
# solution_error = 1.0
# iters = 0
range_iters = 0
Expand All @@ -3136,11 +3139,12 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo
# range_length = [ 1, 2, 4, 8,16,32,64,128,1024]
scale = 1.0

while range_iters < 100 && !(solution_error < 1e-12 && solved_scale == 1)
while range_iters < 500 && !(solution_error < 1e-12 && solved_scale == 1)
range_iters += 1
# println(range_iters)
# println(scale)
# println(solved_scale)
# println(closest_solution[end])
# for range_ in range_length
# rangee = range(0,1,range_+1)
# for scale in rangee[2:end]
Expand Down Expand Up @@ -3503,7 +3507,7 @@ function solve_steady_state!(𝓂::β„³; verbose::Bool = false)
end))

push!(SS_solve_func,:(if (current_best > 1e-8) && (solution_error < 1e-12)
reverse_diff_friendly_push!(𝓂.NSSS_solver_cache, NSSS_solver_cache_tmp)
reverse_diff_friendly_push!(𝓂.NSSS_solver_cache, NSSS_solver_cache_tmp)
# solved_scale = scale
end))

Expand All @@ -3529,23 +3533,33 @@ function solve_steady_state!(𝓂::β„³; verbose::Bool = false)

current_best = sum(abs2,𝓂.NSSS_solver_cache[end][end] - initial_parameters)
closest_solution_init = 𝓂.NSSS_solver_cache[end]

for pars in 𝓂.NSSS_solver_cache
latest = sum(abs2,pars[end] - initial_parameters)
if latest <= current_best
current_best = latest
closest_solution_init = pars
end
end

# closest_solution = closest_solution_init
# solution_error = 1.0
# iters = 0
range_iters = 0
solution_error = 1.0
solved_scale = 0
range_length = [ 1, 2, 4, 8,16,32]
for range_ in range_length
rangee = range(0,1,range_+1)
for scale in rangee[2:end]
scale = 6*scale^5 - 15*scale^4 + 10*scale^3 # smootherstep
# range_length = [ 1, 2, 4, 8,16,32,64,128,1024]
scale = 1.0

if scale <= solved_scale continue end
while range_iters < 500 && !(solution_error < 1e-12 && solved_scale == 1)
range_iters += 1

# for range_ in range_length
# rangee = range(0,1,range_+1)
# for scale in rangee[2:end]
# scale = 6*scale^5 - 15*scale^4 + 10*scale^3 # smootherstep

# if scale <= solved_scale continue end

current_best = sum(abs2,𝓂.NSSS_solver_cache[end][end] - initial_parameters)
closest_solution = 𝓂.NSSS_solver_cache[end]
Expand All @@ -3557,35 +3571,49 @@ function solve_steady_state!(𝓂::β„³; verbose::Bool = false)
closest_solution = pars
end
end


# println(closest_solution)

if all(isfinite,closest_solution[end]) && initial_parameters != closest_solution_init[end]
parameters = scale * initial_parameters + (1 - scale) * closest_solution_init[end]
else
parameters = copy(initial_parameters)
end
params_flt = parameters

# println(parameters)

$(parameters_in_equations...)
$(par_bounds...)
$(𝓂.calibration_equations_no_var...)
NSSS_solver_cache_tmp = []
iters = 0
solution_error = 0.0
iters = 0
$(SS_solve_func...)

if solution_error < 1e-12
# println("solved for $scale; $range_iters")
solved_scale = scale

if scale == 1
# return ComponentVector([$(sort(union(𝓂.var,𝓂.exo_past,𝓂.exo_future))...), $(𝓂.calibration_equations_parameters...)], Axis([sort(union(𝓂.exo_present,𝓂.var))...,𝓂.calibration_equations_parameters...])), solution_error
# NSSS_solution = [$(Symbol.(replace.(string.(sort(union(𝓂.var,𝓂.exo_past,𝓂.exo_future))), r"ᴸ⁽⁻?[⁰¹²³⁴⁡⁢⁷⁸⁹]+⁾" => ""))...), $(𝓂.calibration_equations_parameters...)]
# NSSS_solution[abs.(NSSS_solution) .< 1e-12] .= 0 # doesnt work with Zygote
return [$(Symbol.(replace.(string.(sort(union(𝓂.var,𝓂.exo_past,𝓂.exo_future))), r"ᴸ⁽⁻?[⁰¹²³⁴⁡⁢⁷⁸⁹]+⁾" => ""))...), $(𝓂.calibration_equations_parameters...)], (solution_error, iters)
end
elseif scale == 1 && range_ == range_length[end]
return [$(Symbol.(replace.(string.(sort(union(𝓂.var,𝓂.exo_past,𝓂.exo_future))), r"ᴸ⁽⁻?[⁰¹²³⁴⁡⁢⁷⁸⁹]+⁾" => ""))...), $(𝓂.calibration_equations_parameters...)], (solution_error, iters)

if scale > .95
scale = 1
else
scale = (scale + 1) / 2
end
# else
# println("no sol")
# scale = (scale + solved_scale) / 2
# println("scale $scale")
# elseif scale == 1 && range_ == range_length[end]
# return [$(Symbol.(replace.(string.(sort(union(𝓂.var,𝓂.exo_past,𝓂.exo_future))), r"ᴸ⁽⁻?[⁰¹²³⁴⁡⁢⁷⁸⁹]+⁾" => ""))...), $(𝓂.calibration_equations_parameters...)], (solution_error, iters)
end
end
# end
end
return [0.0], (1, 0)
end)
Expand Down Expand Up @@ -6864,7 +6892,9 @@ function get_NSSS_and_parameters(𝓂::β„³,
SS_and_pars, (solution_error, iters) = 𝓂.SS_solve_func(parameter_values, 𝓂, verbose, false, 𝓂.solver_parameters)

if solution_error > tol || isnan(solution_error)
if verbose println("Failed to find NSSS") end
if verbose
println("Failed to find NSSS")
end

return (SS_and_pars, (10, iters))#, x -> (NoTangent(), NoTangent(), NoTangent(), NoTangent())
end
Expand Down

0 comments on commit 3ee8940

Please sign in to comment.