From 3ee894064d4223084ba45d1da79fb4f2c29b7255 Mon Sep 17 00:00:00 2001 From: Thore Kockerols Date: Fri, 8 Nov 2024 23:13:43 +0000 Subject: [PATCH] 1/3 back 1/2 forward when trying to find SS --- src/MacroModelling.jl | 74 ++++++++++++++++++++++++++++++------------- 1 file changed, 52 insertions(+), 22 deletions(-) diff --git a/src/MacroModelling.jl b/src/MacroModelling.jl index 2a6810a4..3d06e180 100644 --- a/src/MacroModelling.jl +++ b/src/MacroModelling.jl @@ -2368,11 +2368,11 @@ function write_block_solution!(𝓂, SS_solve_func, vars_to_solve, eqs_to_solve, push!(SS_solve_func,:(iters += solution[2][2])) push!(SS_solve_func,:(solution_error += solution[2][1])) - push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed after solving block with error $solution_error") end; scale = (scale + solved_scale) / 2; continue end)) + push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed after solving block with error $solution_error") end; scale = (scale + 2 * solved_scale) / 3; continue end)) if length(ss_and_aux_equations_error) > 0 push!(SS_solve_func,:(solution_error += $(Expr(:call, :+, ss_and_aux_equations_error...)))) - push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for aux variables with error $(solution_error)") end; scale = (scale + solved_scale) / 2; continue end)) + push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for aux variables with error $(solution_error)") end; scale = (scale + 2 * solved_scale) / 3; continue end)) end push!(SS_solve_func,:(sol = solution[1])) @@ -2899,7 +2899,7 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo if parsed_eq_to_solve_for != minmax_fixed_eqs [push!(atoms_in_equations, a) for a in setdiff(get_symbols(parsed_eq_to_solve_for), get_symbols(minmax_fixed_eqs))] push!(min_max_errors,:(solution_error += abs($parsed_eq_to_solve_for))) - push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for min max terms in equations with error $solution_error") end; scale = (scale + solved_scale) / 2; continue end)) + push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for min max terms in equations with error $solution_error") end; scale = (scale + 2 * solved_scale) / 3; continue end)) eq_to_solve = eval(minmax_fixed_eqs) end @@ -2941,7 +2941,7 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo if (𝓂.solved_vars[end] ∈ 𝓂.βž•_vars) push!(SS_solve_func,:($(𝓂.solved_vars[end]) = min(max($(𝓂.bounds[𝓂.solved_vars[end]][1]), $(𝓂.solved_vals[end])), $(𝓂.bounds[𝓂.solved_vars[end]][2])))) push!(SS_solve_func,:(solution_error += $(Expr(:call,:abs, Expr(:call, :-, 𝓂.solved_vars[end], 𝓂.solved_vals[end]))))) - push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for analytical aux variables with error $solution_error") end; scale = (scale + solved_scale) / 2; continue end)) + push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for analytical aux variables with error $solution_error") end; scale = (scale + 2 * solved_scale) / 3; continue end)) unique_βž•_eqs[𝓂.solved_vals[end]] = 𝓂.solved_vars[end] else @@ -2952,7 +2952,7 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo if length(vcat(ss_and_aux_equations_error, ss_and_aux_equations_error_dep)) > 0 push!(SS_solve_func,vcat(ss_and_aux_equations, ss_and_aux_equations_dep)...) push!(SS_solve_func,:(solution_error += $(Expr(:call, :+, vcat(ss_and_aux_equations_error, ss_and_aux_equations_error_dep)...)))) - push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for analytical variables with error $solution_error") end; scale = (scale + solved_scale) / 2; continue end)) + push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for analytical variables with error $solution_error") end; scale = (scale + 2 * solved_scale) / 3; continue end)) end push!(SS_solve_func,:($(𝓂.solved_vars[end]) = $(rewritten_eqs[1]))) @@ -2960,7 +2960,7 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo if haskey(𝓂.bounds, 𝓂.solved_vars[end]) && 𝓂.solved_vars[end] βˆ‰ 𝓂.βž•_vars push!(SS_solve_func,:(solution_error += abs(min(max($(𝓂.bounds[𝓂.solved_vars[end]][1]), $(𝓂.solved_vars[end])), $(𝓂.bounds[𝓂.solved_vars[end]][2])) - $(𝓂.solved_vars[end])))) - push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for bounded variables with error $solution_error") end; scale = (scale + solved_scale) / 2; continue end)) + push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for bounded variables with error $solution_error") end; scale = (scale + 2 * solved_scale) / 3; continue end)) end end else @@ -3087,7 +3087,7 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo # end)) push!(SS_solve_func,:(if (current_best > 1e-8) && (solution_error < 1e-12) - reverse_diff_friendly_push!(𝓂.NSSS_solver_cache, NSSS_solver_cache_tmp) + reverse_diff_friendly_push!(𝓂.NSSS_solver_cache, NSSS_solver_cache_tmp) # solved_scale = scale end)) # push!(SS_solve_func,:(if length(𝓂.NSSS_solver_cache) > 100 popfirst!(𝓂.NSSS_solver_cache) end)) @@ -3121,6 +3121,7 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo current_best = sum(abs2,𝓂.NSSS_solver_cache[end][end] - initial_parameters) closest_solution_init = 𝓂.NSSS_solver_cache[end] + for pars in 𝓂.NSSS_solver_cache latest = sum(abs2,pars[end] - initial_parameters) if latest <= current_best @@ -3128,6 +3129,8 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo closest_solution_init = pars end end + + # closest_solution = copy(closest_solution_init) # solution_error = 1.0 # iters = 0 range_iters = 0 @@ -3136,11 +3139,12 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo # range_length = [ 1, 2, 4, 8,16,32,64,128,1024] scale = 1.0 - while range_iters < 100 && !(solution_error < 1e-12 && solved_scale == 1) + while range_iters < 500 && !(solution_error < 1e-12 && solved_scale == 1) range_iters += 1 # println(range_iters) # println(scale) # println(solved_scale) + # println(closest_solution[end]) # for range_ in range_length # rangee = range(0,1,range_+1) # for scale in rangee[2:end] @@ -3503,7 +3507,7 @@ function solve_steady_state!(𝓂::β„³; verbose::Bool = false) end)) push!(SS_solve_func,:(if (current_best > 1e-8) && (solution_error < 1e-12) - reverse_diff_friendly_push!(𝓂.NSSS_solver_cache, NSSS_solver_cache_tmp) + reverse_diff_friendly_push!(𝓂.NSSS_solver_cache, NSSS_solver_cache_tmp) # solved_scale = scale end)) @@ -3529,6 +3533,7 @@ function solve_steady_state!(𝓂::β„³; verbose::Bool = false) current_best = sum(abs2,𝓂.NSSS_solver_cache[end][end] - initial_parameters) closest_solution_init = 𝓂.NSSS_solver_cache[end] + for pars in 𝓂.NSSS_solver_cache latest = sum(abs2,pars[end] - initial_parameters) if latest <= current_best @@ -3536,16 +3541,25 @@ function solve_steady_state!(𝓂::β„³; verbose::Bool = false) closest_solution_init = pars end end + + # closest_solution = closest_solution_init # solution_error = 1.0 # iters = 0 + range_iters = 0 + solution_error = 1.0 solved_scale = 0 - range_length = [ 1, 2, 4, 8,16,32] - for range_ in range_length - rangee = range(0,1,range_+1) - for scale in rangee[2:end] - scale = 6*scale^5 - 15*scale^4 + 10*scale^3 # smootherstep + # range_length = [ 1, 2, 4, 8,16,32,64,128,1024] + scale = 1.0 - if scale <= solved_scale continue end + while range_iters < 500 && !(solution_error < 1e-12 && solved_scale == 1) + range_iters += 1 + + # for range_ in range_length + # rangee = range(0,1,range_+1) + # for scale in rangee[2:end] + # scale = 6*scale^5 - 15*scale^4 + 10*scale^3 # smootherstep + + # if scale <= solved_scale continue end current_best = sum(abs2,𝓂.NSSS_solver_cache[end][end] - initial_parameters) closest_solution = 𝓂.NSSS_solver_cache[end] @@ -3557,7 +3571,9 @@ function solve_steady_state!(𝓂::β„³; verbose::Bool = false) closest_solution = pars end end - + + # println(closest_solution) + if all(isfinite,closest_solution[end]) && initial_parameters != closest_solution_init[end] parameters = scale * initial_parameters + (1 - scale) * closest_solution_init[end] else @@ -3565,27 +3581,39 @@ function solve_steady_state!(𝓂::β„³; verbose::Bool = false) end params_flt = parameters + # println(parameters) + $(parameters_in_equations...) $(par_bounds...) $(𝓂.calibration_equations_no_var...) NSSS_solver_cache_tmp = [] - iters = 0 solution_error = 0.0 + iters = 0 $(SS_solve_func...) if solution_error < 1e-12 + # println("solved for $scale; $range_iters") solved_scale = scale - if scale == 1 # return ComponentVector([$(sort(union(𝓂.var,𝓂.exo_past,𝓂.exo_future))...), $(𝓂.calibration_equations_parameters...)], Axis([sort(union(𝓂.exo_present,𝓂.var))...,𝓂.calibration_equations_parameters...])), solution_error # NSSS_solution = [$(Symbol.(replace.(string.(sort(union(𝓂.var,𝓂.exo_past,𝓂.exo_future))), r"ᴸ⁽⁻?[⁰¹²³⁴⁡⁢⁷⁸⁹]+⁾" => ""))...), $(𝓂.calibration_equations_parameters...)] # NSSS_solution[abs.(NSSS_solution) .< 1e-12] .= 0 # doesnt work with Zygote return [$(Symbol.(replace.(string.(sort(union(𝓂.var,𝓂.exo_past,𝓂.exo_future))), r"ᴸ⁽⁻?[⁰¹²³⁴⁡⁢⁷⁸⁹]+⁾" => ""))...), $(𝓂.calibration_equations_parameters...)], (solution_error, iters) end - elseif scale == 1 && range_ == range_length[end] - return [$(Symbol.(replace.(string.(sort(union(𝓂.var,𝓂.exo_past,𝓂.exo_future))), r"ᴸ⁽⁻?[⁰¹²³⁴⁡⁢⁷⁸⁹]+⁾" => ""))...), $(𝓂.calibration_equations_parameters...)], (solution_error, iters) + + if scale > .95 + scale = 1 + else + scale = (scale + 1) / 2 + end + # else + # println("no sol") + # scale = (scale + solved_scale) / 2 + # println("scale $scale") + # elseif scale == 1 && range_ == range_length[end] + # return [$(Symbol.(replace.(string.(sort(union(𝓂.var,𝓂.exo_past,𝓂.exo_future))), r"ᴸ⁽⁻?[⁰¹²³⁴⁡⁢⁷⁸⁹]+⁾" => ""))...), $(𝓂.calibration_equations_parameters...)], (solution_error, iters) end - end + # end end return [0.0], (1, 0) end) @@ -6864,7 +6892,9 @@ function get_NSSS_and_parameters(𝓂::β„³, SS_and_pars, (solution_error, iters) = 𝓂.SS_solve_func(parameter_values, 𝓂, verbose, false, 𝓂.solver_parameters) if solution_error > tol || isnan(solution_error) - if verbose println("Failed to find NSSS") end + if verbose + println("Failed to find NSSS") + end return (SS_and_pars, (10, iters))#, x -> (NoTangent(), NoTangent(), NoTangent(), NoTangent()) end