From c95c0cec5924586f81db86cb12e827a58719249f Mon Sep 17 00:00:00 2001 From: thorek1 Date: Sun, 15 Dec 2024 14:51:23 +0100 Subject: [PATCH] solver cache for scale and adjust scale dynamics --- src/MacroModelling.jl | 48 +++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/src/MacroModelling.jl b/src/MacroModelling.jl index 8aefa42d..0611c3a8 100644 --- a/src/MacroModelling.jl +++ b/src/MacroModelling.jl @@ -2443,11 +2443,11 @@ function write_block_solution!(𝓂, SS_solve_func, vars_to_solve, eqs_to_solve, push!(SS_solve_func,:(iters += solution[2][2])) push!(SS_solve_func,:(solution_error += solution[2][1])) - push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed after solving block with error $solution_error") end; scale = (scale + 2 * solved_scale) / 3; continue end)) + push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed after solving block with error $solution_error") end; scale = scale * .3 + solved_scale * .7; continue end)) if length(ss_and_aux_equations_error) > 0 push!(SS_solve_func,:(solution_error += $(Expr(:call, :+, ss_and_aux_equations_error...)))) - push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for aux variables with error $(solution_error)") end; scale = (scale + 2 * solved_scale) / 3; continue end)) + push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for aux variables with error $(solution_error)") end; scale = scale * .3 + solved_scale * .7; continue end)) end push!(SS_solve_func,:(sol = solution[1])) @@ -2974,7 +2974,7 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo if parsed_eq_to_solve_for != minmax_fixed_eqs [push!(atoms_in_equations, a) for a in setdiff(get_symbols(parsed_eq_to_solve_for), get_symbols(minmax_fixed_eqs))] push!(min_max_errors,:(solution_error += abs($parsed_eq_to_solve_for))) - push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for min max terms in equations with error $solution_error") end; scale = (scale + 2 * solved_scale) / 3; continue end)) + push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for min max terms in equations with error $solution_error") end; scale = scale * .3 + solved_scale * .7; continue end)) eq_to_solve = eval(minmax_fixed_eqs) end @@ -3016,8 +3016,8 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo if (𝓂.solved_vars[end] ∈ 𝓂.βž•_vars) push!(SS_solve_func,:($(𝓂.solved_vars[end]) = min(max($(𝓂.bounds[𝓂.solved_vars[end]][1]), $(𝓂.solved_vals[end])), $(𝓂.bounds[𝓂.solved_vars[end]][2])))) push!(SS_solve_func,:(solution_error += $(Expr(:call,:abs, Expr(:call, :-, 𝓂.solved_vars[end], 𝓂.solved_vals[end]))))) - push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for analytical aux variables with error $solution_error") end; scale = (scale + 2 * solved_scale) / 3; continue end)) - + push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for analytical aux variables with error $solution_error") end; scale = scale * .3 + solved_scale * .7; continue end)) + unique_βž•_eqs[𝓂.solved_vals[end]] = 𝓂.solved_vars[end] else vars_to_exclude = [vcat(Symbol.(var_to_solve_for), 𝓂.βž•_vars), Symbol[]] @@ -3027,7 +3027,7 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo if length(vcat(ss_and_aux_equations_error, ss_and_aux_equations_error_dep)) > 0 push!(SS_solve_func,vcat(ss_and_aux_equations, ss_and_aux_equations_dep)...) push!(SS_solve_func,:(solution_error += $(Expr(:call, :+, vcat(ss_and_aux_equations_error, ss_and_aux_equations_error_dep)...)))) - push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for analytical variables with error $solution_error") end; scale = (scale + 2 * solved_scale) / 3; continue end)) + push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for analytical variables with error $solution_error") end; scale = scale * .3 + solved_scale * .7; continue end)) end push!(SS_solve_func,:($(𝓂.solved_vars[end]) = $(rewritten_eqs[1]))) @@ -3035,7 +3035,7 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo if haskey(𝓂.bounds, 𝓂.solved_vars[end]) && 𝓂.solved_vars[end] βˆ‰ 𝓂.βž•_vars push!(SS_solve_func,:(solution_error += abs(min(max($(𝓂.bounds[𝓂.solved_vars[end]][1]), $(𝓂.solved_vars[end])), $(𝓂.bounds[𝓂.solved_vars[end]][2])) - $(𝓂.solved_vars[end])))) - push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for bounded variables with error $solution_error") end; scale = (scale + 2 * solved_scale) / 3; continue end)) + push!(SS_solve_func, :(if solution_error > 1e-12 if verbose println("Failed for bounded variables with error $solution_error") end; scale = scale * .3 + solved_scale * .7; continue end)) end end else @@ -3160,10 +3160,8 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo # current_best = latest # end # end)) - - push!(SS_solve_func,:(if (current_best > 1e-8) && (solution_error < 1e-12) + push!(SS_solve_func,:(if (current_best > 1e-8) && (solution_error < 1e-12) && (scale == 1) reverse_diff_friendly_push!(𝓂.NSSS_solver_cache, NSSS_solver_cache_tmp) - # solved_scale = scale end)) # push!(SS_solve_func,:(if length(𝓂.NSSS_solver_cache) > 100 popfirst!(𝓂.NSSS_solver_cache) end)) @@ -3219,14 +3217,21 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo solved_scale = 0 # range_length = [ 1, 2, 4, 8,16,32,64,128,1024] scale = 1.0 + + NSSS_solver_cache_scale = CircularBuffer{Vector{Vector{Float64}}}(500) + push!(NSSS_solver_cache_scale, closest_solution_init) # fail_fast_solvers_only = true # TODO: try have this run for 1000 and and use closest_solution based on the previous result and not on the cache. that way you dont crowd out good and diverse solutions in the cache and make sure he finds the other SS. rely on previous commit for way of implementing closest_solution while range_iters <= (cold_start ? 1 : 500) && !(solution_error < 1e-12 && solved_scale == 1) range_iters += 1 fail_fast_solvers_only = range_iters > 1 ? true : false - # println(range_iters) - # println(scale) - # println(solved_scale) + + if abs(solved_scale - scale) < 1e-4 + # println(NSSS_solver_cache_scale[end]) + break + end + + # println("i: $range_iters - scale: $scale - solved_scale: $solved_scale") # println(closest_solution[end]) # for range_ in range_length # rangee = range(0,1,range_+1) @@ -3235,10 +3240,11 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo # if scale <= solved_scale continue end - current_best = sum(abs2,𝓂.NSSS_solver_cache[end][end] - initial_parameters) - closest_solution = 𝓂.NSSS_solver_cache[end] + + current_best = sum(abs2,NSSS_solver_cache_scale[end][end] - initial_parameters) + closest_solution = NSSS_solver_cache_scale[end] - for pars in 𝓂.NSSS_solver_cache + for pars in NSSS_solver_cache_scale copy!(initial_parameters_tmp, pars[end]) β„’.axpy!(-1,initial_parameters,initial_parameters_tmp) @@ -3278,12 +3284,15 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo # NSSS_solution = [$(Symbol.(replace.(string.(sort(union(𝓂.var,𝓂.exo_past,𝓂.exo_future))), r"ᴸ⁽⁻?[⁰¹²³⁴⁡⁢⁷⁸⁹]+⁾" => ""))...), $(𝓂.calibration_equations_parameters...)] # NSSS_solution[abs.(NSSS_solution) .< 1e-12] .= 0 # doesnt work with Zygote return [$(Symbol.(replace.(string.(sort(union(𝓂.var,𝓂.exo_past,𝓂.exo_future))), r"ᴸ⁽⁻?[⁰¹²³⁴⁡⁢⁷⁸⁹]+⁾" => ""))...), $(𝓂.calibration_equations_parameters...)], (solution_error, iters) + else + reverse_diff_friendly_push!(NSSS_solver_cache_scale, NSSS_solver_cache_tmp) end if scale > .95 scale = 1 else - scale = (scale + 1) / 2 + # scale = (scale + 1) / 2 + scale = scale * .4 + .6 end # else # println("no sol") @@ -3695,12 +3704,15 @@ function solve_steady_state!(𝓂::β„³; verbose::Bool = false) # NSSS_solution = [$(Symbol.(replace.(string.(sort(union(𝓂.var,𝓂.exo_past,𝓂.exo_future))), r"ᴸ⁽⁻?[⁰¹²³⁴⁡⁢⁷⁸⁹]+⁾" => ""))...), $(𝓂.calibration_equations_parameters...)] # NSSS_solution[abs.(NSSS_solution) .< 1e-12] .= 0 # doesnt work with Zygote return [$(Symbol.(replace.(string.(sort(union(𝓂.var,𝓂.exo_past,𝓂.exo_future))), r"ᴸ⁽⁻?[⁰¹²³⁴⁡⁢⁷⁸⁹]+⁾" => ""))...), $(𝓂.calibration_equations_parameters...)], (solution_error, iters) + else + reverse_diff_friendly_push!(NSSS_solver_cache_scale, NSSS_solver_cache_tmp) end if scale > .95 scale = 1 else - scale = (scale + 1) / 2 + # scale = (scale + 1) / 2 + scale = scale * .4 + .6 end # else # println("no sol")